light touchups to export script so one doesn't need to pass in a slash at the end
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
This script exports the Llama 2 weights in llama2c.bin format.
|
This script exports the Llama 2 weights in llama2c.bin format.
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import struct
|
import struct
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -89,16 +90,13 @@ def concat_weights(models):
|
|||||||
|
|
||||||
|
|
||||||
def load_and_export(model_path, output_path):
|
def load_and_export(model_path, output_path):
|
||||||
with open(model_path + 'params.json') as f:
|
params_path = os.path.join(model_path, 'params.json')
|
||||||
|
with open(params_path) as f:
|
||||||
params = json.load(f)
|
params = json.load(f)
|
||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
|
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
|
||||||
models = []
|
models = [torch.load(p, map_location='cpu') for p in model_paths]
|
||||||
for i in model_paths:
|
|
||||||
print(f'Loading {i}')
|
|
||||||
models.append(torch.load(i, map_location='cpu'))
|
|
||||||
|
|
||||||
state_dict = concat_weights(models)
|
state_dict = concat_weights(models)
|
||||||
del models
|
del models
|
||||||
export(params, state_dict, output_path)
|
export(params, state_dict, output_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user