new model export: versions 0 (legacy) and 1

This commit is contained in:
Andrej Karpathy
2023-08-19 18:25:20 +00:00
parent bd182289c5
commit 7f551dbfd7
3 changed files with 245 additions and 53 deletions
+2 -1
View File
@@ -29,6 +29,7 @@ from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
from tinystories import Task
from export import model_export
# -----------------------------------------------------------------------------
# I/O
@@ -287,7 +288,7 @@ while True:
}
print(f"saving checkpoint to {out_dir}")
torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt"))
raw_model.export(os.path.join(out_dir, "model.bin"))
model_export(raw_model, os.path.join(out_dir, "model.bin"), version=0)
if iter_num == 0 and eval_only:
break