diff --git a/train.py b/train.py index 34248b8..79b2c8e 100644 --- a/train.py +++ b/train.py @@ -179,7 +179,7 @@ scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16")) # optimizer optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) -if init_from == "resume": +if init_from == "resume" and "optimizer" in checkpoint: optimizer.load_state_dict(checkpoint["optimizer"]) checkpoint = None # free up memory