add a bit less embarassing argparse that uses keyword arguments instead of positional arguments
This commit is contained in:
@@ -448,35 +448,40 @@ int argmax(float* v, int n) {
|
||||
}
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
void error_usage() {
|
||||
printf("Usage: run <checkpoint> [options]\n");
|
||||
printf("Example: run model.bin -t 0.9 -n 256 -p \"Once upon a time\"\n");
|
||||
printf("Options:\n");
|
||||
printf(" -t <float> temperature, default 0.9\n");
|
||||
printf(" -s <int> random seed, default time(NULL)\n");
|
||||
printf(" -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
|
||||
printf(" -p <string> prompt string, default none\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
|
||||
// poor man's C argparse
|
||||
// default inits
|
||||
char *checkpoint = NULL; // e.g. out/model.bin
|
||||
float temperature = 0.9f; // e.g. 1.0, or 0.0
|
||||
int steps = 256; // max number of steps to run for, 0: use seq_len
|
||||
float temperature = 0.9f; // 0.0 = greedy & deterministic, 1.0 = max uncertainty
|
||||
rng_seed = (unsigned int)time(NULL); // seed rng with time by default
|
||||
int steps = 256; // number of steps to run for
|
||||
char *prompt = NULL; // prompt string
|
||||
|
||||
// 'checkpoint' is necessary arg
|
||||
if (argc < 2) {
|
||||
printf("Usage: %s <checkpoint_file> [temperature] [steps] [prompt]\n", argv[0]);
|
||||
return 1;
|
||||
// poor man's C argparse so we can override the defaults above from the command line
|
||||
if (argc >= 2) { checkpoint = argv[1]; } else { error_usage(); }
|
||||
for (int i = 2; i < argc; i+=2) {
|
||||
// do some basic validation
|
||||
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
|
||||
if (argv[i][0] != '-') { error_usage(); } // must start with dash
|
||||
if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
|
||||
// read in the args
|
||||
if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); }
|
||||
else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'p') { prompt = argv[i + 1]; }
|
||||
else { error_usage(); }
|
||||
}
|
||||
if (argc >= 2) {
|
||||
checkpoint = argv[1];
|
||||
}
|
||||
if (argc >= 3) {
|
||||
// optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline
|
||||
temperature = atof(argv[2]);
|
||||
}
|
||||
if (argc >= 4) {
|
||||
steps = atoi(argv[3]);
|
||||
}
|
||||
if (argc >= 5) {
|
||||
prompt = argv[4];
|
||||
}
|
||||
|
||||
// seed rng with time. if you want deterministic behavior use temperature 0.0
|
||||
rng_seed = (unsigned int)time(NULL);
|
||||
|
||||
// read in the model.bin file
|
||||
Config config;
|
||||
|
||||
Reference in New Issue
Block a user