diff --git a/model.py b/model.py index 66304e7..f7edbb6 100644 --- a/model.py +++ b/model.py @@ -317,7 +317,7 @@ class Transformer(nn.Module): # if the sequence context is growing too long we must crop it at block_size idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:] # forward the model to get the logits for the index in the sequence - logits, _ = self(idx_cond) + logits = self(idx_cond) logits = logits[:, -1, :] # crop to just the final time step if temperature == 0.0: # "sample" the single most likely index