update API of sample.py to be better, small changes here
This commit is contained in:
@@ -132,8 +132,7 @@ Watch the tokens stream by, fun! We can also run the PyTorch inference script fo
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt -P out15M
|
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt -P out15M
|
||||||
mv out15M/stories15M.pt out15M/ckpt.pt # sorry the sample script current assumes this directory structure / filename...
|
python sample.py --checkpoint=out15M/stories15M.pt
|
||||||
python sample.py --out_dir=out15M
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Which gives the same results. More detailed testing will be done in `test_all.py`. Currently you will need two files to test or sample: both the .bin file, and the .ckpt file inside a directory (see `test_all.py` for details). Sorry this is a bit janky right now, I have to think through running the tests without having to download 200MB of data. But run the tests with pytest:
|
Which gives the same results. More detailed testing will be done in `test_all.py`. Currently you will need two files to test or sample: both the .bin file, and the .ckpt file inside a directory (see `test_all.py` for details). Sorry this is a bit janky right now, I have to think through running the tests without having to download 200MB of data. But run the tests with pytest:
|
||||||
|
|||||||
@@ -12,12 +12,13 @@ from tokenizer import Tokenizer
|
|||||||
from tinystories import get_tokenizer_model_path
|
from tinystories import get_tokenizer_model_path
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
out_dir = 'out' # ignored if init_from is not 'resume'
|
checkpoint = 'out/ckpt.pt'
|
||||||
start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
|
start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
|
||||||
num_samples = 1 # number of samples to draw
|
num_samples = 1 # number of samples to draw
|
||||||
max_new_tokens = 100 # number of tokens generated in each sample
|
max_new_tokens = 100 # number of tokens generated in each sample
|
||||||
temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
|
temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
|
||||||
top_k = 300 # retain only the top_k most likely tokens, clamp others to have 0 probability
|
top_k = 300 # retain only the top_k most likely tokens, clamp others to have 0 probability
|
||||||
|
tokenizer = "" # override the tokenizer model path
|
||||||
seed = 1337
|
seed = 1337
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
|
||||||
#dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
|
#dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
|
||||||
@@ -35,11 +36,10 @@ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torc
|
|||||||
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
||||||
|
|
||||||
# init from a model saved in a specific directory
|
# init from a model saved in a specific directory
|
||||||
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
|
checkpoint_dict = torch.load(checkpoint, map_location=device)
|
||||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
gptconf = ModelArgs(**checkpoint_dict['model_args'])
|
||||||
gptconf = ModelArgs(**checkpoint['model_args'])
|
|
||||||
model = Transformer(gptconf)
|
model = Transformer(gptconf)
|
||||||
state_dict = checkpoint['model']
|
state_dict = checkpoint_dict['model']
|
||||||
unwanted_prefix = '_orig_mod.'
|
unwanted_prefix = '_orig_mod.'
|
||||||
for k,v in list(state_dict.items()):
|
for k,v in list(state_dict.items()):
|
||||||
if k.startswith(unwanted_prefix):
|
if k.startswith(unwanted_prefix):
|
||||||
@@ -52,8 +52,11 @@ if compile:
|
|||||||
print("Compiling the model...")
|
print("Compiling the model...")
|
||||||
model = torch.compile(model) # requires PyTorch 2.0 (optional)
|
model = torch.compile(model) # requires PyTorch 2.0 (optional)
|
||||||
|
|
||||||
# load the tokenizer
|
# load the tokenizer, either provided, or attempt to find it
|
||||||
tokenizer_model = get_tokenizer_model_path(vocab_size=gptconf.vocab_size)
|
if tokenizer:
|
||||||
|
tokenizer_model = tokenizer
|
||||||
|
else:
|
||||||
|
tokenizer_model = get_tokenizer_model_path(vocab_size=gptconf.vocab_size)
|
||||||
enc = Tokenizer(tokenizer_model=tokenizer_model)
|
enc = Tokenizer(tokenizer_model=tokenizer_model)
|
||||||
|
|
||||||
# encode the beginning of the prompt
|
# encode the beginning of the prompt
|
||||||
|
|||||||
Reference in New Issue
Block a user