Speed up tinystories pretokenize command
This commit is contained in:
+25
-23
@@ -8,7 +8,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import List
|
from typing import List
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
@@ -66,34 +66,35 @@ def download():
|
|||||||
print(f"Number of shards: {len(shard_filenames)}")
|
print(f"Number of shards: {len(shard_filenames)}")
|
||||||
print(f"Example story:\n{data[0]}")
|
print(f"Example story:\n{data[0]}")
|
||||||
|
|
||||||
def pretokenize():
|
|
||||||
|
def process_shard(args):
|
||||||
|
shard_id, shard = args
|
||||||
enc = Tokenizer()
|
enc = Tokenizer()
|
||||||
|
with open(shard, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
all_tokens = []
|
||||||
|
for example in tqdm(data, position=shard_id):
|
||||||
|
text = example["story"]
|
||||||
|
text = text.strip() # get rid of leading/trailing whitespace
|
||||||
|
tokens = enc.encode(text, bos=True, eos=False) # encode the text, use BOS
|
||||||
|
all_tokens.extend(tokens)
|
||||||
|
# convert to uint16 nparray
|
||||||
|
all_tokens = np.array(all_tokens, dtype=np.uint16)
|
||||||
|
# write to disk
|
||||||
|
tokenized_filename = shard.replace(".json", ".bin")
|
||||||
|
with open(tokenized_filename, "wb") as f:
|
||||||
|
f.write(all_tokens.tobytes())
|
||||||
|
print(f"Saved {tokenized_filename}")
|
||||||
|
|
||||||
def process_shard(shard):
|
|
||||||
with open(shard, "r") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
all_tokens = []
|
|
||||||
for example in tqdm(data):
|
|
||||||
text = example["story"]
|
|
||||||
text = text.strip() # get rid of leading/trailing whitespace
|
|
||||||
tokens = enc.encode(text, bos=True, eos=False) # encode the text, use BOS
|
|
||||||
all_tokens.extend(tokens)
|
|
||||||
# convert to uint16 nparray
|
|
||||||
all_tokens = np.array(all_tokens, dtype=np.uint16)
|
|
||||||
# write to disk
|
|
||||||
tokenized_filename = shard.replace(".json", ".bin")
|
|
||||||
with open(tokenized_filename, "wb") as f:
|
|
||||||
f.write(all_tokens.tobytes())
|
|
||||||
print(f"Saved {tokenized_filename}")
|
|
||||||
|
|
||||||
|
def pretokenize():
|
||||||
# iterate the shards and tokenize all of them one by one
|
# iterate the shards and tokenize all of them one by one
|
||||||
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
||||||
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
|
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
|
||||||
|
|
||||||
# process all the shards in a threadpool
|
# process all the shards in a process pool
|
||||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
with ProcessPoolExecutor() as executor:
|
||||||
executor.map(process_shard, shard_filenames)
|
executor.map(process_shard, enumerate(shard_filenames))
|
||||||
|
|
||||||
print("Done.")
|
print("Done.")
|
||||||
|
|
||||||
|
|
||||||
@@ -163,4 +164,5 @@ if __name__ == "__main__":
|
|||||||
"download": download,
|
"download": download,
|
||||||
"pretokenize": pretokenize,
|
"pretokenize": pretokenize,
|
||||||
}
|
}
|
||||||
fun[args.stage]()
|
fun[args.stage]()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user