light touchups to export script so one doesn't need to pass in a slash at the end
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
This script exports the Llama 2 weights in llama2c.bin format.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import struct
|
||||
from pathlib import Path
|
||||
@@ -89,16 +90,13 @@ def concat_weights(models):
|
||||
|
||||
|
||||
def load_and_export(model_path, output_path):
|
||||
with open(model_path + 'params.json') as f:
|
||||
params_path = os.path.join(model_path, 'params.json')
|
||||
with open(params_path) as f:
|
||||
params = json.load(f)
|
||||
print(params)
|
||||
|
||||
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
|
||||
models = []
|
||||
for i in model_paths:
|
||||
print(f'Loading {i}')
|
||||
models.append(torch.load(i, map_location='cpu'))
|
||||
|
||||
models = [torch.load(p, map_location='cpu') for p in model_paths]
|
||||
state_dict = concat_weights(models)
|
||||
del models
|
||||
export(params, state_dict, output_path)
|
||||
|
||||
Reference in New Issue
Block a user