diff --git a/model.py b/model.py index 1600f5b..6f7a43b 100644 --- a/model.py +++ b/model.py @@ -49,7 +49,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): assert 0 <= 1 < ndim assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) + return freqs_cis.view(shape) def apply_rotary_emb( xq: torch.Tensor, @@ -59,8 +59,8 @@ def apply_rotary_emb( ) -> Tuple[torch.Tensor, torch.Tensor]: # reshape xq and xk to match the complex representation - xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1) - xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1) + xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) + xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) # reshape freqs_cos and freqs_sin for broadcasting freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) @@ -108,6 +108,7 @@ class Attention(nn.Module): # use flash attention or a manual implementation? self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + if not self.flash: print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) @@ -142,10 +143,11 @@ class Attention(nn.Module): # flash implementation if self.flash: - output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) + output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True) else: # manual implementation scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) + assert hasattr(self, 'mask') scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen) scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = self.attn_dropout(scores) @@ -198,6 +200,8 @@ class TransformerBlock(nn.Module): class Transformer(nn.Module): + last_loss: Optional[torch.Tensor] + def __init__(self, params: ModelArgs): super().__init__() self.params = params @@ -227,6 +231,9 @@ class Transformer(nn.Module): if pn.endswith('w3.weight') or pn.endswith('wo.weight'): torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers)) + # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor. + self.last_loss = None + def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) @@ -235,7 +242,7 @@ class Transformer(nn.Module): elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - def forward(self, tokens, targets=None): + def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor: _bsz, seqlen = tokens.shape h = self.tok_embeddings(tokens) h = self.dropout(h) @@ -249,13 +256,16 @@ class Transformer(nn.Module): if targets is not None: # if we are given some desired targets also calculate the loss logits = self.output(h) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + self.last_loss = self.calculate_loss(logits, targets) else: # inference-time mini-optimization: only forward the output on the very last position logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim - loss = None + self.last_loss = None - return logits, loss + return logits + + def calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + return F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): # start with all of the candidate parameters diff --git a/save_model.py b/save_model.py new file mode 100644 index 0000000..54bf82a --- /dev/null +++ b/save_model.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +#!/usr/bin/env python +"""Saves the model as a TorchScript.""" + +import glob +import os +import sys +from typing import List + +import torch + +from model import ModelArgs, Transformer + +def main() -> None: + model_args = ModelArgs(dim=512, n_layers=6, n_heads=8, vocab_size=32000) + model = Transformer(model_args) + torch.jit.save(torch.jit.script(model), "model.pt") + + +if __name__ == "__main__": + main()