Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Normalization ops #1033

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open

[PyTorch] Normalization ops #1033

wants to merge 31 commits into from

Conversation

timmoon10
Copy link
Collaborator

Description

This PR extends the operation-based API (see #707) with LayerNorm, RMSNorm, and FP8 cast operations.

Compare with the existing module-based API:

# Module-based API
module1 = te.LayerNormLinear(...)

# Operation-based API
module2 = te.ops.Sequential(
    te.ops.LayerNorm(...),
    te.ops.Linear(...),
)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • LayerNorm operation
  • FP8 cast operation
  • RMSNorm operation

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@timmoon10 timmoon10 added the enhancement New feature or request label Jul 22, 2024
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Signed-off-by: Tim Moon <[email protected]>
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch


self.reset_parameters(defer_init=(device == "meta"))
# Handle deprecated options
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a little backwards. if dtype is supposed to be the new argument name, then why is it in kwargs? Both params_dtype and dtype should be regular parameters, there should be a deprecation warning when somebody uses params_dtype and also the check for duplicate assignment like the one you have here.
Also, similar treatment should be done for hidden_size and sequence_parallel (especially the last one seems to be just gone completely so there should be some explanation that it was unused before or something?)

Copy link
Collaborator Author

@timmoon10 timmoon10 Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking is that we should forward kwargs directly to te.ops.LayerNorm as much as possible so that we only have to change the API in one place if we ever make changes in the future. We include the deprecated options as explicit kwargs since they are specific to the module.

This function signature also maintains backward compatibility for users who pass in the options as positional args, e.g.:

te.LayerNorm(
inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=zero_centered_gamma
)

super().reset_parameters()

# Set flag for sequence parallelism (deprecated)
if getattr(self, "sequence_parallel", None) is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So why is sequence_parallel option deprecated then? I believe Megatron is using those to guide their logic for the optimizer. I know it is not great, but we should not break them.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe "legacy" is better. We should treat this as a weird, Megatron-specific integration.


self.reset_parameters(defer_init=(device == "meta"))
# Handle deprecated options
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as in LN.

from .._common import is_float8_tensor


class CastFloat8(BasicOperation):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you said that it is mostly an utility op for tests, right? We should probably mention that in this documentaiton.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also maybe we should consider generalizing it a bit with a name that is not specific to FP8? (Like just Quantize)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be a helpful op for users as well. For example, if a user wants to have discrete layers for design reasons but still wants to fuse some operations with FP8 casts:

act = te.ops.Sequential(te.ops.GeLU(), te.ops.Quantize())
linear = te.ops.Sequential(te.ops.Linear())
y = act(x)
z = linear(y)

timmoon10 and others added 5 commits September 17, 2024 16:58
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Rename "CastFloat8" op to "Quantize". Add more fine-grained control for SM margin. Add docs for legacy sequence_parallel kwarg.

Signed-off-by: Tim Moon <[email protected]>
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Signed-off-by: Tim Moon <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants