Skip to content

Add support for float8 activation for Int4GroupwisePreshuffleTensor #2437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: jerryzh168/stack/2
Choose a base branch
from

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Jun 24, 2025

Stacked PRs:


Add support for float8 activation for Int4GroupwisePreshuffleTensor

Summary:
Added basic op support like linear and bmm, we have both float8 and bf16 in the same Tensor
because it's the same dtype, only difference is whether the activation is quantized or not. Although
there is some differneces in implementation:

bf16 activaton:

  • group_scale
  • group_zero

fp8 activation

  • group_scale
  • row_scale

Test Plan:
python test/dtypes/test_float8_activation_int4_groupwise_preshuffle.py

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Jun 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2437

Note: Links to docs will display an error until the docs builds have been completed.

❌ 12 New Failures, 1 Cancelled Job

As of commit cc359e6 with merge base 5a50667 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

jerryzh168 added a commit that referenced this pull request Jun 24, 2025
Summary:
Note: slice is not working yet, others are working

Test Plan:
python test/dtypes/test_float8_activation_int4_groupwise_preshuffle.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2437, branch: jerryzh168/stack/4
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/4 branch from d6d3477 to 26517e8 Compare June 24, 2025 22:25
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 24, 2025
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/2 to main June 24, 2025 22:26
jerryzh168 added a commit that referenced this pull request Jun 24, 2025
Summary:
Note: slice is not working yet, others are working

Test Plan:
python test/dtypes/test_float8_activation_int4_groupwise_preshuffle.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2437, branch: jerryzh168/stack/4
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/4 branch from 26517e8 to d187f78 Compare June 24, 2025 22:26
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/2 June 24, 2025 22:26
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/2 to main June 24, 2025 22:28
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/4 branch from d187f78 to 2fcff42 Compare June 24, 2025 22:28
jerryzh168 added a commit that referenced this pull request Jun 24, 2025
Summary:
Note: slice is not working yet, others are working

Test Plan:
python test/dtypes/test_float8_activation_int4_groupwise_preshuffle.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2437, branch: jerryzh168/stack/4
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/2 June 24, 2025 22:28
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/2 to main June 26, 2025 05:03
jerryzh168 added a commit that referenced this pull request Jun 26, 2025
Summary:
Note: slice is not working yet, others are working

Test Plan:
python test/dtypes/test_float8_activation_int4_groupwise_preshuffle.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2437, branch: jerryzh168/stack/4
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/4 branch from 2fcff42 to 95856ed Compare June 26, 2025 05:03
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/2 June 26, 2025 05:03
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/2 to main June 27, 2025 19:36
jerryzh168 added a commit that referenced this pull request Jun 27, 2025
Summary:
Note: slice is not working yet, others are working

Test Plan:
python test/dtypes/test_float8_activation_int4_groupwise_preshuffle.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2437, branch: jerryzh168/stack/4
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/4 branch from 95856ed to 1dec2cb Compare June 27, 2025 19:37
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/2 June 27, 2025 19:37
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/2 to main June 27, 2025 19:38
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/4 branch from 1dec2cb to 1645c79 Compare June 27, 2025 19:38
jerryzh168 added a commit that referenced this pull request Jun 27, 2025
Summary:
Note: slice is not working yet, others are working

Test Plan:
python test/dtypes/test_float8_activation_int4_groupwise_preshuffle.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2437, branch: jerryzh168/stack/4
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/2 June 27, 2025 19:38
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/2 to main June 27, 2025 19:48
jerryzh168 added a commit that referenced this pull request Jun 27, 2025
Summary:
Note: slice is not working yet, others are working

Test Plan:
python test/dtypes/test_float8_activation_int4_groupwise_preshuffle.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2437, branch: jerryzh168/stack/4
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/4 branch from 1645c79 to 5e9e869 Compare June 27, 2025 19:48
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/2 June 27, 2025 19:48
jerryzh168 added a commit that referenced this pull request Jun 27, 2025
Summary:
Note: slice is not working yet, others are working

