From 98b515e44d23687258c08ec19e1e2458b57aa5ae Mon Sep 17 00:00:00 2001 From: Nicolas Pinto Date: Sun, 6 Aug 2023 14:48:47 -0700 Subject: [PATCH] FIX: model.generate() This patch fixes a simple bug in `generate()` due to model's `forward()` only returning logits and not losses since `f2e34e6b0ac55accd6ba930a04c6f683f5158b29`. --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index 66304e7..f7edbb6 100644 --- a/model.py +++ b/model.py @@ -317,7 +317,7 @@ class Transformer(nn.Module): # if the sequence context is growing too long we must crop it at block_size idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:] # forward the model to get the logits for the index in the sequence - logits, _ = self(idx_cond) + logits = self(idx_cond) logits = logits[:, -1, :] # crop to just the final time step if temperature == 0.0: # "sample" the single most likely index