diff --git a/model.py b/model.py index 6f7a43b..7788749 100644 --- a/model.py +++ b/model.py @@ -108,7 +108,6 @@ class Attention(nn.Module): # use flash attention or a manual implementation? self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') - if not self.flash: print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))