diff --git a/tinyshakespeare.py b/tinyshakespeare.py deleted file mode 100644 index 602624c..0000000 --- a/tinyshakespeare.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -Download, preprocess and serve the TinyShakespeare dataset as a DataLoader. - -Follows the same interface as the TinyStories dataset. -""" - -import argparse -import os -import random - -import numpy as np -import requests -import torch -import torch.distributed as dist -from tqdm import tqdm - -from tokenizer import Tokenizer - -DATA_CACHE_DIR = "data" - -def download_file(url: str, fname: str, chunk_size=1024): - """Helper function to download a file from a given url""" - resp = requests.get(url, stream=True) - total = int(resp.headers.get("content-length", 0)) - with open(fname, "wb") as file, tqdm( - desc=fname, - total=total, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as bar: - for data in resp.iter_content(chunk_size=chunk_size): - size = file.write(data) - bar.update(size) - - -def download(): - """Downloads the dataset to disk.""" - os.makedirs(DATA_CACHE_DIR, exist_ok=True) - - # download the TinyShakespeare dataset, unless it's already downloaded - data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" - data_filename = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.txt") - if not os.path.exists(data_filename): - print(f"Downloading {data_url} to {data_filename}...") - download_file(data_url, data_filename) - else: - print(f"{data_filename} already exists, skipping download...") - - print("Download done.") - -def pretokenize(): - enc = Tokenizer() - - data_file = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.txt") - - all_tokens = [] - with open(data_file, "r") as f: - for line in f: - text = line.strip() - tokens = enc.encode(text, bos=True, eos=False) - all_tokens.extend(tokens) - all_tokens = np.array(all_tokens, dtype=np.uint16) - print(f"Total tokens: {len(all_tokens)}") - with open(data_file.replace(".txt", ".bin"), "wb") as f: - f.write(all_tokens.tobytes()) - print(f"Saved {data_file.replace('.txt', '.bin')}") - print("Done.") - - -class PretokDataset(torch.utils.data.IterableDataset): - """Loads pretokenized examples from disk and yields them as PyTorch tensors.""" - - def __init__(self, split, max_seq_len): - super().__init__() - self.split = split - self.max_seq_len = max_seq_len - - def __iter__(self): - # get worker info within a DataLoader - worker_info = torch.utils.data.get_worker_info() - worker_id = worker_info.id if worker_info else 0 - # get DDP rank info - rank = dist.get_rank() if dist.is_initialized() else 0 - # combine the worker_id and worker_rank to create a unique seed for rng - seed = 42 + worker_id + 1337 * rank - rng = random.Random(seed) - print(f"Created a PretokDataset with rng seed {seed}") - data_file = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.bin") - m_all = np.memmap(data_file, dtype=np.uint16, mode="r") - - # split out 10% of the data for validation - split_ix = int(len(m_all) * 0.9) - if self.split == "train": - m = m_all[:split_ix] - else: - m = m_all[split_ix:] - - num_batches = len(m) // self.max_seq_len - num_batches -= 1 # drop the last partial batch - assert num_batches > 0, "this split is way too small? investigate." - - while True: - ixs = list(range(num_batches)) - rng.shuffle(ixs) - for ix in ixs: - start = ix * self.max_seq_len - end = start + self.max_seq_len + 1 - # calling .astype will copy the data into a new numpy array, now in RAM - chunk = torch.from_numpy((m[start:end]).astype(np.int64)) - x = chunk[:-1] - y = chunk[1:] - yield x, y - - -class ShakespeareTask: - - @staticmethod - def iter_batches(split, batch_size, max_seq_len, device, num_workers=0): - ds = PretokDataset(split, max_seq_len) - dl = torch.utils.data.DataLoader( - ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers - ) - for x, y in dl: - x = x.to(device, non_blocking=True) - y = y.to(device, non_blocking=True) - yield x, y - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("stage", type=str, choices=["download", "train_tokenizer", "pretokenize"]) - args = parser.parse_args() - - # depending on the stage call the appropriate function - fun = { - "download": download, - "pretokenize": pretokenize, - } - fun[args.stage]() \ No newline at end of file diff --git a/train.py b/train.py index 662afcf..39b4f49 100644 --- a/train.py +++ b/train.py @@ -29,7 +29,6 @@ from torch.distributed import destroy_process_group, init_process_group from torch.nn.parallel import DistributedDataParallel as DDP from tinystories import Task -from tinyshakespeare import ShakespeareTask # ----------------------------------------------------------------------------- # I/O @@ -49,7 +48,6 @@ batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch max_seq_len = 256 vocab_source = "custom" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained vocab_size = 512 -dataset = "tinystories" # tinystories|tinyshakespeare # model dim = 288 n_layers = 6 @@ -129,9 +127,8 @@ ctx = ( ) # task-specific setup -task = {'tinystories': Task, 'tinyshakespeare': ShakespeareTask}[dataset] iter_batches = partial( - task.iter_batches, + Task.iter_batches, batch_size=batch_size, max_seq_len=max_seq_len, vocab_size=vocab_size,