ok this first version works but i don't think is ready to merge, have to think on more
This commit is contained in:
@@ -339,11 +339,16 @@ class Transformer(nn.Module):
|
||||
|
||||
return idx
|
||||
|
||||
def export(self, filepath='model.bin', group_size=64):
|
||||
def export(self, filepath='model.bin'):
|
||||
"""export the model weights in Q8_0 into .bin file to be read from C"""
|
||||
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
|
||||
out_file = open(filepath, 'wb')
|
||||
|
||||
# find the max group size that fits hidden_dim using backoff
|
||||
group_size = 64 # a good desired group size default
|
||||
while self.params.dim % group_size != 0:
|
||||
group_size //= 2
|
||||
print(f"using group size {group_size} for quantization")
|
||||
|
||||
def serialize_fp32(t):
|
||||
""" writes one fp32 tensor to file """
|
||||
d = t.detach().cpu().view(-1).numpy().astype(np.float32)
|
||||
@@ -392,6 +397,7 @@ class Transformer(nn.Module):
|
||||
nbytes += 4
|
||||
# 3) write the params, which will be 7 ints
|
||||
p = self.params
|
||||
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
|
||||
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
||||
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
||||
n_kv_heads, p.vocab_size, p.max_seq_len)
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
// ----------------------------------------------------------------------------
|
||||
// Globals
|
||||
|
||||
int GS = 0; // group size global for quantization
|
||||
int GS = 0; // group size global for quantization of weights
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Transformer and RunState structs, and related memory management
|
||||
@@ -71,6 +71,8 @@ typedef struct {
|
||||
float *xb2; // an additional buffer just for convenience (dim,)
|
||||
float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
|
||||
float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
|
||||
QuantizedTensor xq; // quantized x (dim,)
|
||||
QuantizedTensor hq; // quantized hb (hidden_dim,)
|
||||
float *q; // query (dim,)
|
||||
float *k; // key (dim,)
|
||||
float *v; // value (dim,)
|
||||
@@ -90,6 +92,8 @@ void malloc_run_state(RunState* s, Config* p) {
|
||||
s->xb2 = calloc(p->dim, sizeof(float));
|
||||
s->hb = calloc(p->hidden_dim, sizeof(float));
|
||||
s->hb2 = calloc(p->hidden_dim, sizeof(float));
|
||||
s->xq = (QuantizedTensor) { .q = calloc(p->dim, sizeof(int8_t)), .s = calloc(p->dim, sizeof(float)) };
|
||||
s->hq = (QuantizedTensor) { .q = calloc(p->hidden_dim, sizeof(int8_t)), .s = calloc(p->hidden_dim, sizeof(float)) };
|
||||
s->q = calloc(p->dim, sizeof(float));
|
||||
s->k = calloc(kv_dim, sizeof(float));
|
||||
s->v = calloc(kv_dim, sizeof(float));
|
||||
@@ -113,6 +117,10 @@ void free_run_state(RunState* s) {
|
||||
free(s->xb2);
|
||||
free(s->hb);
|
||||
free(s->hb2);
|
||||
free(s->xq.q);
|
||||
free(s->xq.s);
|
||||
free(s->hq.q);
|
||||
free(s->hq.s);
|
||||
free(s->q);
|
||||
free(s->k);
|
||||
free(s->v);
|
||||
@@ -227,20 +235,29 @@ void softmax(float* x, int size) {
|
||||
}
|
||||
}
|
||||
|
||||
void matmul(float* xout, float* x, int8_t* q, float* s, int n, int d) {
|
||||
void matmul(float* xout, int8_t* xq, float* xs, int8_t* wq, float* ws, int n, int d) {
|
||||
// W (d,n) @ x (n,) -> xout (d,)
|
||||
// by far the most amount of time is spent inside this little function
|
||||
// inputs to this function are both quantized
|
||||
|
||||
// do the matmul
|
||||
int i;
|
||||
#pragma omp parallel for private(i)
|
||||
for (i = 0; i < d; i++) {
|
||||
|
||||
float val = 0.0f;
|
||||
for (int j = 0; j < n; j++) {
|
||||
int ix = i * n + j;
|
||||
float wij = q[ix] * s[ix / GS];
|
||||
val += wij * x[j];
|
||||
int32_t ival = 0;
|
||||
int in = i * n;
|
||||
|
||||
// do the matmul in groups of GS
|
||||
int j;
|
||||
for (j = 0; j <= n - GS; j += GS) {
|
||||
for (int k = 0; k < GS; k++) {
|
||||
ival += ((int32_t) xq[j + k]) * ((int32_t) wq[in + j + k]);
|
||||
}
|
||||
val += ((float) ival) * ws[(in + j) / GS] * xs[j / GS];
|
||||
ival = 0;
|
||||
}
|
||||
|
||||
xout[i] = val;
|
||||
}
|
||||
}
|
||||
@@ -251,6 +268,34 @@ void dequantize(int8_t* q, float* s, float* x, int n) {
|
||||
}
|
||||
}
|
||||
|
||||
void quantize(float* x, int8_t* q, float* s, int n) {
|
||||
int num_groups = n / GS;
|
||||
float Q_MAX = 127.0f;
|
||||
|
||||
for (int group = 0; group < num_groups; group++) {
|
||||
|
||||
// find the max absolute value in the current group
|
||||
float wmax = 0.0;
|
||||
for (int i = 0; i < GS; i++) {
|
||||
float val = fabs(x[group * GS + i]);
|
||||
if (val > wmax) {
|
||||
wmax = val;
|
||||
}
|
||||
}
|
||||
|
||||
// calculate and write the scaling factor
|
||||
float scale = wmax / Q_MAX;
|
||||
s[group] = scale;
|
||||
|
||||
// calculate and write the quantized values
|
||||
for (int i = 0; i < GS; i++) {
|
||||
float quant_value = x[group * GS + i] / scale; // scale
|
||||
int8_t quantized = (int8_t) round(quant_value); // round and clamp
|
||||
q[group * GS + i] = quantized;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) {
|
||||
|
||||
// a few convenience variables
|
||||
@@ -272,9 +317,10 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
|
||||
|
||||
// qkv matmuls for this position
|
||||
matmul(s->q, s->xb, w->wq.q + l*dim*dim, w->wq.s + l*dim*dim/GS, dim, dim);
|
||||
matmul(s->k, s->xb, w->wk.q + l*dim*kv_dim, w->wk.s + l*dim*kv_dim/GS, dim, kv_dim);
|
||||
matmul(s->v, s->xb, w->wv.q + l*dim*kv_dim, w->wv.s + l*dim*kv_dim/GS, dim, kv_dim);
|
||||
quantize(s->xb, s->xq.q, s->xq.s, dim);
|
||||
matmul(s->q, s->xq.q, s->xq.s, w->wq.q + l*dim*dim, w->wq.s + l*dim*dim/GS, dim, dim);
|
||||
matmul(s->k, s->xq.q, s->xq.s, w->wk.q + l*dim*kv_dim, w->wk.s + l*dim*kv_dim/GS, dim, kv_dim);
|
||||
matmul(s->v, s->xq.q, s->xq.s, w->wv.q + l*dim*kv_dim, w->wv.s + l*dim*kv_dim/GS, dim, kv_dim);
|
||||
|
||||
// RoPE relative positional encoding: complex-valued rotate q and k in each head
|
||||
for (int i = 0; i < dim; i+=2) {
|
||||
@@ -341,7 +387,8 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
}
|
||||
|
||||
// final matmul to get the output of the attention
|
||||
matmul(s->xb2, s->xb, w->wo.q + l*dim*dim, w->wo.s + l*dim*dim/GS, dim, dim);
|
||||
quantize(s->xb, s->xq.q, s->xq.s, dim);
|
||||
matmul(s->xb2, s->xq.q, s->xq.s, w->wo.q + l*dim*dim, w->wo.s + l*dim*dim/GS, dim, dim);
|
||||
|
||||
// residual connection back into x
|
||||
for (int i = 0; i < dim; i++) {
|
||||
@@ -353,8 +400,9 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
|
||||
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
// first calculate self.w1(x) and self.w3(x)
|
||||
matmul(s->hb, s->xb, w->w1.q + l*dim*hidden_dim, w->w1.s + l*dim*hidden_dim/GS, dim, hidden_dim);
|
||||
matmul(s->hb2, s->xb, w->w3.q + l*dim*hidden_dim, w->w3.s + l*dim*hidden_dim/GS, dim, hidden_dim);
|
||||
quantize(s->xb, s->xq.q, s->xq.s, dim);
|
||||
matmul(s->hb, s->xq.q, s->xq.s, w->w1.q + l*dim*hidden_dim, w->w1.s + l*dim*hidden_dim/GS, dim, hidden_dim);
|
||||
matmul(s->hb2, s->xq.q, s->xq.s, w->w3.q + l*dim*hidden_dim, w->w3.s + l*dim*hidden_dim/GS, dim, hidden_dim);
|
||||
|
||||
// F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid
|
||||
for (int i = 0; i < hidden_dim; i++) {
|
||||
@@ -367,7 +415,8 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
}
|
||||
|
||||
// final matmul to get the output of the ffn
|
||||
matmul(s->xb, s->hb, w->w2.q + l*dim*hidden_dim, w->w2.s + l*dim*hidden_dim/GS, hidden_dim, dim);
|
||||
quantize(s->hb, s->hq.q, s->hq.s, hidden_dim);
|
||||
matmul(s->xb, s->hq.q, s->hq.s, w->w2.q + l*dim*hidden_dim, w->w2.s + l*dim*hidden_dim/GS, hidden_dim, dim);
|
||||
|
||||
// residual connection
|
||||
for (int i = 0; i < dim; i++) {
|
||||
@@ -379,7 +428,8 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
rmsnorm(x, x, w->rms_final_weight, dim);
|
||||
|
||||
// classifier into logits
|
||||
matmul(s->logits, x, w->wcls.q, w->wcls.s, dim, p->vocab_size);
|
||||
quantize(x, s->xq.q, s->xq.s, dim);
|
||||
matmul(s->logits, s->xq.q, s->xq.s, w->wcls.q, w->wcls.s, dim, p->vocab_size);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user