diff --git a/model.py b/model.py index 9ca91fe..9e31d81 100644 --- a/model.py +++ b/model.py @@ -49,7 +49,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): assert 0 <= 1 < ndim assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) + return freqs_cis.view(shape) def apply_rotary_emb( xq: torch.Tensor, @@ -59,8 +59,8 @@ def apply_rotary_emb( ) -> Tuple[torch.Tensor, torch.Tensor]: # reshape xq and xk to match the complex representation - xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1) - xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1) + xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) + xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) # reshape freqs_cos and freqs_sin for broadcasting freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) @@ -142,10 +142,11 @@ class Attention(nn.Module): # flash implementation if self.flash: - output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) + output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True) else: # manual implementation scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) + assert hasattr(self, 'mask') scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen) scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = self.attn_dropout(scores) @@ -198,6 +199,8 @@ class TransformerBlock(nn.Module): class Transformer(nn.Module): + last_loss: Optional[torch.Tensor] + def __init__(self, params: ModelArgs): super().__init__() self.params = params @@ -227,6 +230,9 @@ class Transformer(nn.Module): if pn.endswith('w3.weight') or pn.endswith('wo.weight'): torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers)) + # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor. + self.last_loss = None + def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) @@ -235,7 +241,7 @@ class Transformer(nn.Module): elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - def forward(self, tokens, targets=None): + def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor: _bsz, seqlen = tokens.shape h = self.tok_embeddings(tokens) h = self.dropout(h) @@ -249,13 +255,13 @@ class Transformer(nn.Module): if targets is not None: # if we are given some desired targets also calculate the loss logits = self.output(h) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) else: # inference-time mini-optimization: only forward the output on the very last position logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim - loss = None + self.last_loss = None - return logits, loss + return logits def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): # start with all of the candidate parameters diff --git a/save_torchscript.py b/save_torchscript.py new file mode 100755 index 0000000..af3a299 --- /dev/null +++ b/save_torchscript.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python +"""Saves the model as a TorchScript. + +Usage examples: + ./save_torchscript.py + ./save_torchscript.py --dim=300 + ./save_torchscript.py --gzip_output=True --zero_params=True + +The resulting file can be loaded in C++ code and then used for training or +inference with: + #include + torch::jit::Module module = torch::jit::load("model.pt") + +Note that the serialized model includes the initial parameters and with the default +ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute +the model parameters separately you can zero out the parameters before saving it and +it will gzip down to 780K. +""" +import gzip +import os +import shutil +from inspect import signature + +import torch + +from model import ModelArgs, Transformer + +# Model args config +dim = 288 +n_layers = 6 +n_heads = 6 +n_kv_heads = n_heads +multiple_of = 32 +max_seq_len = 256 +dropout = 0.0 +vocab_size = 32000 +norm_eps = 1e-5 +# Save config +model_path = "model.pt" +zero_params = False +gzip_output = False +# Allow config overrides +exec(open("configurator.py").read()) + + +def main() -> None: + model_args = {k: globals()[k] for k in signature(ModelArgs).parameters} + model = Transformer(ModelArgs(**model_args)) + + # If requested zero params before saving the model. This is useful in + # conjunction with gzip_output. + if zero_params: + for p in model.parameters(): + p.detach().zero_() + + torch.jit.save(torch.jit.script(model), model_path) + + if gzip_output: + with open(model_path, "rb") as f_in: + with gzip.open(f"{model_path}.gz", "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + os.unlink(model_path) + + +if __name__ == "__main__": + main() diff --git a/train.py b/train.py index 70b5109..811dd8a 100644 --- a/train.py +++ b/train.py @@ -211,7 +211,8 @@ def estimate_loss(): for k in range(eval_iters): X, Y = next(batch_iter) with ctx: - logits, loss = model(X, Y) + logits = model(X, Y) + loss = model.last_loss losses[k] = loss.item() out[split] = losses.mean() model.train() @@ -294,7 +295,8 @@ while True: # looking at the source of that context manager, it just toggles this variable model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 with ctx: - logits, loss = model(X, Y) + logits = model(X, Y) + loss = model.last_loss loss = loss / gradient_accumulation_steps # immediately async prefetch next batch while model is doing the forward pass on the GPU X, Y = next(train_batch_iter)