diff --git a/tinystories.py b/tinystories.py index 690cb02..90d576b 100644 --- a/tinystories.py +++ b/tinystories.py @@ -196,6 +196,7 @@ class PretokDataset(torch.utils.data.IterableDataset): shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin"))) # train/test split. let's use only shard 0 for test split, rest train shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1] + assert len(shard_filenames)>0, f"No bin files found in {bin_dir}" while True: rng.shuffle(shard_filenames) for shard in shard_filenames: