modifiying test_all so it can safely run on windows

This commit is contained in:
Ruhollah Majdoddin
2023-08-15 16:01:53 +00:00
parent 66c9f5e6c8
commit 87b11edf27
+11 -5
View File
@@ -30,7 +30,7 @@ def attempt_download_files():
root_url = "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K" root_url = "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K"
need = ["stories260K.bin", "stories260K.pt", "tok512.bin", "tok512.model"] need = ["stories260K.bin", "stories260K.pt", "tok512.bin", "tok512.model"]
for file in need: for file in need:
url = os.path.join(root_url, file) url = root_url + '/' + file #os.path.join inserts \\ on windows
filename = os.path.join(test_ckpt_dir, file) filename = os.path.join(test_ckpt_dir, file)
if not os.path.exists(filename): if not os.path.exists(filename):
download_file(url, filename) download_file(url, filename)
@@ -47,12 +47,16 @@ def test_runc():
model_path = os.path.join(test_ckpt_dir, "stories260K.bin") model_path = os.path.join(test_ckpt_dir, "stories260K.bin")
tokenizer_path = os.path.join(test_ckpt_dir, "tok512.bin") tokenizer_path = os.path.join(test_ckpt_dir, "tok512.bin")
command = ["./run", model_path, "-z", tokenizer_path, "-t", "0.0", "-n", "200"] command = ["./run", model_path, "-z", tokenizer_path, "-t", "0.0", "-n", "200"]
proc = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) with open('err.txt', mode='wb') as fe:
with open('stdout.txt', mode='wb') as fo:
stdout, stderr = proc.communicate() proc = subprocess.Popen(command, stdout=fo, stderr=fe) #pipe in windows terminal does funny things like replacing \n with \r\n
proc.wait()
with open('stdout.txt', mode='r') as f:
stdout = f.read()
# strip the very last \n that is added by run.c for aesthetic reasons # strip the very last \n that is added by run.c for aesthetic reasons
stdout = stdout[:-1] stdout = stdout[:-1].encode('ascii')
assert stdout == expected_stdout assert stdout == expected_stdout
def test_python(): def test_python():
@@ -83,3 +87,5 @@ def test_python():
text = text.encode('ascii') # turn into bytes text = text.encode('ascii') # turn into bytes
assert text == expected_stdout assert text == expected_stdout
test_runc()