From f67185958b5c3e3c690422430220e3db755e9628 Mon Sep 17 00:00:00 2001 From: Michael Cusack Date: Fri, 4 Aug 2023 17:07:41 +0700 Subject: [PATCH] Model args in save script --- save_model.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/save_model.py b/save_model.py index c80253a..4a19880 100644 --- a/save_model.py +++ b/save_model.py @@ -16,9 +16,18 @@ import torch from model import ModelArgs, Transformer + def main() -> None: - model_args = ModelArgs(dim=512, n_layers=6, n_heads=8, vocab_size=32000) - model = Transformer(model_args) + model = Transformer( + ModelArgs( + dim=288, + n_layers=6, + n_heads=6, + multiple_of=32, + dropout=0.0, + vocab_size=32000, + ) + ) torch.jit.save(torch.jit.script(model), "model.pt")