delete the save_torchscript export file, but copy its content to the new export.py for the future maybe
This commit is contained in:
@@ -14,6 +14,9 @@ Among the "output" versions of .bin files:
|
|||||||
|
|
||||||
This script aspires to provide all of these conversions.
|
This script aspires to provide all of these conversions.
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
|
import gzip
|
||||||
|
import shutil
|
||||||
import struct
|
import struct
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -343,6 +346,34 @@ def model_export(model, filepath, version):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown version {version}")
|
raise ValueError(f"unknown version {version}")
|
||||||
|
|
||||||
|
def torchscript_export(model, filepath, zero_params=False, gzip_output=False):
|
||||||
|
"""
|
||||||
|
(This was submitted via a PR earlier. Leaving it here, but "orphaned" for now)
|
||||||
|
Saves the model as a TorchScript.
|
||||||
|
The resulting file can be loaded in C++ code and then used for training or
|
||||||
|
inference with:
|
||||||
|
#include <torch/script.h>
|
||||||
|
torch::jit::Module module = torch::jit::load("model.pt")
|
||||||
|
Note that the serialized model includes the initial parameters and with the default
|
||||||
|
ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute
|
||||||
|
the model parameters separately you can zero out the parameters before saving it and
|
||||||
|
it will gzip down to 780K.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# If requested zero params before saving the model. This is useful in
|
||||||
|
# conjunction with gzip_output.
|
||||||
|
if zero_params:
|
||||||
|
for p in model.parameters():
|
||||||
|
p.detach().zero_()
|
||||||
|
|
||||||
|
torch.jit.save(torch.jit.script(model), filepath)
|
||||||
|
|
||||||
|
if gzip_output:
|
||||||
|
with open(filepath, "rb") as f_in:
|
||||||
|
with gzip.open(f"{filepath}.gz", "wb") as f_out:
|
||||||
|
shutil.copyfileobj(f_in, f_out)
|
||||||
|
os.unlink(filepath)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# CLI entrypoint
|
# CLI entrypoint
|
||||||
|
|
||||||
|
|||||||
@@ -1,66 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
"""Saves the model as a TorchScript.
|
|
||||||
|
|
||||||
Usage examples:
|
|
||||||
./save_torchscript.py
|
|
||||||
./save_torchscript.py --dim=300
|
|
||||||
./save_torchscript.py --gzip_output=True --zero_params=True
|
|
||||||
|
|
||||||
The resulting file can be loaded in C++ code and then used for training or
|
|
||||||
inference with:
|
|
||||||
#include <torch/script.h>
|
|
||||||
torch::jit::Module module = torch::jit::load("model.pt")
|
|
||||||
|
|
||||||
Note that the serialized model includes the initial parameters and with the default
|
|
||||||
ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute
|
|
||||||
the model parameters separately you can zero out the parameters before saving it and
|
|
||||||
it will gzip down to 780K.
|
|
||||||
"""
|
|
||||||
import gzip
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
from inspect import signature
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from model import ModelArgs, Transformer
|
|
||||||
|
|
||||||
# Model args config
|
|
||||||
dim = 288
|
|
||||||
n_layers = 6
|
|
||||||
n_heads = 6
|
|
||||||
n_kv_heads = n_heads
|
|
||||||
multiple_of = 32
|
|
||||||
max_seq_len = 256
|
|
||||||
dropout = 0.0
|
|
||||||
vocab_size = 32000
|
|
||||||
norm_eps = 1e-5
|
|
||||||
# Save config
|
|
||||||
model_path = "model.pt"
|
|
||||||
zero_params = False
|
|
||||||
gzip_output = False
|
|
||||||
# Allow config overrides
|
|
||||||
exec(open("configurator.py").read())
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
model_args = {k: globals()[k] for k in signature(ModelArgs).parameters}
|
|
||||||
model = Transformer(ModelArgs(**model_args))
|
|
||||||
|
|
||||||
# If requested zero params before saving the model. This is useful in
|
|
||||||
# conjunction with gzip_output.
|
|
||||||
if zero_params:
|
|
||||||
for p in model.parameters():
|
|
||||||
p.detach().zero_()
|
|
||||||
|
|
||||||
torch.jit.save(torch.jit.script(model), model_path)
|
|
||||||
|
|
||||||
if gzip_output:
|
|
||||||
with open(model_path, "rb") as f_in:
|
|
||||||
with gzip.open(f"{model_path}.gz", "wb") as f_out:
|
|
||||||
shutil.copyfileobj(f_in, f_out)
|
|
||||||
os.unlink(model_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
Reference in New Issue
Block a user