hide temperature and topp into the sampler, it's a little bit less flexible but a little bit more cleaner
This commit is contained in:
@@ -557,6 +557,8 @@ typedef struct {
|
|||||||
typedef struct {
|
typedef struct {
|
||||||
int vocab_size;
|
int vocab_size;
|
||||||
ProbIndex* probindex; // buffer used in top-p sampling
|
ProbIndex* probindex; // buffer used in top-p sampling
|
||||||
|
float temperature;
|
||||||
|
float topp;
|
||||||
} Sampler;
|
} Sampler;
|
||||||
|
|
||||||
// rng should technically be a state variable of the Sampler
|
// rng should technically be a state variable of the Sampler
|
||||||
@@ -649,34 +651,36 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
|||||||
return probindex[last_idx].index; // in case of rounding errors
|
return probindex[last_idx].index; // in case of rounding errors
|
||||||
}
|
}
|
||||||
|
|
||||||
void build_sampler(Sampler* sampler, int vocab_size) {
|
void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp) {
|
||||||
sampler->vocab_size = vocab_size;
|
sampler->vocab_size = vocab_size;
|
||||||
// probindex might not be needed, but it's a ~small buffer so we'll just malloc it
|
sampler->temperature = temperature;
|
||||||
sampler->probindex = malloc(vocab_size * sizeof(ProbIndex));
|
sampler->topp = topp;
|
||||||
|
// buffer only used with nucleus sampling; may not need but it's ~small
|
||||||
|
sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex));
|
||||||
}
|
}
|
||||||
|
|
||||||
void free_sampler(Sampler* sampler) {
|
void free_sampler(Sampler* sampler) {
|
||||||
free(sampler->probindex);
|
free(sampler->probindex);
|
||||||
}
|
}
|
||||||
|
|
||||||
int sample(Sampler* sampler, float* logits, float temperature, float topp) {
|
int sample(Sampler* sampler, float* logits) {
|
||||||
// sample the token given the logits and some hyperparameters
|
// sample the token given the logits and some hyperparameters
|
||||||
int next;
|
int next;
|
||||||
if (temperature == 0.0f) {
|
if (sampler->temperature == 0.0f) {
|
||||||
// greedy argmax sampling: take the token with the highest probability
|
// greedy argmax sampling: take the token with the highest probability
|
||||||
next = sample_argmax(logits, sampler->vocab_size);
|
next = sample_argmax(logits, sampler->vocab_size);
|
||||||
} else {
|
} else {
|
||||||
// apply the temperature to the logits
|
// apply the temperature to the logits
|
||||||
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= temperature; }
|
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
|
||||||
// apply softmax to the logits to get the probabilities for next token
|
// apply softmax to the logits to get the probabilities for next token
|
||||||
softmax(logits, sampler->vocab_size);
|
softmax(logits, sampler->vocab_size);
|
||||||
// we sample from this distribution to get the next token
|
// we sample from this distribution to get the next token
|
||||||
if (topp <= 0 || topp >= 1) {
|
if (sampler->topp <= 0 || sampler->topp >= 1) {
|
||||||
// simply sample from the predicted probability distribution
|
// simply sample from the predicted probability distribution
|
||||||
next = sample_mult(logits, sampler->vocab_size);
|
next = sample_mult(logits, sampler->vocab_size);
|
||||||
} else {
|
} else {
|
||||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||||
next = sample_topp(logits, sampler->vocab_size, topp, sampler->probindex);
|
next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return next;
|
return next;
|
||||||
@@ -753,7 +757,7 @@ int main(int argc, char *argv[]) {
|
|||||||
|
|
||||||
// build the Sampler
|
// build the Sampler
|
||||||
Sampler sampler;
|
Sampler sampler;
|
||||||
build_sampler(&sampler, transformer.config.vocab_size);
|
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp);
|
||||||
|
|
||||||
// encode the (string) prompt into tokens sequence, if any is given
|
// encode the (string) prompt into tokens sequence, if any is given
|
||||||
int *prompt_tokens = NULL; // the sequence of prompt tokens
|
int *prompt_tokens = NULL; // the sequence of prompt tokens
|
||||||
@@ -779,7 +783,7 @@ int main(int argc, char *argv[]) {
|
|||||||
next = prompt_tokens[pos];
|
next = prompt_tokens[pos];
|
||||||
} else {
|
} else {
|
||||||
// otherwise sample the next token from the logits
|
// otherwise sample the next token from the logits
|
||||||
next = sample(&sampler, logits, temperature, topp);
|
next = sample(&sampler, logits);
|
||||||
}
|
}
|
||||||
pos++;
|
pos++;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user