Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
  • Loading branch information
timmoon10 and ptrendx authored Sep 17, 2024
1 parent 4206fa2 commit fd5afe5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def reset_parameters(self, defer_init: Optional[bool] = None) -> None:
# Check whether to defer init (deprecated)
if defer_init is not None:
warnings.warn(
'reset_parameters kwarg is deprecated. Set device to "meta" instead.',
'defer_init argument to reset_parameters function is deprecated. Set device to "meta" instead.',
DeprecationWarning,
stacklevel=2,
)
Expand Down
15 changes: 10 additions & 5 deletions transformer_engine/pytorch/module/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ class RMSNorm(_RMSNormOp):
`Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__
.. math::
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma
y = \frac{x}{\text{RMS}_\varepsilon(x)} * \gamma
where
.. math::
\text{RMS}_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^n x_i^2 + \varepsilon}
:math:`\gamma` is a learnable affine transform parameter that
matches the inner-most dimensions of the input tensor.
Expand Down Expand Up @@ -81,20 +86,20 @@ def __init__(
def reset_rms_norm_parameters(self) -> None:
"""Deprecated"""
warnings.warn(
"This method will be deprecated in an upcoming release. "
"Update your code to use LayerNorm.reset_parameters() instead.",
"This method is deprecated and will be removed in an upcoming release. "
"Update your code to use RMSNorm.reset_parameters() instead.",
DeprecationWarning,
stacklevel=2,
)
self.reset_parameters()

def reset_parameters(self, defer_init: Optional[bool] = None) -> None:
"""Init LayerNorm parameters"""
"""Init RMSNorm parameters"""

# Check whether to defer init (deprecated)
if defer_init is not None:
warnings.warn(
'reset_parameters kwarg is deprecated. Set device to "meta" instead.',
'defer_init argument to reset_parameters function is deprecated. Set device to "meta" instead.',
DeprecationWarning,
stacklevel=2,
)
Expand Down

0 comments on commit fd5afe5

Please sign in to comment.