Reinline loss function
This commit is contained in:
@@ -255,7 +255,7 @@ class Transformer(nn.Module):
|
||||
if targets is not None:
|
||||
# if we are given some desired targets also calculate the loss
|
||||
logits = self.output(h)
|
||||
self.last_loss = self.calculate_loss(logits, targets)
|
||||
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
else:
|
||||
# inference-time mini-optimization: only forward the output on the very last position
|
||||
logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
||||
@@ -263,9 +263,6 @@ class Transformer(nn.Module):
|
||||
|
||||
return logits
|
||||
|
||||
def calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
||||
return F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
|
||||
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
|
||||
# start with all of the candidate parameters
|
||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||
|
||||
Reference in New Issue
Block a user