probindex should never have been part of RunState. i apologize for this failure of abstraction
This commit is contained in:
@@ -50,11 +50,6 @@ typedef struct {
|
||||
float* wcls;
|
||||
} TransformerWeights;
|
||||
|
||||
typedef struct {
|
||||
float prob;
|
||||
int index;
|
||||
} ProbIndex; // struct used when sorting probabilities during top-p sampling
|
||||
|
||||
typedef struct {
|
||||
// current wave of activations
|
||||
float *x; // activation at current time stamp (dim,)
|
||||
@@ -67,7 +62,6 @@ typedef struct {
|
||||
float *v; // value (dim,)
|
||||
float *att; // buffer for scores/attention values (n_heads, seq_len)
|
||||
float *logits; // output logits
|
||||
ProbIndex *probindex; // buffer used in top-p sampling
|
||||
// kv cache
|
||||
float* key_cache; // (layer, seq_len, dim)
|
||||
float* value_cache; // (layer, seq_len, dim)
|
||||
@@ -86,13 +80,12 @@ void malloc_run_state(RunState* s, Config* p) {
|
||||
s->v = calloc(kv_dim, sizeof(float));
|
||||
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
|
||||
s->logits = calloc(p->vocab_size, sizeof(float));
|
||||
s->probindex = calloc(p->vocab_size, sizeof(ProbIndex));
|
||||
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
||||
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
||||
// ensure all mallocs went fine
|
||||
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|
||||
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|
||||
|| !s->value_cache || !s->probindex) {
|
||||
|| !s->value_cache) {
|
||||
fprintf(stderr, "malloc failed!\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
@@ -109,7 +102,6 @@ void free_run_state(RunState* s) {
|
||||
free(s->v);
|
||||
free(s->att);
|
||||
free(s->logits);
|
||||
free(s->probindex);
|
||||
free(s->key_cache);
|
||||
free(s->value_cache);
|
||||
}
|
||||
@@ -499,6 +491,11 @@ float random_f32() { // random float32 in [0,1)
|
||||
// ----------------------------------------------------------------------------
|
||||
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
|
||||
|
||||
typedef struct {
|
||||
float prob;
|
||||
int index;
|
||||
} ProbIndex; // struct used when sorting probabilities during top-p sampling
|
||||
|
||||
int argmax(float* probabilities, int n) {
|
||||
// return the index that has the highest probability
|
||||
int max_i = 0;
|
||||
@@ -654,6 +651,7 @@ int main(int argc, char *argv[]) {
|
||||
// create and init the application RunState
|
||||
RunState state;
|
||||
malloc_run_state(&state, &config);
|
||||
ProbIndex *probindex = malloc(config.vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling
|
||||
|
||||
// process the prompt, if any
|
||||
int *prompt_tokens = NULL;
|
||||
@@ -693,7 +691,7 @@ int main(int argc, char *argv[]) {
|
||||
next = sample(state.logits, config.vocab_size);
|
||||
} else {
|
||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||
next = sample_topp(state.logits, config.vocab_size, topp, state.probindex);
|
||||
next = sample_topp(state.logits, config.vocab_size, topp, probindex);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -734,6 +732,7 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
// memory and file handles cleanup
|
||||
free_run_state(&state);
|
||||
free(probindex);
|
||||
for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); }
|
||||
free(vocab);
|
||||
free(vocab_scores);
|
||||
|
||||
Reference in New Issue
Block a user