From f2da5eba0c2b03898f8218821dff3390db480a50 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 20 Aug 2024 01:09:32 +0000 Subject: [PATCH 01/13] Add Userbuffers support for column TP linear layer Signed-off-by: Tim Moon --- .../test_fusible_ops_with_userbuffers.py | 438 +++++++++++++++++ tests/pytorch/utils.py | 83 ++++ .../pytorch/ops/basic/basic_linear.py | 12 +- .../pytorch/ops/fused/__init__.py | 4 + .../pytorch/ops/fused/userbuffers_linear.py | 449 ++++++++++++++++++ transformer_engine/pytorch/ops/fuser.py | 2 + transformer_engine/pytorch/ops/linear.py | 2 + 7 files changed, 988 insertions(+), 2 deletions(-) create mode 100644 tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py create mode 100644 tests/pytorch/utils.py create mode 100644 transformer_engine/pytorch/ops/fused/userbuffers_linear.py diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py new file mode 100644 index 0000000000..94421008ce --- /dev/null +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -0,0 +1,438 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +import argparse +import dataclasses +import functools +import itertools +import os +import pathlib +import subprocess +import sys + +import pytest +import torch + +import transformer_engine +import transformer_engine.pytorch as te +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.pytorch.ops._common import is_float8_tensor +from transformer_engine.pytorch.ops.fused import UserbuffersLinear +from transformer_engine.pytorch.utils import is_bf16_compatible +import transformer_engine_torch as tex + +# Import utility functions +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import dtype_tols, str_to_dtype + +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +# Check if there are multiple GPUs +if torch.cuda.device_count() < 2: + pytest.skip("Userbuffers requires at least 2 GPUs.") + + +@dataclasses.dataclass +class ModelConfig: + sequence_length: int + batch_size: int + num_heads: int + head_dim: int + dtype: torch.dtype + fp8: bool + + @property + def sequence_and_batch_size(self): + return self.sequence_length * self.batch_size + + @property + def hidden_size(self): + return self.num_heads * self.head_dim + + +@functools.cache +def launcher() -> str: + if "OMPI_COMM_WORLD_SIZE" in os.environ: + return "ompi" + if "TORCHELASTIC_RUN_ID" in os.environ: + return "torchrun" + raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`") + + +@functools.cache +def world_group() -> torch.distributed.ProcessGroup: + """Get NCCL process group, initializing if needed""" + + # Get launch config from environment + if launcher() == "ompi": + # OpenMPI + world_size = int(os.getenv("OMPI_COMM_WORLD_SIZE")) + rank = int(os.getenv("OMPI_COMM_WORLD_RANK")) + local_size = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE")) + local_rank = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")) + elif launcher() == "torchrun": + # torchrun + world_size = int(os.getenv("WORLD_SIZE")) + rank = int(os.getenv("RANK")) + local_size = int(os.getenv("LOCAL_WORLD_SIZE")) + local_rank = int(os.getenv("LOCAL_RANK")) + else: + raise RuntimeError("Unexpected launcher ({launcher()})") + + # Construct communicator + assert local_size == world_size + torch.cuda.set_device(local_rank) + group = torch.distributed.init_process_group( + "nccl", + init_method="file:///tmp/rdzv", + world_size=world_size, + rank=rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + return group + + +def reset_rng(seed: int = 1234) -> None: + """Reset random number generators""" + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +@torch.no_grad() +def make_reference_and_test_tensors( + shape: int | Iterable[int], + ref_dtype: torch.dtype = torch.float64, + ref_device: torch.device = "cpu", + test_dtype: torch.dtype = torch.float32, + test_device: torch.device = "cuda", + test_is_fp8: bool = False, + requires_grad: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """Construct tensors with the same values + + The reference tensor is intended for use in plain PyTorch + operations in high precision. The test tensor is intended for use + in Transformer Engine operations. + + """ + + # Random data + ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) + + # Make copy of tensor + if test_is_fp8: + test = Float8Tensor.to_float8(ref) + else: + test = ref.to(device=test_device, dtype=test_dtype) + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + + # Make sure reference and test tensors represent exact same values + ref.copy_(test) + + # Return reference and test tensors + ref.requires_grad_(requires_grad) + test.requires_grad_(requires_grad) + return ref, test + + +def _test_linear( + *, + model_config: ModelConfig, + bias: bool = True, + device: torch.device = "cuda", + tensor_parallel_mode: str = "column", + sequence_parallel: bool = True, +) -> None: + dtype = model_config.dtype + fp8_compute = model_config.fp8 + + # Distributed process group + process_group = world_group() + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + # Tensor dimensions + out_features = model_config.hidden_size + in_features = model_config.hidden_size + batch_size = model_config.sequence_and_batch_size + in_shape = [batch_size, in_features] + out_shape = [batch_size, out_features] + + # Random data + reset_rng() + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_compute, + ) + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_compute, + ) + b_ref, b_test = None, None + if bias: + if tensor_parallel_mode == "row": + bias_shape = [world_size, out_features] + else: + bias_shape = [out_features] + b_ref, b_test = make_reference_and_test_tensors( + bias_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_compute, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref) + if bias: + if tensor_parallel_mode == "row": + y_ref += b_ref.sum(dim=0) + else: + y_ref += b_ref + y_ref.backward(dy_ref) + + # Convert to distributed tensors + with torch.no_grad(): + dw_ref = w_ref.grad + db_ref = b_ref.grad if bias else None + dx_ref = x_ref.grad + if tensor_parallel_mode == "column": + local_out_features = out_features // world_size + local_slice = slice( + rank * local_out_features, + (rank + 1) * local_out_features, + ) + w_ref = w_ref[local_slice, :] + dw_ref = dw_ref[local_slice, :] + w_test = w_test[local_slice, :] + if bias: + b_ref = b_ref[local_slice] + db_ref = db_ref[local_slice] + b_test = b_test[local_slice] + y_ref = y_ref[..., local_slice] + dy_ref = dy_ref[..., local_slice] + dy_test = dy_test[..., local_slice].clone() + elif tensor_parallel_mode == "row": + local_in_features = in_features // world_size + local_slice = slice( + rank * local_in_features, + (rank + 1) * local_in_features, + ) + w_ref = w_ref[:, local_slice] + dw_ref = dw_ref[:, local_slice] + w_test = w_test[:, local_slice] + if bias: + b_ref = b_ref[rank, :] + db_ref = db_ref[rank, :] + b_test = b_test[rank, :] + x_ref = x_ref[..., local_slice] + dx_ref = dx_ref[..., local_slice] + x_test = x_test[..., local_slice].clone() + if sequence_parallel: + local_batch_size = batch_size // world_size + local_slice = slice( + rank * local_batch_size, + (rank + 1) * local_batch_size, + ) + if tensor_parallel_mode == "column": + x_ref = x_ref[local_slice, ...] + dx_ref = dx_ref[local_slice, ...] + x_test = x_test[local_slice, ...].clone() + elif tensor_parallel_mode == "row": + y_ref = y_ref[local_slice, ...] + dy_ref = dy_ref[local_slice, ...] + dy_test = dy_test[local_slice, ...].clone() + x_test.requires_grad_() + + # Implementation with fusible operation + userbuffers_options = dict( + comm_name="qkv", + ) + with te.fp8_model_init(enabled=fp8_compute): + model = te_ops.Sequential( + te_ops.Linear( + in_features, + out_features, + bias=bias, + device=device, + dtype=dtype, + tensor_parallel_mode=tensor_parallel_mode, + tensor_parallel_group=process_group, + sequence_parallel=sequence_parallel, + userbuffers_options=userbuffers_options, + ), + ) + with torch.no_grad(): + model[0].weight.copy_(w_test) + if bias: + model[0].bias.copy_(b_test) + del w_test + del b_test + with te.fp8_autocast(enabled=fp8_compute): + y_test = model(x_test) + y_test.backward(dy_test) + + # Check that forward operations have been fused + forward_ops = model._module_groups[0]._forward_ops + assert len(forward_ops) == 1 + assert isinstance(forward_ops[0][0], UserbuffersLinear) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if fp8_compute: + tols = dtype_tols( + model[0].weight._fp8_dtype + if is_float8_tensor(model[0].weight) + else tex.DType.kFloat8E4M3 + ) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, dx_ref, **tols) + torch.testing.assert_close(dw_test, dw_ref, **tols) + if bias: + db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db_test, db_ref, **tols) + + +def run_parallel_tests(model_config: ModelConfig) -> None: + """Run parallel tests""" + + # Distributed process group + process_group = world_group() + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + # Linear op + for test_config in itertools.product( + (False, True), + ("column",), #("column", "row"), + ): + if rank == 0: + print(f"Running _test_linear with {test_config=}") + bias, tensor_parallel_mode = test_config + _test_linear( + model_config=model_config, + bias=bias, + tensor_parallel_mode=tensor_parallel_mode, + ) + + +# Parallel job sizes +_world_sizes = [] +if torch.cuda.device_count() > 1: + _world_sizes.append(torch.cuda.device_count()) + + +@pytest.mark.parametrize("world_size", _world_sizes) +def test_fuser_ops_with_userbuffers( + *, + world_size: int, + dtype: torch.dtype = torch.float32, + fp8: bool = False, +) -> None: + """Launch parallel job that runs parallel tests""" + + # Parallel job launcher + command = [] + if tex.ubuf_built_with_mpi(): + python_exe = pathlib.Path(sys.executable).resolve() + command.extend( + ("mpirun", "-np", str(world_size), "--oversubscribe", "--quiet", python_exe) + ) + else: + command.extend(("torchrun", f"--nproc_per_node={world_size}")) + + # Script invocation + command.extend( + ( + _current_file, + "--parallel", + "--batch-size", + str(world_size), + "--num-heads", + str(world_size), + "--dtype", + str(dtype), + ) + ) + if fp8: + command.append("--fp8") + + # Launch parallel job + result = subprocess.run( + command, + check=True, + ) + + +def main() -> None: + + # Parse command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument("--parallel", action="store_true", help="Run parallel tests") + parser.add_argument("--sequence-length", type=int, default=16) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--num-heads", type=int, default=16) + parser.add_argument("--head-dim", type=int, default=16) + parser.add_argument("--dtype", type=str, default="float32") + parser.add_argument("--fp8", action="store_true") + args = parser.parse_args() + + # Run parallel tests if needed + if args.parallel: + + # Model config + model_config = ModelConfig( + sequence_length=args.sequence_length, + batch_size=args.batch_size, + num_heads=args.num_heads, + head_dim=args.head_dim, + dtype=str_to_dtype(args.dtype), + fp8=args.fp8, + ) + + # Initialize Userbuffers + group = world_group() # Initialize NCCL + bootstrap_backend = "mpi" if launcher() == "ompi" else "nccl" + te.module.base.initialize_ub( + [ + model_config.sequence_length * model_config.batch_size, + model_config.num_heads * model_config.head_dim, + ], + torch.distributed.get_world_size(group), + use_fp8=model_config.fp8, + dtype=model_config.dtype, + bootstrap_backend=bootstrap_backend, + ) + + # Run tests + run_parallel_tests(model_config) + + # Clean up + te.module.base.destroy_ub() + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py new file mode 100644 index 0000000000..36d9f52978 --- /dev/null +++ b/tests/pytorch/utils.py @@ -0,0 +1,83 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +import torch + +import transformer_engine +import transformer_engine.pytorch as te +import transformer_engine_torch as tex + +def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype: + """Convert type name to PyTorch dtype""" + if isinstance(dtype, torch.dtype): + return dtype + name = str(dtype).strip().lower() + if name.startswith("torch."): + name = name.replace("torch.", "", 1) + if name.startswith("fp"): + name = name.replace("fp", "float", 1) + dtype = dict( + float32=torch.float32, + float=torch.float32, + float64=torch.float64, + double=torch.float64, + float16=torch.float16, + half=torch.float16, + bfloat16=torch.bfloat16, + bf16=torch.bfloat16, + float8_e4m3fn=torch.float8_e4m3fn, + float8_e4m3=torch.float8_e4m3fn, + float8e4m3=torch.float8_e4m3fn, + float8=torch.float8_e4m3fn, + float8_e5m2=torch.float8_e5m2, + float8e5m2=torch.float8_e5m2, + uint8=torch.uint8, + byte=torch.uint8, + int8=torch.int8, + char=torch.int8, + int16=torch.int16, + short=torch.int16, + int32=torch.int32, + int=torch.int32, + int64=torch.int64, + long=torch.int64, + bool=torch.bool, + )[name] + return dtype + +def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: + """Estimated numerical error for a datatype + + Based on tolerances for torch.testing.assert_close. + + """ + + # Transformer Engine dtypes + if isinstance(dtype, tex.DType): + dtype = { + tex.DType.kByte: torch.uint8, + tex.DType.kInt32: torch.int32, + tex.DType.kFloat32: torch.float32, + tex.DType.kFloat16: torch.half, + tex.DType.kBFloat16: torch.bfloat16, + tex.DType.kFloat8E4M3: torch.float8_e4m3fn, + tex.DType.kFloat8E5M2: torch.float8_e5m2, + }[dtype] + + # PyTorch dtypes + if dtype == torch.float16: + return dict(rtol=1e-3, atol=1e-5) + if dtype == torch.bfloat16: + return dict(rtol=1.6e-2, atol=1e-5) + if dtype == torch.float32: + return dict(rtol=1.3e-6, atol=1e-5) + if dtype == torch.float64: + return dict(rtol=1e-7, atol=1e-7) + if dtype == torch.float8_e4m3fn: + return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 + if dtype == torch.float8_e5m2: + return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 + raise ValueError(f"Unsupported dtype ({dtype})") diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 826807d1c0..a85d133afa 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -83,6 +83,10 @@ class BasicLinear(BasicOperation): autograd. The weight's `main_grad` must be set externally and there is no guarantee that `grad` will be set or be meaningful. + userbuffers_options, dict, optional + Options for overlapping tensor-parallel communication with + compute using Userbuffers. This feature is highly + experimental. """ @@ -98,6 +102,7 @@ def __init__( sequence_parallel: bool = False, rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, accumulate_into_main_grad: bool = False, + userbuffers_options: Optional[dict[str, Any]] = None, ) -> None: super().__init__() @@ -144,7 +149,7 @@ def __init__( ) # Whether weight tensor is natively in FP8 - self._with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + self._with_fp8_parameters: bool = FP8GlobalStateManager.with_fp8_parameters() if self._with_fp8_parameters: self._fp8_metas = self._make_fp8_metas() @@ -164,7 +169,10 @@ def __init__( self.reset_parameters() # Whether to accumulate weight gradient into main_grad - self._accumulate_into_main_grad = accumulate_into_main_grad + self._accumulate_into_main_grad: bool = accumulate_into_main_grad + + # Userbuffers options + self._userbuffers_options: Optional[dict[str, Any]] = userbuffers_options @classmethod def _canonicalize_tensor_parallelism( diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index bd832254d8..ded9a5ff80 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -16,3 +16,7 @@ ForwardLinearBiasAdd, fuse_forward_linear_bias_add, ) +from .userbuffers_linear import ( + UserbuffersLinear, + fuse_forward_userbuffers_linear, +) diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_linear.py new file mode 100644 index 0000000000..639d6a3b39 --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/userbuffers_linear.py @@ -0,0 +1,449 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Linear layer with Userbuffers communication.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +import transformer_engine_torch as tex +from ...cpp_extensions import FP8TensorMeta, fp8_gemm, gemm +from ...distributed import get_distributed_world_size +from ...float8_tensor import Float8Tensor +from ...fp8 import FP8GlobalStateManager +from ...module.base import get_ub, get_workspace +from ..basic import BasicLinear, Bias +from ..op import ( + BasicOperation, + FusedOperation, + FusibleOperation, + OperationContext, +) +from .._common import ( + canonicalize_device, + canonicalize_dtype, + convert_tensor, + is_float8_tensor, + reshape, +) + +class UserbuffersLinear(FusedOperation): + + def __init__( + self, + *, + linear: BasicLinear, + bias: Optional[Bias], + ) -> None: + + # Basic operations that comprise this fused operation + op_idxs = dict( + linear=0, + bias=None, + ) + ops = [linear] + if bias is not None: + op_idxs["bias"] = len(ops) + ops.append(bias) + + # Initialize base class + super().__init__(ops) + + # Index of each basic operations + self._op_idxs: dict[str, Optional[int]] = op_idxs + + @staticmethod + def _functional_forward( + input: torch.Tensor, # pylint: disable=redefined-builtin + weight: torch.Tensor, + *, + bias: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + tensor_parallel_mode: Optional[str] = None, + tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + sequence_parallel: bool = False, + with_fp8_compute: bool = False, + input_fp8_meta: Optional[dict[str, Any]] = None, + weight_fp8_meta: Optional[dict[str, Any]] = None, + output_fp8_meta: Optional[dict[str, Any]] = None, + ub_comm_name: str, + ): + + # Check device + if device is None: + device = weight.device if out is None else out.device + device = canonicalize_device(device) + if device.type != "cuda": + raise ValueError(f"Only CUDA devices are supported (got {device})") + + # Check datatype + if dtype is None: + dtype = weight.dtype if out is None else out.dtype + dtype = canonicalize_dtype(dtype) + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") + + # Check input tensor dims + input_dims = tuple(input.size()) + weight_dims = tuple(weight.size()) + if len(weight_dims) != 2: + raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})") + if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]: + raise ValueError( + f"Input tensor (shape={input_dims}) " + f"and weight tensor (shape={weight_dims}) " + "are not compatible" + ) + + # Check output tensor dims + output_dims: list[int] + output_dims = list(input_dims) + output_dims[0] = -1 + output_dims[-1] = weight_dims[0] + + # Check if FP8 is enabled + if with_fp8_compute: + if input_fp8_meta is None and not is_float8_tensor(input): + raise ValueError("No FP8 metadata was provided for casting input to FP8") + if weight_fp8_meta is None and not is_float8_tensor(weight): + raise ValueError("No FP8 metadata was provided for casting weight to FP8") + else: + input_fp8_meta = None + weight_fp8_meta = None + output_fp8_meta = None + with_fp8_output = with_fp8_compute and tensor_parallel_mode != "row" + with_fp8_output = with_fp8_output and output_fp8_meta is not None + + # Check tensor parallel group + tensor_parallel_group_size = get_distributed_world_size(tensor_parallel_group) + if tensor_parallel_group_size == 1: + tensor_parallel_mode = None + if tensor_parallel_mode not in ("column", "row"): + raise RuntimeError( + "Invalid configuration for Userbuffers " + f"({tensor_parallel_group_size=}, {tensor_parallel_mode=})" + ) + if not sequence_parallel: + raise RuntimeError( + f"Invalid configuration for Userbuffers ({sequence_parallel=})" + ) + + # Get Userbuffers communicator + ub_comm = get_ub(ub_comm_name + "_fprop") + ub_local_buffer = ub_comm.get_ubuf_output(0) + ub_global_buffer = ub_comm.get_ubuf_output(1) + with_ub_all_gather = tensor_parallel_mode == "column" + with_ub_reduce_scatter = tensor_parallel_mode == "row" + + # Choose Userbuffers communication algorithm + ub_algo = None + if with_ub_all_gather: + if with_fp8_compute and ub_comm.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + elif with_ub_reduce_scatter: + raise NotImplementedError ### TODO Implement + else: + raise RuntimeError("Could not choose Userbuffers communication algorithm") + + # Cast input tensor to correct dtype + x_local = reshape( + input, + (-1, input_dims[-1]), + device=device, + dtype=dtype, + ) + if with_fp8_compute and not is_float8_tensor(x_local): + fp8_dtype = get_fp8_te_dtype( + input_fp8_meta["recipe"], + fprop_tensor=True, + ) + if with_ub_all_gather: + data = ub_local_buffer + else: + data = torch.empty_like(x_local, dtype=torch.uint8) + x_fp8 = Float8Tensor( + data=data, + fp8_meta=input_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), + dtype=dtype, + ) + with_cast_transpose = weight.requires_grad + if tensor_parallel_mode == "column" and sequence_parallel: + with_cast_transpose = False + if with_cast_transpose: + x_fp8.cast_transpose_(x_local) + else: + x_fp8.copy_(x_local) + x_local = x_fp8 + elif not with_fp8_compute and is_float8_tensor(x_local): + if with_ub_all_gather: + x_local = ub_local_buffer.copy_(x_local) + else: + x_local = x_local.from_float8() + + # Initialize buffers for UB all-gather if needed + x = x_local + if with_ub_all_gather: + if with_fp8_compute: + x = Float8Tensor.make_like(x_local, data=ub_global_buffer) + if x_local._data.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(x_local._data) + else: + x_local._data = torch.empty_like(x_local._data) + else: + x = ub_global_buffer + if x_local.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(x_local) + else: + x_local = torch.empty_like(x_local) + + # Check weight tensor + w = convert_tensor( + weight, + device=device, + dtype=dtype, + memory_format=torch.contiguous_format, + ) + if with_fp8_compute and not is_float8_tensor(w): + fp8_dtype = get_fp8_te_dtype( + weight_fp8_meta["recipe"], + fprop_tensor=True, + ) + w = Float8Tensor.to_float8( + w, + fp8_meta=weight_fp8_meta, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + ) + elif not with_fp8_compute and is_float8_tensor(w): + w = w.from_float8() + + # Check bias tensor + b = None + if bias is not None: + b = convert_tensor( + bias, + device=device, + dtype=dtype, + memory_format=torch.contiguous_format, + ) + + # Construct output tensor + ### TODO UB RS + y = None + if with_fp8_output: + fp8_dtype = get_fp8_te_dtype( + output_fp8_meta["recipe"], + fprop_tensor=True, + ) + data = torch.empty( + (x.size(0), weight_dims[0]), + dtype=torch.uint8, + device=device, + ) + y = Float8Tensor( + data=data, + fp8_meta=output_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + dtype=dtype, + ) + else: + y = torch.empty( + (x.size(0), weight_dims[0]), + dtype=dtype, + device=device, + ) + + # Perform GEMM + if with_fp8_compute: + kwargs = dict( + out=y, + bias=b, + use_bias=(b is not None), + ub_algo=ub_algo, + ub=ub_comm, + ) + if with_ub_all_gather: + kwargs["extra_output_tensor"] = x_local._data + if with_fp8_output: + if y._fp8_meta is None: + # Hackily create FP8TensorMeta if needed + fp8_meta = FP8TensorMeta() + fp8_meta.scale = y._scale_inv.reciprocal() + fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=device) + fp8_meta.scale_inv = y._scale_inv + fp8_meta_index = 0 + else: + # Get FP8TensorMeta from Float8Tensor + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=y._fp8_meta_forward, + ) + fp8_meta = y._fp8_meta[fp8_meta_key] + fp8_meta_index = y._fp8_meta_index + kwargs.update( + dict( + out=y._data, + out_index=fp8_meta_index, + fp8_meta_tensor=fp8_meta, + D_dtype=y._fp8_dtype, + ) + ) + fp8_gemm( + w._data, + w._scale_inv, + 0, + w._fp8_dtype, + x._data, + x._scale_inv, + 0, + x._fp8_dtype, + y.dtype, + get_workspace(), + **kwargs, + ) + else: + kwargs = dict( + out=y, + bias=b, + use_bias=(b is not None), + ub_algo=ub_algo, + ub=ub_comm, + ) + if with_ub_all_gather: + kwargs["extra_output_tensor"] = x_local + gemm(w, x, y.dtype, get_workspace(), **kwargs) + + # Reshape output tensor + out = reshape(y, output_dims) + + return out, x_local, w + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + basic_op_prev_ops: list[Optional[BasicOperation]], + basic_op_next_ops: list[Optional[BasicOperation]], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + + # Get basic operations + idx = self._op_idxs["linear"] + linear_op = self.basic_ops[idx] + linear_op_ctx = basic_op_ctxs[idx] + if self._op_idxs["bias"] is None: + bias_op = None + bias = None + else: + idx = self._op_idxs["bias"] + bias_op = self.basic_ops[idx] + bias = bias_op.bias + if basic_op_kwargs[idx]: + raise ValueError("Bias operation forward does not expect keyword arguments") + + # FP8 metadata + with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() + input_fp8_meta = None + weight_fp8_meta = None + output_fp8_meta = None + grad_output_fp8_meta = None + grad_input_fp8_meta = None + if with_fp8_compute: + input_fp8_meta = linear_op.get_fp8_meta("input") + weight_fp8_meta = linear_op.get_fp8_meta("param") + next_op = basic_op_next_ops[-1] + if next_op is not None and next_op.num_fp8_scales("input") > 0: + output_fp8_meta = next_op.get_fp8_meta("input") + grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output") + prev_op = basic_op_prev_ops[0] + if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: + grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + + # Userbuffers options + if linear_op._userbuffers_options is None: + raise RuntimeError("Linear op is missing dict for Userbuffers options") + + # Linear forward + output, x_local, _ = UserbuffersLinear._functional_forward( + input=input_, + weight=linear_op.weight, + bias=bias, + device=linear_op.device, + dtype=linear_op.dtype, + tensor_parallel_mode=linear_op.tensor_parallel_mode, + tensor_parallel_group=linear_op.tensor_parallel_group, + sequence_parallel=linear_op.sequence_parallel, + with_fp8_compute=with_fp8_compute, + input_fp8_meta=input_fp8_meta, + weight_fp8_meta=weight_fp8_meta, + output_fp8_meta=output_fp8_meta, + ub_comm_name=linear_op._userbuffers_options["comm_name"], + ) + + # Save state for backward pass + linear_op_ctx.save_for_backward(x_local) + linear_op_ctx.with_fp8_compute = with_fp8_compute + linear_op_ctx.weight_fp8_meta = weight_fp8_meta + linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta + linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta + linear_op_ctx.input_dims = input_.size() + linear_op_ctx.input_requires_grad = input_.requires_grad + linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad + linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None + + return output, [() for _ in range(len(self.basic_ops))] + + +def fuse_forward_userbuffers_linear( + ops: list[tuple[FusibleOperation, list[int]]], +) -> list[tuple[FusibleOperation, list[int]]]: + + # Scan through ops, fusing if possible + out = [] + window = [] + while ops: + out.extend(window) + + # Check if first op is linear + op1, _ = ops[0] + window = [ops[0]] + ops = ops[1:] + if not isinstance(op1, BasicLinear): + continue + if op1.tensor_parallel_mode not in ("column", "row"): + continue + if op1._userbuffers_options is None: + continue + + # Check if second op is bias + op2 = None + if ops and isinstance(ops[0][0], Bias): + op2, _ = ops[0] + window.append(ops[0]) + ops = ops[1:] + + # Replace window with fused op + op = UserbuffersLinear( + linear=op1, + bias=op2, + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out.extend(window) + return out diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index a7c99c592d..4978c0e5f7 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -20,6 +20,7 @@ fuse_backward_linear_add, fuse_forward_linear_bias_activation, fuse_forward_linear_bias_add, + fuse_forward_userbuffers_linear, ) @@ -318,6 +319,7 @@ def _fuse_forward_ops( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: """Attempt to fuse operations in forward pass""" + ops = fuse_forward_userbuffers_linear(ops) ops = fuse_forward_linear_bias_add(ops) ops = fuse_forward_linear_bias_activation(ops) return ops diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py index 13cec30fa2..3fb08368c6 100644 --- a/transformer_engine/pytorch/ops/linear.py +++ b/transformer_engine/pytorch/ops/linear.py @@ -71,6 +71,7 @@ def __init__( sequence_parallel: bool = False, rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, accumulate_into_main_grad: bool = False, + userbuffers_options: Optional[dict[str, Any]] = None, ) -> None: # Tensor parallel configuration @@ -101,6 +102,7 @@ def __init__( sequence_parallel=sequence_parallel, rng_state_tracker_function=rng_state_tracker_function, accumulate_into_main_grad=accumulate_into_main_grad, + userbuffers_options=userbuffers_options, ) bias_kwargs = dict( size=out_features, From 90e0a41de6f5267070c71cfb971eaf5a0ffee213 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 20 Aug 2024 22:45:29 +0000 Subject: [PATCH 02/13] Add Userbuffers support for row TP linear layer Signed-off-by: Tim Moon --- .../test_fusible_ops_with_userbuffers.py | 41 ++++--- .../pytorch/ops/fused/userbuffers_linear.py | 103 +++++++++++++----- 2 files changed, 99 insertions(+), 45 deletions(-) diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 94421008ce..99dffd9a26 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -18,13 +18,13 @@ import transformer_engine import transformer_engine.pytorch as te +import transformer_engine.pytorch.cpp_extensions as tex from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.ops.fused import UserbuffersLinear from transformer_engine.pytorch.utils import is_bf16_compatible -import transformer_engine_torch as tex # Import utility functions _current_file = pathlib.Path(__file__).resolve() @@ -146,7 +146,7 @@ def make_reference_and_test_tensors( def _test_linear( *, model_config: ModelConfig, - bias: bool = True, + bias: bool = False, device: torch.device = "cuda", tensor_parallel_mode: str = "column", sequence_parallel: bool = True, @@ -262,15 +262,18 @@ def _test_linear( x_test.requires_grad_() # Implementation with fusible operation - userbuffers_options = dict( - comm_name="qkv", - ) + userbuffers_options = {} + if tensor_parallel_mode == "column": + userbuffers_options["comm_name"] = "qkv" + if tensor_parallel_mode == "row": + userbuffers_options["comm_name"] = "fc1" with te.fp8_model_init(enabled=fp8_compute): model = te_ops.Sequential( - te_ops.Linear( + # te_op.Linear( ### TODO Restore + te_ops.BasicLinear( ### TODO Remove in_features, out_features, - bias=bias, + #bias=bias, ### TODO Restore device=device, dtype=dtype, tensor_parallel_mode=tensor_parallel_mode, @@ -327,8 +330,8 @@ def run_parallel_tests(model_config: ModelConfig) -> None: # Linear op for test_config in itertools.product( - (False, True), - ("column",), #("column", "row"), + (False,), # (False, True) # bias + ("column", "row"), # tensor_parallel_mode ): if rank == 0: print(f"Running _test_linear with {test_config=}") @@ -381,11 +384,17 @@ def test_fuser_ops_with_userbuffers( if fp8: command.append("--fp8") + # Environment + env = dict(os.environ) + if not tex.device_supports_multicast(): + env["UB_SKIPMC"] = "1" + env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + env["PYTORCH_JIT"] = "0" + env["NVTE_TORCH_COMPILE"] = "0" + env["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + # Launch parallel job - result = subprocess.run( - command, - check=True, - ) + result = subprocess.run(command, check=True, env=env) def main() -> None: @@ -393,11 +402,11 @@ def main() -> None: # Parse command-line arguments parser = argparse.ArgumentParser() parser.add_argument("--parallel", action="store_true", help="Run parallel tests") - parser.add_argument("--sequence-length", type=int, default=16) + parser.add_argument("--sequence-length", type=int, default=32) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--num-heads", type=int, default=16) - parser.add_argument("--head-dim", type=int, default=16) - parser.add_argument("--dtype", type=str, default="float32") + parser.add_argument("--head-dim", type=int, default=32) + parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--fp8", action="store_true") args = parser.parse_args() diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_linear.py index 639d6a3b39..66d7e91ea2 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_linear.py @@ -16,7 +16,7 @@ from ...float8_tensor import Float8Tensor from ...fp8 import FP8GlobalStateManager from ...module.base import get_ub, get_workspace -from ..basic import BasicLinear, Bias +from ..basic import BasicLinear, Bias, ReduceScatter from ..op import ( BasicOperation, FusedOperation, @@ -120,13 +120,13 @@ def _functional_forward( with_fp8_output = with_fp8_output and output_fp8_meta is not None # Check tensor parallel group - tensor_parallel_group_size = get_distributed_world_size(tensor_parallel_group) - if tensor_parallel_group_size == 1: + tensor_parallel_size = get_distributed_world_size(tensor_parallel_group) + if tensor_parallel_size == 1: tensor_parallel_mode = None if tensor_parallel_mode not in ("column", "row"): raise RuntimeError( "Invalid configuration for Userbuffers " - f"({tensor_parallel_group_size=}, {tensor_parallel_mode=})" + f"({tensor_parallel_size=}, {tensor_parallel_mode=})" ) if not sequence_parallel: raise RuntimeError( @@ -148,7 +148,21 @@ def _functional_forward( else: ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P elif with_ub_reduce_scatter: - raise NotImplementedError ### TODO Implement + if with_fp8_compute: + ub_algo = { + (True, True): tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P, + (True, False): tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P, + (False, True): tex.UbufOverlapAlgo.ATOMIC_GEMM_RS, + (False, False): tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS, + }[(ub_comm.is_p2p_overlap(), ub_comm.is_atomic_gemm())] + if ub_comm.is_fp8_ubuf(): + assert output_fp8_meta is not None + ub_comm.set_ubuf_scale_inv(output_fp8_meta.scale_inv) + else: + if ub_comm.is_p2p_overlap(): + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS else: raise RuntimeError("Could not choose Userbuffers communication algorithm") @@ -239,32 +253,57 @@ def _functional_forward( ) # Construct output tensor - ### TODO UB RS y = None - if with_fp8_output: - fp8_dtype = get_fp8_te_dtype( - output_fp8_meta["recipe"], - fprop_tensor=True, - ) - data = torch.empty( - (x.size(0), weight_dims[0]), - dtype=torch.uint8, - device=device, - ) - y = Float8Tensor( - data=data, - fp8_meta=output_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - dtype=dtype, - ) - else: - y = torch.empty( - (x.size(0), weight_dims[0]), + y_local = None + if with_ub_reduce_scatter: + # Initialize buffers for UB reduce-scatter + if with_fp8_output: + fp8_dtype = get_fp8_te_dtype( + output_fp8_meta["recipe"], + fprop_tensor=True, + ) + y = Float8Tensor( + data=ub_global_buffer, + fp8_meta=output_fp8_meta, ### TODO Probably wrong + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + dtype=dtype, + ) + else: + y = ub_global_buffer + y_local = torch.empty( + (x.size(0) // tensor_parallel_size, weight_dims[0]), dtype=dtype, device=device, ) + else: + # Allocate output tensor + if with_fp8_output: + fp8_dtype = get_fp8_te_dtype( + output_fp8_meta["recipe"], + fprop_tensor=True, + ) + data = torch.empty( + (x.size(0), weight_dims[0]), + dtype=torch.uint8, + device=device, + ) + y = Float8Tensor( + data=data, + fp8_meta=output_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + dtype=dtype, + ) + else: + y = torch.empty( + (x.size(0), weight_dims[0]), + dtype=dtype, + device=device, + ) + y_local = y # Perform GEMM if with_fp8_compute: @@ -277,6 +316,8 @@ def _functional_forward( ) if with_ub_all_gather: kwargs["extra_output_tensor"] = x_local._data + if with_ub_reduce_scatter: + kwargs["extra_output_tensor"] = y_local._data if with_fp8_output: if y._fp8_meta is None: # Hackily create FP8TensorMeta if needed @@ -323,10 +364,12 @@ def _functional_forward( ) if with_ub_all_gather: kwargs["extra_output_tensor"] = x_local + if with_ub_reduce_scatter: + kwargs["extra_output_tensor"] = y_local gemm(w, x, y.dtype, get_workspace(), **kwargs) # Reshape output tensor - out = reshape(y, output_dims) + out = reshape(y_local, output_dims) return out, x_local, w @@ -412,6 +455,8 @@ def fuse_forward_userbuffers_linear( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: + ### TODO Handle separate RS op + # Scan through ops, fusing if possible out = [] window = [] @@ -431,7 +476,7 @@ def fuse_forward_userbuffers_linear( # Check if second op is bias op2 = None - if ops and isinstance(ops[0][0], Bias): + if op1.tensor_parallel_mode != "row" and ops and isinstance(ops[0][0], Bias): op2, _ = ops[0] window.append(ops[0]) ops = ops[1:] From a520974ac2951dee4a456aadbe5400fcf91ffd55 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 21 Aug 2024 02:48:14 +0000 Subject: [PATCH 03/13] Interpret linear+RS as row TP linear Signed-off-by: Tim Moon --- .../test_fusible_ops_with_userbuffers.py | 60 ++++++--- transformer_engine/pytorch/ops/__init__.py | 12 +- .../pytorch/ops/fused/userbuffers_linear.py | 121 ++++++++++++------ transformer_engine/pytorch/ops/linear.py | 2 - 4 files changed, 126 insertions(+), 69 deletions(-) diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 99dffd9a26..5854ff56c4 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -262,30 +262,49 @@ def _test_linear( x_test.requires_grad_() # Implementation with fusible operation - userbuffers_options = {} - if tensor_parallel_mode == "column": - userbuffers_options["comm_name"] = "qkv" - if tensor_parallel_mode == "row": - userbuffers_options["comm_name"] = "fc1" with te.fp8_model_init(enabled=fp8_compute): - model = te_ops.Sequential( - # te_op.Linear( ### TODO Restore - te_ops.BasicLinear( ### TODO Remove + ops = [] + linear_op = None + bias_op = None + if tensor_parallel_mode == "column": + userbuffers_options = dict(comm_name="qkv") + linear_op = te_ops.BasicLinear( in_features, out_features, - #bias=bias, ### TODO Restore device=device, dtype=dtype, tensor_parallel_mode=tensor_parallel_mode, tensor_parallel_group=process_group, sequence_parallel=sequence_parallel, userbuffers_options=userbuffers_options, - ), - ) + ) + ops.append(linear_op) + if bias: + bias_op = te_ops.Bias( + out_features // world_size, + device=device, + dtype=dtype, + ) + ops.append(bias_op) + elif tensor_parallel_mode == "row": + userbuffers_options = dict(comm_name="fc1") + linear_op = te_ops.BasicLinear( + in_features // world_size, + out_features, + device=device, + dtype=dtype, + userbuffers_options=userbuffers_options, + ) + ops.append(linear_op) + if bias: + bias_op = te_ops.Bias(out_features, device=device, dtype=dtype) + ops.append(bias_op) + ops.append(te_ops.ReduceScatter(process_group)) + model = te_ops.Sequential(*ops) with torch.no_grad(): - model[0].weight.copy_(w_test) + linear_op.weight.copy_(w_test) if bias: - model[0].bias.copy_(b_test) + bias_op.bias.copy_(b_test) del w_test del b_test with te.fp8_autocast(enabled=fp8_compute): @@ -311,12 +330,12 @@ def _test_linear( # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") - dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu") + dw_test = linear_op.weight.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, dx_ref, **tols) torch.testing.assert_close(dw_test, dw_ref, **tols) if bias: - db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu") + db_test = bias_op.bias.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(db_test, db_ref, **tols) @@ -330,7 +349,7 @@ def run_parallel_tests(model_config: ModelConfig) -> None: # Linear op for test_config in itertools.product( - (False,), # (False, True) # bias + (False, True), # bias ("column", "row"), # tensor_parallel_mode ): if rank == 0: @@ -350,14 +369,19 @@ def run_parallel_tests(model_config: ModelConfig) -> None: @pytest.mark.parametrize("world_size", _world_sizes) +@pytest.mark.parametrize("fp8", (False, True)) def test_fuser_ops_with_userbuffers( *, world_size: int, dtype: torch.dtype = torch.float32, - fp8: bool = False, + fp8: bool, ) -> None: """Launch parallel job that runs parallel tests""" + # Skip invalid configurations + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + # Parallel job launcher command = [] if tex.ubuf_built_with_mpi(): @@ -406,7 +430,7 @@ def main() -> None: parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--num-heads", type=int, default=16) parser.add_argument("--head-dim", type=int, default=32) - parser.add_argument("--dtype", type=str, default="bfloat16") + parser.add_argument("--dtype", type=str, default="float32") parser.add_argument("--fp8", action="store_true") args = parser.parse_args() diff --git a/transformer_engine/pytorch/ops/__init__.py b/transformer_engine/pytorch/ops/__init__.py index f437f877b4..f65433398e 100644 --- a/transformer_engine/pytorch/ops/__init__.py +++ b/transformer_engine/pytorch/ops/__init__.py @@ -8,17 +8,7 @@ """ -from transformer_engine.pytorch.ops.basic import ( - AddInPlace, - AllGather, - AllReduce, - BasicLinear, - Bias, - Identity, - MakeExtraOutput, - ReduceScatter, - Reshape, -) +from transformer_engine.pytorch.ops.basic import * from transformer_engine.pytorch.ops.linear import Linear from transformer_engine.pytorch.ops.op import FusibleOperation from transformer_engine.pytorch.ops.sequential import Sequential diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_linear.py index 66d7e91ea2..4c2c58af25 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_linear.py @@ -38,17 +38,22 @@ def __init__( *, linear: BasicLinear, bias: Optional[Bias], + reduce_scatter: Optional[ReduceScatter], ) -> None: # Basic operations that comprise this fused operation op_idxs = dict( linear=0, bias=None, + reduce_scatter=None, ) ops = [linear] if bias is not None: op_idxs["bias"] = len(ops) ops.append(bias) + if reduce_scatter is not None: + op_idxs["reduce_scatter"] = len(ops) + ops.append(reduce_scatter) # Initialize base class super().__init__(ops) @@ -56,6 +61,22 @@ def __init__( # Index of each basic operations self._op_idxs: dict[str, Optional[int]] = op_idxs + # Tensor parallelism configuration + self.tensor_parallel_mode: Optional[str] + self.tensor_parallel_group: Optional[torch.distributed.ProcessGroup] + self.tensor_parallel_size: int + self.sequence_parallel: bool + if reduce_scatter is None: + self.tensor_parallel_mode = linear.tensor_parallel_mode + self.tensor_parallel_group = linear.tensor_parallel_group + self.tensor_parallel_size = linear.tensor_parallel_size + self.sequence_parallel = linear.sequence_parallel + else: + self.tensor_parallel_mode = "row" + self.tensor_parallel_group = reduce_scatter.process_group + self.tensor_parallel_size = reduce_scatter.process_group_size + self.sequence_parallel = True + @staticmethod def _functional_forward( input: torch.Tensor, # pylint: disable=redefined-builtin @@ -66,6 +87,7 @@ def _functional_forward( dtype: Optional[torch.dtype] = None, tensor_parallel_mode: Optional[str] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + tensor_parallel_size: Optional[int] = None, sequence_parallel: bool = False, with_fp8_compute: bool = False, input_fp8_meta: Optional[dict[str, Any]] = None, @@ -106,21 +128,9 @@ def _functional_forward( output_dims[0] = -1 output_dims[-1] = weight_dims[0] - # Check if FP8 is enabled - if with_fp8_compute: - if input_fp8_meta is None and not is_float8_tensor(input): - raise ValueError("No FP8 metadata was provided for casting input to FP8") - if weight_fp8_meta is None and not is_float8_tensor(weight): - raise ValueError("No FP8 metadata was provided for casting weight to FP8") - else: - input_fp8_meta = None - weight_fp8_meta = None - output_fp8_meta = None - with_fp8_output = with_fp8_compute and tensor_parallel_mode != "row" - with_fp8_output = with_fp8_output and output_fp8_meta is not None - # Check tensor parallel group - tensor_parallel_size = get_distributed_world_size(tensor_parallel_group) + if tensor_parallel_size is None: + tensor_parallel_size = get_distributed_world_size(tensor_parallel_group) if tensor_parallel_size == 1: tensor_parallel_mode = None if tensor_parallel_mode not in ("column", "row"): @@ -133,6 +143,20 @@ def _functional_forward( f"Invalid configuration for Userbuffers ({sequence_parallel=})" ) + # Check if FP8 is enabled + if with_fp8_compute: + if input_fp8_meta is None and not is_float8_tensor(input): + raise ValueError("No FP8 metadata was provided for casting input to FP8") + if weight_fp8_meta is None and not is_float8_tensor(weight): + raise ValueError("No FP8 metadata was provided for casting weight to FP8") + if output_fp8_meta is None and tensor_parallel_mode == "row": + raise ValueError("No FP8 metadata was provided for casting output to FP8") + else: + input_fp8_meta = None + weight_fp8_meta = None + output_fp8_meta = None + with_fp8_output = with_fp8_compute and output_fp8_meta is not None + # Get Userbuffers communicator ub_comm = get_ub(ub_comm_name + "_fprop") ub_local_buffer = ub_comm.get_ubuf_output(0) @@ -156,7 +180,6 @@ def _functional_forward( (False, False): tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS, }[(ub_comm.is_p2p_overlap(), ub_comm.is_atomic_gemm())] if ub_comm.is_fp8_ubuf(): - assert output_fp8_meta is not None ub_comm.set_ubuf_scale_inv(output_fp8_meta.scale_inv) else: if ub_comm.is_p2p_overlap(): @@ -264,7 +287,7 @@ def _functional_forward( ) y = Float8Tensor( data=ub_global_buffer, - fp8_meta=output_fp8_meta, ### TODO Probably wrong + fp8_meta=output_fp8_meta, fp8_meta_forward=True, fp8_meta_index=0, fp8_dtype=fp8_dtype, @@ -388,15 +411,20 @@ def fuser_forward( idx = self._op_idxs["linear"] linear_op = self.basic_ops[idx] linear_op_ctx = basic_op_ctxs[idx] - if self._op_idxs["bias"] is None: - bias_op = None - bias = None - else: + bias_op = None + bias = None + if self._op_idxs["bias"] is not None: idx = self._op_idxs["bias"] bias_op = self.basic_ops[idx] bias = bias_op.bias if basic_op_kwargs[idx]: raise ValueError("Bias operation forward does not expect keyword arguments") + reduce_scatter_op = None + if self._op_idxs["reduce_scatter"] is not None: + idx = self._op_idxs["reduce_scatter"] + reduce_scatter_op = self.basic_ops[idx] + if basic_op_kwargs[idx]: + raise ValueError("Reduce-scatter operation forward does not expect keyword arguments") # FP8 metadata with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() @@ -427,9 +455,10 @@ def fuser_forward( bias=bias, device=linear_op.device, dtype=linear_op.dtype, - tensor_parallel_mode=linear_op.tensor_parallel_mode, - tensor_parallel_group=linear_op.tensor_parallel_group, - sequence_parallel=linear_op.sequence_parallel, + tensor_parallel_mode=self.tensor_parallel_mode, + tensor_parallel_group=self.tensor_parallel_group, + tensor_parallel_size=self.tensor_parallel_size, + sequence_parallel=self.sequence_parallel, with_fp8_compute=with_fp8_compute, input_fp8_meta=input_fp8_meta, weight_fp8_meta=weight_fp8_meta, @@ -455,8 +484,6 @@ def fuse_forward_userbuffers_linear( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: - ### TODO Handle separate RS op - # Scan through ops, fusing if possible out = [] window = [] @@ -464,27 +491,45 @@ def fuse_forward_userbuffers_linear( out.extend(window) # Check if first op is linear - op1, _ = ops[0] - window = [ops[0]] - ops = ops[1:] - if not isinstance(op1, BasicLinear): - continue - if op1.tensor_parallel_mode not in ("column", "row"): + window, ops = ops[:1], ops[1:] + if not isinstance(window[0][0], BasicLinear): continue - if op1._userbuffers_options is None: + linear = window[0][0] + if linear._userbuffers_options is None: continue - # Check if second op is bias - op2 = None - if op1.tensor_parallel_mode != "row" and ops and isinstance(ops[0][0], Bias): - op2, _ = ops[0] + # Check if next op is bias + bias = None + if linear.tensor_parallel_mode != "row" and ops and isinstance(ops[0][0], Bias): + bias = ops[0][0] window.append(ops[0]) ops = ops[1:] + # Check if next op is FP8 cast + ### TODO Implement + + # Check if next op is reduce-scatter + reduce_scatter = None + if linear.tensor_parallel_mode is None and ops and isinstance(ops[0][0], ReduceScatter): + reduce_scatter = ops[0][0] + window.append(ops[0]) + ops = ops[1:] + + # Check for invalid combinations + if reduce_scatter is None: + if linear.tensor_parallel_mode is None: + continue + if linear.tensor_parallel_mode == "row" and bias is not None: + continue + else: + if linear.tensor_parallel_mode is not None: + continue + # Replace window with fused op op = UserbuffersLinear( - linear=op1, - bias=op2, + linear=linear, + bias=bias, + reduce_scatter=reduce_scatter, ) basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] window = [(op, basic_op_idxs)] diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py index 3fb08368c6..13cec30fa2 100644 --- a/transformer_engine/pytorch/ops/linear.py +++ b/transformer_engine/pytorch/ops/linear.py @@ -71,7 +71,6 @@ def __init__( sequence_parallel: bool = False, rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, accumulate_into_main_grad: bool = False, - userbuffers_options: Optional[dict[str, Any]] = None, ) -> None: # Tensor parallel configuration @@ -102,7 +101,6 @@ def __init__( sequence_parallel=sequence_parallel, rng_state_tracker_function=rng_state_tracker_function, accumulate_into_main_grad=accumulate_into_main_grad, - userbuffers_options=userbuffers_options, ) bias_kwargs = dict( size=out_features, From bb2e71462822097b96b7d3874fc371d8fac7a1c6 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 22 Aug 2024 02:37:21 +0000 Subject: [PATCH 04/13] Add Userbuffers support for FP8 row TP linear layer Assumes FP8 RS, which is not a good assumption. Signed-off-by: Tim Moon --- .../test_fusible_ops_with_userbuffers.py | 14 ++- .../pytorch/ops/basic/__init__.py | 1 + .../pytorch/ops/basic/cast_float8.py | 100 +++++++++++++++ .../pytorch/ops/fused/__init__.py | 6 +- ...inear.py => userbuffers_forward_linear.py} | 116 ++++++++++++------ transformer_engine/pytorch/ops/fuser.py | 4 +- 6 files changed, 195 insertions(+), 46 deletions(-) create mode 100644 transformer_engine/pytorch/ops/basic/cast_float8.py rename transformer_engine/pytorch/ops/fused/{userbuffers_linear.py => userbuffers_forward_linear.py} (82%) diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 5854ff56c4..6320c47fc1 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -23,7 +23,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import is_float8_tensor -from transformer_engine.pytorch.ops.fused import UserbuffersLinear +from transformer_engine.pytorch.ops.fused import UserbuffersForwardLinear from transformer_engine.pytorch.utils import is_bf16_compatible # Import utility functions @@ -287,7 +287,9 @@ def _test_linear( ) ops.append(bias_op) elif tensor_parallel_mode == "row": - userbuffers_options = dict(comm_name="fc1") + userbuffers_options = dict(comm_name="proj") + if fp8_compute: + userbuffers_options["comm_name"] = "fc1" ### TODO Remove linear_op = te_ops.BasicLinear( in_features // world_size, out_features, @@ -299,6 +301,8 @@ def _test_linear( if bias: bias_op = te_ops.Bias(out_features, device=device, dtype=dtype) ops.append(bias_op) + if fp8_compute: + ops.append(te_ops.CastFloat8(backward=False)) ops.append(te_ops.ReduceScatter(process_group)) model = te_ops.Sequential(*ops) with torch.no_grad(): @@ -314,7 +318,7 @@ def _test_linear( # Check that forward operations have been fused forward_ops = model._module_groups[0]._forward_ops assert len(forward_ops) == 1 - assert isinstance(forward_ops[0][0], UserbuffersLinear) + assert isinstance(forward_ops[0][0], UserbuffersForwardLinear) # Expected numerical error tols = dtype_tols(dtype) @@ -373,7 +377,7 @@ def run_parallel_tests(model_config: ModelConfig) -> None: def test_fuser_ops_with_userbuffers( *, world_size: int, - dtype: torch.dtype = torch.float32, + dtype: torch.dtype = torch.bfloat16, fp8: bool, ) -> None: """Launch parallel job that runs parallel tests""" @@ -430,7 +434,7 @@ def main() -> None: parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--num-heads", type=int, default=16) parser.add_argument("--head-dim", type=int, default=32) - parser.add_argument("--dtype", type=str, default="float32") + parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--fp8", action="store_true") args = parser.parse_args() diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 1003cc0337..599bac1b4e 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -9,6 +9,7 @@ from .all_reduce import AllReduce from .basic_linear import BasicLinear from .bias import Bias +from .cast_float8 import CastFloat8 from .identity import Identity from .make_extra_output import MakeExtraOutput from .reduce_scatter import ReduceScatter diff --git a/transformer_engine/pytorch/ops/basic/cast_float8.py b/transformer_engine/pytorch/ops/basic/cast_float8.py new file mode 100644 index 0000000000..40cd350e99 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/cast_float8.py @@ -0,0 +1,100 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for FP8 cast.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.fp8 import ( + FP8GlobalStateManager, + get_fp8_te_dtype, +) +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + OperationContext, +) +from .._common import is_float8_tensor + + +class CastFloat8(BasicOperation): + """Cast tensor to FP8 + + Uses FP8 recipe from `fp8_autocast` context. When called outside + of an `fp8_autocast` context, this is an identity operation. + + Parameters + ---------- + forward: bool, default = `True` + Perform FP8 cast in forward pass + backward: bool, default = `True` + Perform FP8 cast in backward pass + + """ + + def __init__( + self, + forward: bool = True, + backward: bool = True, + ) -> None: + super().__init__() + self._cast_forward = forward + self._cast_backward = backward + + def num_fp8_scales(self, mode: str) -> int: + if mode == "input" and self._cast_forward: + return 1 + if mode == "grad_output" and self._cast_backward: + return 1 + return 0 + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Check if FP8 is enabled + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() + cast_forward = fp8_enabled and self._cast_forward + cast_backward = fp8_enabled and self._cast_backward + + # Cast to FP8 if needed + out = input_ + if cast_forward and not is_float8_tensor(out): + fp8_meta = self.get_fp8_meta("input") + fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + out = Float8Tensor.to_float8( + out, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + ) + + ctx.cast_backward = cast_backward + return out + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + grad_input = grad_output + if ctx.cast_backward and not is_float8_tensor(grad_input): + fp8_meta = self.get_fp8_meta("grad_output") + fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) + grad_input = Float8Tensor.to_float8( + grad_input, + fp8_meta=fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + ) + return grad_input, () diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index ded9a5ff80..85a3e31b4e 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -16,7 +16,7 @@ ForwardLinearBiasAdd, fuse_forward_linear_bias_add, ) -from .userbuffers_linear import ( - UserbuffersLinear, - fuse_forward_userbuffers_linear, +from .userbuffers_forward_linear import ( + UserbuffersForwardLinear, + fuse_userbuffers_forward_linear, ) diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py similarity index 82% rename from transformer_engine/pytorch/ops/fused/userbuffers_linear.py rename to transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 4c2c58af25..3073776645 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -10,13 +10,12 @@ import torch -import transformer_engine_torch as tex -from ...cpp_extensions import FP8TensorMeta, fp8_gemm, gemm +from ...cpp_extensions import FP8TensorMeta, UbufOverlapAlgo, fp8_gemm, gemm from ...distributed import get_distributed_world_size from ...float8_tensor import Float8Tensor -from ...fp8 import FP8GlobalStateManager +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype from ...module.base import get_ub, get_workspace -from ..basic import BasicLinear, Bias, ReduceScatter +from ..basic import BasicLinear, Bias, CastFloat8, ReduceScatter from ..op import ( BasicOperation, FusedOperation, @@ -31,13 +30,14 @@ reshape, ) -class UserbuffersLinear(FusedOperation): +class UserbuffersForwardLinear(FusedOperation): def __init__( self, *, linear: BasicLinear, bias: Optional[Bias], + cast_fp8: Optional[CastFloat8], reduce_scatter: Optional[ReduceScatter], ) -> None: @@ -45,12 +45,16 @@ def __init__( op_idxs = dict( linear=0, bias=None, + cast_fp8=None, reduce_scatter=None, ) ops = [linear] if bias is not None: op_idxs["bias"] = len(ops) ops.append(bias) + if cast_fp8 is not None: + op_idxs["cast_fp8"] = len(ops) + ops.append(cast_fp8) if reduce_scatter is not None: op_idxs["reduce_scatter"] = len(ops) ops.append(reduce_scatter) @@ -168,24 +172,22 @@ def _functional_forward( ub_algo = None if with_ub_all_gather: if with_fp8_compute and ub_comm.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + ub_algo = UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_algo = UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P elif with_ub_reduce_scatter: if with_fp8_compute: ub_algo = { - (True, True): tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P, - (True, False): tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P, - (False, True): tex.UbufOverlapAlgo.ATOMIC_GEMM_RS, - (False, False): tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS, + (True, True): UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P, + (True, False): UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P, + (False, True): UbufOverlapAlgo.ATOMIC_GEMM_RS, + (False, False): UbufOverlapAlgo.SPLIT_PIPELINED_RS, }[(ub_comm.is_p2p_overlap(), ub_comm.is_atomic_gemm())] - if ub_comm.is_fp8_ubuf(): - ub_comm.set_ubuf_scale_inv(output_fp8_meta.scale_inv) else: if ub_comm.is_p2p_overlap(): - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo = UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo = UbufOverlapAlgo.SPLIT_PIPELINED_RS else: raise RuntimeError("Could not choose Userbuffers communication algorithm") @@ -281,6 +283,7 @@ def _functional_forward( if with_ub_reduce_scatter: # Initialize buffers for UB reduce-scatter if with_fp8_output: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) fp8_dtype = get_fp8_te_dtype( output_fp8_meta["recipe"], fprop_tensor=True, @@ -291,8 +294,10 @@ def _functional_forward( fp8_meta_forward=True, fp8_meta_index=0, fp8_dtype=fp8_dtype, + fp8_scale_inv=output_fp8_meta[fp8_meta_key].scale_inv[0], dtype=dtype, ) + ub_comm.set_ubuf_scale_inv(y._scale_inv) else: y = ub_global_buffer y_local = torch.empty( @@ -334,13 +339,14 @@ def _functional_forward( out=y, bias=b, use_bias=(b is not None), + use_split_accumulator=False, ub_algo=ub_algo, ub=ub_comm, ) if with_ub_all_gather: kwargs["extra_output_tensor"] = x_local._data if with_ub_reduce_scatter: - kwargs["extra_output_tensor"] = y_local._data + kwargs["extra_output_tensor"] = y_local if with_fp8_output: if y._fp8_meta is None: # Hackily create FP8TensorMeta if needed @@ -419,6 +425,14 @@ def fuser_forward( bias = bias_op.bias if basic_op_kwargs[idx]: raise ValueError("Bias operation forward does not expect keyword arguments") + cast_fp8_op = None + cast_fp8_op_ctx = None + if self._op_idxs["cast_fp8"] is not None: + idx = self._op_idxs["cast_fp8"] + cast_fp8_op = self.basic_ops[idx] + cast_fp8_op_ctx = basic_op_ctxs[idx] + if basic_op_kwargs[idx]: + raise ValueError("FP8 cast operation forward does not expect keyword arguments") reduce_scatter_op = None if self._op_idxs["reduce_scatter"] is not None: idx = self._op_idxs["reduce_scatter"] @@ -436,9 +450,19 @@ def fuser_forward( if with_fp8_compute: input_fp8_meta = linear_op.get_fp8_meta("input") weight_fp8_meta = linear_op.get_fp8_meta("param") - next_op = basic_op_next_ops[-1] - if next_op is not None and next_op.num_fp8_scales("input") > 0: - output_fp8_meta = next_op.get_fp8_meta("input") + if cast_fp8_op is not None: + output_fp8_meta = cast_fp8_op.get_fp8_meta("input") + if output_fp8_meta is None: + next_op = basic_op_next_ops[-1] + if next_op is not None and next_op.num_fp8_scales("input") > 0: + output_fp8_meta = next_op.get_fp8_meta("input") + if self.tensor_parallel_mode == "row" and output_fp8_meta is None: + raise RuntimeError( + "Userbuffers implementation of row tensor-parallel linear with FP8 compute " + "casts GEMM output to FP8, but could not find FP8 metadata to perform this " + "cast. Either disable Userbuffers or insert CastFloat8 op between BasicLinear " + "and ReduceScatter." + ) grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output") prev_op = basic_op_prev_ops[0] if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: @@ -449,7 +473,7 @@ def fuser_forward( raise RuntimeError("Linear op is missing dict for Userbuffers options") # Linear forward - output, x_local, _ = UserbuffersLinear._functional_forward( + output, x_local, _ = UserbuffersForwardLinear._functional_forward( input=input_, weight=linear_op.weight, bias=bias, @@ -476,44 +500,63 @@ def fuser_forward( linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None + if cast_fp8_op is not None: + cast_fp8_op_ctx.cast_backward = with_fp8_compute and cast_fp8_op._cast_backward return output, [() for _ in range(len(self.basic_ops))] -def fuse_forward_userbuffers_linear( +def fuse_userbuffers_forward_linear( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: + window = [] + + def peek_next_op() -> Optional[FusibleOperation]: + nonlocal ops + if not ops: + return None + return ops[0][0] + + def pop_next_op() -> FusibleOperation: + nonlocal ops, window + window.append(ops[0]) + ops = ops[1:] + return window[-1][0] + # Scan through ops, fusing if possible out = [] - window = [] while ops: out.extend(window) + window.clear() - # Check if first op is linear - window, ops = ops[:1], ops[1:] - if not isinstance(window[0][0], BasicLinear): + # Check if next op is linear + next_op = pop_next_op() + if not isinstance(next_op, BasicLinear): continue - linear = window[0][0] + linear = next_op if linear._userbuffers_options is None: continue # Check if next op is bias bias = None - if linear.tensor_parallel_mode != "row" and ops and isinstance(ops[0][0], Bias): - bias = ops[0][0] - window.append(ops[0]) - ops = ops[1:] + if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias): + bias = pop_next_op() # Check if next op is FP8 cast - ### TODO Implement + cast_fp8 = None + next_op = peek_next_op() + if ( + linear.tensor_parallel_mode != "column" + and isinstance(next_op, CastFloat8) + and next_op.num_fp8_scales("input") > 0 + ): + cast_fp8 = pop_next_op() # Check if next op is reduce-scatter reduce_scatter = None - if linear.tensor_parallel_mode is None and ops and isinstance(ops[0][0], ReduceScatter): - reduce_scatter = ops[0][0] - window.append(ops[0]) - ops = ops[1:] + if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter): + reduce_scatter = pop_next_op() # Check for invalid combinations if reduce_scatter is None: @@ -526,9 +569,10 @@ def fuse_forward_userbuffers_linear( continue # Replace window with fused op - op = UserbuffersLinear( + op = UserbuffersForwardLinear( linear=linear, bias=bias, + cast_fp8=cast_fp8, reduce_scatter=reduce_scatter, ) basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 4978c0e5f7..1f7dba5487 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -20,7 +20,7 @@ fuse_backward_linear_add, fuse_forward_linear_bias_activation, fuse_forward_linear_bias_add, - fuse_forward_userbuffers_linear, + fuse_userbuffers_forward_linear, ) @@ -319,7 +319,7 @@ def _fuse_forward_ops( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: """Attempt to fuse operations in forward pass""" - ops = fuse_forward_userbuffers_linear(ops) + ops = fuse_userbuffers_forward_linear(ops) ops = fuse_forward_linear_bias_add(ops) ops = fuse_forward_linear_bias_activation(ops) return ops From 1e54b88b7119c968373078154f741d18f6f9a056 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 22 Aug 2024 04:57:21 +0000 Subject: [PATCH 05/13] Debug bug with incorrect bias pointers in UB GEMM Bias pointers are not properly offset for different data chunks. Also removed logic for FP8 RS. Signed-off-by: Tim Moon --- .../test_fusible_ops_with_userbuffers.py | 4 -- .../pytorch/csrc/comm_gemm_overlap.h | 25 ++++++++-- .../ops/fused/userbuffers_forward_linear.py | 46 ++----------------- 3 files changed, 24 insertions(+), 51 deletions(-) diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 6320c47fc1..6b81d644c1 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -288,8 +288,6 @@ def _test_linear( ops.append(bias_op) elif tensor_parallel_mode == "row": userbuffers_options = dict(comm_name="proj") - if fp8_compute: - userbuffers_options["comm_name"] = "fc1" ### TODO Remove linear_op = te_ops.BasicLinear( in_features // world_size, out_features, @@ -301,8 +299,6 @@ def _test_linear( if bias: bias_op = te_ops.Bias(out_features, device=device, dtype=dtype) ops.append(bias_op) - if fp8_compute: - ops.append(te_ops.CastFloat8(backward=False)) ops.append(te_ops.ReduceScatter(process_group)) model = te_ops.Sequential(*ops) with torch.no_grad(): diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 3b4e126943..895f75b11d 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -450,13 +450,14 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { int m_chunk = m / _num_splits; int input_a_chunk_size = m_chunk * k; int output_chunk_size = n * m_chunk; + int bias_chunk_size = m_chunk; int workspace_size_chunk = workspaceSize / _stream_compute.size(); // Get input, output, and workspace data pointers char *input_a_chunk_ptr = reinterpret_cast(A.data_ptr()); char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.data_ptr()); + char *bias_chunk_ptr = reinterpret_cast(bias.data_ptr()); char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); // Catch up the default torch stream @@ -477,27 +478,35 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { torch::Tensor input_a_chunk = torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); torch::Tensor output_chunk = torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); + torch::Tensor bias_chunk = (bias_chunk_ptr == nullptr + ? bias + : torch::from_blob(bias_chunk_ptr, {m_chunk}, bias.options())); torch::Tensor workspace_chunk = torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); at::cuda::setCurrentCUDAStream(_stream_compute[0]); te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + output_chunk, D_scale, D_type, D_amax, bias_chunk, bias_type, pre_gelu_out, grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); for (int i = 1; i < _num_splits; i++) { input_a_chunk_ptr += input_a_chunk_size * B.element_size(); output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); - + if (bias_chunk_ptr != nullptr) { + bias_chunk_ptr += bias_chunk_size * bias.element_size(); + } torch::Tensor input_a_chunk = torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); torch::Tensor output_chunk = torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); + torch::Tensor bias_chunk = (bias_chunk_ptr == nullptr + ? bias + : torch::from_blob(bias_chunk_ptr, {m_chunk}, bias.options())); torch::Tensor workspace_chunk = torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}, workspace.options()); at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + output_chunk, D_scale, D_type, D_amax, bias_chunk, bias_type, pre_gelu_out, grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); @@ -549,12 +558,15 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); torch::Tensor output_chunk = torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); + torch::Tensor bias_chunk = (bias_chunk_ptr == nullptr + ? bias + : torch::from_blob(bias_chunk_ptr, {m_chunk}, bias.options())); torch::Tensor workspace_chunk = torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}, workspace.options()); at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + output_chunk, D_scale, D_type, D_amax, bias_chunk, bias_type, pre_gelu_out, grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); @@ -582,6 +594,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { rs_output_ptr += m_chunk * rs_output.element_size(); input_a_chunk_ptr += input_a_chunk_size * B.element_size(); output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); + if (bias_chunk_ptr != nullptr) { + bias_chunk_ptr += bias_chunk_size * bias.element_size(); + } } } for (size_t i = 0; i < _stream_compute.size(); i++) { diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 3073776645..03dbcc7485 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -15,7 +15,7 @@ from ...float8_tensor import Float8Tensor from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype from ...module.base import get_ub, get_workspace -from ..basic import BasicLinear, Bias, CastFloat8, ReduceScatter +from ..basic import BasicLinear, Bias, ReduceScatter from ..op import ( BasicOperation, FusedOperation, @@ -37,7 +37,6 @@ def __init__( *, linear: BasicLinear, bias: Optional[Bias], - cast_fp8: Optional[CastFloat8], reduce_scatter: Optional[ReduceScatter], ) -> None: @@ -45,16 +44,12 @@ def __init__( op_idxs = dict( linear=0, bias=None, - cast_fp8=None, reduce_scatter=None, ) ops = [linear] if bias is not None: op_idxs["bias"] = len(ops) ops.append(bias) - if cast_fp8 is not None: - op_idxs["cast_fp8"] = len(ops) - ops.append(cast_fp8) if reduce_scatter is not None: op_idxs["reduce_scatter"] = len(ops) ops.append(reduce_scatter) @@ -153,8 +148,6 @@ def _functional_forward( raise ValueError("No FP8 metadata was provided for casting input to FP8") if weight_fp8_meta is None and not is_float8_tensor(weight): raise ValueError("No FP8 metadata was provided for casting weight to FP8") - if output_fp8_meta is None and tensor_parallel_mode == "row": - raise ValueError("No FP8 metadata was provided for casting output to FP8") else: input_fp8_meta = None weight_fp8_meta = None @@ -425,14 +418,6 @@ def fuser_forward( bias = bias_op.bias if basic_op_kwargs[idx]: raise ValueError("Bias operation forward does not expect keyword arguments") - cast_fp8_op = None - cast_fp8_op_ctx = None - if self._op_idxs["cast_fp8"] is not None: - idx = self._op_idxs["cast_fp8"] - cast_fp8_op = self.basic_ops[idx] - cast_fp8_op_ctx = basic_op_ctxs[idx] - if basic_op_kwargs[idx]: - raise ValueError("FP8 cast operation forward does not expect keyword arguments") reduce_scatter_op = None if self._op_idxs["reduce_scatter"] is not None: idx = self._op_idxs["reduce_scatter"] @@ -450,19 +435,9 @@ def fuser_forward( if with_fp8_compute: input_fp8_meta = linear_op.get_fp8_meta("input") weight_fp8_meta = linear_op.get_fp8_meta("param") - if cast_fp8_op is not None: - output_fp8_meta = cast_fp8_op.get_fp8_meta("input") - if output_fp8_meta is None: - next_op = basic_op_next_ops[-1] - if next_op is not None and next_op.num_fp8_scales("input") > 0: - output_fp8_meta = next_op.get_fp8_meta("input") - if self.tensor_parallel_mode == "row" and output_fp8_meta is None: - raise RuntimeError( - "Userbuffers implementation of row tensor-parallel linear with FP8 compute " - "casts GEMM output to FP8, but could not find FP8 metadata to perform this " - "cast. Either disable Userbuffers or insert CastFloat8 op between BasicLinear " - "and ReduceScatter." - ) + next_op = basic_op_next_ops[-1] + if next_op is not None and next_op.num_fp8_scales("input") > 0: + output_fp8_meta = next_op.get_fp8_meta("input") grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output") prev_op = basic_op_prev_ops[0] if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: @@ -500,8 +475,6 @@ def fuser_forward( linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None - if cast_fp8_op is not None: - cast_fp8_op_ctx.cast_backward = with_fp8_compute and cast_fp8_op._cast_backward return output, [() for _ in range(len(self.basic_ops))] @@ -543,16 +516,6 @@ def pop_next_op() -> FusibleOperation: if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias): bias = pop_next_op() - # Check if next op is FP8 cast - cast_fp8 = None - next_op = peek_next_op() - if ( - linear.tensor_parallel_mode != "column" - and isinstance(next_op, CastFloat8) - and next_op.num_fp8_scales("input") > 0 - ): - cast_fp8 = pop_next_op() - # Check if next op is reduce-scatter reduce_scatter = None if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter): @@ -572,7 +535,6 @@ def pop_next_op() -> FusibleOperation: op = UserbuffersForwardLinear( linear=linear, bias=bias, - cast_fp8=cast_fp8, reduce_scatter=reduce_scatter, ) basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] From 80b9d42158de953c71cff8f5d3c0c731a312fbf8 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 23 Aug 2024 04:21:24 +0000 Subject: [PATCH 06/13] Add Userbuffers support for linear dgrad Test passes with row TP, fails with col TP. Signed-off-by: Tim Moon --- .../test_fusible_ops_with_userbuffers.py | 18 +- .../pytorch/ops/fused/__init__.py | 4 + .../ops/fused/userbuffers_backward_linear.py | 552 ++++++++++++++++++ .../ops/fused/userbuffers_forward_linear.py | 24 +- transformer_engine/pytorch/ops/fuser.py | 2 + 5 files changed, 581 insertions(+), 19 deletions(-) create mode 100644 transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 6b81d644c1..843634e653 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -23,7 +23,10 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import is_float8_tensor -from transformer_engine.pytorch.ops.fused import UserbuffersForwardLinear +from transformer_engine.pytorch.ops.fused import ( + UserbuffersBackwardLinear, + UserbuffersForwardLinear, +) from transformer_engine.pytorch.utils import is_bf16_compatible # Import utility functions @@ -313,8 +316,11 @@ def _test_linear( # Check that forward operations have been fused forward_ops = model._module_groups[0]._forward_ops + backward_ops = model._module_groups[0]._backward_ops assert len(forward_ops) == 1 + assert len(backward_ops) == 1 assert isinstance(forward_ops[0][0], UserbuffersForwardLinear) + assert isinstance(backward_ops[0][0], UserbuffersBackwardLinear) # Expected numerical error tols = dtype_tols(dtype) @@ -330,13 +336,13 @@ def _test_linear( # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") - dw_test = linear_op.weight.grad.to(dtype=torch.float64, device="cpu") + # dw_test = linear_op.weight.grad.to(dtype=torch.float64, device="cpu") ### TODO Restore torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, dx_ref, **tols) - torch.testing.assert_close(dw_test, dw_ref, **tols) - if bias: - db_test = bias_op.bias.grad.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(db_test, db_ref, **tols) + # torch.testing.assert_close(dw_test, dw_ref, **tols) ### TODO Restore + # if bias: ### TODO Restore + # db_test = bias_op.bias.grad.to(dtype=torch.float64, device="cpu") + # torch.testing.assert_close(db_test, db_ref, **tols) def run_parallel_tests(model_config: ModelConfig) -> None: diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 85a3e31b4e..08b9f06123 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -16,6 +16,10 @@ ForwardLinearBiasAdd, fuse_forward_linear_bias_add, ) +from .userbuffers_backward_linear import ( + UserbuffersBackwardLinear, + fuse_userbuffers_backward_linear, +) from .userbuffers_forward_linear import ( UserbuffersForwardLinear, fuse_userbuffers_forward_linear, diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py new file mode 100644 index 0000000000..530993eaf3 --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -0,0 +1,552 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Linear layer backward with Userbuffers communication.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +from ...cpp_extensions import FP8TensorMeta, UbufOverlapAlgo, fp8_gemm, gemm +from ...distributed import get_distributed_world_size +from ...float8_tensor import Float8Tensor +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ...module.base import get_ub, get_workspace +from ..basic import BasicLinear, Bias, ReduceScatter +from ..op import ( + BasicOperation, + FusedOperation, + FusibleOperation, + OperationContext, +) +from .._common import ( + canonicalize_device, + canonicalize_dtype, + convert_tensor, + is_float8_tensor, + reshape, +) + +class UserbuffersBackwardLinear(FusedOperation): + + def __init__( + self, + *, + linear: BasicLinear, + bias: Optional[Bias], + reduce_scatter: Optional[ReduceScatter], + ) -> None: + + # Basic operations that comprise this fused operation + op_idxs = dict( + linear=None, + bias=None, + reduce_scatter=None, + ) + ops = [] + if reduce_scatter is not None: + op_idxs["reduce_scatter"] = len(ops) + ops.append(reduce_scatter) + if bias is not None: + op_idxs["bias"] = len(ops) + ops.append(bias) + op_idxs["linear"] = len(ops) + ops.append(linear) + + # Initialize base class + super().__init__(ops) + + # Index of each basic operations + self._op_idxs: dict[str, Optional[int]] = op_idxs + + # Tensor parallelism configuration + self.tensor_parallel_mode: Optional[str] + self.tensor_parallel_group: Optional[torch.distributed.ProcessGroup] + self.tensor_parallel_size: int + self.sequence_parallel: bool + if reduce_scatter is None: + self.tensor_parallel_mode = linear.tensor_parallel_mode + self.tensor_parallel_group = linear.tensor_parallel_group + self.tensor_parallel_size = linear.tensor_parallel_size + self.sequence_parallel = linear.sequence_parallel + else: + self.tensor_parallel_mode = "row" + self.tensor_parallel_group = reduce_scatter.process_group + self.tensor_parallel_size = reduce_scatter.process_group_size + self.sequence_parallel = True + + @staticmethod + def _functional_backward_dgrad( + grad_output_local: torch.Tensor, + weight: Optional[torch.Tensor], + *, + weight_requires_grad: bool, + device: torch.device, + dtype: torch.dtype, + tensor_parallel_mode: str, + tensor_parallel_group: Optional[torch.distributed.ProcessGroup], + tensor_parallel_size: int, + with_fp8_compute: bool, + with_fp8_grad_input: bool, + weight_fp8_meta: Optional[dict[str, Any]], + grad_output_fp8_meta: Optional[dict[str, Any]], + grad_input_fp8_meta: Optional[dict[str, Any]], + ub_comm_name: str, + ): + + # Get Userbuffers communicator + ub_comm = get_ub(ub_comm_name + "_dgrad") + ub_local_buffer = ub_comm.get_ubuf_output(0) + ub_global_buffer = ub_comm.get_ubuf_output(1) + with_ub_reduce_scatter = tensor_parallel_mode == "column" + with_ub_all_gather = tensor_parallel_mode == "row" + + # Choose Userbuffers communication algorithm + ub_algo = None + if with_ub_all_gather: + if with_fp8_compute and ub_comm.is_atomic_gemm(): + ub_algo = UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo = UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + elif with_ub_reduce_scatter: + ub_algo = UbufOverlapAlgo.BULK_OVERLAP_AG ### TODO Is this right? + else: + raise RuntimeError("Could not choose Userbuffers communication algorithm") + + # Cast grad output tensor to correct dtype + dy_local = grad_output_local + if with_fp8_compute and not is_float8_tensor(dy_local): + fp8_dtype = get_fp8_te_dtype( + grad_output_fp8_meta["recipe"], + fprop_tensor=False, + ) + if with_ub_all_gather: + data = ub_local_buffer + else: + data = torch.empty_like(dy_local, dtype=torch.uint8) + dy_fp8 = Float8Tensor( + data=data, + fp8_meta=grad_output_fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), + dtype=dtype, + ) + with_cast_transpose = weight_requires_grad and not with_ub_all_gather + if with_cast_transpose: + dy_fp8.cast_transpose_(dy_local) + else: + dy_fp8.copy_(dy_local) + dy_local = dy_fp8 + elif not with_fp8_compute and is_float8_tensor(dy_local): + if with_ub_all_gather: + dy_local = ub_local_buffer.copy_(dy_local) + else: + dy_local = dy_local.from_float8() + + # Initialize buffers for UB all-gather if needed + dy = dy_local + if with_ub_all_gather: + if with_fp8_compute: + dy = Float8Tensor.make_like(dy_local, data=ub_global_buffer) + if dy_local._data.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(dy_local._data) + else: + dy = ub_global_buffer + if dy_local.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(dy_local) + + # Check weight tensor + w = convert_tensor( + weight, + device=device, + dtype=dtype, + memory_format=torch.contiguous_format, + ) + if with_fp8_compute and not is_float8_tensor(w): + fp8_dtype = get_fp8_te_dtype( + weight_fp8_meta["recipe"], + fprop_tensor=True, + ) + w_fp8 = Float8Tensor( + data=torch.empty_like(w, dtype=torch.uint8), + fp8_meta=weight_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), + dtype=dtype, + ) + w_fp8.cast_transpose_(w) + w = w_fp8 + elif not with_fp8_compute and is_float8_tensor(w): + w = w.from_float8() + + # Construct grad input tensor + dx = None + dx_local = None + if with_ub_reduce_scatter: + # Initialize buffers for UB reduce-scatter + dx = ub_global_buffer + dx_local = torch.empty( + (dy.size(0) // tensor_parallel_size, w.size(-1)), + dtype=dtype, + device=device, + ) + else: + # Allocate grad input tensor + if with_fp8_grad_input: + fp8_dtype = get_fp8_te_dtype( + grad_input_fp8_meta["recipe"], + fprop_tensor=False, + ) + data = torch.empty( + (dy.size(0), w.size(-1)), + dtype=torch.uint8, + device=device, + ) + dx = Float8Tensor( + data=data, + fp8_meta=grad_input_fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + dtype=dtype, + ) + else: + dx = torch.empty( + (dy.size(0), w.size(-1)), + dtype=dtype, + device=device, + ) + dx_local = dx + + # Perform dgrad GEMM + if with_fp8_compute: + kwargs = dict( + out=dx, + use_split_accumulator=False, ### TODO ? + ub_algo=ub_algo, + ub=ub_comm, + ) + if with_ub_reduce_scatter: + kwargs["extra_output_tensor"] = dx_local + if with_fp8_grad_input: + if dx._fp8_meta is None: + # Hackily create FP8TensorMeta if needed + fp8_meta = FP8TensorMeta() + fp8_meta.scale = dx._scale_inv.reciprocal() + fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=device) + fp8_meta.scale_inv = dx._scale_inv + fp8_meta_index = 0 + else: + # Get FP8TensorMeta from Float8Tensor + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=dx._fp8_meta_forward, + ) + fp8_meta = dx._fp8_meta[fp8_meta_key] + fp8_meta_index = dx._fp8_meta_index + kwargs.update( + dict( + out=dx._data, + out_index=fp8_meta_index, + fp8_meta_tensor=fp8_meta, + D_dtype=dx._fp8_dtype, + ) + ) + fp8_gemm( + w.transpose_2d(), + w._scale_inv, + 0, + w._fp8_dtype, + dy._data, + dy._scale_inv, + 0, + dy._fp8_dtype, + dy.dtype, + get_workspace(), + **kwargs, + ) + else: + kwargs = dict( + layout="NN", + out=dx, + ub_algo=ub_algo, + ub=ub_comm, + ) + if with_ub_reduce_scatter: + kwargs["extra_output_tensor"] = dx_local + gemm(w, dy, dx.dtype, get_workspace(), **kwargs) + + return dx_local, dy + + @staticmethod + def _functional_backward( + grad_output: torch.Tensor, + input: Optional[torch.Tensor], # pylint: disable=redefined-builtin + weight: Optional[torch.Tensor], + input_dims: Iterable[int], + weight_dims: Iterable[int], + *, + weight_requires_grad: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + grad_weight: Optional[torch.Tensor] = None, + accumulate_into_grad_weight: bool = False, + grad_bias: Optional[torch.Tensor] = None, + tensor_parallel_mode: Optional[str] = None, + tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + tensor_parallel_size: Optional[int] = None, + sequence_parallel: bool = False, + with_fp8_compute: bool = False, + weight_fp8_meta: Optional[dict[str, Any]] = None, + grad_output_fp8_meta: Optional[dict[str, Any]] = None, + grad_input_fp8_meta: Optional[dict[str, Any]] = None, + ub_comm_name: str, + ): + + # Check device + if device is None: + device = weight.device + device = canonicalize_device(device) + if device.type != "cuda": + raise ValueError(f"Only CUDA devices are supported (got {device})") + + # Check datatype + if dtype is None: + dtype = weight.dtype + dtype = canonicalize_dtype(dtype) + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") + + # Check input tensor + output_dims = tuple(grad_output.size()) + input_dims = tuple(input_dims) + weight_dims = tuple(weight_dims) + if len(weight_dims) != 2: + raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})") + if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]: + raise ValueError( + f"Input tensor (shape={input_dims}) " + f"and weight tensor (shape={weight_dims}) " + "are not compatible" + ) + if weight_dims[0] != output_dims[-1]: + raise ValueError( + f"Grad output tensor (shape={output_dims}) " + f"and weight tensor (shape={weight_dims}) " + "are not compatible" + ) + + # Check tensor parallel group + if tensor_parallel_size is None: + tensor_parallel_size = get_distributed_world_size(tensor_parallel_group) + if tensor_parallel_size == 1: + tensor_parallel_mode = None + if tensor_parallel_mode not in ("column", "row"): + raise RuntimeError( + "Invalid configuration for Userbuffers " + f"({tensor_parallel_size=}, {tensor_parallel_mode=})" + ) + if not sequence_parallel: + raise RuntimeError( + f"Invalid configuration for Userbuffers ({sequence_parallel=})" + ) + + # Check if FP8 is enabled + if with_fp8_compute: + if grad_output_fp8_meta is None and not is_float8_tensor(grad_output): + raise ValueError("No FP8 metadata was provided for casting output gradient to FP8") + else: + weight_fp8_meta = None + grad_output_fp8_meta = None + grad_input_fp8_meta = None + with_fp8_grad_input = ( + with_fp8_compute + and tensor_parallel_mode != "column" + and grad_input_fp8_meta is not None + ) + + # Perform dgrad GEMM + dy_local = reshape( + grad_output, + (-1, output_dims[-1]), + device=device, + dtype=dtype, + ) + dx_local, dy = UserbuffersBackwardLinear._functional_backward_dgrad( + grad_output_local=dy_local, + weight=weight, + weight_requires_grad=weight_requires_grad, + device=device, + dtype=dtype, + tensor_parallel_mode=tensor_parallel_mode, + tensor_parallel_group=tensor_parallel_group, + tensor_parallel_size=tensor_parallel_size, + with_fp8_compute=with_fp8_compute, + with_fp8_grad_input=with_fp8_grad_input, + weight_fp8_meta=weight_fp8_meta, + grad_output_fp8_meta=grad_output_fp8_meta, + grad_input_fp8_meta=grad_output_fp8_meta, + ub_comm_name=ub_comm_name, + ) + grad_input = reshape(dx_local, input_dims) + + # Perform wgrad GEMM + if not weight_requires_grad: + grad_weight = None + else: + raise NotImplementedError() ### TODO Implement + + return grad_input, grad_weight + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + list[tuple[Optional[torch.Tensor], ...]], + list[tuple[()]], + ]: + + # Get basic operations + idx = self._op_idxs["linear"] + linear_op = self.basic_ops[idx] + linear_op_ctx = basic_op_ctxs[idx] + bias_op = None + bias = None + if self._op_idxs["bias"] is not None: + idx = self._op_idxs["bias"] + bias_op = self.basic_ops[idx] + bias = bias_op.bias + + # Saved tensors from forward pass + (x_local,) = linear_op_ctx.saved_tensors + + # wgrad fusion + accumulate_into_main_grad = linear_op._accumulate_into_main_grad + grad_weight = None + if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: + if not hasattr(linear_op.weight, "main_grad"): + raise RuntimeError( + "BasicLinear op is configured with " + "accumulate_into_main_grad=True, " + "but weight parameter does not have main_grad attribute" + ) + grad_weight = linear_op.weight.main_grad.detach() + else: + accumulate_into_main_grad = False + + # Linear backward pass + grad_bias = None ### TODO Implement + grad_input, grad_weight = UserbuffersBackwardLinear._functional_backward( + grad_output=grad_output, + input=x_local, + weight=linear_op.weight, + input_dims=linear_op_ctx.input_dims, + weight_dims=linear_op.weight.size(), + weight_requires_grad=linear_op_ctx.weight_requires_grad, + device=linear_op.device, + dtype=linear_op.dtype, + grad_weight=grad_weight, + accumulate_into_grad_weight=accumulate_into_main_grad, + grad_bias=grad_bias, + tensor_parallel_mode=self.tensor_parallel_mode, + tensor_parallel_group=self.tensor_parallel_group, + sequence_parallel=self.sequence_parallel, + with_fp8_compute=linear_op_ctx.with_fp8_compute, + weight_fp8_meta=linear_op_ctx.weight_fp8_meta, + grad_output_fp8_meta=linear_op_ctx.grad_output_fp8_meta, + grad_input_fp8_meta=linear_op_ctx.grad_input_fp8_meta, + ub_comm_name=linear_op._userbuffers_options["comm_name"], + ) + if accumulate_into_main_grad: + grad_weight = None + + # Clear input tensor if possible + if linear_op_ctx.has_prev_op: + clear_tensor_data(x_local) + + # Return gradients + grad_params = [() for _ in range(len(self.basic_ops))] + grad_params[self._op_idxs["linear"]] = (grad_weight,) + if bias_op is not None: + grad_params[self._op_idxs["bias"]] = (grad_bias,) + grad_extra_inputs = [() for _ in range(len(self.basic_ops))] + return grad_input, grad_params, grad_extra_inputs + + +def fuse_userbuffers_backward_linear( + ops: list[tuple[FusibleOperation, list[int]]], +) -> list[tuple[FusibleOperation, list[int]]]: + + # Sliding window in list of ops + window = [] + + def peek_next_op() -> Optional[FusibleOperation]: + nonlocal ops + if not ops: + return None + return ops[-1][0] + + def pop_next_op() -> FusibleOperation: + nonlocal ops, window + window.insert(0, ops[-1]) + ops = ops[:-1] + return window[0][0] + + # Scan through ops in reverse order, fusing if possible + out_reversed = [] + while ops: + out_reversed.extend(reversed(window)) + window.clear() + + # Check if next op is linear + next_op = pop_next_op() + if not isinstance(next_op, BasicLinear): + continue + linear = next_op + if linear._userbuffers_options is None: + continue + + # Check if next op is bias + bias = None + if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias): + bias = pop_next_op() + + # Check if next op is reduce-scatter + reduce_scatter = None + if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter): + reduce_scatter = pop_next_op() + + # Check for invalid combinations + if reduce_scatter is None: + if linear.tensor_parallel_mode is None: + continue + if linear.tensor_parallel_mode == "row" and bias is not None: + continue + else: + if linear.tensor_parallel_mode is not None: + continue + + # Replace window with fused op + op = UserbuffersBackwardLinear( + linear=linear, + bias=bias, + reduce_scatter=reduce_scatter, + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out_reversed.extend(reversed(window)) + out = out_reversed + out.reverse() + return out diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 03dbcc7485..c81b4dbe99 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""Linear layer with Userbuffers communication.""" +"""Linear layer forward with Userbuffers communication.""" from __future__ import annotations from collections.abc import Iterable @@ -97,19 +97,19 @@ def _functional_forward( # Check device if device is None: - device = weight.device if out is None else out.device + device = weight.device device = canonicalize_device(device) if device.type != "cuda": raise ValueError(f"Only CUDA devices are supported (got {device})") # Check datatype if dtype is None: - dtype = weight.dtype if out is None else out.dtype + dtype = weight.dtype dtype = canonicalize_dtype(dtype) if dtype not in (torch.float32, torch.float16, torch.bfloat16): raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") - # Check input tensor dims + # Input tensor dims input_dims = tuple(input.size()) weight_dims = tuple(weight.size()) if len(weight_dims) != 2: @@ -121,8 +121,7 @@ def _functional_forward( "are not compatible" ) - # Check output tensor dims - output_dims: list[int] + # Output tensor dims output_dims = list(input_dims) output_dims[0] = -1 output_dims[-1] = weight_dims[0] @@ -152,7 +151,11 @@ def _functional_forward( input_fp8_meta = None weight_fp8_meta = None output_fp8_meta = None - with_fp8_output = with_fp8_compute and output_fp8_meta is not None + with_fp8_output = ( + with_fp8_compute + and tensor_parallel_mode != "row" + and output_fp8_meta is not None + ) # Get Userbuffers communicator ub_comm = get_ub(ub_comm_name + "_fprop") @@ -418,12 +421,6 @@ def fuser_forward( bias = bias_op.bias if basic_op_kwargs[idx]: raise ValueError("Bias operation forward does not expect keyword arguments") - reduce_scatter_op = None - if self._op_idxs["reduce_scatter"] is not None: - idx = self._op_idxs["reduce_scatter"] - reduce_scatter_op = self.basic_ops[idx] - if basic_op_kwargs[idx]: - raise ValueError("Reduce-scatter operation forward does not expect keyword arguments") # FP8 metadata with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() @@ -483,6 +480,7 @@ def fuse_userbuffers_forward_linear( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: + # Sliding window in list of ops window = [] def peek_next_op() -> Optional[FusibleOperation]: diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 1f7dba5487..ed34f97a81 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -20,6 +20,7 @@ fuse_backward_linear_add, fuse_forward_linear_bias_activation, fuse_forward_linear_bias_add, + fuse_userbuffers_backward_linear, fuse_userbuffers_forward_linear, ) @@ -330,6 +331,7 @@ def _fuse_backward_ops( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: """Attempt to fuse operations in backward pass""" + ops = fuse_userbuffers_backward_linear(ops) ops = fuse_backward_linear_add(ops) return ops From e6ad5711af29a2be160de7c899da813e3d016130 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 24 Aug 2024 00:29:53 +0000 Subject: [PATCH 07/13] Add Userbuffers support for linear wgrad Signed-off-by: Tim Moon --- .../test_fusible_ops_with_userbuffers.py | 4 +- transformer_engine/pytorch/ops/_common.py | 25 +- .../ops/fused/userbuffers_backward_linear.py | 460 ++++++++++-------- .../ops/fused/userbuffers_forward_linear.py | 28 +- 4 files changed, 303 insertions(+), 214 deletions(-) diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 843634e653..776eeb5b87 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -336,10 +336,10 @@ def _test_linear( # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") - # dw_test = linear_op.weight.grad.to(dtype=torch.float64, device="cpu") ### TODO Restore + dw_test = linear_op.weight.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, dx_ref, **tols) - # torch.testing.assert_close(dw_test, dw_ref, **tols) ### TODO Restore + torch.testing.assert_close(dw_test, dw_ref, **tols) # if bias: ### TODO Restore # db_test = bias_op.bias.grad.to(dtype=torch.float64, device="cpu") # torch.testing.assert_close(db_test, db_ref, **tols) diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 77efef4ab6..5080942437 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -9,7 +9,8 @@ import torch -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from ..cpp_extensions import FP8TensorMeta +from ..float8_tensor import Float8Tensor def canonicalize_device(device: Optional[torch.device | str]) -> torch.device: @@ -150,3 +151,25 @@ def reshape( # Reshape standard PyTorch tensor return tensor.view(shape) + + +def get_fp8_meta_from_fp8_tensor(tensor: Float8Tensor) -> tuple[FP8TensorMeta, int]: + """Get FP8TensorMeta object and index corrsponding to Float8Tensor + + Constructs FP8TensorMeta if needed. + + """ + + # Check if tensor already has FP8 metadata + if tensor._fp8_meta is not None: + key = FP8GlobalStateManager.get_meta_tensor_key( + forward=tensor._fp8_meta_forward, + ) + return tensor._fp8_meta[key], tensor._fp8_meta_index + + # Create FP8TensorMeta class + fp8_meta = FP8TensorMeta() + fp8_meta.scale = tensor._scale_inv.reciprocal() + fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=tensor.device) + fp8_meta.scale_inv = tensor._scale_inv + return fp8_meta, 0 diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 530993eaf3..71b07fd043 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -10,7 +10,7 @@ import torch -from ...cpp_extensions import FP8TensorMeta, UbufOverlapAlgo, fp8_gemm, gemm +from ...cpp_extensions import UbufOverlapAlgo, fp8_gemm, gemm from ...distributed import get_distributed_world_size from ...float8_tensor import Float8Tensor from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype @@ -26,6 +26,7 @@ canonicalize_device, canonicalize_dtype, convert_tensor, + get_fp8_meta_from_fp8_tensor, is_float8_tensor, reshape, ) @@ -79,52 +80,135 @@ def __init__( self.sequence_parallel = True @staticmethod - def _functional_backward_dgrad( - grad_output_local: torch.Tensor, + def _functional_backward( + grad_output: torch.Tensor, + input: Optional[torch.Tensor], # pylint: disable=redefined-builtin weight: Optional[torch.Tensor], + input_dims: Iterable[int], + weight_dims: Iterable[int], *, - weight_requires_grad: bool, - device: torch.device, - dtype: torch.dtype, - tensor_parallel_mode: str, - tensor_parallel_group: Optional[torch.distributed.ProcessGroup], - tensor_parallel_size: int, - with_fp8_compute: bool, - with_fp8_grad_input: bool, - weight_fp8_meta: Optional[dict[str, Any]], - grad_output_fp8_meta: Optional[dict[str, Any]], - grad_input_fp8_meta: Optional[dict[str, Any]], + input_requires_grad: bool = True, + weight_requires_grad: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + grad_weight: Optional[torch.Tensor] = None, + accumulate_into_grad_weight: bool = False, + grad_bias: Optional[torch.Tensor] = None, + tensor_parallel_mode: Optional[str] = None, + tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + tensor_parallel_size: Optional[int] = None, + sequence_parallel: bool = False, + with_fp8_compute: bool = False, + weight_fp8_meta: Optional[dict[str, Any]] = None, + grad_output_fp8_meta: Optional[dict[str, Any]] = None, + grad_input_fp8_meta: Optional[dict[str, Any]] = None, ub_comm_name: str, ): - # Get Userbuffers communicator - ub_comm = get_ub(ub_comm_name + "_dgrad") - ub_local_buffer = ub_comm.get_ubuf_output(0) - ub_global_buffer = ub_comm.get_ubuf_output(1) - with_ub_reduce_scatter = tensor_parallel_mode == "column" - with_ub_all_gather = tensor_parallel_mode == "row" - - # Choose Userbuffers communication algorithm - ub_algo = None - if with_ub_all_gather: - if with_fp8_compute and ub_comm.is_atomic_gemm(): - ub_algo = UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P - else: - ub_algo = UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P - elif with_ub_reduce_scatter: - ub_algo = UbufOverlapAlgo.BULK_OVERLAP_AG ### TODO Is this right? + assert input_requires_grad ### TODO Support optional + assert weight_requires_grad ### TODO Support optional + + # Check device + if device is None: + device = weight.device + device = canonicalize_device(device) + if device.type != "cuda": + raise ValueError(f"Only CUDA devices are supported (got {device})") + + # Check datatype + if dtype is None: + dtype = weight.dtype + dtype = canonicalize_dtype(dtype) + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") + + # Input tensor dims + output_dims = tuple(grad_output.size()) + input_dims = tuple(input_dims) + weight_dims = tuple(weight_dims) + if len(weight_dims) != 2: + raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})") + if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]: + raise ValueError( + f"Input tensor (shape={input_dims}) " + f"and weight tensor (shape={weight_dims}) " + "are not compatible" + ) + if weight_dims[0] != output_dims[-1]: + raise ValueError( + f"Grad output tensor (shape={output_dims}) " + f"and weight tensor (shape={weight_dims}) " + "are not compatible" + ) + + # Check tensor parallel group + if tensor_parallel_size is None: + tensor_parallel_size = get_distributed_world_size(tensor_parallel_group) + if tensor_parallel_size == 1: + tensor_parallel_mode = None + if tensor_parallel_mode not in ("column", "row"): + raise RuntimeError( + "Invalid configuration for Userbuffers " + f"({tensor_parallel_size=}, {tensor_parallel_mode=})" + ) + if not sequence_parallel: + raise RuntimeError( + f"Invalid configuration for Userbuffers ({sequence_parallel=})" + ) + + # Check if FP8 is enabled + if with_fp8_compute: + if grad_output_fp8_meta is None and not is_float8_tensor(grad_output): + raise ValueError("No FP8 metadata was provided for casting output gradient to FP8") else: - raise RuntimeError("Could not choose Userbuffers communication algorithm") + weight_fp8_meta = None + grad_output_fp8_meta = None + grad_input_fp8_meta = None + with_fp8_grad_input = ( + with_fp8_compute + and tensor_parallel_mode != "column" + and grad_input_fp8_meta is not None + ) - # Cast grad output tensor to correct dtype - dy_local = grad_output_local + # Get Userbuffers communicators and algorithms + with_ub_all_gather_dy = False + with_ub_reduce_scatter_dx = False + with_ub_all_gather_x = False + ub_comm_dy = None + ub_comm_dx = None + ub_comm_x = None + ub_algo_dy = None + ub_algo_dx = None + ub_algo_x = None + if tensor_parallel_mode == "row": + with_ub_all_gather_dy = True + ub_comm_dy = get_ub(ub_comm_name + "_dgrad") + if with_fp8_compute and ub_comm_dy.is_atomic_gemm(): + ub_algo_dy = UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo_dy = UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + elif tensor_parallel_mode == "column": + with_ub_reduce_scatter_dx = True + with_ub_all_gather_x = True + ub_comm_dx = get_ub(ub_comm_name + "_wgrad") + ub_comm_x = get_ub(ub_comm_name + "_dgrad") + ub_algo_dx = UbufOverlapAlgo.BULK_OVERLAP_RS + ub_algo_x = UbufOverlapAlgo.BULK_OVERLAP_AG + + # Check grad output tensor + dy_local = reshape( + grad_output, + (-1, output_dims[-1]), + device=device, + dtype=dtype, + ) if with_fp8_compute and not is_float8_tensor(dy_local): fp8_dtype = get_fp8_te_dtype( grad_output_fp8_meta["recipe"], fprop_tensor=False, ) - if with_ub_all_gather: - data = ub_local_buffer + if with_ub_all_gather_dy: + data = ub_comm_dy.get_ubuf_output(0) else: data = torch.empty_like(dy_local, dtype=torch.uint8) dy_fp8 = Float8Tensor( @@ -136,29 +220,54 @@ def _functional_backward_dgrad( fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), dtype=dtype, ) - with_cast_transpose = weight_requires_grad and not with_ub_all_gather - if with_cast_transpose: - dy_fp8.cast_transpose_(dy_local) - else: + if with_ub_all_gather_dy: dy_fp8.copy_(dy_local) + else: + dy_fp8.cast_transpose_(dy_local) dy_local = dy_fp8 elif not with_fp8_compute and is_float8_tensor(dy_local): - if with_ub_all_gather: + if with_ub_all_gather_dy: + ub_local_buffer = ub_comm_dy.get_ubuf_output(0) dy_local = ub_local_buffer.copy_(dy_local) else: dy_local = dy_local.from_float8() - # Initialize buffers for UB all-gather if needed - dy = dy_local - if with_ub_all_gather: - if with_fp8_compute: - dy = Float8Tensor.make_like(dy_local, data=ub_global_buffer) - if dy_local._data.data_ptr() != ub_local_buffer.data_ptr(): - ub_local_buffer.copy_(dy_local._data) + # Check input tensor + x_local = reshape( + input, + (-1, input_dims[-1]), + device=device, + dtype=dtype, + ) + if with_fp8_compute and not is_float8_tensor(x_local): + fp8_dtype = get_fp8_te_dtype( + input_fp8_meta["recipe"], + fprop_tensor=True, + ) + if with_ub_all_gather_x: + data = ub_comm_x.get_ubuf_output(0) else: - dy = ub_global_buffer - if dy_local.data_ptr() != ub_local_buffer.data_ptr(): - ub_local_buffer.copy_(dy_local) + data = torch.empty_like(x_local, dtype=torch.uint8) + x_fp8 = Float8Tensor( + data=data, + fp8_meta=input_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), + dtype=dtype, + ) + if with_ub_all_gather_x: + x_fp8.copy_(x_local) + else: + x_fp8.cast_transpose_(x_local) + x_local = x_fp8 + elif not with_fp8_compute and is_float8_tensor(x_local): + if with_ub_all_gather_x: + ub_local_buffer = ub_comm_x.get_ubuf_output(0) + x_local = ub_local_buffer.copy_(x_local) + else: + x_local = x_local.from_float8() # Check weight tensor w = convert_tensor( @@ -186,17 +295,39 @@ def _functional_backward_dgrad( elif not with_fp8_compute and is_float8_tensor(w): w = w.from_float8() + # Initialize buffers for UB all-gather if needed + dy = dy_local + x = x_local + if with_ub_all_gather_dy: + ub_local_buffer = ub_comm_dy.get_ubuf_output(0) + ub_global_buffer = ub_comm_dy.get_ubuf_output(1) + if with_fp8_compute: + dy = Float8Tensor.make_like(dy_local, data=ub_global_buffer) + if dy_local._data.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(dy_local._data) + else: + dy = ub_global_buffer + if dy_local.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(dy_local) + if with_ub_all_gather_x: + ub_local_buffer = ub_comm_x.get_ubuf_output(0) + ub_global_buffer = ub_comm_x.get_ubuf_output(1) + if with_fp8_compute: + x = Float8Tensor.make_like(x_local, data=ub_global_buffer) + if x_local._data.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(x_local._data) + else: + x = ub_global_buffer + if x_local.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(x_local) + # Construct grad input tensor dx = None dx_local = None - if with_ub_reduce_scatter: + if with_ub_reduce_scatter_dx: # Initialize buffers for UB reduce-scatter - dx = ub_global_buffer - dx_local = torch.empty( - (dy.size(0) // tensor_parallel_size, w.size(-1)), - dtype=dtype, - device=device, - ) + dx = ub_comm_dx.get_ubuf_output(1) + dx_local = ub_comm_dx.get_ubuf_output(0) else: # Allocate grad input tensor if with_fp8_grad_input: @@ -225,31 +356,34 @@ def _functional_backward_dgrad( ) dx_local = dx + # Allocate grad input tensor + if grad_weight is None: + if accumulate_into_grad_weight: + raise ValueError( + "Attempted to accumulate into grad weight buffer" + "without providing grad weight" + ) + grad_weight = torch.empty( + weight_dims, + dtype=dtype, + device=device, + memory_format=torch.contiguous_format, + ) + # Perform dgrad GEMM if with_fp8_compute: kwargs = dict( out=dx, - use_split_accumulator=False, ### TODO ? - ub_algo=ub_algo, - ub=ub_comm, + use_split_accumulator=True, ) - if with_ub_reduce_scatter: - kwargs["extra_output_tensor"] = dx_local + if with_ub_all_gather_dy: + kwargs["ub_algo"] = ub_algo_dy + kwargs["ub"] = ub_comm_dy + elif with_ub_all_gather_x: + kwargs["ub_algo"] = ub_algo_x + kwargs["ub"] = ub_comm_x if with_fp8_grad_input: - if dx._fp8_meta is None: - # Hackily create FP8TensorMeta if needed - fp8_meta = FP8TensorMeta() - fp8_meta.scale = dx._scale_inv.reciprocal() - fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=device) - fp8_meta.scale_inv = dx._scale_inv - fp8_meta_index = 0 - else: - # Get FP8TensorMeta from Float8Tensor - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dx._fp8_meta_forward, - ) - fp8_meta = dx._fp8_meta[fp8_meta_key] - fp8_meta_index = dx._fp8_meta_index + fp8_meta, fp8_meta_index = get_fp8_meta_from_fp8_tensor(dx) kwargs.update( dict( out=dx._data, @@ -273,135 +407,53 @@ def _functional_backward_dgrad( ) else: kwargs = dict( + grad=True, layout="NN", out=dx, - ub_algo=ub_algo, - ub=ub_comm, ) - if with_ub_reduce_scatter: - kwargs["extra_output_tensor"] = dx_local + if with_ub_all_gather_dy: + kwargs["ub_algo"] = ub_algo_dy + kwargs["ub"] = ub_comm_dy + elif with_ub_all_gather_x: + kwargs["ub_algo"] = ub_algo_x + kwargs["ub"] = ub_comm_x gemm(w, dy, dx.dtype, get_workspace(), **kwargs) - return dx_local, dy - - @staticmethod - def _functional_backward( - grad_output: torch.Tensor, - input: Optional[torch.Tensor], # pylint: disable=redefined-builtin - weight: Optional[torch.Tensor], - input_dims: Iterable[int], - weight_dims: Iterable[int], - *, - weight_requires_grad: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - grad_weight: Optional[torch.Tensor] = None, - accumulate_into_grad_weight: bool = False, - grad_bias: Optional[torch.Tensor] = None, - tensor_parallel_mode: Optional[str] = None, - tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - tensor_parallel_size: Optional[int] = None, - sequence_parallel: bool = False, - with_fp8_compute: bool = False, - weight_fp8_meta: Optional[dict[str, Any]] = None, - grad_output_fp8_meta: Optional[dict[str, Any]] = None, - grad_input_fp8_meta: Optional[dict[str, Any]] = None, - ub_comm_name: str, - ): - - # Check device - if device is None: - device = weight.device - device = canonicalize_device(device) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - - # Check datatype - if dtype is None: - dtype = weight.dtype - dtype = canonicalize_dtype(dtype) - if dtype not in (torch.float32, torch.float16, torch.bfloat16): - raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") - - # Check input tensor - output_dims = tuple(grad_output.size()) - input_dims = tuple(input_dims) - weight_dims = tuple(weight_dims) - if len(weight_dims) != 2: - raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})") - if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]: - raise ValueError( - f"Input tensor (shape={input_dims}) " - f"and weight tensor (shape={weight_dims}) " - "are not compatible" - ) - if weight_dims[0] != output_dims[-1]: - raise ValueError( - f"Grad output tensor (shape={output_dims}) " - f"and weight tensor (shape={weight_dims}) " - "are not compatible" - ) - - # Check tensor parallel group - if tensor_parallel_size is None: - tensor_parallel_size = get_distributed_world_size(tensor_parallel_group) - if tensor_parallel_size == 1: - tensor_parallel_mode = None - if tensor_parallel_mode not in ("column", "row"): - raise RuntimeError( - "Invalid configuration for Userbuffers " - f"({tensor_parallel_size=}, {tensor_parallel_mode=})" + # Perform wgrad GEMM + if with_fp8_compute: + kwargs = dict( + accumulate=accumulate_into_grad_weight, + out=grad_weight, + use_split_accumulator=True, ) - if not sequence_parallel: - raise RuntimeError( - f"Invalid configuration for Userbuffers ({sequence_parallel=})" + if with_ub_reduce_scatter_dx: + kwargs["ub_algo"] = ub_algo_dx + kwargs["ub"] = ub_comm_dx + fp8_gemm( + x.transpose_2d(), + x._scale_inv, + 0, + x._fp8_dtype, + dy.transpose_2d(), + dy._scale_inv, + 0, + dy._fp8_dtype, + grad_weight.dtype, + get_workspace(), + **kwargs, ) - - # Check if FP8 is enabled - if with_fp8_compute: - if grad_output_fp8_meta is None and not is_float8_tensor(grad_output): - raise ValueError("No FP8 metadata was provided for casting output gradient to FP8") else: - weight_fp8_meta = None - grad_output_fp8_meta = None - grad_input_fp8_meta = None - with_fp8_grad_input = ( - with_fp8_compute - and tensor_parallel_mode != "column" - and grad_input_fp8_meta is not None - ) + kwargs = dict( + accumulate=accumulate_into_grad_weight, + layout="NT", + out=grad_weight, + ) + if with_ub_reduce_scatter_dx: + kwargs["ub_algo"] = ub_algo_dx + kwargs["ub"] = ub_comm_dx + gemm(x, dy, grad_weight.dtype, get_workspace(), **kwargs) - # Perform dgrad GEMM - dy_local = reshape( - grad_output, - (-1, output_dims[-1]), - device=device, - dtype=dtype, - ) - dx_local, dy = UserbuffersBackwardLinear._functional_backward_dgrad( - grad_output_local=dy_local, - weight=weight, - weight_requires_grad=weight_requires_grad, - device=device, - dtype=dtype, - tensor_parallel_mode=tensor_parallel_mode, - tensor_parallel_group=tensor_parallel_group, - tensor_parallel_size=tensor_parallel_size, - with_fp8_compute=with_fp8_compute, - with_fp8_grad_input=with_fp8_grad_input, - weight_fp8_meta=weight_fp8_meta, - grad_output_fp8_meta=grad_output_fp8_meta, - grad_input_fp8_meta=grad_output_fp8_meta, - ub_comm_name=ub_comm_name, - ) grad_input = reshape(dx_local, input_dims) - - # Perform wgrad GEMM - if not weight_requires_grad: - grad_weight = None - else: - raise NotImplementedError() ### TODO Implement - return grad_input, grad_weight def fuser_backward( @@ -452,6 +504,7 @@ def fuser_backward( weight=linear_op.weight, input_dims=linear_op_ctx.input_dims, weight_dims=linear_op.weight.size(), + input_requires_grad=linear_op_ctx.input_requires_grad, weight_requires_grad=linear_op_ctx.weight_requires_grad, device=linear_op.device, dtype=linear_op.dtype, @@ -467,8 +520,6 @@ def fuser_backward( grad_input_fp8_meta=linear_op_ctx.grad_input_fp8_meta, ub_comm_name=linear_op._userbuffers_options["comm_name"], ) - if accumulate_into_main_grad: - grad_weight = None # Clear input tensor if possible if linear_op_ctx.has_prev_op: @@ -476,6 +527,8 @@ def fuser_backward( # Return gradients grad_params = [() for _ in range(len(self.basic_ops))] + if accumulate_into_main_grad: + grad_weight = None grad_params[self._op_idxs["linear"]] = (grad_weight,) if bias_op is not None: grad_params[self._op_idxs["bias"]] = (grad_bias,) @@ -487,16 +540,25 @@ def fuse_userbuffers_backward_linear( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: + # Return immediately if environment is not distributed + if ( + not torch.distributed.is_initialized() + or torch.distributed.get_world_size() == 1 + ): + return ops + # Sliding window in list of ops window = [] def peek_next_op() -> Optional[FusibleOperation]: + """Get next op in list of ops""" nonlocal ops if not ops: return None return ops[-1][0] def pop_next_op() -> FusibleOperation: + """Remove next op from list of ops and add to sliding window""" nonlocal ops, window window.insert(0, ops[-1]) ops = ops[:-1] @@ -530,11 +592,15 @@ def pop_next_op() -> FusibleOperation: if reduce_scatter is None: if linear.tensor_parallel_mode is None: continue + if linear.tensor_parallel_size == 1: + continue if linear.tensor_parallel_mode == "row" and bias is not None: continue else: if linear.tensor_parallel_mode is not None: continue + if reduce_scatter.process_group_size == 1: + continue # Replace window with fused op op = UserbuffersBackwardLinear( diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index c81b4dbe99..394802317c 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -344,20 +344,7 @@ def _functional_forward( if with_ub_reduce_scatter: kwargs["extra_output_tensor"] = y_local if with_fp8_output: - if y._fp8_meta is None: - # Hackily create FP8TensorMeta if needed - fp8_meta = FP8TensorMeta() - fp8_meta.scale = y._scale_inv.reciprocal() - fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=device) - fp8_meta.scale_inv = y._scale_inv - fp8_meta_index = 0 - else: - # Get FP8TensorMeta from Float8Tensor - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=y._fp8_meta_forward, - ) - fp8_meta = y._fp8_meta[fp8_meta_key] - fp8_meta_index = y._fp8_meta_index + fp8_meta, fp8_meta_index = get_fp8_meta_from_fp8_tensor(y) kwargs.update( dict( out=y._data, @@ -480,16 +467,25 @@ def fuse_userbuffers_forward_linear( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: + # Return immediately if environment is not distributed + if ( + not torch.distributed.is_initialized() + or torch.distributed.get_world_size() == 1 + ): + return ops + # Sliding window in list of ops window = [] def peek_next_op() -> Optional[FusibleOperation]: + """Get next op in list of ops""" nonlocal ops if not ops: return None return ops[0][0] def pop_next_op() -> FusibleOperation: + """Remove next op from list of ops and add to sliding window""" nonlocal ops, window window.append(ops[0]) ops = ops[1:] @@ -523,11 +519,15 @@ def pop_next_op() -> FusibleOperation: if reduce_scatter is None: if linear.tensor_parallel_mode is None: continue + if linear.tensor_parallel_size == 1: + continue if linear.tensor_parallel_mode == "row" and bias is not None: continue else: if linear.tensor_parallel_mode is not None: continue + if reduce_scatter.process_group_size == 1: + continue # Replace window with fused op op = UserbuffersForwardLinear( From bd5c61e58685bae365a1fadaf49c65f664a95b23 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 24 Aug 2024 02:23:37 +0000 Subject: [PATCH 08/13] Add support for grad bias Signed-off-by: Tim Moon --- .../test_fusible_ops_with_userbuffers.py | 6 +- .../ops/fused/userbuffers_backward_linear.py | 81 ++++++++++++++++--- 2 files changed, 73 insertions(+), 14 deletions(-) diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 776eeb5b87..c0b80e3bb6 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -340,9 +340,9 @@ def _test_linear( torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, dx_ref, **tols) torch.testing.assert_close(dw_test, dw_ref, **tols) - # if bias: ### TODO Restore - # db_test = bias_op.bias.grad.to(dtype=torch.float64, device="cpu") - # torch.testing.assert_close(db_test, db_ref, **tols) + if bias: + db_test = bias_op.bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db_test, db_ref, **tols) def run_parallel_tests(model_config: ModelConfig) -> None: diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 71b07fd043..8deaa80074 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -89,11 +89,11 @@ def _functional_backward( *, input_requires_grad: bool = True, weight_requires_grad: bool = True, + bias_requires_grad: bool = False, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, grad_weight: Optional[torch.Tensor] = None, accumulate_into_grad_weight: bool = False, - grad_bias: Optional[torch.Tensor] = None, tensor_parallel_mode: Optional[str] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, tensor_parallel_size: Optional[int] = None, @@ -108,6 +108,9 @@ def _functional_backward( assert input_requires_grad ### TODO Support optional assert weight_requires_grad ### TODO Support optional + # Configuration-specific outputs + extra_outputs = {} + # Check device if device is None: device = weight.device @@ -196,12 +199,25 @@ def _functional_backward( ub_algo_x = UbufOverlapAlgo.BULK_OVERLAP_AG # Check grad output tensor + # Note: Possibly fuse cast with computing grad bias dy_local = reshape( grad_output, (-1, output_dims[-1]), device=device, dtype=dtype, ) + db = None + db_async = None + if bias_requires_grad and with_fp8_compute and with_ub_all_gather_dy: + # We don't have a grad bias impl that takes FP8 input. For + # cases where we cast to FP8 and all-gather, it's better + # to compute the grad bias on ungathered, non-FP8 values. + db = dy_local.sum(dim=0) + db_async = torch.distributed.all_reduce( + db, + group=tensor_parallel_group, + async_op=True, + ) if with_fp8_compute and not is_float8_tensor(dy_local): fp8_dtype = get_fp8_te_dtype( grad_output_fp8_meta["recipe"], @@ -220,10 +236,20 @@ def _functional_backward( fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), dtype=dtype, ) - if with_ub_all_gather_dy: - dy_fp8.copy_(dy_local) - else: + if bias_requires_grad and db is None: + ### TODO Fused cast-transpose-dbias dy_fp8.cast_transpose_(dy_local) + db = dy_local.sum(dim=0) + if with_ub_all_gather_dy: + db_async = torch.distributed.all_reduce( + db, + group=tensor_parallel_group, + async_op=True, + ) + elif not with_ub_all_gather_dy: + dy_fp8.cast_transpose_(dy_local) + else: + dy_fp8.copy_(dy_local) dy_local = dy_fp8 elif not with_fp8_compute and is_float8_tensor(dy_local): if with_ub_all_gather_dy: @@ -232,6 +258,23 @@ def _functional_backward( else: dy_local = dy_local.from_float8() + if ( + bias_requires_grad + and db is None + and with_fp8_compute + and with_ub_all_gather_dy + ): + # We don't have a fused grad bias impl that takes FP8 + # input. For cases where we cast to FP8 and all-gather, + # it's better to compute the grad bias on ungathered, + # non-FP8 values. + db = dy_local.sum(dim=0) + db_async = torch.distributed.all_reduce( + db, + group=tensor_parallel_group, + async_op=True, + ) + # Check input tensor x_local = reshape( input, @@ -418,6 +461,7 @@ def _functional_backward( kwargs["ub_algo"] = ub_algo_x kwargs["ub"] = ub_comm_x gemm(w, dy, dx.dtype, get_workspace(), **kwargs) + grad_input = reshape(dx_local, input_dims) # Perform wgrad GEMM if with_fp8_compute: @@ -446,15 +490,30 @@ def _functional_backward( kwargs = dict( accumulate=accumulate_into_grad_weight, layout="NT", + grad=True, + use_bias=bias_requires_grad, out=grad_weight, ) if with_ub_reduce_scatter_dx: kwargs["ub_algo"] = ub_algo_dx kwargs["ub"] = ub_comm_dx - gemm(x, dy, grad_weight.dtype, get_workspace(), **kwargs) + grad_weight, db, _ = gemm( + x, + dy, + grad_weight.dtype, + get_workspace(), + **kwargs, + ) - grad_input = reshape(dx_local, input_dims) - return grad_input, grad_weight + # Compute grad bias if needed + if db_async is not None: + db_async.wait() + if bias_requires_grad: + if db is None: + db = dy.sum(dim=0) + extra_outputs["bias"] = db + + return grad_input, grad_weight, extra_outputs def fuser_backward( self, @@ -497,8 +556,7 @@ def fuser_backward( accumulate_into_main_grad = False # Linear backward pass - grad_bias = None ### TODO Implement - grad_input, grad_weight = UserbuffersBackwardLinear._functional_backward( + retval = UserbuffersBackwardLinear._functional_backward( grad_output=grad_output, input=x_local, weight=linear_op.weight, @@ -506,11 +564,11 @@ def fuser_backward( weight_dims=linear_op.weight.size(), input_requires_grad=linear_op_ctx.input_requires_grad, weight_requires_grad=linear_op_ctx.weight_requires_grad, + bias_requires_grad=(bias_op is not None), device=linear_op.device, dtype=linear_op.dtype, grad_weight=grad_weight, accumulate_into_grad_weight=accumulate_into_main_grad, - grad_bias=grad_bias, tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, @@ -520,6 +578,7 @@ def fuser_backward( grad_input_fp8_meta=linear_op_ctx.grad_input_fp8_meta, ub_comm_name=linear_op._userbuffers_options["comm_name"], ) + grad_input, grad_weight, extra_outputs = retval # Clear input tensor if possible if linear_op_ctx.has_prev_op: @@ -531,7 +590,7 @@ def fuser_backward( grad_weight = None grad_params[self._op_idxs["linear"]] = (grad_weight,) if bias_op is not None: - grad_params[self._op_idxs["bias"]] = (grad_bias,) + grad_params[self._op_idxs["bias"]] = (extra_outputs["bias"],) grad_extra_inputs = [() for _ in range(len(self.basic_ops))] return grad_input, grad_params, grad_extra_inputs From d5f8a8b60dc9dbed158821cc67cf5e7d55e461a6 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 27 Aug 2024 00:04:49 +0000 Subject: [PATCH 09/13] Fused cast-transpose-dbias Signed-off-by: Tim Moon --- .../ops/fused/userbuffers_backward_linear.py | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 8deaa80074..5b6079f353 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -10,7 +10,12 @@ import torch -from ...cpp_extensions import UbufOverlapAlgo, fp8_gemm, gemm +from ...cpp_extensions import ( + UbufOverlapAlgo, + fp8_cast_transpose_bgrad_fused, + fp8_gemm, + gemm, +) from ...distributed import get_distributed_world_size from ...float8_tensor import Float8Tensor from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype @@ -237,15 +242,28 @@ def _functional_backward( dtype=dtype, ) if bias_requires_grad and db is None: - ### TODO Fused cast-transpose-dbias - dy_fp8.cast_transpose_(dy_local) - db = dy_local.sum(dim=0) + # Fused cast-transpose-bgrad + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=dy_fp8._fp8_meta_forward, + ) + db, data, data_transpose = fp8_cast_transpose_bgrad_fused( + dy_local, + dy_fp8._fp8_meta[fp8_meta_key], + dy_fp8._fp8_meta_index, + dy_fp8._fp8_dtype, + scale_inv=dy_fp8._scale_inv, + ) if with_ub_all_gather_dy: + dy_fp8._data.copy_(data) db_async = torch.distributed.all_reduce( db, group=tensor_parallel_group, async_op=True, ) + else: + dy_fp8._data = data + dy_fp8._transpose = data_transpose + dy_fp8._transpose_invalid = False elif not with_ub_all_gather_dy: dy_fp8.cast_transpose_(dy_local) else: From cd0db1ccc16c46ac22ee66b1d4305cc7a0828fdb Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 27 Aug 2024 20:57:05 +0000 Subject: [PATCH 10/13] Support case where wgrad is optional Signed-off-by: Tim Moon --- .../test_fusible_ops_with_userbuffers.py | 28 +++- .../ops/fused/userbuffers_backward_linear.py | 133 ++++++++++++------ .../ops/fused/userbuffers_forward_linear.py | 19 +-- 3 files changed, 118 insertions(+), 62 deletions(-) diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index c0b80e3bb6..2d0083baa1 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -153,6 +153,7 @@ def _test_linear( device: torch.device = "cuda", tensor_parallel_mode: str = "column", sequence_parallel: bool = True, + weight_requires_grad: bool = True, ) -> None: dtype = model_config.dtype fp8_compute = model_config.fp8 @@ -270,7 +271,18 @@ def _test_linear( linear_op = None bias_op = None if tensor_parallel_mode == "column": - userbuffers_options = dict(comm_name="qkv") + userbuffers_options = {} + if not weight_requires_grad: + if fp8_compute: + userbuffers_options["comm_name"] = "fc1" + else: + # There is a correctness bug with overlapping + # dgrad reduce-scatter with dgrad GEMM. Fall back + # to overlapping dgrad reduce-scatter with wgrad + # GEMM, even though wgrad isn't needed. + userbuffers_options["comm_name"] = "qkv" + else: + userbuffers_options["comm_name"] = "qkv" linear_op = te_ops.BasicLinear( in_features, out_features, @@ -306,6 +318,7 @@ def _test_linear( model = te_ops.Sequential(*ops) with torch.no_grad(): linear_op.weight.copy_(w_test) + linear_op.weight.requires_grad_(requires_grad=weight_requires_grad) if bias: bias_op.bias.copy_(b_test) del w_test @@ -336,10 +349,11 @@ def _test_linear( # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") - dw_test = linear_op.weight.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, dx_ref, **tols) - torch.testing.assert_close(dw_test, dw_ref, **tols) + if weight_requires_grad: + dw_test = linear_op.weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dw_test, dw_ref, **tols) if bias: db_test = bias_op.bias.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(db_test, db_ref, **tols) @@ -357,14 +371,16 @@ def run_parallel_tests(model_config: ModelConfig) -> None: for test_config in itertools.product( (False, True), # bias ("column", "row"), # tensor_parallel_mode + (False, True), # weight_requires_grad ): if rank == 0: print(f"Running _test_linear with {test_config=}") - bias, tensor_parallel_mode = test_config + bias, tensor_parallel_mode, weight_requires_grad = test_config _test_linear( model_config=model_config, bias=bias, tensor_parallel_mode=tensor_parallel_mode, + weight_requires_grad=weight_requires_grad, ) @@ -456,6 +472,9 @@ def main() -> None: # Initialize Userbuffers group = world_group() # Initialize NCCL bootstrap_backend = "mpi" if launcher() == "ompi" else "nccl" + userbuffer_configs = { + "fc1_dgrad": {"method": "pipeline"}, # Overlap dgrad RS with dgrad GEMM + } te.module.base.initialize_ub( [ model_config.sequence_length * model_config.batch_size, @@ -465,6 +484,7 @@ def main() -> None: use_fp8=model_config.fp8, dtype=model_config.dtype, bootstrap_backend=bootstrap_backend, + ub_cfgs=userbuffer_configs, ) # Run tests diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 5b6079f353..f9acee221d 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -7,6 +7,7 @@ from __future__ import annotations from collections.abc import Iterable from typing import Any, Optional +import warnings import torch @@ -92,7 +93,6 @@ def _functional_backward( input_dims: Iterable[int], weight_dims: Iterable[int], *, - input_requires_grad: bool = True, weight_requires_grad: bool = True, bias_requires_grad: bool = False, device: Optional[torch.device] = None, @@ -110,9 +110,6 @@ def _functional_backward( ub_comm_name: str, ): - assert input_requires_grad ### TODO Support optional - assert weight_requires_grad ### TODO Support optional - # Configuration-specific outputs extra_outputs = {} @@ -179,6 +176,10 @@ def _functional_backward( ) # Get Userbuffers communicators and algorithms + # Note: communication patterns are (1) overlap dy all-gather + # with dgrad GEMM, (2) overlap x all-gather with dgrad GEMM + # and dx reduce-scatter with wgrad GEMM, (3) overlap dx + # reduce-scatter with dgrad GEMM. with_ub_all_gather_dy = False with_ub_reduce_scatter_dx = False with_ub_all_gather_x = False @@ -197,11 +198,22 @@ def _functional_backward( ub_algo_dy = UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P elif tensor_parallel_mode == "column": with_ub_reduce_scatter_dx = True - with_ub_all_gather_x = True - ub_comm_dx = get_ub(ub_comm_name + "_wgrad") - ub_comm_x = get_ub(ub_comm_name + "_dgrad") - ub_algo_dx = UbufOverlapAlgo.BULK_OVERLAP_RS - ub_algo_x = UbufOverlapAlgo.BULK_OVERLAP_AG + if weight_requires_grad: + with_ub_all_gather_x = True + ub_comm_dx = get_ub(ub_comm_name + "_wgrad") + ub_comm_x = get_ub(ub_comm_name + "_dgrad") + ub_algo_dx = UbufOverlapAlgo.BULK_OVERLAP_RS + ub_algo_x = UbufOverlapAlgo.BULK_OVERLAP_AG + else: + with_ub_all_gather_x = False + ub_comm_dx = get_ub(ub_comm_name + "_dgrad") + is_atomic_gemm = with_fp8_compute and ub_comm_dx.is_atomic_gemm() + ub_algo_dx = { + (True, True): UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P, + (True, False): UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P, + (False, True): UbufOverlapAlgo.ATOMIC_GEMM_RS, + (False, False): UbufOverlapAlgo.SPLIT_PIPELINED_RS, + }[(ub_comm_dx.is_p2p_overlap(), is_atomic_gemm)] # Check grad output tensor # Note: Possibly fuse cast with computing grad bias @@ -294,41 +306,43 @@ def _functional_backward( ) # Check input tensor - x_local = reshape( - input, - (-1, input_dims[-1]), - device=device, - dtype=dtype, - ) - if with_fp8_compute and not is_float8_tensor(x_local): - fp8_dtype = get_fp8_te_dtype( - input_fp8_meta["recipe"], - fprop_tensor=True, - ) - if with_ub_all_gather_x: - data = ub_comm_x.get_ubuf_output(0) - else: - data = torch.empty_like(x_local, dtype=torch.uint8) - x_fp8 = Float8Tensor( - data=data, - fp8_meta=input_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), + x_local = None + if weight_requires_grad: + x_local = reshape( + input, + (-1, input_dims[-1]), + device=device, dtype=dtype, ) - if with_ub_all_gather_x: - x_fp8.copy_(x_local) - else: - x_fp8.cast_transpose_(x_local) - x_local = x_fp8 - elif not with_fp8_compute and is_float8_tensor(x_local): - if with_ub_all_gather_x: - ub_local_buffer = ub_comm_x.get_ubuf_output(0) - x_local = ub_local_buffer.copy_(x_local) - else: - x_local = x_local.from_float8() + if with_fp8_compute and not is_float8_tensor(x_local): + fp8_dtype = get_fp8_te_dtype( + input_fp8_meta["recipe"], + fprop_tensor=True, + ) + if with_ub_all_gather_x: + data = ub_comm_x.get_ubuf_output(0) + else: + data = torch.empty_like(x_local, dtype=torch.uint8) + x_fp8 = Float8Tensor( + data=data, + fp8_meta=input_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + fp8_scale_inv=torch.empty([1], dtype=torch.float32, device=device), + dtype=dtype, + ) + if with_ub_all_gather_x: + x_fp8.copy_(x_local) + else: + x_fp8.cast_transpose_(x_local) + x_local = x_fp8 + elif not with_fp8_compute and is_float8_tensor(x_local): + if with_ub_all_gather_x: + ub_local_buffer = ub_comm_x.get_ubuf_output(0) + x_local = ub_local_buffer.copy_(x_local) + else: + x_local = x_local.from_float8() # Check weight tensor w = convert_tensor( @@ -388,7 +402,11 @@ def _functional_backward( if with_ub_reduce_scatter_dx: # Initialize buffers for UB reduce-scatter dx = ub_comm_dx.get_ubuf_output(1) - dx_local = ub_comm_dx.get_ubuf_output(0) + ub_local_buffer = ub_comm_dx.get_ubuf_output(0) + if with_ub_all_gather_x: + dx_local = ub_local_buffer + else: + dx_local = torch.empty_like(ub_local_buffer) else: # Allocate grad input tensor if with_fp8_grad_input: @@ -443,6 +461,10 @@ def _functional_backward( elif with_ub_all_gather_x: kwargs["ub_algo"] = ub_algo_x kwargs["ub"] = ub_comm_x + elif with_ub_reduce_scatter_dx: + kwargs["ub_algo"] = ub_algo_dx + kwargs["ub"] = ub_comm_dx + kwargs["extra_output_tensor"] = dx_local if with_fp8_grad_input: fp8_meta, fp8_meta_index = get_fp8_meta_from_fp8_tensor(dx) kwargs.update( @@ -478,11 +500,17 @@ def _functional_backward( elif with_ub_all_gather_x: kwargs["ub_algo"] = ub_algo_x kwargs["ub"] = ub_comm_x + elif with_ub_reduce_scatter_dx: + kwargs["ub_algo"] = ub_algo_dx + kwargs["ub"] = ub_comm_dx + kwargs["extra_output_tensor"] = dx_local gemm(w, dy, dx.dtype, get_workspace(), **kwargs) grad_input = reshape(dx_local, input_dims) # Perform wgrad GEMM - if with_fp8_compute: + if not weight_requires_grad: + pass + elif with_fp8_compute: kwargs = dict( accumulate=accumulate_into_grad_weight, out=grad_weight, @@ -573,6 +601,20 @@ def fuser_backward( else: accumulate_into_main_grad = False + # Hackily workaround Userbuffers bug with non-FP8 dgrad + # reduce-scatter overlap + weight_requires_grad = linear_op_ctx.weight_requires_grad + if not linear_op_ctx.with_fp8_compute and not weight_requires_grad: + warnings.warn( + "There is a correctness bug when using Userbuffers " + "to overlap a dgrad reduce-scatter with a non-FP8 dgrad GEMM. " + "Hackily working around by overlapping dgrad reduce-scatter " + "with wgrad GEMM, even though wgrad isn't needed. " + "Please contact Transformer Engine team " + "if you encounter this use-case." + ) + weight_requires_grad = True + # Linear backward pass retval = UserbuffersBackwardLinear._functional_backward( grad_output=grad_output, @@ -580,8 +622,7 @@ def fuser_backward( weight=linear_op.weight, input_dims=linear_op_ctx.input_dims, weight_dims=linear_op.weight.size(), - input_requires_grad=linear_op_ctx.input_requires_grad, - weight_requires_grad=linear_op_ctx.weight_requires_grad, + weight_requires_grad=weight_requires_grad, bias_requires_grad=(bias_op is not None), device=linear_op.device, dtype=linear_op.dtype, diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 394802317c..d0f56c8880 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -172,18 +172,13 @@ def _functional_forward( else: ub_algo = UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P elif with_ub_reduce_scatter: - if with_fp8_compute: - ub_algo = { - (True, True): UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P, - (True, False): UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P, - (False, True): UbufOverlapAlgo.ATOMIC_GEMM_RS, - (False, False): UbufOverlapAlgo.SPLIT_PIPELINED_RS, - }[(ub_comm.is_p2p_overlap(), ub_comm.is_atomic_gemm())] - else: - if ub_comm.is_p2p_overlap(): - ub_algo = UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - ub_algo = UbufOverlapAlgo.SPLIT_PIPELINED_RS + is_atomic_gemm = with_fp8_compute and ub_comm.is_atomic_gemm() + ub_algo = { + (True, True): UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P, + (True, False): UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P, + (False, True): UbufOverlapAlgo.ATOMIC_GEMM_RS, + (False, False): UbufOverlapAlgo.SPLIT_PIPELINED_RS, + }[(ub_comm.is_p2p_overlap(), is_atomic_gemm)] else: raise RuntimeError("Could not choose Userbuffers communication algorithm") From 6209910e36f9f2810c79c6a3835cd1e97dc9be3e Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 27 Aug 2024 22:10:04 +0000 Subject: [PATCH 11/13] Expand documentation Signed-off-by: Tim Moon --- qa/L1_pytorch_distributed_unittest/test.sh | 1 + .../test_fusible_ops_with_userbuffers.py | 11 +- .../pytorch/ops/basic/basic_linear.py | 2 +- .../ops/fused/userbuffers_backward_linear.py | 100 +++++++++++++++++- .../ops/fused/userbuffers_forward_linear.py | 80 +++++++++++++- 5 files changed, 181 insertions(+), 13 deletions(-) diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index fef48fd4b0..720ff1f9ca 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -6,6 +6,7 @@ set -e : ${TE_PATH:=/opt/transformerengine} pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py git clone https://github.com/NVIDIA/Megatron-LM.git cd Megatron-LM diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 2d0083baa1..68cb56b3a7 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -44,6 +44,8 @@ @dataclasses.dataclass class ModelConfig: + """Tensor dimensions in Transformer model""" + sequence_length: int batch_size: int num_heads: int @@ -51,10 +53,6 @@ class ModelConfig: dtype: torch.dtype fp8: bool - @property - def sequence_and_batch_size(self): - return self.sequence_length * self.batch_size - @property def hidden_size(self): return self.num_heads * self.head_dim @@ -62,6 +60,7 @@ def hidden_size(self): @functools.cache def launcher() -> str: + """Launcher for current parallel job""" if "OMPI_COMM_WORLD_SIZE" in os.environ: return "ompi" if "TORCHELASTIC_RUN_ID" in os.environ: @@ -166,7 +165,7 @@ def _test_linear( # Tensor dimensions out_features = model_config.hidden_size in_features = model_config.hidden_size - batch_size = model_config.sequence_and_batch_size + batch_size = model_config.sequence_length * model_config.batch_size in_shape = [batch_size, in_features] out_shape = [batch_size, out_features] @@ -398,7 +397,7 @@ def test_fuser_ops_with_userbuffers( dtype: torch.dtype = torch.bfloat16, fp8: bool, ) -> None: - """Launch parallel job that runs parallel tests""" + """Launch parallel job and run tests""" # Skip invalid configurations if fp8 and not fp8_available: diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index a85d133afa..72e5dde7f7 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -712,7 +712,7 @@ def _functional_backward( FP8 metadata for casting loss gradient w.r.t. output tensor to FP8. Required if output grad is not already in FP8. - grad_output_fp8_meta: dict, optional + grad_input_fp8_meta: dict, optional FP8 metadata for casting loss gradient w.r.t. input tensor to FP8 diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index f9acee221d..391ca28538 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -38,6 +38,13 @@ ) class UserbuffersBackwardLinear(FusedOperation): + """Linear backward implementation using Userbuffers + + This operation is equivalent to a linear operation's backward + pass, but it uses Userbuffers to overlap tensor-parallel + communication with compute. + + """ def __init__( self, @@ -104,11 +111,80 @@ def _functional_backward( tensor_parallel_size: Optional[int] = None, sequence_parallel: bool = False, with_fp8_compute: bool = False, + input_fp8_meta: Optional[dict[str, Any]] = None, weight_fp8_meta: Optional[dict[str, Any]] = None, grad_output_fp8_meta: Optional[dict[str, Any]] = None, grad_input_fp8_meta: Optional[dict[str, Any]] = None, ub_comm_name: str, - ): + ) -> tuple[torch.Tensor, Optional[torch.Tensor], dict]: + """Functional API for backward pass + + Parameters + ---------- + grad_output: torch.Tensor + Loss gradient w.r.t. output tensor + input: torch.Tensor, optional + Input tensor. Required to compute loss gradient w.r.t. + weight. + weight: torch.Tensor, optional + Weight tensor. Required to compute loss gradient w.r.t. + input. + input_dims: iterable of int + Input tensor dimensions + weight_dims: iterable of int + Weight tensor dimensions + weight_requires_grad: bool + Whether to compute loss gradient w.r.t. weight tensor + bias_requires_grad: bool + Whether to compute loss gradient w.r.t. bias tensor + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype + grad_weight: torch.Tensor, optional + Loss gradient w.r.t. weight tensor + accumulate_into_grad_weight: bool, default = `False` + Add result to weight grad instead of overwriting + tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + Mode for tensor parallelism + tensor_parallel_group: torch.distributed.ProcessGroup, default = world group + Process group for tensor parallelism + sequence_parallel: bool, default = `False` + Whether to apply sequence parallelism together with tensor + parallelism, i.e. distributing input or output tensors + along outer dimension (sequence or batch dim) when not + distributing along inner dimension (embedding dim) + with_fp8_compute: bool, default = `False` + Whether to perform compute in FP8 + input_fp8_meta: dict, optional + FP8 metadata for casting input tensor to FP8. Required for + FP8 compute if input is not already in FP8. + weight_fp8_meta: dict, optional + FP8 metadata for casting weight tensor to FP8. Required for + FP8 compute if weight is not already in FP8. + grad_output_fp8_meta: dict, optional + FP8 metadata for casting loss gradient w.r.t. output + tensor to FP8. Required if output grad is not already in + FP8. + grad_input_fp8_meta: dict, optional + FP8 metadata for casting loss gradient w.r.t. input + tensor to FP8 + ub_comm_name: str + Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is + used to access the corresponding Userbuffers communicators + (e.g. "qkv_dgrad", "qkv_wgrad"). + + Returns + ------- + torch.Tensor + Loss gradient w.r.t. input tensor + torch.Tensor + Loss gradient w.r.t. weight tensor + dict + Extra output tensors. "grad_bias" is loss gradient w.r.t. + the bias tensor. + + """ # Configuration-specific outputs extra_outputs = {} @@ -166,6 +242,7 @@ def _functional_backward( if grad_output_fp8_meta is None and not is_float8_tensor(grad_output): raise ValueError("No FP8 metadata was provided for casting output gradient to FP8") else: + input_fp8_meta = None weight_fp8_meta = None grad_output_fp8_meta = None grad_input_fp8_meta = None @@ -557,7 +634,7 @@ def _functional_backward( if bias_requires_grad: if db is None: db = dy.sum(dim=0) - extra_outputs["bias"] = db + extra_outputs["grad_bias"] = db return grad_input, grad_weight, extra_outputs @@ -638,6 +715,9 @@ def fuser_backward( ub_comm_name=linear_op._userbuffers_options["comm_name"], ) grad_input, grad_weight, extra_outputs = retval + grad_bias = None + if bias_op is not None: + grad_bias = extra_outputs["grad_bias"] # Clear input tensor if possible if linear_op_ctx.has_prev_op: @@ -649,7 +729,7 @@ def fuser_backward( grad_weight = None grad_params[self._op_idxs["linear"]] = (grad_weight,) if bias_op is not None: - grad_params[self._op_idxs["bias"]] = (extra_outputs["bias"],) + grad_params[self._op_idxs["bias"]] = (grad_bias,) grad_extra_inputs = [() for _ in range(len(self.basic_ops))] return grad_input, grad_params, grad_extra_inputs @@ -657,6 +737,20 @@ def fuser_backward( def fuse_userbuffers_backward_linear( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: + """Substitute linear operations with Userbuffers implementation + + Parameters + ---------- + ops: list of tuples + Forward pass operations and the indices of the corresponding + basic operations. + + Returns + ------- + ops: list of tuples + Updated forward pass operations + + """ # Return immediately if environment is not distributed if ( diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index d0f56c8880..12bc5db38b 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -31,6 +31,13 @@ ) class UserbuffersForwardLinear(FusedOperation): + """Linear forward implementation using Userbuffers + + This operation is equivalent to a linear operation's forward pass, + but it uses Userbuffers to overlap tensor-parallel communication + with compute. + + """ def __init__( self, @@ -93,7 +100,54 @@ def _functional_forward( weight_fp8_meta: Optional[dict[str, Any]] = None, output_fp8_meta: Optional[dict[str, Any]] = None, ub_comm_name: str, - ): + ) -> tuple[torch.Tensor, dict]: + """Functional API for forward pass + + Parameters + ---------- + input: torch.Tensor + Input tensor + weight: torch.Tensor + Weight tensor + bias: torch.Tensor, optional + Bias tensor + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype + tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + Mode for tensor parallelism + tensor_parallel_group: torch.distributed.ProcessGroup, default = world group + Process group for tensor parallelism + sequence_parallel: bool, default = `False` + Whether to apply sequence parallelism together with tensor + parallelism, i.e. distributing input or output tensors + along outer dimension (sequence or batch dim) when not + distributing along inner dimension (embedding dim) + with_fp8_compute: bool, default = `False` + Whether to perform compute in FP8 + input_fp8_meta: dict, optional + FP8 metadata for casting input tensor to FP8. Required for + FP8 compute if input is not already in FP8. + weight_fp8_meta: dict, optional + FP8 metadata for casting weight tensor to FP8. Required for + FP8 compute if weight is not already in FP8. + output_fp8_meta: dict, optional + FP8 metadata for casting output tensor to FP8 + ub_comm_name: str + Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is + used to access the corresponding Userbuffers communicators + (e.g. "qkv_fprop"). + + Returns + ------- + torch.Tensor + Output tensor + dict + Extra output tensors. "input" is the input tensor, + possibly cast and reshaped from the provided input tensor. + + """ # Check device if device is None: @@ -378,7 +432,12 @@ def _functional_forward( # Reshape output tensor out = reshape(y_local, output_dims) - return out, x_local, w + # Return cast tensors + extra_outputs = dict( + input=x_local, + weight=w, + ) + return out, extra_outputs def fuser_forward( self, @@ -427,7 +486,7 @@ def fuser_forward( raise RuntimeError("Linear op is missing dict for Userbuffers options") # Linear forward - output, x_local, _ = UserbuffersForwardLinear._functional_forward( + output, extra_outputs = UserbuffersForwardLinear._functional_forward( input=input_, weight=linear_op.weight, bias=bias, @@ -443,6 +502,7 @@ def fuser_forward( output_fp8_meta=output_fp8_meta, ub_comm_name=linear_op._userbuffers_options["comm_name"], ) + x_local = extra_outputs["input"] # Save state for backward pass linear_op_ctx.save_for_backward(x_local) @@ -461,6 +521,20 @@ def fuser_forward( def fuse_userbuffers_forward_linear( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: + """Substitute linear operations with Userbuffers implementation + + Parameters + ---------- + ops: list of tuples + Forward pass operations and the indices of the corresponding + basic operations. + + Returns + ------- + ops: list of tuples + Updated forward pass operations + + """ # Return immediately if environment is not distributed if ( From a98e2f2c7064737e990c326d7a00beb88fe7a472 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 22:23:08 +0000 Subject: [PATCH 12/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../test_fusible_ops_with_userbuffers.py | 5 ++--- tests/pytorch/utils.py | 2 ++ .../pytorch/csrc/comm_gemm_overlap.h | 20 ++++++++++--------- .../ops/fused/userbuffers_backward_linear.py | 20 +++++-------------- .../ops/fused/userbuffers_forward_linear.py | 14 ++++--------- 5 files changed, 24 insertions(+), 37 deletions(-) diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 68cb56b3a7..ead121f314 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -407,9 +407,7 @@ def test_fuser_ops_with_userbuffers( command = [] if tex.ubuf_built_with_mpi(): python_exe = pathlib.Path(sys.executable).resolve() - command.extend( - ("mpirun", "-np", str(world_size), "--oversubscribe", "--quiet", python_exe) - ) + command.extend(("mpirun", "-np", str(world_size), "--oversubscribe", "--quiet", python_exe)) else: command.extend(("torchrun", f"--nproc_per_node={world_size}")) @@ -492,5 +490,6 @@ def main() -> None: # Clean up te.module.base.destroy_ub() + if __name__ == "__main__": main() diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 36d9f52978..a8b181a187 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -10,6 +10,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex + def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype: """Convert type name to PyTorch dtype""" if isinstance(dtype, torch.dtype): @@ -48,6 +49,7 @@ def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype: )[name] return dtype + def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: """Estimated numerical error for a datatype diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 895f75b11d..08f64ce519 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -478,9 +478,9 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { torch::Tensor input_a_chunk = torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); torch::Tensor output_chunk = torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor bias_chunk = (bias_chunk_ptr == nullptr - ? bias - : torch::from_blob(bias_chunk_ptr, {m_chunk}, bias.options())); + torch::Tensor bias_chunk = + (bias_chunk_ptr == nullptr ? bias + : torch::from_blob(bias_chunk_ptr, {m_chunk}, bias.options())); torch::Tensor workspace_chunk = torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); at::cuda::setCurrentCUDAStream(_stream_compute[0]); @@ -498,9 +498,10 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); torch::Tensor output_chunk = torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor bias_chunk = (bias_chunk_ptr == nullptr - ? bias - : torch::from_blob(bias_chunk_ptr, {m_chunk}, bias.options())); + torch::Tensor bias_chunk = + (bias_chunk_ptr == nullptr + ? bias + : torch::from_blob(bias_chunk_ptr, {m_chunk}, bias.options())); torch::Tensor workspace_chunk = torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}, workspace.options()); @@ -558,9 +559,10 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); torch::Tensor output_chunk = torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor bias_chunk = (bias_chunk_ptr == nullptr - ? bias - : torch::from_blob(bias_chunk_ptr, {m_chunk}, bias.options())); + torch::Tensor bias_chunk = + (bias_chunk_ptr == nullptr + ? bias + : torch::from_blob(bias_chunk_ptr, {m_chunk}, bias.options())); torch::Tensor workspace_chunk = torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}, workspace.options()); diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 391ca28538..85de61285f 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -37,6 +37,7 @@ reshape, ) + class UserbuffersBackwardLinear(FusedOperation): """Linear backward implementation using Userbuffers @@ -233,9 +234,7 @@ def _functional_backward( f"({tensor_parallel_size=}, {tensor_parallel_mode=})" ) if not sequence_parallel: - raise RuntimeError( - f"Invalid configuration for Userbuffers ({sequence_parallel=})" - ) + raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})") # Check if FP8 is enabled if with_fp8_compute: @@ -365,12 +364,7 @@ def _functional_backward( else: dy_local = dy_local.from_float8() - if ( - bias_requires_grad - and db is None - and with_fp8_compute - and with_ub_all_gather_dy - ): + if bias_requires_grad and db is None and with_fp8_compute and with_ub_all_gather_dy: # We don't have a fused grad bias impl that takes FP8 # input. For cases where we cast to FP8 and all-gather, # it's better to compute the grad bias on ungathered, @@ -516,8 +510,7 @@ def _functional_backward( if grad_weight is None: if accumulate_into_grad_weight: raise ValueError( - "Attempted to accumulate into grad weight buffer" - "without providing grad weight" + "Attempted to accumulate into grad weight bufferwithout providing grad weight" ) grad_weight = torch.empty( weight_dims, @@ -753,10 +746,7 @@ def fuse_userbuffers_backward_linear( """ # Return immediately if environment is not distributed - if ( - not torch.distributed.is_initialized() - or torch.distributed.get_world_size() == 1 - ): + if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: return ops # Sliding window in list of ops diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 12bc5db38b..f2aacadaef 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -30,6 +30,7 @@ reshape, ) + class UserbuffersForwardLinear(FusedOperation): """Linear forward implementation using Userbuffers @@ -191,9 +192,7 @@ def _functional_forward( f"({tensor_parallel_size=}, {tensor_parallel_mode=})" ) if not sequence_parallel: - raise RuntimeError( - f"Invalid configuration for Userbuffers ({sequence_parallel=})" - ) + raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})") # Check if FP8 is enabled if with_fp8_compute: @@ -206,9 +205,7 @@ def _functional_forward( weight_fp8_meta = None output_fp8_meta = None with_fp8_output = ( - with_fp8_compute - and tensor_parallel_mode != "row" - and output_fp8_meta is not None + with_fp8_compute and tensor_parallel_mode != "row" and output_fp8_meta is not None ) # Get Userbuffers communicator @@ -537,10 +534,7 @@ def fuse_userbuffers_forward_linear( """ # Return immediately if environment is not distributed - if ( - not torch.distributed.is_initialized() - or torch.distributed.get_world_size() == 1 - ): + if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: return ops # Sliding window in list of ops From 7aaef65efad7434ea2f88c431985437064682498 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 27 Aug 2024 22:40:43 +0000 Subject: [PATCH 13/13] Fix linter warnings Signed-off-by: Tim Moon --- transformer_engine/pytorch/ops/_common.py | 3 +- .../pytorch/ops/basic/__init__.py | 1 - .../pytorch/ops/basic/cast_float8.py | 100 ------------------ .../ops/fused/userbuffers_backward_linear.py | 12 +-- .../ops/fused/userbuffers_forward_linear.py | 4 +- 5 files changed, 8 insertions(+), 112 deletions(-) delete mode 100644 transformer_engine/pytorch/ops/basic/cast_float8.py diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 5080942437..6ee4705970 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -9,8 +9,9 @@ import torch -from ..cpp_extensions import FP8TensorMeta +from transformer_engine_torch import FP8TensorMeta from ..float8_tensor import Float8Tensor +from ..fp8 import FP8GlobalStateManager def canonicalize_device(device: Optional[torch.device | str]) -> torch.device: diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 599bac1b4e..1003cc0337 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -9,7 +9,6 @@ from .all_reduce import AllReduce from .basic_linear import BasicLinear from .bias import Bias -from .cast_float8 import CastFloat8 from .identity import Identity from .make_extra_output import MakeExtraOutput from .reduce_scatter import ReduceScatter diff --git a/transformer_engine/pytorch/ops/basic/cast_float8.py b/transformer_engine/pytorch/ops/basic/cast_float8.py deleted file mode 100644 index 40cd350e99..0000000000 --- a/transformer_engine/pytorch/ops/basic/cast_float8.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Fusible operation for FP8 cast.""" - -from __future__ import annotations -from typing import Optional - -import torch - -from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.fp8 import ( - FP8GlobalStateManager, - get_fp8_te_dtype, -) -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) -from .._common import is_float8_tensor - - -class CastFloat8(BasicOperation): - """Cast tensor to FP8 - - Uses FP8 recipe from `fp8_autocast` context. When called outside - of an `fp8_autocast` context, this is an identity operation. - - Parameters - ---------- - forward: bool, default = `True` - Perform FP8 cast in forward pass - backward: bool, default = `True` - Perform FP8 cast in backward pass - - """ - - def __init__( - self, - forward: bool = True, - backward: bool = True, - ) -> None: - super().__init__() - self._cast_forward = forward - self._cast_backward = backward - - def num_fp8_scales(self, mode: str) -> int: - if mode == "input" and self._cast_forward: - return 1 - if mode == "grad_output" and self._cast_backward: - return 1 - return 0 - - def op_forward( - self, - ctx: OperationContext, - input_: torch.Tensor, - prev_op: Optional[BasicOperation] = None, - next_op: Optional[BasicOperation] = None, - ) -> torch.Tensor: - - # Check if FP8 is enabled - fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() - cast_forward = fp8_enabled and self._cast_forward - cast_backward = fp8_enabled and self._cast_backward - - # Cast to FP8 if needed - out = input_ - if cast_forward and not is_float8_tensor(out): - fp8_meta = self.get_fp8_meta("input") - fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - out = Float8Tensor.to_float8( - out, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - ) - - ctx.cast_backward = cast_backward - return out - - def op_backward( - self, - ctx: OperationContext, - grad_output: torch.Tensor, - ) -> tuple[torch.Tensor, tuple[()]]: - grad_input = grad_output - if ctx.cast_backward and not is_float8_tensor(grad_input): - fp8_meta = self.get_fp8_meta("grad_output") - fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) - grad_input = Float8Tensor.to_float8( - grad_input, - fp8_meta=fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - ) - return grad_input, () diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 85de61285f..2d6e4eeaec 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -11,8 +11,8 @@ import torch +from transformer_engine_torch import UbufOverlapAlgo from ...cpp_extensions import ( - UbufOverlapAlgo, fp8_cast_transpose_bgrad_fused, fp8_gemm, gemm, @@ -21,13 +21,9 @@ from ...float8_tensor import Float8Tensor from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype from ...module.base import get_ub, get_workspace +from ...utils import clear_tensor_data from ..basic import BasicLinear, Bias, ReduceScatter -from ..op import ( - BasicOperation, - FusedOperation, - FusibleOperation, - OperationContext, -) +from ..op import FusedOperation, FusibleOperation, OperationContext from .._common import ( canonicalize_device, canonicalize_dtype, @@ -648,11 +644,9 @@ def fuser_backward( linear_op = self.basic_ops[idx] linear_op_ctx = basic_op_ctxs[idx] bias_op = None - bias = None if self._op_idxs["bias"] is not None: idx = self._op_idxs["bias"] bias_op = self.basic_ops[idx] - bias = bias_op.bias # Saved tensors from forward pass (x_local,) = linear_op_ctx.saved_tensors diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index f2aacadaef..409d9d8757 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -10,7 +10,8 @@ import torch -from ...cpp_extensions import FP8TensorMeta, UbufOverlapAlgo, fp8_gemm, gemm +from transformer_engine_torch import UbufOverlapAlgo +from ...cpp_extensions import fp8_gemm, gemm from ...distributed import get_distributed_world_size from ...float8_tensor import Float8Tensor from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype @@ -26,6 +27,7 @@ canonicalize_device, canonicalize_dtype, convert_tensor, + get_fp8_meta_from_fp8_tensor, is_float8_tensor, reshape, )