ok i can train and sample a model with a custom tokenizer
This commit is contained in:
@@ -11,12 +11,13 @@ from torch import nn
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs:
|
class ModelArgs:
|
||||||
|
# default hyperparameters for the Llama 7B model
|
||||||
dim: int = 4096
|
dim: int = 4096
|
||||||
n_layers: int = 32
|
n_layers: int = 32
|
||||||
n_heads: int = 32
|
n_heads: int = 32
|
||||||
n_kv_heads: Optional[int] = None
|
n_kv_heads: Optional[int] = None
|
||||||
vocab_size: int = -1 # defined later by tokenizer
|
vocab_size: int = 32000
|
||||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
multiple_of: int = 256 # MLP hidden layer size will be multiple of
|
||||||
norm_eps: float = 1e-5
|
norm_eps: float = 1e-5
|
||||||
max_seq_len: int = 2048
|
max_seq_len: int = 2048
|
||||||
dropout: float = 0.0
|
dropout: float = 0.0
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import tiktoken
|
|||||||
from model import ModelArgs, Transformer
|
from model import ModelArgs, Transformer
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
|
|
||||||
|
from tinystories import get_tokenizer_model_path
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
out_dir = 'out' # ignored if init_from is not 'resume'
|
out_dir = 'out' # ignored if init_from is not 'resume'
|
||||||
start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
|
start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
|
||||||
@@ -51,7 +53,9 @@ if compile:
|
|||||||
model = torch.compile(model) # requires PyTorch 2.0 (optional)
|
model = torch.compile(model) # requires PyTorch 2.0 (optional)
|
||||||
|
|
||||||
# load the tokenizer
|
# load the tokenizer
|
||||||
enc = Tokenizer()
|
assert checkpoint["config"]["dataset"] == "tinystories" # TODO: generalize
|
||||||
|
tokenizer_model = get_tokenizer_model_path(vocab_size=gptconf.vocab_size)
|
||||||
|
enc = Tokenizer(tokenizer_model=tokenizer_model)
|
||||||
|
|
||||||
# encode the beginning of the prompt
|
# encode the beginning of the prompt
|
||||||
if start.startswith('FILE:'):
|
if start.startswith('FILE:'):
|
||||||
|
|||||||
+29
-8
@@ -120,9 +120,7 @@ def train_vocab(vocab_size):
|
|||||||
|
|
||||||
def process_shard(args, vocab_size):
|
def process_shard(args, vocab_size):
|
||||||
shard_id, shard = args
|
shard_id, shard = args
|
||||||
tokenizer_model = None
|
tokenizer_model = get_tokenizer_model_path()
|
||||||
if vocab_size > 0:
|
|
||||||
tokenizer_model = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}.model")
|
|
||||||
enc = Tokenizer(tokenizer_model)
|
enc = Tokenizer(tokenizer_model)
|
||||||
with open(shard, "r") as f:
|
with open(shard, "r") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
@@ -171,10 +169,12 @@ def pretokenize(vocab_size):
|
|||||||
class PretokDataset(torch.utils.data.IterableDataset):
|
class PretokDataset(torch.utils.data.IterableDataset):
|
||||||
"""Loads pretokenized examples from disk and yields them as PyTorch tensors."""
|
"""Loads pretokenized examples from disk and yields them as PyTorch tensors."""
|
||||||
|
|
||||||
def __init__(self, split, max_seq_len):
|
def __init__(self, split, max_seq_len, vocab_size, vocab_source):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.split = split
|
self.split = split
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.vocab_source = vocab_source
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
# get worker info within a DataLoader
|
# get worker info within a DataLoader
|
||||||
@@ -186,8 +186,14 @@ class PretokDataset(torch.utils.data.IterableDataset):
|
|||||||
seed = 42 + worker_id + 1337 * rank
|
seed = 42 + worker_id + 1337 * rank
|
||||||
rng = random.Random(seed)
|
rng = random.Random(seed)
|
||||||
print(f"Created a PretokDataset with rng seed {seed}")
|
print(f"Created a PretokDataset with rng seed {seed}")
|
||||||
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
if self.vocab_source == "llama2":
|
||||||
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.bin")))
|
# the .bin files are right along the .json files
|
||||||
|
bin_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
||||||
|
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
|
||||||
|
elif self.vocab_source == "custom":
|
||||||
|
# the .bin files are in tok{N} directory
|
||||||
|
bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{self.vocab_size}")
|
||||||
|
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
|
||||||
# train/test split. let's use only shard 0 for test split, rest train
|
# train/test split. let's use only shard 0 for test split, rest train
|
||||||
shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1]
|
shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1]
|
||||||
while True:
|
while True:
|
||||||
@@ -209,12 +215,25 @@ class PretokDataset(torch.utils.data.IterableDataset):
|
|||||||
y = chunk[1:]
|
y = chunk[1:]
|
||||||
yield x, y
|
yield x, y
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# public interface functions
|
||||||
|
|
||||||
|
def get_tokenizer_model_path(vocab_size):
|
||||||
|
"""
|
||||||
|
Returns path to the sentencepiece tokenizer model for a given vocab size
|
||||||
|
vocab_size = 0 designates the default Llama 2 tokenizer, in that case
|
||||||
|
None is returned.
|
||||||
|
"""
|
||||||
|
if vocab_size == 0:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}.model")
|
||||||
|
|
||||||
class Task:
|
class Task:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def iter_batches(split, batch_size, max_seq_len, device, num_workers=0):
|
def iter_batches(batch_size, device, num_workers=0, **dataset_kwargs):
|
||||||
ds = PretokDataset(split, max_seq_len)
|
ds = PretokDataset(**dataset_kwargs)
|
||||||
dl = torch.utils.data.DataLoader(
|
dl = torch.utils.data.DataLoader(
|
||||||
ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
|
ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
|
||||||
)
|
)
|
||||||
@@ -223,6 +242,8 @@ class Task:
|
|||||||
y = y.to(device, non_blocking=True)
|
y = y.to(device, non_blocking=True)
|
||||||
yield x, y
|
yield x, y
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# CLI for constructing the dataset
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -47,6 +47,8 @@ wandb_run_name = "run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
|||||||
# data
|
# data
|
||||||
batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size
|
batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size
|
||||||
max_seq_len = 256
|
max_seq_len = 256
|
||||||
|
vocab_source = "custom" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained
|
||||||
|
vocab_size = 512
|
||||||
dataset = "tinystories" # tinystories|tinyshakespeare
|
dataset = "tinystories" # tinystories|tinyshakespeare
|
||||||
# model
|
# model
|
||||||
dim = 288
|
dim = 288
|
||||||
@@ -83,6 +85,10 @@ config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
|||||||
lr_decay_iters = max_iters # should be ~= max_iters per Chinchilla
|
lr_decay_iters = max_iters # should be ~= max_iters per Chinchilla
|
||||||
min_lr = 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
|
min_lr = 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
|
||||||
|
|
||||||
|
# validating checks
|
||||||
|
assert vocab_source in ["llama2", "custom"]
|
||||||
|
assert vocab_source == "custom" or vocab_size == 32000, "The vocab from Meta has 32K tokens"
|
||||||
|
|
||||||
# various inits, derived attributes, I/O setup
|
# various inits, derived attributes, I/O setup
|
||||||
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
||||||
if ddp:
|
if ddp:
|
||||||
@@ -128,6 +134,8 @@ iter_batches = partial(
|
|||||||
task.iter_batches,
|
task.iter_batches,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
vocab_source=vocab_source,
|
||||||
device=device,
|
device=device,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
)
|
)
|
||||||
@@ -142,7 +150,7 @@ model_args = dict(
|
|||||||
n_layers=n_layers,
|
n_layers=n_layers,
|
||||||
n_heads=n_heads,
|
n_heads=n_heads,
|
||||||
n_kv_heads=n_heads,
|
n_kv_heads=n_heads,
|
||||||
vocab_size=32000,
|
vocab_size=vocab_size,
|
||||||
multiple_of=multiple_of,
|
multiple_of=multiple_of,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
@@ -206,7 +214,7 @@ def estimate_loss():
|
|||||||
out = {}
|
out = {}
|
||||||
model.eval()
|
model.eval()
|
||||||
for split in ["train", "val"]:
|
for split in ["train", "val"]:
|
||||||
batch_iter = iter_batches(split)
|
batch_iter = iter_batches(split=split)
|
||||||
losses = torch.zeros(eval_iters) # keep on CPU
|
losses = torch.zeros(eval_iters) # keep on CPU
|
||||||
for k in range(eval_iters):
|
for k in range(eval_iters):
|
||||||
X, Y = next(batch_iter)
|
X, Y = next(batch_iter)
|
||||||
@@ -238,7 +246,7 @@ if wandb_log and master_process:
|
|||||||
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
|
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
train_batch_iter = iter_batches("train")
|
train_batch_iter = iter_batches(split="train")
|
||||||
X, Y = next(train_batch_iter) # fetch the very first batch
|
X, Y = next(train_batch_iter) # fetch the very first batch
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
local_iter_num = 0 # number of iterations in the lifetime of this process
|
local_iter_num = 0 # number of iterations in the lifetime of this process
|
||||||
|
|||||||
Reference in New Issue
Block a user