ok i can train and sample a model with a custom tokenizer
This commit is contained in:
+29
-8
@@ -120,9 +120,7 @@ def train_vocab(vocab_size):
|
||||
|
||||
def process_shard(args, vocab_size):
|
||||
shard_id, shard = args
|
||||
tokenizer_model = None
|
||||
if vocab_size > 0:
|
||||
tokenizer_model = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}.model")
|
||||
tokenizer_model = get_tokenizer_model_path()
|
||||
enc = Tokenizer(tokenizer_model)
|
||||
with open(shard, "r") as f:
|
||||
data = json.load(f)
|
||||
@@ -171,10 +169,12 @@ def pretokenize(vocab_size):
|
||||
class PretokDataset(torch.utils.data.IterableDataset):
|
||||
"""Loads pretokenized examples from disk and yields them as PyTorch tensors."""
|
||||
|
||||
def __init__(self, split, max_seq_len):
|
||||
def __init__(self, split, max_seq_len, vocab_size, vocab_source):
|
||||
super().__init__()
|
||||
self.split = split
|
||||
self.max_seq_len = max_seq_len
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab_source = vocab_source
|
||||
|
||||
def __iter__(self):
|
||||
# get worker info within a DataLoader
|
||||
@@ -186,8 +186,14 @@ class PretokDataset(torch.utils.data.IterableDataset):
|
||||
seed = 42 + worker_id + 1337 * rank
|
||||
rng = random.Random(seed)
|
||||
print(f"Created a PretokDataset with rng seed {seed}")
|
||||
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
||||
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.bin")))
|
||||
if self.vocab_source == "llama2":
|
||||
# the .bin files are right along the .json files
|
||||
bin_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
||||
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
|
||||
elif self.vocab_source == "custom":
|
||||
# the .bin files are in tok{N} directory
|
||||
bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{self.vocab_size}")
|
||||
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
|
||||
# train/test split. let's use only shard 0 for test split, rest train
|
||||
shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1]
|
||||
while True:
|
||||
@@ -209,12 +215,25 @@ class PretokDataset(torch.utils.data.IterableDataset):
|
||||
y = chunk[1:]
|
||||
yield x, y
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# public interface functions
|
||||
|
||||
def get_tokenizer_model_path(vocab_size):
|
||||
"""
|
||||
Returns path to the sentencepiece tokenizer model for a given vocab size
|
||||
vocab_size = 0 designates the default Llama 2 tokenizer, in that case
|
||||
None is returned.
|
||||
"""
|
||||
if vocab_size == 0:
|
||||
return None
|
||||
else:
|
||||
return os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}.model")
|
||||
|
||||
class Task:
|
||||
|
||||
@staticmethod
|
||||
def iter_batches(split, batch_size, max_seq_len, device, num_workers=0):
|
||||
ds = PretokDataset(split, max_seq_len)
|
||||
def iter_batches(batch_size, device, num_workers=0, **dataset_kwargs):
|
||||
ds = PretokDataset(**dataset_kwargs)
|
||||
dl = torch.utils.data.DataLoader(
|
||||
ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
|
||||
)
|
||||
@@ -223,6 +242,8 @@ class Task:
|
||||
y = y.to(device, non_blocking=True)
|
||||
yield x, y
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# CLI for constructing the dataset
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user