When I enable flash attention on A100 I get two errors. First the input is not float16. Even if that is fixed, flash attention does not seem to support a non-empty mask.
Pytorch: 2.1.1+cu121
Cuda: 12.1
A100
`import torch
import soundstorm_pytorch
myattn = soundstorm_pytorch.attend.Attend(flash=True)
x = torch.randint(0, 1024, (1, 8, 1024, 64)).to('cuda') # (batch, seq, num residual VQ)
z=myattn(x,x,x) # Fails
z=myattn(x.half(),x.half(),x.half()) # Works
mask = torch.ones((1, 8, 1024, 1024)).to('cuda').bool()
z=myattn(x.half(),x.half(),x.half(),mask=mask) # Fails`
Error messages are
/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:367.) out = F.scaled_dot_product_attention( /u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:437.) out = F.scaled_dot_product_attention( /u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:369.) out = F.scaled_dot_product_attention( /u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Expected query, key and value to all be of dtype: {Half, BFloat16}. Got Query dtype: long int, Key dtype: long int, and Value dtype: long int instead. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:92.) out = F.scaled_dot_product_attention( Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/speech7/dhaws6/ANACONDA3/envs/SOUNDSTORM_RVQGAN_v1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/speech7/dhaws6/ANACONDA3/envs/SOUNDSTORM_RVQGAN_v1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py", line 135, in forward return self.flash_attn(q, k, v, mask = mask) File "/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py", line 109, in flash_attn out = F.scaled_dot_product_attention(
and
/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Both fused kernels do not support non-null attn_mask. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:261.) out = F.scaled_dot_product_attention( Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/speech7/dhaws6/ANACONDA3/envs/SOUNDSTORM_RVQGAN_v1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/speech7/dhaws6/ANACONDA3/envs/SOUNDSTORM_RVQGAN_v1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py", line 135, in forward return self.flash_attn(q, k, v, mask = mask) File "/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py", line 109, in flash_attn out = F.scaled_dot_product_attention( RuntimeError: No available kernel. Aborting execution.
When I enable flash attention on A100 I get two errors. First the input is not float16. Even if that is fixed, flash attention does not seem to support a non-empty mask.
Pytorch: 2.1.1+cu121
Cuda: 12.1
A100
`import torch
import soundstorm_pytorch
myattn = soundstorm_pytorch.attend.Attend(flash=True)
x = torch.randint(0, 1024, (1, 8, 1024, 64)).to('cuda') # (batch, seq, num residual VQ)
z=myattn(x,x,x) # Fails
z=myattn(x.half(),x.half(),x.half()) # Works
mask = torch.ones((1, 8, 1024, 1024)).to('cuda').bool()
z=myattn(x.half(),x.half(),x.half(),mask=mask) # Fails`
Error messages are
/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:367.) out = F.scaled_dot_product_attention( /u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:437.) out = F.scaled_dot_product_attention( /u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:369.) out = F.scaled_dot_product_attention( /u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Expected query, key and value to all be of dtype: {Half, BFloat16}. Got Query dtype: long int, Key dtype: long int, and Value dtype: long int instead. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:92.) out = F.scaled_dot_product_attention( Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/speech7/dhaws6/ANACONDA3/envs/SOUNDSTORM_RVQGAN_v1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/speech7/dhaws6/ANACONDA3/envs/SOUNDSTORM_RVQGAN_v1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py", line 135, in forward return self.flash_attn(q, k, v, mask = mask) File "/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py", line 109, in flash_attn out = F.scaled_dot_product_attention(and
/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Both fused kernels do not support non-null attn_mask. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:261.) out = F.scaled_dot_product_attention( Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/speech7/dhaws6/ANACONDA3/envs/SOUNDSTORM_RVQGAN_v1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/speech7/dhaws6/ANACONDA3/envs/SOUNDSTORM_RVQGAN_v1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py", line 135, in forward return self.flash_attn(q, k, v, mask = mask) File "/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py", line 109, in flash_attn out = F.scaled_dot_product_attention( RuntimeError: No available kernel. Aborting execution.