Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run LitGPT benchmarking with a custom Attention implementation priority. #1714

Open
wprazuch opened this issue Jan 29, 2025 · 5 comments · Fixed by #1715
Open

Run LitGPT benchmarking with a custom Attention implementation priority. #1714

wprazuch opened this issue Jan 29, 2025 · 5 comments · Fixed by #1715
Assignees

Comments

@wprazuch
Copy link
Contributor

wprazuch commented Jan 29, 2025

PyTorch has recently changed the priority order of the attention implementations. We would like to benchmark (through benchmark_litgpt.py) the best performance of each compilation backend to ensure we have a clear picture how Thunder, and ThunderFX are doing.

There exists a context manager, which can be used to select a given sdp backend. For the sake of our script, we call:

benchmark.train()

at one point in benchmark_litgpt.py

My idea would be to wrap this call in the following manner:

from torch.nn.attention import SDPBackend, sdpa_kernel
...
with sdpa_kernel([SDPBackend.CUDNN_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH], set_priority=True):
    benchmark.train()

Torch should then follow the order in the list in calling the sdp backends.

In theory it should work, however, the below happens.

🐛 Bug

When applying the above snippet, and running:

python thunder/benchmarks/benchmark_litgpt.py     --max_iters 10     --warmup_iters 5     --model_name Phi-3-mini-4k-instruct     --compile inductor

We receive the error:

