update API of sample.py to be better, small changes here

This commit is contained in:
Andrej Karpathy
2023-08-13 20:31:32 +00:00
parent 1bcb2d18d6
commit 58075b5ac5
2 changed files with 11 additions and 9 deletions
+1 -2
View File
@@ -132,8 +132,7 @@ Watch the tokens stream by, fun! We can also run the PyTorch inference script fo
```bash
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt -P out15M
mv out15M/stories15M.pt out15M/ckpt.pt # sorry the sample script current assumes this directory structure / filename...
python sample.py --out_dir=out15M
python sample.py --checkpoint=out15M/stories15M.pt
```
Which gives the same results. More detailed testing will be done in `test_all.py`. Currently you will need two files to test or sample: both the .bin file, and the .ckpt file inside a directory (see `test_all.py` for details). Sorry this is a bit janky right now, I have to think through running the tests without having to download 200MB of data. But run the tests with pytest:
+10 -7
View File
@@ -12,12 +12,13 @@ from tokenizer import Tokenizer
from tinystories import get_tokenizer_model_path
# -----------------------------------------------------------------------------
out_dir = 'out' # ignored if init_from is not 'resume'
checkpoint = 'out/ckpt.pt'
start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 1 # number of samples to draw
max_new_tokens = 100 # number of tokens generated in each sample
temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 300 # retain only the top_k most likely tokens, clamp others to have 0 probability
tokenizer = "" # override the tokenizer model path
seed = 1337
device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
#dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
@@ -35,11 +36,10 @@ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torc
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
# init from a model saved in a specific directory
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = ModelArgs(**checkpoint['model_args'])
checkpoint_dict = torch.load(checkpoint, map_location=device)
gptconf = ModelArgs(**checkpoint_dict['model_args'])
model = Transformer(gptconf)
state_dict = checkpoint['model']
state_dict = checkpoint_dict['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
@@ -52,8 +52,11 @@ if compile:
print("Compiling the model...")
model = torch.compile(model) # requires PyTorch 2.0 (optional)
# load the tokenizer
tokenizer_model = get_tokenizer_model_path(vocab_size=gptconf.vocab_size)
# load the tokenizer, either provided, or attempt to find it
if tokenizer:
tokenizer_model = tokenizer
else:
tokenizer_model = get_tokenizer_model_path(vocab_size=gptconf.vocab_size)
enc = Tokenizer(tokenizer_model=tokenizer_model)
# encode the beginning of the prompt