Merge branch 'mpcusack-mpcusack/jitsave'
This commit is contained in:
@@ -49,7 +49,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||
assert 0 <= 1 < ndim
|
||||
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
return freqs_cis.view(shape)
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
@@ -59,8 +59,8 @@ def apply_rotary_emb(
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# reshape xq and xk to match the complex representation
|
||||
xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1)
|
||||
xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1)
|
||||
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
|
||||
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
|
||||
|
||||
# reshape freqs_cos and freqs_sin for broadcasting
|
||||
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
|
||||
@@ -142,10 +142,11 @@ class Attention(nn.Module):
|
||||
|
||||
# flash implementation
|
||||
if self.flash:
|
||||
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
|
||||
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
||||
else:
|
||||
# manual implementation
|
||||
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
assert hasattr(self, 'mask')
|
||||
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
scores = self.attn_dropout(scores)
|
||||
@@ -198,6 +199,8 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
last_loss: Optional[torch.Tensor]
|
||||
|
||||
def __init__(self, params: ModelArgs):
|
||||
super().__init__()
|
||||
self.params = params
|
||||
@@ -227,6 +230,9 @@ class Transformer(nn.Module):
|
||||
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
|
||||
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))
|
||||
|
||||
# Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
|
||||
self.last_loss = None
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
@@ -235,7 +241,7 @@ class Transformer(nn.Module):
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
||||
def forward(self, tokens, targets=None):
|
||||
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
_bsz, seqlen = tokens.shape
|
||||
h = self.tok_embeddings(tokens)
|
||||
h = self.dropout(h)
|
||||
@@ -249,13 +255,13 @@ class Transformer(nn.Module):
|
||||
if targets is not None:
|
||||
# if we are given some desired targets also calculate the loss
|
||||
logits = self.output(h)
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
else:
|
||||
# inference-time mini-optimization: only forward the output on the very last position
|
||||
logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
||||
loss = None
|
||||
self.last_loss = None
|
||||
|
||||
return logits, loss
|
||||
return logits
|
||||
|
||||
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
|
||||
# start with all of the candidate parameters
|
||||
|
||||
Executable
+66
@@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python
|
||||
"""Saves the model as a TorchScript.
|
||||
|
||||
Usage examples:
|
||||
./save_torchscript.py
|
||||
./save_torchscript.py --dim=300
|
||||
./save_torchscript.py --gzip_output=True --zero_params=True
|
||||
|
||||
The resulting file can be loaded in C++ code and then used for training or
|
||||
inference with:
|
||||
#include <torch/script.h>
|
||||
torch::jit::Module module = torch::jit::load("model.pt")
|
||||
|
||||
Note that the serialized model includes the initial parameters and with the default
|
||||
ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute
|
||||
the model parameters separately you can zero out the parameters before saving it and
|
||||
it will gzip down to 780K.
|
||||
"""
|
||||
import gzip
|
||||
import os
|
||||
import shutil
|
||||
from inspect import signature
|
||||
|
||||
import torch
|
||||
|
||||
from model import ModelArgs, Transformer
|
||||
|
||||
# Model args config
|
||||
dim = 288
|
||||
n_layers = 6
|
||||
n_heads = 6
|
||||
n_kv_heads = n_heads
|
||||
multiple_of = 32
|
||||
max_seq_len = 256
|
||||
dropout = 0.0
|
||||
vocab_size = 32000
|
||||
norm_eps = 1e-5
|
||||
# Save config
|
||||
model_path = "model.pt"
|
||||
zero_params = False
|
||||
gzip_output = False
|
||||
# Allow config overrides
|
||||
exec(open("configurator.py").read())
|
||||
|
||||
|
||||
def main() -> None:
|
||||
model_args = {k: globals()[k] for k in signature(ModelArgs).parameters}
|
||||
model = Transformer(ModelArgs(**model_args))
|
||||
|
||||
# If requested zero params before saving the model. This is useful in
|
||||
# conjunction with gzip_output.
|
||||
if zero_params:
|
||||
for p in model.parameters():
|
||||
p.detach().zero_()
|
||||
|
||||
torch.jit.save(torch.jit.script(model), model_path)
|
||||
|
||||
if gzip_output:
|
||||
with open(model_path, "rb") as f_in:
|
||||
with gzip.open(f"{model_path}.gz", "wb") as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
os.unlink(model_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -211,7 +211,8 @@ def estimate_loss():
|
||||
for k in range(eval_iters):
|
||||
X, Y = next(batch_iter)
|
||||
with ctx:
|
||||
logits, loss = model(X, Y)
|
||||
logits = model(X, Y)
|
||||
loss = model.last_loss
|
||||
losses[k] = loss.item()
|
||||
out[split] = losses.mean()
|
||||
model.train()
|
||||
@@ -294,7 +295,8 @@ while True:
|
||||
# looking at the source of that context manager, it just toggles this variable
|
||||
model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1
|
||||
with ctx:
|
||||
logits, loss = model(X, Y)
|
||||
logits = model(X, Y)
|
||||
loss = model.last_loss
|
||||
loss = loss / gradient_accumulation_steps
|
||||
# immediately async prefetch next batch while model is doing the forward pass on the GPU
|
||||
X, Y = next(train_batch_iter)
|
||||
|
||||
Reference in New Issue
Block a user