and finally refactor the Sampler. things are starting to look a lot cleaner I think
This commit is contained in:
@@ -164,7 +164,7 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh
|
||||
memory_map_weights(weights, config, weights_ptr, shared_weights);
|
||||
}
|
||||
|
||||
void build_transformer(char* checkpoint_path, Transformer *t) {
|
||||
void build_transformer(Transformer *t, char* checkpoint_path) {
|
||||
// read in the Config and the Weights from the checkpoint
|
||||
read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
|
||||
// allocate the RunState buffers
|
||||
@@ -377,7 +377,7 @@ typedef struct {
|
||||
char byte_piece[2];
|
||||
} Tokenizer;
|
||||
|
||||
void build_tokenizer(char* tokenizer_path, Tokenizer* t, int vocab_size) {
|
||||
void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
|
||||
// i should have written the vocab_size into the tokenizer file... sigh
|
||||
t->vocab_size = vocab_size;
|
||||
// malloc space to hold the scores and the strings
|
||||
@@ -542,15 +542,21 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// utilities: time / rng
|
||||
// The Sampler, which takes logits and returns a sampled token
|
||||
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
|
||||
|
||||
long time_in_ms() {
|
||||
// return time in milliseconds, for benchmarking the model speed
|
||||
struct timespec time;
|
||||
clock_gettime(CLOCK_REALTIME, &time);
|
||||
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
|
||||
}
|
||||
typedef struct {
|
||||
float prob;
|
||||
int index;
|
||||
} ProbIndex; // struct used when sorting probabilities during top-p sampling
|
||||
|
||||
typedef struct {
|
||||
int vocab_size;
|
||||
ProbIndex* probindex; // buffer used in top-p sampling
|
||||
} Sampler;
|
||||
|
||||
// rng should technically be a state variable of the Sampler
|
||||
// leaving it global here for now for convenience, maybe move later
|
||||
unsigned long long rng_seed;
|
||||
unsigned int random_u32() {
|
||||
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
|
||||
@@ -563,15 +569,7 @@ float random_f32() { // random float32 in [0,1)
|
||||
return (random_u32() >> 8) / 16777216.0f;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
|
||||
|
||||
typedef struct {
|
||||
float prob;
|
||||
int index;
|
||||
} ProbIndex; // struct used when sorting probabilities during top-p sampling
|
||||
|
||||
int argmax(float* probabilities, int n) {
|
||||
int sample_argmax(float* probabilities, int n) {
|
||||
// return the index that has the highest probability
|
||||
int max_i = 0;
|
||||
float max_p = probabilities[0];
|
||||
@@ -584,7 +582,7 @@ int argmax(float* probabilities, int n) {
|
||||
return max_i;
|
||||
}
|
||||
|
||||
int sample(float* probabilities, int n) {
|
||||
int sample_mult(float* probabilities, int n) {
|
||||
// sample index from probabilities (they must sum to 1!)
|
||||
float r = random_f32();
|
||||
float cdf = 0.0f;
|
||||
@@ -647,6 +645,48 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
||||
return probindex[last_idx].index; // in case of rounding errors
|
||||
}
|
||||
|
||||
void build_sampler(Sampler* sampler, int vocab_size) {
|
||||
sampler->vocab_size = vocab_size;
|
||||
// probindex might not be needed, but it's a ~small buffer so we'll just malloc it
|
||||
sampler->probindex = malloc(vocab_size * sizeof(ProbIndex));
|
||||
}
|
||||
|
||||
void free_sampler(Sampler* sampler) {
|
||||
free(sampler->probindex);
|
||||
}
|
||||
|
||||
int sample(Sampler* sampler, float* logits, float temperature, float topp) {
|
||||
// sample the token given the logits and some hyperparameters
|
||||
int next;
|
||||
if (temperature == 0.0f) {
|
||||
// greedy argmax sampling: take the token with the highest probability
|
||||
next = sample_argmax(logits, sampler->vocab_size);
|
||||
} else {
|
||||
// apply the temperature to the logits
|
||||
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= temperature; }
|
||||
// apply softmax to the logits to get the probabilities for next token
|
||||
softmax(logits, sampler->vocab_size);
|
||||
// we sample from this distribution to get the next token
|
||||
if (topp <= 0 || topp >= 1) {
|
||||
// simply sample from the predicted probability distribution
|
||||
next = sample_mult(logits, sampler->vocab_size);
|
||||
} else {
|
||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||
next = sample_topp(logits, sampler->vocab_size, topp, sampler->probindex);
|
||||
}
|
||||
}
|
||||
return next;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// utilities: time
|
||||
|
||||
long time_in_ms() {
|
||||
// return time in milliseconds, for benchmarking the model speed
|
||||
struct timespec time;
|
||||
clock_gettime(CLOCK_REALTIME, &time);
|
||||
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// int main
|
||||
@@ -695,16 +735,18 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
// build the Transformer via the model .bin file
|
||||
Transformer transformer;
|
||||
build_transformer(checkpoint_path, &transformer);
|
||||
build_transformer(&transformer, checkpoint_path);
|
||||
int vocab_size = transformer.config.vocab_size; // convenience copy
|
||||
|
||||
// build the Tokenizer via the tokenizer .bin file
|
||||
Tokenizer tokenizer;
|
||||
build_tokenizer(tokenizer_path, &tokenizer, vocab_size);
|
||||
build_tokenizer(&tokenizer, tokenizer_path, vocab_size);
|
||||
|
||||
// create and init the application RunState
|
||||
ProbIndex *probindex = malloc(vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling
|
||||
// process the prompt, if any
|
||||
// build the Sampler
|
||||
Sampler sampler;
|
||||
build_sampler(&sampler, vocab_size);
|
||||
|
||||
// encode the (string) prompt into tokens sequence, if any is given
|
||||
int *prompt_tokens = NULL;
|
||||
int num_prompt_tokens = 0;
|
||||
if (prompt != NULL) {
|
||||
@@ -723,28 +765,12 @@ int main(int argc, char *argv[]) {
|
||||
float* logits = forward(&transformer, token, pos);
|
||||
|
||||
// advance the state state machine
|
||||
if(pos < num_prompt_tokens) {
|
||||
if (pos < num_prompt_tokens) {
|
||||
// if we are still processing the input prompt, force the next prompt token
|
||||
next = prompt_tokens[pos];
|
||||
} else {
|
||||
// sample the next token
|
||||
if (temperature == 0.0f) {
|
||||
// greedy argmax sampling: take the token with the highest probability
|
||||
next = argmax(logits, vocab_size);
|
||||
} else {
|
||||
// apply the temperature to the logits
|
||||
for (int q=0; q<vocab_size; q++) { logits[q] /= temperature; }
|
||||
// apply softmax to the logits to get the probabilities for next token
|
||||
softmax(logits, vocab_size);
|
||||
// we sample from this distribution to get the next token
|
||||
if (topp <= 0 || topp >= 1) {
|
||||
// simply sample from the predicted probability distribution
|
||||
next = sample(logits, vocab_size);
|
||||
} else {
|
||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||
next = sample_topp(logits, vocab_size, topp, probindex);
|
||||
}
|
||||
}
|
||||
// otherwise sample the next token from the logits
|
||||
next = sample(&sampler, logits, temperature, topp);
|
||||
}
|
||||
pos++;
|
||||
|
||||
@@ -769,8 +795,8 @@ int main(int argc, char *argv[]) {
|
||||
}
|
||||
|
||||
// memory and file handles cleanup
|
||||
free(probindex);
|
||||
if (prompt_tokens != NULL) { free(prompt_tokens); }
|
||||
free_sampler(&sampler);
|
||||
free_tokenizer(&tokenizer);
|
||||
free_transformer(&transformer);
|
||||
return 0;
|
||||
|
||||
Reference in New Issue
Block a user