Files
llama2.c/run.c
T
2023-07-23 05:23:45 +00:00

493 lines
16 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/*
Inference for Llama-2 Transformer model in pure C.
Compile simply with:
$ gcc -o run run.c
Or if that doesn't work then:
$ gcc -o run run.c -lm
Then run with:
$ ./run
*/
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <math.h>
// ----------------------------------------------------------------------------
// Transformer and RunState structs, and related memory management
typedef struct {
int dim; // transformer dimension
int hidden_dim; // for ffn layers
int n_layers; // number of layers
int n_heads; // number of query heads
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
int vocab_size; // vocabulary size, usually 256 (byte-level)
int seq_len; // max sequence length
} Config;
typedef struct {
// token embedding table
float* token_embedding_table; // (vocab_size, dim)
// weights for rmsnorms
float* rms_att_weight; // (layer, dim) rmsnorm weights
float* rms_ffn_weight; // (layer, dim)
// weights for matmuls
float* wq; // (layer, dim, dim)
float* wk; // (layer, dim, dim)
float* wv; // (layer, dim, dim)
float* wo; // (layer, dim, dim)
// weights for ffn
float* w1; // (layer, hidden_dim, dim)
float* w2; // (layer, dim, hidden_dim)
float* w3; // (layer, hidden_dim, dim)
// final rmsnorm
float* rms_final_weight; // (dim,)
// freq_cis for RoPE relatively positional embeddings
float* freq_cis_real; // (seq_len, dim/2)
float* freq_cis_imag; // (seq_len, dim/2)
} TransformerWeights;
typedef struct {
// current wave of activations
float *x; // activation at current time stamp (dim,)
float *xb; // same, but inside a residual branch (dim,)
float *xb2; // an additional buffer just for convenience (dim,)
float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
float *q; // query (dim,)
float *k; // key (dim,)
float *v; // value (dim,)
float *att; // buffer for scores/attention values (seq_len,)
float *logits; // output logits
// kv cache
float* key_cache; // (layer, seq_len, dim)
float* value_cache; // (layer, seq_len, dim)
} RunState;
void malloc_run_state(RunState* s, Config* p) {
// we calloc instead of malloc to keep valgrind happy
s->x = calloc(p->dim, sizeof(float));
s->xb = calloc(p->dim, sizeof(float));
s->xb2 = calloc(p->dim, sizeof(float));
s->hb = calloc(p->hidden_dim, sizeof(float));
s->hb2 = calloc(p->hidden_dim, sizeof(float));
s->q = calloc(p->dim, sizeof(float));
s->k = calloc(p->dim, sizeof(float));
s->v = calloc(p->dim, sizeof(float));
s->att = calloc(p->seq_len, sizeof(float));
s->logits = calloc(p->vocab_size, sizeof(float));
s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
// ensure all mallocs went fine
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|| !s->value_cache) {
printf("malloc failed!\n");
exit(1);
}
}
void free_run_state(RunState* s, Config* p) {
free(s->x);
free(s->xb);
free(s->xb2);
free(s->hb);
free(s->hb2);
free(s->q);
free(s->k);
free(s->v);
free(s->att);
free(s->logits);
free(s->key_cache);
free(s->value_cache);
}
void malloc_weights(TransformerWeights* w, Config* p) {
// we calloc instead of malloc to keep valgrind happy
w->token_embedding_table = calloc(p->vocab_size * p->dim, sizeof(float));
w->rms_att_weight = calloc(p->n_layers * p->dim, sizeof(float));
w->rms_ffn_weight = calloc(p->n_layers * p->dim, sizeof(float));
w->wq = calloc(p->n_layers * p->dim * p->dim, sizeof(float));
w->wk = calloc(p->n_layers * p->dim * p->dim, sizeof(float));
w->wv = calloc(p->n_layers * p->dim * p->dim, sizeof(float));
w->wo = calloc(p->n_layers * p->dim * p->dim, sizeof(float));
w->w1 = calloc(p->n_layers * p->hidden_dim * p->dim, sizeof(float));
w->w2 = calloc(p->n_layers * p->dim * p->hidden_dim, sizeof(float));
w->w3 = calloc(p->n_layers * p->hidden_dim * p->dim, sizeof(float));
w->rms_final_weight = calloc(p->dim, sizeof(float));
w->freq_cis_real = calloc(p->seq_len * p->dim / 2, sizeof(float));
w->freq_cis_imag = calloc(p->seq_len * p->dim / 2, sizeof(float));
// ensure all mallocs went fine
if (!w->token_embedding_table || !w->rms_att_weight || !w->rms_ffn_weight
|| !w->wq || !w->wk || !w->wv || !w->wo || !w->w1 || !w->w2 || !w->w3 ||
!w->rms_final_weight || !w->freq_cis_real || !w->freq_cis_imag) {
printf("malloc failed!\n");
exit(1);
}
}
void free_weights(TransformerWeights* w, Config* p) {
free(w->token_embedding_table);
free(w->rms_att_weight);
free(w->rms_ffn_weight);
free(w->wq);
free(w->wk);
free(w->wv);
free(w->wo);
free(w->w1);
free(w->w2);
free(w->w3);
free(w->rms_final_weight);
free(w->freq_cis_real);
free(w->freq_cis_imag);
}
// ----------------------------------------------------------------------------
// initialization: random init, or read from checkpoint
// initializes weights to random numbers from -.5 to .5
void init_rand(float* w, int size) {
for (int i = 0; i < size; i++) {
w[i] = ((float)rand()/(float)(RAND_MAX)) - 0.5f;
}
}
// constant init
void init_const(float* w, int size, float val) {
for (int i = 0; i < size; i++) {
w[i] = val;
}
}
void random_init_weights(TransformerWeights* w, Config* p) {
init_rand(w->token_embedding_table, p->vocab_size * p->dim);
init_const(w->rms_att_weight, p->n_layers * p->dim, 1.0f);
init_const(w->rms_ffn_weight, p->n_layers * p->dim, 1.0f);
init_rand(w->wq, p->n_layers * p->dim * p->dim);
init_rand(w->wk, p->n_layers * p->dim * p->dim);
init_rand(w->wv, p->n_layers * p->dim * p->dim);
init_rand(w->wo, p->n_layers * p->dim * p->dim);
init_rand(w->w1, p->n_layers * p->dim * p->hidden_dim);
init_rand(w->w2, p->n_layers * p->hidden_dim * p->dim);
init_rand(w->w3, p->n_layers * p->dim * p->hidden_dim);
init_const(w->rms_final_weight, p->dim, 1.0f);
init_rand(w->freq_cis_real, p->seq_len * p->dim / 2);
init_rand(w->freq_cis_imag, p->seq_len * p->dim / 2);
}
void checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f) {
fread(w->token_embedding_table, sizeof(float), p->vocab_size * p->dim, f);
fread(w->rms_att_weight, sizeof(float), p->n_layers * p->dim, f);
fread(w->wq, sizeof(float), p->n_layers * p->dim * p->dim, f);
fread(w->wk, sizeof(float), p->n_layers * p->dim * p->dim, f);
fread(w->wv, sizeof(float), p->n_layers * p->dim * p->dim, f);
fread(w->wo, sizeof(float), p->n_layers * p->dim * p->dim, f);
fread(w->rms_ffn_weight, sizeof(float), p->n_layers * p->dim, f);
fread(w->w1, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f);
fread(w->w2, sizeof(float), p->n_layers * p->hidden_dim * p->dim, f);
fread(w->w3, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f);
fread(w->rms_final_weight, sizeof(float), p->dim, f);
int head_size = p->dim / p->n_heads;
fread(w->freq_cis_real, sizeof(float), p->seq_len * head_size / 2, f);
fread(w->freq_cis_imag, sizeof(float), p->seq_len * head_size / 2, f);
}
// ----------------------------------------------------------------------------
// neural net blocks
void copy(float *a, float *b, int size) {
for (int i = 0; i < size; i++) {
a[i] = b[i];
}
}
void accum(float *a, float *b, int size) {
for (int i = 0; i < size; i++) {
a[i] += b[i];
}
}
void rmsnorm(float* o, float* x, float* weight, int size) {
// calculate sum of squares
float ss = 0.0f;
for (int j = 0; j < size; j++) {
ss += x[j] * x[j];
}
ss /= size;
ss += 1e-5f;
ss = 1.0f / sqrt(ss);
// normalize and scale
for (int j = 0; j < size; j++) {
o[j] = weight[j] * (ss * x[j]);
}
}
void softmax(float* x, int size) {
if(size == 1) {
x[0] = 1.0f;
return;
}
// find max value (for numerical stability)
float max_val = x[0];
for (int i = 1; i < size; i++) {
if (x[i] > max_val) {
max_val = x[i];
}
}
// e^x
for (int i = 0; i < size; i++) {
x[i] = exp(x[i] - max_val);
}
// normalize
float sum = 0.0f;
for (int i = 0; i < size; i++) {
sum += x[i];
}
for (int i = 0; i < size; i++) {
x[i] /= sum;
}
}
void matmul(float* xout, float* x, float* w, int n, int d) {
// W (d,n) @ x (n,) -> xout (d,)
for (int i = 0; i < d; i++) {
float val = 0.0f;
for (int j = 0; j < n; j++) {
val += w[i * n + j] * x[j];
}
xout[i] = val;
}
}
void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) {
// a few convenice variables
float *x = s->x;
int dim = p->dim;
int hidden_dim = p->hidden_dim;
int head_size = dim / p->n_heads;
// copy the token embedding into x
float* content_row = &(w->token_embedding_table[token * dim]);
copy(x, content_row, dim);
// pluck out the "pos" row of freq_cis_real and freq_cis_imag
float* freq_cis_real_row = w->freq_cis_real + pos * head_size / 2;
float* freq_cis_imag_row = w->freq_cis_imag + pos * head_size / 2;
// forward all the layers
for(int l = 0; l < p->n_layers; l++) {
// attention rmsnorm
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
// qkv matmuls for this position
matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
matmul(s->k, s->xb, w->wk + l*dim*dim, dim, dim);
matmul(s->v, s->xb, w->wv + l*dim*dim, dim, dim);
// apply RoPE rotation to the q and k vectors for each head
for (int h = 0; h < p->n_heads; h++) {
// get the q and k vectors for this head
float* q = s->q + h * head_size;
float* k = s->k + h * head_size;
// rotate q and k by the freq_cis_real and freq_cis_imag
for (int i = 0; i < head_size; i+=2) {
float q0 = q[i];
float q1 = q[i+1];
float k0 = k[i];
float k1 = k[i+1];
float fcr = freq_cis_real_row[i/2];
float fci = freq_cis_imag_row[i/2];
q[i] = q0 * fcr - q1 * fci;
q[i+1] = q0 * fci + q1 * fcr;
k[i] = k0 * fcr - k1 * fci;
k[i+1] = k0 * fci + k1 * fcr;
}
}
// save key,value at this time step (pos) to our kv cache
int loff = l * p->seq_len * dim; // kv cache layer offset for convenience
float* key_cache_row = s->key_cache + loff + pos * dim;
float* value_cache_row = s->value_cache + loff + pos * dim;
copy(key_cache_row, s->k, dim);
copy(value_cache_row, s->v, dim);
// multihead attention. iterate over all heads
for (int h = 0; h < p->n_heads; h++) {
// get the query vector for this head
float* q = s->q + h * head_size;
// iterate over all timesteps, including the current one
for (int t = 0; t <= pos; t++) {
// get the key vector for this head and at this timestep
float* k = s->key_cache + loff + t * dim + h * head_size;
// calculate the attention score as the dot product of q and k
float score = 0.0f;
for (int i = 0; i < head_size; i++) {
score += q[i] * k[i];
}
score /= sqrtf(head_size);
// save the score to the attention buffer
s->att[t] = score;
}
// softmax the scores to get attention weights, from 0..pos inclusively
softmax(s->att, pos + 1);
// weighted sum of the values, store back into xb
for (int i = 0; i < head_size; i++) {
float val = 0.0f;
for (int t = 0; t <= pos; t++) {
val += s->att[t] * s->value_cache[loff + t * dim + h * head_size + i]; // note bad locality
}
s->xb[h * head_size + i] = val;
}
}
// final matmul to get the output of the attention
matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);
// residual connection back into x
accum(x, s->xb2, dim);
// ffn rmsnorm
rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);
// F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid
for (int i = 0; i < hidden_dim; i++) {
s->hb[i] = s->hb[i] * (1.0f / (1.0f + expf(-s->hb[i])));
}
// elementwise multiply with w3(x)
for (int i = 0; i < hidden_dim; i++) {
s->hb[i] = s->hb[i] * s->hb2[i];
}
// final matmul to get the output of the ffn
matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);
// residual connection
accum(x, s->xb, dim);
}
// final rmsnorm
rmsnorm(x, x, w->rms_final_weight, dim);
// classifier into logits
matmul(s->logits, x, w->token_embedding_table, p->dim, p->vocab_size);
}
int sample(float* probabilities, int n) {
// sample index from probabilities, they must sum to 1
float r = (float)rand() / (float)RAND_MAX;
float cdf = 0.0f;
for (int i = 0; i < n; i++) {
cdf += probabilities[i];
if (r < cdf) {
return i;
}
}
return n - 1; // in case of rounding errors
}
int argmax(float* v, int n) {
// return argmax of v in elements 0..n
int max_i = 0;
float max_p = v[0];
for (int i = 1; i < n; i++) {
if (v[i] > max_p) {
max_i = i;
max_p = v[i];
}
}
return max_i;
}
// ----------------------------------------------------------------------------
int main(int argc, char *argv[]) {
setbuf(stdout, NULL); // disable stdout buffering
// poor man's C argparse
char *checkpoint = NULL;
float temperature = 0.9f;
// 'checkpoint' is necessary arg
if (argc < 2) {
printf("Usage: %s <checkpoint_file> [temperature] [seed]\n", argv[0]);
return 1;
}
checkpoint = argv[1];
// temperature is optional
if (argc >= 3) {
temperature = atof(argv[2]);
}
// seed is optional
if (argc >= 4) {
unsigned int seed = atoi(argv[3]);
srand(seed);
} else {
time_t current_time;
time(&current_time);
srand((unsigned int)current_time);
}
// read in the config header
Config config;
FILE *file = fopen(checkpoint, "rb");
if (!file) {
printf("Unable to open file!");
return 1;
}
fread(&config, sizeof(Config), 1, file);
// create and init the Transformer
TransformerWeights weights;
malloc_weights(&weights, &config);
checkpoint_init_weights(&weights, &config, file);
fclose(file);
// create and init the application RunState
RunState state;
malloc_run_state(&state, &config);
// the current position we are in
int next;
int token = 1; // 1 = BOS token in Llama-2 sentencepiece
int pos = 0;
while (pos < config.seq_len) {
// forward the transformer to get logits for the next token
transformer(token, pos, &config, &state, &weights);
// sample the next token
if(temperature == 0.0f) {
// greedy argmax sampling
next = argmax(state.logits, config.vocab_size);
} else {
// apply the temperature to the logits
for (int q=0; q<config.vocab_size; q++) { state.logits[q] /= temperature; }
// apply softmax to the logits to get the probabilities for next token
softmax(state.logits, config.vocab_size);
// we now want to sample from this distribution to get the next token
next = sample(state.logits, config.vocab_size);
}
printf("%d\n", next);
// advance forward
token = next;
pos++;
}
free_run_state(&state, &config);
free_weights(&weights, &config);
return 0;
}