Compare commits
4 Commits
master
...
feature/int8
| Author | SHA1 | Date | |
|---|---|---|---|
| 039a9713c2 | |||
| 591f1353c7 | |||
| e9cbe3e84f | |||
| 5e2e5b28f4 |
@@ -340,53 +340,125 @@ class Transformer(nn.Module):
|
||||
return idx
|
||||
|
||||
def export(self, filepath='model.bin'):
|
||||
"""export the model weights in fp32 into .bin file to be read from C"""
|
||||
f = open(filepath, 'wb')
|
||||
"""export the model weights in Q8_0 into .bin file to be read from C"""
|
||||
out_file = open(filepath, 'wb')
|
||||
|
||||
def serialize(t):
|
||||
# find the max group size that fits hidden_dim using backoff
|
||||
group_size = 64 # a good desired group size default
|
||||
while self.params.dim % group_size != 0:
|
||||
group_size //= 2
|
||||
print(f"using group size {group_size} for quantization")
|
||||
|
||||
def serialize_fp32(t):
|
||||
""" writes one fp32 tensor to file """
|
||||
d = t.detach().cpu().view(-1).numpy().astype(np.float32)
|
||||
b = struct.pack(f'{len(d)}f', *d)
|
||||
f.write(b)
|
||||
out_file.write(b)
|
||||
|
||||
# first write out the header
|
||||
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
|
||||
def serialize_int8(t):
|
||||
""" writes one int8 tensor to file """
|
||||
d = t.detach().cpu().view(-1).numpy().astype(np.int8)
|
||||
b = struct.pack(f'{len(d)}b', *d)
|
||||
out_file.write(b)
|
||||
|
||||
def quantize_q80(w):
|
||||
"""
|
||||
takes a tensor and returns the Q8_0 quantized version
|
||||
i.e. symmetric quantization into int8, range [-127,127]
|
||||
"""
|
||||
assert w.numel() % group_size == 0
|
||||
ori_shape = w.shape
|
||||
w = w.float() # convert to float32
|
||||
w = w.reshape(-1, group_size)
|
||||
# find the max in each group
|
||||
wmax = torch.abs(w).max(dim=1).values
|
||||
# calculate the scaling factor such that float = quant * scale
|
||||
scale = wmax / 127.0
|
||||
# scale into range [-127, 127]
|
||||
quant = w / scale[:,None]
|
||||
# round to nearest integer
|
||||
int8val = torch.round(quant).to(torch.int8)
|
||||
# dequantize by rescaling
|
||||
fp32val = (int8val.float() * scale[:,None]).view(-1)
|
||||
fp32valr = fp32val.reshape(-1, group_size)
|
||||
# calculate the max error in each group
|
||||
err = torch.abs(fp32valr - w).max(dim=1).values
|
||||
# find the max error across all groups
|
||||
maxerr = err.max().item()
|
||||
return int8val, scale, maxerr
|
||||
|
||||
# first write out the header. the header will be 256 bytes
|
||||
nbytes = 0
|
||||
# 1) write magic, which will be uint32 of "ak42" in ASCII
|
||||
out_file.write(struct.pack('I', 0x616b3432))
|
||||
nbytes += 4
|
||||
# 2) write version, which will be int
|
||||
out_file.write(struct.pack('i', 1))
|
||||
nbytes += 4
|
||||
# 3) write the params, which will be 7 ints
|
||||
p = self.params
|
||||
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
|
||||
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
||||
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
||||
n_kv_heads, p.vocab_size, p.max_seq_len)
|
||||
f.write(header)
|
||||
out_file.write(header)
|
||||
nbytes += 7*4
|
||||
# 4) write some other flags
|
||||
shared_classifier = 1 # we do share a classifier, write flag as a byte
|
||||
out_file.write(struct.pack('B', shared_classifier))
|
||||
nbytes += 1
|
||||
out_file.write(struct.pack('i', group_size)) # group size used for quantization
|
||||
nbytes += 4
|
||||
pad = 256 - nbytes # pad the rest with zeros
|
||||
assert pad >= 0
|
||||
out_file.write(b'\0' * pad)
|
||||
# now that the header is done, let's write out the model
|
||||
|
||||
# next write out the embedding weights
|
||||
serialize(self.tok_embeddings.weight)
|
||||
# first let's write out all the params that we are keeping in fp32: the norms
|
||||
for layer in self.layers: # attention norms
|
||||
serialize_fp32(layer.attention_norm.weight)
|
||||
for layer in self.layers: # MLP norms
|
||||
serialize_fp32(layer.ffn_norm.weight)
|
||||
serialize_fp32(self.norm.weight) # final pre-classifier norm
|
||||
|
||||
# now all the layers
|
||||
# attention weights
|
||||
for layer in self.layers:
|
||||
serialize(layer.attention_norm.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.attention.wq.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.attention.wk.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.attention.wv.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.attention.wo.weight)
|
||||
# ffn weights
|
||||
for layer in self.layers:
|
||||
serialize(layer.ffn_norm.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.feed_forward.w1.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.feed_forward.w2.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.feed_forward.w3.weight)
|
||||
# final rmsnorm
|
||||
serialize(self.norm.weight)
|
||||
# note: no need to write final classifier weights due to weight sharing
|
||||
# freqs_cis
|
||||
serialize(self.freqs_cos[:p.max_seq_len])
|
||||
serialize(self.freqs_sin[:p.max_seq_len])
|
||||
# now let's write out all the params that we are quantizing to Q8_0
|
||||
# note we skip classifier weights, which are shared with the embedding
|
||||
weights = [
|
||||
self.tok_embeddings.weight,
|
||||
*[layer.attention.wq.weight for layer in self.layers],
|
||||
*[layer.attention.wk.weight for layer in self.layers],
|
||||
*[layer.attention.wv.weight for layer in self.layers],
|
||||
*[layer.attention.wo.weight for layer in self.layers],
|
||||
*[layer.feed_forward.w1.weight for layer in self.layers],
|
||||
*[layer.feed_forward.w2.weight for layer in self.layers],
|
||||
*[layer.feed_forward.w3.weight for layer in self.layers],
|
||||
]
|
||||
|
||||
ew = []
|
||||
scales = []
|
||||
for i, w in enumerate(weights):
|
||||
assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}"
|
||||
|
||||
# quantize this weight
|
||||
q, s, err = quantize_q80(w)
|
||||
|
||||
# save to file
|
||||
serialize_int8(q) # save the tensor in int8
|
||||
scales.append(s) # we'll do all the scales after all the qs
|
||||
|
||||
# logging
|
||||
ew.append((err, w.shape))
|
||||
print(f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}")
|
||||
|
||||
# save the scaling factors in fp32 here
|
||||
# this is done to keep all the weights contiquous, making pointer arithmetic easier in C
|
||||
for s in scales:
|
||||
serialize_fp32(s)
|
||||
|
||||
# print the highest error across all weights, should be very small, e.g. O(~0.001)
|
||||
ew.sort(reverse=True)
|
||||
print(f"max quantization group error across all weights: {ew[0][0]}")
|
||||
|
||||
# write to binary file
|
||||
f.close()
|
||||
out_file.close()
|
||||
print(f"wrote {filepath}")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/* Inference for Llama-2 Transformer model in pure C */
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <ctype.h>
|
||||
@@ -13,41 +14,49 @@
|
||||
#include <unistd.h>
|
||||
#include <sys/mman.h>
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Globals
|
||||
|
||||
int GS = 0; // group size global for quantization of weights
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Transformer and RunState structs, and related memory management
|
||||
|
||||
typedef struct {
|
||||
int dim; // transformer dimension
|
||||
int hidden_dim; // for ffn layers
|
||||
int n_layers; // number of layers
|
||||
int n_heads; // number of query heads
|
||||
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
|
||||
int vocab_size; // vocabulary size, usually 256 (byte-level)
|
||||
int seq_len; // max sequence length
|
||||
int dim; // transformer dimension
|
||||
int hidden_dim; // dimension of the inner layer in the MLP
|
||||
int n_layers; // number of layers
|
||||
int n_heads; // number of query heads
|
||||
int n_kv_heads; // number of key & value heads (can be < query heads because of multiquery)
|
||||
int vocab_size; // vocabulary size (size of the classifier weights)
|
||||
int seq_len; // max sequence length the model was trained with
|
||||
} Config;
|
||||
|
||||
typedef struct {
|
||||
int8_t* q; // quantized values
|
||||
float* s; // scaling factors
|
||||
} QuantizedTensor;
|
||||
|
||||
typedef struct {
|
||||
// token embedding table
|
||||
float* token_embedding_table; // (vocab_size, dim)
|
||||
QuantizedTensor token_embedding_table; // (vocab_size, dim)
|
||||
// weights for rmsnorms
|
||||
float* rms_att_weight; // (layer, dim) rmsnorm weights
|
||||
float* rms_ffn_weight; // (layer, dim)
|
||||
// weights for matmuls. note dim == n_heads * head_size
|
||||
float* wq; // (layer, dim, n_heads * head_size)
|
||||
float* wk; // (layer, dim, n_kv_heads * head_size)
|
||||
float* wv; // (layer, dim, n_kv_heads * head_size)
|
||||
float* wo; // (layer, n_heads * head_size, dim)
|
||||
QuantizedTensor wq; // (layer, dim, n_heads * head_size)
|
||||
QuantizedTensor wk; // (layer, dim, n_kv_heads * head_size)
|
||||
QuantizedTensor wv; // (layer, dim, n_kv_heads * head_size)
|
||||
QuantizedTensor wo; // (layer, n_heads * head_size, dim)
|
||||
// weights for ffn
|
||||
float* w1; // (layer, hidden_dim, dim)
|
||||
float* w2; // (layer, dim, hidden_dim)
|
||||
float* w3; // (layer, hidden_dim, dim)
|
||||
QuantizedTensor w1; // (layer, hidden_dim, dim)
|
||||
QuantizedTensor w2; // (layer, dim, hidden_dim)
|
||||
QuantizedTensor w3; // (layer, hidden_dim, dim)
|
||||
// final rmsnorm
|
||||
float* rms_final_weight; // (dim,)
|
||||
// freq_cis for RoPE relatively positional embeddings (not used anymore)
|
||||
float* freq_cis_real; // (seq_len, head_size/2)
|
||||
float* freq_cis_imag; // (seq_len, head_size/2)
|
||||
// (optional) classifier weights for the logits, on the last layer
|
||||
float* wcls;
|
||||
QuantizedTensor wcls; // (dim, vocab_size)
|
||||
} TransformerWeights;
|
||||
|
||||
typedef struct {
|
||||
@@ -62,6 +71,8 @@ typedef struct {
|
||||
float *xb2; // an additional buffer just for convenience (dim,)
|
||||
float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
|
||||
float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
|
||||
QuantizedTensor xq; // quantized x (dim,)
|
||||
QuantizedTensor hq; // quantized hb (hidden_dim,)
|
||||
float *q; // query (dim,)
|
||||
float *k; // key (dim,)
|
||||
float *v; // value (dim,)
|
||||
@@ -81,6 +92,8 @@ void malloc_run_state(RunState* s, Config* p) {
|
||||
s->xb2 = calloc(p->dim, sizeof(float));
|
||||
s->hb = calloc(p->hidden_dim, sizeof(float));
|
||||
s->hb2 = calloc(p->hidden_dim, sizeof(float));
|
||||
s->xq = (QuantizedTensor) { .q = calloc(p->dim, sizeof(int8_t)), .s = calloc(p->dim, sizeof(float)) };
|
||||
s->hq = (QuantizedTensor) { .q = calloc(p->hidden_dim, sizeof(int8_t)), .s = calloc(p->hidden_dim, sizeof(float)) };
|
||||
s->q = calloc(p->dim, sizeof(float));
|
||||
s->k = calloc(kv_dim, sizeof(float));
|
||||
s->v = calloc(kv_dim, sizeof(float));
|
||||
@@ -104,6 +117,10 @@ void free_run_state(RunState* s) {
|
||||
free(s->xb2);
|
||||
free(s->hb);
|
||||
free(s->hb2);
|
||||
free(s->xq.q);
|
||||
free(s->xq.s);
|
||||
free(s->hq.q);
|
||||
free(s->hq.s);
|
||||
free(s->q);
|
||||
free(s->k);
|
||||
free(s->v);
|
||||
@@ -117,35 +134,67 @@ void free_run_state(RunState* s) {
|
||||
// ----------------------------------------------------------------------------
|
||||
// initialization: read from checkpoint
|
||||
|
||||
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
|
||||
void checkpoint_init_weights(TransformerWeights *w, Config* p, void* ptr, uint8_t shared_classifier) {
|
||||
int head_size = p->dim / p->n_heads;
|
||||
w->token_embedding_table = ptr;
|
||||
ptr += p->vocab_size * p->dim;
|
||||
w->rms_att_weight = ptr;
|
||||
ptr += p->n_layers * p->dim;
|
||||
w->wq = ptr;
|
||||
ptr += p->n_layers * p->dim * (p->n_heads * head_size);
|
||||
w->wk = ptr;
|
||||
ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
|
||||
w->wv = ptr;
|
||||
ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
|
||||
w->wo = ptr;
|
||||
ptr += p->n_layers * (p->n_heads * head_size) * p->dim;
|
||||
w->rms_ffn_weight = ptr;
|
||||
ptr += p->n_layers * p->dim;
|
||||
w->w1 = ptr;
|
||||
ptr += p->n_layers * p->dim * p->hidden_dim;
|
||||
w->w2 = ptr;
|
||||
ptr += p->n_layers * p->hidden_dim * p->dim;
|
||||
w->w3 = ptr;
|
||||
ptr += p->n_layers * p->dim * p->hidden_dim;
|
||||
w->rms_final_weight = ptr;
|
||||
ptr += p->dim;
|
||||
w->freq_cis_real = ptr;
|
||||
ptr += p->seq_len * head_size / 2;
|
||||
w->freq_cis_imag = ptr;
|
||||
ptr += p->seq_len * head_size / 2;
|
||||
w->wcls = shared_weights ? w->token_embedding_table : ptr;
|
||||
|
||||
// first are the parameters that are kept in fp32 (the rmsnorm (1D) weights)
|
||||
float* fptr = (float*) ptr; // cast our pointer to float*
|
||||
w->rms_att_weight = fptr;
|
||||
fptr += p->n_layers * p->dim;
|
||||
w->rms_ffn_weight = fptr;
|
||||
fptr += p->n_layers * p->dim;
|
||||
w->rms_final_weight = fptr;
|
||||
fptr += p->dim;
|
||||
|
||||
// now read all the quantized weights
|
||||
int8_t* qptr = (int8_t*) fptr; // now cast the pointer to int8_t*
|
||||
w->token_embedding_table.q = qptr;
|
||||
qptr += p->vocab_size * p->dim;
|
||||
w->wq.q = qptr;
|
||||
qptr += p->n_layers * p->dim * (p->n_heads * head_size);
|
||||
w->wk.q = qptr;
|
||||
qptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
|
||||
w->wv.q = qptr;
|
||||
qptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
|
||||
w->wo.q = qptr;
|
||||
qptr += p->n_layers * (p->n_heads * head_size) * p->dim;
|
||||
w->w1.q = qptr;
|
||||
qptr += p->n_layers * p->dim * p->hidden_dim;
|
||||
w->w2.q = qptr;
|
||||
qptr += p->n_layers * p->hidden_dim * p->dim;
|
||||
w->w3.q = qptr;
|
||||
qptr += p->n_layers * p->dim * p->hidden_dim;
|
||||
if (shared_classifier) {
|
||||
w->wcls.q = w->token_embedding_table.q;
|
||||
} else {
|
||||
w->wcls.q = qptr;
|
||||
qptr += p->dim * p->vocab_size;
|
||||
}
|
||||
|
||||
// and finally all the associated scaling factors
|
||||
float* sptr = (float*) qptr; // cast pointer back to float*
|
||||
w->token_embedding_table.s = sptr;
|
||||
sptr += p->vocab_size * p->dim / GS;
|
||||
w->wq.s = sptr;
|
||||
sptr += p->n_layers * p->dim * (p->n_heads * head_size) / GS;
|
||||
w->wk.s = sptr;
|
||||
sptr += p->n_layers * p->dim * (p->n_kv_heads * head_size) / GS;
|
||||
w->wv.s = sptr;
|
||||
sptr += p->n_layers * p->dim * (p->n_kv_heads * head_size) / GS;
|
||||
w->wo.s = sptr;
|
||||
sptr += p->n_layers * (p->n_heads * head_size) * p->dim / GS;
|
||||
w->w1.s = sptr;
|
||||
sptr += p->n_layers * p->dim * p->hidden_dim / GS;
|
||||
w->w2.s = sptr;
|
||||
sptr += p->n_layers * p->hidden_dim * p->dim / GS;
|
||||
w->w3.s = sptr;
|
||||
sptr += p->n_layers * p->dim * p->hidden_dim / GS;
|
||||
if (shared_classifier) {
|
||||
w->wcls.s = w->token_embedding_table.s;
|
||||
} else {
|
||||
w->wcls.s = sptr;
|
||||
sptr += p->dim * p->vocab_size / GS;
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
@@ -186,20 +235,67 @@ void softmax(float* x, int size) {
|
||||
}
|
||||
}
|
||||
|
||||
void matmul(float* xout, float* x, float* w, int n, int d) {
|
||||
void matmul(float* xout, int8_t* xq, float* xs, int8_t* wq, float* ws, int n, int d) {
|
||||
// W (d,n) @ x (n,) -> xout (d,)
|
||||
// by far the most amount of time is spent inside this little function
|
||||
// inputs to this function are both quantized
|
||||
|
||||
int i;
|
||||
#pragma omp parallel for private(i)
|
||||
for (i = 0; i < d; i++) {
|
||||
|
||||
float val = 0.0f;
|
||||
for (int j = 0; j < n; j++) {
|
||||
val += w[i * n + j] * x[j];
|
||||
int32_t ival = 0;
|
||||
int in = i * n;
|
||||
|
||||
// do the matmul in groups of GS
|
||||
int j;
|
||||
for (j = 0; j <= n - GS; j += GS) {
|
||||
for (int k = 0; k < GS; k++) {
|
||||
ival += ((int32_t) xq[j + k]) * ((int32_t) wq[in + j + k]);
|
||||
}
|
||||
val += ((float) ival) * ws[(in + j) / GS] * xs[j / GS];
|
||||
ival = 0;
|
||||
}
|
||||
|
||||
xout[i] = val;
|
||||
}
|
||||
}
|
||||
|
||||
void dequantize(int8_t* q, float* s, float* x, int n) {
|
||||
for (int i = 0; i < n; i++) {
|
||||
x[i] = q[i] * s[i / GS];
|
||||
}
|
||||
}
|
||||
|
||||
void quantize(float* x, int8_t* q, float* s, int n) {
|
||||
int num_groups = n / GS;
|
||||
float Q_MAX = 127.0f;
|
||||
|
||||
for (int group = 0; group < num_groups; group++) {
|
||||
|
||||
// find the max absolute value in the current group
|
||||
float wmax = 0.0;
|
||||
for (int i = 0; i < GS; i++) {
|
||||
float val = fabs(x[group * GS + i]);
|
||||
if (val > wmax) {
|
||||
wmax = val;
|
||||
}
|
||||
}
|
||||
|
||||
// calculate and write the scaling factor
|
||||
float scale = wmax / Q_MAX;
|
||||
s[group] = scale;
|
||||
|
||||
// calculate and write the quantized values
|
||||
for (int i = 0; i < GS; i++) {
|
||||
float quant_value = x[group * GS + i] / scale; // scale
|
||||
int8_t quantized = (int8_t) round(quant_value); // round and clamp
|
||||
q[group * GS + i] = quantized;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) {
|
||||
|
||||
// a few convenience variables
|
||||
@@ -210,9 +306,9 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
int hidden_dim = p->hidden_dim;
|
||||
int head_size = dim / p->n_heads;
|
||||
|
||||
// copy the token embedding into x
|
||||
float* content_row = &(w->token_embedding_table[token * dim]);
|
||||
memcpy(x, content_row, dim*sizeof(*x));
|
||||
// dequantize the token embedding into a float x
|
||||
QuantizedTensor tok = w->token_embedding_table;
|
||||
dequantize(tok.q + token * dim, tok.s + token * dim / GS, x, dim);
|
||||
|
||||
// forward all the layers
|
||||
for(int l = 0; l < p->n_layers; l++) {
|
||||
@@ -221,9 +317,10 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
|
||||
|
||||
// qkv matmuls for this position
|
||||
matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
|
||||
matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
|
||||
matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
|
||||
quantize(s->xb, s->xq.q, s->xq.s, dim);
|
||||
matmul(s->q, s->xq.q, s->xq.s, w->wq.q + l*dim*dim, w->wq.s + l*dim*dim/GS, dim, dim);
|
||||
matmul(s->k, s->xq.q, s->xq.s, w->wk.q + l*dim*kv_dim, w->wk.s + l*dim*kv_dim/GS, dim, kv_dim);
|
||||
matmul(s->v, s->xq.q, s->xq.s, w->wv.q + l*dim*kv_dim, w->wv.s + l*dim*kv_dim/GS, dim, kv_dim);
|
||||
|
||||
// RoPE relative positional encoding: complex-valued rotate q and k in each head
|
||||
for (int i = 0; i < dim; i+=2) {
|
||||
@@ -290,7 +387,8 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
}
|
||||
|
||||
// final matmul to get the output of the attention
|
||||
matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);
|
||||
quantize(s->xb, s->xq.q, s->xq.s, dim);
|
||||
matmul(s->xb2, s->xq.q, s->xq.s, w->wo.q + l*dim*dim, w->wo.s + l*dim*dim/GS, dim, dim);
|
||||
|
||||
// residual connection back into x
|
||||
for (int i = 0; i < dim; i++) {
|
||||
@@ -302,8 +400,9 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
|
||||
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
// first calculate self.w1(x) and self.w3(x)
|
||||
matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
|
||||
matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);
|
||||
quantize(s->xb, s->xq.q, s->xq.s, dim);
|
||||
matmul(s->hb, s->xq.q, s->xq.s, w->w1.q + l*dim*hidden_dim, w->w1.s + l*dim*hidden_dim/GS, dim, hidden_dim);
|
||||
matmul(s->hb2, s->xq.q, s->xq.s, w->w3.q + l*dim*hidden_dim, w->w3.s + l*dim*hidden_dim/GS, dim, hidden_dim);
|
||||
|
||||
// F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid
|
||||
for (int i = 0; i < hidden_dim; i++) {
|
||||
@@ -316,7 +415,8 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
}
|
||||
|
||||
// final matmul to get the output of the ffn
|
||||
matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);
|
||||
quantize(s->hb, s->hq.q, s->hq.s, hidden_dim);
|
||||
matmul(s->xb, s->hq.q, s->hq.s, w->w2.q + l*dim*hidden_dim, w->w2.s + l*dim*hidden_dim/GS, hidden_dim, dim);
|
||||
|
||||
// residual connection
|
||||
for (int i = 0; i < dim; i++) {
|
||||
@@ -328,7 +428,8 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
rmsnorm(x, x, w->rms_final_weight, dim);
|
||||
|
||||
// classifier into logits
|
||||
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
|
||||
quantize(x, s->xq.q, s->xq.s, dim);
|
||||
matmul(s->logits, s->xq.q, s->xq.s, w->wcls.q, w->wcls.s, dim, p->vocab_size);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
@@ -597,35 +698,51 @@ int main(int argc, char *argv[]) {
|
||||
else if (argv[i][1] == 'z') { tokenizer = argv[i + 1]; }
|
||||
else { error_usage(); }
|
||||
}
|
||||
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}
|
||||
// input validations
|
||||
// our rng cannot accoupt 0 as a seed, so might as well use time(NULL) as default
|
||||
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL); }
|
||||
|
||||
// read in the model.bin file
|
||||
Config config;
|
||||
TransformerWeights weights;
|
||||
int fd = 0; // file descriptor for memory mapping
|
||||
float* data = NULL; // memory mapped data pointer
|
||||
ssize_t file_size; // size of the checkpoint file in bytes
|
||||
void* data = NULL; // memory mapped data pointer
|
||||
ssize_t file_size; // size of the checkpoint file in bytes
|
||||
{
|
||||
// first "peak" the checkpoint and extract metadata
|
||||
FILE *file = fopen(checkpoint, "rb");
|
||||
if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); return 1; }
|
||||
// read in the config header
|
||||
// read in magic number (uint32), has to be 0x616b3432, i.e. "ak42" in ASCII
|
||||
uint32_t magic_number;
|
||||
if (fread(&magic_number, sizeof(uint32_t), 1, file) != 1) { return 1; }
|
||||
if (magic_number != 0x616b3432) { fprintf(stderr, "Bad magic number\n"); return 1; }
|
||||
// read in the version number (uint32), has to be 1
|
||||
int version;
|
||||
if (fread(&version, sizeof(int), 1, file) != 1) { return 1; }
|
||||
if (version != 1) { fprintf(stderr, "Bad version number\n"); return 1; }
|
||||
int header_size = 256; // the header size for version 1 in bytes
|
||||
// read in the Config
|
||||
if (fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
|
||||
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
|
||||
int shared_weights = config.vocab_size > 0 ? 1 : 0;
|
||||
config.vocab_size = abs(config.vocab_size);
|
||||
// figure out the file size
|
||||
// read in flags
|
||||
uint8_t shared_classifier; // a byte to indicate if the classifier is shared
|
||||
if (fread(&shared_classifier, sizeof(uint8_t), 1, file) != 1) { return 1; }
|
||||
int group_size; // the group size used in quantization
|
||||
if (fread(&group_size, sizeof(int), 1, file) != 1) { return 1; }
|
||||
GS = group_size; // set as global, as it will be used in many places
|
||||
// seek all the way to the end to figure out the full file size
|
||||
fseek(file, 0, SEEK_END); // move file pointer to end of file
|
||||
file_size = ftell(file); // get the file size, in bytes
|
||||
fclose(file);
|
||||
// memory map the Transformer weights into the data pointer
|
||||
|
||||
// now memory map the Transformer weights into the data pointer
|
||||
fd = open(checkpoint, O_RDONLY); // open in read only mode
|
||||
if (fd == -1) { fprintf(stderr, "open failed!\n"); return 1; }
|
||||
data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0);
|
||||
if (data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); return 1; }
|
||||
float* weights_ptr = data + sizeof(Config)/sizeof(float);
|
||||
checkpoint_init_weights(&weights, &config, weights_ptr, shared_weights);
|
||||
void* weights_ptr = (char*)data + header_size; // skip header bytes. char is 1 byte
|
||||
checkpoint_init_weights(&weights, &config, weights_ptr, shared_classifier);
|
||||
}
|
||||
// right now we cannot run for more than config.seq_len steps
|
||||
// we should not run for more than config.seq_len steps
|
||||
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
|
||||
|
||||
// read in the tokenizer .bin file
|
||||
|
||||
Reference in New Issue
Block a user