we can inference Meta's Llama 2 7B, yay

This commit is contained in:
Andrej Karpathy
2023-07-25 04:21:07 +00:00
parent 133ad3ffff
commit c3e0d73bd2
3 changed files with 126 additions and 4 deletions
+11 -3
View File
@@ -50,6 +50,8 @@ typedef struct {
// freq_cis for RoPE relatively positional embeddings
float* freq_cis_real; // (seq_len, dim/2)
float* freq_cis_imag; // (seq_len, dim/2)
// (optional) classifier weights for the logits, on the last layer
float* wcls;
} TransformerWeights;
typedef struct {
@@ -110,7 +112,7 @@ void free_run_state(RunState* s) {
// ----------------------------------------------------------------------------
// initialization: read from checkpoint
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f) {
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f, int shared_weights) {
float* ptr = f;
w->token_embedding_table = ptr;
ptr += p->vocab_size * p->dim;
@@ -138,6 +140,8 @@ void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f) {
int head_size = p->dim / p->n_heads;
ptr += p->seq_len * head_size / 2;
w->freq_cis_imag = ptr;
ptr += p->seq_len * head_size / 2;
w->wcls = shared_weights ? w->token_embedding_table : ptr;
}
// ----------------------------------------------------------------------------
@@ -319,7 +323,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
rmsnorm(x, x, w->rms_final_weight, dim);
// classifier into logits
matmul(s->logits, x, w->token_embedding_table, p->dim, p->vocab_size);
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
}
int sample(float* probabilities, int n) {
@@ -395,6 +399,9 @@ int main(int argc, char *argv[]) {
}
// read in the config header
if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
int shared_weights = config.vocab_size > 0 ? 1 : 0;
config.vocab_size = abs(config.vocab_size);
// figure out the file size
fseek(file, 0, SEEK_END); // move file pointer to end of file
file_size = ftell(file); // get the file size, in bytes
@@ -404,7 +411,8 @@ int main(int argc, char *argv[]) {
if (fd == -1) { printf("open failed!\n"); return 1; }
data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0);
if (data == MAP_FAILED) { printf("mmap failed!\n"); return 1; }
checkpoint_init_weights(&weights, &config, data + sizeof(Config)/sizeof(float));
float* weights_ptr = data + sizeof(Config)/sizeof(float);
checkpoint_init_weights(&weights, &config, weights_ptr, shared_weights);
}
// right now we cannot run for more than config.seq_len steps
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }