diff --git a/tinyshakespeare.py b/tinyshakespeare.py new file mode 100644 index 0000000..602624c --- /dev/null +++ b/tinyshakespeare.py @@ -0,0 +1,140 @@ +""" +Download, preprocess and serve the TinyShakespeare dataset as a DataLoader. + +Follows the same interface as the TinyStories dataset. +""" + +import argparse +import os +import random + +import numpy as np +import requests +import torch +import torch.distributed as dist +from tqdm import tqdm + +from tokenizer import Tokenizer + +DATA_CACHE_DIR = "data" + +def download_file(url: str, fname: str, chunk_size=1024): + """Helper function to download a file from a given url""" + resp = requests.get(url, stream=True) + total = int(resp.headers.get("content-length", 0)) + with open(fname, "wb") as file, tqdm( + desc=fname, + total=total, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for data in resp.iter_content(chunk_size=chunk_size): + size = file.write(data) + bar.update(size) + + +def download(): + """Downloads the dataset to disk.""" + os.makedirs(DATA_CACHE_DIR, exist_ok=True) + + # download the TinyShakespeare dataset, unless it's already downloaded + data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" + data_filename = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.txt") + if not os.path.exists(data_filename): + print(f"Downloading {data_url} to {data_filename}...") + download_file(data_url, data_filename) + else: + print(f"{data_filename} already exists, skipping download...") + + print("Download done.") + +def pretokenize(): + enc = Tokenizer() + + data_file = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.txt") + + all_tokens = [] + with open(data_file, "r") as f: + for line in f: + text = line.strip() + tokens = enc.encode(text, bos=True, eos=False) + all_tokens.extend(tokens) + all_tokens = np.array(all_tokens, dtype=np.uint16) + print(f"Total tokens: {len(all_tokens)}") + with open(data_file.replace(".txt", ".bin"), "wb") as f: + f.write(all_tokens.tobytes()) + print(f"Saved {data_file.replace('.txt', '.bin')}") + print("Done.") + + +class PretokDataset(torch.utils.data.IterableDataset): + """Loads pretokenized examples from disk and yields them as PyTorch tensors.""" + + def __init__(self, split, max_seq_len): + super().__init__() + self.split = split + self.max_seq_len = max_seq_len + + def __iter__(self): + # get worker info within a DataLoader + worker_info = torch.utils.data.get_worker_info() + worker_id = worker_info.id if worker_info else 0 + # get DDP rank info + rank = dist.get_rank() if dist.is_initialized() else 0 + # combine the worker_id and worker_rank to create a unique seed for rng + seed = 42 + worker_id + 1337 * rank + rng = random.Random(seed) + print(f"Created a PretokDataset with rng seed {seed}") + data_file = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.bin") + m_all = np.memmap(data_file, dtype=np.uint16, mode="r") + + # split out 10% of the data for validation + split_ix = int(len(m_all) * 0.9) + if self.split == "train": + m = m_all[:split_ix] + else: + m = m_all[split_ix:] + + num_batches = len(m) // self.max_seq_len + num_batches -= 1 # drop the last partial batch + assert num_batches > 0, "this split is way too small? investigate." + + while True: + ixs = list(range(num_batches)) + rng.shuffle(ixs) + for ix in ixs: + start = ix * self.max_seq_len + end = start + self.max_seq_len + 1 + # calling .astype will copy the data into a new numpy array, now in RAM + chunk = torch.from_numpy((m[start:end]).astype(np.int64)) + x = chunk[:-1] + y = chunk[1:] + yield x, y + + +class ShakespeareTask: + + @staticmethod + def iter_batches(split, batch_size, max_seq_len, device, num_workers=0): + ds = PretokDataset(split, max_seq_len) + dl = torch.utils.data.DataLoader( + ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers + ) + for x, y in dl: + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + yield x, y + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("stage", type=str, choices=["download", "train_tokenizer", "pretokenize"]) + args = parser.parse_args() + + # depending on the stage call the appropriate function + fun = { + "download": download, + "pretokenize": pretokenize, + } + fun[args.stage]() \ No newline at end of file diff --git a/train.py b/train.py index 3a6db52..70b5109 100644 --- a/train.py +++ b/train.py @@ -29,6 +29,7 @@ from torch.distributed import destroy_process_group, init_process_group from torch.nn.parallel import DistributedDataParallel as DDP from tinystories import Task +from tinyshakespeare import ShakespeareTask # ----------------------------------------------------------------------------- # I/O @@ -46,6 +47,7 @@ wandb_run_name = "run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S") # data batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size max_seq_len = 256 +dataset = "tinystories" # tinystories|tinyshakespeare # model dim = 288 n_layers = 6 @@ -121,8 +123,9 @@ ctx = ( ) # task-specific setup +task = {'tinystories': Task, 'tinyshakespeare': ShakespeareTask}[dataset] iter_batches = partial( - Task.iter_batches, + task.iter_batches, batch_size=batch_size, max_seq_len=max_seq_len, device=device,