Add tinyshakespeare dataset
This commit is contained in:
@@ -29,6 +29,7 @@ from torch.distributed import destroy_process_group, init_process_group
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from tinystories import Task
|
||||
from tinyshakespeare import ShakespeareTask
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# I/O
|
||||
@@ -46,6 +47,7 @@ wandb_run_name = "run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
||||
# data
|
||||
batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size
|
||||
max_seq_len = 256
|
||||
dataset = "tinystories" # tinystories|tinyshakespeare
|
||||
# model
|
||||
dim = 288
|
||||
n_layers = 6
|
||||
@@ -121,8 +123,9 @@ ctx = (
|
||||
)
|
||||
|
||||
# task-specific setup
|
||||
task = {'tinystories': Task, 'tinyshakespeare': ShakespeareTask}[dataset]
|
||||
iter_batches = partial(
|
||||
Task.iter_batches,
|
||||
task.iter_batches,
|
||||
batch_size=batch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
device=device,
|
||||
|
||||
Reference in New Issue
Block a user