+274
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
Download, preprocess and serve the TinyStories dataset as a DataLoader.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import List
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
DATA_CACHE_DIR = "data"
|
||||
|
||||
def download_file(url: str, fname: str, chunk_size=1024):
|
||||
"""Helper function to download a file from a given url"""
|
||||
resp = requests.get(url, stream=True)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
with open(fname, "wb") as file, tqdm(
|
||||
desc=fname,
|
||||
total=total,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as bar:
|
||||
for data in resp.iter_content(chunk_size=chunk_size):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
|
||||
|
||||
def download():
|
||||
"""Downloads the TinyStories dataset to DATA_CACHE_DIR"""
|
||||
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
|
||||
|
||||
# download the TinyStories dataset, unless it's already downloaded
|
||||
data_url = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz"
|
||||
data_filename = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data.tar.gz")
|
||||
if not os.path.exists(data_filename):
|
||||
print(f"Downloading {data_url} to {data_filename}...")
|
||||
download_file(data_url, data_filename)
|
||||
else:
|
||||
print(f"{data_filename} already exists, skipping download...")
|
||||
|
||||
# unpack the tar.gz file into all the data shards (json files)
|
||||
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
||||
if not os.path.exists(data_dir):
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
print(f"Unpacking {data_filename}...")
|
||||
os.system(f"tar -xzf {data_filename} -C {data_dir}")
|
||||
else:
|
||||
print(f"{data_dir} already exists, skipping unpacking...")
|
||||
|
||||
# print a single example just for debugging and such
|
||||
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
|
||||
with open(shard_filenames[0], "r") as f:
|
||||
data = json.load(f)
|
||||
print("Download done.")
|
||||
print(f"Number of shards: {len(shard_filenames)}")
|
||||
print(f"Example story:\n{data[0]}")
|
||||
|
||||
def train_vocab(vocab_size):
|
||||
"""
|
||||
Trains a custom sentencepiece tokenizer on the TinyStories dataset.
|
||||
The custom tokenizer files will be saved in DATA_CACHE_DIR/tok{N} directories,
|
||||
where N is the vocab size. This is also where the pretok .bin files will go.
|
||||
"""
|
||||
assert vocab_size > 0, "Vocab size must be positive"
|
||||
|
||||
# output file prefix path for sentencepiece
|
||||
prefix = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
|
||||
|
||||
# how many shards we'll use for vocab training, kept low for efficiency
|
||||
num_shards = 10
|
||||
|
||||
# 1) export a large chunk of text as a single text file tiny.txt
|
||||
tiny_file = os.path.join(DATA_CACHE_DIR, "tiny.txt")
|
||||
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
||||
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
|
||||
|
||||
print(f"Writing temporary file {tiny_file} with {num_shards} shards...")
|
||||
with open(tiny_file, "w") as of:
|
||||
for shard in tqdm(shard_filenames[:num_shards]):
|
||||
with open(shard, "r") as f:
|
||||
data = json.load(f)
|
||||
for example in data:
|
||||
text = example["story"]
|
||||
text = text.strip()
|
||||
of.write(text + "\n")
|
||||
print(f"Size is: {os.path.getsize(tiny_file) / 1024 / 1024:.2f} MB")
|
||||
|
||||
# 2) run the train_vocab.sh script that trains the sentencepiece model
|
||||
print("Will now train the vocab with:")
|
||||
cmd = f"bash train_vocab.sh {tiny_file} {prefix} {vocab_size}"
|
||||
print(cmd)
|
||||
print("OK? [y/N] ")
|
||||
dec = input()
|
||||
if dec.lower() != "y":
|
||||
print("Exiting...")
|
||||
return
|
||||
os.system(cmd)
|
||||
|
||||
# 3) optional cleanup, ask the user if they'd like to delete tiny.txt
|
||||
dec = input(f"Delete the temporary file {tiny_file}? [y/N] ")
|
||||
if dec.lower() == "y":
|
||||
os.remove(tiny_file)
|
||||
print(f"Deleted {tiny_file}")
|
||||
|
||||
print(f"Trained tokenizer is in {prefix}.model")
|
||||
print("Done.")
|
||||
|
||||
|
||||
def process_shard(args, vocab_size):
|
||||
shard_id, shard = args
|
||||
tokenizer_model = get_tokenizer_model_path(vocab_size)
|
||||
enc = Tokenizer(tokenizer_model)
|
||||
with open(shard, "r") as f:
|
||||
data = json.load(f)
|
||||
all_tokens = []
|
||||
for example in tqdm(data, position=shard_id):
|
||||
text = example["story"]
|
||||
text = text.strip() # get rid of leading/trailing whitespace
|
||||
tokens = enc.encode(text, bos=True, eos=False) # encode the text, use BOS
|
||||
all_tokens.extend(tokens)
|
||||
# convert to uint16 nparray
|
||||
all_tokens = np.array(all_tokens, dtype=np.uint16)
|
||||
# calculate the output filename
|
||||
if vocab_size == 0:
|
||||
# if we're using Llama 2, just save the tokenized file in the same dir
|
||||
tokenized_filename = shard.replace(".json", ".bin")
|
||||
else:
|
||||
# save .bin files into a new tok{N} directory
|
||||
bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
|
||||
shard_basename = os.path.basename(shard)
|
||||
bin_basename = shard_basename.replace(".json", ".bin")
|
||||
tokenized_filename = os.path.join(bin_dir, bin_basename)
|
||||
# write the bytes
|
||||
with open(tokenized_filename, "wb") as f:
|
||||
f.write(all_tokens.tobytes())
|
||||
# calculate the average sequence length (they are separated by BOS=1)
|
||||
avg_seq_len = all_tokens.size / ((all_tokens == 1).sum())
|
||||
print(f"Saved {tokenized_filename}, average seqlen: {avg_seq_len:.2f}")
|
||||
|
||||
|
||||
def pretokenize(vocab_size):
|
||||
# iterate the shards and tokenize all of them one by one
|
||||
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
||||
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
|
||||
if vocab_size > 0:
|
||||
# .bin files will be saved into tok{N} directory, create it once here
|
||||
bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
|
||||
os.makedirs(bin_dir, exist_ok=True)
|
||||
|
||||
# process all the shards in a process pool
|
||||
fun = partial(process_shard, vocab_size=vocab_size)
|
||||
with ProcessPoolExecutor() as executor:
|
||||
executor.map(fun, enumerate(shard_filenames))
|
||||
print("Done.")
|
||||
|
||||
|
||||
class PretokDataset(torch.utils.data.IterableDataset):
|
||||
"""Loads pretokenized examples from disk and yields them as PyTorch tensors."""
|
||||
|
||||
def __init__(self, split, max_seq_len, vocab_size, vocab_source):
|
||||
super().__init__()
|
||||
self.split = split
|
||||
self.max_seq_len = max_seq_len
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab_source = vocab_source
|
||||
|
||||
def __iter__(self):
|
||||
# get worker info within a DataLoader
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
worker_id = worker_info.id if worker_info else 0
|
||||
# get DDP rank info
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
# combine the worker_id and worker_rank to create a unique seed for rng
|
||||
seed = 42 + worker_id + 1337 * rank
|
||||
rng = random.Random(seed)
|
||||
print(f"Created a PretokDataset with rng seed {seed}")
|
||||
if self.vocab_source == "llama2":
|
||||
# the .bin files are right along the .json files
|
||||
bin_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
||||
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
|
||||
elif self.vocab_source == "custom":
|
||||
# the .bin files are in tok{N} directory
|
||||
bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{self.vocab_size}")
|
||||
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
|
||||
# train/test split. let's use only shard 0 for test split, rest train
|
||||
shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1]
|
||||
while True:
|
||||
rng.shuffle(shard_filenames)
|
||||
for shard in shard_filenames:
|
||||
# open the dataset for reading but keep it on disk with memmap
|
||||
m = np.memmap(shard, dtype=np.uint16, mode="r")
|
||||
num_batches = len(m) // self.max_seq_len
|
||||
num_batches -= 1 # drop the last partial batch
|
||||
assert num_batches > 0, "this shard is way too small? investigate."
|
||||
ixs = list(range(num_batches))
|
||||
rng.shuffle(ixs)
|
||||
for ix in ixs:
|
||||
start = ix * self.max_seq_len
|
||||
end = start + self.max_seq_len + 1
|
||||
# calling .astype will copy the data into a new numpy array, now in RAM
|
||||
chunk = torch.from_numpy((m[start:end]).astype(np.int64))
|
||||
x = chunk[:-1]
|
||||
y = chunk[1:]
|
||||
yield x, y
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# public interface functions
|
||||
|
||||
def get_tokenizer_model_path(vocab_size):
|
||||
"""
|
||||
Returns path to the sentencepiece tokenizer model for a given vocab size
|
||||
vocab_size = 0 designates the default Llama 2 tokenizer, in that case
|
||||
None is returned.
|
||||
"""
|
||||
if vocab_size == 0:
|
||||
return None
|
||||
else:
|
||||
return os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}.model")
|
||||
|
||||
class Task:
|
||||
|
||||
@staticmethod
|
||||
def iter_batches(batch_size, device, num_workers=0, **dataset_kwargs):
|
||||
ds = PretokDataset(**dataset_kwargs)
|
||||
dl = torch.utils.data.DataLoader(
|
||||
ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
|
||||
)
|
||||
for x, y in dl:
|
||||
x = x.to(device, non_blocking=True)
|
||||
y = y.to(device, non_blocking=True)
|
||||
yield x, y
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# CLI for constructing the dataset
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
These stages are designed to be run in order.
|
||||
|
||||
To tokenize data with the Llama 2 tokenizer:
|
||||
python tinystories.py download
|
||||
python tinystories.py pretokenize
|
||||
|
||||
To tokenize data with a custom tokenizer we train ourselves with sentencepiece, e.g.:
|
||||
python tinystories.py download
|
||||
python tinystories.py train_vocab --vocab_size=2048
|
||||
python tinystories.py pretokenize --vocab_size=2048
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("stage", type=str, choices=["download", "pretokenize", "train_vocab"])
|
||||
parser.add_argument("--vocab_size", type=int, default=0, help="pretokenization vocab size. 0 = use Llama 2 tokenizer.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# depending on the stage call the appropriate function
|
||||
if args.stage == "download":
|
||||
download()
|
||||
elif args.stage == "train_vocab":
|
||||
train_vocab(vocab_size=args.vocab_size)
|
||||
elif args.stage == "pretokenize":
|
||||
pretokenize(vocab_size=args.vocab_size)
|
||||
else:
|
||||
raise ValueError(f"Unknown stage {args.stage}")
|
||||
Reference in New Issue
Block a user