diff --git a/export_meta_llama_bin.py b/export_meta_llama_bin.py index 801077b..3d07c1c 100644 --- a/export_meta_llama_bin.py +++ b/export_meta_llama_bin.py @@ -1,6 +1,7 @@ """ This script exports the Llama 2 weights in llama2c.bin format. """ +import os import sys import struct from pathlib import Path @@ -89,16 +90,13 @@ def concat_weights(models): def load_and_export(model_path, output_path): - with open(model_path + 'params.json') as f: + params_path = os.path.join(model_path, 'params.json') + with open(params_path) as f: params = json.load(f) print(params) model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth'))) - models = [] - for i in model_paths: - print(f'Loading {i}') - models.append(torch.load(i, map_location='cpu')) - + models = [torch.load(p, map_location='cpu') for p in model_paths] state_dict = concat_weights(models) del models export(params, state_dict, output_path)