Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Propagate fp8 scale-inverse modification to GroupedLinear #1128

Merged
merged 13 commits into from
Sep 9, 2024

Conversation

yaox12
Copy link
Collaborator

@yaox12 yaox12 commented Aug 22, 2024

Description

This PR

  • Propagates the fp8 scale-inverse modification to GroupedLinear.
  • Fixes a bug that wrong scale_inv is used for weights in fp8_grouped_gemm.
  • Adds a new grouped gemm interface for separate scale_inv (for weights) and single output (for fwd output and bwd dgrad).

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@yaox12 yaox12 marked this pull request as ready for review August 22, 2024 09:18
Comment on lines -156 to -157
fp8_meta["scaling_fwd"].scale_inv,
_GEMM_WEIGHT,
Copy link
Collaborator Author

@yaox12 yaox12 Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @timmoon10, I want to confirm that the previous way to get scale_inv from fp8_meta["scaling_fwd"] in forward still works but not encouraged, right?

I'm pretty sure this was a bug and maybe I should create an issue to let users be aware of it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the bug you are noticing? You are changing the API for fp8_grouped_gemm, but other than that it looks like using fp8_meta would be correct (but not recommended).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When casting weights to FP8 in get_fp8_workspace(), the scale_invs of weight tensors are written to the private ones in Float8Tensors, while the scale_invs in fp8_meta are updated after the first forward step is done. So for the first micro batch, they don't match.

@yaox12
Copy link
Collaborator Author

yaox12 commented Aug 26, 2024

@timmoon10 Can I have your review?

@ksivaman
Copy link
Member

Can we not add this functionality to the existing grouped_gemm API via arguments instead of introducing another one? @yaox12

@ksivaman ksivaman self-requested a review August 26, 2024 14:15
@yaox12
Copy link
Collaborator Author

yaox12 commented Aug 27, 2024

Can we not add this functionality to the existing grouped_gemm API via arguments instead of introducing another one? @yaox12

Done.

Signed-off-by: Xin Yao <[email protected]>
@yaox12
Copy link
Collaborator Author

yaox12 commented Sep 2, 2024

@ksivaman Can you trigger the CI?

@timmoon10 timmoon10 self-requested a review September 3, 2024 18:59
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks reasonable to me. I just have a suggestion to make the internal API more consistent.

transformer_engine/pytorch/cpp_extensions/gemm.py Outdated Show resolved Hide resolved
Comment on lines -156 to -157
fp8_meta["scaling_fwd"].scale_inv,
_GEMM_WEIGHT,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the bug you are noticing? You are changing the API for fp8_grouped_gemm, but other than that it looks like using fp8_meta would be correct (but not recommended).

@timmoon10
Copy link
Collaborator

/te-ci pytorch

@yaox12
Copy link
Collaborator Author

yaox12 commented Sep 5, 2024

/te-ci pytorch

@yaox12
Copy link
Collaborator Author

yaox12 commented Sep 6, 2024

/te-ci pytorch

@yaox12
Copy link
Collaborator Author

yaox12 commented Sep 9, 2024

@timmoon10 @ksivaman Can you take another look at this PR? Thanks.

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ksivaman ksivaman merged commit 047a507 into NVIDIA:main Sep 9, 2024
26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants