Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Normalization ops #1033

Merged
merged 39 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
338e193
Add layer norm op
timmoon10 Jul 20, 2024
84bc1d7
Add FP8 cast op
timmoon10 Jul 22, 2024
0c40c54
Merge branch 'main' into norm-ops
timmoon10 Jul 22, 2024
a7f0228
Add tests for linear and layernorm with FP8 output
timmoon10 Jul 22, 2024
cb9c455
RMSNorm op
timmoon10 Jul 22, 2024
68635ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 22, 2024
cb9b4ec
Fix linter warnings
timmoon10 Jul 22, 2024
b33f367
Merge branch 'main' into norm-ops
timmoon10 Jul 24, 2024
d9fb6f4
Merge branch 'main' into norm-ops
timmoon10 Jul 26, 2024
00592d7
Replace LayerNorm module with LayerNorm op
timmoon10 Jul 29, 2024
e0a2fd9
Replace RMSNorm module with RMSNorm op
timmoon10 Jul 29, 2024
ad32d6a
Add AMP support
timmoon10 Jul 30, 2024
92d1f89
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 30, 2024
2197bca
Merge branch 'main' into norm-ops
timmoon10 Jul 30, 2024
c27a783
Merge branch 'main' into norm-ops
timmoon10 Jul 30, 2024
fb6b7e4
Merge branch 'main' into norm-ops
timmoon10 Aug 12, 2024
7be0524
Do not save autograd context if grad mode is disabled
timmoon10 Aug 12, 2024
e7c9c67
Merge branch 'main' into norm-ops
timmoon10 Aug 12, 2024
91e6a03
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2024
21086aa
Merge branch 'main' into norm-ops
timmoon10 Aug 12, 2024
28bc058
Forward args in pre_forward func to base op class
timmoon10 Aug 16, 2024
e6c5d5f
Merge branch 'main' into norm-ops
timmoon10 Aug 16, 2024
4fdc3b7
Merge branch 'main' into norm-ops
timmoon10 Sep 3, 2024
b1141f5
Merge branch 'main' into norm-ops
timmoon10 Sep 11, 2024
4206fa2
Update to use QuantizedTensor class
timmoon10 Sep 11, 2024
fd5afe5
Apply suggestions from code review
timmoon10 Sep 17, 2024
5b90e4b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2024
fd4ef97
Merge branch 'main' into norm-ops
timmoon10 Sep 19, 2024
102c64f
Review suggestions from @ptrendx
timmoon10 Sep 19, 2024
87ce450
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 19, 2024
fed61f9
Fix linter warnings
timmoon10 Sep 20, 2024
393ee66
Merge branch 'main' into norm-ops
timmoon10 Sep 24, 2024
556983e
Use weight dtype as default compute dtype
timmoon10 Sep 24, 2024
9b508df
Merge branch 'main' into norm-ops
timmoon10 Oct 1, 2024
fb16ee9
Merge branch 'main' into norm-ops
timmoon10 Oct 9, 2024
34e2985
Merge branch 'main' into norm-ops
timmoon10 Oct 18, 2024
7e61399
Fix linter warnings
timmoon10 Oct 18, 2024
8b3cf24
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 18, 2024
8ddd539
Merge branch 'main' into norm-ops
timmoon10 Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
515 changes: 411 additions & 104 deletions tests/pytorch/test_fusible_ops.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _load_library():
from transformer_engine.pytorch.distributed import checkpoint
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context
from transformer_engine.pytorch import ops
from transformer_engine.pytorch import optimizers

# Register custom op symbolic ONNX functions
Expand Down
269 changes: 102 additions & 167 deletions transformer_engine/pytorch/module/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,156 +3,81 @@
# See LICENSE for license information.

"""LayerNorm API"""
import os
import warnings
from typing import Union, Tuple, Optional
from typing import Iterable, Optional, Union

import torch
from torch.nn.parameter import Parameter
from torch.nn import init

import transformer_engine_torch as tex
from .base import TransformerEngineBaseModule
from ..cpp_extensions import (
layernorm_fwd_inf,
)
from ..jit import no_torch_dynamo
from ..utils import cast_if_needed
from transformer_engine.pytorch.ops import LayerNorm as _LayerNormOp

__all__ = ["LayerNorm"]


class _LayerNorm(torch.autograd.Function):
"""functional LayerNorm"""

@staticmethod
def forward(
ctx,
inp: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
eps: float,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
inf_ln_sm_margin: int,
zero_centered_gamma: bool,
is_grad_enabled: bool,
activation_dtype: torch.dtype,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.view((-1, in_features))

# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype)
ln_bias = cast_if_needed(ln_bias, activation_dtype)

if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
ctx.inp_shape = inp.shape
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
else:
ln_out, mu, rsigma = (
layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, inf_ln_sm_margin, zero_centered_gamma
),
None,
None,
)
return ln_out.view_as(inp)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
inputmat, ln_weight, mu, rsigma = ctx.saved_tensors
grad_output = grad_output.contiguous()
d_ln_out = grad_output.view(inputmat.shape)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None, None

class LayerNorm(_LayerNormOp):
r"""Layer Normalization

class LayerNorm(torch.nn.Module):
r"""
Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta

:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
size :attr:`hidden_size`
:math:`\gamma` and :math:`\beta` are learnable affine transform
parameters that match the inner-most dimensions of the input
tensor.

Parameters
----------
hidden_size : int
size of each input sample.
normalized_shape: int or iterable of int
Inner dimensions of input tensor
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
A value added to the denominator of layer normalization for
numerical stability
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to

.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
If `True`, the :math:`\gamma` parameter is initialized to zero
and the calculation changes to

.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta

sm_margin: int, default = 0
timmoon10 marked this conversation as resolved.
Show resolved Hide resolved
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.

"""

def __init__(
self,
hidden_size: int,
normalized_shape: Union[Iterable[int], int],
eps: float = 1e-5,
sequence_parallel: bool = False,
params_dtype: Optional[torch.dtype] = None,
sequence_parallel: Optional[bool] = None, # deprecated
params_dtype: Optional[torch.dtype] = None, # deprecated
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
**kwargs,
) -> None:
super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.eps = eps
self.zero_centered_gamma = zero_centered_gamma
self.weight = Parameter(
torch.empty(
hidden_size,
device=device,
dtype=params_dtype,
)
)
self.bias = Parameter(
torch.empty(
hidden_size,
device=device,
dtype=params_dtype,
)
)
self.sequence_parallel = sequence_parallel

self.reset_parameters(defer_init=(device == "meta"))
# Handle deprecated options
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a little backwards. if dtype is supposed to be the new argument name, then why is it in kwargs? Both params_dtype and dtype should be regular parameters, there should be a deprecation warning when somebody uses params_dtype and also the check for duplicate assignment like the one you have here.
Also, similar treatment should be done for hidden_size and sequence_parallel (especially the last one seems to be just gone completely so there should be some explanation that it was unused before or something?)

Copy link
Collaborator Author

@timmoon10 timmoon10 Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking is that we should forward kwargs directly to te.ops.LayerNorm as much as possible so that we only have to change the API in one place if we ever make changes in the future. We include the deprecated options as explicit kwargs since they are specific to the module.

This function signature also maintains backward compatibility for users who pass in the options as positional args, e.g.:

te.LayerNorm(
inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=zero_centered_gamma
)

if params_dtype is not None:
if "dtype" in kwargs:
raise RuntimeError(
"Both `dtype` and `params_dtype` (deprecated) kwargs are provided"
)
kwargs["dtype"] = params_dtype

# Initialize layer norm operation
super().__init__(
normalized_shape,
eps=eps,
zero_centered_gamma=zero_centered_gamma,
**kwargs,
)

# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
# Flag for sequence parallelism (deprecated)
self.sequence_parallel: Optional[bool] = sequence_parallel

def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
Expand All @@ -162,51 +87,61 @@ def reset_layer_norm_parameters(self) -> None:
DeprecationWarning,
stacklevel=2,
)
if not self.zero_centered_gamma:
init.ones_(self.weight)
else:
init.zeros_(self.weight)
init.zeros_(self.bias)
self.reset_parameters()

def reset_parameters(self, defer_init=False) -> None:
def reset_parameters(self, defer_init: Optional[bool] = None) -> None:
"""Init LayerNorm parameters"""
if defer_init:
return

if self.weight.device == torch.device("meta"):
self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device="cuda"))
setattr(self.weight, "sequence_parallel", self.sequence_parallel)
init.constant_(self.weight, float(not self.zero_centered_gamma))

if self.bias.device == torch.device("meta"):
self.bias = torch.nn.Parameter(torch.empty_like(self.bias, device="cuda"))
setattr(self.bias, "sequence_parallel", self.sequence_parallel)
init.zeros_(self.bias)

@no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD"""
# Set the activation type for AMP.
TransformerEngineBaseModule.set_activation_dtype(self, inp)

if torch.is_grad_enabled():
fwd_fn = _LayerNorm.apply
args = []
else:
fwd_fn = _LayerNorm.forward
args = [None]

args += (
inp,
self.weight,
self.bias,
self.eps,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.inf_ln_sm_margin,
self.zero_centered_gamma,
torch.is_grad_enabled(),
self.activation_dtype,
)

return fwd_fn(*args)
# Check whether to defer init (deprecated)
if defer_init is not None:
warnings.warn(
'reset_parameters kwarg is deprecated. Set device to "meta" instead.',
timmoon10 marked this conversation as resolved.
Show resolved Hide resolved
DeprecationWarning,
stacklevel=2,
)
if defer_init:
return

# Reset parameters
super().reset_parameters()

# Set flag for sequence parallelism (deprecated)
if getattr(self, "sequence_parallel", None) is not None:
timmoon10 marked this conversation as resolved.
Show resolved Hide resolved
self.weight.sequence_parallel = self.sequence_parallel
self.bias.sequence_parallel = self.sequence_parallel

@property
def fwd_ln_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("fwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["fwd"]

@fwd_ln_sm_margin.setter
def fwd_ln_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("fwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["fwd"] = val

@property
def bwd_ln_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("bwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["bwd"]

@bwd_ln_sm_margin.setter
def bwd_ln_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("bwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["bwd"] = val

@property
def inf_ln_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("inf_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["inf"]

@inf_ln_sm_margin.setter
def inf_ln_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("inf_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["inf"] = val
Loading
Loading