parameter validation cleanup
This commit is contained in:
@@ -695,8 +695,8 @@ void error_usage() {
|
|||||||
fprintf(stderr, "Usage: run <checkpoint> [options]\n");
|
fprintf(stderr, "Usage: run <checkpoint> [options]\n");
|
||||||
fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
|
fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
|
||||||
fprintf(stderr, "Options:\n");
|
fprintf(stderr, "Options:\n");
|
||||||
fprintf(stderr, " -t <float> temperature, default 1.0\n");
|
fprintf(stderr, " -t <float> temperature in [0,inf], default 1.0\n");
|
||||||
fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling. default 0.9\n");
|
fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
|
||||||
fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
|
fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
|
||||||
fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
|
fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
|
||||||
fprintf(stderr, " -i <string> input prompt\n");
|
fprintf(stderr, " -i <string> input prompt\n");
|
||||||
@@ -706,7 +706,7 @@ void error_usage() {
|
|||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
|
|
||||||
// default inits
|
// default parameters
|
||||||
char *checkpoint_path = NULL; // e.g. out/model.bin
|
char *checkpoint_path = NULL; // e.g. out/model.bin
|
||||||
char *tokenizer_path = "tokenizer.bin";
|
char *tokenizer_path = "tokenizer.bin";
|
||||||
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
||||||
@@ -732,6 +732,12 @@ int main(int argc, char *argv[]) {
|
|||||||
else { error_usage(); }
|
else { error_usage(); }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parameter validation/overrides
|
||||||
|
if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL);
|
||||||
|
if (temperature < 0.0) temperature = 0.0;
|
||||||
|
if (topp < 0.0 || 1.0 < topp) topp = 0.9;
|
||||||
|
if (steps <= 0) steps = 0;
|
||||||
|
|
||||||
// build the Transformer via the model .bin file
|
// build the Transformer via the model .bin file
|
||||||
Transformer transformer;
|
Transformer transformer;
|
||||||
build_transformer(&transformer, checkpoint_path);
|
build_transformer(&transformer, checkpoint_path);
|
||||||
@@ -744,16 +750,6 @@ int main(int argc, char *argv[]) {
|
|||||||
Sampler sampler;
|
Sampler sampler;
|
||||||
build_sampler(&sampler, transformer.config.vocab_size);
|
build_sampler(&sampler, transformer.config.vocab_size);
|
||||||
|
|
||||||
|
|
||||||
// Check for sound parameters and fallback to defaults if needed.
|
|
||||||
// This needs to stay close to the generation code since, in the future,
|
|
||||||
// it will be in the "chat loop" (one may change params from on prompts to the other).
|
|
||||||
if (rng_seed <= 0) { rng_seed = (unsigned int)time(NULL);}
|
|
||||||
if (temperature < 0.0 || 1.0 < temperature) temperature = 1.0;
|
|
||||||
if (topp < 0.0 || 1.0 < temperature) topp = 0.9;
|
|
||||||
if (steps <= 0 || transformer.config.seq_len < steps ) { steps = transformer.config.seq_len; }
|
|
||||||
|
|
||||||
|
|
||||||
// encode the (string) prompt into tokens sequence, if any is given
|
// encode the (string) prompt into tokens sequence, if any is given
|
||||||
int *prompt_tokens = NULL; // the sequence of prompt tokens
|
int *prompt_tokens = NULL; // the sequence of prompt tokens
|
||||||
int num_prompt_tokens = 0; // the total number of prompt tokens
|
int num_prompt_tokens = 0; // the total number of prompt tokens
|
||||||
|
|||||||
Reference in New Issue
Block a user