diff --git a/save_model.py b/save_model.py index c80253a..4a19880 100644 --- a/save_model.py +++ b/save_model.py @@ -16,9 +16,18 @@ import torch from model import ModelArgs, Transformer + def main() -> None: - model_args = ModelArgs(dim=512, n_layers=6, n_heads=8, vocab_size=32000) - model = Transformer(model_args) + model = Transformer( + ModelArgs( + dim=288, + n_layers=6, + n_heads=6, + multiple_of=32, + dropout=0.0, + vocab_size=32000, + ) + ) torch.jit.save(torch.jit.script(model), "model.pt")