Model args in save script

This commit is contained in:
Michael Cusack
2023-08-04 17:07:41 +07:00
parent fd5e2cc7bc
commit f67185958b
+11 -2
View File
@@ -16,9 +16,18 @@ import torch
from model import ModelArgs, Transformer from model import ModelArgs, Transformer
def main() -> None: def main() -> None:
model_args = ModelArgs(dim=512, n_layers=6, n_heads=8, vocab_size=32000) model = Transformer(
model = Transformer(model_args) ModelArgs(
dim=288,
n_layers=6,
n_heads=6,
multiple_of=32,
dropout=0.0,
vocab_size=32000,
)
)
torch.jit.save(torch.jit.script(model), "model.pt") torch.jit.save(torch.jit.script(model), "model.pt")