Add options to save_torchscript

This commit is contained in:
Michael Cusack
2023-08-04 23:10:14 +07:00
parent 113c675bc9
commit 4b3a41b8fc
Regular → Executable
+43 -18
View File
@@ -1,6 +1,11 @@
#!/usr/bin/env python
"""Saves the model as a TorchScript.
Usage examples:
./save_torchscript.py
./save_torchscript.py --dim=300
./save_torchscript.py --gzip_output=True --zero_params=True
The resulting file can be loaded in C++ code and then used for training or
inference with:
#include <torch/script.h>
@@ -8,33 +13,53 @@ inference with:
Note that the serialized model includes the initial parameters and with the default
ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute
the model parameters separately and you can zero out the parameters before saving it
and it will gzip down to 780K:
for p in model.parameters():
p.detach().zero_()
the model parameters separately you can zero out the parameters before saving it and
it will gzip down to 780K.
"""
import glob
import gzip
import os
import sys
from typing import List
import shutil
from inspect import signature
import torch
from model import ModelArgs, Transformer
# Model args
dim = 288
n_layers = 6
n_heads = 6
n_kv_heads = n_heads
multiple_of = 32
max_seq_len = 256
dropout = 0.0
vocab_size = 32000
norm_eps = 1e-5
# Save config
model_path = "model.pt"
zero_params = False
gzip_output = False
exec(open("configurator.py").read())
def main() -> None:
model = Transformer(
ModelArgs(
dim=288,
n_layers=6,
n_heads=6,
multiple_of=32,
dropout=0.0,
vocab_size=32000,
)
)
torch.jit.save(torch.jit.script(model), "model.pt")
model_args = {k: globals()[k] for k in signature(ModelArgs).parameters}
model = Transformer(ModelArgs(**model_args))
# If requested zero params before saving the model. This is usful in
# conjunction with gzip_output.
if zero_params:
for p in model.parameters():
p.detach().zero_()
torch.jit.save(torch.jit.script(model), model_path)
if gzip_output:
with open(model_path, "rb") as f_in:
with gzip.open(f"{model_path}.gz", "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
os.unlink(model_path)
if __name__ == "__main__":