From 2d57db8bcc5cf5562e726e978c875877c478a139 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 11 Sep 2024 06:12:03 -0700 Subject: [PATCH] [PyTorch] Proxy class for low-precision tensor (#1127) * Add base class for tensor proxies Signed-off-by: Tim Moon * Move tensor detaching logic to tensor proxy base class Signed-off-by: Tim Moon * Use Python wrappers to PyTorch extensions Signed-off-by: Tim Moon * Include transpose caching logic in proxy encode function Signed-off-by: Tim Moon * Debug dimension mismatch with amax history Signed-off-by: Tim Moon * Move dequantize logic to proxy_decode func Signed-off-by: Tim Moon * Rename to "QuantizedTensor" Signed-off-by: Tim Moon * Rename "proxy_detach" to "detach" Signed-off-by: Tim Moon * Include transpose cache in detach and clone funcs Signed-off-by: Tim Moon * Fix linter warnings Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update FP8 workspaces with QuantizedTensor functions Signed-off-by: Tim Moon * Move logic for FP8 transpose cache in FP8 workspaces to base class Signed-off-by: Tim Moon * Remove cast-transpose logic from linear op Signed-off-by: Tim Moon * Remove unnecessary args for Float8Tensor when using FP8 attr dict Signed-off-by: Tim Moon * Remove __torch_function__ to QuantizedTensor Signed-off-by: Tim Moon * Fix linter warnings Signed-off-by: Tim Moon * Update tests/pytorch/test_float8tensor.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Debug FP8 transpose test Signed-off-by: Tim Moon * Debug cast functions Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- .github/workflows/lint.yml | 2 +- tests/pytorch/test_float8tensor.py | 9 +- tests/pytorch/test_fusible_ops.py | 5 +- .../pytorch/cpp_extensions/_common.py | 6 +- .../pytorch/cpp_extensions/cast.py | 5 +- transformer_engine/pytorch/float8_tensor.py | 1001 +---------------- transformer_engine/pytorch/module/base.py | 96 +- .../pytorch/module/grouped_linear.py | 21 +- .../pytorch/module/layernorm_linear.py | 34 +- .../pytorch/module/layernorm_mlp.py | 19 +- transformer_engine/pytorch/module/linear.py | 30 +- transformer_engine/pytorch/ops/_common.py | 62 +- .../pytorch/ops/basic/all_reduce.py | 11 +- .../pytorch/ops/basic/basic_linear.py | 74 +- .../pytorch/ops/basic/reduce_scatter.py | 17 +- transformer_engine/pytorch/tensor/__init__.py | 8 + .../pytorch/tensor/float8_tensor.py | 972 ++++++++++++++++ .../pytorch/tensor/quantized_tensor.py | 172 +++ transformer_engine/pytorch/utils.py | 50 + 19 files changed, 1352 insertions(+), 1242 deletions(-) create mode 100644 transformer_engine/pytorch/tensor/__init__.py create mode 100644 transformer_engine/pytorch/tensor/float8_tensor.py create mode 100644 transformer_engine/pytorch/tensor/quantized_tensor.py diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index a4fdbdfdfd..d2bd865a8f 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -31,7 +31,7 @@ jobs: run: | sudo apt-get update sudo apt-get install pip -y - pip install torch + pip install torch numpy export PYTHON_ONLY=1 export TE_PATH=. bash ./qa/L0_pytorch_lint/test.sh diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 0ea0319771..fd204f58c4 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -293,7 +293,7 @@ def test_transpose( with pytest.raises(AssertionError): torch.testing.assert_close(x_fp8_t, x, **tols) - # Caching test. + # Caching test assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching." x_fp8 += 0.5 x = x_fp8.from_float8() @@ -302,14 +302,13 @@ def test_transpose( torch.testing.assert_close(x_fp8_t, x_t, **tols) assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." - # Inplace update test. + # Inplace update test x_fp8 += 0.5 - assert x_fp8._transpose_invalid, "Transpose cache not invalidated properly." + assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." x = x_fp8.from_float8() - x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True)) + x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8._transpose) x_t = x.transpose(0, 1) torch.testing.assert_close(x_fp8_t, x_t, **tols) - assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." def test_serialization( self, diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 3523e1cda5..e97dfe1efd 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -88,10 +88,7 @@ def make_reference_and_test_tensors( ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) test = ref.to(device=test_device, dtype=test_dtype) if test_is_fp8: - test = Float8Tensor.to_float8(test) - test._transpose = test._data.reshape(-1, test.size(-1)).transpose(0, 1) - test._transpose = test._transpose.contiguous() - test._transpose_invalid = False + test = Float8Tensor.to_float8(test, with_transpose_cache=True) elif test.data_ptr() == ref.data_ptr(): test = test.clone() ref.copy_(test) diff --git a/transformer_engine/pytorch/cpp_extensions/_common.py b/transformer_engine/pytorch/cpp_extensions/_common.py index 6ab7d95138..b9d7288dfa 100644 --- a/transformer_engine/pytorch/cpp_extensions/_common.py +++ b/transformer_engine/pytorch/cpp_extensions/_common.py @@ -68,13 +68,13 @@ def canonicalize_fp8_scales( # Force offsets to be the same if needed if not allow_multiple_offsets and not scale_offset == amax_offset == scale_inv_offset: if scale_offset != 0: - scale = scale[scale_offset] + scale = scale[scale_offset:] scale_offset = 0 if amax_offset != 0: - amax = amax[0][amax_offset] + amax = amax[:, amax_offset:] amax_offset = 0 if scale_inv_offset != 0: - scale_inv = scale_inv[scale_inv_offset] + scale_inv = scale_inv[scale_inv_offset:] scale_inv_offset = 0 # Pack tensors and offsets into dicts diff --git a/transformer_engine/pytorch/cpp_extensions/cast.py b/transformer_engine/pytorch/cpp_extensions/cast.py index 0c78a65a6c..cd3c01c785 100644 --- a/transformer_engine/pytorch/cpp_extensions/cast.py +++ b/transformer_engine/pytorch/cpp_extensions/cast.py @@ -8,7 +8,7 @@ import torch import transformer_engine_torch as tex -from ._common import canonicalize_fp8_scales, empty_tensor +from ._common import canonicalize_fp8_scales __all__ = ["cast_to_fp8", "cast_from_fp8"] @@ -81,8 +81,7 @@ def cast_from_fp8( # Construct empty tensors if needed if scale_inv is None: - scale_inv = empty_tensor() - scale_inv_offset = 0 + raise ValueError("Did not provide either `scale_inv` or `fp8_meta_tensor`") # Launch FP8 cast kernel return torch.ops.tex_ts.cast_from_fp8_ts( diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index d531979868..c3d8709925 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -3,1004 +3,7 @@ # See LICENSE for license information. """Tensor class with FP8 data""" -from __future__ import annotations -from typing import Any, Dict, Optional, Tuple, Union -import warnings -import torch -from torch.utils._pytree import tree_map -import transformer_engine_torch as tex +from .tensor import Float8Tensor -from .constants import TE_DType -from .cpp_extensions import fp8_cast_transpose_fused -from .fp8 import FP8GlobalStateManager - -aten = torch.ops.aten -c10d = torch.ops.c10d -updated_fp8_params = {} - - -def _make_fp8_attr_property_funcs(name: str) -> Any: - """Make accessors for an FP8 attribute - - We store FP8 attributes in a dictionary so we can share them - between tensors with the same data, e.g. detached tensors. For - convenience, we also expose them as property attributes. This - function creates the accessors for property attributes. - - Parameters - ---------- - name: str - Key in dictionary of FP8 attributes - - """ - - def get_func(self) -> Any: - return self._fp8_attrs[name] - - def set_func(self, value: Any) -> None: - self._fp8_attrs[name] = value - - def del_func(self) -> None: - del self._fp8_attrs[name] - - return dict(fget=get_func, fset=set_func, fdel=del_func) - - -class _FromFloat8Func(torch.autograd.Function): - """Cast from FP8 to other dtype""" - - @staticmethod - def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: Float8Tensor, - dtype: Optional[torch.dtype] = None, - ) -> torch.Tensor: - if dtype is None: - dtype = tensor.dtype - data = tensor._data.contiguous().view(1, -1).detach() - out = tex.cast_from_fp8( - data, - tensor._scale_inv, - tensor._fp8_dtype, - TE_DType[dtype], - ) - out = out.view(tensor.size()) - return out - - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # Assume that we want gradients in full precision - return grad, None - - -def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None: - """Amax scale and update when there is at least 1 trainable FP8 parameter.""" - param_id = id(param._data) - - if param_id not in FP8GlobalStateManager.fp8_param_to_autocast: - return - - autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id] - - if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params: - return - - if autocast_key in updated_fp8_params: - updated_fp8_params[autocast_key].add(param_id) - else: - updated_fp8_params[autocast_key] = {param_id} - - current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key] - # All FP8 trainable parameters have been updated. - if updated_fp8_params[autocast_key] == current_fp8_params_set: - FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True, fp8_weights=True) - del updated_fp8_params[autocast_key] - - -class _ToFloat8Func(torch.autograd.Function): - """Cast to FP8 from other dtype""" - - @staticmethod - def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: torch.Tensor, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - ) -> Float8Tensor: - - # Extract data from FP8 meta tensors if provided - if fp8_meta is not None: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=fp8_meta_forward, - ) - if fp8_meta_index is None: - raise ValueError( - "To initialize Float8Tensor with FP8 meta tensors, " - "the FP8 meta tensor index must also be provided" - ) - if scale is None: - scale = fp8_meta[fp8_meta_key].scale[fp8_meta_index] - if amax is None: - amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - - # Check input tensor - tensor = tensor.contiguous().cuda().detach() - if tensor.dtype not in (torch.float32, torch.bfloat16, torch.float16): - tensor = tensor.float() - - # Check scale - if not isinstance(scale, torch.Tensor): - if scale is None: - scale = 1 - scale = torch.full( - [1], - scale, - dtype=torch.float32, - device=tensor.device, - ) - if scale.numel() != 1: - raise ValueError("Attempted to initialize Float8Tensor with invalid scale tensor") - scale = scale.to(device=tensor.device, dtype=torch.float32) - - # Check scale-inverse - if scale_inv is None: - scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device) - else: - scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32) - - # Check amax - if amax is None: - amax = torch.empty_like(scale) - if not (amax.numel() == 1 and amax.is_cuda and amax.dtype == torch.float32): - raise ValueError("Attempted to initialize Float8Tensor with invalid amax tensor") - - # Cast data to FP8 - data = tex.cast_to_fp8( - tensor.view(1, -1), - scale, - amax, - scale_inv, - fp8_dtype, - ) - data = data.view(tensor.size()) - - # Construct FP8 tensor - return Float8Tensor( - data=data, - fp8_meta=fp8_meta, - fp8_meta_forward=fp8_meta_forward, - fp8_meta_index=fp8_meta_index, - fp8_dtype=fp8_dtype, - fp8_scale_inv=scale_inv, - dtype=tensor.dtype, - ) - - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # Assume that we want gradients in full precision - return grad, None, None, None, None, None, None, None - - -class _IdentityFunc(torch.autograd.Function): - """Identity function - - If constructor keyword-arguments are provided, then construct a - new Float8Tensor using the provided tensor's attributes. - - """ - - @staticmethod - def forward( - ctx, - tensor: Float8Tensor, - init_kwargs: Optional[Dict[str, Any]] = None, - ) -> torch.Tensor: - - # Return input tensor if constructor kwargs are not provided - ctx.input_dtype = tensor.dtype - if init_kwargs is None: - return tensor - - # Construct new tensor if constructor kwargs are provided - default_kwargs = dict( - data=tensor._data, - fp8_meta=tensor._fp8_meta, - fp8_meta_forward=tensor._fp8_meta_forward, - fp8_meta_index=tensor._fp8_meta_index, - fp8_dtype=tensor._fp8_dtype, - fp8_scale_inv=tensor._scale_inv, - dtype=tensor.dtype, - ) - for key, val in default_kwargs.items(): - if key not in init_kwargs: - init_kwargs[key] = val - return Float8Tensor(**init_kwargs) - - @staticmethod - def backward(ctx, grad): - return grad.to(ctx.input_dtype), None - - -class _ViewFunc(torch.autograd.Function): - """View function - - View the Float8Tensor using the provided shape. - - """ - - @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - shape: Tuple[int] = None, - ) -> torch.Tensor: - - # Return input tensor if shape is not provided - ctx.shape = tensor.shape - if shape is None: - return tensor - - # Construct new tensor if shape is provided - if isinstance(tensor, Float8Tensor): - return Float8Tensor.make_like( - tensor, - data=tensor._data.view(*shape), - ) - return tensor.view(*shape) - - @staticmethod - def backward( - ctx, - grad: torch.Tensor, - ) -> Tuple[Union[torch.Tensor, None], ...]: - - if isinstance(grad, Float8Tensor): - dgrad = Float8Tensor.make_like( - grad, - data=grad._data.view(ctx.shape), - ) - return dgrad, None - return grad.view(ctx.shape), None - - -class _ReshapeFunc(torch.autograd.Function): - """Reshape function - - Reshape the Float8Tensor using the provided shape. - - """ - - @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - shape: Tuple[int] = None, - ) -> torch.Tensor: - - # Return input tensor if shape is not provided - ctx.shape = tensor.shape - if shape is None: - return tensor - - # Construct new tensor if shape is provided - if isinstance(tensor, Float8Tensor): - return Float8Tensor.make_like( - tensor, - data=tensor._data.reshape(*shape), - ) - return tensor.reshape(*shape) - - @staticmethod - def backward( - ctx, - grad: torch.Tensor, - ) -> Tuple[Union[torch.Tensor, None], ...]: - - if isinstance(grad, Float8Tensor): - dgrad = Float8Tensor.make_like( - grad, - data=grad._data.reshape(ctx.shape), - ) - return dgrad, None - return grad.reshape(ctx.shape), None - - -class Float8Tensor(torch.Tensor): - """Experimental tensor class with FP8 data - - The tensor presents as having a standard, higher-precision dtype, - but the data itself is (scaled) FP8. For most tensor operations, - the data will be cast to the nominal dtype before performing the - operation. - - Parameters - ---------- - data: torch.Tensor - Raw FP8 data in a uint8 tensor - fp8_attrs: dict, optional - FP8 metadata, primarily managed by Float8Tensor. If - provided, all other FP8 configuration is ignored. - fp8_meta: dict, optional - FP8 metadata object, primarily managed by TE modules. - fp8_meta_forward: bool, default = `True` - Whether to access the FP8 metadata for the - forward pass. Ignored if fp8_meta is not - provided. - fp8_meta_index: int, optional - Index to access in FP8 meta tensors. Required if - fp8_meta is provided and otherwise ignored. - fp8_dtype: transformer_engine_torch.DType, tex.DType.kFloat8E4M3 - FP8 format. - fp8_scale_inv: torch.Tensor - Reciprocal of the scaling factor applied when - casting to FP8, i.e. the scaling factor that must - be applied when casting from FP8 to higher - precision. Can be inferred from fp8_meta if - provided. - dtype: torch.dtype, default = torch.float32 - Nominal tensor datatype. - - """ - - def __new__( - cls, - *, - data: torch.Tensor, - fp8_attrs: Optional[Dict[str, Any]] = None, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, - fp8_scale_inv: Optional[torch.Tensor] = None, - dtype: torch.dtype = torch.float32, - ): - - # Check that data buffer is valid - if data.element_size() != 1: - raise ValueError( - f"Float8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})" - ) - if data.requires_grad: - raise ValueError("Float8Tensor requires non-differentiable data buffer") - if not data.is_cuda: - data = data.cuda() - - # Initialize tensor object - self = torch.Tensor._make_wrapper_subclass( - cls, - data.size(), - strides=data.stride(), - storage_offset=data.storage_offset(), - dtype=dtype, - layout=data.layout, - requires_grad=data.requires_grad, - device=data.device, - ) - self._data: torch.Tensor = data - - # Initialize dict of class attributes - # Note: We store FP8 attributes in a dictionary so we can - # share them between tensors with the same data, e.g. detached - # tensors. - self._fp8_attrs: dict = {} - if fp8_attrs is not None: - self._fp8_attrs = fp8_attrs - return self - - # FP8 meta tensors - if fp8_meta is not None and fp8_meta_index is None: - raise ValueError( - "To initialize Float8Tensor with FP8 meta tensors, " - "the FP8 meta tensor index must also be provided" - ) - self._fp8_meta: Optional[Dict[str, Any]] = fp8_meta - self._fp8_meta_forward: bool = fp8_meta_forward - self._fp8_meta_index: Optional[int] = fp8_meta_index - - # FP8 dtype - assert fp8_dtype in ( - tex.DType.kFloat8E4M3, - tex.DType.kFloat8E5M2, - ), f"Unsupported fp8_dtype {fp8_dtype}." - self._fp8_dtype: tex.DType = fp8_dtype - - # Transposed version of `_data`. - self._transpose: Optional[Float8Tensor] = None - self._transpose_invalid: bool = True - - # FP8 scale-inverse - if fp8_scale_inv is None and self._fp8_meta is not None: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=self._fp8_meta_forward, - ) - fp8_scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] - fp8_scale_inv = fp8_scale_inv.detach().view(1).clone() - if fp8_scale_inv is None: - raise ValueError( - "Attempted to initialize Float8Tensor without specifying scale-inverse" - ) - if not isinstance(fp8_scale_inv, torch.Tensor): - fp8_scale_inv = torch.full( - [1], - fp8_scale_inv, - dtype=torch.float32, - device=self._data.device, - ) - if fp8_scale_inv.numel() != 1: - raise ValueError( - "Attempted to initialize Float8Tensor with invalid scale-inverse tensor" - ) - if fp8_scale_inv.dim() != 1: - fp8_scale_inv = fp8_scale_inv.reshape(1) - if fp8_scale_inv.device != self._data.device or fp8_scale_inv.dtype != torch.float32: - fp8_scale_inv = fp8_scale_inv.to( - device=self._data.device, - dtype=torch.float32, - ) - self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv - - return self - - @classmethod - def make_like( - cls, - tensor: Float8Tensor, - *, - data: torch.Tensor, - fp8_attrs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Float8Tensor: - """Use attributes of a Float8Tensor to create another Float8Tensor - - See constructor for list of keyword arguments. - - """ - default_kwargs = dict( - fp8_meta=tensor._fp8_meta, - fp8_meta_forward=tensor._fp8_meta_forward, - fp8_meta_index=tensor._fp8_meta_index, - fp8_dtype=tensor._fp8_dtype, - fp8_scale_inv=tensor._scale_inv, - dtype=tensor.dtype, - ) - for key, val in default_kwargs.items(): - if key not in kwargs: - kwargs[key] = val - return Float8Tensor(data=data, fp8_attrs=fp8_attrs, **kwargs) - - def __repr__(self): - return ( - "Float8Tensor(" - f"fp8_dtype={self._fp8_dtype}, " - f"scale_inv={self._scale_inv.item()}, " - f"data={self.from_float8(dtype=self.dtype)}" - ")" - ) - - def from_float8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: - """ - Construct plain PyTorch tensor from Float8Tensor - - By default the resulting tensor's dtype is the - Float8Tensor's nominal dtype. - """ - return _FromFloat8Func.apply(self, dtype) - - @classmethod - def to_float8( - cls, - tensor: torch.Tensor, - *, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - ): - """Construct Float8Tensor from plain PyTorch tensor""" - return _ToFloat8Func.apply( - tensor, - fp8_meta, - fp8_meta_forward, - fp8_meta_index, - fp8_dtype, - scale, - amax, - scale_inv, - ) - - def float(self) -> torch.Tensor: - return self.from_float8(dtype=torch.float32) - - def bfloat16(self) -> torch.Tensor: - return self.from_float8(dtype=torch.bfloat16) - - def half(self) -> torch.Tensor: - return self.from_float8(dtype=torch.float16) - - def cpu(self) -> torch.Tensor: - return self.from_float8().cpu() - - def clone(self) -> Float8Tensor: - return _IdentityFunc.apply(self, {"data": self._data.detach().clone()}) - - def view(self, *shape: Tuple[int]) -> Float8Tensor: - return _ViewFunc.apply(self, shape) - - def reshape(self, *shape: Tuple[int]) -> Float8Tensor: - return _ReshapeFunc.apply(self, shape) - - def expand_as(self, other: torch.Tensor): - if other is self: - # Note: expand_as is hackily used to create dummy autograd nodes - # and access the backward graph (see - # https://github.com/pytorch/pytorch/blob/238fb660851268f44ff88127887041fea352fe48/torch/nn/parallel/distributed.py#L1026). - # We equally hackily add a dummy function to handle this - # case. - return _IdentityFunc.apply(self) - return super().expand_as(other) - - def contiguous( - self, - *, - memory_format: torch.memory_format = torch.contiguous_format, - ) -> Float8Tensor: - """Returns tensor with data in provided memory format - - Returns `self` if data is already in correct memory format. - - """ - if self._data.is_contiguous(memory_format=memory_format): - return self - return _IdentityFunc.apply( - self, - {"data": self._data.detach().contiguous(memory_format=memory_format)}, - ) - - def transpose_2d( - self, - *, - force_compute: bool = False, - fill_cache: bool = False, - noop_flag: Optional[torch.Tensor] = None, - cache: Optional[bool] = None, - ) -> torch.Tensor: - """ - 2D transpose with caching support. - - Parameters - ---------- - force_compute: bool, default = `False` - Force computation of transpose. Otherwise use - cached values, if possible. - fill_cache: bool, default = `False` - Cache output tensor for future function calls. - noop_flag: torch.Tensor, optional - float32 flag indicating whether to avoid updating - cached values, if possible. - cache: bool, deprecated - - """ - assert self.dim() == 2, f"{self.dim()}-D transpose not supported." - - # Handle deprecated cache kwarg - if cache is not None: - msg = ( - "cache kwarg for Float8Tensor.transpose_2d is deprecated, " - "please use force_compute and fill_cache instead" - ) - warnings.warn(msg, DeprecationWarning) - if cache: - force_compute = False - fill_cache = True - else: - force_compute = True - fill_cache = False - - # Need to compute transpose if cache is invalid - need_compute = force_compute - if self._transpose is None: - need_compute = True - elif self._transpose_invalid: - need_compute = True - - # Need to apply transpose kernel if noop flag is applied - if noop_flag is not None: - need_compute = True - - # Return cached transpose if possible - if not need_compute: - return self._transpose - - # Allocate output if needed - data = self._data.contiguous().reshape(-1, self.size(-1)) - out = self._transpose - if out is None: - out = torch.empty( - (data.size(1), data.size(0)), - dtype=torch.uint8, - device=data.device, - ) - noop_flag = None - else: - self._transpose_invalid = False - - # Apply transpose kernel - fp8_dtype = self._fp8_dtype - if noop_flag is None: - tex.fp8_transpose_noalloc(data, out, fp8_dtype) - else: - noop_flag = noop_flag.to(dtype=torch.float32, device=data.device) - tex.fp8_transpose_noalloc_noop(data, out, noop_flag, fp8_dtype) - - # Fill cache if needed - if fill_cache: - self._transpose = out - self._transpose_invalid = False - - return out - - @torch.no_grad() - def cast_transpose_( - self, - tensor: torch.Tensor, - noop_flag: Optional[torch.Tensor] = None, - ) -> None: - """Cast from tensor and populate transpose cache - - Only supported for 2D tensors. - - Parameters - ---------- - tensor: torch.Tensor - Tensor to copy from. Must have same dimensions as - destination tensor. - noop_flag: torch.Tensor, optional - float32 flag indicating whether to avoid updating - destination tensor. - - """ - - # Make sure tensor is in expected format - data = self._data - if ( - tensor.device != data.device - or tensor.dtype not in (torch.float32, torch.float16, torch.bfloat16) - or not tensor.is_contiguous() - ): - dtype = tensor.dtype - if dtype not in (torch.float32, torch.float16, torch.bfloat16): - dtype = torch.float32 - tensor = tensor.to( - device=self.device, - dtype=dtype, - memory_format=torch.contiguous_format, - ) - if tensor.size() != data.size() or data.dim() != 2: - raise ValueError( - "Invalid tensor dimensions for FP8 cast-transpose " - f"(src={tuple(tensor.size())}, dst={tuple(data.size())})" - ) - if not data.is_contiguous(): - raise ValueError( - "FP8 cast-transpose is only supported for `Float8Tensor`s with contiguous data" - ) - if self._fp8_meta is None: - raise ValueError( - "FP8 cast-transpose is only supported for `Float8Tensor`s with FP8 metadata " - ) - - # Construct transpose cache if needed - transpose = self._transpose - if transpose is None or not transpose.is_contiguous(): - transpose = torch.empty( - (data.size(1), data.size(0)), - dtype=torch.uint8, - device=data.device, - ) - self._transpose = transpose - noop_flag = None - - # Launch cast-transpose kernel - fp8_meta_index = int(self._fp8_meta_index) - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=self._fp8_meta_forward, - ) - fp8_meta = self._fp8_meta[fp8_meta_key] - fp8_cast_transpose_fused( - tensor, - fp8_meta, - fp8_meta_index, - self._fp8_dtype, - cast_out=data, - transpose_out=transpose, - scale_inv=self._scale_inv, - noop_flag=noop_flag, - ) - self._transpose_invalid = False - - @torch.no_grad() - def reset_fp8_meta_scale_inv(self) -> None: - """Replace FP8 meta tensor scale-inverse with cached value - - The FP8 meta tensor scale_inv entry corresponding to this - tensor is replaced with the scale_inv value used to construct - the tensor. - - """ - assert self._fp8_meta is not None, "FP8 meta tensors not found." - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=self._fp8_meta_forward, - ) - self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0]) - - def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: - """Create `Float8Tensor` with given nominal dtype - - The new tensor has the same underlying FP8 data. - - """ - return Float8Tensor.make_like( - self, - data=self._data, - fp8_attrs=self._fp8_attrs, - dtype=dtype, - ) - - def _reset_caches(self) -> None: - """ - Set transpose cache as invalid. - Should be called after any in-place operation. - """ - self._transpose_invalid = True - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - - # In-place copy op - if func == aten.copy_.default: - - # Check tensors - dst = args[0] - src = args[1] - if not isinstance(dst, torch.Tensor): - raise RuntimeError("Attempted to copy into something that isn't a PyTorch tensor") - if not isinstance(src, torch.Tensor): - raise RuntimeError("Attempted to copy from something that isn't a PyTorch tensor") - - # Special handling based on which tensors are FP8 - dst_is_fp8 = isinstance(dst, Float8Tensor) - src_is_fp8 = isinstance(src, Float8Tensor) - if dst_is_fp8 and src_is_fp8: - - # Directly copy FP8 data if possible - if dst._fp8_dtype == src._fp8_dtype: - dst._data.copy_(src._data) - dst._scale_inv.copy_(src._scale_inv.detach()) - if dst._fp8_meta is not None: - if src._fp8_meta is None: - src_min, src_max = src.from_float8().aminmax() - src_amax = torch.maximum(-src_min, src_max) - else: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=src._fp8_meta_forward, - ) - fp8_meta_index = src._fp8_meta_index - src_amax = src._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dst._fp8_meta_forward, - ) - fp8_meta_index = dst._fp8_meta_index - dst_amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - torch.maximum(src_amax, dst_amax, out=dst_amax) - else: - dst.copy_(src.from_float8()) - - elif not dst_is_fp8 and src_is_fp8: - - # Cast source tensor to higher precision - dst.copy_(src.from_float8()) - - elif dst_is_fp8 and not src_is_fp8: - # Make sure input is in expected format - src = src.expand(dst.size()) - src = src.to( - device=dst.device, - memory_format=torch.contiguous_format, - ) - - # Update scaling factor if FP8 meta tensors are available - if dst._fp8_meta is None: - scale = dst._scale_inv.reciprocal() - amax = torch.empty_like(scale) - else: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dst._fp8_meta_forward, - ) - fp8_meta_index = dst._fp8_meta_index - scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index] - amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - - # Cast to FP8 - if not dst._data.is_contiguous(): - raise RuntimeError("Transformer Engine cast kernels require contiguous data") - tex.cast_to_fp8_noalloc( - src.view(1, -1), - scale, - dst._data.view(1, -1), - amax, - dst._scale_inv, - dst._fp8_dtype, - ) - - # This branch is where the FP8 parameters are updated in-place during optimization. - # Handle forward amax reduction. - post_optimizer_step_fwd_amax_reduction(dst) - else: - - # Invalid case - raise RuntimeError("Using Float8Tensor copy logic, but no Float8Tensor found") - - # Nothing to return for in-place ops - if dst_is_fp8: - dst._reset_caches() - - return None - - # Slice op - if func == aten.slice.Tensor: - tensor = args[0] - data = tensor._data - data_slice = data.__torch_dispatch__( - func, - types, - [data] + list(args[1:]), - kwargs, - ) - return Float8Tensor.make_like(tensor, data=data_slice) - - # Detach op - if func == aten.detach.default: - # Simply return a new Float8Tensor with the same attrs - return Float8Tensor.make_like( - args[0], - data=args[0]._data, - fp8_attrs=args[0]._fp8_attrs, - ) - - # View op - if func == aten.view.default: - tensor = args[0] - data = tensor._data - data_view = data.__torch_dispatch__( - func, - types, - [data] + list(args[1:]), - kwargs, - ) - return Float8Tensor.make_like( - tensor, - data=data_view, - fp8_attrs=tensor._fp8_attrs, - ) - - def maybe_unwrap(t): - if isinstance(t, Float8Tensor): - return t.from_float8() - return t - - def maybe_update_inplace(arg, new_arg, schema_arg): - """Update values of FP8 tensors - - Keep the same FP8 scaling factors. - - """ - if ( - isinstance(arg, Float8Tensor) - and isinstance(new_arg, torch.Tensor) - and hasattr(schema_arg, "alias_info") - and hasattr(schema_arg.alias_info, "is_write") - and schema_arg.alias_info.is_write - ): - arg.copy_(new_arg) - arg._reset_caches() - - # In-place op - if func._schema.is_mutable: - # Cast to higher precision, perform op, and cast values - # back to original FP8 buffers - new_args = tree_map(maybe_unwrap, args) - new_kwargs = tree_map(maybe_unwrap, kwargs) - schema_args = func._schema.arguments - args_len = len(args) - out = super().__torch_dispatch__(func, types, new_args, new_kwargs) - for arg, new_arg, schema_arg in zip(args, new_args, schema_args): - maybe_update_inplace(arg, new_arg, schema_arg) - for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): - assert kwarg == new_kwarg == schema_arg.name, "name of the kw argument should match" - maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg) - return None - - # Default op - # Note: cast to higher precision and perform op - args = tree_map(maybe_unwrap, args) - if kwargs is not None: - kwargs = tree_map(maybe_unwrap, kwargs) - out = super().__torch_dispatch__(func, types, args, kwargs) - return out - - @classmethod - def _make_in_reduce_ex( - cls, - data: torch.Tensor, - fp8_dtype: tex.DType, - fp8_scale_inv: torch.Tensor, - dtype: torch.dtype, - ) -> Float8Tensor: - """Build Float8Tensor, for use in __reduce__ - - __reduce_ex__ assumes object constructor has positional - arguments. - - """ - return Float8Tensor( - data=data, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - dtype=dtype, - ) - - def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling to remove references to FP8 metadata objects""" - return ( - Float8Tensor._make_in_reduce_ex, - (self._data, self._fp8_dtype, self._scale_inv, self.dtype), - ) - - def _get_data(self) -> Float8Tensor: - """Get tensor data property""" - return super().data - - def _set_data(self, tensor: torch.Tensor) -> None: - """Set tensor data property - - Cast tensor to FP8 and store in FP8 buffer. - - """ - with torch.no_grad(): - self.copy_(tensor) - - # Cast to FP8 when setting Float8Tensor.data - data = property(_get_data, _set_data) - - # Accessors for objects in self._fp8_attrs - # Note: We store FP8 attributes in a dictionary so we can share - # them between tensors with the same data, e.g. detached tensors. - # For convenience, we also expose them as property attributes. - _fp8_meta = property(**_make_fp8_attr_property_funcs("fp8_meta")) - _fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward")) - _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) - _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) - _transpose = property(**_make_fp8_attr_property_funcs("transpose")) - _transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid")) - _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - return torch._C._disabled_torch_function_impl(func, types, args, kwargs) +__all__ = ["Float8Tensor"] diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3613e1fa5e..3375b8ab7d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -865,11 +865,17 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # If primary weights are in fp8, wrap the parameter as Float8Tensor fp8_meta_index = self.param_init_meta[name].fp8_meta_index if self.primary_weights_in_fp8 and fp8_meta_index is not None: + dummy_amax = torch.empty( + (1, 1), + dtype=torch.float32, + device=param.device, + ) # Dummy buffer to avoid overwriting amax history param = Float8Tensor.to_float8( param, fp8_meta=self.fp8_meta, fp8_meta_index=fp8_meta_index, - amax=torch.empty(1, device="cuda"), # Dummy amax to avoid overwriting history. + amax=dummy_amax, + with_transpose_cache=torch.is_grad_enabled(), ) # Redo parameter wrap in case we broke it above @@ -891,7 +897,6 @@ def get_fp8_workspace( cache_name: Optional[str] = None, update_workspace: bool = True, skip_update_flag: Optional[torch.Tensor] = None, - with_transpose: bool = False, fsdp_group: dist_group_type = None, ) -> Float8Tensor: """Get FP8 workspace buffer and maybe update its values @@ -917,27 +922,30 @@ def get_fp8_workspace( skip_update_flag: torch.Tensor, optional GPU flag to skip updating the workspace. Take precedence over `update_workspace` if provided. - with_transpose: bool, default = `False` - Whether to initialize cached transpose in workspace. fsdp_group: bool, default = None FSDP process group that the weights are distributed over. """ - # Construct workspace if needed + # Try getting workspace from cache out = None if cache_name is not None: out = self._fp8_workspaces.get(cache_name, None) - # Gather cached Fp8 workspace if it's distributed - # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work - # for models initialized with Fp8 primary weights. - if ( - not isinstance(out, Float8Tensor) - and fsdp_group is not None - and out._data.shape != tensor.data.shape - ): - _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out) + # Gather cached Fp8 workspace if it's distributed + # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work + # for models initialized with Fp8 primary weights. + if ( + out is not None + and not isinstance(out, Float8Tensor) + and fsdp_group is not None + and out._data.shape != tensor.data.shape + ): + _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out) + + # Construct workspace if needed if out is None: + + # FP8 data if tensor is None or fp8_meta_forward is None or fp8_meta_index is None: raise ValueError( "tensor, fp8_meta_forward, and fp8_meta_index kwargs " @@ -947,16 +955,38 @@ def get_fp8_workspace( self.fp8_meta["recipe"], fprop_tensor=fp8_meta_forward, ) + data = torch.empty_like(tensor, dtype=torch.uint8) scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device) + + # Transpose cache + with_transpose_cache = torch.is_grad_enabled() + if ( + not with_transpose_cache + and is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() + ): + with_transpose_cache = True + data_transpose = None + if with_transpose_cache: + data_transpose = torch.empty( + (tensor.size(-1), tensor.numel() // tensor.size(-1)), + dtype=torch.uint8, + device=tensor.device, + ) + + # Construct FP8 tensor out = Float8Tensor( - data=torch.empty_like(tensor, dtype=torch.uint8), + data=data, fp8_meta=self.fp8_meta, fp8_meta_forward=fp8_meta_forward, fp8_meta_index=fp8_meta_index, fp8_dtype=fp8_dtype, fp8_scale_inv=scale_inv, dtype=tensor.dtype, + data_transpose=data_transpose, ) + + # Update cache if cache_name is not None: self._fp8_workspaces[cache_name] = out update_workspace = True @@ -968,33 +998,17 @@ def get_fp8_workspace( if update_workspace: if tensor is None: raise ValueError("tensor kwarg must be provided to update FP8 workspace") - if with_transpose: - out.cast_transpose_( - tensor, - noop_flag=skip_update_flag, - ) + if is_in_onnx_export_mode(): + # ONNX export does not support fused cast-transpose + # kernel and requires that FP8 scales can be + # represented with constant ops. + transpose_cache = out._transpose + out._transpose = None + out.quantize_(tensor) + out._scale_inv.fill_(out._scale_inv.item()) + out._transpose = transpose_cache else: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=out._fp8_meta_forward, - ) - fp8_meta = out._fp8_meta[fp8_meta_key] - fp8_meta_index = out._fp8_meta_index - cast_to_fp8( - tensor, - fp8_meta, - fp8_meta_index, - out._fp8_dtype, - out=out._data, - ) - if is_in_onnx_export_mode(): - # ONNX export expects FP8 scales can be - # represented with constant ops. However, copying - # into a buffer involves an expand op for array - # broadcasting. We work around this by filling the - # buffer instead. - out._scale_inv.fill_(fp8_meta.scale_inv[fp8_meta_index].item()) - else: - out._scale_inv.copy_(fp8_meta.scale_inv[fp8_meta_index]) + out.quantize_(tensor, noop_flag=skip_update_flag) return out diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index ca100392c7..10c8d91551 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -28,8 +28,6 @@ from ..distributed import ( set_tensor_model_parallel_attributes, get_distributed_world_size, - is_fp8_activation_recompute_enabled, - in_fp8_activation_recompute_phase, ) from ..cpp_extensions import ( cast_to_fp8, @@ -760,22 +758,12 @@ def forward( weight_tensors_fp8 = [None] * self.num_gemms if self.fp8: - with_transpose = torch.is_grad_enabled() - if ( - not with_transpose - and is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ): - with_transpose = True for i in range(self.num_gemms): if isinstance(weight_tensors[i], Float8Tensor): - # Fill transpose cache in FP8 tensor if needed - update_transpose_cache = with_transpose - if update_transpose_cache: - update_transpose_cache = ( - is_first_microbatch or skip_fp8_weight_update is not None - ) - if update_transpose_cache: + # Make sure transpose cache is valid, if present + # Note: Transpose cache may have been invalidated + # externally, e.g. by optimizer. + if weight_tensors[i]._transpose is not None: weight_tensors[i].transpose_2d( fill_cache=True, noop_flag=skip_fp8_weight_update, @@ -790,7 +778,6 @@ def forward( cache_name=(None if is_first_microbatch is None else f"weight{i}"), update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, - with_transpose=with_transpose, ) from ..cpu_offload import CPUOffloadEnabled diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 9586d6d345..da77879e06 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -36,8 +36,6 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, - is_fp8_activation_recompute_enabled, - in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, _fsdp_gather_tensors, ) @@ -47,6 +45,7 @@ from ._common import _apply_normalization, _noop_cat from ..float8_tensor import Float8Tensor from ..export import is_in_onnx_export_mode +from ..tensor import QuantizedTensor __all__ = ["LayerNormLinear"] @@ -1151,14 +1150,14 @@ def forward( # Get concatenated weight and bias tensors unfused_weights = [getattr(self, name) for name in self.weight_names] - if any(isinstance(w, Float8Tensor) for w in unfused_weights): + if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if self.fp8: if len(unfused_weights) != 1: raise RuntimeError( - "Splitting Float8Tensor into multiple params is not supported" + "Splitting QuantizedTensor into multiple params is not supported" ) else: - unfused_weights = [w.from_float8() for w in unfused_weights] + unfused_weights = [w.dequantize() for w in unfused_weights] weight_tensor = _noop_cat(unfused_weights) if self.use_bias: bias_tensor = _noop_cat( @@ -1170,32 +1169,18 @@ def forward( # Initialize FP8 weights if needed weight_fp8 = None if self.fp8: - with_transpose = torch.is_grad_enabled() - if ( - not with_transpose - and is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ): - with_transpose = True if isinstance(weight_tensor, Float8Tensor): - # Fill transpose cache in FP8 tensor if needed - update_transpose_cache = with_transpose - if update_transpose_cache: - update_transpose_cache = ( - is_first_microbatch or skip_fp8_weight_update is not None - ) - if update_transpose_cache: + # Make sure transpose cache is valid, if present + # Note: Transpose cache may have been invalidated + # externally, e.g. by optimizer. + if weight_tensor._transpose is not None: weight_tensor.transpose_2d( fill_cache=True, noop_flag=skip_fp8_weight_update, ) else: # FP8 cast to workspace buffer - update_workspace = ( - is_first_microbatch is None - or is_first_microbatch - or skip_fp8_weight_update is not None - ) + update_workspace = is_first_microbatch is None or is_first_microbatch weight_fp8 = self.get_fp8_workspace( tensor=weight_tensor, fp8_meta_forward=True, @@ -1203,7 +1188,6 @@ def forward( cache_name=(None if is_first_microbatch is None else "weight"), update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, - with_transpose=with_transpose, ) from ..cpu_offload import CPUOffloadEnabled diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index dc9bef645f..b802c972d4 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -42,8 +42,6 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, - is_fp8_activation_recompute_enabled, - in_fp8_activation_recompute_phase, use_reentrant_activation_recompute, _fsdp_scatter_tensors, _fsdp_gather_tensors, @@ -1485,19 +1483,8 @@ def forward( fc2_weight_fp8 = None if self.fp8: update_workspace = is_first_microbatch is None or is_first_microbatch - with_transpose = torch.is_grad_enabled() - if ( - is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ): - with_transpose = True - update_transpose_cache = with_transpose - if update_transpose_cache: - update_transpose_cache = ( - is_first_microbatch or skip_fp8_weight_update is not None - ) if isinstance(fc1_weight, Float8Tensor): - if update_transpose_cache: + if fc1_weight._transpose is not None: fc1_weight.transpose_2d( fill_cache=True, noop_flag=skip_fp8_weight_update, @@ -1513,10 +1500,9 @@ def forward( cache_name=cache_name, update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, - with_transpose=with_transpose, ) if isinstance(fc2_weight, Float8Tensor): - if update_transpose_cache: + if fc2_weight._transpose is not None: fc2_weight.transpose_2d( fill_cache=True, noop_flag=skip_fp8_weight_update, @@ -1532,7 +1518,6 @@ def forward( cache_name=cache_name, update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, - with_transpose=with_transpose, ) # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f92a2db2d9..a7be82ccf1 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -33,8 +33,6 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, - is_fp8_activation_recompute_enabled, - in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, _fsdp_gather_tensors, ) @@ -49,6 +47,7 @@ from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor from ..export import is_in_onnx_export_mode +from ..tensor import QuantizedTensor __all__ = ["Linear"] @@ -938,19 +937,19 @@ def forward( with self.prepare_forward( inp, is_first_microbatch, - allow_non_contiguous=isinstance(inp, Float8Tensor), + allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: # Get concatenated weight and bias tensors unfused_weights = [getattr(self, name) for name in self.weight_names] - if any(isinstance(w, Float8Tensor) for w in unfused_weights): + if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if self.fp8: if len(unfused_weights) != 1: raise RuntimeError( - "Splitting Float8Tensor into multiple params is not supported" + "Splitting QuantizedTensor into multiple params is not supported" ) else: - unfused_weights = [w.from_float8() for w in unfused_weights] + unfused_weights = [w.dequantize() for w in unfused_weights] weight_tensor = _noop_cat(unfused_weights) if self.use_bias: bias_tensor = _noop_cat( @@ -962,21 +961,11 @@ def forward( # Initialize FP8 weights if needed weight_fp8 = None if self.fp8: - with_transpose = torch.is_grad_enabled() - if ( - not with_transpose - and is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ): - with_transpose = True if isinstance(weight_tensor, Float8Tensor): - # Fill transpose cache in FP8 tensor if needed - update_transpose_cache = with_transpose - if update_transpose_cache: - update_transpose_cache = ( - is_first_microbatch or skip_fp8_weight_update is not None - ) - if update_transpose_cache: + # Make sure transpose cache is valid, if present + # Note: Transpose cache may have been invalidated + # externally, e.g. by optimizer. + if weight_tensor._transpose is not None: weight_tensor.transpose_2d( fill_cache=True, noop_flag=skip_fp8_weight_update, @@ -991,7 +980,6 @@ def forward( cache_name=(None if is_first_microbatch is None else "weight"), update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, - with_transpose=with_transpose, fsdp_group=self.fsdp_group, ) diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 77efef4ab6..12270d8340 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -9,54 +9,12 @@ import torch -from transformer_engine.pytorch.float8_tensor import Float8Tensor - - -def canonicalize_device(device: Optional[torch.device | str]) -> torch.device: - """Canonicalize PyTorch device - - If `None`, then returns the default CUDA device. - - """ - if device is None: - # Use default CUDA device - device = torch.get_default_device() - if device.type != "cuda": - device = torch.device("cuda", torch.cuda.current_device()) - elif not isinstance(device, torch.device): - device = torch.device(device) - if device.type == "cuda" and device.index is None: - device = torch.device("cuda", torch.cuda.current_device()) - return device - - -def canonicalize_dtype(dtype: Optional[torch.dtype]) -> torch.dtype: - """Canonicalize PyTorch datatype - - If `None`, then returns the default PyTorch datatype. - - """ - if dtype is None: - # Use default dtype - dtype = torch.get_default_dtype() - return dtype - - -def devices_match(device1: torch.device, device2: torch.device) -> bool: - """Whether two devices are the same""" - device1 = torch.device(device1) - device2 = torch.device(device2) - if device1.type != device2.type: - return False - if device1.type == "cuda": - index1 = device1.index - index2 = device2.index - if index1 is None: - index1 = torch.cuda.current_device() - if index2 is None: - index2 = torch.cuda.current_device() - return index1 == index2 - return device1 == device2 +from ..tensor import Float8Tensor +from ..utils import ( + canonicalize_device, # pylint: disable=unused-import + canonicalize_dtype, # pylint: disable=unused-import + devices_match, # pylint: disable=unused-import +) def is_float8_tensor(tensor: Any) -> bool: @@ -92,7 +50,13 @@ def convert_tensor( # Convert FP8 tensor if is_float8_tensor(tensor): - data = tensor._data.to(device=device, memory_format=memory_format) + data = tensor._data + if not devices_match(device, data.device): + data = data.to(device=device) + if memory_format != torch.preserve_format and not data.is_contiguous( + memory_format=memory_format + ): + data = data.contiguous(memory_format=memory_format) return Float8Tensor.make_like( tensor, data=data, diff --git a/transformer_engine/pytorch/ops/basic/all_reduce.py b/transformer_engine/pytorch/ops/basic/all_reduce.py index 622346b1c5..f466ade3a3 100644 --- a/transformer_engine/pytorch/ops/basic/all_reduce.py +++ b/transformer_engine/pytorch/ops/basic/all_reduce.py @@ -9,11 +9,8 @@ import torch -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) -from .._common import is_float8_tensor +from ...tensor import QuantizedTensor +from ..op import BasicOperation, OperationContext class AllReduce(BasicOperation): @@ -54,8 +51,8 @@ def op_forward( # Perform all-reduce x = input_ - if is_float8_tensor(x): - x = x.from_float8() + if isinstance(x, QuantizedTensor): + x = x.dequantize() x = x.contiguous() torch.distributed.all_reduce(x, group=self.process_group) return x diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 826807d1c0..ce72dd8a55 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -289,10 +289,18 @@ def reset_parameters(self) -> None: # Cast to FP8 if needed if self._with_fp8_parameters: + dummy_amax = torch.empty( + (1, 1), + dtype=torch.float32, + device=self.device, + ) # Dummy buffer to avoid overwriting amax history weight = Float8Tensor.to_float8( weight, fp8_meta=self.get_fp8_meta("param"), + fp8_meta_forward=True, fp8_meta_index=0, + amax=dummy_amax, + with_transpose_cache=torch.is_grad_enabled(), ) # Save updated parameter @@ -467,25 +475,19 @@ def _functional_forward( input_fp8_meta["recipe"], fprop_tensor=True, ) - x_fp8 = Float8Tensor( - data=torch.empty_like(x_local, dtype=torch.uint8), + with_transpose_cache = weight.requires_grad + if tensor_parallel_mode == "column" and sequence_parallel: + with_transpose_cache = False + x_local = Float8Tensor.to_float8( + x_local, fp8_meta=input_fp8_meta, fp8_meta_forward=True, fp8_meta_index=0, fp8_dtype=fp8_dtype, - fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), - dtype=dtype, + with_transpose_cache=with_transpose_cache, ) - with_cast_transpose = weight.requires_grad - if tensor_parallel_mode == "column" and sequence_parallel: - with_cast_transpose = False - if with_cast_transpose: - x_fp8.cast_transpose_(x_local) - else: - x_fp8.copy_(x_local) - x_local = x_fp8 elif not with_fp8_compute and is_float8_tensor(x_local): - x_local = x_local.from_float8() + x_local = x_local.dequantize() x = x_local x_async = None if tensor_parallel_mode == "column" and sequence_parallel: @@ -510,11 +512,12 @@ def _functional_forward( w = Float8Tensor.to_float8( w, fp8_meta=weight_fp8_meta, + fp8_meta_forward=True, fp8_meta_index=0, fp8_dtype=fp8_dtype, ) elif not with_fp8_compute and is_float8_tensor(w): - w = w.from_float8() + w = w.dequantize() # Check bias tensor b = None @@ -815,25 +818,19 @@ def _functional_backward( grad_output_fp8_meta["recipe"], fprop_tensor=False, ) - dy_fp8 = Float8Tensor( - data=torch.empty_like(dy, dtype=torch.uint8), + with_transpose_cache = weight_requires_grad + if tensor_parallel_mode == "row" and sequence_parallel: + with_transpose_cache = False + dy = Float8Tensor.to_float8( + dy, fp8_meta=grad_output_fp8_meta, fp8_meta_forward=False, fp8_meta_index=0, fp8_dtype=fp8_dtype, - fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), - dtype=dtype, + with_transpose_cache=with_transpose_cache, ) - with_cast_transpose = weight_requires_grad - if tensor_parallel_mode == "row" and sequence_parallel: - with_cast_transpose = False - if with_cast_transpose: - dy_fp8.cast_transpose_(dy) - else: - dy_fp8.copy_(dy) - dy = dy_fp8 elif not with_fp8_compute and is_float8_tensor(dy): - dy = dy.from_float8() + dy = dy.dequantize() if tensor_parallel_mode == "row" and sequence_parallel: dy, dy_async = gather_along_first_dim( dy, @@ -853,26 +850,24 @@ def _functional_backward( device=device, dtype=dtype, ) + x_is_sharded = tensor_parallel_mode == "column" and sequence_parallel if with_fp8_compute and not is_float8_tensor(x_local): fp8_dtype = get_fp8_te_dtype( input_fp8_meta["recipe"], fprop_tensor=True, ) - x_fp8 = Float8Tensor( - data=torch.empty_like(x_local, dtype=torch.uint8), + x_local = Float8Tensor.to_float8( + x_local, fp8_meta=input_fp8_meta, fp8_meta_forward=True, fp8_meta_index=0, fp8_dtype=fp8_dtype, - fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), - dtype=dtype, + with_transpose_cache=(not x_is_sharded), ) - x_fp8.cast_transpose_(x_local) - x_local = x_fp8 elif not with_fp8_compute and is_float8_tensor(x_local): x_local = x_local.from_float8() x = x_local - if tensor_parallel_mode == "column" and sequence_parallel: + if x_is_sharded: x, x_async = gather_along_first_dim( x_local, tensor_parallel_group, @@ -898,19 +893,16 @@ def _functional_backward( weight_fp8_meta["recipe"], fprop_tensor=True, ) - w_fp8 = Float8Tensor( - data=torch.empty_like(w, dtype=torch.uint8), + w = Float8Tensor.to_float8( + w, fp8_meta=weight_fp8_meta, fp8_meta_forward=True, fp8_meta_index=0, fp8_dtype=fp8_dtype, - fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), - dtype=dtype, + with_transpose_cache=True, ) - w_fp8.cast_transpose_(w) - w = w_fp8 elif not with_fp8_compute and is_float8_tensor(w): - w = w.from_float8() + w = w.dequantize() # Construct grad input tensor if grad_input is not None: diff --git a/transformer_engine/pytorch/ops/basic/reduce_scatter.py b/transformer_engine/pytorch/ops/basic/reduce_scatter.py index 996ca2da31..c78dbc2877 100644 --- a/transformer_engine/pytorch/ops/basic/reduce_scatter.py +++ b/transformer_engine/pytorch/ops/basic/reduce_scatter.py @@ -9,12 +9,9 @@ import torch -from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) -from .._common import convert_tensor, is_float8_tensor +from ...tensor import Float8Tensor, QuantizedTensor +from ..op import BasicOperation, OperationContext +from .._common import convert_tensor class ReduceScatter(BasicOperation): @@ -63,8 +60,8 @@ def op_forward( # Check input tensor x = input_ - if is_float8_tensor(x): - x = x.from_float8() + if isinstance(x, QuantizedTensor): + x = x.dequantize() x = x.contiguous() # Perform reduce-scatter @@ -96,7 +93,7 @@ def op_backward( # Perform all-gather dy = convert_tensor(grad_output, memory_format=torch.contiguous_format) dx = None - if is_float8_tensor(dy): + if isinstance(dy, Float8Tensor): dx = Float8Tensor.make_like( dy, data=torch.empty( @@ -111,6 +108,8 @@ def op_backward( group=self.process_group, ) else: + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device) torch.distributed.all_gather_into_tensor( dx, diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py new file mode 100644 index 0000000000..2bad862768 --- /dev/null +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Custom tensor classes""" + +from .float8_tensor import Float8Tensor +from .quantized_tensor import QuantizedTensor diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py new file mode 100644 index 0000000000..610523a10d --- /dev/null +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -0,0 +1,972 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with FP8 data""" +from __future__ import annotations +from typing import Any, Dict, Optional, Tuple +import warnings + +import torch +import transformer_engine_torch as tex + +from transformer_engine_torch import DType as TE_DType +from ..constants import TE_DType as torch_to_transformer_engine_dtype +from ..cpp_extensions import ( + cast_from_fp8, + cast_to_fp8, + fp8_cast_transpose_fused, +) +from ..fp8 import FP8GlobalStateManager +from ..utils import devices_match +from .quantized_tensor import QuantizedTensor + +aten = torch.ops.aten +updated_fp8_params = {} + + +def _make_fp8_attr_property_funcs(name: str) -> Any: + """Make accessors for an FP8 attribute + + We store FP8 attributes in a dictionary so we can share them + between tensors with the same data, e.g. detached tensors. For + convenience, we also expose them as property attributes. This + function creates the accessors for property attributes. + + Parameters + ---------- + name: str + Key in dictionary of FP8 attributes + + """ + + def get_func(self) -> Any: + return self._fp8_attrs[name] + + def set_func(self, value: Any) -> None: + self._fp8_attrs[name] = value + + def del_func(self) -> None: + del self._fp8_attrs[name] + + return dict(fget=get_func, fset=set_func, fdel=del_func) + + +class _FromFloat8Func(torch.autograd.Function): + """Cast from FP8 to other dtype""" + + @staticmethod + def forward( + _ctx: torch.autograd.function.FunctionCtx, # unused + tensor: Float8Tensor, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + return tensor.dequantize(dtype=dtype) + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # Assume that we want gradients in full precision + return grad, None + + +def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None: + """Amax scale and update when there is at least 1 trainable FP8 parameter.""" + param_id = id(param._data) + + if param_id not in FP8GlobalStateManager.fp8_param_to_autocast: + return + + autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id] + + if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params: + return + + if autocast_key in updated_fp8_params: + updated_fp8_params[autocast_key].add(param_id) + else: + updated_fp8_params[autocast_key] = {param_id} + + current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key] + # All FP8 trainable parameters have been updated. + if updated_fp8_params[autocast_key] == current_fp8_params_set: + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True, fp8_weights=True) + del updated_fp8_params[autocast_key] + + +class _ToFloat8Func(torch.autograd.Function): + """Cast to FP8 from other dtype""" + + @staticmethod + def forward( + _ctx: torch.autograd.function.FunctionCtx, # unused + tensor: torch.Tensor, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + with_transpose_cache: bool = False, + ) -> Float8Tensor: + + # Tensor attributes + dtype = tensor.dtype + if dtype not in (torch.float32, torch.bfloat16, torch.float16): + dtype = torch.float32 + device = tensor.device + if device.type != "cuda": + device = torch.device("cuda") + + # FP8 data buffer + data = torch.empty(tensor.size(), dtype=torch.uint8, device=device) + + # Check scale + if scale is None and fp8_meta is None: + scale = 1 + if scale is not None: + if isinstance(scale, torch.Tensor): + scale = scale.to(device=device, dtype=torch.float32) + else: + scale = torch.full([1], scale, dtype=torch.float32, device=device) + + # Check scale-inverse + if scale_inv is None: + scale_inv = torch.empty([1], dtype=torch.float32, device=device) + elif not devices_match(scale_inv.device, device) or scale_inv.dtype != dtype: + scale_inv = scale_inv.to(device=device, dtype=torch.float32) + + # Transpose cache + data_transpose = None + if with_transpose_cache: + data_transpose = torch.empty( + (data.size(-1), data.numel() // data.size(-1)), + dtype=torch.uint8, + device=tensor.device, + ) + + # Construct FP8 tensor + out = Float8Tensor( + data=data, + fp8_meta=fp8_meta, + fp8_meta_forward=fp8_meta_forward, + fp8_meta_index=fp8_meta_index, + fp8_dtype=fp8_dtype, + fp8_scale_inv=scale_inv, + dtype=dtype, + data_transpose=data_transpose, + ) + + # Cast to FP8 tensor + out.quantize_(tensor, scale=scale, amax=amax) + + return out + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # Assume that we want gradients in full precision + return grad, None, None, None, None, None, None, None + + +class _IdentityFunc(torch.autograd.Function): + """Identity function + + If constructor keyword-arguments are provided, then construct a + new Float8Tensor using the provided tensor's attributes. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8Tensor, + init_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + + # Return input tensor if constructor kwargs are not provided + ctx.input_dtype = tensor.dtype + if init_kwargs is None: + return tensor + + # Construct new tensor if constructor kwargs are provided + default_kwargs = dict( + data=tensor._data, + fp8_meta=tensor._fp8_meta, + fp8_meta_forward=tensor._fp8_meta_forward, + fp8_meta_index=tensor._fp8_meta_index, + fp8_dtype=tensor._fp8_dtype, + fp8_scale_inv=tensor._scale_inv, + dtype=tensor.dtype, + ) + for key, val in default_kwargs.items(): + if key not in init_kwargs: + init_kwargs[key] = val + return Float8Tensor(**init_kwargs) + + @staticmethod + def backward(ctx, grad): + return grad.to(ctx.input_dtype), None + + +class _ViewFunc(torch.autograd.Function): + """View function + + View the Float8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: torch.Tensor, + shape: Tuple[int] = None, + ) -> torch.Tensor: + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Construct new tensor if shape is provided + if isinstance(tensor, Float8Tensor): + return Float8Tensor.make_like( + tensor, + data=tensor._data.view(*shape), + ) + return tensor.view(*shape) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + + if isinstance(grad, Float8Tensor): + dgrad = Float8Tensor.make_like( + grad, + data=grad._data.view(ctx.shape), + ) + return dgrad, None + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the Float8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: torch.Tensor, + shape: Tuple[int] = None, + ) -> torch.Tensor: + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Construct new tensor if shape is provided + if isinstance(tensor, Float8Tensor): + return Float8Tensor.make_like( + tensor, + data=tensor._data.reshape(*shape), + ) + return tensor.reshape(*shape) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + + if isinstance(grad, Float8Tensor): + dgrad = Float8Tensor.make_like( + grad, + data=grad._data.reshape(ctx.shape), + ) + return dgrad, None + return grad.reshape(ctx.shape), None + + +class Float8Tensor(QuantizedTensor): + """Experimental tensor class with FP8 data + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP8. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Parameters + ---------- + data: torch.Tensor + Raw FP8 data in a uint8 tensor + fp8_attrs: dict, optional + FP8 metadata, primarily managed by Float8Tensor. If + provided, all other FP8 configuration is ignored. + fp8_meta: dict, optional + FP8 metadata object, primarily managed by TE modules. + fp8_meta_forward: bool, default = `True` + Whether to access the FP8 metadata for the + forward pass. Ignored if fp8_meta is not + provided. + fp8_meta_index: int, optional + Index to access in FP8 meta tensors. Required if + fp8_meta is provided and otherwise ignored. + fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3 + FP8 format. + fp8_scale_inv: torch.Tensor + Reciprocal of the scaling factor applied when + casting to FP8, i.e. the scaling factor that must + be applied when casting from FP8 to higher + precision. Can be inferred from fp8_meta if + provided. + dtype: torch.dtype, default = torch.float32 + Nominal tensor datatype. + + """ + + def __new__( + cls, + *, + data: torch.Tensor, + fp8_attrs: Optional[Dict[str, Any]] = None, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, + fp8_scale_inv: Optional[torch.Tensor] = None, + dtype: torch.dtype = torch.float32, + data_transpose: Optional[torch.Tensor] = None, + ): + + # Check that data buffer is valid + if data.element_size() != 1: + raise ValueError( + f"Float8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})" + ) + if data.requires_grad: + raise ValueError("Float8Tensor requires non-differentiable data buffer") + if not data.is_cuda: + data = data.cuda() + + # Initialize tensor object + self = torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=dtype, + layout=data.layout, + requires_grad=data.requires_grad, + device=data.device, + ) + self._data: torch.Tensor = data + + # Initialize dict of class attributes + # Note: We store FP8 attributes in a dictionary so we can + # share them between tensors with the same data, e.g. detached + # tensors. + self._fp8_attrs: dict + if fp8_attrs is None: + self._fp8_attrs = {} + else: + self._fp8_attrs = fp8_attrs + return self + + # FP8 meta tensors + if fp8_meta is not None and fp8_meta_index is None: + raise ValueError( + "To initialize Float8Tensor with FP8 meta tensors, " + "the FP8 meta tensor index must also be provided" + ) + self._fp8_meta: Optional[Dict[str, Any]] = fp8_meta + self._fp8_meta_forward: bool = fp8_meta_forward + self._fp8_meta_index: Optional[int] = fp8_meta_index + + # FP8 dtype + assert fp8_dtype in ( + TE_DType.kFloat8E4M3, + TE_DType.kFloat8E5M2, + ), f"Unsupported fp8_dtype {fp8_dtype}." + self._fp8_dtype: TE_DType = fp8_dtype + + # FP8 scale-inverse + if fp8_scale_inv is None and self._fp8_meta is not None: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=self._fp8_meta_forward, + ) + fp8_scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] + fp8_scale_inv = fp8_scale_inv.detach().view(1).clone() + if fp8_scale_inv is None: + raise ValueError( + "Attempted to initialize Float8Tensor without specifying scale-inverse" + ) + if not isinstance(fp8_scale_inv, torch.Tensor): + fp8_scale_inv = torch.full( + [1], + fp8_scale_inv, + dtype=torch.float32, + device=self._data.device, + ) + if fp8_scale_inv.numel() != 1: + raise ValueError( + "Attempted to initialize Float8Tensor with invalid scale-inverse tensor" + ) + if fp8_scale_inv.dim() != 1: + fp8_scale_inv = fp8_scale_inv.reshape(1) + if ( + not devices_match(fp8_scale_inv.device, self._data.device) + or fp8_scale_inv.dtype != torch.float32 + ): + fp8_scale_inv = fp8_scale_inv.to( + device=self._data.device, + dtype=torch.float32, + ) + self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv + + # FP8 transpose cache + self._transpose: Optional[Float8Tensor] = data_transpose + self._transpose_invalid: bool = self._transpose is None + + return self + + @classmethod + def make_like( + cls, + tensor: Float8Tensor, + *, + data: torch.Tensor, + fp8_attrs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Float8Tensor: + """Use attributes of a Float8Tensor to create another Float8Tensor + + See constructor for list of keyword arguments. + + """ + default_kwargs = dict( + fp8_meta=tensor._fp8_meta, + fp8_meta_forward=tensor._fp8_meta_forward, + fp8_meta_index=tensor._fp8_meta_index, + fp8_dtype=tensor._fp8_dtype, + fp8_scale_inv=tensor._scale_inv, + dtype=tensor.dtype, + ) + for key, val in default_kwargs.items(): + if key not in kwargs: + kwargs[key] = val + return Float8Tensor(data=data, fp8_attrs=fp8_attrs, **kwargs) + + def __repr__(self): + return ( + "Float8Tensor(" + f"fp8_dtype={self._fp8_dtype}, " + f"scale_inv={self._scale_inv.item()}, " + f"data={self.from_float8(dtype=self.dtype)}" + ")" + ) + + def dequantize(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + + # Convert PyTorch dtype to TE dtype + if dtype is None: + dtype = self.dtype + dtype = torch_to_transformer_engine_dtype[dtype] + + # Make sure FP8 data is in expected format + data = self._data + if data.device.type != "cuda": + data = data.cuda() + if not data.is_contiguous(): + data = data.contiguous() + if data.dim() != 2: + data = data.view(1, -1) + + # Cast from FP8 + out = cast_from_fp8( + data.view(1, -1), + None, # fp8_meta_tensor + None, # fp8_tensor + self._fp8_dtype, + dtype, + scale_inv=self._scale_inv, + ) + + # Make sure output is in expected format + if out.size() != self.size(): + out = out.view(self.size()) + return out + + def from_float8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from Float8Tensor + + By default the resulting tensor's dtype is the + Float8Tensor's nominal dtype. + """ + return _FromFloat8Func.apply(self, dtype) + + def quantize_( + self, + tensor: torch.Tensor, + *, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + noop_flag: Optional[torch.Tensor] = None, + ) -> Float8Tensor: + """Update FP8 data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + scale: torch.Tensor, optional + Scaling factor to use for FP8 quantization + amax: torch.Tensor, optional + History of maximum absolute values. The first entry will + be updated with the absmax of `tensor`. + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + src = tensor + dst = self + + # In-place operations invalidate transpose cache + self._reset_caches() + + # Special logic if other tensor is Float8Tensor + if isinstance(src, Float8Tensor): + + # Cast to plain tensor if FP8 dtypes don't match + if dst._fp8_dtype != src._fp8_dtype: + return dst.quantize_(src.dequantize()) + + # Directly copy FP8 data + dst._data.copy_(src._data.detach()) + dst._scale_inv.copy_(src._scale_inv.detach()) + if amax is not None or dst._fp8_meta is not None: + src_amax: torch.Tensor + if src._fp8_meta is None: + src_min, src_max = src.dequantize().aminmax() + src_amax = torch.maximum(-src_min, src_max) + else: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=src._fp8_meta_forward, + ) + fp8_meta_index = src._fp8_meta_index + src_amax = src._fp8_meta[fp8_meta_key].amax_history[0, fp8_meta_index] + dst_amax: torch.Tensor + if amax is None: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=dst._fp8_meta_forward, + ) + fp8_meta_index = dst._fp8_meta_index + dst_amax = dst._fp8_meta[fp8_meta_key].amax_history[0, fp8_meta_index] + else: + dst_amax = amax + if dst_amax.dim() > 0: + dst_amax = dst_amax[tuple([0] * dst_amax.dim())] + torch.maximum(src_amax, dst_amax, out=dst_amax) + if dst._transpose is not None: + if src._transpose is None: + dst.transpose_2d(force_compute=True, fill_cache=True) + else: + dst._transpose.copy_(src._transpose) + dst._transpose_invalid = False + return self + + # Convert QuantizedTensor to plain tensor + if isinstance(src, QuantizedTensor): + return dst.quantize_(src.dequantize()) + + # Make sure input is in expected format + if src.size() != dst.size(): + src = src.expand(dst.size()) + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if src.dtype not in (torch.float32, torch.bfloat16, torch.float16): + src = src.float() + if not src.is_contiguous(): + src = src.contiguous() + + # Make sure FP8 scaling factors are in expected format + if scale is not None: + if isinstance(scale, torch.Tensor): + if not devices_match(scale.device, dst.device) or scale.dtype != torch.float32: + scale = scale.to(device=dst.device, dtype=torch.float32) + else: + scale = torch.full([1], scale, dtype=torch.float32, device=dst.device) + if amax is not None: + while amax.dim() < 2: + amax = amax.unsqueeze(0) + if not devices_match(amax.device, dst.device): + raise ValueError( + f"Invalid device for amax (expected {dst.device}, found {amax.device})" + ) + if amax.dtype != torch.float32: + raise ValueError(f"Invalid dtype for amax (expected float32, found {amax.type})") + + # Default FP8 scaling factors + fp8_meta = None + if dst._fp8_meta is None: + if scale is None: + scale = dst._scale_inv.reciprocal() + if amax is None: + amax = torch.empty((1, 1), dtype=torch.float32, device=dst.device) + else: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=dst._fp8_meta_forward, + ) + fp8_meta = dst._fp8_meta[fp8_meta_key] + + # Check local data + if not dst._data.is_contiguous(): + raise RuntimeError("Transformer Engine cast kernels require contiguous data") + + # Perform FP8 cast + if dst._transpose is None: + dst_data = dst._data + if src.dim() != 2: + src = src.view(1, -1) + dst_data = dst_data.view(1, -1) + cast_to_fp8( + src, + fp8_meta, + dst._fp8_meta_index, + dst._fp8_dtype, + out=dst_data, + scale=scale, + amax=amax, + scale_inv=dst._scale_inv, + ) + else: + fp8_cast_transpose_fused( + src.view(-1, src.size(-1)), + fp8_meta, + dst._fp8_meta_index, + dst._fp8_dtype, + cast_out=dst._data, + transpose_out=dst._transpose, + scale=scale, + amax=amax, + scale_inv=dst._scale_inv, + noop_flag=noop_flag, + ) + dst._transpose_invalid = False + + # Callback hook to perform amax reduction after optimizer step + post_optimizer_step_fwd_amax_reduction(self) + + return self + + @classmethod + def to_float8( + cls, + tensor: torch.Tensor, + *, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + with_transpose_cache: bool = False, + ): + """Construct Float8Tensor from plain PyTorch tensor""" + return _ToFloat8Func.apply( + tensor, + fp8_meta, + fp8_meta_forward, + fp8_meta_index, + fp8_dtype, + scale, + amax, + scale_inv, + with_transpose_cache, + ) + + def detach(self) -> Float8Tensor: + return Float8Tensor.make_like( + self, + data=self._data, + fp8_attrs=self._fp8_attrs, + ) + + def clone(self) -> Float8Tensor: + data = self._data.detach().clone() + data_transpose = None + if self._transpose is not None: + data_transpose = self._transpose.detach().clone() + return _IdentityFunc.apply( + self, + dict( + data=data, + data_transpose=data_transpose, + ), + ) + + def view(self, *shape: Tuple[int]) -> Float8Tensor: + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> Float8Tensor: + return _ReshapeFunc.apply(self, shape) + + def contiguous( + self, + *, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> Float8Tensor: + """Returns tensor with data in provided memory format + + Returns `self` if data is already in correct memory format. + + """ + if self._data.is_contiguous(memory_format=memory_format): + return self + return _IdentityFunc.apply( + self, + {"data": self._data.detach().contiguous(memory_format=memory_format)}, + ) + + def transpose_2d( + self, + *, + force_compute: bool = False, + fill_cache: bool = False, + noop_flag: Optional[torch.Tensor] = None, + cache: Optional[bool] = None, + ) -> torch.Tensor: + """ + 2D transpose with caching support. + + Parameters + ---------- + force_compute: bool, default = `False` + Force computation of transpose. Otherwise use + cached values, if possible. + fill_cache: bool, default = `False` + Cache output tensor for future function calls. + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid updating + cached values, if possible. + cache: bool, deprecated + + """ + + # Handle deprecated cache kwarg + if cache is not None: + msg = ( + "cache kwarg for Float8Tensor.transpose_2d is deprecated, " + "please use force_compute and fill_cache instead" + ) + warnings.warn(msg, DeprecationWarning) + if cache: + force_compute = False + fill_cache = True + else: + force_compute = True + fill_cache = False + + # Need to compute transpose if cache is invalid + need_compute = force_compute + if self._transpose is None: + need_compute = True + elif self._transpose_invalid: + need_compute = True + + # Need to apply transpose kernel if noop flag is applied + if noop_flag is not None: + need_compute = True + + # Return cached transpose if possible + if not need_compute: + return self._transpose + + # Allocate output if needed + data = self._data.contiguous().reshape(-1, self.size(-1)) + out = self._transpose + if out is None: + out = torch.empty( + (data.size(1), data.size(0)), + dtype=torch.uint8, + device=data.device, + ) + noop_flag = None + else: + self._transpose_invalid = False + + # Apply transpose kernel + fp8_dtype = self._fp8_dtype + if noop_flag is None: + tex.fp8_transpose_noalloc(data, out, fp8_dtype) + else: + noop_flag = noop_flag.to(dtype=torch.float32, device=data.device) + tex.fp8_transpose_noalloc_noop(data, out, noop_flag, fp8_dtype) + + # Fill cache if needed + if fill_cache: + self._transpose = out + self._transpose_invalid = False + + return out + + @torch.no_grad() + def cast_transpose_( + self, + tensor: torch.Tensor, + noop_flag: Optional[torch.Tensor] = None, + ) -> None: + """Cast from tensor and populate transpose cache + + Tensor is reshaped as a 2D matrix. + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from. Must have same dimensions as + destination tensor. + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid updating + destination tensor. + + """ + if self._transpose is None: + self._transpose = torch.empty( + (self.size(-1), self.numel() // self.size(-1)), + dtype=torch.uint8, + device=self.device, + ) + self.quantize_(tensor, noop_flag=noop_flag) + + @torch.no_grad() + def reset_fp8_meta_scale_inv(self) -> None: + """Replace FP8 meta tensor scale-inverse with cached value + + The FP8 meta tensor scale_inv entry corresponding to this + tensor is replaced with the scale_inv value used to construct + the tensor. + + """ + assert self._fp8_meta is not None, "FP8 meta tensors not found." + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=self._fp8_meta_forward, + ) + self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0]) + + def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: + """Create `Float8Tensor` with given nominal dtype + + The new tensor has the same underlying FP8 data. + + """ + return Float8Tensor.make_like( + self, + data=self._data, + fp8_attrs=self._fp8_attrs, + dtype=dtype, + ) + + def _reset_caches(self) -> None: + """ + Set transpose cache as invalid. + Should be called after any in-place operation. + """ + self._transpose_invalid = True + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # Slice op + if func == aten.slice.Tensor: + tensor = args[0] + data = tensor._data + data_slice = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=data_slice) + + # View op + if func == aten.view.default: + tensor = args[0] + data = tensor._data + data_view = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=data_view) + + # Default case + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def _make_in_reduce_ex( + cls, + data: torch.Tensor, + fp8_dtype: TE_DType, + fp8_scale_inv: torch.Tensor, + dtype: torch.dtype, + ) -> Float8Tensor: + """Build Float8Tensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return Float8Tensor( + data=data, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + dtype=dtype, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling to remove references to FP8 metadata objects""" + return ( + Float8Tensor._make_in_reduce_ex, + (self._data, self._fp8_dtype, self._scale_inv, self.dtype), + ) + + def _get_data(self) -> Float8Tensor: + """Get tensor data property""" + return super().data + + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Cast tensor to FP8 and store in FP8 buffer. + + """ + with torch.no_grad(): + self.copy_(tensor) + + # Cast to FP8 when setting Float8Tensor.data + data = property(_get_data, _set_data) + + # Accessors for objects in self._fp8_attrs + # Note: We store FP8 attributes in a dictionary so we can share + # them between tensors with the same data, e.g. detached tensors. + # For convenience, we also expose them as property attributes. + _fp8_meta = property(**_make_fp8_attr_property_funcs("fp8_meta")) + _fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward")) + _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) + _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) + _transpose = property(**_make_fp8_attr_property_funcs("transpose")) + _transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid")) + _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py new file mode 100644 index 0000000000..f890b0878a --- /dev/null +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -0,0 +1,172 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor with quantized data""" + +from __future__ import annotations +from typing import Optional, Tuple + +import torch +from torch.utils._pytree import tree_map + + +class _DequantizeFunc(torch.autograd.Function): + """Autograd function to convert quantized tensor to standard tensor""" + + @staticmethod + def forward( + _ctx: torch.autograd.function.FunctionCtx, # unused + tensor: QuantizedTensor, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + return tensor.dequantize(dtype=dtype) + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + return grad, None + + +class _IdentityFunc(torch.autograd.Function): + """Autograd function to create quantized tensor with same data""" + + @staticmethod + def forward( + _ctx: torch.autograd.function.FunctionCtx, # unused + tensor: QuantizedTensor, + ) -> QuantizedTensor: + return tensor.detach() + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> torch.Tensor: + return grad + + +class QuantizedTensor(torch.Tensor): + """Abstract base class for tensor with quantized data + + This is a proxy class with the interface of a standard PyTorch + tensor, but with data that has been encoded with some quantization + scheme. Derived classes should implement the quantization scheme + by overriding the `quantize_` and `dequantize` functions. + + """ + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """Convert quantized data to standard PyTorch tensor""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement dequantize function" + ) + + def quantize_(self, tensor: torch.Tensor) -> QuantizedTensor: + """Update quantized data in-place""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement quantize_ function" + ) + + def detach(self) -> QuantizedTensor: + """Create new quantized tensor with same data + + Output tensor must be detached from the current autograd + graph. + + """ + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement detach function" + ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" + + def float(self) -> torch.Tensor: + return _DequantizeFunc.apply(self, torch.float32) + + def bfloat16(self) -> torch.Tensor: + return _DequantizeFunc.apply(self, torch.bfloat16) + + def half(self) -> torch.Tensor: + return _DequantizeFunc.apply(self, torch.float16) + + def cpu(self) -> torch.Tensor: + return _DequantizeFunc.apply(self).cpu() + + def expand_as(self, other: torch.Tensor) -> torch.Tensor: + if other is self: + # Note: expand_as is hackily used to create dummy autograd nodes + # and access the backward graph (see + # https://github.com/pytorch/pytorch/blob/238fb660851268f44ff88127887041fea352fe48/torch/nn/parallel/distributed.py#L1026). + # We hackily add a dummy function to handle this case. + return _IdentityFunc.apply(self) + return super().expand_as(other) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # Detach op + if func == torch.ops.aten.detach.default: + return args[0].detach() + + # In-place copy op + if func == torch.ops.aten.copy_.default: + dst = args[0] + src = args[1] + if isinstance(dst, QuantizedTensor): + dst.quantize_(src) + else: + if isinstance(src, QuantizedTensor): + src = src.dequantize() + dst.copy_(src) + return None + + # View op + if func == torch.ops.aten.view.default: + raise NotImplementedError("{cls.__name__} class does not support tensor views") + + def maybe_unwrap(arg): + if isinstance(arg, QuantizedTensor): + return arg.dequantize(dtype=arg.dtype) + return arg + + def maybe_update_inplace(arg, new_arg, schema_arg): + if ( + isinstance(arg, QuantizedTensor) + and isinstance(new_arg, torch.Tensor) + and hasattr(schema_arg, "alias_info") + and hasattr(schema_arg.alias_info, "is_write") + and schema_arg.alias_info.is_write + ): + arg.quantize_(new_arg) + + # In-place op: dequantize, perform op, and quantize + if func._schema.is_mutable: + new_args = tree_map(maybe_unwrap, args) + new_kwargs = tree_map(maybe_unwrap, kwargs) + schema_args = func._schema.arguments + args_len = len(args) + super().__torch_dispatch__(func, types, new_args, new_kwargs) + for arg, new_arg, schema_arg in zip(args, new_args, schema_args): + maybe_update_inplace(arg, new_arg, schema_arg) + for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): + assert kwarg == new_kwarg == schema_arg.name, "name of the kw argument should match" + maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg) + return None + + # Default op: dequantize and perform op + args = tree_map(maybe_unwrap, args) + if kwargs is not None: + kwargs = tree_map(maybe_unwrap, kwargs) + out = super().__torch_dispatch__(func, types, args, kwargs) + return out + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + # Do not force the QuantizedTensor type on the returned tensor + return torch._C._disabled_torch_function_impl(func, types, args, kwargs) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 5e3fa05f52..d5145455b8 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Utility functions for Transformer Engine modules""" +from __future__ import annotations import functools import math from typing import Any, Callable, Optional, Tuple @@ -251,3 +252,52 @@ def get_cudnn_version() -> Tuple[int, int, int]: major, encoded_version = divmod(encoded_version, major_version_magnitude) minor, patch = divmod(encoded_version, 100) return (major, minor, patch) + + +def canonicalize_device(device: Optional[torch.device | str]) -> torch.device: + """Canonicalize PyTorch device + + If `None`, then returns the default CUDA device. + + """ + if device is None: + # Use default CUDA device + device = torch.get_default_device() + if device.type != "cuda": + device = torch.device("cuda", torch.cuda.current_device()) + elif not isinstance(device, torch.device): + device = torch.device(device) + if device.type == "cuda" and device.index is None: + device = torch.device("cuda", torch.cuda.current_device()) + return device + + +def canonicalize_dtype(dtype: Optional[torch.dtype]) -> torch.dtype: + """Canonicalize PyTorch datatype + + If `None`, then returns the default PyTorch datatype. + + """ + if dtype is None: + # Use default dtype + dtype = torch.get_default_dtype() + return dtype + + +def devices_match(device1: torch.device, device2: torch.device) -> bool: + """Whether two devices are the same""" + device1 = torch.device(device1) + device2 = torch.device(device2) + if device1.type != device2.type: + return False + if device1.type == "cuda": + index1 = device1.index + index2 = device2.index + if index1 == index2: + return True + if index1 is None: + index1 = torch.cuda.current_device() + if index2 is None: + index2 = torch.cuda.current_device() + return index1 == index2 + return device1 == device2