revive tests. now that we have a tiny stories260K model this only requires a 2MB download. phew

This commit is contained in:
Andrej Karpathy
2023-08-13 21:22:44 +00:00
parent 0805cb2c31
commit f0024cfc88
2 changed files with 70 additions and 33 deletions
+58 -28
View File
@@ -4,37 +4,65 @@ $ pytest
"""
import os
import pytest # pip install pytest
import requests
import subprocess
import torch
from model import ModelArgs, Transformer
from tokenizer import Tokenizer
def test_argmax_inference():
"""
Only the simplest test for now: run inference with temperature 0
(for determinism) in both C and PyTorch, and see that the sampled tokens
are the same.
"""
test_ckpt_dir = "out" # TODO create a dummy test checkpoint for this?
# -----------------------------------------------------------------------------
# test utilities
# run C version
model_path = os.path.join(test_ckpt_dir, "model.bin")
command = ["./run", model_path, "0.0"]
proc = subprocess.Popen(command, stdout=subprocess.PIPE)
c_tokens = []
for line in proc.stdout:
token = int(line.decode('utf-8').strip())
c_tokens.append(token)
proc.wait()
#print(c_tokens)
test_ckpt_dir = "test"
# run PyTorch version
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt_path = os.path.join(test_ckpt_dir, "ckpt.pt")
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = ModelArgs(**checkpoint['model_args'])
def download_file(url, filename):
print(f"Downloading {url} to {filename}")
response = requests.get(url, stream=True)
response.raise_for_status() # Raise an HTTPError on bad status code
with open(filename, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
def attempt_download_files():
os.makedirs(test_ckpt_dir, exist_ok=True)
root_url = "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K"
need = ["stories260K.bin", "stories260K.pt", "tok512.bin", "tok512.model"]
for file in need:
url = os.path.join(root_url, file)
filename = os.path.join(test_ckpt_dir, file)
if not os.path.exists(filename):
download_file(url, filename)
expected_stdout = b'Once upon a time, there was a little girl named Lily. She loved to play outside in the park. One day, she saw a big, red ball. She wanted to play with it, but it was too high.\nLily\'s mom said, "Lily, let\'s go to the park." Lily was sad and didn\'t know what to do. She said, "I want to play with your ball, but I can\'t find it."\nLily was sad and didn\'t know what to do. She said, "I\'m sorry, Lily. I didn\'t know what to do."\nLily didn\'t want to help her mom, so she'
# -----------------------------------------------------------------------------
# actual tests
def test_runc():
""" Forwards a model against a known-good desired outcome in run.c for 200 steps"""
model_path = os.path.join(test_ckpt_dir, "stories260K.bin")
tokenizer_path = os.path.join(test_ckpt_dir, "tok512.bin")
command = ["./run", model_path, "-z", tokenizer_path, "-t", "0.0", "-n", "200"]
proc = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = proc.communicate()
# strip the very last \n that is added by run.c for aesthetic reasons
stdout = stdout[:-1]
assert stdout == expected_stdout
def test_python():
""" Forwards a model against a known-good desired outcome in sample.py for 200 steps"""
device = "cpu" # stories260K is small enough to just breeze through it on CPU
checkpoint = os.path.join(test_ckpt_dir, "stories260K.pt")
checkpoint_dict = torch.load(checkpoint, map_location=device)
gptconf = ModelArgs(**checkpoint_dict['model_args'])
model = Transformer(gptconf)
state_dict = checkpoint['model']
state_dict = checkpoint_dict['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
@@ -44,10 +72,12 @@ def test_argmax_inference():
model.to(device)
x = torch.tensor([[1]], dtype=torch.long, device=device) # 1 is BOS
with torch.inference_mode():
y = model.generate(x, max_new_tokens=gptconf.max_seq_len, temperature=0.0)
y = model.generate(x, max_new_tokens=200, temperature=0.0)
pt_tokens = y[0].tolist()
pt_tokens = pt_tokens[1:] # remove BOS
#print(pt_tokens)
# compare
assert c_tokens == pt_tokens
tokenizer_model = os.path.join(test_ckpt_dir, "tok512.model")
enc = Tokenizer(tokenizer_model=tokenizer_model)
text = enc.decode(pt_tokens)
text = text.encode('ascii') # turn into bytes
assert text == expected_stdout