diff --git a/run.c b/run.c index afe695f..14469ad 100644 --- a/run.c +++ b/run.c @@ -508,6 +508,7 @@ void error_usage() { fprintf(stderr, " -s random seed, default time(NULL)\n"); fprintf(stderr, " -n number of steps to run for, default 256. 0 = max_seq_len\n"); fprintf(stderr, " -i input prompt\n"); + fprintf(stderr, " -z optional path to custom tokenizer\n"); exit(EXIT_FAILURE); } @@ -515,6 +516,7 @@ int main(int argc, char *argv[]) { // default inits char *checkpoint = NULL; // e.g. out/model.bin + char *tokenizer = "tokenizer.bin"; float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher float topp = 1.0f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower rng_seed = 0; // seed rng with time by default @@ -534,6 +536,7 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); } else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); } else if (argv[i][1] == 'i') { prompt = argv[i + 1]; } + else if (argv[i][1] == 'z') { tokenizer = argv[i + 1]; } else { error_usage(); } } if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);} @@ -567,13 +570,13 @@ int main(int argc, char *argv[]) { // right now we cannot run for more than config.seq_len steps if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; } - // read in the tokenizer.bin file + // read in the tokenizer .bin file char** vocab = (char**)malloc(config.vocab_size * sizeof(char*)); float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float)); unsigned int max_token_length; { - FILE *file = fopen("tokenizer.bin", "rb"); - if (!file) { fprintf(stderr, "couldn't load tokenizer.bin\n"); return 1; } + FILE *file = fopen(tokenizer, "rb"); + if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); return 1; } if (fread(&max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; } int len; for (int i = 0; i < config.vocab_size; i++) {