""" Sample from the trained model with PyTorch """ import os import pickle from contextlib import nullcontext import torch import tiktoken from model import ModelArgs, Transformer from tokenizer import Tokenizer from tinystories import get_tokenizer_model_path # ----------------------------------------------------------------------------- out_dir = 'out' # ignored if init_from is not 'resume' start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt" num_samples = 1 # number of samples to draw max_new_tokens = 100 # number of tokens generated in each sample temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions top_k = 300 # retain only the top_k most likely tokens, clamp others to have 0 probability seed = 1337 device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. #dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' dtype = "float32" compile = False # use PyTorch 2.0 to compile the model to be faster exec(open('configurator.py').read()) # overrides from command line or config file # ----------------------------------------------------------------------------- torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) # init from a model saved in a specific directory ckpt_path = os.path.join(out_dir, 'ckpt.pt') checkpoint = torch.load(ckpt_path, map_location=device) gptconf = ModelArgs(**checkpoint['model_args']) model = Transformer(gptconf) state_dict = checkpoint['model'] unwanted_prefix = '_orig_mod.' for k,v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict, strict=False) model.eval() model.to(device) if compile: print("Compiling the model...") model = torch.compile(model) # requires PyTorch 2.0 (optional) # load the tokenizer assert checkpoint["config"]["dataset"] == "tinystories" # TODO: generalize tokenizer_model = get_tokenizer_model_path(vocab_size=gptconf.vocab_size) enc = Tokenizer(tokenizer_model=tokenizer_model) # encode the beginning of the prompt if start.startswith('FILE:'): with open(start[5:], 'r', encoding='utf-8') as f: start = f.read() start_ids = enc.encode(start, bos=True, eos=False) x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) # run generation with torch.no_grad(): with ctx: for k in range(num_samples): y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) print(enc.decode(y[0].tolist())) print('---------------')