Merge pull request #167 from mzcu/pretokenize-speedup

Speed up tinystories pretokenize command
This commit is contained in:
Andrej
2023-08-05 15:14:51 -07:00
committed by GitHub
+11 -9
View File
@@ -8,7 +8,7 @@ import json
import os import os
import random import random
from typing import List from typing import List
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor
import numpy as np import numpy as np
import requests import requests
@@ -66,14 +66,14 @@ def download():
print(f"Number of shards: {len(shard_filenames)}") print(f"Number of shards: {len(shard_filenames)}")
print(f"Example story:\n{data[0]}") print(f"Example story:\n{data[0]}")
def pretokenize():
enc = Tokenizer()
def process_shard(shard): def process_shard(args):
shard_id, shard = args
enc = Tokenizer()
with open(shard, "r") as f: with open(shard, "r") as f:
data = json.load(f) data = json.load(f)
all_tokens = [] all_tokens = []
for example in tqdm(data): for example in tqdm(data, position=shard_id):
text = example["story"] text = example["story"]
text = text.strip() # get rid of leading/trailing whitespace text = text.strip() # get rid of leading/trailing whitespace
tokens = enc.encode(text, bos=True, eos=False) # encode the text, use BOS tokens = enc.encode(text, bos=True, eos=False) # encode the text, use BOS
@@ -86,14 +86,15 @@ def pretokenize():
f.write(all_tokens.tobytes()) f.write(all_tokens.tobytes())
print(f"Saved {tokenized_filename}") print(f"Saved {tokenized_filename}")
def pretokenize():
# iterate the shards and tokenize all of them one by one # iterate the shards and tokenize all of them one by one
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json"))) shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
# process all the shards in a threadpool # process all the shards in a process pool
with ThreadPoolExecutor(max_workers=8) as executor: with ProcessPoolExecutor() as executor:
executor.map(process_shard, shard_filenames) executor.map(process_shard, enumerate(shard_filenames))
print("Done.") print("Done.")
@@ -164,3 +165,4 @@ if __name__ == "__main__":
"pretokenize": pretokenize, "pretokenize": pretokenize,
} }
fun[args.stage]() fun[args.stage]()