absorb our rng state into the Sampler. I feel that this is correct because it makes our use of entropy very explicit and localized, and the sampler is now well-contained without any global state. Code is increasingly more beautiful.
This commit is contained in:
@@ -455,7 +455,8 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create a temporary buffer that will store merge candidates of always two consecutive tokens
|
// create a temporary buffer that will store merge candidates of always two consecutive tokens
|
||||||
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
|
// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
|
||||||
|
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));
|
||||||
size_t str_len = 0;
|
size_t str_len = 0;
|
||||||
|
|
||||||
// add_dummy_prefix is true by default
|
// add_dummy_prefix is true by default
|
||||||
@@ -559,22 +560,9 @@ typedef struct {
|
|||||||
ProbIndex* probindex; // buffer used in top-p sampling
|
ProbIndex* probindex; // buffer used in top-p sampling
|
||||||
float temperature;
|
float temperature;
|
||||||
float topp;
|
float topp;
|
||||||
|
unsigned long long rng_state;
|
||||||
} Sampler;
|
} Sampler;
|
||||||
|
|
||||||
// rng should technically be a state variable of the Sampler
|
|
||||||
// leaving it global here for now for convenience, maybe move later
|
|
||||||
unsigned long long rng_seed;
|
|
||||||
unsigned int random_u32() {
|
|
||||||
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
|
|
||||||
rng_seed ^= rng_seed >> 12;
|
|
||||||
rng_seed ^= rng_seed << 25;
|
|
||||||
rng_seed ^= rng_seed >> 27;
|
|
||||||
return (rng_seed * 0x2545F4914F6CDD1Dull) >> 32;
|
|
||||||
}
|
|
||||||
float random_f32() { // random float32 in [0,1)
|
|
||||||
return (random_u32() >> 8) / 16777216.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
int sample_argmax(float* probabilities, int n) {
|
int sample_argmax(float* probabilities, int n) {
|
||||||
// return the index that has the highest probability
|
// return the index that has the highest probability
|
||||||
int max_i = 0;
|
int max_i = 0;
|
||||||
@@ -588,13 +576,13 @@ int sample_argmax(float* probabilities, int n) {
|
|||||||
return max_i;
|
return max_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
int sample_mult(float* probabilities, int n) {
|
int sample_mult(float* probabilities, int n, float coin) {
|
||||||
// sample index from probabilities (they must sum to 1!)
|
// sample index from probabilities (they must sum to 1!)
|
||||||
float r = random_f32();
|
// coin is a random number in [0, 1), usually from random_f32()
|
||||||
float cdf = 0.0f;
|
float cdf = 0.0f;
|
||||||
for (int i = 0; i < n; i++) {
|
for (int i = 0; i < n; i++) {
|
||||||
cdf += probabilities[i];
|
cdf += probabilities[i];
|
||||||
if (r < cdf) {
|
if (coin < cdf) {
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -609,10 +597,11 @@ int compare(const void* a, const void* b) {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
|
||||||
// top-p sampling (or "nucleus sampling") samples from the smallest set of
|
// top-p sampling (or "nucleus sampling") samples from the smallest set of
|
||||||
// tokens that exceed probability topp. This way we never sample tokens that
|
// tokens that exceed probability topp. This way we never sample tokens that
|
||||||
// have very low probabilities and are less likely to go "off the rails".
|
// have very low probabilities and are less likely to go "off the rails".
|
||||||
|
// coin is a random number in [0, 1), usually from random_f32()
|
||||||
|
|
||||||
int n0 = 0;
|
int n0 = 0;
|
||||||
// quicksort indices in descending order of probabilities
|
// quicksort indices in descending order of probabilities
|
||||||
@@ -640,7 +629,7 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sample from the truncated list
|
// sample from the truncated list
|
||||||
float r = random_f32() * cumulative_prob;
|
float r = coin * cumulative_prob;
|
||||||
float cdf = 0.0f;
|
float cdf = 0.0f;
|
||||||
for (int i = 0; i <= last_idx; i++) {
|
for (int i = 0; i <= last_idx; i++) {
|
||||||
cdf += probindex[i].prob;
|
cdf += probindex[i].prob;
|
||||||
@@ -651,10 +640,11 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
|||||||
return probindex[last_idx].index; // in case of rounding errors
|
return probindex[last_idx].index; // in case of rounding errors
|
||||||
}
|
}
|
||||||
|
|
||||||
void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp) {
|
void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {
|
||||||
sampler->vocab_size = vocab_size;
|
sampler->vocab_size = vocab_size;
|
||||||
sampler->temperature = temperature;
|
sampler->temperature = temperature;
|
||||||
sampler->topp = topp;
|
sampler->topp = topp;
|
||||||
|
sampler->rng_state = rng_seed;
|
||||||
// buffer only used with nucleus sampling; may not need but it's ~small
|
// buffer only used with nucleus sampling; may not need but it's ~small
|
||||||
sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex));
|
sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex));
|
||||||
}
|
}
|
||||||
@@ -663,6 +653,17 @@ void free_sampler(Sampler* sampler) {
|
|||||||
free(sampler->probindex);
|
free(sampler->probindex);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsigned int random_u32(unsigned long long *state) {
|
||||||
|
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
|
||||||
|
*state ^= *state >> 12;
|
||||||
|
*state ^= *state << 25;
|
||||||
|
*state ^= *state >> 27;
|
||||||
|
return (*state * 0x2545F4914F6CDD1Dull) >> 32;
|
||||||
|
}
|
||||||
|
float random_f32(unsigned long long *state) { // random float32 in [0,1)
|
||||||
|
return (random_u32(state) >> 8) / 16777216.0f;
|
||||||
|
}
|
||||||
|
|
||||||
int sample(Sampler* sampler, float* logits) {
|
int sample(Sampler* sampler, float* logits) {
|
||||||
// sample the token given the logits and some hyperparameters
|
// sample the token given the logits and some hyperparameters
|
||||||
int next;
|
int next;
|
||||||
@@ -674,13 +675,15 @@ int sample(Sampler* sampler, float* logits) {
|
|||||||
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
|
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
|
||||||
// apply softmax to the logits to get the probabilities for next token
|
// apply softmax to the logits to get the probabilities for next token
|
||||||
softmax(logits, sampler->vocab_size);
|
softmax(logits, sampler->vocab_size);
|
||||||
|
// flip a (float) coin (this is our source of entropy for sampling)
|
||||||
|
float coin = random_f32(&sampler->rng_state);
|
||||||
// we sample from this distribution to get the next token
|
// we sample from this distribution to get the next token
|
||||||
if (sampler->topp <= 0 || sampler->topp >= 1) {
|
if (sampler->topp <= 0 || sampler->topp >= 1) {
|
||||||
// simply sample from the predicted probability distribution
|
// simply sample from the predicted probability distribution
|
||||||
next = sample_mult(logits, sampler->vocab_size);
|
next = sample_mult(logits, sampler->vocab_size, coin);
|
||||||
} else {
|
} else {
|
||||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||||
next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex);
|
next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return next;
|
return next;
|
||||||
@@ -775,9 +778,9 @@ int main(int argc, char *argv[]) {
|
|||||||
char *tokenizer_path = "tokenizer.bin";
|
char *tokenizer_path = "tokenizer.bin";
|
||||||
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
||||||
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
|
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
|
||||||
rng_seed = 0; // seed rng with time by default
|
|
||||||
int steps = 256; // number of steps to run for
|
int steps = 256; // number of steps to run for
|
||||||
char *prompt = NULL; // prompt string
|
char *prompt = NULL; // prompt string
|
||||||
|
unsigned long long rng_seed = 0; // seed rng with time by default
|
||||||
|
|
||||||
// poor man's C argparse so we can override the defaults above from the command line
|
// poor man's C argparse so we can override the defaults above from the command line
|
||||||
if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
|
if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
|
||||||
@@ -813,7 +816,7 @@ int main(int argc, char *argv[]) {
|
|||||||
|
|
||||||
// build the Sampler
|
// build the Sampler
|
||||||
Sampler sampler;
|
Sampler sampler;
|
||||||
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp);
|
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);
|
||||||
|
|
||||||
// run!
|
// run!
|
||||||
generate(&transformer, &tokenizer, &sampler, prompt, steps);
|
generate(&transformer, &tokenizer, &sampler, prompt, steps);
|
||||||
|
|||||||
Reference in New Issue
Block a user