Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Remove dtype from args of permutation #1145

Merged
merged 5 commits into from
Aug 29, 2024

Conversation

yaox12
Copy link
Collaborator

@yaox12 yaox12 commented Aug 28, 2024

Description

  • Previous moe_permute and moe_unpermute methods require users passing tex.DType as an argument. I think exposing tex.DType outside of TE should be avoided. This PR just gets the dtype from input tensors, including Float8Tensor.
  • Added permutation UTs to QA.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@yaox12
Copy link
Collaborator Author

yaox12 commented Aug 28, 2024

@phu0ngng Can I have your review? This is just some API change.

Signed-off-by: Xin Yao <[email protected]>
@ksivaman
Copy link
Member

/te-ci pytorch

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yaox12 A tangent, but the permute API is currently not documented. You'll need to append to the docs file for it.

Copy link
Collaborator

@phu0ngng phu0ngng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Signed-off-by: Xin Yao <[email protected]>
@phu0ngng phu0ngng merged commit 8ddac3d into NVIDIA:main Aug 29, 2024
15 checks passed
@yaox12 yaox12 deleted the xiny/fix_permute_api branch August 30, 2024 01:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants