hide temperature and topp into the sampler, it's a little bit less flexible but a little bit more cleaner
This commit is contained in:
@@ -557,6 +557,8 @@ typedef struct {
|
||||
typedef struct {
|
||||
int vocab_size;
|
||||
ProbIndex* probindex; // buffer used in top-p sampling
|
||||
float temperature;
|
||||
float topp;
|
||||
} Sampler;
|
||||
|
||||
// rng should technically be a state variable of the Sampler
|
||||
@@ -649,34 +651,36 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
||||
return probindex[last_idx].index; // in case of rounding errors
|
||||
}
|
||||
|
||||
void build_sampler(Sampler* sampler, int vocab_size) {
|
||||
void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp) {
|
||||
sampler->vocab_size = vocab_size;
|
||||
// probindex might not be needed, but it's a ~small buffer so we'll just malloc it
|
||||
sampler->probindex = malloc(vocab_size * sizeof(ProbIndex));
|
||||
sampler->temperature = temperature;
|
||||
sampler->topp = topp;
|
||||
// buffer only used with nucleus sampling; may not need but it's ~small
|
||||
sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex));
|
||||
}
|
||||
|
||||
void free_sampler(Sampler* sampler) {
|
||||
free(sampler->probindex);
|
||||
}
|
||||
|
||||
int sample(Sampler* sampler, float* logits, float temperature, float topp) {
|
||||
int sample(Sampler* sampler, float* logits) {
|
||||
// sample the token given the logits and some hyperparameters
|
||||
int next;
|
||||
if (temperature == 0.0f) {
|
||||
if (sampler->temperature == 0.0f) {
|
||||
// greedy argmax sampling: take the token with the highest probability
|
||||
next = sample_argmax(logits, sampler->vocab_size);
|
||||
} else {
|
||||
// apply the temperature to the logits
|
||||
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= temperature; }
|
||||
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
|
||||
// apply softmax to the logits to get the probabilities for next token
|
||||
softmax(logits, sampler->vocab_size);
|
||||
// we sample from this distribution to get the next token
|
||||
if (topp <= 0 || topp >= 1) {
|
||||
if (sampler->topp <= 0 || sampler->topp >= 1) {
|
||||
// simply sample from the predicted probability distribution
|
||||
next = sample_mult(logits, sampler->vocab_size);
|
||||
} else {
|
||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||
next = sample_topp(logits, sampler->vocab_size, topp, sampler->probindex);
|
||||
next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex);
|
||||
}
|
||||
}
|
||||
return next;
|
||||
@@ -753,7 +757,7 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
// build the Sampler
|
||||
Sampler sampler;
|
||||
build_sampler(&sampler, transformer.config.vocab_size);
|
||||
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp);
|
||||
|
||||
// encode the (string) prompt into tokens sequence, if any is given
|
||||
int *prompt_tokens = NULL; // the sequence of prompt tokens
|
||||
@@ -779,7 +783,7 @@ int main(int argc, char *argv[]) {
|
||||
next = prompt_tokens[pos];
|
||||
} else {
|
||||
// otherwise sample the next token from the logits
|
||||
next = sample(&sampler, logits, temperature, topp);
|
||||
next = sample(&sampler, logits);
|
||||
}
|
||||
pos++;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user