From fe9b9f2f15eda96507837f1b2584e98401a61930 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Wed, 23 Aug 2023 17:28:14 +0300 Subject: [PATCH 1/2] Train vocab in Python --- README.md | 2 +- tinystories.py | 26 ++++++++++++++++---------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index e9df1f6..f4e20e9 100644 --- a/README.md +++ b/README.md @@ -163,7 +163,7 @@ python tinystories.py train_vocab --vocab_size=4096 python tinystories.py pretokenize --vocab_size=4096 ``` -The `train_vocab` stage will call the `train_vocab.sh` script, which calls the `sentencepiece` library to train the tokenizer, storing it in a new file `data/tok4096.model`. I tried to reproduce as well as I could the settings that (I think) Meta used to train their vocabulary. This uses the Byte Pair Encoding algorithm that starts out with raw utf8 byte sequences of the text data and then iteratively merges the most common consecutive pairs of tokens to form the vocabulary. Inspect the `tinystories.py` file - the custom tokenizers are stored in a special directory structure indexed by the vocab size. +The `train_vocab` stage will call the `sentencepiece` library to train the tokenizer, storing it in a new file `data/tok4096.model`. I tried to reproduce as well as I could the settings that (I think) Meta used to train their vocabulary. This uses the Byte Pair Encoding algorithm that starts out with raw utf8 byte sequences of the text data and then iteratively merges the most common consecutive pairs of tokens to form the vocabulary. Inspect the `tinystories.py` file - the custom tokenizers are stored in a special directory structure indexed by the vocab size. A quick note of interest is that vocab size of 4096 trained specifically on tinystories creates integer sequences with about the same sequence length per example as the default Llama 2 tokenizer of 32000 tokens! This means that our custom, tailored tokenizer is a lot better adapted to our specific text, and can compress it very effectively. So our trained models are smaller and faster. diff --git a/tinystories.py b/tinystories.py index 90d576b..003b1e3 100644 --- a/tinystories.py +++ b/tinystories.py @@ -13,6 +13,7 @@ from functools import partial import numpy as np import requests +import sentencepiece as spm import torch import torch.distributed as dist from tqdm import tqdm @@ -97,16 +98,21 @@ def train_vocab(vocab_size): of.write(text + "\n") print(f"Size is: {os.path.getsize(tiny_file) / 1024 / 1024:.2f} MB") - # 2) run the train_vocab.sh script that trains the sentencepiece model - print("Will now train the vocab with:") - cmd = f"bash train_vocab.sh {tiny_file} {prefix} {vocab_size}" - print(cmd) - print("OK? [y/N] ") - dec = input() - if dec.lower() != "y": - print("Exiting...") - return - os.system(cmd) + # 2) train the sentencepiece model + print("Will now train the vocab...") + + spm.SentencePieceTrainer.train(input=tiny_file, + model_prefix=prefix, + model_type="bpe", + vocab_size=vocab_size, + self_test_sample_size=0, + input_format="text", + character_coverage=1.0, + split_digits=True, + allow_whitespace_only_pieces=True, + byte_fallback=True, + unk_surface=r" \342\201\207 ", + normalization_rule_name="identity") # 3) optional cleanup, ask the user if they'd like to delete tiny.txt dec = input(f"Delete the temporary file {tiny_file}? [y/N] ") From 096325b66c2ab84095bd407cbab84d731edc65bc Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 24 Aug 2023 03:09:55 +0000 Subject: [PATCH 2/2] bring back num_threads --- tinystories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinystories.py b/tinystories.py index 003b1e3..800d73a 100644 --- a/tinystories.py +++ b/tinystories.py @@ -100,7 +100,6 @@ def train_vocab(vocab_size): # 2) train the sentencepiece model print("Will now train the vocab...") - spm.SentencePieceTrainer.train(input=tiny_file, model_prefix=prefix, model_type="bpe", @@ -108,6 +107,7 @@ def train_vocab(vocab_size): self_test_sample_size=0, input_format="text", character_coverage=1.0, + num_threads=os.cpu_count(), split_digits=True, allow_whitespace_only_pieces=True, byte_fallback=True,