final piece: run.c support for new tokenizer, super ez
This commit is contained in:
@@ -508,6 +508,7 @@ void error_usage() {
|
|||||||
fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
|
fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
|
||||||
fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
|
fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
|
||||||
fprintf(stderr, " -i <string> input prompt\n");
|
fprintf(stderr, " -i <string> input prompt\n");
|
||||||
|
fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
|
||||||
exit(EXIT_FAILURE);
|
exit(EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -515,6 +516,7 @@ int main(int argc, char *argv[]) {
|
|||||||
|
|
||||||
// default inits
|
// default inits
|
||||||
char *checkpoint = NULL; // e.g. out/model.bin
|
char *checkpoint = NULL; // e.g. out/model.bin
|
||||||
|
char *tokenizer = "tokenizer.bin";
|
||||||
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
||||||
float topp = 1.0f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
|
float topp = 1.0f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
|
||||||
rng_seed = 0; // seed rng with time by default
|
rng_seed = 0; // seed rng with time by default
|
||||||
@@ -534,6 +536,7 @@ int main(int argc, char *argv[]) {
|
|||||||
else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
|
else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
|
||||||
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
|
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
|
||||||
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
|
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
|
||||||
|
else if (argv[i][1] == 'z') { tokenizer = argv[i + 1]; }
|
||||||
else { error_usage(); }
|
else { error_usage(); }
|
||||||
}
|
}
|
||||||
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}
|
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}
|
||||||
@@ -572,8 +575,8 @@ int main(int argc, char *argv[]) {
|
|||||||
float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float));
|
float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float));
|
||||||
unsigned int max_token_length;
|
unsigned int max_token_length;
|
||||||
{
|
{
|
||||||
FILE *file = fopen("tokenizer.bin", "rb");
|
FILE *file = fopen(tokenizer, "rb");
|
||||||
if (!file) { fprintf(stderr, "couldn't load tokenizer.bin\n"); return 1; }
|
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); return 1; }
|
||||||
if (fread(&max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
|
if (fread(&max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
|
||||||
int len;
|
int len;
|
||||||
for (int i = 0; i < config.vocab_size; i++) {
|
for (int i = 0; i < config.vocab_size; i++) {
|
||||||
|
|||||||
Reference in New Issue
Block a user