remove the tinyshakespeare dataset until i can bring it back later in a nicer form, otherwise right now we just have a ton of copy paste code here

This commit is contained in:
Andrej Karpathy
2023-08-13 02:18:30 +00:00
parent f5fc0c245f
commit 00a61dc7f9
2 changed files with 1 additions and 144 deletions
-140
View File
@@ -1,140 +0,0 @@
"""
Download, preprocess and serve the TinyShakespeare dataset as a DataLoader.
Follows the same interface as the TinyStories dataset.
"""
import argparse
import os
import random
import numpy as np
import requests
import torch
import torch.distributed as dist
from tqdm import tqdm
from tokenizer import Tokenizer
DATA_CACHE_DIR = "data"
def download_file(url: str, fname: str, chunk_size=1024):
"""Helper function to download a file from a given url"""
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
with open(fname, "wb") as file, tqdm(
desc=fname,
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)
def download():
"""Downloads the dataset to disk."""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
# download the TinyShakespeare dataset, unless it's already downloaded
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
data_filename = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.txt")
if not os.path.exists(data_filename):
print(f"Downloading {data_url} to {data_filename}...")
download_file(data_url, data_filename)
else:
print(f"{data_filename} already exists, skipping download...")
print("Download done.")
def pretokenize():
enc = Tokenizer()
data_file = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.txt")
all_tokens = []
with open(data_file, "r") as f:
for line in f:
text = line.strip()
tokens = enc.encode(text, bos=True, eos=False)
all_tokens.extend(tokens)
all_tokens = np.array(all_tokens, dtype=np.uint16)
print(f"Total tokens: {len(all_tokens)}")
with open(data_file.replace(".txt", ".bin"), "wb") as f:
f.write(all_tokens.tobytes())
print(f"Saved {data_file.replace('.txt', '.bin')}")
print("Done.")
class PretokDataset(torch.utils.data.IterableDataset):
"""Loads pretokenized examples from disk and yields them as PyTorch tensors."""
def __init__(self, split, max_seq_len):
super().__init__()
self.split = split
self.max_seq_len = max_seq_len
def __iter__(self):
# get worker info within a DataLoader
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id if worker_info else 0
# get DDP rank info
rank = dist.get_rank() if dist.is_initialized() else 0
# combine the worker_id and worker_rank to create a unique seed for rng
seed = 42 + worker_id + 1337 * rank
rng = random.Random(seed)
print(f"Created a PretokDataset with rng seed {seed}")
data_file = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.bin")
m_all = np.memmap(data_file, dtype=np.uint16, mode="r")
# split out 10% of the data for validation
split_ix = int(len(m_all) * 0.9)
if self.split == "train":
m = m_all[:split_ix]
else:
m = m_all[split_ix:]
num_batches = len(m) // self.max_seq_len
num_batches -= 1 # drop the last partial batch
assert num_batches > 0, "this split is way too small? investigate."
while True:
ixs = list(range(num_batches))
rng.shuffle(ixs)
for ix in ixs:
start = ix * self.max_seq_len
end = start + self.max_seq_len + 1
# calling .astype will copy the data into a new numpy array, now in RAM
chunk = torch.from_numpy((m[start:end]).astype(np.int64))
x = chunk[:-1]
y = chunk[1:]
yield x, y
class ShakespeareTask:
@staticmethod
def iter_batches(split, batch_size, max_seq_len, device, num_workers=0):
ds = PretokDataset(split, max_seq_len)
dl = torch.utils.data.DataLoader(
ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
)
for x, y in dl:
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
yield x, y
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("stage", type=str, choices=["download", "train_tokenizer", "pretokenize"])
args = parser.parse_args()
# depending on the stage call the appropriate function
fun = {
"download": download,
"pretokenize": pretokenize,
}
fun[args.stage]()
+1 -4
View File
@@ -29,7 +29,6 @@ from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
from tinystories import Task
from tinyshakespeare import ShakespeareTask
# -----------------------------------------------------------------------------
# I/O
@@ -49,7 +48,6 @@ batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch
max_seq_len = 256
vocab_source = "custom" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained
vocab_size = 512
dataset = "tinystories" # tinystories|tinyshakespeare
# model
dim = 288
n_layers = 6
@@ -129,9 +127,8 @@ ctx = (
)
# task-specific setup
task = {'tinystories': Task, 'tinyshakespeare': ShakespeareTask}[dataset]
iter_batches = partial(
task.iter_batches,
Task.iter_batches,
batch_size=batch_size,
max_seq_len=max_seq_len,
vocab_size=vocab_size,