Model args in save script
This commit is contained in:
+11
-2
@@ -16,9 +16,18 @@ import torch
|
|||||||
|
|
||||||
from model import ModelArgs, Transformer
|
from model import ModelArgs, Transformer
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
model_args = ModelArgs(dim=512, n_layers=6, n_heads=8, vocab_size=32000)
|
model = Transformer(
|
||||||
model = Transformer(model_args)
|
ModelArgs(
|
||||||
|
dim=288,
|
||||||
|
n_layers=6,
|
||||||
|
n_heads=6,
|
||||||
|
multiple_of=32,
|
||||||
|
dropout=0.0,
|
||||||
|
vocab_size=32000,
|
||||||
|
)
|
||||||
|
)
|
||||||
torch.jit.save(torch.jit.script(model), "model.pt")
|
torch.jit.save(torch.jit.script(model), "model.pt")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user