Reinline loss function

This commit is contained in:
Michael Cusack
2023-08-04 17:21:29 +07:00
parent f67185958b
commit f8d45f180d
+1 -4
View File
@@ -255,7 +255,7 @@ class Transformer(nn.Module):
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.output(h)
self.last_loss = self.calculate_loss(logits, targets)
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# inference-time mini-optimization: only forward the output on the very last position
logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
@@ -263,9 +263,6 @@ class Transformer(nn.Module):
return logits
def calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
return F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
# start with all of the candidate parameters
param_dict = {pn: p for pn, p in self.named_parameters()}