Skip to content

Add Float8BlockwiseLinear with Triton kernels for quantization and GEMMs #2592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jul 24, 2025

Stacked PRs:


Summary

  • Add "Float8BlockwiseLinear" and make it differentiable with autograd func to support training
  • Add Triton kernels for (1) activation quant, (2) weight quant, and (3) GEMM based on these DeepGemm inference kernel and rename GEMM to explicitly include expected scaling granularity of operands in the function names (fp8_gemm => blockwise_fp8_gemm_1x128_128x128).
  • Add new Triton kernel need for backward: blockwise_fp8_gemm_1x128_1x128_kernel for dW calculation where both left and right operands have activation scaling granularity (1 x block_size). This is a modified version of the kernel above, so it accepts 1x128 scaling for both operands.
  • GEMM kernels do accumulation in fp32 and cast output to bfloat16.
  • Modify all quantization kernels to use EPS to guard against NaNs upon division by 0.
  • Add tests verifying numerics by enforcing reasonable SQNR
  • Added benchmarking script for comparing Triton kernels vs FBGEMM vs DeepGEMM quantization kernels.

Why use Triton kernels instead of DeepGEMM cutlass kernels?

The GEMM APIs in @vkuzo's PoC here no longer exist in DeepGemm. I tried using the new GEMM APIs (fp8_gemm_nt etc), and:

  • On B200, with both Vasiliy's PR and my PR, got device-side asserts on this line, that were not immediately clear how to resolve.
  • On H100, I only tried Vasiliy's PR, but got undefined symbols error from CUDA, despite using CUDA toolkit 12.8+ as stated in the readme.

Since our only goal is a functional skeleton and not performance, rather than spend more time on this, I just used the existing Triton kernels we had and made a modified GEMM (1 line change) to support blockwise_fp8_gemm_1x128_1x128_kernel.

If we want to replace these Triton GEMMs with the Cutlass ones later to see if perf is better (it probably is), we can do that.

Note on numerics

Interestingly, the reference DeepGemm triton quantization kernels do NOT use EPS/clamping to prevent division by 0. This resulted in my unit tests passing (where inputs were from a normal distributed), but NaNs occuring in TorchTitan training runs, where actual activation values sometimes had amax of 0.

I updated the kernels to use the same EPS guards as torchao.float8, and this fixed the Nans.

Test plan

  • pytest test/prototype/blockwise_fp8/test_blockwise_kernels.py
  • pytest test/prototype/blockwise_fp8/test_blockwise_linear.py

Torchtitan PoC integration results

  • Logs showing converted model, nice decreasing loss, but very low throughput (~5.1k vs ~9.8k for bfloat16).
  • I think perf may be bad because of (1) all the eager .t().contiguous() type of transformations done, and (2) the DeepGemm triton GEMM requires the B tensor be in row-major, then does strided reads to get it in col-major in SMEM to do the fp8 wgmma? Same thing as with fp8 flex attention I think..

Copy link

pytorch-bot bot commented Jul 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2592

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures

As of commit d0631c0 with merge base 0e00df3 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Jul 24, 2025
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from d0cd3be to 3b36022 Compare July 24, 2025 03:25
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 24, 2025
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Jul 24, 2025
@danielvegamyhre danielvegamyhre requested review from vkuzo and drisspg July 24, 2025 04:01
@danielvegamyhre
Copy link
Contributor Author

cc @vkuzo @drisspg for review

error = torch.norm(C - C_q) / torch.norm(C)
print(f"Relative Error: {error.item():.6f}")

assert error < 0.1, "Quantize gemm error is too high"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use sqnr everywhere match w/ existing numerics testing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use SQNR


