diff --git a/sample.py b/sample.py index b26e277..d2f56ea 100644 --- a/sample.py +++ b/sample.py @@ -51,11 +51,16 @@ if compile: print("Compiling the model...") model = torch.compile(model) # requires PyTorch 2.0 (optional) -# load the tokenizer, either provided, or attempt to find it +# load the tokenizer +vocab_source = checkpoint_dict.get("vocab_source", "llama2") +vocab_size = gptconf.vocab_size if tokenizer: + # a specific tokenizer is provided, use it tokenizer_model = tokenizer else: - tokenizer_model = get_tokenizer_model_path(vocab_size=gptconf.vocab_size) + # let's try to find the tokenizer model automatically. bit gross here... + query_vocab_size = 0 if vocab_source == "llama2" else vocab_size + tokenizer_model = get_tokenizer_model_path(vocab_size=query_vocab_size) enc = Tokenizer(tokenizer_model=tokenizer_model) # encode the beginning of the prompt