FIX: model.generate()
This patch fixes a simple bug in `generate()` due to model's `forward()` only returning logits and not losses since `f2e34e6b0ac55accd6ba930a04c6f683f5158b29`.
This commit is contained in:
@@ -317,7 +317,7 @@ class Transformer(nn.Module):
|
||||
# if the sequence context is growing too long we must crop it at block_size
|
||||
idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
|
||||
# forward the model to get the logits for the index in the sequence
|
||||
logits, _ = self(idx_cond)
|
||||
logits = self(idx_cond)
|
||||
logits = logits[:, -1, :] # crop to just the final time step
|
||||
if temperature == 0.0:
|
||||
# "sample" the single most likely index
|
||||
|
||||
Reference in New Issue
Block a user