FIX: model.generate()

This patch fixes a simple bug in `generate()` due to model's `forward()` only returning logits and not losses since `f2e34e6b0ac55accd6ba930a04c6f683f5158b29`.
This commit is contained in:
Nicolas Pinto
2023-08-06 14:48:47 -07:00
committed by GitHub
parent a7a3aa09b8
commit 98b515e44d
+1 -1
View File
@@ -317,7 +317,7 @@ class Transformer(nn.Module):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
logits = self(idx_cond)
logits = logits[:, -1, :] # crop to just the final time step
if temperature == 0.0:
# "sample" the single most likely index