light touchups to export script so one doesn't need to pass in a slash at the end

This commit is contained in:
Andrej Karpathy
2023-07-27 05:08:45 +00:00
parent 5f681b64b1
commit 530ef8e778
+4 -6
View File
@@ -1,6 +1,7 @@
"""
This script exports the Llama 2 weights in llama2c.bin format.
"""
import os
import sys
import struct
from pathlib import Path
@@ -89,16 +90,13 @@ def concat_weights(models):
def load_and_export(model_path, output_path):
with open(model_path + 'params.json') as f:
params_path = os.path.join(model_path, 'params.json')
with open(params_path) as f:
params = json.load(f)
print(params)
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
models = []
for i in model_paths:
print(f'Loading {i}')
models.append(torch.load(i, map_location='cpu'))
models = [torch.load(p, map_location='cpu') for p in model_paths]
state_dict = concat_weights(models)
del models
export(params, state_dict, output_path)