add a bit less embarassing argparse that uses keyword arguments instead of positional arguments

This commit is contained in:
Andrej Karpathy
2023-08-05 17:08:11 +00:00
parent 837796e0b7
commit dcef5ff7c7
2 changed files with 30 additions and 25 deletions
+28 -23
View File
@@ -448,35 +448,40 @@ int argmax(float* v, int n) {
}
// ----------------------------------------------------------------------------
void error_usage() {
printf("Usage: run <checkpoint> [options]\n");
printf("Example: run model.bin -t 0.9 -n 256 -p \"Once upon a time\"\n");
printf("Options:\n");
printf(" -t <float> temperature, default 0.9\n");
printf(" -s <int> random seed, default time(NULL)\n");
printf(" -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
printf(" -p <string> prompt string, default none\n");
exit(EXIT_FAILURE);
}
int main(int argc, char *argv[]) {
// poor man's C argparse
// default inits
char *checkpoint = NULL; // e.g. out/model.bin
float temperature = 0.9f; // e.g. 1.0, or 0.0
int steps = 256; // max number of steps to run for, 0: use seq_len
float temperature = 0.9f; // 0.0 = greedy & deterministic, 1.0 = max uncertainty
rng_seed = (unsigned int)time(NULL); // seed rng with time by default
int steps = 256; // number of steps to run for
char *prompt = NULL; // prompt string
// 'checkpoint' is necessary arg
if (argc < 2) {
printf("Usage: %s <checkpoint_file> [temperature] [steps] [prompt]\n", argv[0]);
return 1;
// poor man's C argparse so we can override the defaults above from the command line
if (argc >= 2) { checkpoint = argv[1]; } else { error_usage(); }
for (int i = 2; i < argc; i+=2) {
// do some basic validation
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
if (argv[i][0] != '-') { error_usage(); } // must start with dash
if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
// read in the args
if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); }
else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
else if (argv[i][1] == 'p') { prompt = argv[i + 1]; }
else { error_usage(); }
}
if (argc >= 2) {
checkpoint = argv[1];
}
if (argc >= 3) {
// optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline
temperature = atof(argv[2]);
}
if (argc >= 4) {
steps = atoi(argv[3]);
}
if (argc >= 5) {
prompt = argv[4];
}
// seed rng with time. if you want deterministic behavior use temperature 0.0
rng_seed = (unsigned int)time(NULL);
// read in the model.bin file
Config config;