Merge remote-tracking branch 'upstream/master'
This commit is contained in:
@@ -4,9 +4,11 @@
|
||||
<img src="assets/llama_cute.jpg" width="300" height="300" alt="Cute Llama">
|
||||
</p>
|
||||
|
||||
With the code in this repo you can train the Llama 2 LLM architecture from scratch in PyTorch, then export the weights to a binary file, and load that into one ~simple 500-line C file ([run.c](run.c)) that inferences the model. Alternatively, you can load, finetune, and inference Meta's Llama 2 (but this is still being actively fleshed out). Hence, this repo is a "fullstack" train + inference solution for Llama 2 LLM, with a focus on minimalism and simplicity. You might think that you need many billion parameter LLMs to do anything useful, but in fact very small LLMs can have surprisingly strong performance if you make the domain narrow enough. I recommend looking at the [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) paper for inspiration.
|
||||
Train the Llama 2 LLM architecture in PyTorch then inference it with one simple 700-line C file ([run.c](run.c)). You might think that you need many billion parameter LLMs to do anything useful, but in fact very small LLMs can have surprisingly strong performance if you make the domain narrow enough (ref: [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) paper). This repo is a "fullstack" train + inference solution for Llama 2 LLM, with focus on minimalism and simplicity.
|
||||
|
||||
Please note that this started recently as just a fun weekend project: I took my earlier [nanoGPT](https://github.com/karpathy/nanoGPT), tuned it to implement the Llama-2 architecture instead of GPT-2, and the meat of it was writing the C inference engine in [run.c](run.c). So the project is young and moving quickly. Hat tip to the awesome [llama.cpp](https://github.com/ggerganov/llama.cpp) for inspiring this project. I wanted something super minimal so I chose to hard-code the Llama 2 architecture, stick to fp32, and just roll one inference file of pure C with no dependencies.
|
||||
As the architecture is identical, you can also load and inference Meta's Llama 2 models. However, the current code only inferences models in fp32, so you will most likely not be able to productively load models larger than 7B. Work on model quantization is currently ongoing.
|
||||
|
||||
Please note that this repo started recently as a fun weekend project: I took my earlier [nanoGPT](https://github.com/karpathy/nanoGPT), tuned it to implement the Llama-2 architecture instead of GPT-2, and the meat of it was writing the C inference engine in [run.c](run.c). So the project is young and moving quickly. Hat tip to the awesome [llama.cpp](https://github.com/ggerganov/llama.cpp) for inspiring this project. Compred to llama.cpp, I wanted something super simple, minimal, and educational so I chose to hard-code the Llama 2 architecture and just roll one inference file of pure C with no dependencies.
|
||||
|
||||
## feel the magic
|
||||
|
||||
@@ -56,7 +58,9 @@ You can also prompt the model with a prefix or a number of additional command li
|
||||
|
||||
> One day, Lily met a Shoggoth. He was very shy, but was also very generous. Lily said “Hello Shoggy! Can I be your friend?” Shoggy was happy to have a friend and said “Yes, let’s explore the universe together!” So they set off on a journey to explore the universe. As they travelled, Shoggy was happy to explain to Lily about all the wonderful things in the universe. At the end of the day, Lily and Shoggy had gathered lots of wonderful things from the universe, and they both felt very proud. They promised to explore the universe as one big pair and to never stop being generous to each other.
|
||||
|
||||
There is also an even better 110M param model available, see [models](#models). Quick note on sampling, the recommendation for good results is to use `-t 1.0 -p 0.9`, i.e. top-p sampling at 0.9 with temperature 1.0 (this is the default). To control the diversity of samples use either the temperature (i.e. vary `-t` between 0 and 1 and keep top-p off with `-p 0`) or the top-p value (i.e. vary `-p` between 0 and 1 and keep `-t 1`), but not both. Nice explainers on LLM sampling strategies include [this](https://peterchng.com/blog/2023/05/02/token-selection-strategies-top-k-top-p-and-temperature/), [this](https://docs.cohere.com/docs/controlling-generation-with-top-k-top-p) or [this](https://huggingface.co/blog/how-to-generate).
|
||||
There is also an even better 110M param model available, see [models](#models).
|
||||
|
||||
Quick note on sampling, the recommendation for ~best results is to sample with `-t 1.0 -p 0.9`, i.e. temperature 1.0 (default) but also top-p sampling at 0.9 (default). Intuitively, top-p ensures that tokens with tiny probabilities do not get sampled, so we can't get "unlucky" during sampling, and we are less likely to go "off the rails" afterwards. More generally, to control the diversity of samples use either the temperature (i.e. vary `-t` between 0 and 1 and keep top-p off with `-p 0`) or the top-p value (i.e. vary `-p` between 0 and 1 and keep `-t 1`), but not both. Nice explainers on LLM sampling strategies include [this](https://peterchng.com/blog/2023/05/02/token-selection-strategies-top-k-top-p-and-temperature/), [this](https://docs.cohere.com/docs/controlling-generation-with-top-k-top-p) or [this](https://huggingface.co/blog/how-to-generate).
|
||||
|
||||
## Meta's Llama 2 models
|
||||
|
||||
@@ -83,11 +87,12 @@ base models... ¯\\_(ツ)_/¯. Since we can inference the base model, it should
|
||||
|
||||
For the sake of examples of smaller, from-scratch models, I trained a small model series on TinyStories. All of these trained in a few hours on my training setup (4X A100 40GB GPUs). The 110M took around 24 hours. I am hosting them on huggingface hub [tinyllamas](https://huggingface.co/karpathy/tinyllamas), both in the original PyTorch .pt, and also in the llama2.c format .bin:
|
||||
|
||||
| model | dim | n_layers | n_heads | max context length | parameters | val loss | download
|
||||
| --- | --- | --- | --- | --- | --- | --- | --- |
|
||||
| OG | 288 | 6 | 6 | 256 | 15M | 1.072 | [stories15M.bin](https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin) |
|
||||
| 42M| 512 | 8 | 8 | 1024 | 42M | 0.847 | [stories42M.bin](https://huggingface.co/karpathy/tinyllamas/resolve/main/stories42M.bin) |
|
||||
| 110M| 768 | 12 | 12 | 1024 | 110M | 0.760 | [stories110M.bin](https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin) |
|
||||
| model | dim | n_layers | n_heads | n_kv_heads | max context length | parameters | val loss | download
|
||||
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
||||
| 260K | 64 | 5 | 8 | 4 | 512 | 260K | 1.297 | [stories260K](https://huggingface.co/karpathy/tinyllamas/tree/main/stories260K)
|
||||
| OG | 288 | 6 | 6 | 6 | 256 | 15M | 1.072 | [stories15M.bin](https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin) |
|
||||
| 42M| 512 | 8 | 8 | 8 | 1024 | 42M | 0.847 | [stories42M.bin](https://huggingface.co/karpathy/tinyllamas/resolve/main/stories42M.bin) |
|
||||
| 110M| 768 | 12 | 12 | 12 | 1024 | 110M | 0.760 | [stories110M.bin](https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin) |
|
||||
|
||||
You'll notice that the 110M model is equivalent to GPT-1 in size. Alternatively, this is also the smallest model in the GPT-2 series (`GPT-2 small`), except the max context length is only 1024 instead of 2048. The only notable changes from GPT-1/2 architecture is that Llama uses RoPE relatively positional embeddings instead of absolute/learned positional embeddings, a bit more fancy SwiGLU non-linearity in the MLP, RMSNorm instead of LayerNorm, bias=False on all Linear layers, and is optionally multiquery (but this is not yet supported in llama2.c).
|
||||
|
||||
@@ -130,15 +135,53 @@ Watch the tokens stream by, fun! We can also run the PyTorch inference script fo
|
||||
|
||||
```bash
|
||||
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt -P out15M
|
||||
mv out15M/stories15M.pt out15M/ckpt.pt # sorry the sample script current assumes this directory structure / filename...
|
||||
python sample.py --out_dir=out15M
|
||||
python sample.py --checkpoint=out15M/stories15M.pt
|
||||
```
|
||||
|
||||
Which gives the same results. More detailed testing will be done in `test_all.py`. Currently you will need two files to test or sample: both the .bin file, and the .ckpt file inside a directory (see `test_all.py` for details). Sorry this is a bit janky right now, I have to think through running the tests without having to download 200MB of data. But run the tests with pytest:
|
||||
Which gives the same results.
|
||||
|
||||
## custom tokenizers
|
||||
|
||||
In everything above, we've assumed the custom Lllama 2 tokenizer with 32,000 tokens. However, in many boutique LLMs, using vocabulary this big might be an overkill. If you have a small application you have in mind, you might be much better off training your own tokenizers. This can make everything nicer - with smaller vocabs your model has fewer parameters (because the token embedding table is a lot smaller), the inference is faster (because there are fewer tokens to predict), and your average sequence length per example could also get smaller (because the compression is a lot more efficient on your data). So let's see how we train a custom tokenizer.
|
||||
|
||||
By default, to pretokenize the tinystories dataset we had to run, in order:
|
||||
|
||||
```bash
|
||||
$ pytest
|
||||
```
|
||||
python tinystories.py download
|
||||
python tinystories.py pretokenize
|
||||
```
|
||||
|
||||
The `pretokenize` stage here loads the Llama 2 tokenizer (vocab size 32,000) and uses it to convert the downloaded text into integers, and saves that to file. We now change this as follows, to train an example 4096-token tokenizer:
|
||||
|
||||
```
|
||||
python tinystories.py download
|
||||
python tinystories.py train_vocab --vocab_size=4096
|
||||
python tinystories.py pretokenize --vocab_size=4096
|
||||
```
|
||||
|
||||
The `train_vocab` stage will call the `train_vocab.sh` script, which calls the `sentencepiece` library to train the tokenizer, storing it in a new file `data/tok4096.model`. I tried to reproduce as well as I could the settings that (I think) Meta used to train their vocabulary. This uses the Byte Pair Encoding algorithm that starts out with raw utf8 byte sequences of the text data and then iteratively merges the most common consecutive pairs of tokens to form the vocabulary. Inspect the `tinystories.py` file - the custom tokenizers are stored in a special directory structure indexed by the vocab size.
|
||||
|
||||
A quick note of interest is that vocab size of 4096 trained specifically on tinystories creates integer sequences with about the same sequence length per example as the default Llama 2 tokenizer of 32000 tokens! This means that our custom, tailored tokenizer is a lot better adapted to our specific text, and can compress it very effectively. So our trained models are smaller and faster.
|
||||
|
||||
Now that we have pretokenized the dataset with our custom tokenizer, we can train the model. The training script `train.py` doesn't care about the exact tokens, it only cares about the vocabulary size so it can correctly initialize the model. So when training your model, make sure to pass in
|
||||
|
||||
```
|
||||
python train.py --vocab_source=custom --vocab_size=4096
|
||||
```
|
||||
|
||||
(The defaults are `llama2` and `32000` respectively, which indicates the default Llama 2 tokenizer). This trains the model. Finally we are ready to run inference with our `run.c` script. For that we need two things. Number one, we have to export our tokenizer in the `.bin` format, do that with:
|
||||
|
||||
```
|
||||
python tokenizer.py --tokenizer-model=data/tok4096.model
|
||||
```
|
||||
|
||||
This writes the tokenizer to `data/tok4096.bin`. Now we can run inference, pointing it to this tokenizer using the `-z` flag:
|
||||
|
||||
```
|
||||
./run out/model.bin -z data/tok4096.bin
|
||||
```
|
||||
|
||||
This should print the samples. If you leave out the `-z` flag, it will use the default Llama 2 tokenizer, which would generate a good sequence of integers, but they would get translated using a different vocabulary to text, so it would look like gibberish.
|
||||
|
||||
## performance
|
||||
|
||||
@@ -184,6 +227,17 @@ On **Centos 7**, **Amazon Linux 2018** use `rungnu` Makefile target: `make rungn
|
||||
|
||||
On **Mac**, use clang from brew for openmp build. Install clang as `brew install llvm` and use the installed clang binary to compile with openmp: `make runomp CC=/opt/homebrew/opt/llvm/bin/clang`
|
||||
|
||||
## tests
|
||||
|
||||
You can run tests simply with pytest:
|
||||
|
||||
```bash
|
||||
$ pip install pytest
|
||||
$ pytest
|
||||
```
|
||||
|
||||
This will currently invoke two tests inside `test_all.py`, which forward the model in both C and Python for 200 steps and check the output against a known good expected output. The tests currently run in only a few seconds, but will have to download and cache the stories260K models in a temporary `test` directory (only ~2MB download).
|
||||
|
||||
## ack
|
||||
|
||||
I trained the llama2.c storyteller models on a 4X A100 40GB box graciously provided by the excellent [Lambda labs](https://lambdalabs.com/service/gpu-cloud), thank you.
|
||||
@@ -216,6 +270,7 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
|
||||
- [llama2.rs](https://github.com/gaxler/llama2.rs) by @[gaxler](https://github.com/gaxler): a Rust port of this project
|
||||
- [llama2.rs](https://github.com/leo-du/llama2.rs) by @[leo-du](https://github.com/leo-du): A Rust port of this project
|
||||
- [llama2-rs](https://github.com/danielgrittner/llama2-rs) by @[danielgrittner](https://github.com/danielgrittner): a Rust port of this project
|
||||
- [llama2.rs](https://github.com/lintian06/llama2.rs) by @[lintian06](https://github.com/lintian06): A Rust port of this project
|
||||
- Go
|
||||
- [go-llama2](https://github.com/tmc/go-llama2) by @[tmc](https://github.com/tmc): a Go port of this project
|
||||
- [llama2.go](https://github.com/nikolaydubina/llama2.go) by @[nikolaydubina](https://github.com/nikolaydubina): a Go port of this project
|
||||
@@ -228,6 +283,7 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
|
||||
- [llama2.cpp](https://github.com/leloykun/llama2.cpp) by @[leloykun](https://github.com/leloykun): a C++ port of this project
|
||||
- JavaScript
|
||||
- [llama2.js](https://github.com/epicure/llama2.js) by @[epicure](https://github.com/epicure): a JavaScript port of this project
|
||||
- [llama2.ts](https://github.com/wizzard0/llama2.ts) by @[oleksandr_now](https://twitter.com/oleksandr_now): a TypeScript port of this project. Full Llama2-7B capable.
|
||||
- [llama2.c-emscripten](https://github.com/gohai/llama2.c-emscripten) by @[gohai](https://github.com/gohai): Emscripten (JavaScript) port, based on @ggerganov's initial prototype
|
||||
- Zig
|
||||
- [llama2.zig](https://github.com/cgbur/llama2.zig) by @[cgbur](https://github.com/cgbur): A Zig port of this project
|
||||
@@ -245,16 +301,16 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
|
||||
- [llama2.py](https://github.com/tairov/llama2.py) by @[tairov](https://github.com/tairov): a simple one file pure Python port of this project with zero dependencies
|
||||
- C#
|
||||
- [llama2.cs](https://github.com/trrahul/llama2.cs) by @[trrahul](https://github.com/trrahul): a C# port of this project
|
||||
- WebAssembly
|
||||
- [icpp-llm](https://github.com/icppWorld/icpp-llm): LLMs for the Internet Computer
|
||||
- [llama2.c - Llama 2 Everywhere](https://github.com/trholding/llama2.c) by @[trholding](https://github.com/trholding): Standalone, Bootable & Portable Binary Llama 2
|
||||
|
||||
## unsorted todos
|
||||
|
||||
- add multiquery support into run.c
|
||||
- add custom bpe training code and the ability to train a smaller vocabulary (32K is to much)
|
||||
- make it easier to add a new dataset with not too much pain
|
||||
- should calculate freq_cis online in the script run.c instead of loading them
|
||||
- int4/8 quantization
|
||||
- export the model in a more sensible output format with a proper header, etc.
|
||||
- train a tiny Llama test model (committed to repo) and use it as reference in unit tests
|
||||
- support Llama 2 7B Chat models and tune run.c to Chat UI/UX
|
||||
- llama2.cu investigate and merge
|
||||
- (LoRA) finetuning and export of Llama 2 models
|
||||
|
||||
Reference in New Issue
Block a user