From 591f1353c75da9bd94618e5f73e6412f4840ed30 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 18 Aug 2023 03:40:18 +0000 Subject: [PATCH] ok this works but is super slow because we are doing all the work in fp32 still --- model.py | 52 ++++++++------ run.c | 211 ++++++++++++++++++++++++++++++++++++------------------- 2 files changed, 169 insertions(+), 94 deletions(-) diff --git a/model.py b/model.py index f12f84c..7bdf6a7 100644 --- a/model.py +++ b/model.py @@ -339,7 +339,7 @@ class Transformer(nn.Module): return idx - def export(self, filepath='model.bin'): + def export(self, filepath='model.bin', group_size=64): """export the model weights in Q8_0 into .bin file to be read from C""" hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0] out_file = open(filepath, 'wb') @@ -356,7 +356,7 @@ class Transformer(nn.Module): b = struct.pack(f'{len(d)}b', *d) out_file.write(b) - def quantize_q80(w, group_size=32): + def quantize_q80(w): """ takes a tensor and returns the Q8_0 quantized version i.e. symmetric quantization into int8, range [-127,127] @@ -367,35 +367,44 @@ class Transformer(nn.Module): w = w.reshape(-1, group_size) # find the max in each group wmax = torch.abs(w).max(dim=1).values + # calculate the scaling factor such that float = quant * scale + scale = wmax / 127.0 # scale into range [-127, 127] - scaled = w/wmax[:,None]*127 + quant = w / scale[:,None] # round to nearest integer - int8val = torch.round(scaled).to(torch.int8) + int8val = torch.round(quant).to(torch.int8) # dequantize by rescaling - fp32val = (int8val.float()*wmax[:,None]/127.0).view(-1) + fp32val = (int8val.float() * scale[:,None]).view(-1) fp32valr = fp32val.reshape(-1, group_size) # calculate the max error in each group err = torch.abs(fp32valr - w).max(dim=1).values # find the max error across all groups maxerr = err.max().item() - return int8val, wmax, maxerr + return int8val, scale, maxerr # first write out the header. the header will be 256 bytes + nbytes = 0 # 1) write magic, which will be uint32 of "ak42" in ASCII out_file.write(struct.pack('I', 0x616b3432)) - # 2) write version, which will be uint32 - out_file.write(struct.pack('I', 1)) + nbytes += 4 + # 2) write version, which will be int + out_file.write(struct.pack('i', 1)) + nbytes += 4 # 3) write the params, which will be 7 ints p = self.params n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads - header = struct.pack('IIIIIII', p.dim, hidden_dim, p.n_layers, p.n_heads, + header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, n_kv_heads, p.vocab_size, p.max_seq_len) out_file.write(header) + nbytes += 7*4 # 4) write some other flags shared_classifier = 1 # we do share a classifier, write flag as a byte out_file.write(struct.pack('B', shared_classifier)) - # ok so we so far used 4 + 4 + 7*4 + 1 = 37 bytes - pad = 256 - 37 # pad the rest with zeros + nbytes += 1 + out_file.write(struct.pack('i', group_size)) # group size used for quantization + nbytes += 4 + pad = 256 - nbytes # pad the rest with zeros + assert pad >= 0 out_file.write(b'\0' * pad) # now that the header is done, let's write out the model @@ -420,26 +429,25 @@ class Transformer(nn.Module): ] ew = [] + scales = [] for i, w in enumerate(weights): - - # find a good group size for this weight tensor - gs = 64 # group size we want - while w.numel() % gs != 0: - gs //= 2 # but fall back as needed - if gs <= 8: - print(f"WARNING: weight of shape {tuple(w.shape)} caused group size to fall down to {gs}") + assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}" # quantize this weight - q, s, err = quantize_q80(w, group_size=gs) + q, s, err = quantize_q80(w) # save to file - out_file.write(struct.pack('I', gs)) # save the group size as uint32 serialize_int8(q) # save the tensor in int8 - serialize_fp32(s) # save the scaling factors in fp32 + scales.append(s) # we'll do all the scales after all the qs # logging ew.append((err, w.shape)) - print(f"{i:3d} quantized {tuple(w.shape)} to Q8_0 with group size {gs} and max error {err}") + print(f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}") + + # save the scaling factors in fp32 here + # this is done to keep all the weights contiquous, making pointer arithmetic easier in C + for s in scales: + serialize_fp32(s) # print the highest error across all weights, should be very small, e.g. O(~0.001) ew.sort(reverse=True) diff --git a/run.c b/run.c index 10d468b..b33209a 100644 --- a/run.c +++ b/run.c @@ -1,5 +1,6 @@ /* Inference for Llama-2 Transformer model in pure C */ +#include #include #include #include @@ -13,41 +14,49 @@ #include #include #endif + +// ---------------------------------------------------------------------------- +// Globals + +int GS = 0; // group size global for quantization + // ---------------------------------------------------------------------------- // Transformer and RunState structs, and related memory management typedef struct { - int dim; // transformer dimension - int hidden_dim; // for ffn layers - int n_layers; // number of layers - int n_heads; // number of query heads - int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery) - int vocab_size; // vocabulary size, usually 256 (byte-level) - int seq_len; // max sequence length + int dim; // transformer dimension + int hidden_dim; // dimension of the inner layer in the MLP + int n_layers; // number of layers + int n_heads; // number of query heads + int n_kv_heads; // number of key & value heads (can be < query heads because of multiquery) + int vocab_size; // vocabulary size (size of the classifier weights) + int seq_len; // max sequence length the model was trained with } Config; +typedef struct { + int8_t* q; // quantized values + float* s; // scaling factors +} QuantizedTensor; + typedef struct { // token embedding table - float* token_embedding_table; // (vocab_size, dim) + QuantizedTensor token_embedding_table; // (vocab_size, dim) // weights for rmsnorms float* rms_att_weight; // (layer, dim) rmsnorm weights float* rms_ffn_weight; // (layer, dim) // weights for matmuls. note dim == n_heads * head_size - float* wq; // (layer, dim, n_heads * head_size) - float* wk; // (layer, dim, n_kv_heads * head_size) - float* wv; // (layer, dim, n_kv_heads * head_size) - float* wo; // (layer, n_heads * head_size, dim) + QuantizedTensor wq; // (layer, dim, n_heads * head_size) + QuantizedTensor wk; // (layer, dim, n_kv_heads * head_size) + QuantizedTensor wv; // (layer, dim, n_kv_heads * head_size) + QuantizedTensor wo; // (layer, n_heads * head_size, dim) // weights for ffn - float* w1; // (layer, hidden_dim, dim) - float* w2; // (layer, dim, hidden_dim) - float* w3; // (layer, hidden_dim, dim) + QuantizedTensor w1; // (layer, hidden_dim, dim) + QuantizedTensor w2; // (layer, dim, hidden_dim) + QuantizedTensor w3; // (layer, hidden_dim, dim) // final rmsnorm float* rms_final_weight; // (dim,) - // freq_cis for RoPE relatively positional embeddings (not used anymore) - float* freq_cis_real; // (seq_len, head_size/2) - float* freq_cis_imag; // (seq_len, head_size/2) // (optional) classifier weights for the logits, on the last layer - float* wcls; + QuantizedTensor wcls; // (dim, vocab_size) } TransformerWeights; typedef struct { @@ -117,35 +126,67 @@ void free_run_state(RunState* s) { // ---------------------------------------------------------------------------- // initialization: read from checkpoint -void checkpoint_init_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) { +void checkpoint_init_weights(TransformerWeights *w, Config* p, void* ptr, uint8_t shared_classifier) { int head_size = p->dim / p->n_heads; - w->token_embedding_table = ptr; - ptr += p->vocab_size * p->dim; - w->rms_att_weight = ptr; - ptr += p->n_layers * p->dim; - w->wq = ptr; - ptr += p->n_layers * p->dim * (p->n_heads * head_size); - w->wk = ptr; - ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size); - w->wv = ptr; - ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size); - w->wo = ptr; - ptr += p->n_layers * (p->n_heads * head_size) * p->dim; - w->rms_ffn_weight = ptr; - ptr += p->n_layers * p->dim; - w->w1 = ptr; - ptr += p->n_layers * p->dim * p->hidden_dim; - w->w2 = ptr; - ptr += p->n_layers * p->hidden_dim * p->dim; - w->w3 = ptr; - ptr += p->n_layers * p->dim * p->hidden_dim; - w->rms_final_weight = ptr; - ptr += p->dim; - w->freq_cis_real = ptr; - ptr += p->seq_len * head_size / 2; - w->freq_cis_imag = ptr; - ptr += p->seq_len * head_size / 2; - w->wcls = shared_weights ? w->token_embedding_table : ptr; + + // first are the parameters that are kept in fp32 (the rmsnorm (1D) weights) + float* fptr = (float*) ptr; // cast our pointer to float* + w->rms_att_weight = fptr; + fptr += p->n_layers * p->dim; + w->rms_ffn_weight = fptr; + fptr += p->n_layers * p->dim; + w->rms_final_weight = fptr; + fptr += p->dim; + + // now read all the quantized weights + int8_t* qptr = (int8_t*) fptr; // now cast the pointer to int8_t* + w->token_embedding_table.q = qptr; + qptr += p->vocab_size * p->dim; + w->wq.q = qptr; + qptr += p->n_layers * p->dim * (p->n_heads * head_size); + w->wk.q = qptr; + qptr += p->n_layers * p->dim * (p->n_kv_heads * head_size); + w->wv.q = qptr; + qptr += p->n_layers * p->dim * (p->n_kv_heads * head_size); + w->wo.q = qptr; + qptr += p->n_layers * (p->n_heads * head_size) * p->dim; + w->w1.q = qptr; + qptr += p->n_layers * p->dim * p->hidden_dim; + w->w2.q = qptr; + qptr += p->n_layers * p->hidden_dim * p->dim; + w->w3.q = qptr; + qptr += p->n_layers * p->dim * p->hidden_dim; + if (shared_classifier) { + w->wcls.q = w->token_embedding_table.q; + } else { + w->wcls.q = qptr; + qptr += p->dim * p->vocab_size; + } + + // and finally all the associated scaling factors + float* sptr = (float*) qptr; // cast pointer back to float* + w->token_embedding_table.s = sptr; + sptr += p->vocab_size * p->dim / GS; + w->wq.s = sptr; + sptr += p->n_layers * p->dim * (p->n_heads * head_size) / GS; + w->wk.s = sptr; + sptr += p->n_layers * p->dim * (p->n_kv_heads * head_size) / GS; + w->wv.s = sptr; + sptr += p->n_layers * p->dim * (p->n_kv_heads * head_size) / GS; + w->wo.s = sptr; + sptr += p->n_layers * (p->n_heads * head_size) * p->dim / GS; + w->w1.s = sptr; + sptr += p->n_layers * p->dim * p->hidden_dim / GS; + w->w2.s = sptr; + sptr += p->n_layers * p->hidden_dim * p->dim / GS; + w->w3.s = sptr; + sptr += p->n_layers * p->dim * p->hidden_dim / GS; + if (shared_classifier) { + w->wcls.s = w->token_embedding_table.s; + } else { + w->wcls.s = sptr; + sptr += p->dim * p->vocab_size / GS; + } } // ---------------------------------------------------------------------------- @@ -186,20 +227,30 @@ void softmax(float* x, int size) { } } -void matmul(float* xout, float* x, float* w, int n, int d) { +void matmul(float* xout, float* x, int8_t* q, float* s, int n, int d) { // W (d,n) @ x (n,) -> xout (d,) // by far the most amount of time is spent inside this little function + + // do the matmul int i; #pragma omp parallel for private(i) for (i = 0; i < d; i++) { float val = 0.0f; for (int j = 0; j < n; j++) { - val += w[i * n + j] * x[j]; + int ix = i * n + j; + float wij = q[ix] * s[ix / GS]; + val += wij * x[j]; } xout[i] = val; } } +void dequantize(int8_t* q, float* s, float* x, int n) { + for (int i = 0; i < n; i++) { + x[i] = q[i] * s[i / GS]; + } +} + void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) { // a few convenience variables @@ -210,9 +261,9 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* int hidden_dim = p->hidden_dim; int head_size = dim / p->n_heads; - // copy the token embedding into x - float* content_row = &(w->token_embedding_table[token * dim]); - memcpy(x, content_row, dim*sizeof(*x)); + // dequantize the token embedding into a float x + QuantizedTensor tok = w->token_embedding_table; + dequantize(tok.q + token * dim, tok.s + token * dim / GS, x, dim); // forward all the layers for(int l = 0; l < p->n_layers; l++) { @@ -221,9 +272,9 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim); // qkv matmuls for this position - matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim); - matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim); - matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim); + matmul(s->q, s->xb, w->wq.q + l*dim*dim, w->wq.s + l*dim*dim/GS, dim, dim); + matmul(s->k, s->xb, w->wk.q + l*dim*kv_dim, w->wk.s + l*dim*kv_dim/GS, dim, kv_dim); + matmul(s->v, s->xb, w->wv.q + l*dim*kv_dim, w->wv.s + l*dim*kv_dim/GS, dim, kv_dim); // RoPE relative positional encoding: complex-valued rotate q and k in each head for (int i = 0; i < dim; i+=2) { @@ -290,7 +341,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* } // final matmul to get the output of the attention - matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim); + matmul(s->xb2, s->xb, w->wo.q + l*dim*dim, w->wo.s + l*dim*dim/GS, dim, dim); // residual connection back into x for (int i = 0; i < dim; i++) { @@ -302,8 +353,8 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) // first calculate self.w1(x) and self.w3(x) - matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim); - matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim); + matmul(s->hb, s->xb, w->w1.q + l*dim*hidden_dim, w->w1.s + l*dim*hidden_dim/GS, dim, hidden_dim); + matmul(s->hb2, s->xb, w->w3.q + l*dim*hidden_dim, w->w3.s + l*dim*hidden_dim/GS, dim, hidden_dim); // F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid for (int i = 0; i < hidden_dim; i++) { @@ -316,7 +367,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* } // final matmul to get the output of the ffn - matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim); + matmul(s->xb, s->hb, w->w2.q + l*dim*hidden_dim, w->w2.s + l*dim*hidden_dim/GS, hidden_dim, dim); // residual connection for (int i = 0; i < dim; i++) { @@ -328,7 +379,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* rmsnorm(x, x, w->rms_final_weight, dim); // classifier into logits - matmul(s->logits, x, w->wcls, p->dim, p->vocab_size); + matmul(s->logits, x, w->wcls.q, w->wcls.s, dim, p->vocab_size); } // ---------------------------------------------------------------------------- @@ -597,35 +648,51 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'z') { tokenizer = argv[i + 1]; } else { error_usage(); } } - if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);} + // input validations + // our rng cannot accoupt 0 as a seed, so might as well use time(NULL) as default + if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL); } // read in the model.bin file Config config; TransformerWeights weights; int fd = 0; // file descriptor for memory mapping - float* data = NULL; // memory mapped data pointer - ssize_t file_size; // size of the checkpoint file in bytes + void* data = NULL; // memory mapped data pointer + ssize_t file_size; // size of the checkpoint file in bytes { + // first "peak" the checkpoint and extract metadata FILE *file = fopen(checkpoint, "rb"); if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); return 1; } - // read in the config header + // read in magic number (uint32), has to be 0x616b3432, i.e. "ak42" in ASCII + uint32_t magic_number; + if (fread(&magic_number, sizeof(uint32_t), 1, file) != 1) { return 1; } + if (magic_number != 0x616b3432) { fprintf(stderr, "Bad magic number\n"); return 1; } + // read in the version number (uint32), has to be 1 + int version; + if (fread(&version, sizeof(int), 1, file) != 1) { return 1; } + if (version != 1) { fprintf(stderr, "Bad version number\n"); return 1; } + int header_size = 256; // the header size for version 1 in bytes + // read in the Config if (fread(&config, sizeof(Config), 1, file) != 1) { return 1; } - // negative vocab size is hacky way of signaling unshared weights. bit yikes. - int shared_weights = config.vocab_size > 0 ? 1 : 0; - config.vocab_size = abs(config.vocab_size); - // figure out the file size + // read in flags + uint8_t shared_classifier; // a byte to indicate if the classifier is shared + if (fread(&shared_classifier, sizeof(uint8_t), 1, file) != 1) { return 1; } + int group_size; // the group size used in quantization + if (fread(&group_size, sizeof(int), 1, file) != 1) { return 1; } + GS = group_size; // set as global, as it will be used in many places + // seek all the way to the end to figure out the full file size fseek(file, 0, SEEK_END); // move file pointer to end of file file_size = ftell(file); // get the file size, in bytes fclose(file); - // memory map the Transformer weights into the data pointer + + // now memory map the Transformer weights into the data pointer fd = open(checkpoint, O_RDONLY); // open in read only mode if (fd == -1) { fprintf(stderr, "open failed!\n"); return 1; } data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0); if (data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); return 1; } - float* weights_ptr = data + sizeof(Config)/sizeof(float); - checkpoint_init_weights(&weights, &config, weights_ptr, shared_weights); + void* weights_ptr = (char*)data + header_size; // skip header bytes. char is 1 byte + checkpoint_init_weights(&weights, &config, weights_ptr, shared_classifier); } - // right now we cannot run for more than config.seq_len steps + // we should not run for more than config.seq_len steps if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; } // read in the tokenizer .bin file