diff --git a/train.py b/train.py index 811dd8a..dbf0b24 100644 --- a/train.py +++ b/train.py @@ -212,7 +212,7 @@ def estimate_loss(): X, Y = next(batch_iter) with ctx: logits = model(X, Y) - loss = model.last_loss + loss = raw_model.last_loss losses[k] = loss.item() out[split] = losses.mean() model.train() @@ -296,7 +296,7 @@ while True: model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 with ctx: logits = model(X, Y) - loss = model.last_loss + loss = raw_model.last_loss loss = loss / gradient_accumulation_steps # immediately async prefetch next batch while model is doing the forward pass on the GPU X, Y = next(train_batch_iter)