Resolve jit.save errors

This commit is contained in:
Michael Cusack
2023-08-04 16:49:26 +07:00
parent af8708d87b
commit f2e34e6b0a
2 changed files with 39 additions and 8 deletions
+18 -8
View File
@@ -49,7 +49,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
return freqs_cis.view(shape)
def apply_rotary_emb(
xq: torch.Tensor,
@@ -59,8 +59,8 @@ def apply_rotary_emb(
) -> Tuple[torch.Tensor, torch.Tensor]:
# reshape xq and xk to match the complex representation
xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1)
xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1)
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
# reshape freqs_cos and freqs_sin for broadcasting
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
@@ -108,6 +108,7 @@ class Attention(nn.Module):
# use flash attention or a manual implementation?
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
@@ -142,10 +143,11 @@ class Attention(nn.Module):
# flash implementation
if self.flash:
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
else:
# manual implementation
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
assert hasattr(self, 'mask')
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
@@ -198,6 +200,8 @@ class TransformerBlock(nn.Module):
class Transformer(nn.Module):
last_loss: Optional[torch.Tensor]
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
@@ -227,6 +231,9 @@ class Transformer(nn.Module):
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))
# Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
self.last_loss = None
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
@@ -235,7 +242,7 @@ class Transformer(nn.Module):
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, tokens, targets=None):
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
h = self.dropout(h)
@@ -249,13 +256,16 @@ class Transformer(nn.Module):
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.output(h)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
self.last_loss = self.calculate_loss(logits, targets)
else:
# inference-time mini-optimization: only forward the output on the very last position
logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None
self.last_loss = None
return logits, loss
return logits
def calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
return F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
# start with all of the candidate parameters
+21
View File
@@ -0,0 +1,21 @@
#!/usr/bin/env python3
#!/usr/bin/env python
"""Saves the model as a TorchScript."""
import glob
import os
import sys
from typing import List
import torch
from model import ModelArgs, Transformer
def main() -> None:
model_args = ModelArgs(dim=512, n_layers=6, n_heads=8, vocab_size=32000)
model = Transformer(model_args)
torch.jit.save(torch.jit.script(model), "model.pt")
if __name__ == "__main__":
main()