800 lines
27 KiB
Dart
800 lines
27 KiB
Dart
import 'dart:convert';
|
||
import 'dart:developer';
|
||
import 'dart:io';
|
||
import 'dart:math';
|
||
import 'dart:typed_data';
|
||
|
||
import 'package:args/args.dart';
|
||
|
||
class Config {
|
||
// transformer dimension
|
||
late int dim;
|
||
// for ffn layers
|
||
late int hidden_dim;
|
||
// number of layers
|
||
late int n_layers;
|
||
// number of query heads
|
||
late int n_heads;
|
||
// number of key/value heads (can be < query heads because of multiquery)
|
||
late int n_kv_heads;
|
||
// vocabulary size, usually 256 (byte-level)
|
||
late int vocab_size;
|
||
// max sequence length
|
||
late int seq_len;
|
||
|
||
@override
|
||
String toString() {
|
||
return "Config(dim: $dim, hidden_dim: $hidden_dim, n_layers: $n_layers, n_heads: $n_heads, n_kv_heads: $n_kv_heads, vocab_size: $vocab_size, seq_len: $seq_len)";
|
||
}
|
||
}
|
||
|
||
const configByteSize = 7 * 4;
|
||
|
||
//We are using 32 bit percision floats here
|
||
class TransformerWeights {
|
||
// token embedding table
|
||
late Float32List token_embedding_table; // (vocab_size, dim)
|
||
// weights for rmsnorms
|
||
late Float32List rms_att_weight; // (layer, dim) rmsnorm weights
|
||
late Float32List rms_ffn_weight; // (layer, dim)
|
||
// weights for matmuls. note dim == n_heads * head_size
|
||
late Float32List wq; // (layer, dim, n_heads * head_size)
|
||
late Float32List wk; // (layer, dim, n_kv_heads * head_size)
|
||
late Float32List wv; // (layer, dim, n_kv_heads * head_size)
|
||
late Float32List wo; // (layer, n_heads * head_size, dim)
|
||
// weights for ffn
|
||
late Float32List w1; // (layer, hidden_dim, dim)
|
||
late Float32List w2; // (layer, dim, hidden_dim)
|
||
late Float32List w3; // (layer, hidden_dim, dim)
|
||
// final rmsnorm
|
||
late Float32List rms_final_weight; // (dim,)
|
||
// freq_cis for RoPE relatively positional embeddings
|
||
late Float32List freq_cis_real; // (seq_len, head_size/2)
|
||
late Float32List freq_cis_imag; // (seq_len, head_size/2)
|
||
// (optional) classifier weights for the logits, on the last layer
|
||
late Float32List wcls;
|
||
}
|
||
|
||
class ProbIndex {
|
||
double prob;
|
||
int index;
|
||
ProbIndex(this.prob, this.index);
|
||
}
|
||
|
||
class TokenIndex {
|
||
String str;
|
||
int id;
|
||
TokenIndex(this.str, this.id);
|
||
}
|
||
|
||
class RunState {
|
||
// current wave of activations
|
||
late Float32List x; // activation at current time stamp (dim,)
|
||
late Float32List xb; // same, but inside a residual branch (dim,)
|
||
late Float32List xb2; // an additional buffer just for convenience (dim,)
|
||
late Float32List hb; // buffer for hidden dimension in the ffn (hidden_dim,)
|
||
late Float32List hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
|
||
late Float32List q; // query (dim,)
|
||
late Float32List k; // key (dim,)
|
||
late Float32List v; // value (dim,)
|
||
late Float32List att; // buffer for scores/attention values (n_heads, seq_len)
|
||
late Float32List logits; // output logits
|
||
late List<ProbIndex> probindex; // buffer used in top-p sampling
|
||
// kv cache
|
||
late Float32List key_cache; // (layer, seq_len, dim)
|
||
late Float32List value_cache; // (layer, seq_len, dim)
|
||
}
|
||
|
||
initialize_run_state(RunState s, Config config) {
|
||
// we calloc instead of malloc to keep valgrind happy
|
||
int kv_dim = (config.dim * config.n_kv_heads) ~/ config.n_heads;
|
||
s.x = Float32List(config.dim);
|
||
s.xb = Float32List(config.dim);
|
||
s.xb2 = Float32List(config.dim);
|
||
s.hb = Float32List(config.hidden_dim);
|
||
s.hb2 = Float32List(config.hidden_dim);
|
||
s.q = Float32List(config.dim);
|
||
s.k = Float32List(kv_dim);
|
||
s.v = Float32List(kv_dim);
|
||
s.att = Float32List(config.n_heads * config.seq_len);
|
||
s.logits = Float32List(config.vocab_size);
|
||
s.probindex = [];
|
||
s.key_cache = Float32List(config.n_layers * config.seq_len * kv_dim);
|
||
s.value_cache = Float32List(config.n_layers * config.seq_len * kv_dim);
|
||
}
|
||
|
||
class Tokenizer {
|
||
List<String> vocab;
|
||
List<double> vocab_scores;
|
||
Tokenizer(
|
||
this.vocab,
|
||
this.vocab_scores,
|
||
);
|
||
|
||
bpe_encode(String text, List<int> tokens, int n_tokens) {
|
||
tokens = [];
|
||
|
||
// First pass, combine raw tokens
|
||
text.runes.forEach((element) {
|
||
String decoded = utf8.decode([element]);
|
||
if (vocab.contains(decoded)) {
|
||
tokens.add(vocab.indexOf(decoded));
|
||
}
|
||
});
|
||
|
||
// Second pass, combine bpe tokens
|
||
while (true) {
|
||
double best_score = -1e10;
|
||
int best_id = -1;
|
||
int best_index = -1;
|
||
|
||
for (int i = 0; i < tokens.length - 1; i++) {
|
||
String newStr = vocab[tokens[i]] + vocab[tokens[i + 1]];
|
||
int newStrIndex = vocab.indexOf(newStr);
|
||
if (newStrIndex != -1 && vocab_scores[newStrIndex] > best_score) {
|
||
best_score = vocab_scores[newStrIndex];
|
||
best_id = newStrIndex;
|
||
best_index = i;
|
||
}
|
||
}
|
||
|
||
if (best_index == -1) break;
|
||
|
||
tokens[best_index] = best_id;
|
||
tokens.removeAt(best_index + 1);
|
||
}
|
||
return tokens;
|
||
}
|
||
}
|
||
|
||
// ----------------------------------------------------------------------------
|
||
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
|
||
|
||
int argmax(Float32List probabilities) {
|
||
// return the index that has the highest probability
|
||
int max_i = 0;
|
||
double max_p = probabilities[0];
|
||
for (int i = 1; i < probabilities.length; i++) {
|
||
if (probabilities[i] > max_p) {
|
||
max_i = i;
|
||
max_p = probabilities[i];
|
||
}
|
||
}
|
||
return max_i;
|
||
}
|
||
|
||
int sample(Float32List probabilities) {
|
||
// sample index from probabilities (they must sum to 1!)
|
||
double r = Random().nextDouble();
|
||
double cdf = 0.0;
|
||
for (int i = 0; i < probabilities.length; i++) {
|
||
cdf += probabilities[i];
|
||
if (r < cdf) return i;
|
||
}
|
||
return probabilities.length - 1; // in case of rounding errors
|
||
}
|
||
|
||
int sample_topp(Float32List probabilities, double topp) {
|
||
// top-p sampling (or "nucleus sampling") samples from the smallest set of
|
||
// tokens that exceed probability topp. This way we never sample tokens that
|
||
// have very low probabilities and are less likely to go "off the rails".
|
||
|
||
// quicksort indices in descending order of probabilities
|
||
// values smaller than (1 - topp) / (n - 1) cannot be part of the result
|
||
// In the original llama.c they crop these out as candidates before sorting
|
||
List<ProbIndex> probindex = [];
|
||
|
||
double cutoff = (1.0 - topp) / (probabilities.length - 1);
|
||
|
||
for (int i = 0; i < probabilities.length; i++) {
|
||
if (probabilities[i] >= cutoff) {
|
||
probindex.add(ProbIndex(probabilities[i], i));
|
||
}
|
||
}
|
||
|
||
probindex.sort((a, b) => b.prob.compareTo(a.prob));
|
||
|
||
// truncate the list where cumulative probability exceeds topp
|
||
double cumulative_prob = 0.0;
|
||
int last_idx =
|
||
probindex.length - 1; // in case of rounding errors consider all elements
|
||
for (int i = 0; i < probindex.length; i++) {
|
||
cumulative_prob += probindex[i].prob;
|
||
if (cumulative_prob > topp) {
|
||
last_idx = i;
|
||
break; // we've exceeded topp by including last_idx
|
||
}
|
||
}
|
||
|
||
probindex.removeRange(last_idx + 1, probindex.length);
|
||
|
||
// sample from the truncated list
|
||
double r = new Random().nextDouble() * cumulative_prob;
|
||
double cdf = 0.0;
|
||
for (int i = 0; i <= last_idx; i++) {
|
||
cdf += probindex[i].prob;
|
||
if (r < cdf) {
|
||
return probindex[i].index;
|
||
}
|
||
}
|
||
return probindex[last_idx].index; // in case of rounding errors
|
||
}
|
||
|
||
rmsnorm(Float32List out, Float32List x, Float32List weight) {
|
||
assert(out.length == x.length);
|
||
assert(x.length == weight.length);
|
||
// calculate sum of squares
|
||
double ss = 0.0;
|
||
x.forEach((element) {
|
||
ss += element * element;
|
||
});
|
||
ss /= x.length;
|
||
ss += 1e-5;
|
||
ss = 1.0 / sqrt(ss); // sqr mean sum of squares
|
||
|
||
// normalize and scale
|
||
for (int j = 0; j < x.length; j++) {
|
||
out[j] = weight[j] * (ss * x[j]);
|
||
}
|
||
}
|
||
|
||
void softmax(Float32List x, int size) {
|
||
// find max value (for numerical stability)
|
||
double max_val = x[0];
|
||
for (int i = 1; i < size; i++) {
|
||
if (x[i] > max_val) {
|
||
max_val = x[i];
|
||
}
|
||
}
|
||
// exp and sum
|
||
double sum = 0.0;
|
||
for (int i = 0; i < size; i++) {
|
||
x[i] = exp(x[i] - max_val);
|
||
sum += x[i];
|
||
}
|
||
// normalize
|
||
for (int i = 0; i < size; i++) x[i] /= sum;
|
||
}
|
||
|
||
void matmul(Float32List out, Float32List x, Float32List w, int n, int d) {
|
||
assert(out.length == d);
|
||
assert(x.length == n);
|
||
assert(w.length == n * d);
|
||
|
||
// W (d,n) @ x (n,) -> xout (d,)
|
||
// by far the most amount of time is spent inside this little function
|
||
for (int i = 0; i < d; i++) {
|
||
double val = 0.0;
|
||
for (int j = 0; j < n; j++) {
|
||
val += w[i * n + j] * x[j];
|
||
}
|
||
out[i] = val;
|
||
}
|
||
}
|
||
|
||
transformer(int token, int pos, Config config, RunState state,
|
||
TransformerWeights weights) {
|
||
int dim = config.dim;
|
||
int kv_dim = config.dim * config.n_kv_heads ~/ config.n_heads;
|
||
int kv_mul = config.n_kv_heads ~/
|
||
config.n_heads; // integer multiplier of the kv sharing in multiquery
|
||
int hidden_dim = config.hidden_dim;
|
||
int head_size = config.dim ~/ config.n_heads;
|
||
|
||
// copy the token embedding into x
|
||
Float32List current_row = Float32List.sublistView(
|
||
weights.token_embedding_table,
|
||
token * config.dim,
|
||
(token + 1) * config.dim);
|
||
for (int i = 0; i < config.dim; i++) state.x[i] = current_row[i];
|
||
|
||
// Note: Divide by 2 here because Rope Parameters repeat after every 2 dimensions
|
||
Float32List freq_cis_real_row = weights.freq_cis_real
|
||
.sublist(pos * head_size ~/ 2, (pos + 1) * head_size ~/ 2);
|
||
Float32List freq_cis_imag_row = weights.freq_cis_imag
|
||
.sublist(pos * head_size ~/ 2, (pos + 1) * head_size ~/ 2);
|
||
|
||
// forward all the layers
|
||
for (int l = 0; l < config.n_layers; l++) {
|
||
rmsnorm(
|
||
state.xb,
|
||
state.x,
|
||
Float32List.sublistView(
|
||
weights.rms_att_weight, l * dim, (l + 1) * dim));
|
||
|
||
// qkv matmuls for this position
|
||
// NOTE:yiming This look slike a place for lots of paralle work :thinking:
|
||
// x = x @ wq, wq with dim * dim
|
||
matmul(
|
||
state.q,
|
||
state.xb,
|
||
Float32List.sublistView(weights.wq, l * dim * dim, (l + 1) * dim * dim),
|
||
dim,
|
||
dim);
|
||
|
||
// x = x @ wk, wq with dim * kv_dim
|
||
matmul(
|
||
state.k,
|
||
state.xb,
|
||
Float32List.sublistView(
|
||
weights.wk, l * dim * kv_dim, (l + 1) * dim * kv_dim),
|
||
dim,
|
||
kv_dim);
|
||
|
||
// x = x @ wv, wq with dim * kv_dim
|
||
matmul(
|
||
state.v,
|
||
state.xb,
|
||
Float32List.sublistView(
|
||
weights.wv, l * dim * kv_dim, (l + 1) * dim * kv_dim),
|
||
dim,
|
||
kv_dim);
|
||
|
||
// RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head
|
||
// https://arxiv.org/pdf/2104.09864v4.pdf
|
||
// We are just reusing the loop for k and q distance calculation
|
||
for (int v = 0; v < 2; v++) {
|
||
Float32List vec =
|
||
v == 0 ? state.q : state.k; // the vector to rotate (query or key)
|
||
int vec_size = v == 0 ? dim : kv_dim; // the size of the vector
|
||
|
||
// We are only rotating in a group of 2
|
||
for (int i = 0; i < vec_size; i += 2) {
|
||
double v0 = vec[i];
|
||
double v1 = vec[i + 1];
|
||
double fcr = freq_cis_real_row[(i % head_size) ~/ 2];
|
||
double fci = freq_cis_imag_row[(i % head_size) ~/ 2];
|
||
// See the RoPE paper for this section
|
||
// 3.4.2 Computational efficient realization of rotary matrix multiplication
|
||
// x1 = x1 + cos mθ_1 - x2 sin mθ_1
|
||
vec[i] = v0 * fcr - v1 * fci;
|
||
// x2 = x1 sin mθ_1 + x2 + cos mθ_1
|
||
vec[i + 1] = v0 * fci + v1 * fcr;
|
||
}
|
||
}
|
||
|
||
// save key,value at this time step (pos) to our kv cache
|
||
// offset by n_layer * seq_len * kv_dim
|
||
int loff =
|
||
l * config.seq_len * kv_dim; // kv cache layer offset for convenience
|
||
// key cache = loff + pos * kv_dim
|
||
int key_cache_row_offset = loff + pos * kv_dim;
|
||
// save k,v into kv cache
|
||
for (int i = 0; i < state.k.length; i++)
|
||
state.key_cache[key_cache_row_offset + i] = state.k[i];
|
||
|
||
for (int i = 0; i < state.v.length; i++)
|
||
state.value_cache[key_cache_row_offset + i] = state.v[i];
|
||
|
||
// multihead attention. iterate over all heads
|
||
for (int h = 0; h < config.n_heads; h++) {
|
||
// get the query vector for this head
|
||
Float32List q =
|
||
Float32List.sublistView(state.q, h * head_size, (h + 1) * head_size);
|
||
// attention scores for this head
|
||
Float32List att = Float32List.sublistView(
|
||
state.att, h * config.seq_len, (h + 1) * config.seq_len);
|
||
// iterate over all timesteps, including the current one
|
||
for (int t = 0; t <= pos; t++) {
|
||
// get the key vector for this head and at this timestep
|
||
// kv_mul is just 1 now
|
||
int key_cache_offset = loff +
|
||
t * kv_dim +
|
||
(h ~/ kv_mul) *
|
||
head_size; // it's still offset by head size kv_dim = head_size * h!
|
||
// but sometimes multiple head can share a key_cache
|
||
Float32List k = Float32List.sublistView(
|
||
state.key_cache, key_cache_offset, key_cache_offset + kv_dim);
|
||
// calculate the attention score as the dot product of q and k
|
||
double score = 0.0;
|
||
for (int ll = 0; ll < head_size; ll++) {
|
||
score += q[ll] * k[ll];
|
||
}
|
||
// TODO(yiming): reread the paper to understand better
|
||
score /= sqrt(head_size);
|
||
// save the score to the attention buffer
|
||
att[t] = score;
|
||
}
|
||
|
||
// softmax the scores to get attention weights, from 0..pos inclusively
|
||
// soft max happens before attention * v
|
||
// softmax is done on the entire attention
|
||
// I think there's some trick in pytorch for this
|
||
softmax(att, pos + 1);
|
||
|
||
// Now we have calculated the weighted attention vector, it's time to apply attention value
|
||
// weighted sum of the values, store back into xb
|
||
// Clear out xb for the next stage
|
||
for (int i = 0; i < head_size; i++) {
|
||
state.xb[h * head_size + i] = 0.0;
|
||
}
|
||
|
||
Float32List xb_off =
|
||
Float32List.sublistView(state.xb, h * head_size, (h + 1) * head_size);
|
||
for (int t = 0; t <= pos; t++) {
|
||
// get the value vector for this head and at this timestep
|
||
int v_cache_offset = loff + t * kv_dim + (h ~/ kv_mul) * head_size;
|
||
Float32List v = Float32List.sublistView(
|
||
state.value_cache, v_cache_offset, v_cache_offset + head_size);
|
||
// get the attention weight for this timestep
|
||
double a = att[t];
|
||
// accumulate the weighted value into xb
|
||
for (int i = 0; i < head_size; i++) {
|
||
xb_off[i] += a * v[i];
|
||
}
|
||
}
|
||
}
|
||
|
||
// final matmul to get the output of the attention
|
||
// The "Aggregate output" of all the attention heads
|
||
matmul(
|
||
state.xb2,
|
||
state.xb,
|
||
Float32List.sublistView(weights.wo, l * dim * dim, (l + 1) * dim * dim),
|
||
dim,
|
||
dim);
|
||
|
||
// residual connection back into x
|
||
for (int i = 0; i < dim; i++) {
|
||
state.x[i] += state.xb2[i];
|
||
}
|
||
|
||
// ffn rmsnorm
|
||
rmsnorm(
|
||
state.xb,
|
||
state.x,
|
||
Float32List.sublistView(
|
||
weights.rms_ffn_weight, l * dim, (l + 1) * dim));
|
||
|
||
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||
// first calculate self.w1(x) and self.w3(x)
|
||
matmul(
|
||
state.hb,
|
||
state.xb,
|
||
Float32List.sublistView(
|
||
weights.w1, (l * dim * hidden_dim), (l + 1) * dim * hidden_dim),
|
||
dim,
|
||
hidden_dim);
|
||
|
||
matmul(
|
||
state.hb2,
|
||
state.xb,
|
||
Float32List.sublistView(
|
||
weights.w3, (l * dim * hidden_dim), (l + 1) * dim * hidden_dim),
|
||
dim,
|
||
hidden_dim);
|
||
|
||
// F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid
|
||
for (int i = 0; i < hidden_dim; i++) {
|
||
state.hb[i] = state.hb[i] * (1.0 / (1.0 + exp(-state.hb[i])));
|
||
}
|
||
|
||
// elementwise multiply with w3(x)
|
||
// F.silu(self.w1(x)) * self.w3(x)
|
||
for (int i = 0; i < hidden_dim; i++) {
|
||
state.hb[i] = state.hb[i] * state.hb2[i];
|
||
}
|
||
|
||
// final matmul to get the output of the ffn
|
||
// here we are reusing xb again!
|
||
// x = self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||
matmul(
|
||
state.xb,
|
||
state.hb,
|
||
Float32List.sublistView(
|
||
weights.w2, l * dim * hidden_dim, (l + 1) * dim * hidden_dim),
|
||
hidden_dim,
|
||
dim);
|
||
|
||
// residual connection
|
||
for (int i = 0; i < dim; i++) {
|
||
state.x[i] += state.xb[i];
|
||
}
|
||
}
|
||
|
||
// final rmsnorm
|
||
rmsnorm(state.x, state.x, weights.rms_final_weight);
|
||
|
||
// classifier into logits
|
||
matmul(state.logits, state.x, weights.wcls, config.dim, config.vocab_size);
|
||
}
|
||
|
||
void main(List<String> args) {
|
||
String? checkpoint_path = "./stories15M.bin";
|
||
String tokenizer_path = "tokenizer.bin";
|
||
double temperature = 1.0;
|
||
double top_p = 0.9;
|
||
int rng_seed = 0; // seed rng with time by default
|
||
int steps = 256; // number of steps to run for
|
||
String? prompt = " One";
|
||
|
||
var parser = ArgParser();
|
||
parser.addOption(
|
||
'checkpoint_path',
|
||
abbr: 'c',
|
||
callback: (value) => checkpoint_path = value,
|
||
);
|
||
parser.addOption('temp',
|
||
abbr: 't',
|
||
callback: (value) =>
|
||
{if (value != null) temperature = double.parse(value)},
|
||
defaultsTo: "1.0");
|
||
parser.addOption('topp',
|
||
abbr: 'p',
|
||
callback: (value) => {if (value != null) top_p = double.parse(value)},
|
||
defaultsTo: "0.9");
|
||
parser.addOption('seed',
|
||
abbr: 's',
|
||
callback: (value) => {if (value != null) rng_seed = int.parse(value)},
|
||
defaultsTo: "0");
|
||
parser.addOption('steps',
|
||
abbr: 'n',
|
||
callback: (value) => {if (value != null) steps = int.parse(value)},
|
||
defaultsTo: "256");
|
||
parser.addOption('prompt',
|
||
abbr: 'i',
|
||
callback: (value) => {if (value != null) prompt = value},
|
||
defaultsTo: "");
|
||
parser.addOption('tokenizer_path',
|
||
abbr: 'z',
|
||
callback: (value) => {if (value != null) tokenizer_path = value});
|
||
|
||
parser.parse(args);
|
||
|
||
if (rng_seed == 0) rng_seed = Timeline.now;
|
||
|
||
print("===========llama2.dart===========");
|
||
print("check_point_path: $checkpoint_path");
|
||
print("tokenizer_path: $tokenizer_path");
|
||
print("temperature: $temperature");
|
||
print("top_p: $top_p");
|
||
print("rng_seed: $rng_seed");
|
||
print("steps: $steps");
|
||
print("prompt: $prompt");
|
||
|
||
var config = Config();
|
||
var weights = TransformerWeights();
|
||
|
||
if (checkpoint_path == null) return print("No checkpoint path provided");
|
||
|
||
print("========= Reading Weights =========");
|
||
|
||
// Read Weights and Config from file
|
||
{
|
||
Uint8List checkpoint_bytes = File(checkpoint_path!).readAsBytesSync();
|
||
print("Read ${checkpoint_bytes.length} bytes from $checkpoint_path");
|
||
|
||
{
|
||
// Reading Config
|
||
Uint8List config_bytes = checkpoint_bytes.sublist(0, configByteSize);
|
||
Int32List config_ints = config_bytes.buffer.asInt32List();
|
||
config.dim = config_ints[0];
|
||
config.hidden_dim = config_ints[1];
|
||
config.n_layers = config_ints[2];
|
||
config.n_heads = config_ints[3];
|
||
config.n_kv_heads = config_ints[4];
|
||
config.vocab_size = config_ints[5];
|
||
config.seq_len = config_ints[6];
|
||
print("Read Config: $config");
|
||
}
|
||
|
||
{
|
||
bool shared_weights = config.vocab_size > 0;
|
||
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
|
||
config.vocab_size = config.vocab_size.abs();
|
||
// Load the weights
|
||
int offset = 0;
|
||
Float32List weight_floats =
|
||
checkpoint_bytes.buffer.asFloat32List(configByteSize);
|
||
|
||
int head_size = config.dim ~/ config.n_heads;
|
||
weights.token_embedding_table = weight_floats.sublist(
|
||
offset, offset + config.vocab_size * config.dim);
|
||
offset += config.vocab_size * config.dim;
|
||
print(
|
||
"Read ${weights.token_embedding_table.lengthInBytes} bytes into token_embedding_table");
|
||
|
||
weights.rms_att_weight =
|
||
weight_floats.sublist(offset, offset + config.n_layers * config.dim);
|
||
offset += config.n_layers * config.dim;
|
||
print(
|
||
"Read ${weights.rms_att_weight.lengthInBytes} bytes into rms_att_weight");
|
||
|
||
weights.wq = weight_floats.sublist(offset,
|
||
offset + config.n_layers * config.dim * config.n_heads * head_size);
|
||
offset += config.n_layers * config.dim * config.n_heads * head_size;
|
||
print("Read ${weights.wq.lengthInBytes} bytes into wq");
|
||
|
||
weights.wk = weight_floats.sublist(
|
||
offset,
|
||
offset +
|
||
config.n_layers * config.dim * config.n_kv_heads * head_size);
|
||
offset += config.n_layers * config.dim * config.n_kv_heads * head_size;
|
||
print("Read ${weights.wk.lengthInBytes} bytes into wk");
|
||
|
||
weights.wv = weight_floats.sublist(
|
||
offset,
|
||
offset +
|
||
config.n_layers * config.dim * config.n_kv_heads * head_size);
|
||
offset += config.n_layers * config.dim * config.n_kv_heads * head_size;
|
||
print("Read ${weights.wv.lengthInBytes} bytes into wv");
|
||
|
||
weights.wo = weight_floats.sublist(offset,
|
||
offset + config.n_layers * config.n_heads * head_size * config.dim);
|
||
offset += config.n_layers * config.n_heads * head_size * config.dim;
|
||
print("Read ${weights.wo.lengthInBytes} bytes into wo");
|
||
|
||
weights.rms_ffn_weight =
|
||
weight_floats.sublist(offset, offset + config.n_layers * config.dim);
|
||
offset += config.n_layers * config.dim;
|
||
print(
|
||
"Read ${weights.rms_ffn_weight.lengthInBytes} bytes into rms_ffn_weight");
|
||
|
||
weights.w1 = weight_floats.sublist(
|
||
offset, offset + config.n_layers * config.hidden_dim * config.dim);
|
||
offset += config.n_layers * config.hidden_dim * config.dim;
|
||
print("Read ${weights.w1.lengthInBytes} bytes into w1");
|
||
|
||
weights.w2 = weight_floats.sublist(
|
||
offset, offset + config.n_layers * config.dim * config.hidden_dim);
|
||
offset += config.n_layers * config.dim * config.hidden_dim;
|
||
print("Read ${weights.w2.lengthInBytes} bytes into w2");
|
||
|
||
weights.w3 = weight_floats.sublist(
|
||
offset, offset + config.n_layers * config.hidden_dim * config.dim);
|
||
offset += config.n_layers * config.hidden_dim * config.dim;
|
||
print("Read ${weights.w3.lengthInBytes} bytes into w3");
|
||
|
||
weights.rms_final_weight =
|
||
weight_floats.sublist(offset, offset + config.dim);
|
||
offset += config.dim;
|
||
print(
|
||
"Read ${weights.rms_final_weight.lengthInBytes} bytes into rms_final_weight");
|
||
|
||
weights.freq_cis_real = weight_floats.sublist(
|
||
offset, offset + config.seq_len * head_size ~/ 2);
|
||
offset += config.seq_len * head_size ~/ 2;
|
||
print(
|
||
"Read ${weights.freq_cis_real.lengthInBytes} bytes into freq_cis_real");
|
||
|
||
weights.freq_cis_imag = weight_floats.sublist(
|
||
offset, offset + config.seq_len * head_size ~/ 2);
|
||
offset += config.seq_len * head_size ~/ 2;
|
||
print(
|
||
"Read ${weights.freq_cis_imag.lengthInBytes} bytes into freq_cis_imag");
|
||
|
||
if (shared_weights) {
|
||
print("Read shared weights into wcls");
|
||
weights.wcls = weights.token_embedding_table;
|
||
} else {
|
||
weights.wcls = weight_floats.sublist(
|
||
offset, offset + config.vocab_size * config.dim);
|
||
offset += config.dim;
|
||
print("Read ${weights.wcls.lengthInBytes} bytes into wcls");
|
||
}
|
||
}
|
||
}
|
||
|
||
// clamp number of steps to supported range
|
||
if (steps <= 0 || steps > config.seq_len) {
|
||
steps = config.seq_len;
|
||
}
|
||
|
||
// read in the tokenizer .bin file
|
||
List<Uint8List> vocab = new List.filled(
|
||
config.vocab_size, new Uint8List(0)); // config.vocab_size;
|
||
Float32List vocab_scores = new Float32List(config.vocab_size);
|
||
{
|
||
ByteData tokenizer_bytes =
|
||
File(tokenizer_path).readAsBytesSync().buffer.asByteData(0);
|
||
int offset = 0;
|
||
// Not being used but read anyways
|
||
int max_token_length = tokenizer_bytes.getUint32(offset, Endian.little);
|
||
offset += 4;
|
||
int next_str_length = 0;
|
||
for (int i = 0; i < config.vocab_size; i++) {
|
||
double score = tokenizer_bytes.getFloat32(offset, Endian.little);
|
||
offset += 4;
|
||
next_str_length = tokenizer_bytes.getUint32(offset, Endian.little);
|
||
offset += 4;
|
||
Uint8List next_chunk =
|
||
tokenizer_bytes.buffer.asUint8List(offset, next_str_length);
|
||
vocab_scores[i] = score;
|
||
offset += next_str_length;
|
||
vocab[i] = next_chunk;
|
||
}
|
||
}
|
||
|
||
print("=====beginning generation=====");
|
||
|
||
Tokenizer tokenizer;
|
||
tokenizer =
|
||
Tokenizer(vocab.map((e) => utf8.decode(e)).toList(), vocab_scores);
|
||
|
||
// process the prompt, if any
|
||
List<int> prompt_tokens = [];
|
||
int num_prompt_tokens = 0;
|
||
if (prompt != null) {
|
||
prompt_tokens =
|
||
tokenizer.bpe_encode(prompt!, prompt_tokens, num_prompt_tokens);
|
||
}
|
||
|
||
RunState state = RunState();
|
||
|
||
initialize_run_state(state, config);
|
||
// Finally! the main loop
|
||
// used to time our code, only initialized after first iteration
|
||
int start = 0;
|
||
int next; // will store the next token in the sequence
|
||
// init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
|
||
int token = 1;
|
||
int pos = 0; // position in the sequence
|
||
|
||
while (pos < steps) {
|
||
// transformer! Run the model
|
||
transformer(token, pos, config, state, weights);
|
||
|
||
// advance the state state machine
|
||
if (pos < prompt_tokens.length) {
|
||
// if we are still processing the input prompt, force the next prompt token
|
||
next = prompt_tokens[pos];
|
||
} else {
|
||
// sample the next token
|
||
if (temperature == 0.0) {
|
||
// greedy argmax sampling: take the token with the highest probability
|
||
next = argmax(state.logits);
|
||
} else {
|
||
// apply the temperature to the logits
|
||
for (int q = 0; q < config.vocab_size; q++) {
|
||
state.logits[q] /= temperature;
|
||
}
|
||
// apply softmax to the logits to get the probabilities for next token
|
||
softmax(state.logits, state.logits.length);
|
||
|
||
// we sample from this distribution to get the next token
|
||
if (top_p <= 0 || top_p >= 1) {
|
||
// simply sample from the predicted probability distribution
|
||
next = sample(state.logits);
|
||
} else {
|
||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||
next = sample_topp(state.logits, top_p);
|
||
}
|
||
}
|
||
}
|
||
pos++;
|
||
|
||
// data-dependent terminating condition: the BOS (1) token delimits sequences
|
||
if (next == 1) {
|
||
break;
|
||
}
|
||
|
||
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
||
Uint8List token_str =
|
||
(token == 1 && (vocab[next][0] == ' ')) ? vocab[next + 1] : vocab[next];
|
||
|
||
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
|
||
String str;
|
||
str = utf8.decode(token_str);
|
||
|
||
// In the original llama2.c they check for a lot of special tokens, but I've only seen this token really being used
|
||
// Being a little lazy here Hehe.
|
||
if (str == "<0x0A>") {
|
||
str = "\n";
|
||
}
|
||
stdout.write("$str");
|
||
token = next;
|
||
|
||
// init the timer here because the first iteration can be slower
|
||
if (start == 0) {
|
||
start = DateTime.now().millisecondsSinceEpoch;
|
||
}
|
||
}
|
||
stdout.write("\n");
|
||
|
||
// report achieved tok/s (pos-1 because the timer starts after first iteration)
|
||
if (pos > 1) {
|
||
int end = DateTime.now().millisecondsSinceEpoch;
|
||
print("achieved tok/s: ${(pos - 1) / (end - start) * 1000} \n");
|
||
}
|
||
}
|