Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Normalization ops #1033

Merged
merged 39 commits into from
Nov 5, 2024
Merged

[PyTorch] Normalization ops #1033

merged 39 commits into from
Nov 5, 2024

Conversation

timmoon10
Copy link
Collaborator

Description

This PR extends the operation-based API (see #707) with LayerNorm, RMSNorm, and FP8 cast operations.

Compare with the existing module-based API:

# Module-based API
module1 = te.LayerNormLinear(...)

# Operation-based API
module2 = te.ops.Sequential(
    te.ops.LayerNorm(...),
    te.ops.Linear(...),
)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • LayerNorm operation
  • FP8 cast operation
  • RMSNorm operation

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@timmoon10 timmoon10 added the enhancement New feature or request label Jul 22, 2024
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Signed-off-by: Tim Moon <[email protected]>
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

from .._common import is_float8_tensor


class CastFloat8(BasicOperation):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you said that it is mostly an utility op for tests, right? We should probably mention that in this documentaiton.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also maybe we should consider generalizing it a bit with a name that is not specific to FP8? (Like just Quantize)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be a helpful op for users as well. For example, if a user wants to have discrete layers for design reasons but still wants to fuse some operations with FP8 casts:

act = te.ops.Sequential(te.ops.GeLU(), te.ops.Quantize())
linear = te.ops.Sequential(te.ops.Linear())
y = act(x)
z = linear(y)

timmoon10 and others added 5 commits September 17, 2024 16:58
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Rename "CastFloat8" op to "Quantize". Add more fine-grained control for SM margin. Add docs for legacy sequence_parallel kwarg.

Signed-off-by: Tim Moon <[email protected]>
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Signed-off-by: Tim Moon <[email protected]>
@timmoon10
Copy link
Collaborator Author

timmoon10 commented Sep 24, 2024

/te-ci pytorch

Edit: te-ci/docs failure disappears when job is rerun.

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

Merging with approval from @ptrendx and @ksivaman.

@timmoon10 timmoon10 merged commit 77c37d4 into NVIDIA:main Nov 5, 2024
26 checks passed
phu0ngng pushed a commit to phu0ngng/TransformerEngine that referenced this pull request Nov 5, 2024
* Add layer norm op

Signed-off-by: Tim Moon <[email protected]>

* Add FP8 cast op

Signed-off-by: Tim Moon <[email protected]>

* Add tests for linear and layernorm with FP8 output

Signed-off-by: Tim Moon <[email protected]>

* RMSNorm op

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* Replace LayerNorm module with LayerNorm op

Signed-off-by: Tim Moon <[email protected]>

* Replace RMSNorm module with RMSNorm op

Signed-off-by: Tim Moon <[email protected]>

* Add AMP support

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Do not save autograd context if grad mode is disabled

Debugging ONNX export tests.

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Forward args in pre_forward func to base op class

Signed-off-by: Tim Moon <[email protected]>

* Update to use QuantizedTensor class

Signed-off-by: Tim Moon <[email protected]>

* Apply suggestions from code review

Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Review suggestions from @ptrendx

Rename "CastFloat8" op to "Quantize". Add more fine-grained control for SM margin. Add docs for legacy sequence_parallel kwarg.

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* Use weight dtype as default compute dtype

Signed-off-by: Tim Moon <[email protected]>

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants