add nucleus sampling. it costs lines of code, but i think thit is the default best way to sample, so it is important to have

This commit is contained in:
Andrej Karpathy
2023-08-06 07:22:39 +00:00
parent 49e3ff6d08
commit 8931d5092e
2 changed files with 87 additions and 21 deletions
+84 -18
View File
@@ -57,6 +57,11 @@ typedef struct {
float* wcls;
} TransformerWeights;
typedef struct {
float prob;
int index;
} ProbIndex; // struct used when sorting probabilities during top-p sampling
typedef struct {
// current wave of activations
float *x; // activation at current time stamp (dim,)
@@ -69,6 +74,7 @@ typedef struct {
float *v; // value (dim,)
float *att; // buffer for scores/attention values (n_heads, seq_len)
float *logits; // output logits
ProbIndex *probindex; // buffer used in top-p sampling
// kv cache
float* key_cache; // (layer, seq_len, dim)
float* value_cache; // (layer, seq_len, dim)
@@ -86,12 +92,13 @@ void malloc_run_state(RunState* s, Config* p) {
s->v = calloc(p->dim, sizeof(float));
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
s->logits = calloc(p->vocab_size, sizeof(float));
s->probindex = calloc(p->vocab_size, sizeof(ProbIndex));
s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
// ensure all mallocs went fine
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|| !s->value_cache) {
|| !s->value_cache || !s->probindex) {
printf("malloc failed!\n");
exit(EXIT_FAILURE);
}
@@ -108,6 +115,7 @@ void free_run_state(RunState* s) {
free(s->v);
free(s->att);
free(s->logits);
free(s->probindex);
free(s->key_cache);
free(s->value_cache);
}
@@ -394,7 +402,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
}
// ----------------------------------------------------------------------------
// utilities
// utilities: time / rng
long time_in_ms() {
// return time in milliseconds, for benchmarking the model speed
@@ -415,8 +423,24 @@ float random_f32() { // random float32 in [0,1)
return (random_u32() >> 8) / 16777216.0f;
}
// ----------------------------------------------------------------------------
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
int argmax(float* probabilities, int n) {
// return the index that has the highest probability
int max_i = 0;
float max_p = probabilities[0];
for (int i = 1; i < n; i++) {
if (probabilities[i] > max_p) {
max_i = i;
max_p = probabilities[i];
}
}
return max_i;
}
int sample(float* probabilities, int n) {
// sample index from probabilities, they must sum to 1
// sample index from probabilities (they must sum to 1!)
float r = random_f32();
float cdf = 0.0f;
for (int i = 0; i < n; i++) {
@@ -428,28 +452,62 @@ int sample(float* probabilities, int n) {
return n - 1; // in case of rounding errors
}
int argmax(float* v, int n) {
// return argmax of v in elements 0..n
int max_i = 0;
float max_p = v[0];
for (int i = 1; i < n; i++) {
if (v[i] > max_p) {
max_i = i;
max_p = v[i];
int compare(const void* a, const void* b) {
ProbIndex* a_ = (ProbIndex*) a;
ProbIndex* b_ = (ProbIndex*) b;
if (a_->prob > b_->prob) return -1;
if (a_->prob < b_->prob) return 1;
return 0;
}
int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
// top-p sampling (or "nucleus sampling") samples from the smallest set of
// tokens that exceed probability topp. This way we never sample tokens that
// have very low probabilities and are less likely to go "off the rails".
// quicksort indices in descending order of probabilities
for (int i = 0; i < n; i++) {
probindex[i].index = i;
probindex[i].prob = probabilities[i];
}
qsort(probindex, n, sizeof(ProbIndex), compare);
// truncate the list where cumulative probability exceeds topp
float cumulative_prob = 0.0f;
int last_idx = 0;
for (int i = 0; i < n; i++) {
cumulative_prob += probindex[i].prob;
if (cumulative_prob > topp) {
last_idx = i;
break; // we've exceeded topp by including last_idx
}
}
return max_i;
// sample from the truncated list
float r = random_f32() * cumulative_prob;
float cdf = 0.0f;
for (int i = 0; i <= last_idx; i++) {
cdf += probindex[i].prob;
if (r < cdf) {
return probindex[i].index;
}
}
return probindex[last_idx].index; // in case of rounding errors
}
// ----------------------------------------------------------------------------
// int main
void error_usage() {
printf("Usage: run <checkpoint> [options]\n");
printf("Example: run model.bin -t 0.9 -n 256 -p \"Once upon a time\"\n");
printf("Example: run model.bin -n 256 -i \"Once upon a time\"\n");
printf("Options:\n");
printf(" -t <float> temperature, default 0.9\n");
printf(" -t <float> temperature, default 1.0\n");
printf(" -p <float> p value in top-p (nucleus) sampling. default 0.9, 0 = off\n");
printf(" -s <int> random seed, default time(NULL)\n");
printf(" -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
printf(" -p <string> prompt string, default none\n");
printf(" -i <string> input prompt\n");
exit(EXIT_FAILURE);
}
@@ -457,7 +515,8 @@ int main(int argc, char *argv[]) {
// default inits
char *checkpoint = NULL; // e.g. out/model.bin
float temperature = 0.9f; // 0.0 = greedy & deterministic, 1.0 = max uncertainty
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
float topp = 0.9f; // top-p in nucleus sampling
rng_seed = (unsigned int)time(NULL); // seed rng with time by default
int steps = 256; // number of steps to run for
char *prompt = NULL; // prompt string
@@ -471,9 +530,10 @@ int main(int argc, char *argv[]) {
if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
// read in the args
if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); }
else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); }
else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
else if (argv[i][1] == 'p') { prompt = argv[i + 1]; }
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
else { error_usage(); }
}
@@ -562,7 +622,13 @@ int main(int argc, char *argv[]) {
// apply softmax to the logits to get the probabilities for next token
softmax(state.logits, config.vocab_size);
// we sample from this distribution to get the next token
next = sample(state.logits, config.vocab_size);
if (topp <= 0) {
// simply sample from the predicted probability distribution
next = sample(state.logits, config.vocab_size);
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
next = sample_topp(state.logits, config.vocab_size, topp, state.probindex);
}
}
}