-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Normalization ops #1033
Conversation
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch |
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
Debugging ONNX export tests. Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
from .._common import is_float8_tensor | ||
|
||
|
||
class CastFloat8(BasicOperation): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe you said that it is mostly an utility op for tests, right? We should probably mention that in this documentaiton.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also maybe we should consider generalizing it a bit with a name that is not specific to FP8? (Like just Quantize)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could be a helpful op for users as well. For example, if a user wants to have discrete layers for design reasons but still wants to fuse some operations with FP8 casts:
act = te.ops.Sequential(te.ops.GeLU(), te.ops.Quantize())
linear = te.ops.Sequential(te.ops.Linear())
y = act(x)
z = linear(y)
Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Rename "CastFloat8" op to "Quantize". Add more fine-grained control for SM margin. Add docs for legacy sequence_parallel kwarg. Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch Edit: |
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch |
/te-ci pytorch |
* Add layer norm op Signed-off-by: Tim Moon <[email protected]> * Add FP8 cast op Signed-off-by: Tim Moon <[email protected]> * Add tests for linear and layernorm with FP8 output Signed-off-by: Tim Moon <[email protected]> * RMSNorm op Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warnings Signed-off-by: Tim Moon <[email protected]> * Replace LayerNorm module with LayerNorm op Signed-off-by: Tim Moon <[email protected]> * Replace RMSNorm module with RMSNorm op Signed-off-by: Tim Moon <[email protected]> * Add AMP support Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Do not save autograd context if grad mode is disabled Debugging ONNX export tests. Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Forward args in pre_forward func to base op class Signed-off-by: Tim Moon <[email protected]> * Update to use QuantizedTensor class Signed-off-by: Tim Moon <[email protected]> * Apply suggestions from code review Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Review suggestions from @ptrendx Rename "CastFloat8" op to "Quantize". Add more fine-grained control for SM margin. Add docs for legacy sequence_parallel kwarg. Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warnings Signed-off-by: Tim Moon <[email protected]> * Use weight dtype as default compute dtype Signed-off-by: Tim Moon <[email protected]> * Fix linter warnings Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak <[email protected]>
Description
This PR extends the operation-based API (see #707) with LayerNorm, RMSNorm, and FP8 cast operations.
Compare with the existing module-based API:
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: