diff --git a/tinystories.py b/tinystories.py index 9a69cb8..419e0d5 100644 --- a/tinystories.py +++ b/tinystories.py @@ -8,7 +8,7 @@ import json import os import random from typing import List -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ProcessPoolExecutor import numpy as np import requests @@ -66,34 +66,35 @@ def download(): print(f"Number of shards: {len(shard_filenames)}") print(f"Example story:\n{data[0]}") -def pretokenize(): + +def process_shard(args): + shard_id, shard = args enc = Tokenizer() + with open(shard, "r") as f: + data = json.load(f) + all_tokens = [] + for example in tqdm(data, position=shard_id): + text = example["story"] + text = text.strip() # get rid of leading/trailing whitespace + tokens = enc.encode(text, bos=True, eos=False) # encode the text, use BOS + all_tokens.extend(tokens) + # convert to uint16 nparray + all_tokens = np.array(all_tokens, dtype=np.uint16) + # write to disk + tokenized_filename = shard.replace(".json", ".bin") + with open(tokenized_filename, "wb") as f: + f.write(all_tokens.tobytes()) + print(f"Saved {tokenized_filename}") - def process_shard(shard): - with open(shard, "r") as f: - data = json.load(f) - all_tokens = [] - for example in tqdm(data): - text = example["story"] - text = text.strip() # get rid of leading/trailing whitespace - tokens = enc.encode(text, bos=True, eos=False) # encode the text, use BOS - all_tokens.extend(tokens) - # convert to uint16 nparray - all_tokens = np.array(all_tokens, dtype=np.uint16) - # write to disk - tokenized_filename = shard.replace(".json", ".bin") - with open(tokenized_filename, "wb") as f: - f.write(all_tokens.tobytes()) - print(f"Saved {tokenized_filename}") +def pretokenize(): # iterate the shards and tokenize all of them one by one data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json"))) - # process all the shards in a threadpool - with ThreadPoolExecutor(max_workers=8) as executor: - executor.map(process_shard, shard_filenames) - + # process all the shards in a process pool + with ProcessPoolExecutor() as executor: + executor.map(process_shard, enumerate(shard_filenames)) print("Done.") @@ -163,4 +164,5 @@ if __name__ == "__main__": "download": download, "pretokenize": pretokenize, } - fun[args.stage]() \ No newline at end of file + fun[args.stage]() +