Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trouble using flash attention on A100 #27

Open
dchaws opened this issue Dec 27, 2023 · 0 comments
Open

Trouble using flash attention on A100 #27

dchaws opened this issue Dec 27, 2023 · 0 comments

Comments

@dchaws
Copy link

dchaws commented Dec 27, 2023

When I enable flash attention on A100 I get two errors. First the input is not float16. Even if that is fixed, flash attention does not seem to support a non-empty mask.

Pytorch: 2.1.1+cu121
Cuda: 12.1
A100

`import torch
import soundstorm_pytorch
myattn = soundstorm_pytorch.attend.Attend(flash=True)

x = torch.randint(0, 1024, (1, 8, 1024, 64)).to('cuda') # (batch, seq, num residual VQ)
z=myattn(x,x,x) # Fails
z=myattn(x.half(),x.half(),x.half()) # Works

mask = torch.ones((1, 8, 1024, 1024)).to('cuda').bool()
z=myattn(x.half(),x.half(),x.half(),mask=mask) # Fails`

Error messages are

/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:367.) out = F.scaled_dot_product_attention( /u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:437.) out = F.scaled_dot_product_attention( /u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:369.) out = F.scaled_dot_product_attention( /u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Expected query, key and value to all be of dtype: {Half, BFloat16}. Got Query dtype: long int, Key dtype: long int, and Value dtype: long int instead. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:92.) out = F.scaled_dot_product_attention( Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/speech7/dhaws6/ANACONDA3/envs/SOUNDSTORM_RVQGAN_v1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/speech7/dhaws6/ANACONDA3/envs/SOUNDSTORM_RVQGAN_v1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py", line 135, in forward return self.flash_attn(q, k, v, mask = mask) File "/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py", line 109, in flash_attn out = F.scaled_dot_product_attention(

and

/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py:109: UserWarning: Both fused kernels do not support non-null attn_mask. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:261.) out = F.scaled_dot_product_attention( Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/speech7/dhaws6/ANACONDA3/envs/SOUNDSTORM_RVQGAN_v1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/speech7/dhaws6/ANACONDA3/envs/SOUNDSTORM_RVQGAN_v1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py", line 135, in forward return self.flash_attn(q, k, v, mask = mask) File "/u/dhaws/TTS/soundstorm-pytorch/soundstorm_pytorch/attend.py", line 109, in flash_attn out = F.scaled_dot_product_attention( RuntimeError: No available kernel. Aborting execution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant