new model export: versions 0 (legacy) and 1
This commit is contained in:
@@ -29,6 +29,7 @@ from torch.distributed import destroy_process_group, init_process_group
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from tinystories import Task
|
||||
from export import model_export
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# I/O
|
||||
@@ -287,7 +288,7 @@ while True:
|
||||
}
|
||||
print(f"saving checkpoint to {out_dir}")
|
||||
torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt"))
|
||||
raw_model.export(os.path.join(out_dir, "model.bin"))
|
||||
model_export(raw_model, os.path.join(out_dir, "model.bin"), version=0)
|
||||
if iter_num == 0 and eval_only:
|
||||
break
|
||||
|
||||
|
||||
Reference in New Issue
Block a user