remove the tinyshakespeare dataset until i can bring it back later in a nicer form, otherwise right now we just have a ton of copy paste code here
This commit is contained in:
@@ -1,140 +0,0 @@
|
|||||||
"""
|
|
||||||
Download, preprocess and serve the TinyShakespeare dataset as a DataLoader.
|
|
||||||
|
|
||||||
Follows the same interface as the TinyStories dataset.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import requests
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from tokenizer import Tokenizer
|
|
||||||
|
|
||||||
DATA_CACHE_DIR = "data"
|
|
||||||
|
|
||||||
def download_file(url: str, fname: str, chunk_size=1024):
|
|
||||||
"""Helper function to download a file from a given url"""
|
|
||||||
resp = requests.get(url, stream=True)
|
|
||||||
total = int(resp.headers.get("content-length", 0))
|
|
||||||
with open(fname, "wb") as file, tqdm(
|
|
||||||
desc=fname,
|
|
||||||
total=total,
|
|
||||||
unit="iB",
|
|
||||||
unit_scale=True,
|
|
||||||
unit_divisor=1024,
|
|
||||||
) as bar:
|
|
||||||
for data in resp.iter_content(chunk_size=chunk_size):
|
|
||||||
size = file.write(data)
|
|
||||||
bar.update(size)
|
|
||||||
|
|
||||||
|
|
||||||
def download():
|
|
||||||
"""Downloads the dataset to disk."""
|
|
||||||
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
|
|
||||||
|
|
||||||
# download the TinyShakespeare dataset, unless it's already downloaded
|
|
||||||
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
|
|
||||||
data_filename = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.txt")
|
|
||||||
if not os.path.exists(data_filename):
|
|
||||||
print(f"Downloading {data_url} to {data_filename}...")
|
|
||||||
download_file(data_url, data_filename)
|
|
||||||
else:
|
|
||||||
print(f"{data_filename} already exists, skipping download...")
|
|
||||||
|
|
||||||
print("Download done.")
|
|
||||||
|
|
||||||
def pretokenize():
|
|
||||||
enc = Tokenizer()
|
|
||||||
|
|
||||||
data_file = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.txt")
|
|
||||||
|
|
||||||
all_tokens = []
|
|
||||||
with open(data_file, "r") as f:
|
|
||||||
for line in f:
|
|
||||||
text = line.strip()
|
|
||||||
tokens = enc.encode(text, bos=True, eos=False)
|
|
||||||
all_tokens.extend(tokens)
|
|
||||||
all_tokens = np.array(all_tokens, dtype=np.uint16)
|
|
||||||
print(f"Total tokens: {len(all_tokens)}")
|
|
||||||
with open(data_file.replace(".txt", ".bin"), "wb") as f:
|
|
||||||
f.write(all_tokens.tobytes())
|
|
||||||
print(f"Saved {data_file.replace('.txt', '.bin')}")
|
|
||||||
print("Done.")
|
|
||||||
|
|
||||||
|
|
||||||
class PretokDataset(torch.utils.data.IterableDataset):
|
|
||||||
"""Loads pretokenized examples from disk and yields them as PyTorch tensors."""
|
|
||||||
|
|
||||||
def __init__(self, split, max_seq_len):
|
|
||||||
super().__init__()
|
|
||||||
self.split = split
|
|
||||||
self.max_seq_len = max_seq_len
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
# get worker info within a DataLoader
|
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
|
||||||
worker_id = worker_info.id if worker_info else 0
|
|
||||||
# get DDP rank info
|
|
||||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
|
||||||
# combine the worker_id and worker_rank to create a unique seed for rng
|
|
||||||
seed = 42 + worker_id + 1337 * rank
|
|
||||||
rng = random.Random(seed)
|
|
||||||
print(f"Created a PretokDataset with rng seed {seed}")
|
|
||||||
data_file = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.bin")
|
|
||||||
m_all = np.memmap(data_file, dtype=np.uint16, mode="r")
|
|
||||||
|
|
||||||
# split out 10% of the data for validation
|
|
||||||
split_ix = int(len(m_all) * 0.9)
|
|
||||||
if self.split == "train":
|
|
||||||
m = m_all[:split_ix]
|
|
||||||
else:
|
|
||||||
m = m_all[split_ix:]
|
|
||||||
|
|
||||||
num_batches = len(m) // self.max_seq_len
|
|
||||||
num_batches -= 1 # drop the last partial batch
|
|
||||||
assert num_batches > 0, "this split is way too small? investigate."
|
|
||||||
|
|
||||||
while True:
|
|
||||||
ixs = list(range(num_batches))
|
|
||||||
rng.shuffle(ixs)
|
|
||||||
for ix in ixs:
|
|
||||||
start = ix * self.max_seq_len
|
|
||||||
end = start + self.max_seq_len + 1
|
|
||||||
# calling .astype will copy the data into a new numpy array, now in RAM
|
|
||||||
chunk = torch.from_numpy((m[start:end]).astype(np.int64))
|
|
||||||
x = chunk[:-1]
|
|
||||||
y = chunk[1:]
|
|
||||||
yield x, y
|
|
||||||
|
|
||||||
|
|
||||||
class ShakespeareTask:
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def iter_batches(split, batch_size, max_seq_len, device, num_workers=0):
|
|
||||||
ds = PretokDataset(split, max_seq_len)
|
|
||||||
dl = torch.utils.data.DataLoader(
|
|
||||||
ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
|
|
||||||
)
|
|
||||||
for x, y in dl:
|
|
||||||
x = x.to(device, non_blocking=True)
|
|
||||||
y = y.to(device, non_blocking=True)
|
|
||||||
yield x, y
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("stage", type=str, choices=["download", "train_tokenizer", "pretokenize"])
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# depending on the stage call the appropriate function
|
|
||||||
fun = {
|
|
||||||
"download": download,
|
|
||||||
"pretokenize": pretokenize,
|
|
||||||
}
|
|
||||||
fun[args.stage]()
|
|
||||||
@@ -29,7 +29,6 @@ from torch.distributed import destroy_process_group, init_process_group
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from tinystories import Task
|
from tinystories import Task
|
||||||
from tinyshakespeare import ShakespeareTask
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# I/O
|
# I/O
|
||||||
@@ -49,7 +48,6 @@ batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch
|
|||||||
max_seq_len = 256
|
max_seq_len = 256
|
||||||
vocab_source = "custom" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained
|
vocab_source = "custom" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained
|
||||||
vocab_size = 512
|
vocab_size = 512
|
||||||
dataset = "tinystories" # tinystories|tinyshakespeare
|
|
||||||
# model
|
# model
|
||||||
dim = 288
|
dim = 288
|
||||||
n_layers = 6
|
n_layers = 6
|
||||||
@@ -129,9 +127,8 @@ ctx = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
# task-specific setup
|
# task-specific setup
|
||||||
task = {'tinystories': Task, 'tinyshakespeare': ShakespeareTask}[dataset]
|
|
||||||
iter_batches = partial(
|
iter_batches = partial(
|
||||||
task.iter_batches,
|
Task.iter_batches,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
|
|||||||
Reference in New Issue
Block a user