diff --git a/README.md b/README.md index 6d29986..9b054dc 100644 --- a/README.md +++ b/README.md @@ -297,7 +297,6 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg - revive tests; train a tiny Llama test model (committed to repo) and use it as reference in unit tests - make it easier to add a new dataset with not too much pain -- add multiquery support into run.c - should calculate freq_cis online in the script run.c instead of loading them - int4/8 quantization - export the model in a more sensible output format with a proper header, etc. diff --git a/model.py b/model.py index 7329d6c..c8c82a9 100644 --- a/model.py +++ b/model.py @@ -94,6 +94,7 @@ class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + assert args.n_heads % self.n_kv_heads == 0 model_parallel_size = 1 self.n_local_heads = args.n_heads // model_parallel_size self.n_local_kv_heads = self.n_kv_heads // model_parallel_size diff --git a/run.c b/run.c index 14469ad..4a6e8c2 100644 --- a/run.c +++ b/run.c @@ -39,11 +39,11 @@ typedef struct { // weights for rmsnorms float* rms_att_weight; // (layer, dim) rmsnorm weights float* rms_ffn_weight; // (layer, dim) - // weights for matmuls - float* wq; // (layer, dim, dim) - float* wk; // (layer, dim, dim) - float* wv; // (layer, dim, dim) - float* wo; // (layer, dim, dim) + // weights for matmuls. note dim == n_heads * head_size + float* wq; // (layer, dim, n_heads * head_size) + float* wk; // (layer, dim, n_kv_heads * head_size) + float* wv; // (layer, dim, n_kv_heads * head_size) + float* wo; // (layer, n_heads * head_size, dim) // weights for ffn float* w1; // (layer, hidden_dim, dim) float* w2; // (layer, dim, hidden_dim) @@ -82,6 +82,7 @@ typedef struct { void malloc_run_state(RunState* s, Config* p) { // we calloc instead of malloc to keep valgrind happy + int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; s->x = calloc(p->dim, sizeof(float)); s->xb = calloc(p->dim, sizeof(float)); s->xb2 = calloc(p->dim, sizeof(float)); @@ -93,8 +94,8 @@ void malloc_run_state(RunState* s, Config* p) { s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); s->logits = calloc(p->vocab_size, sizeof(float)); s->probindex = calloc(p->vocab_size, sizeof(ProbIndex)); - s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float)); - s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float)); + s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); + s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); // ensure all mallocs went fine if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q || !s->k || !s->v || !s->att || !s->logits || !s->key_cache @@ -124,19 +125,20 @@ void free_run_state(RunState* s) { // initialization: read from checkpoint void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f, int shared_weights) { + int head_size = p->dim / p->n_heads; float* ptr = f; w->token_embedding_table = ptr; ptr += p->vocab_size * p->dim; w->rms_att_weight = ptr; ptr += p->n_layers * p->dim; w->wq = ptr; - ptr += p->n_layers * p->dim * p->dim; + ptr += p->n_layers * p->dim * (p->n_heads * head_size); w->wk = ptr; - ptr += p->n_layers * p->dim * p->dim; + ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size); w->wv = ptr; - ptr += p->n_layers * p->dim * p->dim; + ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size); w->wo = ptr; - ptr += p->n_layers * p->dim * p->dim; + ptr += p->n_layers * (p->n_heads * head_size) * p->dim; w->rms_ffn_weight = ptr; ptr += p->n_layers * p->dim; w->w1 = ptr; @@ -148,7 +150,6 @@ void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f, int sha w->rms_final_weight = ptr; ptr += p->dim; w->freq_cis_real = ptr; - int head_size = p->dim / p->n_heads; ptr += p->seq_len * head_size / 2; w->freq_cis_imag = ptr; ptr += p->seq_len * head_size / 2; @@ -218,6 +219,8 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* // a few convenience variables float *x = s->x; int dim = p->dim; + int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; + int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery int hidden_dim = p->hidden_dim; int head_size = dim / p->n_heads; @@ -237,29 +240,33 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* // qkv matmuls for this position matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim); - matmul(s->k, s->xb, w->wk + l*dim*dim, dim, dim); - matmul(s->v, s->xb, w->wv + l*dim*dim, dim, dim); + matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim); + matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim); // RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head for (int i = 0; i < dim; i+=2) { float q0 = s->q[i]; float q1 = s->q[i+1]; - float k0 = s->k[i]; - float k1 = s->k[i+1]; float fcr = freq_cis_real_row[(i % head_size) / 2]; float fci = freq_cis_imag_row[(i % head_size) / 2]; s->q[i] = q0 * fcr - q1 * fci; s->q[i+1] = q0 * fci + q1 * fcr; + } + for (int i = 0; i < kv_dim; i+=2) { + float k0 = s->k[i]; + float k1 = s->k[i+1]; + float fcr = freq_cis_real_row[(i % head_size) / 2]; + float fci = freq_cis_imag_row[(i % head_size) / 2]; s->k[i] = k0 * fcr - k1 * fci; s->k[i+1] = k0 * fci + k1 * fcr; } // save key,value at this time step (pos) to our kv cache - int loff = l * p->seq_len * dim; // kv cache layer offset for convenience - float* key_cache_row = s->key_cache + loff + pos * dim; - float* value_cache_row = s->value_cache + loff + pos * dim; - memcpy(key_cache_row, s->k, dim*sizeof(*key_cache_row)); - memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row)); + int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience + float* key_cache_row = s->key_cache + loff + pos * kv_dim; + float* value_cache_row = s->value_cache + loff + pos * kv_dim; + memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row)); + memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row)); // multihead attention. iterate over all heads int h; @@ -272,7 +279,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* // iterate over all timesteps, including the current one for (int t = 0; t <= pos; t++) { // get the key vector for this head and at this timestep - float* k = s->key_cache + loff + t * dim + h * head_size; + float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size; // calculate the attention score as the dot product of q and k float score = 0.0f; for (int i = 0; i < head_size; i++) { @@ -291,7 +298,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* memset(xb, 0, head_size * sizeof(float)); for (int t = 0; t <= pos; t++) { // get the value vector for this head and at this timestep - float* v = s->value_cache + loff + t * dim + h * head_size; + float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size; // get the attention weight for this timestep float a = att[t]; // accumulate the weighted value into xb diff --git a/sample.py b/sample.py index 93c9407..2f66e7f 100644 --- a/sample.py +++ b/sample.py @@ -53,7 +53,6 @@ if compile: model = torch.compile(model) # requires PyTorch 2.0 (optional) # load the tokenizer -assert checkpoint["config"]["dataset"] == "tinystories" # TODO: generalize tokenizer_model = get_tokenizer_model_path(vocab_size=gptconf.vocab_size) enc = Tokenizer(tokenizer_model=tokenizer_model) diff --git a/train.py b/train.py index 24d6fa6..b1972dc 100644 --- a/train.py +++ b/train.py @@ -52,6 +52,7 @@ vocab_size = 32000 # the Llama 2 tokenizer has 32K tokens dim = 288 n_layers = 6 n_heads = 6 +n_kv_heads = 6 multiple_of = 32 dropout = 0.0 # adamw optimizer @@ -146,7 +147,7 @@ model_args = dict( dim=dim, n_layers=n_layers, n_heads=n_heads, - n_kv_heads=n_heads, + n_kv_heads=n_kv_heads, vocab_size=vocab_size, multiple_of=multiple_of, max_seq_len=max_seq_len,