we can inference Meta's Llama 2 7B, yay
This commit is contained in:
@@ -50,6 +50,8 @@ typedef struct {
|
||||
// freq_cis for RoPE relatively positional embeddings
|
||||
float* freq_cis_real; // (seq_len, dim/2)
|
||||
float* freq_cis_imag; // (seq_len, dim/2)
|
||||
// (optional) classifier weights for the logits, on the last layer
|
||||
float* wcls;
|
||||
} TransformerWeights;
|
||||
|
||||
typedef struct {
|
||||
@@ -110,7 +112,7 @@ void free_run_state(RunState* s) {
|
||||
// ----------------------------------------------------------------------------
|
||||
// initialization: read from checkpoint
|
||||
|
||||
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f) {
|
||||
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f, int shared_weights) {
|
||||
float* ptr = f;
|
||||
w->token_embedding_table = ptr;
|
||||
ptr += p->vocab_size * p->dim;
|
||||
@@ -138,6 +140,8 @@ void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f) {
|
||||
int head_size = p->dim / p->n_heads;
|
||||
ptr += p->seq_len * head_size / 2;
|
||||
w->freq_cis_imag = ptr;
|
||||
ptr += p->seq_len * head_size / 2;
|
||||
w->wcls = shared_weights ? w->token_embedding_table : ptr;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
@@ -319,7 +323,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
rmsnorm(x, x, w->rms_final_weight, dim);
|
||||
|
||||
// classifier into logits
|
||||
matmul(s->logits, x, w->token_embedding_table, p->dim, p->vocab_size);
|
||||
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
|
||||
}
|
||||
|
||||
int sample(float* probabilities, int n) {
|
||||
@@ -395,6 +399,9 @@ int main(int argc, char *argv[]) {
|
||||
}
|
||||
// read in the config header
|
||||
if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
|
||||
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
|
||||
int shared_weights = config.vocab_size > 0 ? 1 : 0;
|
||||
config.vocab_size = abs(config.vocab_size);
|
||||
// figure out the file size
|
||||
fseek(file, 0, SEEK_END); // move file pointer to end of file
|
||||
file_size = ftell(file); // get the file size, in bytes
|
||||
@@ -404,7 +411,8 @@ int main(int argc, char *argv[]) {
|
||||
if (fd == -1) { printf("open failed!\n"); return 1; }
|
||||
data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0);
|
||||
if (data == MAP_FAILED) { printf("mmap failed!\n"); return 1; }
|
||||
checkpoint_init_weights(&weights, &config, data + sizeof(Config)/sizeof(float));
|
||||
float* weights_ptr = data + sizeof(Config)/sizeof(float);
|
||||
checkpoint_init_weights(&weights, &config, weights_ptr, shared_weights);
|
||||
}
|
||||
// right now we cannot run for more than config.seq_len steps
|
||||
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
|
||||
|
||||
Reference in New Issue
Block a user