clean up swiglu a little bit
This commit is contained in:
@@ -339,14 +339,14 @@ float* forward(Transformer* transformer, int token, int pos) {
|
||||
matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
|
||||
matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);
|
||||
|
||||
// F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid
|
||||
// SwiGLU non-linearity
|
||||
for (int i = 0; i < hidden_dim; i++) {
|
||||
s->hb[i] = s->hb[i] * (1.0f / (1.0f + expf(-s->hb[i])));
|
||||
}
|
||||
|
||||
// elementwise multiply with w3(x)
|
||||
for (int i = 0; i < hidden_dim; i++) {
|
||||
s->hb[i] = s->hb[i] * s->hb2[i];
|
||||
float val = s->hb[i];
|
||||
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
|
||||
val *= (1.0f / (1.0f + expf(-val)));
|
||||
// elementwise multiply with w3(x)
|
||||
val *= s->hb2[i];
|
||||
s->hb[i] = val;
|
||||
}
|
||||
|
||||
// final matmul to get the output of the ffn
|
||||
|
||||
Reference in New Issue
Block a user