probindex should never have been part of RunState. i apologize for this failure of abstraction
This commit is contained in:
@@ -50,11 +50,6 @@ typedef struct {
|
|||||||
float* wcls;
|
float* wcls;
|
||||||
} TransformerWeights;
|
} TransformerWeights;
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
float prob;
|
|
||||||
int index;
|
|
||||||
} ProbIndex; // struct used when sorting probabilities during top-p sampling
|
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
// current wave of activations
|
// current wave of activations
|
||||||
float *x; // activation at current time stamp (dim,)
|
float *x; // activation at current time stamp (dim,)
|
||||||
@@ -67,7 +62,6 @@ typedef struct {
|
|||||||
float *v; // value (dim,)
|
float *v; // value (dim,)
|
||||||
float *att; // buffer for scores/attention values (n_heads, seq_len)
|
float *att; // buffer for scores/attention values (n_heads, seq_len)
|
||||||
float *logits; // output logits
|
float *logits; // output logits
|
||||||
ProbIndex *probindex; // buffer used in top-p sampling
|
|
||||||
// kv cache
|
// kv cache
|
||||||
float* key_cache; // (layer, seq_len, dim)
|
float* key_cache; // (layer, seq_len, dim)
|
||||||
float* value_cache; // (layer, seq_len, dim)
|
float* value_cache; // (layer, seq_len, dim)
|
||||||
@@ -86,13 +80,12 @@ void malloc_run_state(RunState* s, Config* p) {
|
|||||||
s->v = calloc(kv_dim, sizeof(float));
|
s->v = calloc(kv_dim, sizeof(float));
|
||||||
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
|
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
|
||||||
s->logits = calloc(p->vocab_size, sizeof(float));
|
s->logits = calloc(p->vocab_size, sizeof(float));
|
||||||
s->probindex = calloc(p->vocab_size, sizeof(ProbIndex));
|
|
||||||
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
||||||
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
||||||
// ensure all mallocs went fine
|
// ensure all mallocs went fine
|
||||||
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|
||||||
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|
||||||
|| !s->value_cache || !s->probindex) {
|
|| !s->value_cache) {
|
||||||
fprintf(stderr, "malloc failed!\n");
|
fprintf(stderr, "malloc failed!\n");
|
||||||
exit(EXIT_FAILURE);
|
exit(EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
@@ -109,7 +102,6 @@ void free_run_state(RunState* s) {
|
|||||||
free(s->v);
|
free(s->v);
|
||||||
free(s->att);
|
free(s->att);
|
||||||
free(s->logits);
|
free(s->logits);
|
||||||
free(s->probindex);
|
|
||||||
free(s->key_cache);
|
free(s->key_cache);
|
||||||
free(s->value_cache);
|
free(s->value_cache);
|
||||||
}
|
}
|
||||||
@@ -499,6 +491,11 @@ float random_f32() { // random float32 in [0,1)
|
|||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
|
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
float prob;
|
||||||
|
int index;
|
||||||
|
} ProbIndex; // struct used when sorting probabilities during top-p sampling
|
||||||
|
|
||||||
int argmax(float* probabilities, int n) {
|
int argmax(float* probabilities, int n) {
|
||||||
// return the index that has the highest probability
|
// return the index that has the highest probability
|
||||||
int max_i = 0;
|
int max_i = 0;
|
||||||
@@ -654,6 +651,7 @@ int main(int argc, char *argv[]) {
|
|||||||
// create and init the application RunState
|
// create and init the application RunState
|
||||||
RunState state;
|
RunState state;
|
||||||
malloc_run_state(&state, &config);
|
malloc_run_state(&state, &config);
|
||||||
|
ProbIndex *probindex = malloc(config.vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling
|
||||||
|
|
||||||
// process the prompt, if any
|
// process the prompt, if any
|
||||||
int *prompt_tokens = NULL;
|
int *prompt_tokens = NULL;
|
||||||
@@ -693,7 +691,7 @@ int main(int argc, char *argv[]) {
|
|||||||
next = sample(state.logits, config.vocab_size);
|
next = sample(state.logits, config.vocab_size);
|
||||||
} else {
|
} else {
|
||||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||||
next = sample_topp(state.logits, config.vocab_size, topp, state.probindex);
|
next = sample_topp(state.logits, config.vocab_size, topp, probindex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -734,6 +732,7 @@ int main(int argc, char *argv[]) {
|
|||||||
|
|
||||||
// memory and file handles cleanup
|
// memory and file handles cleanup
|
||||||
free_run_state(&state);
|
free_run_state(&state);
|
||||||
|
free(probindex);
|
||||||
for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); }
|
for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); }
|
||||||
free(vocab);
|
free(vocab);
|
||||||
free(vocab_scores);
|
free(vocab_scores);
|
||||||
|
|||||||
Reference in New Issue
Block a user