diff --git a/run.c b/run.c index 59b8b29..614f18b 100644 --- a/run.c +++ b/run.c @@ -50,11 +50,6 @@ typedef struct { float* wcls; } TransformerWeights; -typedef struct { - float prob; - int index; -} ProbIndex; // struct used when sorting probabilities during top-p sampling - typedef struct { // current wave of activations float *x; // activation at current time stamp (dim,) @@ -67,7 +62,6 @@ typedef struct { float *v; // value (dim,) float *att; // buffer for scores/attention values (n_heads, seq_len) float *logits; // output logits - ProbIndex *probindex; // buffer used in top-p sampling // kv cache float* key_cache; // (layer, seq_len, dim) float* value_cache; // (layer, seq_len, dim) @@ -86,13 +80,12 @@ void malloc_run_state(RunState* s, Config* p) { s->v = calloc(kv_dim, sizeof(float)); s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); s->logits = calloc(p->vocab_size, sizeof(float)); - s->probindex = calloc(p->vocab_size, sizeof(ProbIndex)); s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); // ensure all mallocs went fine if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q || !s->k || !s->v || !s->att || !s->logits || !s->key_cache - || !s->value_cache || !s->probindex) { + || !s->value_cache) { fprintf(stderr, "malloc failed!\n"); exit(EXIT_FAILURE); } @@ -109,7 +102,6 @@ void free_run_state(RunState* s) { free(s->v); free(s->att); free(s->logits); - free(s->probindex); free(s->key_cache); free(s->value_cache); } @@ -499,6 +491,11 @@ float random_f32() { // random float32 in [0,1) // ---------------------------------------------------------------------------- // sampling can be done in a few ways: greedy argmax, sampling, top-p sampling +typedef struct { + float prob; + int index; +} ProbIndex; // struct used when sorting probabilities during top-p sampling + int argmax(float* probabilities, int n) { // return the index that has the highest probability int max_i = 0; @@ -654,6 +651,7 @@ int main(int argc, char *argv[]) { // create and init the application RunState RunState state; malloc_run_state(&state, &config); + ProbIndex *probindex = malloc(config.vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling // process the prompt, if any int *prompt_tokens = NULL; @@ -693,7 +691,7 @@ int main(int argc, char *argv[]) { next = sample(state.logits, config.vocab_size); } else { // top-p (nucleus) sampling, clamping the least likely tokens to zero - next = sample_topp(state.logits, config.vocab_size, topp, state.probindex); + next = sample_topp(state.logits, config.vocab_size, topp, probindex); } } } @@ -734,6 +732,7 @@ int main(int argc, char *argv[]) { // memory and file handles cleanup free_run_state(&state); + free(probindex); for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); } free(vocab); free(vocab_scores);