parameter validation cleanup

This commit is contained in:
Andrej Karpathy
2023-08-21 15:17:14 +00:00
parent 2d972f1763
commit 33d94f60a5
+9 -13
View File
@@ -695,8 +695,8 @@ void error_usage() {
fprintf(stderr, "Usage: run <checkpoint> [options]\n");
fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
fprintf(stderr, "Options:\n");
fprintf(stderr, " -t <float> temperature, default 1.0\n");
fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling. default 0.9\n");
fprintf(stderr, " -t <float> temperature in [0,inf], default 1.0\n");
fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
fprintf(stderr, " -i <string> input prompt\n");
@@ -706,7 +706,7 @@ void error_usage() {
int main(int argc, char *argv[]) {
// default inits
// default parameters
char *checkpoint_path = NULL; // e.g. out/model.bin
char *tokenizer_path = "tokenizer.bin";
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
@@ -732,6 +732,12 @@ int main(int argc, char *argv[]) {
else { error_usage(); }
}
// parameter validation/overrides
if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL);
if (temperature < 0.0) temperature = 0.0;
if (topp < 0.0 || 1.0 < topp) topp = 0.9;
if (steps <= 0) steps = 0;
// build the Transformer via the model .bin file
Transformer transformer;
build_transformer(&transformer, checkpoint_path);
@@ -744,16 +750,6 @@ int main(int argc, char *argv[]) {
Sampler sampler;
build_sampler(&sampler, transformer.config.vocab_size);
// Check for sound parameters and fallback to defaults if needed.
// This needs to stay close to the generation code since, in the future,
// it will be in the "chat loop" (one may change params from on prompts to the other).
if (rng_seed <= 0) { rng_seed = (unsigned int)time(NULL);}
if (temperature < 0.0 || 1.0 < temperature) temperature = 1.0;
if (topp < 0.0 || 1.0 < temperature) topp = 0.9;
if (steps <= 0 || transformer.config.seq_len < steps ) { steps = transformer.config.seq_len; }
// encode the (string) prompt into tokens sequence, if any is given
int *prompt_tokens = NULL; // the sequence of prompt tokens
int num_prompt_tokens = 0; // the total number of prompt tokens