Merge branch 'mpcusack-mpcusack/jitsave'

This commit is contained in:
Andrej Karpathy
2023-08-05 18:13:07 +00:00
3 changed files with 84 additions and 10 deletions
+14 -8
View File
@@ -49,7 +49,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
return freqs_cis.view(shape)
def apply_rotary_emb(
xq: torch.Tensor,
@@ -59,8 +59,8 @@ def apply_rotary_emb(
) -> Tuple[torch.Tensor, torch.Tensor]:
# reshape xq and xk to match the complex representation
xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1)
xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1)
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
# reshape freqs_cos and freqs_sin for broadcasting
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
@@ -142,10 +142,11 @@ class Attention(nn.Module):
# flash implementation
if self.flash:
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
else:
# manual implementation
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
assert hasattr(self, 'mask')
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
@@ -198,6 +199,8 @@ class TransformerBlock(nn.Module):
class Transformer(nn.Module):
last_loss: Optional[torch.Tensor]
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
@@ -227,6 +230,9 @@ class Transformer(nn.Module):
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))
# Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
self.last_loss = None
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
@@ -235,7 +241,7 @@ class Transformer(nn.Module):
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, tokens, targets=None):
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
h = self.dropout(h)
@@ -249,13 +255,13 @@ class Transformer(nn.Module):
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.output(h)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# inference-time mini-optimization: only forward the output on the very last position
logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None
self.last_loss = None
return logits, loss
return logits
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
# start with all of the candidate parameters
+66
View File
@@ -0,0 +1,66 @@
#!/usr/bin/env python
"""Saves the model as a TorchScript.
Usage examples:
./save_torchscript.py
./save_torchscript.py --dim=300
./save_torchscript.py --gzip_output=True --zero_params=True
The resulting file can be loaded in C++ code and then used for training or
inference with:
#include <torch/script.h>
torch::jit::Module module = torch::jit::load("model.pt")
Note that the serialized model includes the initial parameters and with the default
ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute
the model parameters separately you can zero out the parameters before saving it and
it will gzip down to 780K.
"""
import gzip
import os
import shutil
from inspect import signature
import torch
from model import ModelArgs, Transformer
# Model args config
dim = 288
n_layers = 6
n_heads = 6
n_kv_heads = n_heads
multiple_of = 32
max_seq_len = 256
dropout = 0.0
vocab_size = 32000
norm_eps = 1e-5
# Save config
model_path = "model.pt"
zero_params = False
gzip_output = False
# Allow config overrides
exec(open("configurator.py").read())
def main() -> None:
model_args = {k: globals()[k] for k in signature(ModelArgs).parameters}
model = Transformer(ModelArgs(**model_args))
# If requested zero params before saving the model. This is useful in
# conjunction with gzip_output.
if zero_params:
for p in model.parameters():
p.detach().zero_()
torch.jit.save(torch.jit.script(model), model_path)
if gzip_output:
with open(model_path, "rb") as f_in:
with gzip.open(f"{model_path}.gz", "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
os.unlink(model_path)
if __name__ == "__main__":
main()
+4 -2
View File
@@ -211,7 +211,8 @@ def estimate_loss():
for k in range(eval_iters):
X, Y = next(batch_iter)
with ctx:
logits, loss = model(X, Y)
logits = model(X, Y)
loss = model.last_loss
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
@@ -294,7 +295,8 @@ while True:
# looking at the source of that context manager, it just toggles this variable
model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1
with ctx:
logits, loss = model(X, Y)
logits = model(X, Y)
loss = model.last_loss
loss = loss / gradient_accumulation_steps
# immediately async prefetch next batch while model is doing the forward pass on the GPU
X, Y = next(train_batch_iter)