diff --git a/run.c b/run.c index df95e6f..56ceff5 100644 --- a/run.c +++ b/run.c @@ -89,8 +89,8 @@ void malloc_run_state(RunState* s, Config* p) { s->hb = calloc(p->hidden_dim, sizeof(float)); s->hb2 = calloc(p->hidden_dim, sizeof(float)); s->q = calloc(p->dim, sizeof(float)); - s->k = calloc(p->dim, sizeof(float)); - s->v = calloc(p->dim, sizeof(float)); + s->k = calloc(kv_dim, sizeof(float)); + s->v = calloc(kv_dim, sizeof(float)); s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); s->logits = calloc(p->vocab_size, sizeof(float)); s->probindex = calloc(p->vocab_size, sizeof(ProbIndex));