add dropout support to model

This commit is contained in:
Andrej Karpathy
2023-07-24 14:18:50 +00:00
parent cdfb49208a
commit 624cdfc76a
+13 -3
View File
@@ -22,6 +22,7 @@ class ModelArgs:
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
norm_eps: float = 1e-5 norm_eps: float = 1e-5
max_seq_len: int = 2048 max_seq_len: int = 2048
dropout: float = 0.0
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
@@ -90,6 +91,9 @@ class Attention(nn.Module):
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
# use flash attention or a manual implementation? # use flash attention or a manual implementation?
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
@@ -126,12 +130,13 @@ class Attention(nn.Module):
# flash implementation # flash implementation
if self.flash: if self.flash:
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=0.0, is_causal=True) output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
else: else:
# manual implementation # manual implementation
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen) scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim) output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
# restore time as batch dimension and concat heads # restore time as batch dimension and concat heads
@@ -139,20 +144,22 @@ class Attention(nn.Module):
# final projection into the residual stream # final projection into the residual stream
output = self.wo(output) output = self.wo(output)
output = self.resid_dropout(output)
return output return output
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int): def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
super().__init__() super().__init__()
hidden_dim = int(2 * hidden_dim / 3) hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x): def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x)) return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class TransformerBlock(nn.Module): class TransformerBlock(nn.Module):
@@ -166,6 +173,7 @@ class TransformerBlock(nn.Module):
dim=args.dim, dim=args.dim,
hidden_dim=4 * args.dim, hidden_dim=4 * args.dim,
multiple_of=args.multiple_of, multiple_of=args.multiple_of,
dropout=args.dropout,
) )
self.layer_id = layer_id self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
@@ -185,6 +193,7 @@ class Transformer(nn.Module):
self.n_layers = params.n_layers self.n_layers = params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.layers = torch.nn.ModuleList() self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers): for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params)) self.layers.append(TransformerBlock(layer_id, params))
@@ -216,6 +225,7 @@ class Transformer(nn.Module):
def forward(self, tokens, targets=None): def forward(self, tokens, targets=None):
_bsz, seqlen = tokens.shape _bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens) h = self.tok_embeddings(tokens)
h = self.dropout(h)
freqs_cis = self.freqs_cis[:seqlen] freqs_cis = self.freqs_cis[:seqlen]
for layer in self.layers: for layer in self.layers: