add nucleus sampling. it costs lines of code, but i think thit is the default best way to sample, so it is important to have
This commit is contained in:
@@ -57,6 +57,11 @@ typedef struct {
|
||||
float* wcls;
|
||||
} TransformerWeights;
|
||||
|
||||
typedef struct {
|
||||
float prob;
|
||||
int index;
|
||||
} ProbIndex; // struct used when sorting probabilities during top-p sampling
|
||||
|
||||
typedef struct {
|
||||
// current wave of activations
|
||||
float *x; // activation at current time stamp (dim,)
|
||||
@@ -69,6 +74,7 @@ typedef struct {
|
||||
float *v; // value (dim,)
|
||||
float *att; // buffer for scores/attention values (n_heads, seq_len)
|
||||
float *logits; // output logits
|
||||
ProbIndex *probindex; // buffer used in top-p sampling
|
||||
// kv cache
|
||||
float* key_cache; // (layer, seq_len, dim)
|
||||
float* value_cache; // (layer, seq_len, dim)
|
||||
@@ -86,12 +92,13 @@ void malloc_run_state(RunState* s, Config* p) {
|
||||
s->v = calloc(p->dim, sizeof(float));
|
||||
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
|
||||
s->logits = calloc(p->vocab_size, sizeof(float));
|
||||
s->probindex = calloc(p->vocab_size, sizeof(ProbIndex));
|
||||
s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
|
||||
s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
|
||||
// ensure all mallocs went fine
|
||||
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|
||||
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|
||||
|| !s->value_cache) {
|
||||
|| !s->value_cache || !s->probindex) {
|
||||
printf("malloc failed!\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
@@ -108,6 +115,7 @@ void free_run_state(RunState* s) {
|
||||
free(s->v);
|
||||
free(s->att);
|
||||
free(s->logits);
|
||||
free(s->probindex);
|
||||
free(s->key_cache);
|
||||
free(s->value_cache);
|
||||
}
|
||||
@@ -394,7 +402,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// utilities
|
||||
// utilities: time / rng
|
||||
|
||||
long time_in_ms() {
|
||||
// return time in milliseconds, for benchmarking the model speed
|
||||
@@ -415,8 +423,24 @@ float random_f32() { // random float32 in [0,1)
|
||||
return (random_u32() >> 8) / 16777216.0f;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
|
||||
|
||||
int argmax(float* probabilities, int n) {
|
||||
// return the index that has the highest probability
|
||||
int max_i = 0;
|
||||
float max_p = probabilities[0];
|
||||
for (int i = 1; i < n; i++) {
|
||||
if (probabilities[i] > max_p) {
|
||||
max_i = i;
|
||||
max_p = probabilities[i];
|
||||
}
|
||||
}
|
||||
return max_i;
|
||||
}
|
||||
|
||||
int sample(float* probabilities, int n) {
|
||||
// sample index from probabilities, they must sum to 1
|
||||
// sample index from probabilities (they must sum to 1!)
|
||||
float r = random_f32();
|
||||
float cdf = 0.0f;
|
||||
for (int i = 0; i < n; i++) {
|
||||
@@ -428,28 +452,62 @@ int sample(float* probabilities, int n) {
|
||||
return n - 1; // in case of rounding errors
|
||||
}
|
||||
|
||||
int argmax(float* v, int n) {
|
||||
// return argmax of v in elements 0..n
|
||||
int max_i = 0;
|
||||
float max_p = v[0];
|
||||
for (int i = 1; i < n; i++) {
|
||||
if (v[i] > max_p) {
|
||||
max_i = i;
|
||||
max_p = v[i];
|
||||
int compare(const void* a, const void* b) {
|
||||
ProbIndex* a_ = (ProbIndex*) a;
|
||||
ProbIndex* b_ = (ProbIndex*) b;
|
||||
if (a_->prob > b_->prob) return -1;
|
||||
if (a_->prob < b_->prob) return 1;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
||||
// top-p sampling (or "nucleus sampling") samples from the smallest set of
|
||||
// tokens that exceed probability topp. This way we never sample tokens that
|
||||
// have very low probabilities and are less likely to go "off the rails".
|
||||
|
||||
// quicksort indices in descending order of probabilities
|
||||
for (int i = 0; i < n; i++) {
|
||||
probindex[i].index = i;
|
||||
probindex[i].prob = probabilities[i];
|
||||
}
|
||||
qsort(probindex, n, sizeof(ProbIndex), compare);
|
||||
|
||||
// truncate the list where cumulative probability exceeds topp
|
||||
float cumulative_prob = 0.0f;
|
||||
int last_idx = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
cumulative_prob += probindex[i].prob;
|
||||
if (cumulative_prob > topp) {
|
||||
last_idx = i;
|
||||
break; // we've exceeded topp by including last_idx
|
||||
}
|
||||
}
|
||||
return max_i;
|
||||
|
||||
// sample from the truncated list
|
||||
float r = random_f32() * cumulative_prob;
|
||||
float cdf = 0.0f;
|
||||
for (int i = 0; i <= last_idx; i++) {
|
||||
cdf += probindex[i].prob;
|
||||
if (r < cdf) {
|
||||
return probindex[i].index;
|
||||
}
|
||||
}
|
||||
return probindex[last_idx].index; // in case of rounding errors
|
||||
}
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// int main
|
||||
|
||||
void error_usage() {
|
||||
printf("Usage: run <checkpoint> [options]\n");
|
||||
printf("Example: run model.bin -t 0.9 -n 256 -p \"Once upon a time\"\n");
|
||||
printf("Example: run model.bin -n 256 -i \"Once upon a time\"\n");
|
||||
printf("Options:\n");
|
||||
printf(" -t <float> temperature, default 0.9\n");
|
||||
printf(" -t <float> temperature, default 1.0\n");
|
||||
printf(" -p <float> p value in top-p (nucleus) sampling. default 0.9, 0 = off\n");
|
||||
printf(" -s <int> random seed, default time(NULL)\n");
|
||||
printf(" -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
|
||||
printf(" -p <string> prompt string, default none\n");
|
||||
printf(" -i <string> input prompt\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
@@ -457,7 +515,8 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
// default inits
|
||||
char *checkpoint = NULL; // e.g. out/model.bin
|
||||
float temperature = 0.9f; // 0.0 = greedy & deterministic, 1.0 = max uncertainty
|
||||
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
||||
float topp = 0.9f; // top-p in nucleus sampling
|
||||
rng_seed = (unsigned int)time(NULL); // seed rng with time by default
|
||||
int steps = 256; // number of steps to run for
|
||||
char *prompt = NULL; // prompt string
|
||||
@@ -471,9 +530,10 @@ int main(int argc, char *argv[]) {
|
||||
if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
|
||||
// read in the args
|
||||
if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); }
|
||||
else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'p') { prompt = argv[i + 1]; }
|
||||
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
|
||||
else { error_usage(); }
|
||||
}
|
||||
|
||||
@@ -562,7 +622,13 @@ int main(int argc, char *argv[]) {
|
||||
// apply softmax to the logits to get the probabilities for next token
|
||||
softmax(state.logits, config.vocab_size);
|
||||
// we sample from this distribution to get the next token
|
||||
next = sample(state.logits, config.vocab_size);
|
||||
if (topp <= 0) {
|
||||
// simply sample from the predicted probability distribution
|
||||
next = sample(state.logits, config.vocab_size);
|
||||
} else {
|
||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||
next = sample_topp(state.logits, config.vocab_size, topp, state.probindex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user