hide temperature and topp into the sampler, it's a little bit less flexible but a little bit more cleaner

This commit is contained in:
Andrej Karpathy
2023-08-22 02:17:51 +00:00
parent 379f083b85
commit d73b917d3b
+14 -10
View File
@@ -557,6 +557,8 @@ typedef struct {
typedef struct {
int vocab_size;
ProbIndex* probindex; // buffer used in top-p sampling
float temperature;
float topp;
} Sampler;
// rng should technically be a state variable of the Sampler
@@ -649,34 +651,36 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
return probindex[last_idx].index; // in case of rounding errors
}
void build_sampler(Sampler* sampler, int vocab_size) {
void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp) {
sampler->vocab_size = vocab_size;
// probindex might not be needed, but it's a ~small buffer so we'll just malloc it
sampler->probindex = malloc(vocab_size * sizeof(ProbIndex));
sampler->temperature = temperature;
sampler->topp = topp;
// buffer only used with nucleus sampling; may not need but it's ~small
sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex));
}
void free_sampler(Sampler* sampler) {
free(sampler->probindex);
}
int sample(Sampler* sampler, float* logits, float temperature, float topp) {
int sample(Sampler* sampler, float* logits) {
// sample the token given the logits and some hyperparameters
int next;
if (temperature == 0.0f) {
if (sampler->temperature == 0.0f) {
// greedy argmax sampling: take the token with the highest probability
next = sample_argmax(logits, sampler->vocab_size);
} else {
// apply the temperature to the logits
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= temperature; }
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
// apply softmax to the logits to get the probabilities for next token
softmax(logits, sampler->vocab_size);
// we sample from this distribution to get the next token
if (topp <= 0 || topp >= 1) {
if (sampler->topp <= 0 || sampler->topp >= 1) {
// simply sample from the predicted probability distribution
next = sample_mult(logits, sampler->vocab_size);
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
next = sample_topp(logits, sampler->vocab_size, topp, sampler->probindex);
next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex);
}
}
return next;
@@ -753,7 +757,7 @@ int main(int argc, char *argv[]) {
// build the Sampler
Sampler sampler;
build_sampler(&sampler, transformer.config.vocab_size);
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp);
// encode the (string) prompt into tokens sequence, if any is given
int *prompt_tokens = NULL; // the sequence of prompt tokens
@@ -779,7 +783,7 @@ int main(int argc, char *argv[]) {
next = prompt_tokens[pos];
} else {
// otherwise sample the next token from the logits
next = sample(&sampler, logits, temperature, topp);
next = sample(&sampler, logits);
}
pos++;