tweak argparse. fix steps=256, even if some models may support longer maximum seq_len. get rid of seed option for now, use temp=0.0 for deterministic behavior
This commit is contained in:
@@ -386,28 +386,28 @@ long time_in_ms() {
|
|||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
|
|
||||||
// poor man's C argparse
|
// poor man's C argparse
|
||||||
char *checkpoint = NULL;
|
char *checkpoint = NULL; // e.g. out/model.bin
|
||||||
float temperature = 0.9f;
|
float temperature = 0.9f; // e.g. 1.0, or 0.0
|
||||||
|
int steps = 256; // max number of steps to run for, 0: use seq_len
|
||||||
// 'checkpoint' is necessary arg
|
// 'checkpoint' is necessary arg
|
||||||
if (argc < 2) {
|
if (argc < 2) {
|
||||||
printf("Usage: %s <checkpoint_file> [temperature] [seed]\n", argv[0]);
|
printf("Usage: %s <checkpoint_file> [temperature] [steps]\n", argv[0]);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
checkpoint = argv[1];
|
if (argc >= 2) {
|
||||||
// temperature is optional
|
checkpoint = argv[1];
|
||||||
|
}
|
||||||
if (argc >= 3) {
|
if (argc >= 3) {
|
||||||
|
// optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline
|
||||||
temperature = atof(argv[2]);
|
temperature = atof(argv[2]);
|
||||||
}
|
}
|
||||||
// seed is optional
|
|
||||||
if (argc >= 4) {
|
if (argc >= 4) {
|
||||||
unsigned int seed = atoi(argv[3]);
|
steps = atoi(argv[3]);
|
||||||
srand(seed);
|
|
||||||
} else {
|
|
||||||
time_t current_time;
|
|
||||||
time(¤t_time);
|
|
||||||
srand((unsigned int)current_time);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// seed rng with time. if you want deterministic behavior use temperature 0.0
|
||||||
|
srand((unsigned int)time(NULL));
|
||||||
|
|
||||||
// read in the model.bin file
|
// read in the model.bin file
|
||||||
Config config;
|
Config config;
|
||||||
TransformerWeights weights;
|
TransformerWeights weights;
|
||||||
@@ -424,6 +424,8 @@ int main(int argc, char *argv[]) {
|
|||||||
if(checkpoint_init_weights(&weights, &config, file)) { return 1; }
|
if(checkpoint_init_weights(&weights, &config, file)) { return 1; }
|
||||||
fclose(file);
|
fclose(file);
|
||||||
}
|
}
|
||||||
|
// right now we cannot run for more than config.seq_len steps
|
||||||
|
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
|
||||||
|
|
||||||
// read in the tokenizer.bin file
|
// read in the tokenizer.bin file
|
||||||
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
|
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
|
||||||
@@ -450,11 +452,11 @@ int main(int argc, char *argv[]) {
|
|||||||
|
|
||||||
// the current position we are in
|
// the current position we are in
|
||||||
long start = time_in_ms();
|
long start = time_in_ms();
|
||||||
|
|
||||||
int next;
|
int next;
|
||||||
int token = 1; // 1 = BOS token in Llama-2 sentencepiece
|
int token = 1; // 1 = BOS token in Llama-2 sentencepiece
|
||||||
int pos = 0;
|
int pos = 0;
|
||||||
while (pos < config.seq_len) {
|
printf("<s>\n"); // explicit print the initial BOS token (=1), stylistically symmetric
|
||||||
|
while (pos < steps) {
|
||||||
|
|
||||||
// forward the transformer to get logits for the next token
|
// forward the transformer to get logits for the next token
|
||||||
transformer(token, pos, &config, &state, &weights);
|
transformer(token, pos, &config, &state, &weights);
|
||||||
@@ -478,11 +480,10 @@ int main(int argc, char *argv[]) {
|
|||||||
token = next;
|
token = next;
|
||||||
pos++;
|
pos++;
|
||||||
}
|
}
|
||||||
printf("\n");
|
|
||||||
|
|
||||||
// report our achieved tok/s
|
// report achieved tok/s
|
||||||
long end = time_in_ms();
|
long end = time_in_ms();
|
||||||
printf("achieved tok/s: %f\n", config.seq_len / (double)(end-start)*1000);
|
printf("\nachieved tok/s: %f\n", config.seq_len / (double)(end-start)*1000);
|
||||||
|
|
||||||
// memory cleanup
|
// memory cleanup
|
||||||
free_run_state(&state);
|
free_run_state(&state);
|
||||||
|
|||||||
Reference in New Issue
Block a user