-
Notifications
You must be signed in to change notification settings - Fork 305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Normalization ops #1033
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch |
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
Debugging ONNX export tests. Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
|
||
self.reset_parameters(defer_init=(device == "meta")) | ||
# Handle deprecated options |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems a little backwards. if dtype
is supposed to be the new argument name, then why is it in kwargs
? Both params_dtype
and dtype
should be regular parameters, there should be a deprecation warning when somebody uses params_dtype
and also the check for duplicate assignment like the one you have here.
Also, similar treatment should be done for hidden_size and sequence_parallel (especially the last one seems to be just gone completely so there should be some explanation that it was unused before or something?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thinking is that we should forward kwargs directly to te.ops.LayerNorm
as much as possible so that we only have to change the API in one place if we ever make changes in the future. We include the deprecated options as explicit kwargs since they are specific to the module.
This function signature also maintains backward compatibility for users who pass in the options as positional args, e.g.:
TransformerEngine/tests/pytorch/test_onnx_export.py
Lines 676 to 678 in 0ee5ccd
te.LayerNorm( | |
inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=zero_centered_gamma | |
) |
super().reset_parameters() | ||
|
||
# Set flag for sequence parallelism (deprecated) | ||
if getattr(self, "sequence_parallel", None) is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So why is sequence_parallel option deprecated then? I believe Megatron is using those to guide their logic for the optimizer. I know it is not great, but we should not break them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe "legacy" is better. We should treat this as a weird, Megatron-specific integration.
|
||
self.reset_parameters(defer_init=(device == "meta")) | ||
# Handle deprecated options |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue as in LN.
from .._common import is_float8_tensor | ||
|
||
|
||
class CastFloat8(BasicOperation): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe you said that it is mostly an utility op for tests, right? We should probably mention that in this documentaiton.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also maybe we should consider generalizing it a bit with a name that is not specific to FP8? (Like just Quantize)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could be a helpful op for users as well. For example, if a user wants to have discrete layers for design reasons but still wants to fuse some operations with FP8 casts:
act = te.ops.Sequential(te.ops.GeLU(), te.ops.Quantize())
linear = te.ops.Sequential(te.ops.Linear())
y = act(x)
z = linear(y)
Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Rename "CastFloat8" op to "Quantize". Add more fine-grained control for SM margin. Add docs for legacy sequence_parallel kwarg. Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
Description
This PR extends the operation-based API (see #707) with LayerNorm, RMSNorm, and FP8 cast operations.
Compare with the existing module-based API:
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: