diff --git a/model.py b/model.py index 7788749..a1ce255 100644 --- a/model.py +++ b/model.py @@ -255,7 +255,7 @@ class Transformer(nn.Module): if targets is not None: # if we are given some desired targets also calculate the loss logits = self.output(h) - self.last_loss = self.calculate_loss(logits, targets) + self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) else: # inference-time mini-optimization: only forward the output on the very last position logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim @@ -263,9 +263,6 @@ class Transformer(nn.Module): return logits - def calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: - return F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) - def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): # start with all of the candidate parameters param_dict = {pn: p for pn, p in self.named_parameters()}