# original implementation from fbgemm_gpu:
# https://github.com/pytorch/FBGEMM/blob/b19401e913fcdff536dc097fa3013a0a9d66256e/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py#L3091
def triton_quantize_fp8_block(
Copy link
Contributor

@drisspg drisspg Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we have an optional runtime dependency on fbgemm can we just call their kernel directly?

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that is the desired end state. For now I have tried and have had repeated problems getting it to work so far (fbgemm-gpu-genai), e.g. undefined symbols. Tried on both H100 and B200 and got different undefined symbol errors

danielvegamyhre added a commit that referenced this pull request Jul 24, 2025
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from 3b36022 to 9821453 Compare July 24, 2025 04:13
@drisspg
Copy link
Contributor

drisspg commented Jul 24, 2025

(32768, 32768)  (1,128)          5732.42      7960.21        7830.56
(32768, 32768)  (128,128)       13692.4        669.664       7831.14

this number is kinda weird to me, do you have memory bandwidth calcs? I dont immediately get why there is a 10x delta in group wise vs blockwise

@danielvegamyhre
Copy link
Contributor Author

this number is kinda weird to me, do you have memory bandwidth calcs? I dont immediately get why there is a 10x delta in group wise vs blockwise

Yeah I agree it's odd, will try adding some mem bw calcs, was thinking about checking with Josh / fbgemm team as well if perhapst here is a different kernel they use for activation quant.

danielvegamyhre added a commit that referenced this pull request Jul 25, 2025
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from fa64d54 to 9d1e13d Compare July 25, 2025 20:05
danielvegamyhre added a commit that referenced this pull request Jul 25, 2025
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from 9d1e13d to 766156b Compare July 25, 2025 20:08
danielvegamyhre added a commit that referenced this pull request Jul 25, 2025
deepgemm for GEMMs

stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from 766156b to 41f63f6 Compare July 25, 2025 20:13
@danielvegamyhre danielvegamyhre changed the title add more generic kernel for fp8 blockwise scaling add fp8 blockwise linear with triton kernels for quantization and Jul 25, 2025
danielvegamyhre added a commit that referenced this pull request Jul 25, 2025
…pgemm for GEMMs

stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from 41f63f6 to b2c78e9 Compare July 25, 2025 20:16
@danielvegamyhre danielvegamyhre changed the title add fp8 blockwise linear with triton kernels for quantization and add fp8 blockwise linear with triton kernels for quantization and deepgemm for GEMMs Jul 25, 2025
danielvegamyhre added a commit that referenced this pull request Jul 25, 2025
…pgemm for GEMMs

stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from b2c78e9 to b06f818 Compare July 25, 2025 22:48
danielvegamyhre added a commit that referenced this pull request Jul 25, 2025
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from b06f818 to 1f06adc Compare July 25, 2025 22:49
@danielvegamyhre danielvegamyhre changed the title add fp8 blockwise linear with triton kernels for quantization and deepgemm for GEMMs Add Float8BlockwiseLinear with Triton kernels for quantization and GEMMs Jul 25, 2025
danielvegamyhre added a commit that referenced this pull request Jul 25, 2025
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from 1f06adc to 48a0bb9 Compare July 25, 2025 23:20
danielvegamyhre added a commit that referenced this pull request Jul 25, 2025
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from 48a0bb9 to 0ed3a77 Compare July 25, 2025 23:26
danielvegamyhre added a commit that referenced this pull request Jul 25, 2025
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from 0ed3a77 to 77f2c8e Compare July 25, 2025 23:30
danielvegamyhre added a commit that referenced this pull request Jul 25, 2025
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from 77f2c8e to 05e1a19 Compare July 25, 2025 23:38
danielvegamyhre added a commit that referenced this pull request Jul 25, 2025
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from 05e1a19 to 343718a Compare July 25, 2025 23:48
danielvegamyhre added a commit that referenced this pull request Jul 25, 2025
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from 343718a to 97cfaa4 Compare July 25, 2025 23:53
danielvegamyhre added a commit that referenced this pull request Jul 25, 2025
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from 97cfaa4 to 44448c1 Compare July 25, 2025 23:56
danielvegamyhre added a commit that referenced this pull request Jul 26, 2025
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from 44448c1 to a151e46 Compare July 26, 2025 00:58
danielvegamyhre added a commit that referenced this pull request Jul 26, 2025
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from a151e46 to 0c2d688 Compare July 26, 2025 01:06
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/15 branch from 0c2d688 to d0631c0 Compare July 26, 2025 01:28
@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Jul 26, 2025

@drisspg @vkuzo this is ready for another look, numerics look good and torchtitan loss curve looks good. perf is bad for llama3 8b right now.

(accidentally squashed my stack into a single pr with stack-pr, sorry for the large PR, I can break it up if necessary)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants