Add tinyshakespeare dataset

This commit is contained in:
Will Lamond
2023-07-30 06:14:38 -07:00
parent a8f3e1c499
commit e592ed5d64
2 changed files with 144 additions and 1 deletions
+140
View File
@@ -0,0 +1,140 @@
"""
Download, preprocess and serve the TinyShakespeare dataset as a DataLoader.
Follows the same interface as the TinyStories dataset.
"""
import argparse
import os
import random
import numpy as np
import requests
import torch
import torch.distributed as dist
from tqdm import tqdm
from tokenizer import Tokenizer
DATA_CACHE_DIR = "data"
def download_file(url: str, fname: str, chunk_size=1024):
"""Helper function to download a file from a given url"""
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
with open(fname, "wb") as file, tqdm(
desc=fname,
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)
def download():
"""Downloads the dataset to disk."""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
# download the TinyShakespeare dataset, unless it's already downloaded
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
data_filename = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.txt")
if not os.path.exists(data_filename):
print(f"Downloading {data_url} to {data_filename}...")
download_file(data_url, data_filename)
else:
print(f"{data_filename} already exists, skipping download...")
print("Download done.")
def pretokenize():
enc = Tokenizer()
data_file = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.txt")
all_tokens = []
with open(data_file, "r") as f:
for line in f:
text = line.strip()
tokens = enc.encode(text, bos=True, eos=False)
all_tokens.extend(tokens)
all_tokens = np.array(all_tokens, dtype=np.uint16)
print(f"Total tokens: {len(all_tokens)}")
with open(data_file.replace(".txt", ".bin"), "wb") as f:
f.write(all_tokens.tobytes())
print(f"Saved {data_file.replace('.txt', '.bin')}")
print("Done.")
class PretokDataset(torch.utils.data.IterableDataset):
"""Loads pretokenized examples from disk and yields them as PyTorch tensors."""
def __init__(self, split, max_seq_len):
super().__init__()
self.split = split
self.max_seq_len = max_seq_len
def __iter__(self):
# get worker info within a DataLoader
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id if worker_info else 0
# get DDP rank info
rank = dist.get_rank() if dist.is_initialized() else 0
# combine the worker_id and worker_rank to create a unique seed for rng
seed = 42 + worker_id + 1337 * rank
rng = random.Random(seed)
print(f"Created a PretokDataset with rng seed {seed}")
data_file = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.bin")
m_all = np.memmap(data_file, dtype=np.uint16, mode="r")
# split out 10% of the data for validation
split_ix = int(len(m_all) * 0.9)
if self.split == "train":
m = m_all[:split_ix]
else:
m = m_all[split_ix:]
num_batches = len(m) // self.max_seq_len
num_batches -= 1 # drop the last partial batch
assert num_batches > 0, "this split is way too small? investigate."
while True:
ixs = list(range(num_batches))
rng.shuffle(ixs)
for ix in ixs:
start = ix * self.max_seq_len
end = start + self.max_seq_len + 1
# calling .astype will copy the data into a new numpy array, now in RAM
chunk = torch.from_numpy((m[start:end]).astype(np.int64))
x = chunk[:-1]
y = chunk[1:]
yield x, y
class ShakespeareTask:
@staticmethod
def iter_batches(split, batch_size, max_seq_len, device, num_workers=0):
ds = PretokDataset(split, max_seq_len)
dl = torch.utils.data.DataLoader(
ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
)
for x, y in dl:
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
yield x, y
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("stage", type=str, choices=["download", "train_tokenizer", "pretokenize"])
args = parser.parse_args()
# depending on the stage call the appropriate function
fun = {
"download": download,
"pretokenize": pretokenize,
}
fun[args.stage]()
+4 -1
View File
@@ -29,6 +29,7 @@ from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
from tinystories import Task
from tinyshakespeare import ShakespeareTask
# -----------------------------------------------------------------------------
# I/O
@@ -46,6 +47,7 @@ wandb_run_name = "run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
# data
batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size
max_seq_len = 256
dataset = "tinystories" # tinystories|tinyshakespeare
# model
dim = 288
n_layers = 6
@@ -121,8 +123,9 @@ ctx = (
)
# task-specific setup
task = {'tinystories': Task, 'tinyshakespeare': ShakespeareTask}[dataset]
iter_batches = partial(
Task.iter_batches,
task.iter_batches,
batch_size=batch_size,
max_seq_len=max_seq_len,
device=device,