Compare commits
72 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4a7a62bd21 | |||
| 5c6427e4d7 | |||
| cbc2488b82 | |||
| fbe324fc5a | |||
| 6def77d4ba | |||
| 19cfbeca71 | |||
| d7cd98633d | |||
| 3d787b2463 | |||
| 40fb902cf0 | |||
| c7a26264a2 | |||
| 446c1c0df3 | |||
| 096325b66c | |||
| 90104db721 | |||
| 9bc72acab0 | |||
| c5e0e7fce4 | |||
| fe9b9f2f15 | |||
| 7ac65cb2c2 | |||
| 4b3e66021a | |||
| d1eb18b8ec | |||
| d26a499207 | |||
| ac6cf8d6e8 | |||
| ad7a1ef525 | |||
| 0e362f735f | |||
| d73b917d3b | |||
| 379f083b85 | |||
| 5eaca535cd | |||
| 83287ff254 | |||
| c2834c8a1f | |||
| ee95b1bf29 | |||
| d02e0c90d8 | |||
| 33d94f60a5 | |||
| 2d972f1763 | |||
| 8a3ea7b433 | |||
| 61c26d5392 | |||
| 36a78af5e1 | |||
| de005474d3 | |||
| 4444575c4e | |||
| dd61b13e57 | |||
| ea44f53568 | |||
| 801c68f5a1 | |||
| 74a68eeb35 | |||
| 288b3cec09 | |||
| 14275bd623 | |||
| 3868f732a4 | |||
| 8a377a1d31 | |||
| ae2e4f8d88 | |||
| 0dd82158f6 | |||
| 155475a523 | |||
| d7704bdeaa | |||
| 09db52c69e | |||
| a72b3b0206 | |||
| c74456f3f0 | |||
| 1e335a41cf | |||
| c0511de617 | |||
| 8c93c7a30e | |||
| 13dcee493a | |||
| f3db92a2dc | |||
| fa8dfd854e | |||
| 4df5e2e939 | |||
| 4212bd6d43 | |||
| 7f551dbfd7 | |||
| 6c5d78fa41 | |||
| db1a722816 | |||
| d2a546c577 | |||
| fbefeec1b1 | |||
| 978c311b30 | |||
| 882e480bc0 | |||
| d09ebbb32b | |||
| bc7cb7d0e8 | |||
| 01df3731d6 | |||
| 8607b11ea1 | |||
| 52fe3653e5 |
@@ -55,6 +55,14 @@ test:
|
||||
testc:
|
||||
pytest -k runc
|
||||
|
||||
# run the C tests, without touching pytest / python
|
||||
# to increase verbosity level run e.g. as `make testcc VERBOSITY=1`
|
||||
VERBOSITY ?= 0
|
||||
.PHONY: testcc
|
||||
testcc:
|
||||
$(CC) -DVERBOSITY=$(VERBOSITY) -O3 -o testc test.c -lm
|
||||
./testc
|
||||
|
||||
.PHONY: clean
|
||||
clean:
|
||||
rm -f run
|
||||
|
||||
@@ -8,7 +8,7 @@ Train the Llama 2 LLM architecture in PyTorch then inference it with one simple
|
||||
|
||||
As the architecture is identical, you can also load and inference Meta's Llama 2 models. However, the current code only inferences models in fp32, so you will most likely not be able to productively load models larger than 7B. Work on model quantization is currently ongoing.
|
||||
|
||||
Please note that this repo started recently as a fun weekend project: I took my earlier [nanoGPT](https://github.com/karpathy/nanoGPT), tuned it to implement the Llama-2 architecture instead of GPT-2, and the meat of it was writing the C inference engine in [run.c](run.c). So the project is young and moving quickly. Hat tip to the awesome [llama.cpp](https://github.com/ggerganov/llama.cpp) for inspiring this project. Compred to llama.cpp, I wanted something super simple, minimal, and educational so I chose to hard-code the Llama 2 architecture and just roll one inference file of pure C with no dependencies.
|
||||
Please note that this repo started recently as a fun weekend project: I took my earlier [nanoGPT](https://github.com/karpathy/nanoGPT), tuned it to implement the Llama-2 architecture instead of GPT-2, and the meat of it was writing the C inference engine in [run.c](run.c). So the project is young and moving quickly. Hat tip to the awesome [llama.cpp](https://github.com/ggerganov/llama.cpp) for inspiring this project. Compared to llama.cpp, I wanted something super simple, minimal, and educational so I chose to hard-code the Llama 2 architecture and just roll one inference file of pure C with no dependencies.
|
||||
|
||||
## feel the magic
|
||||
|
||||
@@ -65,13 +65,13 @@ Quick note on sampling, the recommendation for ~best results is to sample with `
|
||||
## Meta's Llama 2 models
|
||||
|
||||
As the neural net architecture is identical, we can also inference the Llama 2 models released by Meta. Sadly there is a bit of friction here due to licensing (I can't directly upload the checkpoints, I think). So Step 1, get the Llama 2 checkpoints by following the [Meta instructions](https://github.com/facebookresearch/llama). Once we have those checkpoints, we have to convert them into the llama2.c format.
|
||||
For this we need to install the python dependencies (`pip install -r requirements.txt`) and then use the `export_meta_llama_bin.py` file, e.g. for 7B model:
|
||||
For this we need to install the python dependencies (`pip install -r requirements.txt`) and then use the `export.py` file, e.g. for 7B model:
|
||||
|
||||
```bash
|
||||
python export_meta_llama_bin.py path/to/llama/model/7B llama2_7b.bin
|
||||
python export.py llama2_7b.bin --meta-llama path/to/llama/model/7B
|
||||
```
|
||||
|
||||
The export will take ~10 minutes or so and generate a 26GB file (the weights of the 7B model in float32) called `llama2_7b.bin` in the current directory. It has been [reported](https://github.com/karpathy/llama2.c/pull/85) that despite efforts, the 13B export currently doesn't work for unknown reasons (accepting PRs for fix). We can run the model as normal:
|
||||
The export will take ~10 minutes or so and generate a 26GB file (the weights of the 7B model in float32) called `llama2_7b.bin` in the current directory. It has been [reported](https://github.com/karpathy/llama2.c/pull/85) that despite efforts. I would not attempt to run anything above 7B right now for two reasons: first, 13B+ currently doesn't work because of integer flow in pointer arithmetic, which is yet to be fixed, and second, even if it were fixed, this repo is doing float32 inference right now, so it would be fairly unusably slow. Once the export is done, we can run it:
|
||||
|
||||
```bash
|
||||
./run llama2_7b.bin
|
||||
@@ -83,6 +83,22 @@ This ran at about 4 tokens/s compiled with [OpenMP](#OpenMP) on 96 threads on my
|
||||
|
||||
base models... ¯\\_(ツ)_/¯. Since we can inference the base model, it should be possible to also inference the chat model quite easily, and have a conversation with it. And if we can find a way to run 7B more efficiently, we can start adding LoRA to our training script, and going wild with finetunes all within the repo!
|
||||
|
||||
You can also chat with the Llama Chat models. Export the chat model exactly as above:
|
||||
|
||||
```bash
|
||||
python export.py llama2_7b_chat.bin --meta-llama /path/to/7B-chat
|
||||
```
|
||||
|
||||
Then chat with it by specifying the chat mode using the `-m` flag, e.g.:
|
||||
|
||||
```bash
|
||||
./run llama2_7b_chat.bin -m chat
|
||||
```
|
||||
|
||||
## hugginface models
|
||||
|
||||
We can load any huggingface models that use the Llama 2 architecture. See the script [export.py](export.py) and the `--hf` flag to export the model .bin file.
|
||||
|
||||
## models
|
||||
|
||||
For the sake of examples of smaller, from-scratch models, I trained a small model series on TinyStories. All of these trained in a few hours on my training setup (4X A100 40GB GPUs). The 110M took around 24 hours. I am hosting them on huggingface hub [tinyllamas](https://huggingface.co/karpathy/tinyllamas), both in the original PyTorch .pt, and also in the llama2.c format .bin:
|
||||
@@ -159,7 +175,7 @@ python tinystories.py train_vocab --vocab_size=4096
|
||||
python tinystories.py pretokenize --vocab_size=4096
|
||||
```
|
||||
|
||||
The `train_vocab` stage will call the `train_vocab.sh` script, which calls the `sentencepiece` library to train the tokenizer, storing it in a new file `data/tok4096.model`. I tried to reproduce as well as I could the settings that (I think) Meta used to train their vocabulary. This uses the Byte Pair Encoding algorithm that starts out with raw utf8 byte sequences of the text data and then iteratively merges the most common consecutive pairs of tokens to form the vocabulary. Inspect the `tinystories.py` file - the custom tokenizers are stored in a special directory structure indexed by the vocab size.
|
||||
The `train_vocab` stage will call the `sentencepiece` library to train the tokenizer, storing it in a new file `data/tok4096.model`. I tried to reproduce as well as I could the settings that (I think) Meta used to train their vocabulary. This uses the Byte Pair Encoding algorithm that starts out with raw utf8 byte sequences of the text data and then iteratively merges the most common consecutive pairs of tokens to form the vocabulary. Inspect the `tinystories.py` file - the custom tokenizers are stored in a special directory structure indexed by the vocab size.
|
||||
|
||||
A quick note of interest is that vocab size of 4096 trained specifically on tinystories creates integer sequences with about the same sequence length per example as the default Llama 2 tokenizer of 32000 tokens! This means that our custom, tailored tokenizer is a lot better adapted to our specific text, and can compress it very effectively. So our trained models are smaller and faster.
|
||||
|
||||
@@ -203,8 +219,7 @@ You can also experiment with replacing `gcc` with `clang`.
|
||||
|
||||
If compiling with gcc, try experimenting with `-funroll-all-loops`, see PR [#183](https://github.com/karpathy/llama2.c/pull/183)
|
||||
|
||||
### OpenMP
|
||||
Big improvements can also be achieved by compiling with OpenMP, which "activates" the `#pragma omp parallel for` inside the matmul and attention, allowing the work in the loops to be split up over multiple processors.
|
||||
**OpenMP**. Big improvements can also be achieved by compiling with OpenMP, which "activates" the `#pragma omp parallel for` inside the matmul and attention, allowing the work in the loops to be split up over multiple processors.
|
||||
You'll need to install the OpenMP library and the clang compiler first (e.g. `apt install clang libomp-dev` on ubuntu). Then you can compile with `make runomp`, which does:
|
||||
|
||||
```bash
|
||||
@@ -217,7 +232,8 @@ When you run inference make sure to use OpenMP flags to set the number of thread
|
||||
OMP_NUM_THREADS=4 ./run out/model.bin
|
||||
```
|
||||
|
||||
Depending on your system resources you may want to tweak these hyperparameters and use more threads. But more is not always better, usually this is a bit U shaped.
|
||||
Depending on your system resources you may want to tweak these hyperparameters and use more threads. But more is not always better, usually this is a bit U shaped. In particular, if your CPU has SMT (multithreading), try setting the number of threads to the number of physical cores rather than logical cores. The performance difference can be large due to cache thrashing and communication overhead. The PyTorch documentation [CPU specific optimizations
|
||||
](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#cpu-specific-optimizations) has some good information that applies here too.
|
||||
|
||||
## platforms
|
||||
|
||||
@@ -238,6 +254,14 @@ $ pytest
|
||||
|
||||
This will currently invoke two tests inside `test_all.py`, which forward the model in both C and Python for 200 steps and check the output against a known good expected output. The tests currently run in only a few seconds, but will have to download and cache the stories260K models in a temporary `test` directory (only ~2MB download).
|
||||
|
||||
There are also some tests in C, in the file [test.c](test.c). You can run these with `make testcc`, or to see more stuff printed:
|
||||
|
||||
```
|
||||
make testcc VERBOSITY=1
|
||||
```
|
||||
|
||||
Call for help: help add more tests.
|
||||
|
||||
## ack
|
||||
|
||||
I trained the llama2.c storyteller models on a 4X A100 40GB box graciously provided by the excellent [Lambda labs](https://lambdalabs.com/service/gpu-cloud), thank you.
|
||||
@@ -271,6 +295,7 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
|
||||
- [llama2.rs](https://github.com/leo-du/llama2.rs) by @[leo-du](https://github.com/leo-du): A Rust port of this project
|
||||
- [llama2-rs](https://github.com/danielgrittner/llama2-rs) by @[danielgrittner](https://github.com/danielgrittner): a Rust port of this project
|
||||
- [llama2.rs](https://github.com/lintian06/llama2.rs) by @[lintian06](https://github.com/lintian06): A Rust port of this project
|
||||
- [pecca.rs](https://github.com/rahoua/pecca-rs) by @[rahoua](https://github.com/rahoua): A Rust port leveraging [ndarray](https://github.com/rust-ndarray/ndarray), supports BLAS.
|
||||
- Go
|
||||
- [go-llama2](https://github.com/tmc/go-llama2) by @[tmc](https://github.com/tmc): a Go port of this project
|
||||
- [llama2.go](https://github.com/nikolaydubina/llama2.go) by @[nikolaydubina](https://github.com/nikolaydubina): a Go port of this project
|
||||
@@ -301,6 +326,8 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
|
||||
- [llama2.py](https://github.com/tairov/llama2.py) by @[tairov](https://github.com/tairov): a simple one file pure Python port of this project with zero dependencies
|
||||
- C#
|
||||
- [llama2.cs](https://github.com/trrahul/llama2.cs) by @[trrahul](https://github.com/trrahul): a C# port of this project
|
||||
- Dart
|
||||
- [llama2.dart](https://github.com/yiminghan/llama2.dart) by @[yiminghan](https://github.com/yiminghan/llama2.dart): one-file dart port of this project, works with Flutter!
|
||||
- WebAssembly
|
||||
- [icpp-llm](https://github.com/icppWorld/icpp-llm): LLMs for the Internet Computer
|
||||
- [llama2.c - Llama 2 Everywhere](https://github.com/trholding/llama2.c) by @[trholding](https://github.com/trholding): Standalone, Bootable & Portable Binary Llama 2
|
||||
@@ -308,12 +335,12 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
|
||||
|
||||
## unsorted todos
|
||||
|
||||
- add support in run.c of reading version 1+ files from export, later deprecate "version 0"
|
||||
- runq.c (int8 quantization) add
|
||||
- run.cu (CUDA) investigate and merge
|
||||
- add more tests inside [test.c](test.c)
|
||||
- add Engine class for use in sample.py that does efficient inference in PyTorch, e.g. KV cache keeping
|
||||
- make it easier to add a new dataset with not too much pain
|
||||
- should calculate freq_cis online in the script run.c instead of loading them
|
||||
- int4/8 quantization
|
||||
- export the model in a more sensible output format with a proper header, etc.
|
||||
- support Llama 2 7B Chat models and tune run.c to Chat UI/UX
|
||||
- llama2.cu investigate and merge
|
||||
- (LoRA) finetuning and export of Llama 2 models
|
||||
|
||||
## License
|
||||
|
||||
@@ -0,0 +1,470 @@
|
||||
"""
|
||||
This script has functions and utilties for model export.
|
||||
Basically, we have a bunch of versions of the model, and we
|
||||
want to export them to .bin files to be read from and inferenced in C.
|
||||
|
||||
Among the "input" versions of PyTorch files/models:
|
||||
- Official Llama 2 weights released by Meta
|
||||
- Huggingface weights available on the hub
|
||||
- llama2.c (this repo) trained models
|
||||
|
||||
Among the "output" versions of .bin files:
|
||||
- v0: Legacy files of the original llama2.c repo (will eventually be DEPRECATED)
|
||||
- v1-vN: Improved .bin files with a proper header, cache alignment, etc.
|
||||
|
||||
This script aspires to provide all of these conversions.
|
||||
"""
|
||||
import os
|
||||
import gzip
|
||||
import shutil
|
||||
import struct
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from model import ModelArgs, Transformer
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# common utilities
|
||||
|
||||
def serialize_fp32(file, tensor):
|
||||
""" writes one fp32 tensor to file that is open in wb mode """
|
||||
d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
|
||||
b = struct.pack(f'{len(d)}f', *d)
|
||||
file.write(b)
|
||||
|
||||
def serialize_int8(file, tensor):
|
||||
""" writes one int8 tensor to file that is open in wb mode """
|
||||
d = tensor.detach().cpu().view(-1).numpy().astype(np.int8)
|
||||
b = struct.pack(f'{len(d)}b', *d)
|
||||
file.write(b)
|
||||
|
||||
def quantize_q80(w, group_size):
|
||||
"""
|
||||
takes a tensor and returns the Q8_0 quantized version
|
||||
i.e. symmetric quantization into int8, range [-127,127]
|
||||
"""
|
||||
assert w.numel() % group_size == 0
|
||||
ori_shape = w.shape
|
||||
w = w.float() # convert to float32
|
||||
w = w.reshape(-1, group_size)
|
||||
# find the max in each group
|
||||
wmax = torch.abs(w).max(dim=1).values
|
||||
# calculate the scaling factor such that float = quant * scale
|
||||
scale = wmax / 127.0
|
||||
# scale into range [-127, 127]
|
||||
quant = w / scale[:,None]
|
||||
# round to nearest integer
|
||||
int8val = torch.round(quant).to(torch.int8)
|
||||
# dequantize by rescaling
|
||||
fp32val = (int8val.float() * scale[:,None]).view(-1)
|
||||
fp32valr = fp32val.reshape(-1, group_size)
|
||||
# calculate the max error in each group
|
||||
err = torch.abs(fp32valr - w).max(dim=1).values
|
||||
# find the max error across all groups
|
||||
maxerr = err.max().item()
|
||||
return int8val, scale, maxerr
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# legacy
|
||||
|
||||
def legacy_export(model, filepath):
|
||||
""" Original export of llama2.c bin files, i.e. version v0 """
|
||||
out_file = open(filepath, 'wb')
|
||||
|
||||
# first write out the header
|
||||
hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
|
||||
p = model.params
|
||||
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
|
||||
# legacy format uses negative/positive vocab size as a shared classifier flag
|
||||
if not shared_classifier:
|
||||
p.vocab_size = -p.vocab_size
|
||||
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
||||
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
||||
n_kv_heads, p.vocab_size, p.max_seq_len)
|
||||
out_file.write(header)
|
||||
|
||||
# next write out the embedding weights
|
||||
serialize_fp32(out_file, model.tok_embeddings.weight)
|
||||
|
||||
# now all the layers
|
||||
# attention weights
|
||||
for layer in model.layers:
|
||||
serialize_fp32(out_file, layer.attention_norm.weight)
|
||||
for layer in model.layers:
|
||||
serialize_fp32(out_file, layer.attention.wq.weight)
|
||||
for layer in model.layers:
|
||||
serialize_fp32(out_file, layer.attention.wk.weight)
|
||||
for layer in model.layers:
|
||||
serialize_fp32(out_file, layer.attention.wv.weight)
|
||||
for layer in model.layers:
|
||||
serialize_fp32(out_file, layer.attention.wo.weight)
|
||||
# ffn weights
|
||||
for layer in model.layers:
|
||||
serialize_fp32(out_file, layer.ffn_norm.weight)
|
||||
for layer in model.layers:
|
||||
serialize_fp32(out_file, layer.feed_forward.w1.weight)
|
||||
for layer in model.layers:
|
||||
serialize_fp32(out_file, layer.feed_forward.w2.weight)
|
||||
for layer in model.layers:
|
||||
serialize_fp32(out_file, layer.feed_forward.w3.weight)
|
||||
# final rmsnorm
|
||||
serialize_fp32(out_file, model.norm.weight)
|
||||
# freqs_cis
|
||||
serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len])
|
||||
serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len])
|
||||
|
||||
# final classifier weights
|
||||
if not shared_classifier:
|
||||
serialize_fp32(out_file, model.output.weight)
|
||||
|
||||
# write to binary file
|
||||
out_file.close()
|
||||
print(f"wrote {filepath}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# new version
|
||||
|
||||
def version1_export(model, filepath):
|
||||
"""
|
||||
Export the model weights in full float32 .bin file to be read from C.
|
||||
This is same as legacy_export, but with a proper header.
|
||||
"""
|
||||
version = 1
|
||||
|
||||
out_file = open(filepath, 'wb')
|
||||
# first write out the header. the header will be 256 bytes
|
||||
# 1) write magic, which will be uint32 of "ak42" in ASCII
|
||||
out_file.write(struct.pack('I', 0x616b3432))
|
||||
# 2) write version, which will be int
|
||||
out_file.write(struct.pack('i', version))
|
||||
# 3) write the params, which will be 7 ints
|
||||
p = model.params
|
||||
hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
|
||||
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
||||
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
||||
n_kv_heads, p.vocab_size, p.max_seq_len)
|
||||
out_file.write(header)
|
||||
# 4) write some other flags
|
||||
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
|
||||
out_file.write(struct.pack('B', int(shared_classifier)))
|
||||
pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
|
||||
assert pad >= 0
|
||||
out_file.write(b'\0' * pad)
|
||||
|
||||
# now let's write out all the params
|
||||
weights = [
|
||||
*[layer.attention_norm.weight for layer in model.layers],
|
||||
*[layer.ffn_norm.weight for layer in model.layers],
|
||||
model.norm.weight,
|
||||
model.tok_embeddings.weight,
|
||||
*[layer.attention.wq.weight for layer in model.layers],
|
||||
*[layer.attention.wk.weight for layer in model.layers],
|
||||
*[layer.attention.wv.weight for layer in model.layers],
|
||||
*[layer.attention.wo.weight for layer in model.layers],
|
||||
*[layer.feed_forward.w1.weight for layer in model.layers],
|
||||
*[layer.feed_forward.w2.weight for layer in model.layers],
|
||||
*[layer.feed_forward.w3.weight for layer in model.layers],
|
||||
]
|
||||
if not shared_classifier:
|
||||
weights.append(model.output.weight)
|
||||
for w in weights:
|
||||
serialize_fp32(out_file, w)
|
||||
|
||||
# write to binary file
|
||||
out_file.close()
|
||||
print(f"wrote {filepath}")
|
||||
|
||||
def version2_export(model, filepath, group_size=64):
|
||||
"""
|
||||
Export the model weights in Q8_0 into .bin file to be read from C.
|
||||
That is:
|
||||
- quantize all weights to symmetric int8, in range [-127, 127]
|
||||
- all other tensors (the rmsnorm params) are kept and exported in fp32
|
||||
- quantization is done in groups of group_size to reduce the effects of any outliers
|
||||
"""
|
||||
version = 2
|
||||
|
||||
# let's first do some validation for this export type
|
||||
while model.params.dim % group_size != 0:
|
||||
group_size //= 2
|
||||
print(f"BACKOFF: reducing group size to {group_size} to fit hidden_dim")
|
||||
weights = [
|
||||
model.tok_embeddings.weight,
|
||||
*[layer.attention.wq.weight for layer in model.layers],
|
||||
*[layer.attention.wk.weight for layer in model.layers],
|
||||
*[layer.attention.wv.weight for layer in model.layers],
|
||||
*[layer.attention.wo.weight for layer in model.layers],
|
||||
*[layer.feed_forward.w1.weight for layer in model.layers],
|
||||
*[layer.feed_forward.w2.weight for layer in model.layers],
|
||||
*[layer.feed_forward.w3.weight for layer in model.layers],
|
||||
]
|
||||
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
|
||||
if not shared_classifier:
|
||||
weights.append(model.output.weight)
|
||||
for w in weights:
|
||||
assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}"
|
||||
|
||||
# write
|
||||
out_file = open(filepath, 'wb')
|
||||
# first write out the header. the header will be 256 bytes
|
||||
# 1) write magic, which will be uint32 of "ak42" in ASCII
|
||||
out_file.write(struct.pack('I', 0x616b3432))
|
||||
# 2) write version, which will be int
|
||||
out_file.write(struct.pack('i', version))
|
||||
# 3) write the params, which will be 7 ints
|
||||
p = model.params
|
||||
hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
|
||||
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
||||
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
||||
n_kv_heads, p.vocab_size, p.max_seq_len)
|
||||
out_file.write(header)
|
||||
# 4) write some other flags
|
||||
out_file.write(struct.pack('B', int(shared_classifier)))
|
||||
out_file.write(struct.pack('i', group_size)) # group size used for quantization
|
||||
pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
|
||||
assert pad >= 0
|
||||
out_file.write(b'\0' * pad)
|
||||
# now that the header is done, let's write out the model
|
||||
|
||||
# first let's write out all the params that we are keeping in fp32: the norms
|
||||
for layer in model.layers: # attention norms
|
||||
serialize_fp32(out_file, layer.attention_norm.weight)
|
||||
for layer in model.layers: # MLP norms
|
||||
serialize_fp32(out_file, layer.ffn_norm.weight)
|
||||
serialize_fp32(out_file, model.norm.weight) # final pre-classifier norm
|
||||
|
||||
# now let's write out all the params that we are quantizing to Q8_0
|
||||
# note we skip classifier weights, which are shared with the embedding
|
||||
ew = []
|
||||
scales = []
|
||||
for i, w in enumerate(weights):
|
||||
# quantize this weight
|
||||
q, s, err = quantize_q80(w, group_size)
|
||||
# save the int8 weights to file
|
||||
serialize_int8(out_file, q) # save the tensor in int8
|
||||
scales.append(s) # we'll do all the scales after all the qs
|
||||
# logging
|
||||
ew.append((err, w.shape))
|
||||
print(f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}")
|
||||
|
||||
# save the scaling factors in fp32 here
|
||||
# this is done to keep all the weights contiquous, making pointer arithmetic easier in C
|
||||
for s in scales:
|
||||
serialize_fp32(out_file, s)
|
||||
|
||||
# print the highest error across all weights, should be very small, e.g. O(~0.001)
|
||||
ew.sort(reverse=True)
|
||||
print(f"max quantization group error across all weights: {ew[0][0]}")
|
||||
|
||||
# write to binary file
|
||||
out_file.close()
|
||||
print(f"wrote {filepath}")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Load / import functions
|
||||
|
||||
def load_checkpoint(checkpoint):
|
||||
|
||||
# load the provided model checkpoint
|
||||
checkpoint_dict = torch.load(checkpoint, map_location='cpu')
|
||||
gptconf = ModelArgs(**checkpoint_dict['model_args'])
|
||||
model = Transformer(gptconf)
|
||||
state_dict = checkpoint_dict['model']
|
||||
unwanted_prefix = '_orig_mod.'
|
||||
for k,v in list(state_dict.items()):
|
||||
if k.startswith(unwanted_prefix):
|
||||
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def load_meta_model(model_path):
|
||||
params_path = os.path.join(model_path, 'params.json')
|
||||
with open(params_path) as f:
|
||||
params = json.load(f)
|
||||
print(params)
|
||||
|
||||
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
|
||||
models = [torch.load(p, map_location='cpu') for p in model_paths]
|
||||
|
||||
def concat_weights(models):
|
||||
state_dict = {}
|
||||
for name in list(models[0]):
|
||||
tensors = [model[name] for model in models]
|
||||
if len(tensors) == 1 or len(tensors[0].shape) == 1:
|
||||
state_dict[name] = tensors[0]
|
||||
continue
|
||||
is_axis_1 = (
|
||||
name.startswith('tok_embeddings.')
|
||||
or name.endswith('.attention.wo.weight')
|
||||
or name.endswith('.feed_forward.w2.weight')
|
||||
)
|
||||
axis = 1 if is_axis_1 else 0
|
||||
state_dict[name] = torch.cat(tensors, dim=axis)
|
||||
for model in models:
|
||||
del model[name]
|
||||
return state_dict
|
||||
|
||||
state_dict = concat_weights(models)
|
||||
del models
|
||||
|
||||
# set ModelArgs
|
||||
config = ModelArgs()
|
||||
config.dim = params["dim"]
|
||||
config.n_layers = params["n_layers"]
|
||||
config.n_heads = params["n_heads"]
|
||||
config.n_kv_heads = params.get('n_kv_heads') or params['n_heads']
|
||||
config.multiple_of = params["multiple_of"]
|
||||
config.norm_eps = params["norm_eps"]
|
||||
|
||||
config.vocab_size = 32000
|
||||
config.max_seq_len = 2048
|
||||
|
||||
# create a new Transformer object and set weights
|
||||
model = Transformer(config)
|
||||
|
||||
model.tok_embeddings.weight = nn.Parameter(state_dict['tok_embeddings.weight'])
|
||||
model.norm.weight = nn.Parameter(state_dict['norm.weight'])
|
||||
|
||||
for layer in model.layers:
|
||||
i = layer.layer_id
|
||||
layer.attention_norm.weight = nn.Parameter(state_dict[f'layers.{i}.attention_norm.weight'])
|
||||
layer.attention.wq.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wq.weight'])
|
||||
layer.attention.wk.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wk.weight'])
|
||||
layer.attention.wv.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wv.weight'])
|
||||
layer.attention.wo.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wo.weight'])
|
||||
layer.ffn_norm.weight = nn.Parameter(state_dict[f'layers.{i}.ffn_norm.weight'])
|
||||
layer.feed_forward.w1.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w1.weight'])
|
||||
layer.feed_forward.w2.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w2.weight'])
|
||||
layer.feed_forward.w3.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w3.weight'])
|
||||
|
||||
# final classifier
|
||||
model.output.weight = nn.Parameter(state_dict['output.weight'])
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def load_hf_model(model_path):
|
||||
|
||||
try:
|
||||
from transformers import AutoModelForCausalLM
|
||||
except ImportError:
|
||||
print("Error: transformers package is required to load huggingface models")
|
||||
print("Please run `pip install transformers` to install it")
|
||||
return None
|
||||
|
||||
# load HF model
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(model_path)
|
||||
hf_dict = hf_model.state_dict()
|
||||
|
||||
# convert LlamaConfig to ModelArgs
|
||||
config = ModelArgs()
|
||||
config.dim = hf_model.config.hidden_size
|
||||
config.n_layers = hf_model.config.num_hidden_layers
|
||||
config.n_heads = hf_model.config.num_attention_heads
|
||||
config.n_kv_heads = hf_model.config.num_attention_heads
|
||||
config.vocab_size = hf_model.config.vocab_size
|
||||
config.hidden_dim = hf_model.config.intermediate_size
|
||||
config.norm_eps = hf_model.config.rms_norm_eps
|
||||
config.max_seq_len = hf_model.config.max_position_embeddings
|
||||
|
||||
# create a new Transformer object and set weights
|
||||
model = Transformer(config)
|
||||
|
||||
model.tok_embeddings.weight = nn.Parameter(hf_dict['model.embed_tokens.weight'])
|
||||
model.norm.weight = nn.Parameter(hf_dict['model.norm.weight'])
|
||||
|
||||
# huggingface permutes WQ and WK, this function reverses it
|
||||
def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim):
|
||||
return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||
|
||||
for layer in model.layers:
|
||||
i = layer.layer_id
|
||||
layer.attention_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.input_layernorm.weight'])
|
||||
layer.attention.wq.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight']))
|
||||
layer.attention.wk.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight']))
|
||||
layer.attention.wv.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.v_proj.weight'])
|
||||
layer.attention.wo.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.o_proj.weight'])
|
||||
layer.ffn_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.post_attention_layernorm.weight'])
|
||||
layer.feed_forward.w1.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.gate_proj.weight'])
|
||||
layer.feed_forward.w2.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.down_proj.weight'])
|
||||
layer.feed_forward.w3.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.up_proj.weight'])
|
||||
|
||||
# final classifier
|
||||
model.output.weight = nn.Parameter(hf_dict['lm_head.weight'])
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# API entrypoint
|
||||
|
||||
def model_export(model, filepath, version):
|
||||
if version == 0:
|
||||
legacy_export(model, filepath)
|
||||
elif version == 1:
|
||||
version1_export(model, filepath)
|
||||
elif version == 2:
|
||||
version2_export(model, filepath)
|
||||
else:
|
||||
raise ValueError(f"unknown version {version}")
|
||||
|
||||
def torchscript_export(model, filepath, zero_params=False, gzip_output=False):
|
||||
"""
|
||||
(This was submitted via a PR earlier. Leaving it here, but "orphaned" for now)
|
||||
Saves the model as a TorchScript.
|
||||
The resulting file can be loaded in C++ code and then used for training or
|
||||
inference with:
|
||||
#include <torch/script.h>
|
||||
torch::jit::Module module = torch::jit::load("model.pt")
|
||||
Note that the serialized model includes the initial parameters and with the default
|
||||
ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute
|
||||
the model parameters separately you can zero out the parameters before saving it and
|
||||
it will gzip down to 780K.
|
||||
"""
|
||||
|
||||
# If requested zero params before saving the model. This is useful in
|
||||
# conjunction with gzip_output.
|
||||
if zero_params:
|
||||
for p in model.parameters():
|
||||
p.detach().zero_()
|
||||
|
||||
torch.jit.save(torch.jit.script(model), filepath)
|
||||
|
||||
if gzip_output:
|
||||
with open(filepath, "rb") as f_in:
|
||||
with gzip.open(f"{filepath}.gz", "wb") as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
os.unlink(filepath)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# CLI entrypoint
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("filepath", type=str, help="the output filepath")
|
||||
parser.add_argument("--version", default=0, type=int, help="the version to export with")
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
|
||||
group.add_argument("--meta-llama", type=str, help="meta llama model path")
|
||||
group.add_argument("--hf", type=str, help="huggingface model path")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.checkpoint:
|
||||
model = load_checkpoint(args.checkpoint)
|
||||
elif args.meta_llama:
|
||||
model = load_meta_model(args.meta_llama)
|
||||
elif args.hf:
|
||||
model = load_hf_model(args.hf)
|
||||
|
||||
if model is None:
|
||||
parser.error("Can't load input model!")
|
||||
|
||||
# export
|
||||
model_export(model, args.filepath, args.version)
|
||||
@@ -1,112 +0,0 @@
|
||||
"""
|
||||
This script exports the Llama 2 weights in llama2c.bin format.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import struct
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
import torch
|
||||
|
||||
from model import precompute_freqs_cis
|
||||
|
||||
|
||||
def export(p, state_dict, filepath='model.bin'):
|
||||
"""export the model weights in fp32 into .bin file to be read from C"""
|
||||
f = open(filepath, 'wb')
|
||||
|
||||
def serialize(key):
|
||||
print(f"writing {key}...")
|
||||
t = state_dict[key].contiguous().view(-1).type(torch.float32).numpy()
|
||||
f.write(memoryview(t))
|
||||
del state_dict[key]
|
||||
|
||||
# first write out the header
|
||||
hidden_dim = state_dict['layers.0.feed_forward.w1.weight'].shape[0]
|
||||
p['vocab_size'] = 32000
|
||||
p['max_seq_len'] = 2048
|
||||
|
||||
n_kv_heads = p.get('n_kv_heads') or p['n_heads']
|
||||
header = struct.pack(
|
||||
'iiiiiii',
|
||||
p['dim'], hidden_dim, p['n_layers'], p['n_heads'],
|
||||
n_kv_heads, -p['vocab_size'], p['max_seq_len']
|
||||
)
|
||||
# NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present
|
||||
# in the checkpoint and should be loaded.
|
||||
f.write(header)
|
||||
|
||||
# next write out the embedding weights
|
||||
print("writing tok_embeddings...")
|
||||
serialize('tok_embeddings.weight')
|
||||
|
||||
# now all the layers
|
||||
# attention weights
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.attention_norm.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wq.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wk.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wv.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wo.weight')
|
||||
# ffn weights
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.ffn_norm.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w1.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w2.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w3.weight')
|
||||
|
||||
# final rmsnorm
|
||||
serialize('norm.weight')
|
||||
# freqs_cos, freqs_sin
|
||||
freqs_cos, freqs_sin = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
|
||||
state_dict['freqs_cos'] = freqs_cos[:p['max_seq_len']]
|
||||
state_dict['freqs_sin'] = freqs_sin[:p['max_seq_len']]
|
||||
serialize('freqs_cos')
|
||||
serialize('freqs_sin')
|
||||
|
||||
# finally write the output weights
|
||||
serialize('output.weight')
|
||||
|
||||
f.close()
|
||||
print(f"wrote {filepath}")
|
||||
|
||||
|
||||
def concat_weights(models):
|
||||
state_dict = {}
|
||||
for name in list(models[0]):
|
||||
tensors = [model[name] for model in models]
|
||||
if len(tensors) == 1 or len(tensors[0].shape) == 1:
|
||||
state_dict[name] = tensors[0]
|
||||
continue
|
||||
is_axis_1 = (
|
||||
name.startswith('tok_embeddings.')
|
||||
or name.endswith('.attention.wo.weight')
|
||||
or name.endswith('.feed_forward.w2.weight')
|
||||
)
|
||||
axis = 1 if is_axis_1 else 0
|
||||
state_dict[name] = torch.cat(tensors, dim=axis)
|
||||
for model in models:
|
||||
del model[name]
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_and_export(model_path, output_path):
|
||||
params_path = os.path.join(model_path, 'params.json')
|
||||
with open(params_path) as f:
|
||||
params = json.load(f)
|
||||
print(params)
|
||||
|
||||
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
|
||||
models = [torch.load(p, map_location='cpu') for p in model_paths]
|
||||
state_dict = concat_weights(models)
|
||||
del models
|
||||
export(params, state_dict, output_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) == 1:
|
||||
print('[Llama model folder path] [output path]')
|
||||
exit()
|
||||
|
||||
model_path = sys.argv[1]
|
||||
output_path = sys.argv[2]
|
||||
load_and_export(model_path, output_path)
|
||||
@@ -1,113 +0,0 @@
|
||||
"""
|
||||
This script exports the Llama 2 weights in llama2c.bin format.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import struct
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
import torch
|
||||
|
||||
from model import precompute_freqs_cis
|
||||
|
||||
|
||||
def export(p, state_dict, filepath='model.bin'):
|
||||
"""export the model weights in fp32 into .bin file to be read from C"""
|
||||
f = open(filepath, 'wb')
|
||||
|
||||
def serialize(key):
|
||||
print(f"writing {key}...")
|
||||
t = state_dict[key].contiguous().view(-1).type(torch.float32).numpy()
|
||||
f.write(memoryview(t))
|
||||
del state_dict[key]
|
||||
|
||||
# first write out the header
|
||||
hidden_dim = state_dict['model.layers.0.mlp.gate_proj.weight'].shape[0]
|
||||
p['vocab_size'] = 32000
|
||||
p['max_seq_len'] = 2048
|
||||
|
||||
n_kv_heads = p.get('n_kv_heads') or p['n_heads']
|
||||
header = struct.pack(
|
||||
'iiiiiii',
|
||||
p['dim'], hidden_dim, p['n_layers'], p['n_heads'],
|
||||
n_kv_heads, -p['vocab_size'], p['max_seq_len']
|
||||
)
|
||||
# NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present
|
||||
# in the checkpoint and should be loaded.
|
||||
f.write(header)
|
||||
|
||||
# next write out the embedding weights
|
||||
print("writing tok_embeddings...")
|
||||
serialize('model.embed_tokens.weight')
|
||||
|
||||
# now all the layers
|
||||
# attention weights
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.input_layernorm.weight')
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.self_attn.q_proj.weight')
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.self_attn.k_proj.weight')
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.self_attn.v_proj.weight')
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.self_attn.o_proj.weight')
|
||||
# ffn weights
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.post_attention_layernorm.weight')
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.mlp.gate_proj.weight')
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.mlp.down_proj.weight')
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.mlp.up_proj.weight')
|
||||
|
||||
# final rmsnorm
|
||||
serialize('model.norm.weight')
|
||||
# freqs_cos, freqs_sin
|
||||
freqs_cos, freqs_sin = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
|
||||
state_dict['freqs_cos'] = freqs_cos[:p['max_seq_len']]
|
||||
state_dict['freqs_sin'] = freqs_sin[:p['max_seq_len']]
|
||||
# check if this requires addtional conversion
|
||||
serialize('freqs_cos')
|
||||
serialize('freqs_sin')
|
||||
|
||||
# finally write the output weights
|
||||
serialize('lm_head.weight')
|
||||
|
||||
f.close()
|
||||
print(f"wrote {filepath}")
|
||||
|
||||
|
||||
def concat_weights(models):
|
||||
state_dict = {}
|
||||
for name in list(models[0]):
|
||||
tensors = [model[name] for model in models]
|
||||
if len(tensors) == 1 or len(tensors[0].shape) == 1:
|
||||
state_dict[name] = tensors[0]
|
||||
continue
|
||||
is_axis_1 = (
|
||||
name.startswith('model.embed_tokens.weight')
|
||||
or name.endswith('.self_attn.o_proj.weight')
|
||||
or name.endswith('.mlp.down_proj.weight')
|
||||
)
|
||||
axis = 1 if is_axis_1 else 0
|
||||
state_dict[name] = torch.cat(tensors, dim=axis)
|
||||
for model in models:
|
||||
del model[name]
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_and_export(model_path, output_path):
|
||||
params_path = os.path.join(model_path, 'params.json')
|
||||
with open(params_path) as f:
|
||||
params = json.load(f)
|
||||
print(params)
|
||||
|
||||
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
|
||||
models = [torch.load(p, map_location='cpu') for p in model_paths]
|
||||
state_dict = concat_weights(models)
|
||||
del models
|
||||
export(params, state_dict, output_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) == 1:
|
||||
print('[Llama model folder path] [output path]')
|
||||
exit()
|
||||
|
||||
model_path = sys.argv[1]
|
||||
output_path = sys.argv[2]
|
||||
load_and_export(model_path, output_path)
|
||||
@@ -17,6 +17,7 @@ class ModelArgs:
|
||||
n_heads: int = 32
|
||||
n_kv_heads: Optional[int] = None
|
||||
vocab_size: int = 32000
|
||||
hidden_dim: Optional[int] = None
|
||||
multiple_of: int = 256 # MLP hidden layer size will be multiple of
|
||||
norm_eps: float = 1e-5
|
||||
max_seq_len: int = 2048
|
||||
@@ -166,8 +167,10 @@ class Attention(nn.Module):
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
if hidden_dim is None:
|
||||
hidden_dim = 4 * dim
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
@@ -186,7 +189,7 @@ class TransformerBlock(nn.Module):
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=args.dim,
|
||||
hidden_dim=4 * args.dim,
|
||||
hidden_dim=args.hidden_dim,
|
||||
multiple_of=args.multiple_of,
|
||||
dropout=args.dropout,
|
||||
)
|
||||
@@ -338,55 +341,3 @@ class Transformer(nn.Module):
|
||||
idx = torch.cat((idx, idx_next), dim=1)
|
||||
|
||||
return idx
|
||||
|
||||
def export(self, filepath='model.bin'):
|
||||
"""export the model weights in fp32 into .bin file to be read from C"""
|
||||
f = open(filepath, 'wb')
|
||||
|
||||
def serialize(t):
|
||||
d = t.detach().cpu().view(-1).numpy().astype(np.float32)
|
||||
b = struct.pack(f'{len(d)}f', *d)
|
||||
f.write(b)
|
||||
|
||||
# first write out the header
|
||||
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
|
||||
p = self.params
|
||||
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
||||
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
||||
n_kv_heads, p.vocab_size, p.max_seq_len)
|
||||
f.write(header)
|
||||
|
||||
# next write out the embedding weights
|
||||
serialize(self.tok_embeddings.weight)
|
||||
|
||||
# now all the layers
|
||||
# attention weights
|
||||
for layer in self.layers:
|
||||
serialize(layer.attention_norm.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.attention.wq.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.attention.wk.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.attention.wv.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.attention.wo.weight)
|
||||
# ffn weights
|
||||
for layer in self.layers:
|
||||
serialize(layer.ffn_norm.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.feed_forward.w1.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.feed_forward.w2.weight)
|
||||
for layer in self.layers:
|
||||
serialize(layer.feed_forward.w3.weight)
|
||||
# final rmsnorm
|
||||
serialize(self.norm.weight)
|
||||
# note: no need to write final classifier weights due to weight sharing
|
||||
# freqs_cis
|
||||
serialize(self.freqs_cos[:p.max_seq_len])
|
||||
serialize(self.freqs_sin[:p.max_seq_len])
|
||||
|
||||
# write to binary file
|
||||
f.close()
|
||||
print(f"wrote {filepath}")
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#include <sys/mman.h>
|
||||
#endif
|
||||
// ----------------------------------------------------------------------------
|
||||
// Transformer and RunState structs, and related memory management
|
||||
// Transformer model
|
||||
|
||||
typedef struct {
|
||||
int dim; // transformer dimension
|
||||
@@ -43,18 +43,10 @@ typedef struct {
|
||||
float* w3; // (layer, hidden_dim, dim)
|
||||
// final rmsnorm
|
||||
float* rms_final_weight; // (dim,)
|
||||
// freq_cis for RoPE relatively positional embeddings (not used anymore)
|
||||
float* freq_cis_real; // (seq_len, head_size/2)
|
||||
float* freq_cis_imag; // (seq_len, head_size/2)
|
||||
// (optional) classifier weights for the logits, on the last layer
|
||||
float* wcls;
|
||||
} TransformerWeights;
|
||||
|
||||
typedef struct {
|
||||
float prob;
|
||||
int index;
|
||||
} ProbIndex; // struct used when sorting probabilities during top-p sampling
|
||||
|
||||
typedef struct {
|
||||
// current wave of activations
|
||||
float *x; // activation at current time stamp (dim,)
|
||||
@@ -67,12 +59,21 @@ typedef struct {
|
||||
float *v; // value (dim,)
|
||||
float *att; // buffer for scores/attention values (n_heads, seq_len)
|
||||
float *logits; // output logits
|
||||
ProbIndex *probindex; // buffer used in top-p sampling
|
||||
// kv cache
|
||||
float* key_cache; // (layer, seq_len, dim)
|
||||
float* value_cache; // (layer, seq_len, dim)
|
||||
} RunState;
|
||||
|
||||
typedef struct {
|
||||
Config config; // the hyperparameters of the architecture (the blueprint)
|
||||
TransformerWeights weights; // the weights of the model
|
||||
RunState state; // buffers for the "wave" of activations in the forward pass
|
||||
// some more state needed to properly clean up the memory mapping (sigh)
|
||||
int fd; // file descriptor for memory mapping
|
||||
float* data; // memory mapped data pointer
|
||||
ssize_t file_size; // size of the checkpoint file in bytes
|
||||
} Transformer;
|
||||
|
||||
void malloc_run_state(RunState* s, Config* p) {
|
||||
// we calloc instead of malloc to keep valgrind happy
|
||||
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
|
||||
@@ -86,13 +87,12 @@ void malloc_run_state(RunState* s, Config* p) {
|
||||
s->v = calloc(kv_dim, sizeof(float));
|
||||
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
|
||||
s->logits = calloc(p->vocab_size, sizeof(float));
|
||||
s->probindex = calloc(p->vocab_size, sizeof(ProbIndex));
|
||||
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
||||
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
||||
// ensure all mallocs went fine
|
||||
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|
||||
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|
||||
|| !s->value_cache || !s->probindex) {
|
||||
|| !s->value_cache) {
|
||||
fprintf(stderr, "malloc failed!\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
@@ -109,15 +109,11 @@ void free_run_state(RunState* s) {
|
||||
free(s->v);
|
||||
free(s->att);
|
||||
free(s->logits);
|
||||
free(s->probindex);
|
||||
free(s->key_cache);
|
||||
free(s->value_cache);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// initialization: read from checkpoint
|
||||
|
||||
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
|
||||
void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
|
||||
int head_size = p->dim / p->n_heads;
|
||||
w->token_embedding_table = ptr;
|
||||
ptr += p->vocab_size * p->dim;
|
||||
@@ -141,15 +137,50 @@ void checkpoint_init_weights(TransformerWeights *w, Config* p, float* ptr, int s
|
||||
ptr += p->n_layers * p->dim * p->hidden_dim;
|
||||
w->rms_final_weight = ptr;
|
||||
ptr += p->dim;
|
||||
w->freq_cis_real = ptr;
|
||||
ptr += p->seq_len * head_size / 2;
|
||||
w->freq_cis_imag = ptr;
|
||||
ptr += p->seq_len * head_size / 2;
|
||||
ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_real (for RoPE)
|
||||
ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_imag (for RoPE)
|
||||
w->wcls = shared_weights ? w->token_embedding_table : ptr;
|
||||
}
|
||||
|
||||
void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
|
||||
int* fd, float** data, ssize_t* file_size) {
|
||||
FILE *file = fopen(checkpoint, "rb");
|
||||
if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }
|
||||
// read in the config header
|
||||
if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); }
|
||||
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
|
||||
int shared_weights = config->vocab_size > 0 ? 1 : 0;
|
||||
config->vocab_size = abs(config->vocab_size);
|
||||
// figure out the file size
|
||||
fseek(file, 0, SEEK_END); // move file pointer to end of file
|
||||
*file_size = ftell(file); // get the file size, in bytes
|
||||
fclose(file);
|
||||
// memory map the Transformer weights into the data pointer
|
||||
*fd = open(checkpoint, O_RDONLY); // open in read only mode
|
||||
if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); }
|
||||
*data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
|
||||
if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
|
||||
float* weights_ptr = *data + sizeof(Config)/sizeof(float);
|
||||
memory_map_weights(weights, config, weights_ptr, shared_weights);
|
||||
}
|
||||
|
||||
void build_transformer(Transformer *t, char* checkpoint_path) {
|
||||
// read in the Config and the Weights from the checkpoint
|
||||
read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
|
||||
// allocate the RunState buffers
|
||||
malloc_run_state(&t->state, &t->config);
|
||||
}
|
||||
|
||||
void free_transformer(Transformer* t) {
|
||||
// close the memory mapping
|
||||
if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); }
|
||||
if (t->fd != -1) { close(t->fd); }
|
||||
// free the RunState buffers
|
||||
free_run_state(&t->state);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// neural net blocks
|
||||
// neural net blocks; the dynamics of the Transformer
|
||||
|
||||
void rmsnorm(float* o, float* x, float* weight, int size) {
|
||||
// calculate sum of squares
|
||||
@@ -200,9 +231,12 @@ void matmul(float* xout, float* x, float* w, int n, int d) {
|
||||
}
|
||||
}
|
||||
|
||||
void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) {
|
||||
float* forward(Transformer* transformer, int token, int pos) {
|
||||
|
||||
// a few convenience variables
|
||||
Config* p = &transformer->config;
|
||||
TransformerWeights* w = &transformer->weights;
|
||||
RunState* s = &transformer->state;
|
||||
float *x = s->x;
|
||||
int dim = p->dim;
|
||||
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
|
||||
@@ -211,7 +245,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
int head_size = dim / p->n_heads;
|
||||
|
||||
// copy the token embedding into x
|
||||
float* content_row = &(w->token_embedding_table[token * dim]);
|
||||
float* content_row = w->token_embedding_table + token * dim;
|
||||
memcpy(x, content_row, dim*sizeof(*x));
|
||||
|
||||
// forward all the layers
|
||||
@@ -305,14 +339,14 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
|
||||
matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);
|
||||
|
||||
// F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid
|
||||
// SwiGLU non-linearity
|
||||
for (int i = 0; i < hidden_dim; i++) {
|
||||
s->hb[i] = s->hb[i] * (1.0f / (1.0f + expf(-s->hb[i])));
|
||||
}
|
||||
|
||||
// elementwise multiply with w3(x)
|
||||
for (int i = 0; i < hidden_dim; i++) {
|
||||
s->hb[i] = s->hb[i] * s->hb2[i];
|
||||
float val = s->hb[i];
|
||||
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
|
||||
val *= (1.0f / (1.0f + expf(-val)));
|
||||
// elementwise multiply with w3(x)
|
||||
val *= s->hb2[i];
|
||||
s->hb[i] = val;
|
||||
}
|
||||
|
||||
// final matmul to get the output of the ffn
|
||||
@@ -329,20 +363,90 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
|
||||
// classifier into logits
|
||||
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
|
||||
return s->logits;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt
|
||||
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
|
||||
|
||||
typedef struct {
|
||||
char *str;
|
||||
int id;
|
||||
} TokenIndex;
|
||||
|
||||
typedef struct {
|
||||
char** vocab;
|
||||
float* vocab_scores;
|
||||
TokenIndex *sorted_vocab;
|
||||
int vocab_size;
|
||||
unsigned int max_token_length;
|
||||
unsigned char byte_pieces[512]; // stores all single-byte strings
|
||||
} Tokenizer;
|
||||
|
||||
int compare_tokens(const void *a, const void *b) {
|
||||
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
|
||||
}
|
||||
|
||||
void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
|
||||
// i should have written the vocab_size into the tokenizer file... sigh
|
||||
t->vocab_size = vocab_size;
|
||||
// malloc space to hold the scores and the strings
|
||||
t->vocab = (char**)malloc(vocab_size * sizeof(char*));
|
||||
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
|
||||
t->sorted_vocab = NULL; // initialized lazily
|
||||
for (int i = 0; i < 256; i++) {
|
||||
t->byte_pieces[i * 2] = (unsigned char)i;
|
||||
t->byte_pieces[i * 2 + 1] = '\0';
|
||||
}
|
||||
// read in the file
|
||||
FILE *file = fopen(tokenizer_path, "rb");
|
||||
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
|
||||
if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
|
||||
int len;
|
||||
for (int i = 0; i < vocab_size; i++) {
|
||||
if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
|
||||
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
|
||||
t->vocab[i] = (char *)malloc(len + 1);
|
||||
if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
|
||||
t->vocab[i][len] = '\0'; // add the string terminating token
|
||||
}
|
||||
fclose(file);
|
||||
}
|
||||
|
||||
void free_tokenizer(Tokenizer* t) {
|
||||
for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
|
||||
free(t->vocab);
|
||||
free(t->vocab_scores);
|
||||
free(t->sorted_vocab);
|
||||
}
|
||||
|
||||
char* decode(Tokenizer* t, int prev_token, int token) {
|
||||
char *piece = t->vocab[token];
|
||||
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||
if (prev_token == 1 && piece[0] == ' ') { piece++; }
|
||||
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
|
||||
// parse this and convert and return the actual byte
|
||||
unsigned char byte_val;
|
||||
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
|
||||
piece = (char*)t->byte_pieces + byte_val * 2;
|
||||
}
|
||||
return piece;
|
||||
}
|
||||
|
||||
void safe_printf(char *piece) {
|
||||
// piece might be a raw byte token, and we only want to print printable chars or whitespace
|
||||
// because some of the other bytes can be various control codes, backspace, etc.
|
||||
if (piece == NULL) { return; }
|
||||
if (piece[0] == '\0') { return; }
|
||||
if (piece[1] == '\0') {
|
||||
unsigned char byte_val = piece[0];
|
||||
if (!(isprint(byte_val) || isspace(byte_val))) {
|
||||
return; // bad byte, don't print it
|
||||
}
|
||||
}
|
||||
printf("%s", piece);
|
||||
}
|
||||
|
||||
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
||||
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
|
||||
TokenIndex tok = { .str = str }; // acts as the key to search for
|
||||
@@ -350,23 +454,40 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
||||
return res != NULL ? res->id : -1;
|
||||
}
|
||||
|
||||
void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) {
|
||||
void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
|
||||
// encode the string text (input) into an upper-bound preallocated tokens[] array
|
||||
// bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
|
||||
if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
|
||||
|
||||
// sort vocabulary
|
||||
TokenIndex *sorted_vocab = malloc(vocab_size * sizeof(TokenIndex));
|
||||
for (int i = 0; i < vocab_size; i++) {
|
||||
sorted_vocab[i].str = vocab[i];
|
||||
sorted_vocab[i].id = i;
|
||||
if (t->sorted_vocab == NULL) {
|
||||
// lazily malloc and sort the vocabulary
|
||||
t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
|
||||
for (int i = 0; i < t->vocab_size; i++) {
|
||||
t->sorted_vocab[i].str = t->vocab[i];
|
||||
t->sorted_vocab[i].id = i;
|
||||
}
|
||||
qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
|
||||
}
|
||||
qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
|
||||
|
||||
// create a temporary buffer that will store merge candidates of always two consecutive tokens
|
||||
char* str_buffer = malloc((max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
|
||||
// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
|
||||
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));
|
||||
size_t str_len = 0;
|
||||
|
||||
// start at 0 tokens
|
||||
*n_tokens = 0;
|
||||
|
||||
// add optional BOS (=1) token, if desired
|
||||
if (bos) tokens[(*n_tokens)++] = 1;
|
||||
|
||||
// add_dummy_prefix is true by default
|
||||
tokens[0] = str_lookup(" ", sorted_vocab, vocab_size);
|
||||
*n_tokens = 1; // the number of tokens
|
||||
// so prepend a dummy prefix token to the input string, but only if text != ""
|
||||
// TODO: pretty sure this isn't correct in the general case but I don't have the
|
||||
// energy to read more of the sentencepiece code to figure out what it's doing
|
||||
if (text[0] != '\0') {
|
||||
int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
|
||||
tokens[(*n_tokens)++] = dummy_prefix;
|
||||
}
|
||||
|
||||
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
|
||||
// Code point ↔ UTF-8 conversion
|
||||
@@ -401,7 +522,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
||||
}
|
||||
|
||||
// ok c+1 is not a continuation byte, so we've read in a full codepoint
|
||||
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
|
||||
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
|
||||
|
||||
if (id != -1) {
|
||||
// we found this codepoint in vocab, add it as a token
|
||||
@@ -425,11 +546,11 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
||||
|
||||
for (int i=0; i < (*n_tokens-1); i++) {
|
||||
// check if we can merge the pair (tokens[i], tokens[i+1])
|
||||
sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]);
|
||||
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
|
||||
if (id != -1 && vocab_scores[id] > best_score) {
|
||||
sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
|
||||
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
|
||||
if (id != -1 && t->vocab_scores[id] > best_score) {
|
||||
// this merge pair exists in vocab! record its score and position
|
||||
best_score = vocab_scores[id];
|
||||
best_score = t->vocab_scores[id];
|
||||
best_id = id;
|
||||
best_idx = i;
|
||||
}
|
||||
@@ -448,36 +569,30 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
||||
(*n_tokens)--; // token length decreased
|
||||
}
|
||||
|
||||
// add optional EOS (=2) token, if desired
|
||||
if (eos) tokens[(*n_tokens)++] = 2;
|
||||
|
||||
free(str_buffer);
|
||||
free(sorted_vocab);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// utilities: time / rng
|
||||
|
||||
long time_in_ms() {
|
||||
// return time in milliseconds, for benchmarking the model speed
|
||||
struct timespec time;
|
||||
clock_gettime(CLOCK_REALTIME, &time);
|
||||
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
|
||||
}
|
||||
|
||||
unsigned long long rng_seed;
|
||||
unsigned int random_u32() {
|
||||
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
|
||||
rng_seed ^= rng_seed >> 12;
|
||||
rng_seed ^= rng_seed << 25;
|
||||
rng_seed ^= rng_seed >> 27;
|
||||
return (rng_seed * 0x2545F4914F6CDD1Dull) >> 32;
|
||||
}
|
||||
float random_f32() { // random float32 in [0,1)
|
||||
return (random_u32() >> 8) / 16777216.0f;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// The Sampler, which takes logits and returns a sampled token
|
||||
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
|
||||
|
||||
int argmax(float* probabilities, int n) {
|
||||
typedef struct {
|
||||
float prob;
|
||||
int index;
|
||||
} ProbIndex; // struct used when sorting probabilities during top-p sampling
|
||||
|
||||
typedef struct {
|
||||
int vocab_size;
|
||||
ProbIndex* probindex; // buffer used in top-p sampling
|
||||
float temperature;
|
||||
float topp;
|
||||
unsigned long long rng_state;
|
||||
} Sampler;
|
||||
|
||||
int sample_argmax(float* probabilities, int n) {
|
||||
// return the index that has the highest probability
|
||||
int max_i = 0;
|
||||
float max_p = probabilities[0];
|
||||
@@ -490,13 +605,13 @@ int argmax(float* probabilities, int n) {
|
||||
return max_i;
|
||||
}
|
||||
|
||||
int sample(float* probabilities, int n) {
|
||||
int sample_mult(float* probabilities, int n, float coin) {
|
||||
// sample index from probabilities (they must sum to 1!)
|
||||
float r = random_f32();
|
||||
// coin is a random number in [0, 1), usually from random_f32()
|
||||
float cdf = 0.0f;
|
||||
for (int i = 0; i < n; i++) {
|
||||
cdf += probabilities[i];
|
||||
if (r < cdf) {
|
||||
if (coin < cdf) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
@@ -511,10 +626,11 @@ int compare(const void* a, const void* b) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
||||
int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
|
||||
// top-p sampling (or "nucleus sampling") samples from the smallest set of
|
||||
// tokens that exceed probability topp. This way we never sample tokens that
|
||||
// have very low probabilities and are less likely to go "off the rails".
|
||||
// coin is a random number in [0, 1), usually from random_f32()
|
||||
|
||||
int n0 = 0;
|
||||
// quicksort indices in descending order of probabilities
|
||||
@@ -542,7 +658,7 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
||||
}
|
||||
|
||||
// sample from the truncated list
|
||||
float r = random_f32() * cumulative_prob;
|
||||
float r = coin * cumulative_prob;
|
||||
float cdf = 0.0f;
|
||||
for (int i = 0; i <= last_idx; i++) {
|
||||
cdf += probindex[i].prob;
|
||||
@@ -553,167 +669,107 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
||||
return probindex[last_idx].index; // in case of rounding errors
|
||||
}
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// int main
|
||||
|
||||
void error_usage() {
|
||||
fprintf(stderr, "Usage: run <checkpoint> [options]\n");
|
||||
fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
|
||||
fprintf(stderr, "Options:\n");
|
||||
fprintf(stderr, " -t <float> temperature, default 1.0\n");
|
||||
fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling. default 0.9\n");
|
||||
fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
|
||||
fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
|
||||
fprintf(stderr, " -i <string> input prompt\n");
|
||||
fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
|
||||
exit(EXIT_FAILURE);
|
||||
void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {
|
||||
sampler->vocab_size = vocab_size;
|
||||
sampler->temperature = temperature;
|
||||
sampler->topp = topp;
|
||||
sampler->rng_state = rng_seed;
|
||||
// buffer only used with nucleus sampling; may not need but it's ~small
|
||||
sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex));
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
void free_sampler(Sampler* sampler) {
|
||||
free(sampler->probindex);
|
||||
}
|
||||
|
||||
// default inits
|
||||
char *checkpoint = NULL; // e.g. out/model.bin
|
||||
char *tokenizer = "tokenizer.bin";
|
||||
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
||||
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
|
||||
rng_seed = 0; // seed rng with time by default
|
||||
int steps = 256; // number of steps to run for
|
||||
char *prompt = NULL; // prompt string
|
||||
unsigned int random_u32(unsigned long long *state) {
|
||||
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
|
||||
*state ^= *state >> 12;
|
||||
*state ^= *state << 25;
|
||||
*state ^= *state >> 27;
|
||||
return (*state * 0x2545F4914F6CDD1Dull) >> 32;
|
||||
}
|
||||
float random_f32(unsigned long long *state) { // random float32 in [0,1)
|
||||
return (random_u32(state) >> 8) / 16777216.0f;
|
||||
}
|
||||
|
||||
// poor man's C argparse so we can override the defaults above from the command line
|
||||
if (argc >= 2) { checkpoint = argv[1]; } else { error_usage(); }
|
||||
for (int i = 2; i < argc; i+=2) {
|
||||
// do some basic validation
|
||||
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
|
||||
if (argv[i][0] != '-') { error_usage(); } // must start with dash
|
||||
if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
|
||||
// read in the args
|
||||
if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); }
|
||||
else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
|
||||
else if (argv[i][1] == 'z') { tokenizer = argv[i + 1]; }
|
||||
else { error_usage(); }
|
||||
}
|
||||
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}
|
||||
|
||||
// read in the model.bin file
|
||||
Config config;
|
||||
TransformerWeights weights;
|
||||
int fd = 0; // file descriptor for memory mapping
|
||||
float* data = NULL; // memory mapped data pointer
|
||||
ssize_t file_size; // size of the checkpoint file in bytes
|
||||
{
|
||||
FILE *file = fopen(checkpoint, "rb");
|
||||
if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); return 1; }
|
||||
// read in the config header
|
||||
if (fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
|
||||
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
|
||||
int shared_weights = config.vocab_size > 0 ? 1 : 0;
|
||||
config.vocab_size = abs(config.vocab_size);
|
||||
// figure out the file size
|
||||
fseek(file, 0, SEEK_END); // move file pointer to end of file
|
||||
file_size = ftell(file); // get the file size, in bytes
|
||||
fclose(file);
|
||||
// memory map the Transformer weights into the data pointer
|
||||
fd = open(checkpoint, O_RDONLY); // open in read only mode
|
||||
if (fd == -1) { fprintf(stderr, "open failed!\n"); return 1; }
|
||||
data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0);
|
||||
if (data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); return 1; }
|
||||
float* weights_ptr = data + sizeof(Config)/sizeof(float);
|
||||
checkpoint_init_weights(&weights, &config, weights_ptr, shared_weights);
|
||||
}
|
||||
// right now we cannot run for more than config.seq_len steps
|
||||
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
|
||||
|
||||
// read in the tokenizer .bin file
|
||||
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
|
||||
float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float));
|
||||
unsigned int max_token_length;
|
||||
{
|
||||
FILE *file = fopen(tokenizer, "rb");
|
||||
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); return 1; }
|
||||
if (fread(&max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
|
||||
int len;
|
||||
for (int i = 0; i < config.vocab_size; i++) {
|
||||
if (fread(vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1;}
|
||||
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
|
||||
vocab[i] = (char *)malloc(len + 1);
|
||||
if (fread(vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
|
||||
vocab[i][len] = '\0'; // add the string terminating token
|
||||
int sample(Sampler* sampler, float* logits) {
|
||||
// sample the token given the logits and some hyperparameters
|
||||
int next;
|
||||
if (sampler->temperature == 0.0f) {
|
||||
// greedy argmax sampling: take the token with the highest probability
|
||||
next = sample_argmax(logits, sampler->vocab_size);
|
||||
} else {
|
||||
// apply the temperature to the logits
|
||||
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
|
||||
// apply softmax to the logits to get the probabilities for next token
|
||||
softmax(logits, sampler->vocab_size);
|
||||
// flip a (float) coin (this is our source of entropy for sampling)
|
||||
float coin = random_f32(&sampler->rng_state);
|
||||
// we sample from this distribution to get the next token
|
||||
if (sampler->topp <= 0 || sampler->topp >= 1) {
|
||||
// simply sample from the predicted probability distribution
|
||||
next = sample_mult(logits, sampler->vocab_size, coin);
|
||||
} else {
|
||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||
next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);
|
||||
}
|
||||
fclose(file);
|
||||
}
|
||||
return next;
|
||||
}
|
||||
|
||||
// create and init the application RunState
|
||||
RunState state;
|
||||
malloc_run_state(&state, &config);
|
||||
// ----------------------------------------------------------------------------
|
||||
// utilities: time
|
||||
|
||||
// process the prompt, if any
|
||||
int *prompt_tokens = NULL;
|
||||
long time_in_ms() {
|
||||
// return time in milliseconds, for benchmarking the model speed
|
||||
struct timespec time;
|
||||
clock_gettime(CLOCK_REALTIME, &time);
|
||||
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// generation loop
|
||||
|
||||
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
|
||||
char *empty_prompt = "";
|
||||
if (prompt == NULL) { prompt = empty_prompt; }
|
||||
|
||||
// encode the (string) prompt into tokens sequence
|
||||
int num_prompt_tokens = 0;
|
||||
if (prompt != NULL) {
|
||||
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
|
||||
bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens);
|
||||
int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS
|
||||
encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
|
||||
if (num_prompt_tokens < 1) {
|
||||
fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
// start the main loop
|
||||
long start = 0; // used to time our code, only initialized after first iteration
|
||||
int next; // will store the next token in the sequence
|
||||
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
|
||||
int token = prompt_tokens[0]; // kick off with the first token in the prompt
|
||||
int pos = 0; // position in the sequence
|
||||
while (pos < steps) {
|
||||
|
||||
// forward the transformer to get logits for the next token
|
||||
transformer(token, pos, &config, &state, &weights);
|
||||
float* logits = forward(transformer, token, pos);
|
||||
|
||||
// advance the state state machine
|
||||
if(pos < num_prompt_tokens) {
|
||||
if (pos < num_prompt_tokens - 1) {
|
||||
// if we are still processing the input prompt, force the next prompt token
|
||||
next = prompt_tokens[pos];
|
||||
next = prompt_tokens[pos + 1];
|
||||
} else {
|
||||
// sample the next token
|
||||
if (temperature == 0.0f) {
|
||||
// greedy argmax sampling: take the token with the highest probability
|
||||
next = argmax(state.logits, config.vocab_size);
|
||||
} else {
|
||||
// apply the temperature to the logits
|
||||
for (int q=0; q<config.vocab_size; q++) { state.logits[q] /= temperature; }
|
||||
// apply softmax to the logits to get the probabilities for next token
|
||||
softmax(state.logits, config.vocab_size);
|
||||
// we sample from this distribution to get the next token
|
||||
if (topp <= 0 || topp >= 1) {
|
||||
// simply sample from the predicted probability distribution
|
||||
next = sample(state.logits, config.vocab_size);
|
||||
} else {
|
||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||
next = sample_topp(state.logits, config.vocab_size, topp, state.probindex);
|
||||
}
|
||||
}
|
||||
// otherwise sample the next token from the logits
|
||||
next = sample(sampler, logits);
|
||||
}
|
||||
pos++;
|
||||
|
||||
// data-dependent terminating condition: the BOS (1) token delimits sequences
|
||||
// data-dependent terminating condition: the BOS (=1) token delimits sequences
|
||||
if (next == 1) { break; }
|
||||
|
||||
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
|
||||
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
|
||||
unsigned char byte_val;
|
||||
if (sscanf(token_str, "<0x%02hhX>", &byte_val) == 1) {
|
||||
// ok this token is a raw byte token, carefuly to only print printable chars or whitespace
|
||||
// some of the other bytes can be various control codes, backspace, etc. => skip
|
||||
if (isprint(byte_val) || isspace(byte_val)) {
|
||||
char byte_piece[2];
|
||||
byte_piece[0] = byte_val;
|
||||
byte_piece[1] = '\0';
|
||||
printf("%s", byte_piece);
|
||||
}
|
||||
} else {
|
||||
printf("%s", token_str);
|
||||
}
|
||||
// print the token as string, decode it with the Tokenizer object
|
||||
char* piece = decode(tokenizer, token, next);
|
||||
safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
|
||||
fflush(stdout);
|
||||
token = next;
|
||||
|
||||
@@ -728,13 +784,195 @@ int main(int argc, char *argv[]) {
|
||||
fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
|
||||
}
|
||||
|
||||
free(prompt_tokens);
|
||||
}
|
||||
|
||||
void read_stdin(const char* guide, char* buffer, size_t bufsize) {
|
||||
// read a line from stdin, up to but not including \n
|
||||
printf("%s", guide);
|
||||
if (fgets(buffer, bufsize, stdin) != NULL) {
|
||||
size_t len = strlen(buffer);
|
||||
if (len > 0 && buffer[len - 1] == '\n') {
|
||||
buffer[len - 1] = '\0'; // strip newline
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// chat loop
|
||||
// I manually inspected the tokens for a few chat conversations compared to
|
||||
// python reference and that seemed ok, but this was not thoroughly tested and
|
||||
// is not safely implemented, it's more a proof of concept atm.
|
||||
|
||||
void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
|
||||
char *cli_user_prompt, char *cli_system_prompt, int steps) {
|
||||
|
||||
// buffers for reading the system prompt and user prompt from stdin
|
||||
// you'll notice they are soomewhat haphazardly and unsafely set atm
|
||||
char system_prompt[512];
|
||||
char user_prompt[512];
|
||||
char rendered_prompt[1152];
|
||||
int num_prompt_tokens = 0;
|
||||
int* prompt_tokens = (int*)malloc(1152 * sizeof(int));
|
||||
int user_idx;
|
||||
|
||||
// start the main loop
|
||||
int8_t user_turn = 1; // user starts
|
||||
int next; // will store the next token in the sequence
|
||||
int token; // stores the current token to feed into the transformer
|
||||
int prev_token;
|
||||
int pos = 0; // position in the sequence
|
||||
while (pos < steps) {
|
||||
|
||||
// when it is the user's turn to contribute tokens to the dialog...
|
||||
if (user_turn) {
|
||||
// get the (optional) system prompt at position 0
|
||||
if (pos == 0) {
|
||||
// at position 0, the user can also contribute a system prompt
|
||||
if (cli_system_prompt == NULL) {
|
||||
// system prompt was not passed in, attempt to get it from stdin
|
||||
read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
|
||||
} else {
|
||||
// system prompt was passed in, use it
|
||||
strcpy(system_prompt, cli_system_prompt);
|
||||
}
|
||||
}
|
||||
// get the user prompt
|
||||
if (pos == 0 && cli_user_prompt != NULL) {
|
||||
// user prompt for position 0 was passed in, use it
|
||||
strcpy(user_prompt, cli_user_prompt);
|
||||
} else {
|
||||
// otherwise get user prompt from stdin
|
||||
read_stdin("User: ", user_prompt, sizeof(user_prompt));
|
||||
}
|
||||
// render user/system prompts into the Llama 2 Chat schema
|
||||
if (pos == 0 && system_prompt[0] != '\0') {
|
||||
char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
|
||||
sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
|
||||
} else {
|
||||
char user_template[] = "[INST] %s [/INST]";
|
||||
sprintf(rendered_prompt, user_template, user_prompt);
|
||||
}
|
||||
// encode the rendered prompt into tokens
|
||||
encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
|
||||
user_idx = 0; // reset the user index
|
||||
user_turn = 0;
|
||||
printf("Assistant: ");
|
||||
}
|
||||
|
||||
// determine the token to pass into the transformer next
|
||||
if (user_idx < num_prompt_tokens) {
|
||||
// if we are still processing the input prompt, force the next prompt token
|
||||
token = prompt_tokens[user_idx++];
|
||||
} else {
|
||||
// otherwise use the next token sampled from previous turn
|
||||
token = next;
|
||||
}
|
||||
// EOS (=2) token ends the Assistant turn
|
||||
if (token == 2) { user_turn = 1; }
|
||||
|
||||
// forward the transformer to get logits for the next token
|
||||
float* logits = forward(transformer, token, pos);
|
||||
next = sample(sampler, logits);
|
||||
pos++;
|
||||
|
||||
if (user_idx >= num_prompt_tokens && next != 2) {
|
||||
// the Assistant is responding, so print its output
|
||||
char* piece = decode(tokenizer, token, next);
|
||||
safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
|
||||
fflush(stdout);
|
||||
}
|
||||
if (next == 2) { printf("\n"); }
|
||||
}
|
||||
printf("\n");
|
||||
free(prompt_tokens);
|
||||
}
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// CLI, include only if not testing
|
||||
#ifndef TESTING
|
||||
|
||||
void error_usage() {
|
||||
fprintf(stderr, "Usage: run <checkpoint> [options]\n");
|
||||
fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
|
||||
fprintf(stderr, "Options:\n");
|
||||
fprintf(stderr, " -t <float> temperature in [0,inf], default 1.0\n");
|
||||
fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
|
||||
fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
|
||||
fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
|
||||
fprintf(stderr, " -i <string> input prompt\n");
|
||||
fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
|
||||
fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
|
||||
fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
|
||||
// default parameters
|
||||
char *checkpoint_path = NULL; // e.g. out/model.bin
|
||||
char *tokenizer_path = "tokenizer.bin";
|
||||
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
||||
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
|
||||
int steps = 256; // number of steps to run for
|
||||
char *prompt = NULL; // prompt string
|
||||
unsigned long long rng_seed = 0; // seed rng with time by default
|
||||
char *mode = "generate"; // generate|chat
|
||||
char *system_prompt = NULL; // the (optional) system prompt to use in chat mode
|
||||
|
||||
// poor man's C argparse so we can override the defaults above from the command line
|
||||
if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
|
||||
for (int i = 2; i < argc; i+=2) {
|
||||
// do some basic validation
|
||||
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
|
||||
if (argv[i][0] != '-') { error_usage(); } // must start with dash
|
||||
if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
|
||||
// read in the args
|
||||
if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); }
|
||||
else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
|
||||
else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
|
||||
else if (argv[i][1] == 'm') { mode = argv[i + 1]; }
|
||||
else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; }
|
||||
else { error_usage(); }
|
||||
}
|
||||
|
||||
// parameter validation/overrides
|
||||
if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL);
|
||||
if (temperature < 0.0) temperature = 0.0;
|
||||
if (topp < 0.0 || 1.0 < topp) topp = 0.9;
|
||||
if (steps < 0) steps = 0;
|
||||
|
||||
// build the Transformer via the model .bin file
|
||||
Transformer transformer;
|
||||
build_transformer(&transformer, checkpoint_path);
|
||||
if (steps == 0 || steps > transformer.config.seq_len) steps = transformer.config.seq_len; // ovrerride to ~max length
|
||||
|
||||
// build the Tokenizer via the tokenizer .bin file
|
||||
Tokenizer tokenizer;
|
||||
build_tokenizer(&tokenizer, tokenizer_path, transformer.config.vocab_size);
|
||||
|
||||
// build the Sampler
|
||||
Sampler sampler;
|
||||
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);
|
||||
|
||||
// run!
|
||||
if (strcmp(mode, "generate") == 0) {
|
||||
generate(&transformer, &tokenizer, &sampler, prompt, steps);
|
||||
} else if (strcmp(mode, "chat") == 0) {
|
||||
chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);
|
||||
} else {
|
||||
fprintf(stderr, "unknown mode: %s\n", mode);
|
||||
error_usage();
|
||||
}
|
||||
|
||||
// memory and file handles cleanup
|
||||
free_run_state(&state);
|
||||
for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); }
|
||||
free(vocab);
|
||||
free(vocab_scores);
|
||||
if (prompt_tokens != NULL) free(prompt_tokens);
|
||||
if (data != MAP_FAILED) munmap(data, file_size);
|
||||
if (fd != -1) close(fd);
|
||||
free_sampler(&sampler);
|
||||
free_tokenizer(&tokenizer);
|
||||
free_transformer(&transformer);
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -52,7 +52,7 @@ if compile:
|
||||
model = torch.compile(model) # requires PyTorch 2.0 (optional)
|
||||
|
||||
# load the tokenizer
|
||||
vocab_source = checkpoint_dict.get("vocab_source", "llama2")
|
||||
vocab_source = checkpoint_dict["config"].get("vocab_source", "llama2")
|
||||
vocab_size = gptconf.vocab_size
|
||||
if tokenizer:
|
||||
# a specific tokenizer is provided, use it
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""Saves the model as a TorchScript.
|
||||
|
||||
Usage examples:
|
||||
./save_torchscript.py
|
||||
./save_torchscript.py --dim=300
|
||||
./save_torchscript.py --gzip_output=True --zero_params=True
|
||||
|
||||
The resulting file can be loaded in C++ code and then used for training or
|
||||
inference with:
|
||||
#include <torch/script.h>
|
||||
torch::jit::Module module = torch::jit::load("model.pt")
|
||||
|
||||
Note that the serialized model includes the initial parameters and with the default
|
||||
ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute
|
||||
the model parameters separately you can zero out the parameters before saving it and
|
||||
it will gzip down to 780K.
|
||||
"""
|
||||
import gzip
|
||||
import os
|
||||
import shutil
|
||||
from inspect import signature
|
||||
|
||||
import torch
|
||||
|
||||
from model import ModelArgs, Transformer
|
||||
|
||||
# Model args config
|
||||
dim = 288
|
||||
n_layers = 6
|
||||
n_heads = 6
|
||||
n_kv_heads = n_heads
|
||||
multiple_of = 32
|
||||
max_seq_len = 256
|
||||
dropout = 0.0
|
||||
vocab_size = 32000
|
||||
norm_eps = 1e-5
|
||||
# Save config
|
||||
model_path = "model.pt"
|
||||
zero_params = False
|
||||
gzip_output = False
|
||||
# Allow config overrides
|
||||
exec(open("configurator.py").read())
|
||||
|
||||
|
||||
def main() -> None:
|
||||
model_args = {k: globals()[k] for k in signature(ModelArgs).parameters}
|
||||
model = Transformer(ModelArgs(**model_args))
|
||||
|
||||
# If requested zero params before saving the model. This is useful in
|
||||
# conjunction with gzip_output.
|
||||
if zero_params:
|
||||
for p in model.parameters():
|
||||
p.detach().zero_()
|
||||
|
||||
torch.jit.save(torch.jit.script(model), model_path)
|
||||
|
||||
if gzip_output:
|
||||
with open(model_path, "rb") as f_in:
|
||||
with gzip.open(f"{model_path}.gz", "wb") as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
os.unlink(model_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,81 @@
|
||||
#define TESTING
|
||||
#include "run.c"
|
||||
|
||||
void assert_eq(int a, int b) {
|
||||
if (a != b) {
|
||||
printf("Assertion failed: %d != %d\n", a, b);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
void test_prompt_encoding(Tokenizer* tokenizer, char* prompt, int* expected_tokens, int num_expected_tokens) {
|
||||
// encode
|
||||
int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int));
|
||||
int num_prompt_tokens = 0; // the total number of prompt tokens
|
||||
encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
|
||||
|
||||
#if VERBOSITY == 1
|
||||
// print maybe
|
||||
printf("expected tokens:\n");
|
||||
for (int i = 0; i < num_expected_tokens; i++) printf("%d ", expected_tokens[i]);
|
||||
printf("\n");
|
||||
printf("actual tokens:\n");
|
||||
for (int i = 0; i < num_prompt_tokens; i++) printf("%d ", prompt_tokens[i]);
|
||||
printf("\n");
|
||||
#endif
|
||||
|
||||
// verify
|
||||
assert_eq(num_prompt_tokens, num_expected_tokens);
|
||||
for (int i = 0; i < num_prompt_tokens; i++) {
|
||||
assert_eq(prompt_tokens[i], expected_tokens[i]);
|
||||
}
|
||||
|
||||
#if VERBOSITY == 1
|
||||
printf("OK\n");
|
||||
printf("---\n");
|
||||
#endif
|
||||
free(prompt_tokens);
|
||||
}
|
||||
|
||||
void test_prompt_encodings() {
|
||||
// let's verify that the Tokenizer works as expected
|
||||
|
||||
char *tokenizer_path = "tokenizer.bin";
|
||||
int vocab_size = 32000;
|
||||
Tokenizer tokenizer;
|
||||
build_tokenizer(&tokenizer, tokenizer_path, vocab_size);
|
||||
|
||||
// test 0 (test the empty string) (I added this as a simple case)
|
||||
char *prompt0 = "";
|
||||
int expected_tokens0[] = {1};
|
||||
test_prompt_encoding(&tokenizer, prompt0, expected_tokens0, sizeof(expected_tokens0) / sizeof(int));
|
||||
|
||||
// the tests below are taken from the Meta Llama 2 repo example code
|
||||
// https://github.com/facebookresearch/llama/blob/main/example_text_completion.py
|
||||
// and the expected tokens come from me breaking in the debugger in Python
|
||||
|
||||
// test 1
|
||||
char *prompt = "I believe the meaning of life is";
|
||||
int expected_tokens[] = {1, 306, 4658, 278, 6593, 310, 2834, 338};
|
||||
test_prompt_encoding(&tokenizer, prompt, expected_tokens, sizeof(expected_tokens) / sizeof(int));
|
||||
|
||||
// test 2
|
||||
char* prompt2 = "Simply put, the theory of relativity states that ";
|
||||
int expected_tokens2[] = {1, 3439, 17632, 1925, 29892, 278, 6368, 310, 14215, 537, 5922, 393, 29871};
|
||||
test_prompt_encoding(&tokenizer, prompt2, expected_tokens2, sizeof(expected_tokens2) / sizeof(int));
|
||||
|
||||
// test 3
|
||||
char* prompt3 = "A brief message congratulating the team on the launch:\n\n Hi everyone,\n\n I just ";
|
||||
int expected_tokens3[] = {1, 319, 11473, 2643, 378, 629, 271, 18099, 278, 3815, 373, 278, 6826, 29901, 13, 13, 4706, 6324, 14332, 29892, 13, 13, 4706, 306, 925, 29871};
|
||||
test_prompt_encoding(&tokenizer, prompt3, expected_tokens3, sizeof(expected_tokens3) / sizeof(int));
|
||||
|
||||
// test 4
|
||||
char* prompt4 = "Translate English to French:\n\n sea otter => loutre de mer\n peppermint => menthe poivrée\n plush girafe => girafe peluche\n cheese =>";
|
||||
int expected_tokens4[] = {1, 4103, 9632, 4223, 304, 5176, 29901, 13, 13, 4706, 7205, 4932, 357, 1149, 301, 449, 276, 316, 2778, 13, 4706, 1236, 407, 837, 524, 1149, 6042, 354, 772, 440, 29878, 1318, 13, 4706, 715, 1878, 330, 3055, 1725, 1149, 330, 3055, 1725, 4639, 28754, 13, 4706, 923, 968, 1149};
|
||||
test_prompt_encoding(&tokenizer, prompt4, expected_tokens4, sizeof(expected_tokens4) / sizeof(int));
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
test_prompt_encodings();
|
||||
printf("ALL OK\n");
|
||||
}
|
||||
+17
-10
@@ -13,6 +13,7 @@ from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
@@ -97,16 +98,21 @@ def train_vocab(vocab_size):
|
||||
of.write(text + "\n")
|
||||
print(f"Size is: {os.path.getsize(tiny_file) / 1024 / 1024:.2f} MB")
|
||||
|
||||
# 2) run the train_vocab.sh script that trains the sentencepiece model
|
||||
print("Will now train the vocab with:")
|
||||
cmd = f"bash train_vocab.sh {tiny_file} {prefix} {vocab_size}"
|
||||
print(cmd)
|
||||
print("OK? [y/N] ")
|
||||
dec = input()
|
||||
if dec.lower() != "y":
|
||||
print("Exiting...")
|
||||
return
|
||||
os.system(cmd)
|
||||
# 2) train the sentencepiece model
|
||||
print("Will now train the vocab...")
|
||||
spm.SentencePieceTrainer.train(input=tiny_file,
|
||||
model_prefix=prefix,
|
||||
model_type="bpe",
|
||||
vocab_size=vocab_size,
|
||||
self_test_sample_size=0,
|
||||
input_format="text",
|
||||
character_coverage=1.0,
|
||||
num_threads=os.cpu_count(),
|
||||
split_digits=True,
|
||||
allow_whitespace_only_pieces=True,
|
||||
byte_fallback=True,
|
||||
unk_surface=r" \342\201\207 ",
|
||||
normalization_rule_name="identity")
|
||||
|
||||
# 3) optional cleanup, ask the user if they'd like to delete tiny.txt
|
||||
dec = input(f"Delete the temporary file {tiny_file}? [y/N] ")
|
||||
@@ -196,6 +202,7 @@ class PretokDataset(torch.utils.data.IterableDataset):
|
||||
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
|
||||
# train/test split. let's use only shard 0 for test split, rest train
|
||||
shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1]
|
||||
assert len(shard_filenames)>0, f"No bin files found in {bin_dir}"
|
||||
while True:
|
||||
rng.shuffle(shard_filenames)
|
||||
for shard in shard_filenames:
|
||||
|
||||
@@ -29,6 +29,7 @@ from torch.distributed import destroy_process_group, init_process_group
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from tinystories import Task
|
||||
from export import model_export
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# I/O
|
||||
@@ -270,7 +271,7 @@ while True:
|
||||
"loss/val": losses["val"],
|
||||
"lr": lr,
|
||||
"mfu": running_mfu * 100, # convert to percentage
|
||||
}
|
||||
}, step = iter_num
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"logging to wandb failed: {e}")
|
||||
@@ -287,7 +288,7 @@ while True:
|
||||
}
|
||||
print(f"saving checkpoint to {out_dir}")
|
||||
torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt"))
|
||||
raw_model.export(os.path.join(out_dir, "model.bin"))
|
||||
model_export(raw_model, os.path.join(out_dir, "model.bin"), version=0)
|
||||
if iter_num == 0 and eval_only:
|
||||
break
|
||||
|
||||
|
||||
Reference in New Issue
Block a user