big change: adding prompting. many LOC, but critical. ty @atamurad for the first draft, i ended up tuning it quite a bit.
This commit is contained in:
@@ -337,6 +337,9 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
|||||||
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
|
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// functions to sample the next token from the transformer's predicted distribution
|
||||||
|
|
||||||
int sample(float* probabilities, int n) {
|
int sample(float* probabilities, int n) {
|
||||||
// sample index from probabilities, they must sum to 1
|
// sample index from probabilities, they must sum to 1
|
||||||
float r = (float)rand() / (float)RAND_MAX;
|
float r = (float)rand() / (float)RAND_MAX;
|
||||||
@@ -362,14 +365,76 @@ int argmax(float* v, int n) {
|
|||||||
}
|
}
|
||||||
return max_i;
|
return max_i;
|
||||||
}
|
}
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt
|
||||||
|
|
||||||
|
int str_lookup(char *str, char **vocab, int vocab_size) {
|
||||||
|
// find the first perfect match for str in vocab, return its index or -1 if not found
|
||||||
|
for (int i = 0; i < vocab_size; i++) {
|
||||||
|
if (strcmp(str, vocab[i]) == 0) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) {
|
||||||
|
|
||||||
|
// a temporary buffer to merge two consecutive tokens
|
||||||
|
char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator
|
||||||
|
|
||||||
|
// first encode every individual byte in the input string
|
||||||
|
*n_tokens = 0; // the number of tokens
|
||||||
|
for (char *c = text; *c != '\0'; c++) {
|
||||||
|
sprintf(str_buffer, "%c", *c);
|
||||||
|
int id = str_lookup(str_buffer, vocab, vocab_size);
|
||||||
|
if (id == -1) { printf("not good\n"); exit(1);}
|
||||||
|
tokens[*n_tokens] = id;
|
||||||
|
(*n_tokens)++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// merge the best consecutive pair each iteration, according the scores in vocab_scores
|
||||||
|
while (1) {
|
||||||
|
float best_score = -1e10;
|
||||||
|
int best_id = -1;
|
||||||
|
int best_idx = -1;
|
||||||
|
|
||||||
|
for (int i=0; i < (*n_tokens-1); i++) {
|
||||||
|
// check if we can merge the pair (tokens[i], tokens[i+1])
|
||||||
|
sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]);
|
||||||
|
int id = str_lookup(str_buffer, vocab, vocab_size);
|
||||||
|
if (id != -1 && vocab_scores[id] > best_score) {
|
||||||
|
// this merge pair exists in vocab! record its score and position
|
||||||
|
best_score = vocab_scores[id];
|
||||||
|
best_id = id;
|
||||||
|
best_idx = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (best_idx == -1) {
|
||||||
|
break; // we couldn't find any more pairs to merge, so we're done
|
||||||
|
}
|
||||||
|
|
||||||
|
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
|
||||||
|
tokens[best_idx] = best_id;
|
||||||
|
// delete token at position best_idx+1, shift the entire sequence back 1
|
||||||
|
for (int i = best_idx+1; i < (*n_tokens-1); i++) {
|
||||||
|
tokens[i] = tokens[i+1];
|
||||||
|
}
|
||||||
|
(*n_tokens)--; // token length decreased
|
||||||
|
}
|
||||||
|
|
||||||
|
free(str_buffer);
|
||||||
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
|
// utilities
|
||||||
long time_in_ms() {
|
long time_in_ms() {
|
||||||
struct timespec time;
|
struct timespec time;
|
||||||
clock_gettime(CLOCK_REALTIME, &time);
|
clock_gettime(CLOCK_REALTIME, &time);
|
||||||
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
|
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
|
||||||
}
|
}
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
|
|
||||||
@@ -377,9 +442,11 @@ int main(int argc, char *argv[]) {
|
|||||||
char *checkpoint = NULL; // e.g. out/model.bin
|
char *checkpoint = NULL; // e.g. out/model.bin
|
||||||
float temperature = 0.9f; // e.g. 1.0, or 0.0
|
float temperature = 0.9f; // e.g. 1.0, or 0.0
|
||||||
int steps = 256; // max number of steps to run for, 0: use seq_len
|
int steps = 256; // max number of steps to run for, 0: use seq_len
|
||||||
|
char *prompt = NULL; // prompt string
|
||||||
|
|
||||||
// 'checkpoint' is necessary arg
|
// 'checkpoint' is necessary arg
|
||||||
if (argc < 2) {
|
if (argc < 2) {
|
||||||
printf("Usage: %s <checkpoint_file> [temperature] [steps]\n", argv[0]);
|
printf("Usage: %s <checkpoint_file> [temperature] [steps] [prompt]\n", argv[0]);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
if (argc >= 2) {
|
if (argc >= 2) {
|
||||||
@@ -392,6 +459,9 @@ int main(int argc, char *argv[]) {
|
|||||||
if (argc >= 4) {
|
if (argc >= 4) {
|
||||||
steps = atoi(argv[3]);
|
steps = atoi(argv[3]);
|
||||||
}
|
}
|
||||||
|
if (argc >= 5) {
|
||||||
|
prompt = argv[4];
|
||||||
|
}
|
||||||
|
|
||||||
// seed rng with time. if you want deterministic behavior use temperature 0.0
|
// seed rng with time. if you want deterministic behavior use temperature 0.0
|
||||||
srand((unsigned int)time(NULL));
|
srand((unsigned int)time(NULL));
|
||||||
@@ -406,7 +476,7 @@ int main(int argc, char *argv[]) {
|
|||||||
FILE *file = fopen(checkpoint, "rb");
|
FILE *file = fopen(checkpoint, "rb");
|
||||||
if (!file) { printf("Couldn't open file %s\n", checkpoint); return 1; }
|
if (!file) { printf("Couldn't open file %s\n", checkpoint); return 1; }
|
||||||
// read in the config header
|
// read in the config header
|
||||||
if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
|
if (fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
|
||||||
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
|
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
|
||||||
int shared_weights = config.vocab_size > 0 ? 1 : 0;
|
int shared_weights = config.vocab_size > 0 ? 1 : 0;
|
||||||
config.vocab_size = abs(config.vocab_size);
|
config.vocab_size = abs(config.vocab_size);
|
||||||
@@ -427,14 +497,18 @@ int main(int argc, char *argv[]) {
|
|||||||
|
|
||||||
// read in the tokenizer.bin file
|
// read in the tokenizer.bin file
|
||||||
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
|
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
|
||||||
|
float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float));
|
||||||
|
unsigned int max_token_length;
|
||||||
{
|
{
|
||||||
FILE *file = fopen("tokenizer.bin", "rb");
|
FILE *file = fopen("tokenizer.bin", "rb");
|
||||||
if (!file) { printf("Couldn't load tokenizer.bin\n"); return 1; }
|
if (!file) { printf("couldn't load tokenizer.bin\n"); return 1; }
|
||||||
|
if (fread(&max_token_length, sizeof(int), 1, file) != 1) { printf("failed read\n"); return 1; }
|
||||||
int len;
|
int len;
|
||||||
for (int i = 0; i < config.vocab_size; i++) {
|
for (int i = 0; i < config.vocab_size; i++) {
|
||||||
if(fread(&len, sizeof(int), 1, file) != 1) { return 1; }
|
if (fread(vocab_scores + i, sizeof(float), 1, file) != 1) { printf("failed read\n"); return 1;}
|
||||||
|
if (fread(&len, sizeof(int), 1, file) != 1) { printf("failed read\n"); return 1; }
|
||||||
vocab[i] = (char *)malloc(len + 1);
|
vocab[i] = (char *)malloc(len + 1);
|
||||||
if(fread(vocab[i], len, 1, file) != 1) { return 1; }
|
if (fread(vocab[i], len, 1, file) != 1) { printf("failed read\n"); return 1; }
|
||||||
vocab[i][len] = '\0'; // add the string terminating token
|
vocab[i][len] = '\0'; // add the string terminating token
|
||||||
}
|
}
|
||||||
fclose(file);
|
fclose(file);
|
||||||
@@ -444,29 +518,43 @@ int main(int argc, char *argv[]) {
|
|||||||
RunState state;
|
RunState state;
|
||||||
malloc_run_state(&state, &config);
|
malloc_run_state(&state, &config);
|
||||||
|
|
||||||
// the current position we are in
|
// process the prompt, if any
|
||||||
long start = 0; // used to time our code, only initialized after first iteration
|
int *prompt_tokens = NULL;
|
||||||
int next;
|
int num_prompt_tokens = 0;
|
||||||
int token = 1; // 1 = BOS token in Llama-2 sentencepiece
|
if (prompt != NULL) {
|
||||||
int pos = 0;
|
prompt_tokens = (int*)malloc(config.seq_len * sizeof(int));
|
||||||
printf("<s>\n"); // explicit print the initial BOS token (=1), stylistically symmetric
|
bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
// start the main loop
|
||||||
|
long start = 0; // used to time our code, only initialized after first iteration
|
||||||
|
int next; // will store the next token in the sequence
|
||||||
|
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
|
||||||
|
int pos = 0; // position in the sequence
|
||||||
|
printf("<s>\n"); // explicit print the initial BOS token for stylistic symmetry reasons
|
||||||
while (pos < steps) {
|
while (pos < steps) {
|
||||||
|
|
||||||
// forward the transformer to get logits for the next token
|
// forward the transformer to get logits for the next token
|
||||||
transformer(token, pos, &config, &state, &weights);
|
transformer(token, pos, &config, &state, &weights);
|
||||||
|
|
||||||
// sample the next token
|
if(pos < num_prompt_tokens) {
|
||||||
if(temperature == 0.0f) {
|
// if we are still processing the input prompt, force the next prompt token
|
||||||
// greedy argmax sampling
|
next = prompt_tokens[pos];
|
||||||
next = argmax(state.logits, config.vocab_size);
|
|
||||||
} else {
|
} else {
|
||||||
// apply the temperature to the logits
|
// sample the next token
|
||||||
for (int q=0; q<config.vocab_size; q++) { state.logits[q] /= temperature; }
|
if (temperature == 0.0f) {
|
||||||
// apply softmax to the logits to get the probabilities for next token
|
// greedy argmax sampling: take the token with the highest probability
|
||||||
softmax(state.logits, config.vocab_size);
|
next = argmax(state.logits, config.vocab_size);
|
||||||
// we now want to sample from this distribution to get the next token
|
} else {
|
||||||
next = sample(state.logits, config.vocab_size);
|
// apply the temperature to the logits
|
||||||
|
for (int q=0; q<config.vocab_size; q++) { state.logits[q] /= temperature; }
|
||||||
|
// apply softmax to the logits to get the probabilities for next token
|
||||||
|
softmax(state.logits, config.vocab_size);
|
||||||
|
// we sample from this distribution to get the next token
|
||||||
|
next = sample(state.logits, config.vocab_size);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89)
|
// following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||||
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
|
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
|
||||||
printf("%s", token_str);
|
printf("%s", token_str);
|
||||||
@@ -487,6 +575,8 @@ int main(int argc, char *argv[]) {
|
|||||||
free_run_state(&state);
|
free_run_state(&state);
|
||||||
for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); }
|
for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); }
|
||||||
free(vocab);
|
free(vocab);
|
||||||
|
free(vocab_scores);
|
||||||
|
if (prompt_tokens != NULL) free(prompt_tokens);
|
||||||
if (data != MAP_FAILED) munmap(data, file_size);
|
if (data != MAP_FAILED) munmap(data, file_size);
|
||||||
if (fd != -1) close(fd);
|
if (fd != -1) close(fd);
|
||||||
return 0;
|
return 0;
|
||||||
|
|||||||
Binary file not shown.
+17
-7
@@ -3,6 +3,7 @@
|
|||||||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import struct
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@@ -39,26 +40,35 @@ class Tokenizer:
|
|||||||
return self.sp_model.decode(t)
|
return self.sp_model.decode(t)
|
||||||
|
|
||||||
def export(self):
|
def export(self):
|
||||||
tokens = []
|
|
||||||
|
# get all the tokens (postprocessed) and their scores as floats
|
||||||
|
tokens, scores = [], []
|
||||||
for i in range(self.n_words):
|
for i in range(self.n_words):
|
||||||
|
|
||||||
# decode the token and light postprocessing
|
# decode the token and light postprocessing
|
||||||
t = self.sp_model.id_to_piece(i)
|
t = self.sp_model.id_to_piece(i)
|
||||||
|
s = self.sp_model.get_score(i)
|
||||||
if i == self.bos_id:
|
if i == self.bos_id:
|
||||||
t = '\n<s>\n'
|
t = '\n<s>\n'
|
||||||
elif i == self.eos_id:
|
elif i == self.eos_id:
|
||||||
t = '\n</s>\n'
|
t = '\n</s>\n'
|
||||||
elif len(t) == 6 and t.startswith('<0x') and t.endswith('>'):
|
elif len(t) == 6 and t.startswith('<0x') and t.endswith('>'):
|
||||||
t = chr(int(t[3:5], 16)) # e.g. make '<0x01>' into '\x01'
|
t = chr(int(t[3:5], 16)) # e.g. make '<0x01>' into '\x01'
|
||||||
t = t.replace('▁', ' ') # sentencepiece uses this as the whitespace
|
t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace
|
||||||
|
b = t.encode('utf-8') # bytes of this token, utf-8 encoded
|
||||||
|
|
||||||
tokens.append(t)
|
tokens.append(b)
|
||||||
|
scores.append(s)
|
||||||
|
|
||||||
|
# record the max token length
|
||||||
|
max_token_length = max(len(t) for t in tokens)
|
||||||
|
|
||||||
|
# write to a binary file
|
||||||
with open(TOKENIZER_BIN, 'wb') as f:
|
with open(TOKENIZER_BIN, 'wb') as f:
|
||||||
for token in tokens:
|
f.write(struct.pack("I", max_token_length))
|
||||||
bytes = token.encode('utf-8')
|
for bytes, score in zip(tokens, scores):
|
||||||
f.write((len(bytes)).to_bytes(4, 'little')) # write length of bytes
|
f.write(struct.pack("fI", score, len(bytes)))
|
||||||
f.write(bytes) # write token bytes
|
f.write(bytes)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
t = Tokenizer()
|
t = Tokenizer()
|
||||||
|
|||||||
Reference in New Issue
Block a user