diff --git a/model.py b/model.py index 8d76d5a..9d04fea 100644 --- a/model.py +++ b/model.py @@ -22,6 +22,7 @@ class ModelArgs: multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 norm_eps: float = 1e-5 max_seq_len: int = 2048 + dropout: float = 0.0 class RMSNorm(torch.nn.Module): @@ -90,6 +91,9 @@ class Attention(nn.Module): self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + self.attn_dropout = nn.Dropout(args.dropout) + self.resid_dropout = nn.Dropout(args.dropout) + self.dropout = args.dropout # use flash attention or a manual implementation? self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') @@ -126,12 +130,13 @@ class Attention(nn.Module): # flash implementation if self.flash: - output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=0.0, is_causal=True) + output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) else: # manual implementation scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen) scores = F.softmax(scores.float(), dim=-1).type_as(xq) + scores = self.attn_dropout(scores) output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim) # restore time as batch dimension and concat heads @@ -139,20 +144,22 @@ class Attention(nn.Module): # final projection into the residual stream output = self.wo(output) + output = self.resid_dropout(output) return output class FeedForward(nn.Module): - def __init__(self, dim: int, hidden_dim: int, multiple_of: int): + def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float): super().__init__() hidden_dim = int(2 * hidden_dim / 3) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.dropout = nn.Dropout(dropout) def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) class TransformerBlock(nn.Module): @@ -166,6 +173,7 @@ class TransformerBlock(nn.Module): dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of, + dropout=args.dropout, ) self.layer_id = layer_id self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) @@ -185,6 +193,7 @@ class Transformer(nn.Module): self.n_layers = params.n_layers self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.dropout = nn.Dropout(params.dropout) self.layers = torch.nn.ModuleList() for layer_id in range(params.n_layers): self.layers.append(TransformerBlock(layer_id, params)) @@ -216,6 +225,7 @@ class Transformer(nn.Module): def forward(self, tokens, targets=None): _bsz, seqlen = tokens.shape h = self.tok_embeddings(tokens) + h = self.dropout(h) freqs_cis = self.freqs_cis[:seqlen] for layer in self.layers: