diff --git a/run.c b/run.c index 818f6da..e5b1233 100644 --- a/run.c +++ b/run.c @@ -695,8 +695,8 @@ void error_usage() { fprintf(stderr, "Usage: run [options]\n"); fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n"); fprintf(stderr, "Options:\n"); - fprintf(stderr, " -t temperature, default 1.0\n"); - fprintf(stderr, " -p p value in top-p (nucleus) sampling. default 0.9\n"); + fprintf(stderr, " -t temperature in [0,inf], default 1.0\n"); + fprintf(stderr, " -p p value in top-p (nucleus) sampling in [0,1] default 0.9\n"); fprintf(stderr, " -s random seed, default time(NULL)\n"); fprintf(stderr, " -n number of steps to run for, default 256. 0 = max_seq_len\n"); fprintf(stderr, " -i input prompt\n"); @@ -706,7 +706,7 @@ void error_usage() { int main(int argc, char *argv[]) { - // default inits + // default parameters char *checkpoint_path = NULL; // e.g. out/model.bin char *tokenizer_path = "tokenizer.bin"; float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher @@ -732,6 +732,12 @@ int main(int argc, char *argv[]) { else { error_usage(); } } + // parameter validation/overrides + if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL); + if (temperature < 0.0) temperature = 0.0; + if (topp < 0.0 || 1.0 < topp) topp = 0.9; + if (steps <= 0) steps = 0; + // build the Transformer via the model .bin file Transformer transformer; build_transformer(&transformer, checkpoint_path); @@ -744,16 +750,6 @@ int main(int argc, char *argv[]) { Sampler sampler; build_sampler(&sampler, transformer.config.vocab_size); - - // Check for sound parameters and fallback to defaults if needed. - // This needs to stay close to the generation code since, in the future, - // it will be in the "chat loop" (one may change params from on prompts to the other). - if (rng_seed <= 0) { rng_seed = (unsigned int)time(NULL);} - if (temperature < 0.0 || 1.0 < temperature) temperature = 1.0; - if (topp < 0.0 || 1.0 < temperature) topp = 0.9; - if (steps <= 0 || transformer.config.seq_len < steps ) { steps = transformer.config.seq_len; } - - // encode the (string) prompt into tokens sequence, if any is given int *prompt_tokens = NULL; // the sequence of prompt tokens int num_prompt_tokens = 0; // the total number of prompt tokens