and finally refactor the Sampler. things are starting to look a lot cleaner I think
This commit is contained in:
@@ -164,7 +164,7 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh
|
|||||||
memory_map_weights(weights, config, weights_ptr, shared_weights);
|
memory_map_weights(weights, config, weights_ptr, shared_weights);
|
||||||
}
|
}
|
||||||
|
|
||||||
void build_transformer(char* checkpoint_path, Transformer *t) {
|
void build_transformer(Transformer *t, char* checkpoint_path) {
|
||||||
// read in the Config and the Weights from the checkpoint
|
// read in the Config and the Weights from the checkpoint
|
||||||
read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
|
read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
|
||||||
// allocate the RunState buffers
|
// allocate the RunState buffers
|
||||||
@@ -377,7 +377,7 @@ typedef struct {
|
|||||||
char byte_piece[2];
|
char byte_piece[2];
|
||||||
} Tokenizer;
|
} Tokenizer;
|
||||||
|
|
||||||
void build_tokenizer(char* tokenizer_path, Tokenizer* t, int vocab_size) {
|
void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
|
||||||
// i should have written the vocab_size into the tokenizer file... sigh
|
// i should have written the vocab_size into the tokenizer file... sigh
|
||||||
t->vocab_size = vocab_size;
|
t->vocab_size = vocab_size;
|
||||||
// malloc space to hold the scores and the strings
|
// malloc space to hold the scores and the strings
|
||||||
@@ -542,15 +542,21 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// utilities: time / rng
|
// The Sampler, which takes logits and returns a sampled token
|
||||||
|
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
|
||||||
|
|
||||||
long time_in_ms() {
|
typedef struct {
|
||||||
// return time in milliseconds, for benchmarking the model speed
|
float prob;
|
||||||
struct timespec time;
|
int index;
|
||||||
clock_gettime(CLOCK_REALTIME, &time);
|
} ProbIndex; // struct used when sorting probabilities during top-p sampling
|
||||||
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int vocab_size;
|
||||||
|
ProbIndex* probindex; // buffer used in top-p sampling
|
||||||
|
} Sampler;
|
||||||
|
|
||||||
|
// rng should technically be a state variable of the Sampler
|
||||||
|
// leaving it global here for now for convenience, maybe move later
|
||||||
unsigned long long rng_seed;
|
unsigned long long rng_seed;
|
||||||
unsigned int random_u32() {
|
unsigned int random_u32() {
|
||||||
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
|
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
|
||||||
@@ -563,15 +569,7 @@ float random_f32() { // random float32 in [0,1)
|
|||||||
return (random_u32() >> 8) / 16777216.0f;
|
return (random_u32() >> 8) / 16777216.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
int sample_argmax(float* probabilities, int n) {
|
||||||
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
float prob;
|
|
||||||
int index;
|
|
||||||
} ProbIndex; // struct used when sorting probabilities during top-p sampling
|
|
||||||
|
|
||||||
int argmax(float* probabilities, int n) {
|
|
||||||
// return the index that has the highest probability
|
// return the index that has the highest probability
|
||||||
int max_i = 0;
|
int max_i = 0;
|
||||||
float max_p = probabilities[0];
|
float max_p = probabilities[0];
|
||||||
@@ -584,7 +582,7 @@ int argmax(float* probabilities, int n) {
|
|||||||
return max_i;
|
return max_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
int sample(float* probabilities, int n) {
|
int sample_mult(float* probabilities, int n) {
|
||||||
// sample index from probabilities (they must sum to 1!)
|
// sample index from probabilities (they must sum to 1!)
|
||||||
float r = random_f32();
|
float r = random_f32();
|
||||||
float cdf = 0.0f;
|
float cdf = 0.0f;
|
||||||
@@ -647,6 +645,48 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
|||||||
return probindex[last_idx].index; // in case of rounding errors
|
return probindex[last_idx].index; // in case of rounding errors
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void build_sampler(Sampler* sampler, int vocab_size) {
|
||||||
|
sampler->vocab_size = vocab_size;
|
||||||
|
// probindex might not be needed, but it's a ~small buffer so we'll just malloc it
|
||||||
|
sampler->probindex = malloc(vocab_size * sizeof(ProbIndex));
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_sampler(Sampler* sampler) {
|
||||||
|
free(sampler->probindex);
|
||||||
|
}
|
||||||
|
|
||||||
|
int sample(Sampler* sampler, float* logits, float temperature, float topp) {
|
||||||
|
// sample the token given the logits and some hyperparameters
|
||||||
|
int next;
|
||||||
|
if (temperature == 0.0f) {
|
||||||
|
// greedy argmax sampling: take the token with the highest probability
|
||||||
|
next = sample_argmax(logits, sampler->vocab_size);
|
||||||
|
} else {
|
||||||
|
// apply the temperature to the logits
|
||||||
|
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= temperature; }
|
||||||
|
// apply softmax to the logits to get the probabilities for next token
|
||||||
|
softmax(logits, sampler->vocab_size);
|
||||||
|
// we sample from this distribution to get the next token
|
||||||
|
if (topp <= 0 || topp >= 1) {
|
||||||
|
// simply sample from the predicted probability distribution
|
||||||
|
next = sample_mult(logits, sampler->vocab_size);
|
||||||
|
} else {
|
||||||
|
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||||
|
next = sample_topp(logits, sampler->vocab_size, topp, sampler->probindex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// utilities: time
|
||||||
|
|
||||||
|
long time_in_ms() {
|
||||||
|
// return time in milliseconds, for benchmarking the model speed
|
||||||
|
struct timespec time;
|
||||||
|
clock_gettime(CLOCK_REALTIME, &time);
|
||||||
|
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
|
||||||
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// int main
|
// int main
|
||||||
@@ -695,16 +735,18 @@ int main(int argc, char *argv[]) {
|
|||||||
|
|
||||||
// build the Transformer via the model .bin file
|
// build the Transformer via the model .bin file
|
||||||
Transformer transformer;
|
Transformer transformer;
|
||||||
build_transformer(checkpoint_path, &transformer);
|
build_transformer(&transformer, checkpoint_path);
|
||||||
int vocab_size = transformer.config.vocab_size; // convenience copy
|
int vocab_size = transformer.config.vocab_size; // convenience copy
|
||||||
|
|
||||||
// build the Tokenizer via the tokenizer .bin file
|
// build the Tokenizer via the tokenizer .bin file
|
||||||
Tokenizer tokenizer;
|
Tokenizer tokenizer;
|
||||||
build_tokenizer(tokenizer_path, &tokenizer, vocab_size);
|
build_tokenizer(&tokenizer, tokenizer_path, vocab_size);
|
||||||
|
|
||||||
// create and init the application RunState
|
// build the Sampler
|
||||||
ProbIndex *probindex = malloc(vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling
|
Sampler sampler;
|
||||||
// process the prompt, if any
|
build_sampler(&sampler, vocab_size);
|
||||||
|
|
||||||
|
// encode the (string) prompt into tokens sequence, if any is given
|
||||||
int *prompt_tokens = NULL;
|
int *prompt_tokens = NULL;
|
||||||
int num_prompt_tokens = 0;
|
int num_prompt_tokens = 0;
|
||||||
if (prompt != NULL) {
|
if (prompt != NULL) {
|
||||||
@@ -727,24 +769,8 @@ int main(int argc, char *argv[]) {
|
|||||||
// if we are still processing the input prompt, force the next prompt token
|
// if we are still processing the input prompt, force the next prompt token
|
||||||
next = prompt_tokens[pos];
|
next = prompt_tokens[pos];
|
||||||
} else {
|
} else {
|
||||||
// sample the next token
|
// otherwise sample the next token from the logits
|
||||||
if (temperature == 0.0f) {
|
next = sample(&sampler, logits, temperature, topp);
|
||||||
// greedy argmax sampling: take the token with the highest probability
|
|
||||||
next = argmax(logits, vocab_size);
|
|
||||||
} else {
|
|
||||||
// apply the temperature to the logits
|
|
||||||
for (int q=0; q<vocab_size; q++) { logits[q] /= temperature; }
|
|
||||||
// apply softmax to the logits to get the probabilities for next token
|
|
||||||
softmax(logits, vocab_size);
|
|
||||||
// we sample from this distribution to get the next token
|
|
||||||
if (topp <= 0 || topp >= 1) {
|
|
||||||
// simply sample from the predicted probability distribution
|
|
||||||
next = sample(logits, vocab_size);
|
|
||||||
} else {
|
|
||||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
|
||||||
next = sample_topp(logits, vocab_size, topp, probindex);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
pos++;
|
pos++;
|
||||||
|
|
||||||
@@ -769,8 +795,8 @@ int main(int argc, char *argv[]) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// memory and file handles cleanup
|
// memory and file handles cleanup
|
||||||
free(probindex);
|
|
||||||
if (prompt_tokens != NULL) { free(prompt_tokens); }
|
if (prompt_tokens != NULL) { free(prompt_tokens); }
|
||||||
|
free_sampler(&sampler);
|
||||||
free_tokenizer(&tokenizer);
|
free_tokenizer(&tokenizer);
|
||||||
free_transformer(&transformer);
|
free_transformer(&transformer);
|
||||||
return 0;
|
return 0;
|
||||||
|
|||||||
Reference in New Issue
Block a user