From d73b917d3ba7a1d933c65b1f33ecb58fd8d78a92 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 22 Aug 2023 02:17:51 +0000 Subject: [PATCH] hide temperature and topp into the sampler, it's a little bit less flexible but a little bit more cleaner --- run.c | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/run.c b/run.c index 705719c..ae78f67 100644 --- a/run.c +++ b/run.c @@ -557,6 +557,8 @@ typedef struct { typedef struct { int vocab_size; ProbIndex* probindex; // buffer used in top-p sampling + float temperature; + float topp; } Sampler; // rng should technically be a state variable of the Sampler @@ -649,34 +651,36 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) { return probindex[last_idx].index; // in case of rounding errors } -void build_sampler(Sampler* sampler, int vocab_size) { +void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp) { sampler->vocab_size = vocab_size; - // probindex might not be needed, but it's a ~small buffer so we'll just malloc it - sampler->probindex = malloc(vocab_size * sizeof(ProbIndex)); + sampler->temperature = temperature; + sampler->topp = topp; + // buffer only used with nucleus sampling; may not need but it's ~small + sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex)); } void free_sampler(Sampler* sampler) { free(sampler->probindex); } -int sample(Sampler* sampler, float* logits, float temperature, float topp) { +int sample(Sampler* sampler, float* logits) { // sample the token given the logits and some hyperparameters int next; - if (temperature == 0.0f) { + if (sampler->temperature == 0.0f) { // greedy argmax sampling: take the token with the highest probability next = sample_argmax(logits, sampler->vocab_size); } else { // apply the temperature to the logits - for (int q=0; qvocab_size; q++) { logits[q] /= temperature; } + for (int q=0; qvocab_size; q++) { logits[q] /= sampler->temperature; } // apply softmax to the logits to get the probabilities for next token softmax(logits, sampler->vocab_size); // we sample from this distribution to get the next token - if (topp <= 0 || topp >= 1) { + if (sampler->topp <= 0 || sampler->topp >= 1) { // simply sample from the predicted probability distribution next = sample_mult(logits, sampler->vocab_size); } else { // top-p (nucleus) sampling, clamping the least likely tokens to zero - next = sample_topp(logits, sampler->vocab_size, topp, sampler->probindex); + next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex); } } return next; @@ -753,7 +757,7 @@ int main(int argc, char *argv[]) { // build the Sampler Sampler sampler; - build_sampler(&sampler, transformer.config.vocab_size); + build_sampler(&sampler, transformer.config.vocab_size, temperature, topp); // encode the (string) prompt into tokens sequence, if any is given int *prompt_tokens = NULL; // the sequence of prompt tokens @@ -779,7 +783,7 @@ int main(int argc, char *argv[]) { next = prompt_tokens[pos]; } else { // otherwise sample the next token from the logits - next = sample(&sampler, logits, temperature, topp); + next = sample(&sampler, logits); } pos++;