-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Run LitGPT benchmarking with a custom Attention implementation priority. #1714
Comments
@riccardofelluga, could you please take a look at what's going on with torch.compile when cuDNN backend is used. An additional question for investigation is why it works for eager and not torch.compile. |
We've seen the fallback not working properly for inductor in cases like this:
We met with @riccardofelluga yesterday and were trying to figure out why it does not work. We have another implementation, which is done in class CausalSelfAttentionWithSDPABackend(litgpt.model.CausalSelfAttention):
def scaled_dot_product_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor | None = None,
) -> torch.Tensor:
def _resolve_sdpa_backend():
"""
Resolves the SDPA backend based on the `MIXOLOGY_SDPA_BACKEND`
variable. The allowed values are the same as in the
`torch.nn.attention.SDPBackend` enum, plus the special value `ALL`
to enable all the backends.
"""
match os.getenv("MIXOLOGY_SDPA_BACKEND"):
case "ERROR":
return SDPBackend.ERROR
case "MATH":
return SDPBackend.MATH
case "FLASH_ATTENTION":
return SDPBackend.FLASH_ATTENTION
case "EFFICIENT_ATTENTION":
return SDPBackend.EFFICIENT_ATTENTION
case "CUDNN_ATTENTION":
return SDPBackend.CUDNN_ATTENTION
case "ALL":
return [
SDPBackend.CUDNN_ATTENTION,
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
]
case _ as unknown:
print(f"Unknown `MIXOLOGY_SDPA_BACKEND`: {unknown}")
with (
nullcontext()
if self.config.attention_logit_softcapping is not None
else sdpa_kernel(_resolve_sdpa_backend(), set_priority=True)
):
return super().scaled_dot_product_attention(q, k, v, mask)
if os.getenv("MIXOLOGY_SDPA_BACKEND") is not None:
litgpt.model.CausalSelfAttention = CausalSelfAttentionWithSDPABackend And this one seems to work. This uses environment variable to tune in the SDPA priority order. Should we then revert back the |
I think that my implementation (which Wojciech linked above) is more correct because it applies the context manager directly on the The current implementation applies the context manager on the whole training loop and after loading and compiling the model, which might not work as expected. Please let me know @riccardofelluga if my reasoning is correct. Thanks @wprazuch for raising the issue! |
@AdamRajfer Unfortunately I think the reasoning there is incorrect, by looking at the traces from inductor we can see that in the case where the sdpa backend priorities were set by using your patch cudnn is not actually used. By further checking the priority list it seems that torch keeps the default one, while with #1715 the priority list is actually set to be the one we want with cudnn first. To check sdpa backedn priorities you can query torch using I think here the main question now is why did cudnn check allow inductor to choose it as backend instead of stopping it before? @wprazuch Let's keep this change, this issue open and further investigate this issue. Meanwhile, for benchmarking just note that inductor does not call cudnn even with Adam's patch. |
@riccardofelluga Two approaches have been tested upon now:
First approach: Applying sdpa kernel on the sdpa function itself
Second approach: Applying sdpa kernel on the whole training benchmark
|
PyTorch has recently changed the priority order of the attention implementations. We would like to benchmark (through benchmark_litgpt.py) the best performance of each compilation backend to ensure we have a clear picture how Thunder, and ThunderFX are doing.
There exists a context manager, which can be used to select a given sdp backend. For the sake of our script, we call:
at one point in benchmark_litgpt.py
My idea would be to wrap this call in the following manner:
Torch should then follow the order in the list in calling the sdp backends.
In theory it should work, however, the below happens.
🐛 Bug
When applying the above snippet, and running:
We receive the error:
Please note the error occuring at the
loss.backward()
. What I think happens is the forward call runs smoothly, but due to reshaping, the shapes of the expected tensors change, and SDPA fails.NOTE: The above issue does not happen for eager, so maybe this is just
torch.compile
issue.I would like to open a discussion and answer the following questions:
torch.compile
with cuDNN SDPA, we could achieve higher performance than ThunderFX, which is not the case when executed with Flash Attention.cc @crcrpar
The text was updated successfully, but these errors were encountered: