ok i can train and sample a model with a custom tokenizer
This commit is contained in:
@@ -47,6 +47,8 @@ wandb_run_name = "run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
||||
# data
|
||||
batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size
|
||||
max_seq_len = 256
|
||||
vocab_source = "custom" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained
|
||||
vocab_size = 512
|
||||
dataset = "tinystories" # tinystories|tinyshakespeare
|
||||
# model
|
||||
dim = 288
|
||||
@@ -83,6 +85,10 @@ config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
lr_decay_iters = max_iters # should be ~= max_iters per Chinchilla
|
||||
min_lr = 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
|
||||
|
||||
# validating checks
|
||||
assert vocab_source in ["llama2", "custom"]
|
||||
assert vocab_source == "custom" or vocab_size == 32000, "The vocab from Meta has 32K tokens"
|
||||
|
||||
# various inits, derived attributes, I/O setup
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
||||
if ddp:
|
||||
@@ -128,6 +134,8 @@ iter_batches = partial(
|
||||
task.iter_batches,
|
||||
batch_size=batch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
vocab_size=vocab_size,
|
||||
vocab_source=vocab_source,
|
||||
device=device,
|
||||
num_workers=0,
|
||||
)
|
||||
@@ -142,7 +150,7 @@ model_args = dict(
|
||||
n_layers=n_layers,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_heads,
|
||||
vocab_size=32000,
|
||||
vocab_size=vocab_size,
|
||||
multiple_of=multiple_of,
|
||||
max_seq_len=max_seq_len,
|
||||
dropout=dropout,
|
||||
@@ -206,7 +214,7 @@ def estimate_loss():
|
||||
out = {}
|
||||
model.eval()
|
||||
for split in ["train", "val"]:
|
||||
batch_iter = iter_batches(split)
|
||||
batch_iter = iter_batches(split=split)
|
||||
losses = torch.zeros(eval_iters) # keep on CPU
|
||||
for k in range(eval_iters):
|
||||
X, Y = next(batch_iter)
|
||||
@@ -238,7 +246,7 @@ if wandb_log and master_process:
|
||||
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
|
||||
|
||||
# training loop
|
||||
train_batch_iter = iter_batches("train")
|
||||
train_batch_iter = iter_batches(split="train")
|
||||
X, Y = next(train_batch_iter) # fetch the very first batch
|
||||
t0 = time.time()
|
||||
local_iter_num = 0 # number of iterations in the lifetime of this process
|
||||
|
||||
Reference in New Issue
Block a user