Test Plan:
python test/dtypes/test_float8_activation_int4_groupwise_preshuffle.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2437, branch: jerryzh168/stack/4
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/4 branch from 43dc85e to 040375e Compare June 27, 2025 20:09
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/2 June 27, 2025 20:09
@jerryzh168 jerryzh168 added topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) topic: new feature Use this tag if this PR adds a new feature and removed topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) labels Jun 28, 2025
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/2 to main June 30, 2025 21:26
jerryzh168 added a commit that referenced this pull request Jun 30, 2025
Summary:
Note: slice is not working yet, others are working

Test Plan:
python test/dtypes/test_float8_activation_int4_groupwise_preshuffle.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2437, branch: jerryzh168/stack/4
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/4 branch from 040375e to fb2686e Compare June 30, 2025 21:26
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/2 June 30, 2025 21:26
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/2 to main June 30, 2025 23:01
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/4 branch from fb2686e to 84bae22 Compare June 30, 2025 23:01
@jerryzh168 jerryzh168 changed the title Add support for Float8ActivationInt4GroupwisePreshuffleTensor for fbgemm Add support for float8 activation for Int4GroupwisePreshuffleTensor Jun 30, 2025
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/2 June 30, 2025 23:01
@jerryzh168 jerryzh168 mentioned this pull request Jun 30, 2025
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/2 to main July 2, 2025 01:58
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/4 branch from 84bae22 to c13fa2b Compare July 2, 2025 01:58
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/2 July 2, 2025 01:58
@@ -0,0 +1,166 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this test be in the test/quantization/quantize_ folder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, the test file to review is test_int4_groupwise_preshuffle.py, I forgot to remove this file

@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
class TestInt4GroupwisePreshuffleTensor(TestCase):
def setUp(self):
self.config = FbgemmConfig(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or, maybe organize tests by the config? it would make sense for tests for everything in FbgemmConfig to be together.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was planning to remove FbgemmConfig as well soon

quantized = linear(input)
self.assertTrue(compute_error(original, quantized) > 20)

# @unittest.skip("WIP: this doesn't work yet")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

self.assertTrue(compute_error(original, quantized) > 20)

# @unittest.skip("WIP: this doesn't work yet")
def test_slice(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there an integration test we can write to cover this? If this is needed for TP, maybe just have an integration test for TP which is easy to run for all of these configs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is required to run the model in vllm, can add to https://github.com/pytorch/ao/blob/main/test/integration/test_vllm.py when the API is more mature

# making sure param.data is updated
assert param.data.packed_weight[0][0] != orig_value

def test_bmm(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unit tests are nice, feels like it would be good to also have an integration test to cover all of these ops in one go

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e2e tests can be the ones we run in vllm/sglang integration on real models I think?

@@ -2040,6 +2040,8 @@ class FbgemmConfig(AOBaseConfig):
weight_dtype (torch.dtype): weight dtype of the kernel
output_dtype (torch.dtype): output dtype of the kernel
group_size (int): The group size for weight
preshuffle (bool): whether preshuffle the weights or not
activation_dtype_for_int4 (str): the dtype for activation for int4 weight, either bf16 or fp8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's confusing to have both input_dtype and activation_dtype_for_int4, what are your thoughts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it's a bit confusing, although this is temporary, I'm deprecating this later in the stack: #2474

)


class Float8ActivationInt4GroupwisePreshuffleTensor(TorchAOBaseTensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not extend the existing Int4GroupwisePreshuffleTensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry this should be removed

Summary:
Added basic op support like linear and bmm, we have both float8 and bf16 in the same Tensor
because it's the same dtype, only difference is whether the activation is quantized or not. Although
there is some differneces in implementation:

bf16 activaton:
* group_scale
* group_zero

fp8 activation
* group_scale
* row_scale

Test Plan:
python test/dtypes/test_float8_activation_int4_groupwise_preshuffle.py

Reviewers:

Subscribers:

Tasks:

Tags:

stack-info: PR: #2437, branch: jerryzh168/stack/4
@jerryzh168 jerryzh168 changed the base branch from jerryzh168/stack/2 to main July 2, 2025 20:35
@jerryzh168 jerryzh168 force-pushed the jerryzh168/stack/4 branch from c13fa2b to cc359e6 Compare July 2, 2025 20:36
@jerryzh168 jerryzh168 changed the base branch from main to jerryzh168/stack/2 July 2, 2025 20:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants