From f2e34e6b0ac55accd6ba930a04c6f683f5158b29 Mon Sep 17 00:00:00 2001 From: Michael Cusack Date: Fri, 4 Aug 2023 16:49:26 +0700 Subject: [PATCH 01/15] Resolve jit.save errors --- model.py | 26 ++++++++++++++++++-------- save_model.py | 21 +++++++++++++++++++++ 2 files changed, 39 insertions(+), 8 deletions(-) create mode 100644 save_model.py diff --git a/model.py b/model.py index 1600f5b..6f7a43b 100644 --- a/model.py +++ b/model.py @@ -49,7 +49,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): assert 0 <= 1 < ndim assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) + return freqs_cis.view(shape) def apply_rotary_emb( xq: torch.Tensor, @@ -59,8 +59,8 @@ def apply_rotary_emb( ) -> Tuple[torch.Tensor, torch.Tensor]: # reshape xq and xk to match the complex representation - xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1) - xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1) + xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) + xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) # reshape freqs_cos and freqs_sin for broadcasting freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) @@ -108,6 +108,7 @@ class Attention(nn.Module): # use flash attention or a manual implementation? self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + if not self.flash: print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) @@ -142,10 +143,11 @@ class Attention(nn.Module): # flash implementation if self.flash: - output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) + output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True) else: # manual implementation scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) + assert hasattr(self, 'mask') scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen) scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = self.attn_dropout(scores) @@ -198,6 +200,8 @@ class TransformerBlock(nn.Module): class Transformer(nn.Module): + last_loss: Optional[torch.Tensor] + def __init__(self, params: ModelArgs): super().__init__() self.params = params @@ -227,6 +231,9 @@ class Transformer(nn.Module): if pn.endswith('w3.weight') or pn.endswith('wo.weight'): torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers)) + # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor. + self.last_loss = None + def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) @@ -235,7 +242,7 @@ class Transformer(nn.Module): elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - def forward(self, tokens, targets=None): + def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor: _bsz, seqlen = tokens.shape h = self.tok_embeddings(tokens) h = self.dropout(h) @@ -249,13 +256,16 @@ class Transformer(nn.Module): if targets is not None: # if we are given some desired targets also calculate the loss logits = self.output(h) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + self.last_loss = self.calculate_loss(logits, targets) else: # inference-time mini-optimization: only forward the output on the very last position logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim - loss = None + self.last_loss = None - return logits, loss + return logits + + def calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + return F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): # start with all of the candidate parameters diff --git a/save_model.py b/save_model.py new file mode 100644 index 0000000..54bf82a --- /dev/null +++ b/save_model.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +#!/usr/bin/env python +"""Saves the model as a TorchScript.""" + +import glob +import os +import sys +from typing import List + +import torch + +from model import ModelArgs, Transformer + +def main() -> None: + model_args = ModelArgs(dim=512, n_layers=6, n_heads=8, vocab_size=32000) + model = Transformer(model_args) + torch.jit.save(torch.jit.script(model), "model.pt") + + +if __name__ == "__main__": + main() From 11a8348dfcd7e982bab2bf5ecfa7df18690e8995 Mon Sep 17 00:00:00 2001 From: Michael Cusack Date: Fri, 4 Aug 2023 16:52:04 +0700 Subject: [PATCH 02/15] extra line --- model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/model.py b/model.py index 6f7a43b..7788749 100644 --- a/model.py +++ b/model.py @@ -108,7 +108,6 @@ class Attention(nn.Module): # use flash attention or a manual implementation? self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') - if not self.flash: print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) From ac2b435151a644dbdb0544a623febba905623085 Mon Sep 17 00:00:00 2001 From: Michael Cusack Date: Fri, 4 Aug 2023 16:55:26 +0700 Subject: [PATCH 03/15] docs --- save_model.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/save_model.py b/save_model.py index 54bf82a..c80253a 100644 --- a/save_model.py +++ b/save_model.py @@ -1,6 +1,11 @@ #!/usr/bin/env python3 #!/usr/bin/env python -"""Saves the model as a TorchScript.""" +"""Saves the model as a TorchScript. + +The resulting file can be loaded in C++ code and then used for training or infrence with: + #include + torch::jit::Module module = torch::jit::load("model.pt") +""" import glob import os From fd5e2cc7bcb715052571a0c1040008e8bf75ed9e Mon Sep 17 00:00:00 2001 From: Michael Cusack Date: Fri, 4 Aug 2023 17:03:11 +0700 Subject: [PATCH 04/15] Updating training code for loss result --- train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 70b5109..811dd8a 100644 --- a/train.py +++ b/train.py @@ -211,7 +211,8 @@ def estimate_loss(): for k in range(eval_iters): X, Y = next(batch_iter) with ctx: - logits, loss = model(X, Y) + logits = model(X, Y) + loss = model.last_loss losses[k] = loss.item() out[split] = losses.mean() model.train() @@ -294,7 +295,8 @@ while True: # looking at the source of that context manager, it just toggles this variable model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 with ctx: - logits, loss = model(X, Y) + logits = model(X, Y) + loss = model.last_loss loss = loss / gradient_accumulation_steps # immediately async prefetch next batch while model is doing the forward pass on the GPU X, Y = next(train_batch_iter) From f67185958b5c3e3c690422430220e3db755e9628 Mon Sep 17 00:00:00 2001 From: Michael Cusack Date: Fri, 4 Aug 2023 17:07:41 +0700 Subject: [PATCH 05/15] Model args in save script --- save_model.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/save_model.py b/save_model.py index c80253a..4a19880 100644 --- a/save_model.py +++ b/save_model.py @@ -16,9 +16,18 @@ import torch from model import ModelArgs, Transformer + def main() -> None: - model_args = ModelArgs(dim=512, n_layers=6, n_heads=8, vocab_size=32000) - model = Transformer(model_args) + model = Transformer( + ModelArgs( + dim=288, + n_layers=6, + n_heads=6, + multiple_of=32, + dropout=0.0, + vocab_size=32000, + ) + ) torch.jit.save(torch.jit.script(model), "model.pt") From f8d45f180d3339d9dd39076f10027c24c4e8904a Mon Sep 17 00:00:00 2001 From: Michael Cusack Date: Fri, 4 Aug 2023 17:21:29 +0700 Subject: [PATCH 06/15] Reinline loss function --- model.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/model.py b/model.py index 7788749..a1ce255 100644 --- a/model.py +++ b/model.py @@ -255,7 +255,7 @@ class Transformer(nn.Module): if targets is not None: # if we are given some desired targets also calculate the loss logits = self.output(h) - self.last_loss = self.calculate_loss(logits, targets) + self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) else: # inference-time mini-optimization: only forward the output on the very last position logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim @@ -263,9 +263,6 @@ class Transformer(nn.Module): return logits - def calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: - return F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) - def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): # start with all of the candidate parameters param_dict = {pn: p for pn, p in self.named_parameters()} From 9f8e0857ee39ea8297153d2750c4ce2435c17807 Mon Sep 17 00:00:00 2001 From: Michael Cusack Date: Fri, 4 Aug 2023 17:22:27 +0700 Subject: [PATCH 07/15] Typo --- save_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/save_model.py b/save_model.py index 4a19880..4aebef0 100644 --- a/save_model.py +++ b/save_model.py @@ -2,7 +2,7 @@ #!/usr/bin/env python """Saves the model as a TorchScript. -The resulting file can be loaded in C++ code and then used for training or infrence with: +The resulting file can be loaded in C++ code and then used for training or inference with: #include torch::jit::Module module = torch::jit::load("model.pt") """ From d4cdd6259eafe886ccbdda44be5523faed645656 Mon Sep 17 00:00:00 2001 From: Michael Cusack Date: Fri, 4 Aug 2023 17:30:05 +0700 Subject: [PATCH 08/15] Zero'ing params docs --- save_model.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/save_model.py b/save_model.py index 4aebef0..f3b7539 100644 --- a/save_model.py +++ b/save_model.py @@ -5,6 +5,13 @@ The resulting file can be loaded in C++ code and then used for training or inference with: #include torch::jit::Module module = torch::jit::load("model.pt") + +Note that the model includes the initial parameters and with default ModelArgs the serialized model +is 59M and gzips down to 55M. If you want to serialize/distribute the model parameters separately +and the size of the model file you can zero out the parameters before saving it and it will gzip +down to 780K: + for p in model.parameters(): + p.detach().zero_() """ import glob From 34f04025014991cbc866706a0f0573f417294ab8 Mon Sep 17 00:00:00 2001 From: Michael Cusack Date: Fri, 4 Aug 2023 17:31:11 +0700 Subject: [PATCH 09/15] Zero'ing params docs --- save_model.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/save_model.py b/save_model.py index f3b7539..de102c7 100644 --- a/save_model.py +++ b/save_model.py @@ -2,14 +2,15 @@ #!/usr/bin/env python """Saves the model as a TorchScript. -The resulting file can be loaded in C++ code and then used for training or inference with: +The resulting file can be loaded in C++ code and then used for training or +inference with: #include torch::jit::Module module = torch::jit::load("model.pt") -Note that the model includes the initial parameters and with default ModelArgs the serialized model -is 59M and gzips down to 55M. If you want to serialize/distribute the model parameters separately -and the size of the model file you can zero out the parameters before saving it and it will gzip -down to 780K: +Note that the model includes the initial parameters and with default ModelArgs the +serialized model is 59M and gzips down to 55M. If you want to serialize/distribute the +model parameters separately and the size of the model file you can zero out the +parameters before saving it and it will gzip down to 780K: for p in model.parameters(): p.detach().zero_() """ From dfff7812db50feee8bfdbabb7409ba897af4f8ca Mon Sep 17 00:00:00 2001 From: Michael Cusack Date: Fri, 4 Aug 2023 17:31:31 +0700 Subject: [PATCH 10/15] Zero'ing params docs --- save_model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/save_model.py b/save_model.py index de102c7..921d227 100644 --- a/save_model.py +++ b/save_model.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 #!/usr/bin/env python """Saves the model as a TorchScript. @@ -14,7 +13,6 @@ parameters before saving it and it will gzip down to 780K: for p in model.parameters(): p.detach().zero_() """ - import glob import os import sys From 305d920862dd3f8a989d2f1132fb40c775c3c29a Mon Sep 17 00:00:00 2001 From: Michael Cusack Date: Fri, 4 Aug 2023 17:33:23 +0700 Subject: [PATCH 11/15] Zero'ing params docs --- save_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/save_model.py b/save_model.py index 921d227..42bcd06 100644 --- a/save_model.py +++ b/save_model.py @@ -6,10 +6,10 @@ inference with: #include torch::jit::Module module = torch::jit::load("model.pt") -Note that the model includes the initial parameters and with default ModelArgs the -serialized model is 59M and gzips down to 55M. If you want to serialize/distribute the -model parameters separately and the size of the model file you can zero out the -parameters before saving it and it will gzip down to 780K: +Note that the serialized model includes the initial parameters and with the default +ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute +the model parameters separately and you can zero out the parameters before saving it +and it will gzip down to 780K: for p in model.parameters(): p.detach().zero_() """ From 113c675bc9ba2e8d0c459705ced8e2a2d17c04d2 Mon Sep 17 00:00:00 2001 From: Michael Cusack Date: Fri, 4 Aug 2023 20:31:44 +0700 Subject: [PATCH 12/15] Rename save_model.py --- save_model.py => save_torchscript.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename save_model.py => save_torchscript.py (100%) diff --git a/save_model.py b/save_torchscript.py similarity index 100% rename from save_model.py rename to save_torchscript.py From 4b3a41b8fce8c39aced3aa9c0088b49524024a1c Mon Sep 17 00:00:00 2001 From: Michael Cusack Date: Fri, 4 Aug 2023 23:10:14 +0700 Subject: [PATCH 13/15] Add options to save_torchscript --- save_torchscript.py | 61 ++++++++++++++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 18 deletions(-) mode change 100644 => 100755 save_torchscript.py diff --git a/save_torchscript.py b/save_torchscript.py old mode 100644 new mode 100755 index 42bcd06..cb520d8 --- a/save_torchscript.py +++ b/save_torchscript.py @@ -1,6 +1,11 @@ #!/usr/bin/env python """Saves the model as a TorchScript. +Usage examples: + ./save_torchscript.py + ./save_torchscript.py --dim=300 + ./save_torchscript.py --gzip_output=True --zero_params=True + The resulting file can be loaded in C++ code and then used for training or inference with: #include @@ -8,33 +13,53 @@ inference with: Note that the serialized model includes the initial parameters and with the default ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute -the model parameters separately and you can zero out the parameters before saving it -and it will gzip down to 780K: - for p in model.parameters(): - p.detach().zero_() +the model parameters separately you can zero out the parameters before saving it and +it will gzip down to 780K. """ -import glob +import gzip import os -import sys -from typing import List +import shutil +from inspect import signature import torch from model import ModelArgs, Transformer +# Model args +dim = 288 +n_layers = 6 +n_heads = 6 +n_kv_heads = n_heads +multiple_of = 32 +max_seq_len = 256 +dropout = 0.0 +vocab_size = 32000 +norm_eps = 1e-5 +# Save config +model_path = "model.pt" +zero_params = False +gzip_output = False + +exec(open("configurator.py").read()) + def main() -> None: - model = Transformer( - ModelArgs( - dim=288, - n_layers=6, - n_heads=6, - multiple_of=32, - dropout=0.0, - vocab_size=32000, - ) - ) - torch.jit.save(torch.jit.script(model), "model.pt") + model_args = {k: globals()[k] for k in signature(ModelArgs).parameters} + model = Transformer(ModelArgs(**model_args)) + + # If requested zero params before saving the model. This is usful in + # conjunction with gzip_output. + if zero_params: + for p in model.parameters(): + p.detach().zero_() + + torch.jit.save(torch.jit.script(model), model_path) + + if gzip_output: + with open(model_path, "rb") as f_in: + with gzip.open(f"{model_path}.gz", "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + os.unlink(model_path) if __name__ == "__main__": From f4c96b7339c5efffd1d5f3b08cf947d9e7787a4b Mon Sep 17 00:00:00 2001 From: Michael Cusack Date: Fri, 4 Aug 2023 23:11:33 +0700 Subject: [PATCH 14/15] Add options to save_torchscript --- save_torchscript.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/save_torchscript.py b/save_torchscript.py index cb520d8..5b7f1ed 100755 --- a/save_torchscript.py +++ b/save_torchscript.py @@ -25,7 +25,7 @@ import torch from model import ModelArgs, Transformer -# Model args +# Model args config dim = 288 n_layers = 6 n_heads = 6 @@ -39,7 +39,7 @@ norm_eps = 1e-5 model_path = "model.pt" zero_params = False gzip_output = False - +# Allow config overrides exec(open("configurator.py").read()) From 13f342af9e937f53c32c96be201a5dbe1936e0c0 Mon Sep 17 00:00:00 2001 From: Michael Cusack Date: Fri, 4 Aug 2023 23:12:06 +0700 Subject: [PATCH 15/15] docs typo --- save_torchscript.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/save_torchscript.py b/save_torchscript.py index 5b7f1ed..af3a299 100755 --- a/save_torchscript.py +++ b/save_torchscript.py @@ -47,7 +47,7 @@ def main() -> None: model_args = {k: globals()[k] for k in signature(ModelArgs).parameters} model = Transformer(ModelArgs(**model_args)) - # If requested zero params before saving the model. This is usful in + # If requested zero params before saving the model. This is useful in # conjunction with gzip_output. if zero_params: for p in model.parameters():