fix sample.py from tokenizer changes before
This commit is contained in:
@@ -51,11 +51,16 @@ if compile:
|
|||||||
print("Compiling the model...")
|
print("Compiling the model...")
|
||||||
model = torch.compile(model) # requires PyTorch 2.0 (optional)
|
model = torch.compile(model) # requires PyTorch 2.0 (optional)
|
||||||
|
|
||||||
# load the tokenizer, either provided, or attempt to find it
|
# load the tokenizer
|
||||||
|
vocab_source = checkpoint_dict.get("vocab_source", "llama2")
|
||||||
|
vocab_size = gptconf.vocab_size
|
||||||
if tokenizer:
|
if tokenizer:
|
||||||
|
# a specific tokenizer is provided, use it
|
||||||
tokenizer_model = tokenizer
|
tokenizer_model = tokenizer
|
||||||
else:
|
else:
|
||||||
tokenizer_model = get_tokenizer_model_path(vocab_size=gptconf.vocab_size)
|
# let's try to find the tokenizer model automatically. bit gross here...
|
||||||
|
query_vocab_size = 0 if vocab_source == "llama2" else vocab_size
|
||||||
|
tokenizer_model = get_tokenizer_model_path(vocab_size=query_vocab_size)
|
||||||
enc = Tokenizer(tokenizer_model=tokenizer_model)
|
enc = Tokenizer(tokenizer_model=tokenizer_model)
|
||||||
|
|
||||||
# encode the beginning of the prompt
|
# encode the beginning of the prompt
|
||||||
|
|||||||
Reference in New Issue
Block a user