Add tinyshakespeare dataset
This commit is contained in:
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Download, preprocess and serve the TinyShakespeare dataset as a DataLoader.
|
||||
|
||||
Follows the same interface as the TinyStories dataset.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
DATA_CACHE_DIR = "data"
|
||||
|
||||
def download_file(url: str, fname: str, chunk_size=1024):
|
||||
"""Helper function to download a file from a given url"""
|
||||
resp = requests.get(url, stream=True)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
with open(fname, "wb") as file, tqdm(
|
||||
desc=fname,
|
||||
total=total,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as bar:
|
||||
for data in resp.iter_content(chunk_size=chunk_size):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
|
||||
|
||||
def download():
|
||||
"""Downloads the dataset to disk."""
|
||||
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
|
||||
|
||||
# download the TinyShakespeare dataset, unless it's already downloaded
|
||||
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
|
||||
data_filename = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.txt")
|
||||
if not os.path.exists(data_filename):
|
||||
print(f"Downloading {data_url} to {data_filename}...")
|
||||
download_file(data_url, data_filename)
|
||||
else:
|
||||
print(f"{data_filename} already exists, skipping download...")
|
||||
|
||||
print("Download done.")
|
||||
|
||||
def pretokenize():
|
||||
enc = Tokenizer()
|
||||
|
||||
data_file = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.txt")
|
||||
|
||||
all_tokens = []
|
||||
with open(data_file, "r") as f:
|
||||
for line in f:
|
||||
text = line.strip()
|
||||
tokens = enc.encode(text, bos=True, eos=False)
|
||||
all_tokens.extend(tokens)
|
||||
all_tokens = np.array(all_tokens, dtype=np.uint16)
|
||||
print(f"Total tokens: {len(all_tokens)}")
|
||||
with open(data_file.replace(".txt", ".bin"), "wb") as f:
|
||||
f.write(all_tokens.tobytes())
|
||||
print(f"Saved {data_file.replace('.txt', '.bin')}")
|
||||
print("Done.")
|
||||
|
||||
|
||||
class PretokDataset(torch.utils.data.IterableDataset):
|
||||
"""Loads pretokenized examples from disk and yields them as PyTorch tensors."""
|
||||
|
||||
def __init__(self, split, max_seq_len):
|
||||
super().__init__()
|
||||
self.split = split
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
def __iter__(self):
|
||||
# get worker info within a DataLoader
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
worker_id = worker_info.id if worker_info else 0
|
||||
# get DDP rank info
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
# combine the worker_id and worker_rank to create a unique seed for rng
|
||||
seed = 42 + worker_id + 1337 * rank
|
||||
rng = random.Random(seed)
|
||||
print(f"Created a PretokDataset with rng seed {seed}")
|
||||
data_file = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.bin")
|
||||
m_all = np.memmap(data_file, dtype=np.uint16, mode="r")
|
||||
|
||||
# split out 10% of the data for validation
|
||||
split_ix = int(len(m_all) * 0.9)
|
||||
if self.split == "train":
|
||||
m = m_all[:split_ix]
|
||||
else:
|
||||
m = m_all[split_ix:]
|
||||
|
||||
num_batches = len(m) // self.max_seq_len
|
||||
num_batches -= 1 # drop the last partial batch
|
||||
assert num_batches > 0, "this split is way too small? investigate."
|
||||
|
||||
while True:
|
||||
ixs = list(range(num_batches))
|
||||
rng.shuffle(ixs)
|
||||
for ix in ixs:
|
||||
start = ix * self.max_seq_len
|
||||
end = start + self.max_seq_len + 1
|
||||
# calling .astype will copy the data into a new numpy array, now in RAM
|
||||
chunk = torch.from_numpy((m[start:end]).astype(np.int64))
|
||||
x = chunk[:-1]
|
||||
y = chunk[1:]
|
||||
yield x, y
|
||||
|
||||
|
||||
class ShakespeareTask:
|
||||
|
||||
@staticmethod
|
||||
def iter_batches(split, batch_size, max_seq_len, device, num_workers=0):
|
||||
ds = PretokDataset(split, max_seq_len)
|
||||
dl = torch.utils.data.DataLoader(
|
||||
ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
|
||||
)
|
||||
for x, y in dl:
|
||||
x = x.to(device, non_blocking=True)
|
||||
y = y.to(device, non_blocking=True)
|
||||
yield x, y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("stage", type=str, choices=["download", "train_tokenizer", "pretokenize"])
|
||||
args = parser.parse_args()
|
||||
|
||||
# depending on the stage call the appropriate function
|
||||
fun = {
|
||||
"download": download,
|
||||
"pretokenize": pretokenize,
|
||||
}
|
||||
fun[args.stage]()
|
||||
@@ -29,6 +29,7 @@ from torch.distributed import destroy_process_group, init_process_group
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from tinystories import Task
|
||||
from tinyshakespeare import ShakespeareTask
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# I/O
|
||||
@@ -46,6 +47,7 @@ wandb_run_name = "run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
||||
# data
|
||||
batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size
|
||||
max_seq_len = 256
|
||||
dataset = "tinystories" # tinystories|tinyshakespeare
|
||||
# model
|
||||
dim = 288
|
||||
n_layers = 6
|
||||
@@ -121,8 +123,9 @@ ctx = (
|
||||
)
|
||||
|
||||
# task-specific setup
|
||||
task = {'tinystories': Task, 'tinyshakespeare': ShakespeareTask}[dataset]
|
||||
iter_batches = partial(
|
||||
Task.iter_batches,
|
||||
task.iter_batches,
|
||||
batch_size=batch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
device=device,
|
||||
|
||||
Reference in New Issue
Block a user