ok this works but is super slow because we are doing all the work in fp32 still

This commit is contained in:
Andrej Karpathy
2023-08-18 03:40:18 +00:00
parent e9cbe3e84f
commit 591f1353c7
2 changed files with 169 additions and 94 deletions
+139 -72
View File
@@ -1,5 +1,6 @@
/* Inference for Llama-2 Transformer model in pure C */
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>
@@ -13,41 +14,49 @@
#include <unistd.h>
#include <sys/mman.h>
#endif
// ----------------------------------------------------------------------------
// Globals
int GS = 0; // group size global for quantization
// ----------------------------------------------------------------------------
// Transformer and RunState structs, and related memory management
typedef struct {
int dim; // transformer dimension
int hidden_dim; // for ffn layers
int n_layers; // number of layers
int n_heads; // number of query heads
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
int vocab_size; // vocabulary size, usually 256 (byte-level)
int seq_len; // max sequence length
int dim; // transformer dimension
int hidden_dim; // dimension of the inner layer in the MLP
int n_layers; // number of layers
int n_heads; // number of query heads
int n_kv_heads; // number of key & value heads (can be < query heads because of multiquery)
int vocab_size; // vocabulary size (size of the classifier weights)
int seq_len; // max sequence length the model was trained with
} Config;
typedef struct {
int8_t* q; // quantized values
float* s; // scaling factors
} QuantizedTensor;
typedef struct {
// token embedding table
float* token_embedding_table; // (vocab_size, dim)
QuantizedTensor token_embedding_table; // (vocab_size, dim)
// weights for rmsnorms
float* rms_att_weight; // (layer, dim) rmsnorm weights
float* rms_ffn_weight; // (layer, dim)
// weights for matmuls. note dim == n_heads * head_size
float* wq; // (layer, dim, n_heads * head_size)
float* wk; // (layer, dim, n_kv_heads * head_size)
float* wv; // (layer, dim, n_kv_heads * head_size)
float* wo; // (layer, n_heads * head_size, dim)
QuantizedTensor wq; // (layer, dim, n_heads * head_size)
QuantizedTensor wk; // (layer, dim, n_kv_heads * head_size)
QuantizedTensor wv; // (layer, dim, n_kv_heads * head_size)
QuantizedTensor wo; // (layer, n_heads * head_size, dim)
// weights for ffn
float* w1; // (layer, hidden_dim, dim)
float* w2; // (layer, dim, hidden_dim)
float* w3; // (layer, hidden_dim, dim)
QuantizedTensor w1; // (layer, hidden_dim, dim)
QuantizedTensor w2; // (layer, dim, hidden_dim)
QuantizedTensor w3; // (layer, hidden_dim, dim)
// final rmsnorm
float* rms_final_weight; // (dim,)
// freq_cis for RoPE relatively positional embeddings (not used anymore)
float* freq_cis_real; // (seq_len, head_size/2)
float* freq_cis_imag; // (seq_len, head_size/2)
// (optional) classifier weights for the logits, on the last layer
float* wcls;
QuantizedTensor wcls; // (dim, vocab_size)
} TransformerWeights;
typedef struct {
@@ -117,35 +126,67 @@ void free_run_state(RunState* s) {
// ----------------------------------------------------------------------------
// initialization: read from checkpoint
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
void checkpoint_init_weights(TransformerWeights *w, Config* p, void* ptr, uint8_t shared_classifier) {
int head_size = p->dim / p->n_heads;
w->token_embedding_table = ptr;
ptr += p->vocab_size * p->dim;
w->rms_att_weight = ptr;
ptr += p->n_layers * p->dim;
w->wq = ptr;
ptr += p->n_layers * p->dim * (p->n_heads * head_size);
w->wk = ptr;
ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
w->wv = ptr;
ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
w->wo = ptr;
ptr += p->n_layers * (p->n_heads * head_size) * p->dim;
w->rms_ffn_weight = ptr;
ptr += p->n_layers * p->dim;
w->w1 = ptr;
ptr += p->n_layers * p->dim * p->hidden_dim;
w->w2 = ptr;
ptr += p->n_layers * p->hidden_dim * p->dim;
w->w3 = ptr;
ptr += p->n_layers * p->dim * p->hidden_dim;
w->rms_final_weight = ptr;
ptr += p->dim;
w->freq_cis_real = ptr;
ptr += p->seq_len * head_size / 2;
w->freq_cis_imag = ptr;
ptr += p->seq_len * head_size / 2;
w->wcls = shared_weights ? w->token_embedding_table : ptr;
// first are the parameters that are kept in fp32 (the rmsnorm (1D) weights)
float* fptr = (float*) ptr; // cast our pointer to float*
w->rms_att_weight = fptr;
fptr += p->n_layers * p->dim;
w->rms_ffn_weight = fptr;
fptr += p->n_layers * p->dim;
w->rms_final_weight = fptr;
fptr += p->dim;
// now read all the quantized weights
int8_t* qptr = (int8_t*) fptr; // now cast the pointer to int8_t*
w->token_embedding_table.q = qptr;
qptr += p->vocab_size * p->dim;
w->wq.q = qptr;
qptr += p->n_layers * p->dim * (p->n_heads * head_size);
w->wk.q = qptr;
qptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
w->wv.q = qptr;
qptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
w->wo.q = qptr;
qptr += p->n_layers * (p->n_heads * head_size) * p->dim;
w->w1.q = qptr;
qptr += p->n_layers * p->dim * p->hidden_dim;
w->w2.q = qptr;
qptr += p->n_layers * p->hidden_dim * p->dim;
w->w3.q = qptr;
qptr += p->n_layers * p->dim * p->hidden_dim;
if (shared_classifier) {
w->wcls.q = w->token_embedding_table.q;
} else {
w->wcls.q = qptr;
qptr += p->dim * p->vocab_size;
}
// and finally all the associated scaling factors
float* sptr = (float*) qptr; // cast pointer back to float*
w->token_embedding_table.s = sptr;
sptr += p->vocab_size * p->dim / GS;
w->wq.s = sptr;
sptr += p->n_layers * p->dim * (p->n_heads * head_size) / GS;
w->wk.s = sptr;
sptr += p->n_layers * p->dim * (p->n_kv_heads * head_size) / GS;
w->wv.s = sptr;
sptr += p->n_layers * p->dim * (p->n_kv_heads * head_size) / GS;
w->wo.s = sptr;
sptr += p->n_layers * (p->n_heads * head_size) * p->dim / GS;
w->w1.s = sptr;
sptr += p->n_layers * p->dim * p->hidden_dim / GS;
w->w2.s = sptr;
sptr += p->n_layers * p->hidden_dim * p->dim / GS;
w->w3.s = sptr;
sptr += p->n_layers * p->dim * p->hidden_dim / GS;
if (shared_classifier) {
w->wcls.s = w->token_embedding_table.s;
} else {
w->wcls.s = sptr;
sptr += p->dim * p->vocab_size / GS;
}
}
// ----------------------------------------------------------------------------
@@ -186,20 +227,30 @@ void softmax(float* x, int size) {
}
}
void matmul(float* xout, float* x, float* w, int n, int d) {
void matmul(float* xout, float* x, int8_t* q, float* s, int n, int d) {
// W (d,n) @ x (n,) -> xout (d,)
// by far the most amount of time is spent inside this little function
// do the matmul
int i;
#pragma omp parallel for private(i)
for (i = 0; i < d; i++) {
float val = 0.0f;
for (int j = 0; j < n; j++) {
val += w[i * n + j] * x[j];
int ix = i * n + j;
float wij = q[ix] * s[ix / GS];
val += wij * x[j];
}
xout[i] = val;
}
}
void dequantize(int8_t* q, float* s, float* x, int n) {
for (int i = 0; i < n; i++) {
x[i] = q[i] * s[i / GS];
}
}
void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) {
// a few convenience variables
@@ -210,9 +261,9 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
int hidden_dim = p->hidden_dim;
int head_size = dim / p->n_heads;
// copy the token embedding into x
float* content_row = &(w->token_embedding_table[token * dim]);
memcpy(x, content_row, dim*sizeof(*x));
// dequantize the token embedding into a float x
QuantizedTensor tok = w->token_embedding_table;
dequantize(tok.q + token * dim, tok.s + token * dim / GS, x, dim);
// forward all the layers
for(int l = 0; l < p->n_layers; l++) {
@@ -221,9 +272,9 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
// qkv matmuls for this position
matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
matmul(s->q, s->xb, w->wq.q + l*dim*dim, w->wq.s + l*dim*dim/GS, dim, dim);
matmul(s->k, s->xb, w->wk.q + l*dim*kv_dim, w->wk.s + l*dim*kv_dim/GS, dim, kv_dim);
matmul(s->v, s->xb, w->wv.q + l*dim*kv_dim, w->wv.s + l*dim*kv_dim/GS, dim, kv_dim);
// RoPE relative positional encoding: complex-valued rotate q and k in each head
for (int i = 0; i < dim; i+=2) {
@@ -290,7 +341,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
}
// final matmul to get the output of the attention
matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);
matmul(s->xb2, s->xb, w->wo.q + l*dim*dim, w->wo.s + l*dim*dim/GS, dim, dim);
// residual connection back into x
for (int i = 0; i < dim; i++) {
@@ -302,8 +353,8 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);
matmul(s->hb, s->xb, w->w1.q + l*dim*hidden_dim, w->w1.s + l*dim*hidden_dim/GS, dim, hidden_dim);
matmul(s->hb2, s->xb, w->w3.q + l*dim*hidden_dim, w->w3.s + l*dim*hidden_dim/GS, dim, hidden_dim);
// F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid
for (int i = 0; i < hidden_dim; i++) {
@@ -316,7 +367,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
}
// final matmul to get the output of the ffn
matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);
matmul(s->xb, s->hb, w->w2.q + l*dim*hidden_dim, w->w2.s + l*dim*hidden_dim/GS, hidden_dim, dim);
// residual connection
for (int i = 0; i < dim; i++) {
@@ -328,7 +379,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
rmsnorm(x, x, w->rms_final_weight, dim);
// classifier into logits
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
matmul(s->logits, x, w->wcls.q, w->wcls.s, dim, p->vocab_size);
}
// ----------------------------------------------------------------------------
@@ -597,35 +648,51 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 'z') { tokenizer = argv[i + 1]; }
else { error_usage(); }
}
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}
// input validations
// our rng cannot accoupt 0 as a seed, so might as well use time(NULL) as default
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL); }
// read in the model.bin file
Config config;
TransformerWeights weights;
int fd = 0; // file descriptor for memory mapping
float* data = NULL; // memory mapped data pointer
ssize_t file_size; // size of the checkpoint file in bytes
void* data = NULL; // memory mapped data pointer
ssize_t file_size; // size of the checkpoint file in bytes
{
// first "peak" the checkpoint and extract metadata
FILE *file = fopen(checkpoint, "rb");
if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); return 1; }
// read in the config header
// read in magic number (uint32), has to be 0x616b3432, i.e. "ak42" in ASCII
uint32_t magic_number;
if (fread(&magic_number, sizeof(uint32_t), 1, file) != 1) { return 1; }
if (magic_number != 0x616b3432) { fprintf(stderr, "Bad magic number\n"); return 1; }
// read in the version number (uint32), has to be 1
int version;
if (fread(&version, sizeof(int), 1, file) != 1) { return 1; }
if (version != 1) { fprintf(stderr, "Bad version number\n"); return 1; }
int header_size = 256; // the header size for version 1 in bytes
// read in the Config
if (fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
int shared_weights = config.vocab_size > 0 ? 1 : 0;
config.vocab_size = abs(config.vocab_size);
// figure out the file size
// read in flags
uint8_t shared_classifier; // a byte to indicate if the classifier is shared
if (fread(&shared_classifier, sizeof(uint8_t), 1, file) != 1) { return 1; }
int group_size; // the group size used in quantization
if (fread(&group_size, sizeof(int), 1, file) != 1) { return 1; }
GS = group_size; // set as global, as it will be used in many places
// seek all the way to the end to figure out the full file size
fseek(file, 0, SEEK_END); // move file pointer to end of file
file_size = ftell(file); // get the file size, in bytes
fclose(file);
// memory map the Transformer weights into the data pointer
// now memory map the Transformer weights into the data pointer
fd = open(checkpoint, O_RDONLY); // open in read only mode
if (fd == -1) { fprintf(stderr, "open failed!\n"); return 1; }
data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0);
if (data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); return 1; }
float* weights_ptr = data + sizeof(Config)/sizeof(float);
checkpoint_init_weights(&weights, &config, weights_ptr, shared_weights);
void* weights_ptr = (char*)data + header_size; // skip header bytes. char is 1 byte
checkpoint_init_weights(&weights, &config, weights_ptr, shared_classifier);
}
// right now we cannot run for more than config.seq_len steps
// we should not run for more than config.seq_len steps
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
// read in the tokenizer .bin file