light touchups to export script so one doesn't need to pass in a slash at the end

This commit is contained in:
Andrej Karpathy
2023-07-27 05:08:45 +00:00
parent 5f681b64b1
commit 530ef8e778
+4 -6
View File
@@ -1,6 +1,7 @@
""" """
This script exports the Llama 2 weights in llama2c.bin format. This script exports the Llama 2 weights in llama2c.bin format.
""" """
import os
import sys import sys
import struct import struct
from pathlib import Path from pathlib import Path
@@ -89,16 +90,13 @@ def concat_weights(models):
def load_and_export(model_path, output_path): def load_and_export(model_path, output_path):
with open(model_path + 'params.json') as f: params_path = os.path.join(model_path, 'params.json')
with open(params_path) as f:
params = json.load(f) params = json.load(f)
print(params) print(params)
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth'))) model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
models = [] models = [torch.load(p, map_location='cpu') for p in model_paths]
for i in model_paths:
print(f'Loading {i}')
models.append(torch.load(i, map_location='cpu'))
state_dict = concat_weights(models) state_dict = concat_weights(models)
del models del models
export(params, state_dict, output_path) export(params, state_dict, output_path)