diff --git a/run.c b/run.c index dc232e5..1f50d59 100644 --- a/run.c +++ b/run.c @@ -455,7 +455,8 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { } // create a temporary buffer that will store merge candidates of always two consecutive tokens - char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1) + // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1) + char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); size_t str_len = 0; // add_dummy_prefix is true by default @@ -559,22 +560,9 @@ typedef struct { ProbIndex* probindex; // buffer used in top-p sampling float temperature; float topp; + unsigned long long rng_state; } Sampler; -// rng should technically be a state variable of the Sampler -// leaving it global here for now for convenience, maybe move later -unsigned long long rng_seed; -unsigned int random_u32() { - // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A - rng_seed ^= rng_seed >> 12; - rng_seed ^= rng_seed << 25; - rng_seed ^= rng_seed >> 27; - return (rng_seed * 0x2545F4914F6CDD1Dull) >> 32; -} -float random_f32() { // random float32 in [0,1) - return (random_u32() >> 8) / 16777216.0f; -} - int sample_argmax(float* probabilities, int n) { // return the index that has the highest probability int max_i = 0; @@ -588,13 +576,13 @@ int sample_argmax(float* probabilities, int n) { return max_i; } -int sample_mult(float* probabilities, int n) { +int sample_mult(float* probabilities, int n, float coin) { // sample index from probabilities (they must sum to 1!) - float r = random_f32(); + // coin is a random number in [0, 1), usually from random_f32() float cdf = 0.0f; for (int i = 0; i < n; i++) { cdf += probabilities[i]; - if (r < cdf) { + if (coin < cdf) { return i; } } @@ -609,10 +597,11 @@ int compare(const void* a, const void* b) { return 0; } -int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) { +int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) { // top-p sampling (or "nucleus sampling") samples from the smallest set of // tokens that exceed probability topp. This way we never sample tokens that // have very low probabilities and are less likely to go "off the rails". + // coin is a random number in [0, 1), usually from random_f32() int n0 = 0; // quicksort indices in descending order of probabilities @@ -640,7 +629,7 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) { } // sample from the truncated list - float r = random_f32() * cumulative_prob; + float r = coin * cumulative_prob; float cdf = 0.0f; for (int i = 0; i <= last_idx; i++) { cdf += probindex[i].prob; @@ -651,10 +640,11 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) { return probindex[last_idx].index; // in case of rounding errors } -void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp) { +void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) { sampler->vocab_size = vocab_size; sampler->temperature = temperature; sampler->topp = topp; + sampler->rng_state = rng_seed; // buffer only used with nucleus sampling; may not need but it's ~small sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex)); } @@ -663,6 +653,17 @@ void free_sampler(Sampler* sampler) { free(sampler->probindex); } +unsigned int random_u32(unsigned long long *state) { + // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A + *state ^= *state >> 12; + *state ^= *state << 25; + *state ^= *state >> 27; + return (*state * 0x2545F4914F6CDD1Dull) >> 32; +} +float random_f32(unsigned long long *state) { // random float32 in [0,1) + return (random_u32(state) >> 8) / 16777216.0f; +} + int sample(Sampler* sampler, float* logits) { // sample the token given the logits and some hyperparameters int next; @@ -674,13 +675,15 @@ int sample(Sampler* sampler, float* logits) { for (int q=0; qvocab_size; q++) { logits[q] /= sampler->temperature; } // apply softmax to the logits to get the probabilities for next token softmax(logits, sampler->vocab_size); + // flip a (float) coin (this is our source of entropy for sampling) + float coin = random_f32(&sampler->rng_state); // we sample from this distribution to get the next token if (sampler->topp <= 0 || sampler->topp >= 1) { // simply sample from the predicted probability distribution - next = sample_mult(logits, sampler->vocab_size); + next = sample_mult(logits, sampler->vocab_size, coin); } else { // top-p (nucleus) sampling, clamping the least likely tokens to zero - next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex); + next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin); } } return next; @@ -775,9 +778,9 @@ int main(int argc, char *argv[]) { char *tokenizer_path = "tokenizer.bin"; float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower - rng_seed = 0; // seed rng with time by default int steps = 256; // number of steps to run for char *prompt = NULL; // prompt string + unsigned long long rng_seed = 0; // seed rng with time by default // poor man's C argparse so we can override the defaults above from the command line if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); } @@ -813,7 +816,7 @@ int main(int argc, char *argv[]) { // build the Sampler Sampler sampler; - build_sampler(&sampler, transformer.config.vocab_size, temperature, topp); + build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed); // run! generate(&transformer, &tokenizer, &sampler, prompt, steps);