and finally refactor the Sampler. things are starting to look a lot cleaner I think

This commit is contained in:
Andrej Karpathy
2023-08-21 04:23:02 +00:00
parent 8a377a1d31
commit 3868f732a4
+70 -44
View File
@@ -164,7 +164,7 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh
memory_map_weights(weights, config, weights_ptr, shared_weights);
}
void build_transformer(char* checkpoint_path, Transformer *t) {
void build_transformer(Transformer *t, char* checkpoint_path) {
// read in the Config and the Weights from the checkpoint
read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
// allocate the RunState buffers
@@ -377,7 +377,7 @@ typedef struct {
char byte_piece[2];
} Tokenizer;
void build_tokenizer(char* tokenizer_path, Tokenizer* t, int vocab_size) {
void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
// i should have written the vocab_size into the tokenizer file... sigh
t->vocab_size = vocab_size;
// malloc space to hold the scores and the strings
@@ -542,15 +542,21 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
}
// ----------------------------------------------------------------------------
// utilities: time / rng
// The Sampler, which takes logits and returns a sampled token
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
long time_in_ms() {
// return time in milliseconds, for benchmarking the model speed
struct timespec time;
clock_gettime(CLOCK_REALTIME, &time);
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
}
typedef struct {
float prob;
int index;
} ProbIndex; // struct used when sorting probabilities during top-p sampling
typedef struct {
int vocab_size;
ProbIndex* probindex; // buffer used in top-p sampling
} Sampler;
// rng should technically be a state variable of the Sampler
// leaving it global here for now for convenience, maybe move later
unsigned long long rng_seed;
unsigned int random_u32() {
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
@@ -563,15 +569,7 @@ float random_f32() { // random float32 in [0,1)
return (random_u32() >> 8) / 16777216.0f;
}
// ----------------------------------------------------------------------------
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
typedef struct {
float prob;
int index;
} ProbIndex; // struct used when sorting probabilities during top-p sampling
int argmax(float* probabilities, int n) {
int sample_argmax(float* probabilities, int n) {
// return the index that has the highest probability
int max_i = 0;
float max_p = probabilities[0];
@@ -584,7 +582,7 @@ int argmax(float* probabilities, int n) {
return max_i;
}
int sample(float* probabilities, int n) {
int sample_mult(float* probabilities, int n) {
// sample index from probabilities (they must sum to 1!)
float r = random_f32();
float cdf = 0.0f;
@@ -647,6 +645,48 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
return probindex[last_idx].index; // in case of rounding errors
}
void build_sampler(Sampler* sampler, int vocab_size) {
sampler->vocab_size = vocab_size;
// probindex might not be needed, but it's a ~small buffer so we'll just malloc it
sampler->probindex = malloc(vocab_size * sizeof(ProbIndex));
}
void free_sampler(Sampler* sampler) {
free(sampler->probindex);
}
int sample(Sampler* sampler, float* logits, float temperature, float topp) {
// sample the token given the logits and some hyperparameters
int next;
if (temperature == 0.0f) {
// greedy argmax sampling: take the token with the highest probability
next = sample_argmax(logits, sampler->vocab_size);
} else {
// apply the temperature to the logits
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= temperature; }
// apply softmax to the logits to get the probabilities for next token
softmax(logits, sampler->vocab_size);
// we sample from this distribution to get the next token
if (topp <= 0 || topp >= 1) {
// simply sample from the predicted probability distribution
next = sample_mult(logits, sampler->vocab_size);
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
next = sample_topp(logits, sampler->vocab_size, topp, sampler->probindex);
}
}
return next;
}
// ----------------------------------------------------------------------------
// utilities: time
long time_in_ms() {
// return time in milliseconds, for benchmarking the model speed
struct timespec time;
clock_gettime(CLOCK_REALTIME, &time);
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
}
// ----------------------------------------------------------------------------
// int main
@@ -695,16 +735,18 @@ int main(int argc, char *argv[]) {
// build the Transformer via the model .bin file
Transformer transformer;
build_transformer(checkpoint_path, &transformer);
build_transformer(&transformer, checkpoint_path);
int vocab_size = transformer.config.vocab_size; // convenience copy
// build the Tokenizer via the tokenizer .bin file
Tokenizer tokenizer;
build_tokenizer(tokenizer_path, &tokenizer, vocab_size);
build_tokenizer(&tokenizer, tokenizer_path, vocab_size);
// create and init the application RunState
ProbIndex *probindex = malloc(vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling
// process the prompt, if any
// build the Sampler
Sampler sampler;
build_sampler(&sampler, vocab_size);
// encode the (string) prompt into tokens sequence, if any is given
int *prompt_tokens = NULL;
int num_prompt_tokens = 0;
if (prompt != NULL) {
@@ -723,28 +765,12 @@ int main(int argc, char *argv[]) {
float* logits = forward(&transformer, token, pos);
// advance the state state machine
if(pos < num_prompt_tokens) {
if (pos < num_prompt_tokens) {
// if we are still processing the input prompt, force the next prompt token
next = prompt_tokens[pos];
} else {
// sample the next token
if (temperature == 0.0f) {
// greedy argmax sampling: take the token with the highest probability
next = argmax(logits, vocab_size);
} else {
// apply the temperature to the logits
for (int q=0; q<vocab_size; q++) { logits[q] /= temperature; }
// apply softmax to the logits to get the probabilities for next token
softmax(logits, vocab_size);
// we sample from this distribution to get the next token
if (topp <= 0 || topp >= 1) {
// simply sample from the predicted probability distribution
next = sample(logits, vocab_size);
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
next = sample_topp(logits, vocab_size, topp, probindex);
}
}
// otherwise sample the next token from the logits
next = sample(&sampler, logits, temperature, topp);
}
pos++;
@@ -769,8 +795,8 @@ int main(int argc, char *argv[]) {
}
// memory and file handles cleanup
free(probindex);
if (prompt_tokens != NULL) { free(prompt_tokens); }
free_sampler(&sampler);
free_tokenizer(&tokenizer);
free_transformer(&transformer);
return 0;