add dropout support to model

This commit is contained in:
Andrej Karpathy
2023-07-24 14:18:50 +00:00
parent cdfb49208a
commit 624cdfc76a
+13 -3
View File
@@ -22,6 +22,7 @@ class ModelArgs:
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
norm_eps: float = 1e-5
max_seq_len: int = 2048
dropout: float = 0.0
class RMSNorm(torch.nn.Module):
@@ -90,6 +91,9 @@ class Attention(nn.Module):
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
# use flash attention or a manual implementation?
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
@@ -126,12 +130,13 @@ class Attention(nn.Module):
# flash implementation
if self.flash:
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=0.0, is_causal=True)
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
else:
# manual implementation
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
# restore time as batch dimension and concat heads
@@ -139,20 +144,22 @@ class Attention(nn.Module):
# final projection into the residual stream
output = self.wo(output)
output = self.resid_dropout(output)
return output
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class TransformerBlock(nn.Module):
@@ -166,6 +173,7 @@ class TransformerBlock(nn.Module):
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
dropout=args.dropout,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
@@ -185,6 +193,7 @@ class Transformer(nn.Module):
self.n_layers = params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
@@ -216,6 +225,7 @@ class Transformer(nn.Module):
def forward(self, tokens, targets=None):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
h = self.dropout(h)
freqs_cis = self.freqs_cis[:seqlen]
for layer in self.layers: