Add options to save_torchscript

This commit is contained in:
Michael Cusack
2023-08-04 23:10:14 +07:00
parent 113c675bc9
commit 4b3a41b8fc
Regular → Executable
+43 -18
View File
@@ -1,6 +1,11 @@
#!/usr/bin/env python #!/usr/bin/env python
"""Saves the model as a TorchScript. """Saves the model as a TorchScript.
Usage examples:
./save_torchscript.py
./save_torchscript.py --dim=300
./save_torchscript.py --gzip_output=True --zero_params=True
The resulting file can be loaded in C++ code and then used for training or The resulting file can be loaded in C++ code and then used for training or
inference with: inference with:
#include <torch/script.h> #include <torch/script.h>
@@ -8,33 +13,53 @@ inference with:
Note that the serialized model includes the initial parameters and with the default Note that the serialized model includes the initial parameters and with the default
ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute
the model parameters separately and you can zero out the parameters before saving it the model parameters separately you can zero out the parameters before saving it and
and it will gzip down to 780K: it will gzip down to 780K.
for p in model.parameters():
p.detach().zero_()
""" """
import glob import gzip
import os import os
import sys import shutil
from typing import List from inspect import signature
import torch import torch
from model import ModelArgs, Transformer from model import ModelArgs, Transformer
# Model args
dim = 288
n_layers = 6
n_heads = 6
n_kv_heads = n_heads
multiple_of = 32
max_seq_len = 256
dropout = 0.0
vocab_size = 32000
norm_eps = 1e-5
# Save config
model_path = "model.pt"
zero_params = False
gzip_output = False
exec(open("configurator.py").read())
def main() -> None: def main() -> None:
model = Transformer( model_args = {k: globals()[k] for k in signature(ModelArgs).parameters}
ModelArgs( model = Transformer(ModelArgs(**model_args))
dim=288,
n_layers=6, # If requested zero params before saving the model. This is usful in
n_heads=6, # conjunction with gzip_output.
multiple_of=32, if zero_params:
dropout=0.0, for p in model.parameters():
vocab_size=32000, p.detach().zero_()
)
) torch.jit.save(torch.jit.script(model), model_path)
torch.jit.save(torch.jit.script(model), "model.pt")
if gzip_output:
with open(model_path, "rb") as f_in:
with gzip.open(f"{model_path}.gz", "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
os.unlink(model_path)
if __name__ == "__main__": if __name__ == "__main__":