diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 3523e1cda5..bc983bcae0 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -450,36 +450,45 @@ def test_reshape( torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) - @pytest.mark.parametrize("size", (1, 7, 32)) + @pytest.mark.parametrize("size", (7, 32)) @pytest.mark.parametrize("in_shape", ((-1,), (1, 3, -1), (2, 3, 4, -1))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("device", _devices) - @pytest.mark.parametrize("fp8", (False, True)) + @pytest.mark.parametrize("fp8_grad_input", (False, True)) def test_bias( self, *, size: int, in_shape: Iterable[int], dtype: torch.dtype, - device: torch.device, - fp8: bool, + device: torch.device = "cuda", + fp8_input: bool = False, + fp8_grad_output: bool = False, + fp8_grad_input: bool, ) -> None: # Make input and bias shapes consistent in_shape = list(in_shape)[:-1] + [size] # Skip invalid configurations - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + if fp8_input or fp8_grad_output or fp8_grad_input: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # FP8 recipe + fp8_recipe = None + if fp8_grad_input: + fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8, + test_is_fp8=fp8_input, ) b_ref, b_test = make_reference_and_test_tensors( size, @@ -490,26 +499,34 @@ def test_bias( in_shape, test_dtype=dtype, test_device=device, + test_is_fp8=(fp8_grad_output or fp8_grad_input), requires_grad=False, ) + if not fp8_grad_output and is_float8_tensor(dy_test): + dy_test = dy_test.from_float8() # Plain PyTorch implementation y_ref = x_ref + b_ref.reshape([1] * (len(in_shape) - 1) + [size]) y_ref.backward(dy_ref) # Implementation with fusible operation - op = te_ops.Bias(size, device=device, dtype=dtype) + bias_op = te_ops.Bias(size, device=device, dtype=dtype) with torch.no_grad(): - op.bias.copy_(b_test) + bias_op.bias.copy_(b_test) del b_test - y_test = op(x_test) + forward = te_ops.Sequential( + te_ops.CastFloat8(forward=False, backward=fp8_grad_input), + bias_op, + ) + with te.fp8_autocast(enabled=fp8_grad_input, fp8_recipe=fp8_recipe): + y_test = forward(x_test) y_test.backward(dy_test) # Check results tols = dtype_tols(dtype) y_test = y_test.to(dtype=torch.float64, device="cpu") dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") - db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") + db_test = bias_op.bias.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols) diff --git a/transformer_engine/pytorch/ops/__init__.py b/transformer_engine/pytorch/ops/__init__.py index f437f877b4..f65433398e 100644 --- a/transformer_engine/pytorch/ops/__init__.py +++ b/transformer_engine/pytorch/ops/__init__.py @@ -8,17 +8,7 @@ """ -from transformer_engine.pytorch.ops.basic import ( - AddInPlace, - AllGather, - AllReduce, - BasicLinear, - Bias, - Identity, - MakeExtraOutput, - ReduceScatter, - Reshape, -) +from transformer_engine.pytorch.ops.basic import * from transformer_engine.pytorch.ops.linear import Linear from transformer_engine.pytorch.ops.op import FusibleOperation from transformer_engine.pytorch.ops.sequential import Sequential diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 1003cc0337..599bac1b4e 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -9,6 +9,7 @@ from .all_reduce import AllReduce from .basic_linear import BasicLinear from .bias import Bias +from .cast_float8 import CastFloat8 from .identity import Identity from .make_extra_output import MakeExtraOutput from .reduce_scatter import ReduceScatter diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index b8e8cc5e56..d8124fd845 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -9,13 +9,15 @@ import torch -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) +from ...cpp_extensions import fp8_cast_transpose_bgrad_fused +from ...float8_tensor import Float8Tensor +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ..op import BasicOperation, OperationContext from .._common import ( canonicalize_device, canonicalize_dtype, + is_float8_tensor, + reshape, ) @@ -60,9 +62,14 @@ def __init__( if device.type == "meta": defer_param_init = True device = canonicalize_device(None) + if device.type != "cuda": + raise ValueError(f"Only CUDA devices are supported (got {device})") self.device: torch.device = device # Bias tensor datatype + dtype = canonicalize_dtype(dtype) + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") self.dtype: torch.dtype = canonicalize_dtype(dtype) # Tensor parallel configuration @@ -101,7 +108,7 @@ def reset_parameters(self) -> None: # Make sure parameter is initialized bias = self.bias - if bias.device.type != "cuda": + if bias.device.type != self.device.type: bias = torch.empty_like(bias, device=self.device) bias = bias.to(device=self.device, dtype=self.dtype) @@ -125,18 +132,89 @@ def op_forward( prev_op: Optional[BasicOperation] = None, next_op: Optional[BasicOperation] = None, ) -> torch.Tensor: + + # Apply bias x = input_ b = self.bias.reshape([1] * (x.dim() - 1) + [self.local_size]) - return x + b + y = x + b + + # Save state for backward pass + ctx.bias_requires_grad = self.bias.requires_grad + ctx.with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() + ctx.prev_op = prev_op + + return y def op_backward( self, ctx: OperationContext, grad_output: torch.Tensor, ) -> tuple[torch.Tensor, tuple[()]]: + + # Check if FP8 is enabled + with_fp8_grad_input = ( + ctx.with_fp8_compute + and ctx.prev_op is not None + and ctx.prev_op.num_fp8_scales("grad_output") > 0 + and grad_output.size(-1) % 16 == 0 + and grad_output.numel() // grad_output.size(-1) % 16 == 0 + ) + + # Compute grad bias dy = grad_output - if dy.dim() > 1: - db = dy.sum(tuple(range(dy.dim() - 1))) + db = None + dx: torch.Tensor + if not ctx.bias_requires_grad: + # Trivial case: Don't compute bgrad, don't do anything + # with dgrad + dx = dy + if not with_fp8_grad_input or is_float8_tensor(dy): + # Non-FP8 case: Compute bgrad, don't do anything with + # dgrad + if dy.dim() > 1: + db = dy.sum(tuple(range(dy.dim() - 1))) + else: + db = dy + dx = dy else: - db = dy - return dy, (db,) + # FP8 case: Call fused kernel to compute bgrad and cast + # dgrad to FP8 + + # Check grad output tensor + output_dims = grad_output.size() + dy = reshape( + dy, + (-1, output_dims[-1]), + device=self.device, + dtype=self.dtype, + ) + + # Call fused kernel for bgrad and casting dgrad to FP8 + fp8_meta = ctx.prev_op.get_fp8_meta("grad_output") + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) + fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) + fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=self.device) + db, dx_data, dx_data_transpose = fp8_cast_transpose_bgrad_fused( + dy, + fp8_meta[fp8_meta_key], + 0, + fp8_dtype, + scale_inv=fp8_scale_inv, + ) + + # Construct grad input tensor + if dx_data.size() != output_dims: + dx_data = dx_data.reshape(output_dims) + dx = Float8Tensor( + data=dx_data, + fp8_meta=fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + dtype=self.dtype, + ) + dx._transpose = dx_data_transpose + dx._transpose_invalid = False + + return dx, (db,) diff --git a/transformer_engine/pytorch/ops/basic/cast_float8.py b/transformer_engine/pytorch/ops/basic/cast_float8.py new file mode 100644 index 0000000000..482c66833e --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/cast_float8.py @@ -0,0 +1,100 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for identity.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.fp8 import ( + FP8GlobalStateManager, + get_fp8_te_dtype, +) +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + OperationContext, +) +from .._common import is_float8_tensor + + +class CastFloat8(BasicOperation): + """Cast tensor to FP8 + + Uses FP8 recipe from `fp8_autocast` context. When called outside + of an `fp8_autocast` context, this is an identity operation. + + Parameters + ---------- + forward: bool, default = `True` + Perform FP8 cast in forward pass + backward: bool, default = `False` + Perform FP8 cast in backward pass + + """ + + def __init__( + self, + forward: bool = True, + backward: bool = False, + ) -> None: + super().__init__() + self._cast_forward = forward + self._cast_backward = backward + + def num_fp8_scales(self, mode: str) -> int: + if mode == "input" and self._cast_forward: + return 1 + if mode == "grad_output" and self._cast_backward: + return 1 + return 0 + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Check if FP8 is enabled + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() + cast_forward = fp8_enabled and self._cast_forward + cast_backward = fp8_enabled and self._cast_backward + + # Cast to FP8 if needed + out = input_ + if cast_forward and not is_float8_tensor(out): + fp8_meta = self.get_fp8_meta("input") + fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + out = Float8Tensor.to_float8( + out, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + ) + + ctx.cast_backward = cast_backward + return out + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + grad_input = grad_output + if ctx.cast_backward and not is_float8_tensor(grad_input): + fp8_meta = self.get_fp8_meta("grad_output") + fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) + grad_input = Float8Tensor.to_float8( + grad_input, + fp8_meta=fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + ) + return grad_input, () diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 5fd52405e4..097da43a7e 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -72,12 +72,13 @@ def fuser_forward( idx = self._op_idxs["linear"] linear_op = self.basic_ops[idx] linear_op_ctx = basic_op_ctxs[idx] - if self._op_idxs["bias"] is None: - bias_op = None - bias = None - else: + bias_op = None + bias_op_ctx = None + bias = None + if self._op_idxs["bias"] is not None: idx = self._op_idxs["bias"] bias_op = self.basic_ops[idx] + bias_op_ctx = basic_op_ctxs[idx] bias = bias_op.bias if basic_op_kwargs[idx]: raise ValueError("Bias operation forward does not expect keyword arguments") @@ -129,7 +130,11 @@ def fuser_forward( linear_op_ctx.input_dims = input_.size() linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad - linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None + linear_op_ctx.has_prev_op = basic_op_prev_ops[self._op_idxs["linear"]] is not None + if bias_op_ctx is not None: + bias_op_ctx.bias_requires_grad = bias.requires_grad + bias_op_ctx.with_fp8_compute = with_fp8_compute + bias_op_ctx.prev_op = basic_op_prev_ops[self._op_idxs["bias"]] return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 6ddee2849a..67cfdbf173 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -70,12 +70,13 @@ def fuser_forward( idx = self._op_idxs["linear"] linear_op = self.basic_ops[idx] linear_op_ctx = basic_op_ctxs[idx] - if self._op_idxs["bias"] is None: - bias_op = None - bias = None - else: + bias_op = None + bias_op_ctx = None + bias = None + if self._op_idxs["bias"] is not None: idx = self._op_idxs["bias"] bias_op = self.basic_ops[idx] + bias_op_ctx = basic_op_ctxs[idx] bias = bias_op.bias if basic_op_kwargs[idx]: raise ValueError("Bias operation forward does not expect keyword arguments") @@ -123,7 +124,11 @@ def fuser_forward( linear_op_ctx.input_dims = input_.size() linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad - linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None + linear_op_ctx.has_prev_op = basic_op_prev_ops[self._op_idxs["linear"]] is not None + if bias_op_ctx is not None: + bias_op_ctx.bias_requires_grad = bias.requires_grad + bias_op_ctx.with_fp8_compute = with_fp8_compute + bias_op_ctx.prev_op = basic_op_prev_ops[self._op_idxs["bias"]] return output, [() for _ in range(len(self.basic_ops))]