benchmark_litgpt.py", line 765, in train
    loss.backward()
  File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 648, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1958, in backward
    return impl_fn()
           ^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1944, in impl_fn
    out = CompiledFunction._backward_impl(ctx, all_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2079, in _backward_impl
    out = call_func_at_runtime_with_args(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
                            ^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 755, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/output_code.py", line 464, in __call__
    return self.current_callable(inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/utils.py", line 2203, in run
    return model(new_inputs)
           ^^^^^^^^^^^^^^^^^
  File "/tmp/torchinductor/cx/ccxhth63u7h6hk5gqnevy66wnvou4lzluuzelpihkdaczuns2cn3.py", line 1156, in call
AssertionError: expected size 32==32, stride 96==393216 at dim=1; expected size 4096==4096, stride 3072==96 at dim=2

Please note the error occuring at the loss.backward(). What I think happens is the forward call runs smoothly, but due to reshaping, the shapes of the expected tensors change, and SDPA fails.

NOTE: The above issue does not happen for eager, so maybe this is just torch.compile issue.

I would like to open a discussion and answer the following questions:

  1. Is this change necessary? In my opinion it is necessary, as we want to benchmark the best performance of each backend, and this performance is usually achieved through cuDNN SDPA. When running torch.compile with cuDNN SDPA, we could achieve higher performance than ThunderFX, which is not the case when executed with Flash Attention.
  2. Is the implementation correct or should it be applied in a different way?
  3. Is the loss reshaping really the problem here?

cc @crcrpar

@IvanYashchuk
Copy link
Collaborator

@riccardofelluga, could you please take a look at what's going on with torch.compile when cuDNN backend is used. An additional question for investigation is why it works for eager and not torch.compile.

@wprazuch
Copy link
Contributor Author

wprazuch commented Feb 12, 2025

We've seen the fallback not working properly for inductor in cases like this:

python thunder/benchmarks/benchmark_litgpt.py     --max_iters 10     --warmup_iters 5     --model_name Phi-3-mini-4k-instruct     --compile inductor     --use_sdpa True

We met with @riccardofelluga yesterday and were trying to figure out why it does not work. We have another implementation, which is done in thunder/tests/litgpt_model.py authored by @AdamRajfer :

class CausalSelfAttentionWithSDPABackend(litgpt.model.CausalSelfAttention):
    def scaled_dot_product_attention(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        def _resolve_sdpa_backend():
            """
            Resolves the SDPA backend based on the `MIXOLOGY_SDPA_BACKEND`
             variable. The allowed values are the same as in the
             `torch.nn.attention.SDPBackend` enum, plus the special value `ALL`
             to enable all the backends.
            """
            match os.getenv("MIXOLOGY_SDPA_BACKEND"):
                case "ERROR":
                    return SDPBackend.ERROR
                case "MATH":
                    return SDPBackend.MATH
                case "FLASH_ATTENTION":
                    return SDPBackend.FLASH_ATTENTION
                case "EFFICIENT_ATTENTION":
                    return SDPBackend.EFFICIENT_ATTENTION
                case "CUDNN_ATTENTION":
                    return SDPBackend.CUDNN_ATTENTION
                case "ALL":
                    return [
                        SDPBackend.CUDNN_ATTENTION,
                        SDPBackend.FLASH_ATTENTION,
                        SDPBackend.EFFICIENT_ATTENTION,
                        SDPBackend.MATH,
                    ]
                case _ as unknown:
                    print(f"Unknown `MIXOLOGY_SDPA_BACKEND`: {unknown}")

        with (
            nullcontext()
            if self.config.attention_logit_softcapping is not None
            else sdpa_kernel(_resolve_sdpa_backend(), set_priority=True)
        ):
            return super().scaled_dot_product_attention(q, k, v, mask)


if os.getenv("MIXOLOGY_SDPA_BACKEND") is not None:
    litgpt.model.CausalSelfAttention = CausalSelfAttentionWithSDPABackend

And this one seems to work. This uses environment variable to tune in the SDPA priority order.

Should we then revert back the use_sdpa change, and integrate this one instead? I think we should, but maybe we could do it together with applying a fix, or a new way to integrate those fallbacks?

@AdamRajfer
Copy link

I think that my implementation (which Wojciech linked above) is more correct because it applies the context manager directly on the scaled_dot_product_attention function, which is consistent with the official PyTorch documentation on how to use it (see here: https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html).

The current implementation applies the context manager on the whole training loop and after loading and compiling the model, which might not work as expected. Please let me know @riccardofelluga if my reasoning is correct.

Thanks @wprazuch for raising the issue!

@riccardofelluga
Copy link
Collaborator

The current implementation applies the context manager on the whole training loop and after loading and compiling the model, which might not work as expected. Please let me know @riccardofelluga if my reasoning is correct.

@AdamRajfer Unfortunately I think the reasoning there is incorrect, by looking at the traces from inductor we can see that in the case where the sdpa backend priorities were set by using your patch cudnn is not actually used. By further checking the priority list it seems that torch keeps the default one, while with #1715 the priority list is actually set to be the one we want with cudnn first.

To check sdpa backedn priorities you can query torch using torch._C._get_sdp_priority_order() and we want to see [3, 1, 2, 0, 0], instead of the default [1, 2, 0, 3, 0]

I think here the main question now is why did cudnn check allow inductor to choose it as backend instead of stopping it before?

@wprazuch Let's keep this change, this issue open and further investigate this issue. Meanwhile, for benchmarking just note that inductor does not call cudnn even with Adam's patch.

@AdamRajfer
Copy link

@riccardofelluga Two approaches have been tested upon now:

  1. Applying sdpa kernel on the sdpa function itself
  2. Applying sdpa kernel on the whole training benchmark

First approach: Applying sdpa kernel on the sdpa function itself

  • [Issue] When passing a LIST OF ATENTION BACKENDS, this approach DOES NOT APPLY IT, and the default list is being used instead.
  • [No issue] When passing a SINGLE ATENTION BACKEND, this approach DOES APPLY IT.

Second approach: Applying sdpa kernel on the whole training benchmark

  • [Issue] When passing a LIST OF ATENTION BACKENDS, this approach DOES APPLY IT. However, the fallback mechanism is not being applied correctly, if the first of the given attention backends fails - it should try to use the next backend, and instead of doing so, it raises the exception immediately.
  • [No issue] When passing a SINGLE ATENTION BACKEND, this approach DOES APPLY IT.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants