probindex should never have been part of RunState. i apologize for this failure of abstraction

This commit is contained in:
Andrej Karpathy
2023-08-20 17:18:06 +00:00
parent 8c93c7a30e
commit c0511de617
+9 -10
View File
@@ -50,11 +50,6 @@ typedef struct {
float* wcls;
} TransformerWeights;
typedef struct {
float prob;
int index;
} ProbIndex; // struct used when sorting probabilities during top-p sampling
typedef struct {
// current wave of activations
float *x; // activation at current time stamp (dim,)
@@ -67,7 +62,6 @@ typedef struct {
float *v; // value (dim,)
float *att; // buffer for scores/attention values (n_heads, seq_len)
float *logits; // output logits
ProbIndex *probindex; // buffer used in top-p sampling
// kv cache
float* key_cache; // (layer, seq_len, dim)
float* value_cache; // (layer, seq_len, dim)
@@ -86,13 +80,12 @@ void malloc_run_state(RunState* s, Config* p) {
s->v = calloc(kv_dim, sizeof(float));
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
s->logits = calloc(p->vocab_size, sizeof(float));
s->probindex = calloc(p->vocab_size, sizeof(ProbIndex));
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
// ensure all mallocs went fine
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|| !s->value_cache || !s->probindex) {
|| !s->value_cache) {
fprintf(stderr, "malloc failed!\n");
exit(EXIT_FAILURE);
}
@@ -109,7 +102,6 @@ void free_run_state(RunState* s) {
free(s->v);
free(s->att);
free(s->logits);
free(s->probindex);
free(s->key_cache);
free(s->value_cache);
}
@@ -499,6 +491,11 @@ float random_f32() { // random float32 in [0,1)
// ----------------------------------------------------------------------------
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
typedef struct {
float prob;
int index;
} ProbIndex; // struct used when sorting probabilities during top-p sampling
int argmax(float* probabilities, int n) {
// return the index that has the highest probability
int max_i = 0;
@@ -654,6 +651,7 @@ int main(int argc, char *argv[]) {
// create and init the application RunState
RunState state;
malloc_run_state(&state, &config);
ProbIndex *probindex = malloc(config.vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling
// process the prompt, if any
int *prompt_tokens = NULL;
@@ -693,7 +691,7 @@ int main(int argc, char *argv[]) {
next = sample(state.logits, config.vocab_size);
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
next = sample_topp(state.logits, config.vocab_size, topp, state.probindex);
next = sample_topp(state.logits, config.vocab_size, topp, probindex);
}
}
}
@@ -734,6 +732,7 @@ int main(int argc, char *argv[]) {
// memory and file handles cleanup
free_run_state(&state);
free(probindex);
for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); }
free(vocab);
free(vocab_scores);