Skip to content

test rowwise fp32 #2431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

test rowwise fp32 #2431

wants to merge 1 commit into from

Conversation

y-sq
Copy link
Contributor

@y-sq y-sq commented Jun 24, 2025

Summary:
Running rowwise scaling on fp32 tensors got the error, P1794222725

RuntimeError: Only bf16 high precision output types are supported for row-wise scaling.

This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision.

It can be enabled by setting

config = dataclasses.replace(config, convert_dtypes_for_rowwise_scaled_mm=True)

Differential Revision: D73552660

Copy link

pytorch-bot bot commented Jun 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2431

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit b0240a2 with merge base 2025b75 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 24, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73552660

@y-sq y-sq requested review from vkuzo, danielvegamyhre and drisspg and removed request for vkuzo June 24, 2025 06:35
@y-sq y-sq added float8 topic: not user facing Use this tag if you don't want this PR to show up in release notes labels Jun 24, 2025

if convert_dtypes_for_rowwise_scaled_mm and is_rowwise_scaling:
output_dtype = torch.bfloat16

Copy link
Contributor

@vkuzo vkuzo Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of adding a flag, TBH I think we can just enable this on-by-default, like this:

file issue in PyTorch core to add float32 output to scaled_mm

output_dtype_to_use = output_dtype
if is_rowwise_scaling:
    # work around torch._scaled_mm not having float32 output type
    # TODO(issue number): remove this once torch._scaled_mm supports float32 output
    output_dtype_to_use = torch.bfloat16
output = torch._scaled_mm(..., output_dtype_to_use, ...)
...
if is_rowwise_scaling and output_dtype == torch.float32:
    # work around torch._scaled_mm not having float32 output type
    # TODO(issue number): remove this once torch._scaled_mm supports float32 output
    output = output.to(orig_dtype)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, I'll change to enable by default and file an issue.

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we file an issue in core to add this to torch._scaled_mm, and enable the workaround without a config for now? also add a test?

@y-sq
Copy link
Contributor Author

y-sq commented Jun 25, 2025

Updated to enable the workaround by default. Included fp16 and fp32 dtypes in the existing test cases. The additional changes are formatting things generated by linter.
(May need some time for this pr to be updated with the diff.)

The pytorch issue: pytorch/pytorch#156771

@vkuzo
Copy link
Contributor

vkuzo commented Jun 25, 2025

@y-sq , maybe export again?

@y-sq y-sq force-pushed the export-D73552660 branch from 0a45ccd to 4ab8986 Compare July 3, 2025 21:36
y-sq added a commit to y-sq/ao that referenced this pull request Jul 3, 2025
…ytorch#2431)

Summary:

Running rowwise scaling on fp32 tensors got the error, P1794222725
```
RuntimeError: Only bf16 high precision output types are supported for row-wise scaling.
```

This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision.

It can be enabled by setting
```
config = dataclasses.replace(config, convert_dtypes_for_rowwise_scaled_mm=True)
```

Differential Revision: D73552660
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73552660

@y-sq
Copy link
Contributor Author

y-sq commented Jul 3, 2025

@vkuzo sorry there were some un-synced files between github and fbcode so the previous export all failed. The pr should be updated now.

y-sq added a commit to y-sq/ao that referenced this pull request Jul 3, 2025
…ytorch#2431)

Summary:

Running rowwise scaling on fp32 tensors got the error, P1794222725
```
RuntimeError: Only bf16 high precision output types are supported for row-wise scaling.
```

This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision.

It can be enabled by setting
```
config = dataclasses.replace(config, convert_dtypes_for_rowwise_scaled_mm=True)
```

Differential Revision: D73552660
@y-sq y-sq force-pushed the export-D73552660 branch from 4ab8986 to d19f362 Compare July 3, 2025 22:34
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73552660

y-sq added a commit to y-sq/ao that referenced this pull request Jul 3, 2025
…ytorch#2431)

Summary:

Running rowwise scaling on fp32 tensors got the error, P1794222725
```
RuntimeError: Only bf16 high precision output types are supported for row-wise scaling.
```

This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision.

It can be enabled by setting
```
config = dataclasses.replace(config, convert_dtypes_for_rowwise_scaled_mm=True)
```

Differential Revision: D73552660
@y-sq y-sq force-pushed the export-D73552660 branch from d19f362 to 92c3668 Compare July 3, 2025 23:40
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73552660

y-sq added a commit to y-sq/ao that referenced this pull request Jul 3, 2025
…ytorch#2431)

Summary:

Running rowwise scaling on fp32 tensors got the error, P1794222725
```
RuntimeError: Only bf16 high precision output types are supported for row-wise scaling.
```

This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision.

It can be enabled by setting
```
config = dataclasses.replace(config, convert_dtypes_for_rowwise_scaled_mm=True)
```

Differential Revision: D73552660
@y-sq y-sq force-pushed the export-D73552660 branch 2 times, most recently from abbed3a to 303d6a6 Compare July 3, 2025 23:55
y-sq added a commit to y-sq/ao that referenced this pull request Jul 3, 2025
…ytorch#2431)

Summary:

Running rowwise scaling on fp32 tensors got the error, P1794222725
```
RuntimeError: Only bf16 high precision output types are supported for row-wise scaling.
```

This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision.

It can be enabled by setting
```
config = dataclasses.replace(config, convert_dtypes_for_rowwise_scaled_mm=True)
```

Differential Revision: D73552660
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73552660

…ytorch#2431)

Summary:
Pull Request resolved: pytorch#2431

Running rowwise scaling on fp32 tensors got the error, P1794222725
```
RuntimeError: Only bf16 high precision output types are supported for row-wise scaling.
```

This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision.

It can be enabled by setting
```
config = dataclasses.replace(config, convert_dtypes_for_rowwise_scaled_mm=True)
```

Differential Revision: D73552660
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73552660

@y-sq y-sq force-pushed the export-D73552660 branch from 303d6a6 to b0240a2 Compare July 3, 2025 23:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported float8 topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants