Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[question] Myelin: attention fusion and FlashAttention #3243

Open
vadimkantorov opened this issue Aug 21, 2023 · 15 comments
Open

[question] Myelin: attention fusion and FlashAttention #3243

vadimkantorov opened this issue Aug 21, 2023 · 15 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@vadimkantorov
Copy link

Hi! When attention op gets fused in a single op with Myelin, it's not written in trex-tooltip if it's using FlashAttention / proper fusion or not (and if it's using quantization under the hood, especially for the implicit quantization mode). How can we know if it's using fused attention impl like FlashAttention? Thanks :)

@zerollzeng
Copy link
Collaborator

@nvpohanh ^ ^

@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Aug 22, 2023
@nvpohanh
Copy link
Collaborator

For now, you can only check the Nsight Systems profiles: https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#nvprof

If the MHA is fused, there should be kernel names with _mha in the profile.

In next TRT version, you will be able to get this info by using the IEngineInspector (or --profilingVerbosity=detailed --dumpLayerInfo if you're using trtexec).

@nvpohanh
Copy link
Collaborator

(and if it's using quantization under the hood, especially for the implicit quantization mode)

TRT's MHA fusion does not support implicit quantizations yet. Please use explicit quantization instead: Add Q/DQ ops before the two batch gemms in the MHA and also add Q/DQ ops before the ResidualAdd.

Here is an ugly example:
2023-07-25 10_19_35-Window

@Jeremalloch
Copy link
Contributor

Hey, a couple questions to tack on:

  1. What version of TensorRT will this be included in? Is that TRT 8.7? And if --profilingVerbosity=detailed dumps the data, does that mean using TRT explorer / TREx tooling will show the kernels (instead of a single myelin kernel)? I'd like to be able [paritially] quantize a model, and being able to see which kernels are executing in INT8 vs FP16 would be very helpful.
  2. Does TRT support int8 flash attention?
  3. In your attached figure, should we be inserting a QDQ node between the bias addition and the residual connection addition, or should we just be inserting it on the identity / shortcut branch? For ResNet style CNN architectures, I know its recommended to not insert on the residual side so that a conv -> add pattern can be fused into a single kernel. Is this the case for a linear layer -> add in transformers as well? (TRT doc link)

@nvpohanh
Copy link
Collaborator

The INT8 MHA fused kernels are already integrated in TRT 8.6. The only caveat is that SeqLen must be 512 or below.

It does use flash attention if applicable.

In your attached figure, should we be inserting a QDQ node between the bias addition and the residual connection addition, or should we just be inserting it on the identity / shortcut branch? For ResNet style CNN architectures, I know its recommended to not insert on the residual side so that a conv -> add pattern can be fused into a single kernel. Is this the case for a linear layer -> add in transformers as well?

For Transformers, it is recommended to add Q/DQs on both the inputs of the ResidualAdd. This is because in ConvNets, we fuse the ResidualAdd with the Conv, but for Transformers, we fuse the ResidualAdd with the LayerNorm that comes right after the ResidualAdd.

@Jeremalloch
Copy link
Contributor

Hey, one followup question. How does the fusion scheme work for pre-norm transformers (as the layer norm would only be applied to the residual branch, and not the identity branch)? Does a norm first transformer come with a performance penalty / less optimized kernels?
Screenshot 2023-09-21 at 3 49 46 PM

@nvpohanh
Copy link
Collaborator

TRT should be able to fuse the add_1 with norm2, so it should not cause any perf issue.

@WeixiangXu
Copy link

TRT's MHA fusion does not support implicit quantizations yet. Please use explicit quantization instead: Add Q/DQ ops before the two batch gemms in the MHA and also add Q/DQ ops before the ResidualAdd.

@nvpohanh How about the int8 attention speed v.s. fp16? Do you have more detailed documents about how to conduct explicit quantization on attention? thanks!

@zhexinli
Copy link

zhexinli commented Mar 14, 2024

The INT8 MHA fused kernels are already integrated in TRT 8.6. The only caveat is that SeqLen must be 512 or below.

It does use flash attention if applicable.

In your attached figure, should we be inserting a QDQ node between the bias addition and the residual connection addition, or should we just be inserting it on the identity / shortcut branch? For ResNet style CNN architectures, I know its recommended to not insert on the residual side so that a conv -> add pattern can be fused into a single kernel. Is this the case for a linear layer -> add in transformers as well?

For Transformers, it is recommended to add Q/DQs on both the inputs of the ResidualAdd. This is because in ConvNets, we fuse the ResidualAdd with the Conv, but for Transformers, we fuse the ResidualAdd with the LayerNorm that comes right after the ResidualAdd.

hi, I'm new to TRT and not familiear with TRT's docs. Is there somewhere we can view the features and constraints of all of TRT supported kernels/fuse pattern/plugins? Such as how to insert QDQ for MHA and its supported layout?

@Aktcob
Copy link

Aktcob commented May 10, 2024

(and if it's using quantization under the hood, especially for the implicit quantization mode)

TRT's MHA fusion does not support implicit quantizations yet. Please use explicit quantization instead: Add Q/DQ ops before the two batch gemms in the MHA and also add Q/DQ ops before the ResidualAdd.

Here is an ugly example: 2023-07-25 10_19_35-Window

Hi, @nvpohanh I try to convert this onnx to Tensorrt Engine. But there is no kernel name with _mha.
Tensorrt Version: 8.6.2
Screenshot from 2024-05-10 10-53-34

@nvpohanh
Copy link
Collaborator

@Aktcob Could you share your trtexec command and the ONNX? Also, could you try TRT 10.0.1.6 GA release and make sure you have enabled FP16?

@Aktcob
Copy link

Aktcob commented May 10, 2024

@Aktcob Could you share your trtexec command and the ONNX? Also, could you try TRT 10.0.1.6 GA release and make sure you have enabled FP16?

@nvpohanh Thanks for reply!
/usr/src/tensorrt/bin/trtexec --onnx=selfattention.onnx --fp16 --iterations=300 --avgRuns=100 --dumpProfile --workspace=1000 --saveEngine=selfattention.engine

I try it on Jetson Orin devkit. So I cannot try it with TRT 10.0.1.6 GA.

Onnx File: https://wenshu.sankuai.com/file/share/download/35BB3D5239CBFD33D79A5FE4DA4F17BFF7B46221
Password:3xv7u6

@Aktcob
Copy link

Aktcob commented May 10, 2024

@nvpohanh I try it on another SelfAttention Module without AttentionMask. And there is a kernel with name mha_v2 which fuses matmal + softmax + matmal.
So the question is: how to use mha on selfattention With AttentionMask?

@ecilay
Copy link

ecilay commented Jun 10, 2024

Does Myelin fuses the attention during runtime at engine running or Ahead of time at engine build?

@Aktcob
Copy link

Aktcob commented Nov 27, 2024

@nvpohanh Hi, I follow ur example and test my MHA module.

And the onnx model is shown as:
Image

I test it on Jetson Orin Device, Tensorrt version is 8.6.0. And it can be built Tensorrt Engine with mha kernel. I use Nsys to show this.
Image

However, the INT8 mha kernel is much slower than the FP16 mha kernel.(800us vs 200us). How to solve this problem? Thanks for ur reply.

(and if it's using quantization under the hood, especially for the implicit quantization mode)

TRT's MHA fusion does not support implicit quantizations yet. Please use explicit quantization instead: Add Q/DQ ops before the two batch gemms in the MHA and also add Q/DQ ops before the ResidualAdd.

Here is an ugly example: 2023-07-25 10_19_35-Window

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

8 participants