Merge pull request #250 from npinto/master-1
FIX: model.generate(); forward() only returns logits now.
This commit is contained in:
@@ -317,7 +317,7 @@ class Transformer(nn.Module):
|
||||
# if the sequence context is growing too long we must crop it at block_size
|
||||
idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
|
||||
# forward the model to get the logits for the index in the sequence
|
||||
logits, _ = self(idx_cond)
|
||||
logits = self(idx_cond)
|
||||
logits = logits[:, -1, :] # crop to just the final time step
|
||||
if temperature == 0.0:
|
||||
# "sample" the single most likely index
|
||||
|
||||
Reference in New Issue
Block a user