Add tinyshakespeare dataset

This commit is contained in:
Will Lamond
2023-07-30 06:14:38 -07:00
parent a8f3e1c499
commit e592ed5d64
2 changed files with 144 additions and 1 deletions
+4 -1
View File
@@ -29,6 +29,7 @@ from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
from tinystories import Task
from tinyshakespeare import ShakespeareTask
# -----------------------------------------------------------------------------
# I/O
@@ -46,6 +47,7 @@ wandb_run_name = "run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
# data
batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size
max_seq_len = 256
dataset = "tinystories" # tinystories|tinyshakespeare
# model
dim = 288
n_layers = 6
@@ -121,8 +123,9 @@ ctx = (
)
# task-specific setup
task = {'tinystories': Task, 'tinyshakespeare': ShakespeareTask}[dataset]
iter_batches = partial(
Task.iter_batches,
task.iter_batches,
batch_size=batch_size,
max_seq_len=max_seq_len,
device=device,