Merge pull request #250 from npinto/master-1

FIX: model.generate(); forward() only returns logits now.
This commit is contained in:
Andrej
2023-08-06 18:43:01 -07:00
committed by GitHub
+1 -1
View File
@@ -317,7 +317,7 @@ class Transformer(nn.Module):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
logits = self(idx_cond)
logits = logits[:, -1, :] # crop to just the final time step
if temperature == 0.0:
# "sample" the single most likely index