diff --git a/run.c b/run.c index 1242596..b7b6183 100644 --- a/run.c +++ b/run.c @@ -164,7 +164,7 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh memory_map_weights(weights, config, weights_ptr, shared_weights); } -void build_transformer(char* checkpoint_path, Transformer *t) { +void build_transformer(Transformer *t, char* checkpoint_path) { // read in the Config and the Weights from the checkpoint read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size); // allocate the RunState buffers @@ -377,7 +377,7 @@ typedef struct { char byte_piece[2]; } Tokenizer; -void build_tokenizer(char* tokenizer_path, Tokenizer* t, int vocab_size) { +void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) { // i should have written the vocab_size into the tokenizer file... sigh t->vocab_size = vocab_size; // malloc space to hold the scores and the strings @@ -542,15 +542,21 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { } // ---------------------------------------------------------------------------- -// utilities: time / rng +// The Sampler, which takes logits and returns a sampled token +// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling -long time_in_ms() { - // return time in milliseconds, for benchmarking the model speed - struct timespec time; - clock_gettime(CLOCK_REALTIME, &time); - return time.tv_sec * 1000 + time.tv_nsec / 1000000; -} +typedef struct { + float prob; + int index; +} ProbIndex; // struct used when sorting probabilities during top-p sampling +typedef struct { + int vocab_size; + ProbIndex* probindex; // buffer used in top-p sampling +} Sampler; + +// rng should technically be a state variable of the Sampler +// leaving it global here for now for convenience, maybe move later unsigned long long rng_seed; unsigned int random_u32() { // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A @@ -563,15 +569,7 @@ float random_f32() { // random float32 in [0,1) return (random_u32() >> 8) / 16777216.0f; } -// ---------------------------------------------------------------------------- -// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling - -typedef struct { - float prob; - int index; -} ProbIndex; // struct used when sorting probabilities during top-p sampling - -int argmax(float* probabilities, int n) { +int sample_argmax(float* probabilities, int n) { // return the index that has the highest probability int max_i = 0; float max_p = probabilities[0]; @@ -584,7 +582,7 @@ int argmax(float* probabilities, int n) { return max_i; } -int sample(float* probabilities, int n) { +int sample_mult(float* probabilities, int n) { // sample index from probabilities (they must sum to 1!) float r = random_f32(); float cdf = 0.0f; @@ -647,6 +645,48 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) { return probindex[last_idx].index; // in case of rounding errors } +void build_sampler(Sampler* sampler, int vocab_size) { + sampler->vocab_size = vocab_size; + // probindex might not be needed, but it's a ~small buffer so we'll just malloc it + sampler->probindex = malloc(vocab_size * sizeof(ProbIndex)); +} + +void free_sampler(Sampler* sampler) { + free(sampler->probindex); +} + +int sample(Sampler* sampler, float* logits, float temperature, float topp) { + // sample the token given the logits and some hyperparameters + int next; + if (temperature == 0.0f) { + // greedy argmax sampling: take the token with the highest probability + next = sample_argmax(logits, sampler->vocab_size); + } else { + // apply the temperature to the logits + for (int q=0; qvocab_size; q++) { logits[q] /= temperature; } + // apply softmax to the logits to get the probabilities for next token + softmax(logits, sampler->vocab_size); + // we sample from this distribution to get the next token + if (topp <= 0 || topp >= 1) { + // simply sample from the predicted probability distribution + next = sample_mult(logits, sampler->vocab_size); + } else { + // top-p (nucleus) sampling, clamping the least likely tokens to zero + next = sample_topp(logits, sampler->vocab_size, topp, sampler->probindex); + } + } + return next; +} + +// ---------------------------------------------------------------------------- +// utilities: time + +long time_in_ms() { + // return time in milliseconds, for benchmarking the model speed + struct timespec time; + clock_gettime(CLOCK_REALTIME, &time); + return time.tv_sec * 1000 + time.tv_nsec / 1000000; +} // ---------------------------------------------------------------------------- // int main @@ -695,16 +735,18 @@ int main(int argc, char *argv[]) { // build the Transformer via the model .bin file Transformer transformer; - build_transformer(checkpoint_path, &transformer); + build_transformer(&transformer, checkpoint_path); int vocab_size = transformer.config.vocab_size; // convenience copy // build the Tokenizer via the tokenizer .bin file Tokenizer tokenizer; - build_tokenizer(tokenizer_path, &tokenizer, vocab_size); + build_tokenizer(&tokenizer, tokenizer_path, vocab_size); - // create and init the application RunState - ProbIndex *probindex = malloc(vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling - // process the prompt, if any + // build the Sampler + Sampler sampler; + build_sampler(&sampler, vocab_size); + + // encode the (string) prompt into tokens sequence, if any is given int *prompt_tokens = NULL; int num_prompt_tokens = 0; if (prompt != NULL) { @@ -723,28 +765,12 @@ int main(int argc, char *argv[]) { float* logits = forward(&transformer, token, pos); // advance the state state machine - if(pos < num_prompt_tokens) { + if (pos < num_prompt_tokens) { // if we are still processing the input prompt, force the next prompt token next = prompt_tokens[pos]; } else { - // sample the next token - if (temperature == 0.0f) { - // greedy argmax sampling: take the token with the highest probability - next = argmax(logits, vocab_size); - } else { - // apply the temperature to the logits - for (int q=0; q= 1) { - // simply sample from the predicted probability distribution - next = sample(logits, vocab_size); - } else { - // top-p (nucleus) sampling, clamping the least likely tokens to zero - next = sample_topp(logits, vocab_size, topp, probindex); - } - } + // otherwise sample the next token from the logits + next = sample(&sampler, logits, temperature, topp); } pos++; @@ -769,8 +795,8 @@ int main(int argc, char *argv[]) { } // memory and file handles cleanup - free(probindex); if (prompt_tokens != NULL) { free(prompt_tokens); } + free_sampler(&sampler); free_tokenizer(&tokenizer); free_transformer(&transformer); return 0;