diff --git a/run.c b/run.c index 3193d89..1ee4947 100644 --- a/run.c +++ b/run.c @@ -386,28 +386,28 @@ long time_in_ms() { int main(int argc, char *argv[]) { // poor man's C argparse - char *checkpoint = NULL; - float temperature = 0.9f; + char *checkpoint = NULL; // e.g. out/model.bin + float temperature = 0.9f; // e.g. 1.0, or 0.0 + int steps = 256; // max number of steps to run for, 0: use seq_len // 'checkpoint' is necessary arg if (argc < 2) { - printf("Usage: %s [temperature] [seed]\n", argv[0]); + printf("Usage: %s [temperature] [steps]\n", argv[0]); return 1; } - checkpoint = argv[1]; - // temperature is optional + if (argc >= 2) { + checkpoint = argv[1]; + } if (argc >= 3) { + // optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline temperature = atof(argv[2]); } - // seed is optional if (argc >= 4) { - unsigned int seed = atoi(argv[3]); - srand(seed); - } else { - time_t current_time; - time(¤t_time); - srand((unsigned int)current_time); + steps = atoi(argv[3]); } + // seed rng with time. if you want deterministic behavior use temperature 0.0 + srand((unsigned int)time(NULL)); + // read in the model.bin file Config config; TransformerWeights weights; @@ -424,6 +424,8 @@ int main(int argc, char *argv[]) { if(checkpoint_init_weights(&weights, &config, file)) { return 1; } fclose(file); } + // right now we cannot run for more than config.seq_len steps + if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; } // read in the tokenizer.bin file char** vocab = (char**)malloc(config.vocab_size * sizeof(char*)); @@ -450,11 +452,11 @@ int main(int argc, char *argv[]) { // the current position we are in long start = time_in_ms(); - int next; int token = 1; // 1 = BOS token in Llama-2 sentencepiece int pos = 0; - while (pos < config.seq_len) { + printf("\n"); // explicit print the initial BOS token (=1), stylistically symmetric + while (pos < steps) { // forward the transformer to get logits for the next token transformer(token, pos, &config, &state, &weights); @@ -478,11 +480,10 @@ int main(int argc, char *argv[]) { token = next; pos++; } - printf("\n"); - // report our achieved tok/s + // report achieved tok/s long end = time_in_ms(); - printf("achieved tok/s: %f\n", config.seq_len / (double)(end-start)*1000); + printf("\nachieved tok/s: %f\n", config.seq_len / (double)(end-start)*1000); // memory cleanup free_run_state(&state);