Train vocab in Python
This commit is contained in:
+16
-10
@@ -13,6 +13,7 @@ from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
@@ -97,16 +98,21 @@ def train_vocab(vocab_size):
|
||||
of.write(text + "\n")
|
||||
print(f"Size is: {os.path.getsize(tiny_file) / 1024 / 1024:.2f} MB")
|
||||
|
||||
# 2) run the train_vocab.sh script that trains the sentencepiece model
|
||||
print("Will now train the vocab with:")
|
||||
cmd = f"bash train_vocab.sh {tiny_file} {prefix} {vocab_size}"
|
||||
print(cmd)
|
||||
print("OK? [y/N] ")
|
||||
dec = input()
|
||||
if dec.lower() != "y":
|
||||
print("Exiting...")
|
||||
return
|
||||
os.system(cmd)
|
||||
# 2) train the sentencepiece model
|
||||
print("Will now train the vocab...")
|
||||
|
||||
spm.SentencePieceTrainer.train(input=tiny_file,
|
||||
model_prefix=prefix,
|
||||
model_type="bpe",
|
||||
vocab_size=vocab_size,
|
||||
self_test_sample_size=0,
|
||||
input_format="text",
|
||||
character_coverage=1.0,
|
||||
split_digits=True,
|
||||
allow_whitespace_only_pieces=True,
|
||||
byte_fallback=True,
|
||||
unk_surface=r" \342\201\207 ",
|
||||
normalization_rule_name="identity")
|
||||
|
||||
# 3) optional cleanup, ask the user if they'd like to delete tiny.txt
|
||||
dec = input(f"Delete the temporary file {tiny_file}? [y/N] ")
|
||||
|
||||
Reference in New Issue
Block a user