Merge pull request #250 from npinto/master-1
FIX: model.generate(); forward() only returns logits now.
This commit is contained in:
@@ -317,7 +317,7 @@ class Transformer(nn.Module):
|
|||||||
# if the sequence context is growing too long we must crop it at block_size
|
# if the sequence context is growing too long we must crop it at block_size
|
||||||
idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
|
idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
|
||||||
# forward the model to get the logits for the index in the sequence
|
# forward the model to get the logits for the index in the sequence
|
||||||
logits, _ = self(idx_cond)
|
logits = self(idx_cond)
|
||||||
logits = logits[:, -1, :] # crop to just the final time step
|
logits = logits[:, -1, :] # crop to just the final time step
|
||||||
if temperature == 0.0:
|
if temperature == 0.0:
|
||||||
# "sample" the single most likely index
|
# "sample" the single most likely index
|
||||||
|
|||||||
Reference in New Issue
Block a user