and finally refactor the Sampler. things are starting to look a lot cleaner I think

This commit is contained in:
Andrej Karpathy
2023-08-21 04:23:02 +00:00
parent 8a377a1d31
commit 3868f732a4
+69 -43
View File
@@ -164,7 +164,7 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh
memory_map_weights(weights, config, weights_ptr, shared_weights); memory_map_weights(weights, config, weights_ptr, shared_weights);
} }
void build_transformer(char* checkpoint_path, Transformer *t) { void build_transformer(Transformer *t, char* checkpoint_path) {
// read in the Config and the Weights from the checkpoint // read in the Config and the Weights from the checkpoint
read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size); read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
// allocate the RunState buffers // allocate the RunState buffers
@@ -377,7 +377,7 @@ typedef struct {
char byte_piece[2]; char byte_piece[2];
} Tokenizer; } Tokenizer;
void build_tokenizer(char* tokenizer_path, Tokenizer* t, int vocab_size) { void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
// i should have written the vocab_size into the tokenizer file... sigh // i should have written the vocab_size into the tokenizer file... sigh
t->vocab_size = vocab_size; t->vocab_size = vocab_size;
// malloc space to hold the scores and the strings // malloc space to hold the scores and the strings
@@ -542,15 +542,21 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// utilities: time / rng // The Sampler, which takes logits and returns a sampled token
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
long time_in_ms() { typedef struct {
// return time in milliseconds, for benchmarking the model speed float prob;
struct timespec time; int index;
clock_gettime(CLOCK_REALTIME, &time); } ProbIndex; // struct used when sorting probabilities during top-p sampling
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
}
typedef struct {
int vocab_size;
ProbIndex* probindex; // buffer used in top-p sampling
} Sampler;
// rng should technically be a state variable of the Sampler
// leaving it global here for now for convenience, maybe move later
unsigned long long rng_seed; unsigned long long rng_seed;
unsigned int random_u32() { unsigned int random_u32() {
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
@@ -563,15 +569,7 @@ float random_f32() { // random float32 in [0,1)
return (random_u32() >> 8) / 16777216.0f; return (random_u32() >> 8) / 16777216.0f;
} }
// ---------------------------------------------------------------------------- int sample_argmax(float* probabilities, int n) {
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
typedef struct {
float prob;
int index;
} ProbIndex; // struct used when sorting probabilities during top-p sampling
int argmax(float* probabilities, int n) {
// return the index that has the highest probability // return the index that has the highest probability
int max_i = 0; int max_i = 0;
float max_p = probabilities[0]; float max_p = probabilities[0];
@@ -584,7 +582,7 @@ int argmax(float* probabilities, int n) {
return max_i; return max_i;
} }
int sample(float* probabilities, int n) { int sample_mult(float* probabilities, int n) {
// sample index from probabilities (they must sum to 1!) // sample index from probabilities (they must sum to 1!)
float r = random_f32(); float r = random_f32();
float cdf = 0.0f; float cdf = 0.0f;
@@ -647,6 +645,48 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
return probindex[last_idx].index; // in case of rounding errors return probindex[last_idx].index; // in case of rounding errors
} }
void build_sampler(Sampler* sampler, int vocab_size) {
sampler->vocab_size = vocab_size;
// probindex might not be needed, but it's a ~small buffer so we'll just malloc it
sampler->probindex = malloc(vocab_size * sizeof(ProbIndex));
}
void free_sampler(Sampler* sampler) {
free(sampler->probindex);
}
int sample(Sampler* sampler, float* logits, float temperature, float topp) {
// sample the token given the logits and some hyperparameters
int next;
if (temperature == 0.0f) {
// greedy argmax sampling: take the token with the highest probability
next = sample_argmax(logits, sampler->vocab_size);
} else {
// apply the temperature to the logits
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= temperature; }
// apply softmax to the logits to get the probabilities for next token
softmax(logits, sampler->vocab_size);
// we sample from this distribution to get the next token
if (topp <= 0 || topp >= 1) {
// simply sample from the predicted probability distribution
next = sample_mult(logits, sampler->vocab_size);
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
next = sample_topp(logits, sampler->vocab_size, topp, sampler->probindex);
}
}
return next;
}
// ----------------------------------------------------------------------------
// utilities: time
long time_in_ms() {
// return time in milliseconds, for benchmarking the model speed
struct timespec time;
clock_gettime(CLOCK_REALTIME, &time);
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// int main // int main
@@ -695,16 +735,18 @@ int main(int argc, char *argv[]) {
// build the Transformer via the model .bin file // build the Transformer via the model .bin file
Transformer transformer; Transformer transformer;
build_transformer(checkpoint_path, &transformer); build_transformer(&transformer, checkpoint_path);
int vocab_size = transformer.config.vocab_size; // convenience copy int vocab_size = transformer.config.vocab_size; // convenience copy
// build the Tokenizer via the tokenizer .bin file // build the Tokenizer via the tokenizer .bin file
Tokenizer tokenizer; Tokenizer tokenizer;
build_tokenizer(tokenizer_path, &tokenizer, vocab_size); build_tokenizer(&tokenizer, tokenizer_path, vocab_size);
// create and init the application RunState // build the Sampler
ProbIndex *probindex = malloc(vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling Sampler sampler;
// process the prompt, if any build_sampler(&sampler, vocab_size);
// encode the (string) prompt into tokens sequence, if any is given
int *prompt_tokens = NULL; int *prompt_tokens = NULL;
int num_prompt_tokens = 0; int num_prompt_tokens = 0;
if (prompt != NULL) { if (prompt != NULL) {
@@ -727,24 +769,8 @@ int main(int argc, char *argv[]) {
// if we are still processing the input prompt, force the next prompt token // if we are still processing the input prompt, force the next prompt token
next = prompt_tokens[pos]; next = prompt_tokens[pos];
} else { } else {
// sample the next token // otherwise sample the next token from the logits
if (temperature == 0.0f) { next = sample(&sampler, logits, temperature, topp);
// greedy argmax sampling: take the token with the highest probability
next = argmax(logits, vocab_size);
} else {
// apply the temperature to the logits
for (int q=0; q<vocab_size; q++) { logits[q] /= temperature; }
// apply softmax to the logits to get the probabilities for next token
softmax(logits, vocab_size);
// we sample from this distribution to get the next token
if (topp <= 0 || topp >= 1) {
// simply sample from the predicted probability distribution
next = sample(logits, vocab_size);
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
next = sample_topp(logits, vocab_size, topp, probindex);
}
}
} }
pos++; pos++;
@@ -769,8 +795,8 @@ int main(int argc, char *argv[]) {
} }
// memory and file handles cleanup // memory and file handles cleanup
free(probindex);
if (prompt_tokens != NULL) { free(prompt_tokens); } if (prompt_tokens != NULL) { free(prompt_tokens); }
free_sampler(&sampler);
free_tokenizer(&tokenizer); free_tokenizer(&tokenizer);
free_transformer(&transformer); free_transformer(&transformer);
return 0; return 0;