tweak argparse. fix steps=256, even if some models may support longer maximum seq_len. get rid of seed option for now, use temp=0.0 for deterministic behavior

This commit is contained in:
Andrej Karpathy
2023-07-24 20:59:32 +00:00
parent 90ae37c3e6
commit 791be9d991
+18 -17
View File
@@ -386,28 +386,28 @@ long time_in_ms() {
int main(int argc, char *argv[]) {
// poor man's C argparse
char *checkpoint = NULL;
float temperature = 0.9f;
char *checkpoint = NULL; // e.g. out/model.bin
float temperature = 0.9f; // e.g. 1.0, or 0.0
int steps = 256; // max number of steps to run for, 0: use seq_len
// 'checkpoint' is necessary arg
if (argc < 2) {
printf("Usage: %s <checkpoint_file> [temperature] [seed]\n", argv[0]);
printf("Usage: %s <checkpoint_file> [temperature] [steps]\n", argv[0]);
return 1;
}
checkpoint = argv[1];
// temperature is optional
if (argc >= 2) {
checkpoint = argv[1];
}
if (argc >= 3) {
// optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline
temperature = atof(argv[2]);
}
// seed is optional
if (argc >= 4) {
unsigned int seed = atoi(argv[3]);
srand(seed);
} else {
time_t current_time;
time(&current_time);
srand((unsigned int)current_time);
steps = atoi(argv[3]);
}
// seed rng with time. if you want deterministic behavior use temperature 0.0
srand((unsigned int)time(NULL));
// read in the model.bin file
Config config;
TransformerWeights weights;
@@ -424,6 +424,8 @@ int main(int argc, char *argv[]) {
if(checkpoint_init_weights(&weights, &config, file)) { return 1; }
fclose(file);
}
// right now we cannot run for more than config.seq_len steps
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
// read in the tokenizer.bin file
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
@@ -450,11 +452,11 @@ int main(int argc, char *argv[]) {
// the current position we are in
long start = time_in_ms();
int next;
int token = 1; // 1 = BOS token in Llama-2 sentencepiece
int pos = 0;
while (pos < config.seq_len) {
printf("<s>\n"); // explicit print the initial BOS token (=1), stylistically symmetric
while (pos < steps) {
// forward the transformer to get logits for the next token
transformer(token, pos, &config, &state, &weights);
@@ -478,11 +480,10 @@ int main(int argc, char *argv[]) {
token = next;
pos++;
}
printf("\n");
// report our achieved tok/s
// report achieved tok/s
long end = time_in_ms();
printf("achieved tok/s: %f\n", config.seq_len / (double)(end-start)*1000);
printf("\nachieved tok/s: %f\n", config.seq_len / (double)(end-start)*1000);
// memory cleanup
free_run_state(&state);