extra line
This commit is contained in:
@@ -108,7 +108,6 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
# use flash attention or a manual implementation?
|
# use flash attention or a manual implementation?
|
||||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
||||||
|
|
||||||
if not self.flash:
|
if not self.flash:
|
||||||
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||||
|
|||||||
Reference in New Issue
Block a user