diff --git a/sample.py b/sample.py index d2f56ea..c203a8c 100644 --- a/sample.py +++ b/sample.py @@ -52,7 +52,7 @@ if compile: model = torch.compile(model) # requires PyTorch 2.0 (optional) # load the tokenizer -vocab_source = checkpoint_dict.get("vocab_source", "llama2") +vocab_source = checkpoint_dict["config"].get("vocab_source", "llama2") vocab_size = gptconf.vocab_size if tokenizer: # a specific tokenizer is provided, use it