tweak argparse. fix steps=256, even if some models may support longer maximum seq_len. get rid of seed option for now, use temp=0.0 for deterministic behavior
This commit is contained in:
@@ -386,28 +386,28 @@ long time_in_ms() {
|
||||
int main(int argc, char *argv[]) {
|
||||
|
||||
// poor man's C argparse
|
||||
char *checkpoint = NULL;
|
||||
float temperature = 0.9f;
|
||||
char *checkpoint = NULL; // e.g. out/model.bin
|
||||
float temperature = 0.9f; // e.g. 1.0, or 0.0
|
||||
int steps = 256; // max number of steps to run for, 0: use seq_len
|
||||
// 'checkpoint' is necessary arg
|
||||
if (argc < 2) {
|
||||
printf("Usage: %s <checkpoint_file> [temperature] [seed]\n", argv[0]);
|
||||
printf("Usage: %s <checkpoint_file> [temperature] [steps]\n", argv[0]);
|
||||
return 1;
|
||||
}
|
||||
if (argc >= 2) {
|
||||
checkpoint = argv[1];
|
||||
// temperature is optional
|
||||
}
|
||||
if (argc >= 3) {
|
||||
// optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline
|
||||
temperature = atof(argv[2]);
|
||||
}
|
||||
// seed is optional
|
||||
if (argc >= 4) {
|
||||
unsigned int seed = atoi(argv[3]);
|
||||
srand(seed);
|
||||
} else {
|
||||
time_t current_time;
|
||||
time(¤t_time);
|
||||
srand((unsigned int)current_time);
|
||||
steps = atoi(argv[3]);
|
||||
}
|
||||
|
||||
// seed rng with time. if you want deterministic behavior use temperature 0.0
|
||||
srand((unsigned int)time(NULL));
|
||||
|
||||
// read in the model.bin file
|
||||
Config config;
|
||||
TransformerWeights weights;
|
||||
@@ -424,6 +424,8 @@ int main(int argc, char *argv[]) {
|
||||
if(checkpoint_init_weights(&weights, &config, file)) { return 1; }
|
||||
fclose(file);
|
||||
}
|
||||
// right now we cannot run for more than config.seq_len steps
|
||||
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
|
||||
|
||||
// read in the tokenizer.bin file
|
||||
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
|
||||
@@ -450,11 +452,11 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
// the current position we are in
|
||||
long start = time_in_ms();
|
||||
|
||||
int next;
|
||||
int token = 1; // 1 = BOS token in Llama-2 sentencepiece
|
||||
int pos = 0;
|
||||
while (pos < config.seq_len) {
|
||||
printf("<s>\n"); // explicit print the initial BOS token (=1), stylistically symmetric
|
||||
while (pos < steps) {
|
||||
|
||||
// forward the transformer to get logits for the next token
|
||||
transformer(token, pos, &config, &state, &weights);
|
||||
@@ -478,11 +480,10 @@ int main(int argc, char *argv[]) {
|
||||
token = next;
|
||||
pos++;
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
// report our achieved tok/s
|
||||
// report achieved tok/s
|
||||
long end = time_in_ms();
|
||||
printf("achieved tok/s: %f\n", config.seq_len / (double)(end-start)*1000);
|
||||
printf("\nachieved tok/s: %f\n", config.seq_len / (double)(end-start)*1000);
|
||||
|
||||
// memory cleanup
|
||||
free_run_state(&state);
|
||||
|
||||
Reference in New Issue
Block a user