isolate read_checkpoint, because i'd like to now make it support both version 0 and version 1
This commit is contained in:
@@ -148,6 +148,28 @@ void checkpoint_init_weights(TransformerWeights *w, Config* p, float* ptr, int s
|
|||||||
w->wcls = shared_weights ? w->token_embedding_table : ptr;
|
w->wcls = shared_weights ? w->token_embedding_table : ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
|
||||||
|
int* fd, float** data, ssize_t* file_size) {
|
||||||
|
FILE *file = fopen(checkpoint, "rb");
|
||||||
|
if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }
|
||||||
|
// read in the config header
|
||||||
|
if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); }
|
||||||
|
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
|
||||||
|
int shared_weights = config->vocab_size > 0 ? 1 : 0;
|
||||||
|
config->vocab_size = abs(config->vocab_size);
|
||||||
|
// figure out the file size
|
||||||
|
fseek(file, 0, SEEK_END); // move file pointer to end of file
|
||||||
|
*file_size = ftell(file); // get the file size, in bytes
|
||||||
|
fclose(file);
|
||||||
|
// memory map the Transformer weights into the data pointer
|
||||||
|
*fd = open(checkpoint, O_RDONLY); // open in read only mode
|
||||||
|
if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); }
|
||||||
|
*data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
|
||||||
|
if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
|
||||||
|
float* weights_ptr = *data + sizeof(Config)/sizeof(float);
|
||||||
|
checkpoint_init_weights(weights, config, weights_ptr, shared_weights);
|
||||||
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// neural net blocks
|
// neural net blocks
|
||||||
|
|
||||||
@@ -604,27 +626,9 @@ int main(int argc, char *argv[]) {
|
|||||||
TransformerWeights weights;
|
TransformerWeights weights;
|
||||||
int fd = 0; // file descriptor for memory mapping
|
int fd = 0; // file descriptor for memory mapping
|
||||||
float* data = NULL; // memory mapped data pointer
|
float* data = NULL; // memory mapped data pointer
|
||||||
ssize_t file_size; // size of the checkpoint file in bytes
|
ssize_t file_size; // size of the checkpoint file in bytes
|
||||||
{
|
read_checkpoint(checkpoint, &config, &weights, &fd, &data, &file_size);
|
||||||
FILE *file = fopen(checkpoint, "rb");
|
|
||||||
if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); return 1; }
|
|
||||||
// read in the config header
|
|
||||||
if (fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
|
|
||||||
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
|
|
||||||
int shared_weights = config.vocab_size > 0 ? 1 : 0;
|
|
||||||
config.vocab_size = abs(config.vocab_size);
|
|
||||||
// figure out the file size
|
|
||||||
fseek(file, 0, SEEK_END); // move file pointer to end of file
|
|
||||||
file_size = ftell(file); // get the file size, in bytes
|
|
||||||
fclose(file);
|
|
||||||
// memory map the Transformer weights into the data pointer
|
|
||||||
fd = open(checkpoint, O_RDONLY); // open in read only mode
|
|
||||||
if (fd == -1) { fprintf(stderr, "open failed!\n"); return 1; }
|
|
||||||
data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0);
|
|
||||||
if (data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); return 1; }
|
|
||||||
float* weights_ptr = data + sizeof(Config)/sizeof(float);
|
|
||||||
checkpoint_init_weights(&weights, &config, weights_ptr, shared_weights);
|
|
||||||
}
|
|
||||||
// right now we cannot run for more than config.seq_len steps
|
// right now we cannot run for more than config.seq_len steps
|
||||||
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
|
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user