diff --git a/train.py b/train.py index 70b5109..811dd8a 100644 --- a/train.py +++ b/train.py @@ -211,7 +211,8 @@ def estimate_loss(): for k in range(eval_iters): X, Y = next(batch_iter) with ctx: - logits, loss = model(X, Y) + logits = model(X, Y) + loss = model.last_loss losses[k] = loss.item() out[split] = losses.mean() model.train() @@ -294,7 +295,8 @@ while True: # looking at the source of that context manager, it just toggles this variable model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 with ctx: - logits, loss = model(X, Y) + logits = model(X, Y) + loss = model.last_loss loss = loss / gradient_accumulation_steps # immediately async prefetch next batch while model is doing the forward pass on the GPU X, Y = next(train_batch_iter)