diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index c22ba221be..9a11ccc008 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -10,4 +10,5 @@ pip install pytest==8.2.1 pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py new file mode 100644 index 0000000000..ead121f314 --- /dev/null +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -0,0 +1,495 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +import argparse +import dataclasses +import functools +import itertools +import os +import pathlib +import subprocess +import sys + +import pytest +import torch + +import transformer_engine +import transformer_engine.pytorch as te +import transformer_engine.pytorch.cpp_extensions as tex +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.pytorch.ops._common import is_float8_tensor +from transformer_engine.pytorch.ops.fused import ( + UserbuffersBackwardLinear, + UserbuffersForwardLinear, +) +from transformer_engine.pytorch.utils import is_bf16_compatible + +# Import utility functions +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import dtype_tols, str_to_dtype + +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +# Check if there are multiple GPUs +if torch.cuda.device_count() < 2: + pytest.skip("Userbuffers requires at least 2 GPUs.") + + +@dataclasses.dataclass +class ModelConfig: + """Tensor dimensions in Transformer model""" + + sequence_length: int + batch_size: int + num_heads: int + head_dim: int + dtype: torch.dtype + fp8: bool + + @property + def hidden_size(self): + return self.num_heads * self.head_dim + + +@functools.cache +def launcher() -> str: + """Launcher for current parallel job""" + if "OMPI_COMM_WORLD_SIZE" in os.environ: + return "ompi" + if "TORCHELASTIC_RUN_ID" in os.environ: + return "torchrun" + raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`") + + +@functools.cache +def world_group() -> torch.distributed.ProcessGroup: + """Get NCCL process group, initializing if needed""" + + # Get launch config from environment + if launcher() == "ompi": + # OpenMPI + world_size = int(os.getenv("OMPI_COMM_WORLD_SIZE")) + rank = int(os.getenv("OMPI_COMM_WORLD_RANK")) + local_size = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE")) + local_rank = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")) + elif launcher() == "torchrun": + # torchrun + world_size = int(os.getenv("WORLD_SIZE")) + rank = int(os.getenv("RANK")) + local_size = int(os.getenv("LOCAL_WORLD_SIZE")) + local_rank = int(os.getenv("LOCAL_RANK")) + else: + raise RuntimeError("Unexpected launcher ({launcher()})") + + # Construct communicator + assert local_size == world_size + torch.cuda.set_device(local_rank) + group = torch.distributed.init_process_group( + "nccl", + init_method="file:///tmp/rdzv", + world_size=world_size, + rank=rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + return group + + +def reset_rng(seed: int = 1234) -> None: + """Reset random number generators""" + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +@torch.no_grad() +def make_reference_and_test_tensors( + shape: int | Iterable[int], + ref_dtype: torch.dtype = torch.float64, + ref_device: torch.device = "cpu", + test_dtype: torch.dtype = torch.float32, + test_device: torch.device = "cuda", + test_is_fp8: bool = False, + requires_grad: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """Construct tensors with the same values + + The reference tensor is intended for use in plain PyTorch + operations in high precision. The test tensor is intended for use + in Transformer Engine operations. + + """ + + # Random data + ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) + + # Make copy of tensor + if test_is_fp8: + test = Float8Tensor.to_float8(ref) + else: + test = ref.to(device=test_device, dtype=test_dtype) + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + + # Make sure reference and test tensors represent exact same values + ref.copy_(test) + + # Return reference and test tensors + ref.requires_grad_(requires_grad) + test.requires_grad_(requires_grad) + return ref, test + + +def _test_linear( + *, + model_config: ModelConfig, + bias: bool = False, + device: torch.device = "cuda", + tensor_parallel_mode: str = "column", + sequence_parallel: bool = True, + weight_requires_grad: bool = True, +) -> None: + dtype = model_config.dtype + fp8_compute = model_config.fp8 + + # Distributed process group + process_group = world_group() + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + # Tensor dimensions + out_features = model_config.hidden_size + in_features = model_config.hidden_size + batch_size = model_config.sequence_length * model_config.batch_size + in_shape = [batch_size, in_features] + out_shape = [batch_size, out_features] + + # Random data + reset_rng() + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_compute, + ) + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_compute, + ) + b_ref, b_test = None, None + if bias: + if tensor_parallel_mode == "row": + bias_shape = [world_size, out_features] + else: + bias_shape = [out_features] + b_ref, b_test = make_reference_and_test_tensors( + bias_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_compute, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref) + if bias: + if tensor_parallel_mode == "row": + y_ref += b_ref.sum(dim=0) + else: + y_ref += b_ref + y_ref.backward(dy_ref) + + # Convert to distributed tensors + with torch.no_grad(): + dw_ref = w_ref.grad + db_ref = b_ref.grad if bias else None + dx_ref = x_ref.grad + if tensor_parallel_mode == "column": + local_out_features = out_features // world_size + local_slice = slice( + rank * local_out_features, + (rank + 1) * local_out_features, + ) + w_ref = w_ref[local_slice, :] + dw_ref = dw_ref[local_slice, :] + w_test = w_test[local_slice, :] + if bias: + b_ref = b_ref[local_slice] + db_ref = db_ref[local_slice] + b_test = b_test[local_slice] + y_ref = y_ref[..., local_slice] + dy_ref = dy_ref[..., local_slice] + dy_test = dy_test[..., local_slice].clone() + elif tensor_parallel_mode == "row": + local_in_features = in_features // world_size + local_slice = slice( + rank * local_in_features, + (rank + 1) * local_in_features, + ) + w_ref = w_ref[:, local_slice] + dw_ref = dw_ref[:, local_slice] + w_test = w_test[:, local_slice] + if bias: + b_ref = b_ref[rank, :] + db_ref = db_ref[rank, :] + b_test = b_test[rank, :] + x_ref = x_ref[..., local_slice] + dx_ref = dx_ref[..., local_slice] + x_test = x_test[..., local_slice].clone() + if sequence_parallel: + local_batch_size = batch_size // world_size + local_slice = slice( + rank * local_batch_size, + (rank + 1) * local_batch_size, + ) + if tensor_parallel_mode == "column": + x_ref = x_ref[local_slice, ...] + dx_ref = dx_ref[local_slice, ...] + x_test = x_test[local_slice, ...].clone() + elif tensor_parallel_mode == "row": + y_ref = y_ref[local_slice, ...] + dy_ref = dy_ref[local_slice, ...] + dy_test = dy_test[local_slice, ...].clone() + x_test.requires_grad_() + + # Implementation with fusible operation + with te.fp8_model_init(enabled=fp8_compute): + ops = [] + linear_op = None + bias_op = None + if tensor_parallel_mode == "column": + userbuffers_options = {} + if not weight_requires_grad: + if fp8_compute: + userbuffers_options["comm_name"] = "fc1" + else: + # There is a correctness bug with overlapping + # dgrad reduce-scatter with dgrad GEMM. Fall back + # to overlapping dgrad reduce-scatter with wgrad + # GEMM, even though wgrad isn't needed. + userbuffers_options["comm_name"] = "qkv" + else: + userbuffers_options["comm_name"] = "qkv" + linear_op = te_ops.BasicLinear( + in_features, + out_features, + device=device, + dtype=dtype, + tensor_parallel_mode=tensor_parallel_mode, + tensor_parallel_group=process_group, + sequence_parallel=sequence_parallel, + userbuffers_options=userbuffers_options, + ) + ops.append(linear_op) + if bias: + bias_op = te_ops.Bias( + out_features // world_size, + device=device, + dtype=dtype, + ) + ops.append(bias_op) + elif tensor_parallel_mode == "row": + userbuffers_options = dict(comm_name="proj") + linear_op = te_ops.BasicLinear( + in_features // world_size, + out_features, + device=device, + dtype=dtype, + userbuffers_options=userbuffers_options, + ) + ops.append(linear_op) + if bias: + bias_op = te_ops.Bias(out_features, device=device, dtype=dtype) + ops.append(bias_op) + ops.append(te_ops.ReduceScatter(process_group)) + model = te_ops.Sequential(*ops) + with torch.no_grad(): + linear_op.weight.copy_(w_test) + linear_op.weight.requires_grad_(requires_grad=weight_requires_grad) + if bias: + bias_op.bias.copy_(b_test) + del w_test + del b_test + with te.fp8_autocast(enabled=fp8_compute): + y_test = model(x_test) + y_test.backward(dy_test) + + # Check that forward operations have been fused + forward_ops = model._module_groups[0]._forward_ops + backward_ops = model._module_groups[0]._backward_ops + assert len(forward_ops) == 1 + assert len(backward_ops) == 1 + assert isinstance(forward_ops[0][0], UserbuffersForwardLinear) + assert isinstance(backward_ops[0][0], UserbuffersBackwardLinear) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if fp8_compute: + tols = dtype_tols( + model[0].weight._fp8_dtype + if is_float8_tensor(model[0].weight) + else tex.DType.kFloat8E4M3 + ) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, dx_ref, **tols) + if weight_requires_grad: + dw_test = linear_op.weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dw_test, dw_ref, **tols) + if bias: + db_test = bias_op.bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db_test, db_ref, **tols) + + +def run_parallel_tests(model_config: ModelConfig) -> None: + """Run parallel tests""" + + # Distributed process group + process_group = world_group() + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + # Linear op + for test_config in itertools.product( + (False, True), # bias + ("column", "row"), # tensor_parallel_mode + (False, True), # weight_requires_grad + ): + if rank == 0: + print(f"Running _test_linear with {test_config=}") + bias, tensor_parallel_mode, weight_requires_grad = test_config + _test_linear( + model_config=model_config, + bias=bias, + tensor_parallel_mode=tensor_parallel_mode, + weight_requires_grad=weight_requires_grad, + ) + + +# Parallel job sizes +_world_sizes = [] +if torch.cuda.device_count() > 1: + _world_sizes.append(torch.cuda.device_count()) + + +@pytest.mark.parametrize("world_size", _world_sizes) +@pytest.mark.parametrize("fp8", (False, True)) +def test_fuser_ops_with_userbuffers( + *, + world_size: int, + dtype: torch.dtype = torch.bfloat16, + fp8: bool, +) -> None: + """Launch parallel job and run tests""" + + # Skip invalid configurations + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + + # Parallel job launcher + command = [] + if tex.ubuf_built_with_mpi(): + python_exe = pathlib.Path(sys.executable).resolve() + command.extend(("mpirun", "-np", str(world_size), "--oversubscribe", "--quiet", python_exe)) + else: + command.extend(("torchrun", f"--nproc_per_node={world_size}")) + + # Script invocation + command.extend( + ( + _current_file, + "--parallel", + "--batch-size", + str(world_size), + "--num-heads", + str(world_size), + "--dtype", + str(dtype), + ) + ) + if fp8: + command.append("--fp8") + + # Environment + env = dict(os.environ) + if not tex.device_supports_multicast(): + env["UB_SKIPMC"] = "1" + env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + env["PYTORCH_JIT"] = "0" + env["NVTE_TORCH_COMPILE"] = "0" + env["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + + # Launch parallel job + result = subprocess.run(command, check=True, env=env) + + +def main() -> None: + + # Parse command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument("--parallel", action="store_true", help="Run parallel tests") + parser.add_argument("--sequence-length", type=int, default=32) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--num-heads", type=int, default=16) + parser.add_argument("--head-dim", type=int, default=32) + parser.add_argument("--dtype", type=str, default="bfloat16") + parser.add_argument("--fp8", action="store_true") + args = parser.parse_args() + + # Run parallel tests if needed + if args.parallel: + + # Model config + model_config = ModelConfig( + sequence_length=args.sequence_length, + batch_size=args.batch_size, + num_heads=args.num_heads, + head_dim=args.head_dim, + dtype=str_to_dtype(args.dtype), + fp8=args.fp8, + ) + + # Initialize Userbuffers + group = world_group() # Initialize NCCL + bootstrap_backend = "mpi" if launcher() == "ompi" else "nccl" + userbuffer_configs = { + "fc1_dgrad": {"method": "pipeline"}, # Overlap dgrad RS with dgrad GEMM + } + te.module.base.initialize_ub( + [ + model_config.sequence_length * model_config.batch_size, + model_config.num_heads * model_config.head_dim, + ], + torch.distributed.get_world_size(group), + use_fp8=model_config.fp8, + dtype=model_config.dtype, + bootstrap_backend=bootstrap_backend, + ub_cfgs=userbuffer_configs, + ) + + # Run tests + run_parallel_tests(model_config) + + # Clean up + te.module.base.destroy_ub() + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py new file mode 100644 index 0000000000..a8b181a187 --- /dev/null +++ b/tests/pytorch/utils.py @@ -0,0 +1,85 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +import torch + +import transformer_engine +import transformer_engine.pytorch as te +import transformer_engine_torch as tex + + +def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype: + """Convert type name to PyTorch dtype""" + if isinstance(dtype, torch.dtype): + return dtype + name = str(dtype).strip().lower() + if name.startswith("torch."): + name = name.replace("torch.", "", 1) + if name.startswith("fp"): + name = name.replace("fp", "float", 1) + dtype = dict( + float32=torch.float32, + float=torch.float32, + float64=torch.float64, + double=torch.float64, + float16=torch.float16, + half=torch.float16, + bfloat16=torch.bfloat16, + bf16=torch.bfloat16, + float8_e4m3fn=torch.float8_e4m3fn, + float8_e4m3=torch.float8_e4m3fn, + float8e4m3=torch.float8_e4m3fn, + float8=torch.float8_e4m3fn, + float8_e5m2=torch.float8_e5m2, + float8e5m2=torch.float8_e5m2, + uint8=torch.uint8, + byte=torch.uint8, + int8=torch.int8, + char=torch.int8, + int16=torch.int16, + short=torch.int16, + int32=torch.int32, + int=torch.int32, + int64=torch.int64, + long=torch.int64, + bool=torch.bool, + )[name] + return dtype + + +def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: + """Estimated numerical error for a datatype + + Based on tolerances for torch.testing.assert_close. + + """ + + # Transformer Engine dtypes + if isinstance(dtype, tex.DType): + dtype = { + tex.DType.kByte: torch.uint8, + tex.DType.kInt32: torch.int32, + tex.DType.kFloat32: torch.float32, + tex.DType.kFloat16: torch.half, + tex.DType.kBFloat16: torch.bfloat16, + tex.DType.kFloat8E4M3: torch.float8_e4m3fn, + tex.DType.kFloat8E5M2: torch.float8_e5m2, + }[dtype] + + # PyTorch dtypes + if dtype == torch.float16: + return dict(rtol=1e-3, atol=1e-5) + if dtype == torch.bfloat16: + return dict(rtol=1.6e-2, atol=1e-5) + if dtype == torch.float32: + return dict(rtol=1.3e-6, atol=1e-5) + if dtype == torch.float64: + return dict(rtol=1e-7, atol=1e-7) + if dtype == torch.float8_e4m3fn: + return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 + if dtype == torch.float8_e5m2: + return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 + raise ValueError(f"Unsupported dtype ({dtype})") diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 59ec56f161..a663385b68 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -314,11 +314,13 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap size_t m_chunk = m / _num_splits; size_t input_a_chunk_size = m_chunk * k; size_t output_chunk_size = n * m_chunk; + size_t bias_chunk_size = m_chunk; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); // Get input, output, and workspace data pointers char *input_a_chunk_ptr = reinterpret_cast(A.dptr()); char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.dptr()); + char *bias_chunk_ptr = reinterpret_cast(bias.dptr()); char *workspace_ptr = reinterpret_cast(workspace.dptr()); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); @@ -337,16 +339,21 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap TensorWrapper(A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv()); auto output_chunk = TensorWrapper(_ubuf.dptr(), {m, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); + auto bias_chunk = + TensorWrapper(bias.dptr(), {m_chunk}, bias.dtype(), nullptr, nullptr, nullptr); auto workspace_chunk = TensorWrapper( workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); - nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _stream_compute[0]); for (int i = 1; i < _num_splits; i++) { input_a_chunk_ptr += input_a_chunk_size * B.element_size(); output_buf_chunk_ptr += output_chunk_size * D.element_size(); + if (bias_chunk_ptr != nullptr) { + bias_chunk_ptr += bias_chunk_size * bias.element_size(); + } char *workspace_chunk_ptr = workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; @@ -354,10 +361,12 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap A.dtype(), nullptr, nullptr, A.scale_inv()); output_chunk = TensorWrapper(reinterpret_cast(output_buf_chunk_ptr), {n, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); + bias_chunk = TensorWrapper(reinterpret_cast(bias_chunk_ptr), {m_chunk}, bias.dtype(), + nullptr, nullptr, nullptr); workspace_chunk = TensorWrapper(reinterpret_cast(workspace_chunk_ptr), std::vector{workspace_size_chunk}, workspace.dtype()); - nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]); @@ -409,11 +418,13 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap A.dtype(), nullptr, nullptr, A.scale_inv()); auto output_chunk = TensorWrapper(reinterpret_cast(output_buf_chunk_ptr), {n, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); + auto bias_chunk = TensorWrapper(reinterpret_cast(bias_chunk_ptr), {m_chunk}, + bias.dtype(), nullptr, nullptr, nullptr); auto workspace_chunk = TensorWrapper(reinterpret_cast(workspace_chunk_ptr), std::vector{workspace_size_chunk}, workspace.dtype()); - nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]); @@ -440,6 +451,9 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap rs_output_ptr += m_chunk * rs_output.element_size(); input_a_chunk_ptr += input_a_chunk_size * B.element_size(); output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); + if (bias_chunk_ptr != nullptr) { + bias_chunk_ptr += bias_chunk_size * bias.element_size(); + } } } diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 89a529a78e..b1654add98 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -9,6 +9,8 @@ import torch +from transformer_engine_torch import FP8TensorMeta +from ..fp8 import FP8GlobalStateManager from ..tensor import Float8Tensor from ..utils import ( canonicalize_device, # pylint: disable=unused-import @@ -134,3 +136,25 @@ def maybe_autocast_dtype( if torch.is_autocast_enabled(device_type): return torch.get_autocast_dtype(device_type) return canonicalize_dtype(default_dtype) + + +def get_fp8_meta_from_fp8_tensor(tensor: Float8Tensor) -> tuple[FP8TensorMeta, int]: + """Get FP8TensorMeta object and index corresponding to Float8Tensor + + Constructs FP8TensorMeta if needed. + + """ + + # Check if tensor already has FP8 metadata + if tensor._fp8_meta is not None: + key = FP8GlobalStateManager.get_meta_tensor_key( + forward=tensor._fp8_meta_forward, + ) + return tensor._fp8_meta[key], tensor._fp8_meta_index + + # Create FP8TensorMeta class + fp8_meta = FP8TensorMeta() + fp8_meta.scale = tensor._scale_inv.reciprocal() + fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=tensor.device) + fp8_meta.scale_inv = tensor._scale_inv + return fp8_meta, 0 diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 46a72a08d2..ad86861114 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -83,6 +83,10 @@ class BasicLinear(BasicOperation): autograd. The weight's `main_grad` must be set externally and there is no guarantee that `grad` will be set or be meaningful. + userbuffers_options, dict, optional + Options for overlapping tensor-parallel communication with + compute using Userbuffers. This feature is highly + experimental. """ @@ -98,6 +102,7 @@ def __init__( sequence_parallel: bool = False, rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, accumulate_into_main_grad: bool = False, + userbuffers_options: Optional[dict[str, Any]] = None, ) -> None: super().__init__() @@ -143,7 +148,7 @@ def __init__( ) # Whether weight tensor is natively in FP8 - self._with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + self._with_fp8_parameters: bool = FP8GlobalStateManager.with_fp8_parameters() if self._with_fp8_parameters: self._fp8_metas = self._make_fp8_metas() @@ -163,7 +168,10 @@ def __init__( self.reset_parameters() # Whether to accumulate weight gradient into main_grad - self._accumulate_into_main_grad = accumulate_into_main_grad + self._accumulate_into_main_grad: bool = accumulate_into_main_grad + + # Userbuffers options + self._userbuffers_options: Optional[dict[str, Any]] = userbuffers_options @classmethod def _canonicalize_tensor_parallelism( @@ -707,7 +715,7 @@ def _functional_backward( FP8 metadata for casting loss gradient w.r.t. output tensor to FP8. Required if output grad is not already in FP8. - grad_output_fp8_meta: dict, optional + grad_input_fp8_meta: dict, optional FP8 metadata for casting loss gradient w.r.t. input tensor to FP8 diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index bd832254d8..08b9f06123 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -16,3 +16,11 @@ ForwardLinearBiasAdd, fuse_forward_linear_bias_add, ) +from .userbuffers_backward_linear import ( + UserbuffersBackwardLinear, + fuse_userbuffers_backward_linear, +) +from .userbuffers_forward_linear import ( + UserbuffersForwardLinear, + fuse_userbuffers_forward_linear, +) diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py new file mode 100644 index 0000000000..907cff1c81 --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -0,0 +1,781 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Linear layer backward with Userbuffers communication.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional +import warnings + +import torch + +from transformer_engine_torch import CommOverlapAlgo +from ...cpp_extensions import ( + fp8_cast_transpose_bgrad_fused, + fp8_gemm, + gemm, +) +from ...distributed import get_distributed_world_size +from ...float8_tensor import Float8Tensor +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ...module.base import get_ub, get_workspace +from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data +from ..basic import BasicLinear, Bias, ReduceScatter +from ..op import FusedOperation, FusibleOperation, OperationContext +from .._common import ( + convert_tensor, + get_fp8_meta_from_fp8_tensor, + is_float8_tensor, + reshape, +) + + +class UserbuffersBackwardLinear(FusedOperation): + """Linear backward implementation using Userbuffers + + This operation is equivalent to a linear operation's backward + pass, but it uses Userbuffers to overlap tensor-parallel + communication with compute. + + """ + + def __init__( + self, + *, + linear: BasicLinear, + bias: Optional[Bias], + reduce_scatter: Optional[ReduceScatter], + ) -> None: + + # Basic operations that comprise this fused operation + op_idxs = {"linear": None, "bias": None, "reduce_scatter": None} + ops = [] + if reduce_scatter is not None: + op_idxs["reduce_scatter"] = len(ops) + ops.append(reduce_scatter) + if bias is not None: + op_idxs["bias"] = len(ops) + ops.append(bias) + op_idxs["linear"] = len(ops) + ops.append(linear) + + # Initialize base class + super().__init__(ops) + + # Index of each basic operations + self._op_idxs: dict[str, Optional[int]] = op_idxs + + # Tensor parallelism configuration + self.tensor_parallel_mode: Optional[str] + self.tensor_parallel_group: Optional[torch.distributed.ProcessGroup] + self.tensor_parallel_size: int + self.sequence_parallel: bool + if reduce_scatter is None: + self.tensor_parallel_mode = linear.tensor_parallel_mode + self.tensor_parallel_group = linear.tensor_parallel_group + self.tensor_parallel_size = linear.tensor_parallel_size + self.sequence_parallel = linear.sequence_parallel + else: + self.tensor_parallel_mode = "row" + self.tensor_parallel_group = reduce_scatter.process_group + self.tensor_parallel_size = reduce_scatter.process_group_size + self.sequence_parallel = True + + @staticmethod + def _functional_backward( + grad_output: torch.Tensor, + input: Optional[torch.Tensor], # pylint: disable=redefined-builtin + weight: Optional[torch.Tensor], + input_dims: Iterable[int], + weight_dims: Iterable[int], + *, + weight_requires_grad: bool = True, + bias_requires_grad: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + grad_weight: Optional[torch.Tensor] = None, + accumulate_into_grad_weight: bool = False, + tensor_parallel_mode: Optional[str] = None, + tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + tensor_parallel_size: Optional[int] = None, + sequence_parallel: bool = False, + with_fp8_compute: bool = False, + input_fp8_meta: Optional[dict[str, Any]] = None, + weight_fp8_meta: Optional[dict[str, Any]] = None, + grad_output_fp8_meta: Optional[dict[str, Any]] = None, + grad_input_fp8_meta: Optional[dict[str, Any]] = None, + ub_comm_name: str, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], dict]: + """Functional API for backward pass + + Parameters + ---------- + grad_output: torch.Tensor + Loss gradient w.r.t. output tensor + input: torch.Tensor, optional + Input tensor. Required to compute loss gradient w.r.t. + weight. + weight: torch.Tensor, optional + Weight tensor. Required to compute loss gradient w.r.t. + input. + input_dims: iterable of int + Input tensor dimensions + weight_dims: iterable of int + Weight tensor dimensions + weight_requires_grad: bool + Whether to compute loss gradient w.r.t. weight tensor + bias_requires_grad: bool + Whether to compute loss gradient w.r.t. bias tensor + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype + grad_weight: torch.Tensor, optional + Loss gradient w.r.t. weight tensor + accumulate_into_grad_weight: bool, default = `False` + Add result to weight grad instead of overwriting + tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + Mode for tensor parallelism + tensor_parallel_group: torch.distributed.ProcessGroup, default = world group + Process group for tensor parallelism + sequence_parallel: bool, default = `False` + Whether to apply sequence parallelism together with tensor + parallelism, i.e. distributing input or output tensors + along outer dimension (sequence or batch dim) when not + distributing along inner dimension (embedding dim) + with_fp8_compute: bool, default = `False` + Whether to perform compute in FP8 + input_fp8_meta: dict, optional + FP8 metadata for casting input tensor to FP8. Required for + FP8 compute if input is not already in FP8. + weight_fp8_meta: dict, optional + FP8 metadata for casting weight tensor to FP8. Required for + FP8 compute if weight is not already in FP8. + grad_output_fp8_meta: dict, optional + FP8 metadata for casting loss gradient w.r.t. output + tensor to FP8. Required if output grad is not already in + FP8. + grad_input_fp8_meta: dict, optional + FP8 metadata for casting loss gradient w.r.t. input + tensor to FP8 + ub_comm_name: str + Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is + used to access the corresponding Userbuffers communicators + (e.g. "qkv_dgrad", "qkv_wgrad"). + + Returns + ------- + torch.Tensor + Loss gradient w.r.t. input tensor + torch.Tensor + Loss gradient w.r.t. weight tensor + dict + Extra output tensors. "grad_bias" is loss gradient w.r.t. + the bias tensor. + + """ + + # Configuration-specific outputs + extra_outputs = {} + + # Check device + if device is None: + device = weight.device + device = canonicalize_device(device) + if device.type != "cuda": + raise ValueError(f"Only CUDA devices are supported (got {device})") + + # Check datatype + if dtype is None: + dtype = weight.dtype + dtype = canonicalize_dtype(dtype) + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") + + # Input tensor dims + output_dims = tuple(grad_output.size()) + input_dims = tuple(input_dims) + weight_dims = tuple(weight_dims) + if len(weight_dims) != 2: + raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})") + if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]: + raise ValueError( + f"Input tensor (shape={input_dims}) " + f"and weight tensor (shape={weight_dims}) " + "are not compatible" + ) + if weight_dims[0] != output_dims[-1]: + raise ValueError( + f"Grad output tensor (shape={output_dims}) " + f"and weight tensor (shape={weight_dims}) " + "are not compatible" + ) + + # Check tensor parallel group + if tensor_parallel_size is None: + tensor_parallel_size = get_distributed_world_size(tensor_parallel_group) + if tensor_parallel_size == 1: + tensor_parallel_mode = None + if tensor_parallel_mode not in ("column", "row"): + raise RuntimeError( + "Invalid configuration for Userbuffers " + f"({tensor_parallel_size=}, {tensor_parallel_mode=})" + ) + if not sequence_parallel: + raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})") + + # Check if FP8 is enabled + if with_fp8_compute: + if grad_output_fp8_meta is None and not is_float8_tensor(grad_output): + raise ValueError("No FP8 metadata was provided for casting output gradient to FP8") + else: + input_fp8_meta = None + weight_fp8_meta = None + grad_output_fp8_meta = None + grad_input_fp8_meta = None + with_fp8_grad_input = ( + with_fp8_compute + and tensor_parallel_mode != "column" + and grad_input_fp8_meta is not None + ) + + # Get Userbuffers communicators and algorithms + # Note: communication patterns are (1) overlap dy all-gather + # with dgrad GEMM, (2) overlap x all-gather with dgrad GEMM + # and dx reduce-scatter with wgrad GEMM, (3) overlap dx + # reduce-scatter with dgrad GEMM. + with_ub_all_gather_dy = False + with_ub_reduce_scatter_dx = False + with_ub_all_gather_x = False + ub_comm_dy = None + ub_comm_dx = None + ub_comm_x = None + ub_algo_dy = None + ub_algo_dx = None + ub_algo_x = None + if tensor_parallel_mode == "row": + with_ub_all_gather_dy = True + ub_comm_dy = get_ub(ub_comm_name + "_dgrad") + if with_fp8_compute and ub_comm_dy.is_atomic_gemm(): + ub_algo_dy = CommOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo_dy = CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + elif tensor_parallel_mode == "column": + with_ub_reduce_scatter_dx = True + if weight_requires_grad: + with_ub_all_gather_x = True + ub_comm_dx = get_ub(ub_comm_name + "_wgrad") + ub_comm_x = get_ub(ub_comm_name + "_dgrad") + ub_algo_dx = CommOverlapAlgo.BULK_OVERLAP_RS + ub_algo_x = CommOverlapAlgo.BULK_OVERLAP_AG + else: + with_ub_all_gather_x = False + ub_comm_dx = get_ub(ub_comm_name + "_dgrad") + is_atomic_gemm = with_fp8_compute and ub_comm_dx.is_atomic_gemm() + ub_algo_dx = { + (True, True): CommOverlapAlgo.ATOMIC_GEMM_RS_P2P, + (True, False): CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P, + (False, True): CommOverlapAlgo.ATOMIC_GEMM_RS, + (False, False): CommOverlapAlgo.SPLIT_PIPELINED_RS, + }[(ub_comm_dx.is_p2p_overlap(), is_atomic_gemm)] + + # Check grad output tensor + # Note: Possibly fuse cast with computing grad bias + dy_local = reshape( + grad_output, + (-1, output_dims[-1]), + device=device, + dtype=dtype, + ) + db = None + db_async = None + if bias_requires_grad and with_fp8_compute and with_ub_all_gather_dy: + # We don't have a grad bias impl that takes FP8 input. For + # cases where we cast to FP8 and all-gather, it's better + # to compute the grad bias on ungathered, non-FP8 values. + db = dy_local.sum(dim=0) + db_async = torch.distributed.all_reduce( + db, + group=tensor_parallel_group, + async_op=True, + ) + if with_fp8_compute and not is_float8_tensor(dy_local): + fp8_dtype = get_fp8_te_dtype( + grad_output_fp8_meta["recipe"], + fprop_tensor=False, + ) + if bias_requires_grad and db is None: + # Fused cast-transpose-bgrad + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) + fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=device) + db, data, data_transpose = fp8_cast_transpose_bgrad_fused( + dy_local, + grad_output_fp8_meta[fp8_meta_key], + 0, + fp8_dtype, + scale_inv=fp8_scale_inv, + ) + if with_ub_all_gather_dy: + data = ub_comm_dy.get_ubuf_output(0).copy_(data) + dy_local = Float8Tensor( + data=data, + fp8_meta=grad_output_fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + dtype=dtype, + data_transpose=data_transpose, + ) + else: + dy_local = Float8Tensor.to_float8( + dy_local, + fp8_meta=grad_output_fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + data=(ub_comm_dy.get_ubuf_output(0) if with_ub_all_gather_dy else None), + with_transpose_cache=(not with_ub_all_gather_dy), + ) + elif not with_fp8_compute and is_float8_tensor(dy_local): + if with_ub_all_gather_dy: + ub_local_buffer = ub_comm_dy.get_ubuf_output(0) + dy_local = ub_local_buffer.copy_(dy_local) + else: + dy_local = dy_local.dequantize() + + if bias_requires_grad and db is None and with_fp8_compute and with_ub_all_gather_dy: + # We don't have a fused grad bias impl that takes FP8 + # input. For cases where we cast to FP8 and all-gather, + # it's better to compute the grad bias on ungathered, + # non-FP8 values. + db = dy_local.sum(dim=0) + db_async = torch.distributed.all_reduce( + db, + group=tensor_parallel_group, + async_op=True, + ) + + # Check input tensor + x_local = None + if weight_requires_grad: + x_local = reshape( + input, + (-1, input_dims[-1]), + device=device, + dtype=dtype, + ) + if with_fp8_compute and not is_float8_tensor(x_local): + fp8_dtype = get_fp8_te_dtype( + input_fp8_meta["recipe"], + fprop_tensor=True, + ) + x_local = Float8Tensor.to_float8( + x_local, + fp8_meta=input_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + data=(ub_comm_x.get_ubuf_output(0) if with_ub_all_gather_x else None), + with_transpose_cache=(not with_ub_all_gather_x), + ) + elif not with_fp8_compute and is_float8_tensor(x_local): + if with_ub_all_gather_x: + ub_local_buffer = ub_comm_x.get_ubuf_output(0) + x_local = ub_local_buffer.copy_(x_local) + else: + x_local = x_local.dequantize() + + # Check weight tensor + w = convert_tensor( + weight, + device=device, + dtype=dtype, + memory_format=torch.contiguous_format, + ) + if with_fp8_compute and not is_float8_tensor(w): + fp8_dtype = get_fp8_te_dtype( + weight_fp8_meta["recipe"], + fprop_tensor=True, + ) + w = Float8Tensor.to_float8( + w, + fp8_meta=weight_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + with_transpose_cache=True, + ) + elif not with_fp8_compute and is_float8_tensor(w): + w = w.dequantize() + + # Initialize buffers for UB all-gather if needed + dy = dy_local + x = x_local + if with_ub_all_gather_dy: + ub_local_buffer = ub_comm_dy.get_ubuf_output(0) + ub_global_buffer = ub_comm_dy.get_ubuf_output(1) + if with_fp8_compute: + dy = Float8Tensor.make_like(dy_local, data=ub_global_buffer) + if dy_local._data.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(dy_local._data) + else: + dy = ub_global_buffer + if dy_local.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(dy_local) + if with_ub_all_gather_x: + ub_local_buffer = ub_comm_x.get_ubuf_output(0) + ub_global_buffer = ub_comm_x.get_ubuf_output(1) + if with_fp8_compute: + x = Float8Tensor.make_like(x_local, data=ub_global_buffer) + if x_local._data.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(x_local._data) + else: + x = ub_global_buffer + if x_local.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(x_local) + + # Construct grad input tensor + dx = None + dx_local = None + if with_ub_reduce_scatter_dx: + # Initialize buffers for UB reduce-scatter + dx = ub_comm_dx.get_ubuf_output(1) + ub_local_buffer = ub_comm_dx.get_ubuf_output(0) + if with_ub_all_gather_x: + dx_local = ub_local_buffer + else: + dx_local = torch.empty_like(ub_local_buffer) + else: + # Allocate grad input tensor + if with_fp8_grad_input: + fp8_dtype = get_fp8_te_dtype( + grad_input_fp8_meta["recipe"], + fprop_tensor=False, + ) + data = torch.empty( + (dy.size(0), w.size(-1)), + dtype=torch.uint8, + device=device, + ) + dx = Float8Tensor( + data=data, + fp8_meta=grad_input_fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + dtype=dtype, + ) + else: + dx = torch.empty( + (dy.size(0), w.size(-1)), + dtype=dtype, + device=device, + ) + dx_local = dx + + # Allocate grad input tensor + if grad_weight is None: + if accumulate_into_grad_weight: + raise ValueError( + "Attempted to accumulate into grad weight bufferwithout providing grad weight" + ) + grad_weight = torch.empty( + weight_dims, + dtype=dtype, + device=device, + memory_format=torch.contiguous_format, + ) + + # Perform dgrad GEMM + if with_fp8_compute: + kwargs = {"out": dx, "use_split_accumulator": True} + if with_ub_all_gather_dy: + kwargs["ub_algo"] = ub_algo_dy + kwargs["ub"] = ub_comm_dy + elif with_ub_all_gather_x: + kwargs["ub_algo"] = ub_algo_x + kwargs["ub"] = ub_comm_x + elif with_ub_reduce_scatter_dx: + kwargs["ub_algo"] = ub_algo_dx + kwargs["ub"] = ub_comm_dx + kwargs["extra_output_tensor"] = dx_local + if with_fp8_grad_input: + fp8_meta, fp8_meta_index = get_fp8_meta_from_fp8_tensor(dx) + kwargs.update( + { + "out": dx._data, + "out_index": fp8_meta_index, + "fp8_meta_tensor": fp8_meta, + "D_dtype": dx._fp8_dtype, + } + ) + fp8_gemm( + w.transpose_2d(), + w._scale_inv, + 0, + w._fp8_dtype, + dy._data, + dy._scale_inv, + 0, + dy._fp8_dtype, + dy.dtype, + get_workspace(), + **kwargs, + ) + else: + kwargs = {"grad": True, "layout": "NN", "out": dx} + if with_ub_all_gather_dy: + kwargs["ub_algo"] = ub_algo_dy + kwargs["ub"] = ub_comm_dy + elif with_ub_all_gather_x: + kwargs["ub_algo"] = ub_algo_x + kwargs["ub"] = ub_comm_x + elif with_ub_reduce_scatter_dx: + kwargs["ub_algo"] = ub_algo_dx + kwargs["ub"] = ub_comm_dx + kwargs["extra_output_tensor"] = dx_local + gemm(w, dy, dx.dtype, get_workspace(), **kwargs) + grad_input = reshape(dx_local, input_dims) + + # Perform wgrad GEMM + if not weight_requires_grad: + pass + elif with_fp8_compute: + kwargs = { + "accumulate": accumulate_into_grad_weight, + "out": grad_weight, + "use_split_accumulator": True, + } + if with_ub_reduce_scatter_dx: + kwargs["ub_algo"] = ub_algo_dx + kwargs["ub"] = ub_comm_dx + fp8_gemm( + x.transpose_2d(), + x._scale_inv, + 0, + x._fp8_dtype, + dy.transpose_2d(), + dy._scale_inv, + 0, + dy._fp8_dtype, + grad_weight.dtype, + get_workspace(), + **kwargs, + ) + else: + kwargs = { + "accumulate": accumulate_into_grad_weight, + "layout": "NT", + "grad": True, + "use_bias": bias_requires_grad, + "out": grad_weight, + } + if with_ub_reduce_scatter_dx: + kwargs["ub_algo"] = ub_algo_dx + kwargs["ub"] = ub_comm_dx + grad_weight, db, _ = gemm( + x, + dy, + grad_weight.dtype, + get_workspace(), + **kwargs, + ) + + # Compute grad bias if needed + if db_async is not None: + db_async.wait() + if bias_requires_grad: + if db is None: + db = dy.sum(dim=0) + extra_outputs["grad_bias"] = db + + return grad_input, grad_weight, extra_outputs + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + list[tuple[Optional[torch.Tensor], ...]], + list[tuple[()]], + ]: + + # Get basic operations + idx = self._op_idxs["linear"] + linear_op = self.basic_ops[idx] + linear_op_ctx = basic_op_ctxs[idx] + bias_op = None + if self._op_idxs["bias"] is not None: + idx = self._op_idxs["bias"] + bias_op = self.basic_ops[idx] + + # Saved tensors from forward pass + (x_local,) = linear_op_ctx.saved_tensors + + # wgrad fusion + accumulate_into_main_grad = linear_op._accumulate_into_main_grad + grad_weight = None + if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: + if not hasattr(linear_op.weight, "main_grad"): + raise RuntimeError( + "BasicLinear op is configured with " + "accumulate_into_main_grad=True, " + "but weight parameter does not have main_grad attribute" + ) + grad_weight = linear_op.weight.main_grad.detach() + else: + accumulate_into_main_grad = False + + # Hackily workaround Userbuffers bug with non-FP8 dgrad + # reduce-scatter overlap + weight_requires_grad = linear_op_ctx.weight_requires_grad + if not linear_op_ctx.with_fp8_compute and not weight_requires_grad: + warnings.warn( + "There is a correctness bug when using Userbuffers " + "to overlap a dgrad reduce-scatter with a non-FP8 dgrad GEMM. " + "Hackily working around by overlapping dgrad reduce-scatter " + "with wgrad GEMM, even though wgrad isn't needed. " + "Please contact Transformer Engine team " + "if you encounter this use-case." + ) + weight_requires_grad = True + + # Linear backward pass + retval = UserbuffersBackwardLinear._functional_backward( + grad_output=grad_output, + input=x_local, + weight=linear_op.weight, + input_dims=linear_op_ctx.input_dims, + weight_dims=linear_op.weight.size(), + weight_requires_grad=weight_requires_grad, + bias_requires_grad=(bias_op is not None), + device=linear_op.device, + dtype=linear_op_ctx.dtype, + grad_weight=grad_weight, + accumulate_into_grad_weight=accumulate_into_main_grad, + tensor_parallel_mode=self.tensor_parallel_mode, + tensor_parallel_group=self.tensor_parallel_group, + sequence_parallel=self.sequence_parallel, + with_fp8_compute=linear_op_ctx.with_fp8_compute, + weight_fp8_meta=linear_op_ctx.weight_fp8_meta, + grad_output_fp8_meta=linear_op_ctx.grad_output_fp8_meta, + grad_input_fp8_meta=linear_op_ctx.grad_input_fp8_meta, + ub_comm_name=linear_op._userbuffers_options["comm_name"], + ) + grad_input, grad_weight, extra_outputs = retval + grad_bias = None + if bias_op is not None: + grad_bias = extra_outputs["grad_bias"] + + # Clear input tensor if possible + if linear_op_ctx.has_prev_op: + clear_tensor_data(x_local) + + # Return gradients + grad_params = [() for _ in range(len(self.basic_ops))] + if accumulate_into_main_grad: + grad_weight = None + grad_params[self._op_idxs["linear"]] = (grad_weight,) + if bias_op is not None: + grad_params[self._op_idxs["bias"]] = (grad_bias,) + grad_extra_inputs = [() for _ in range(len(self.basic_ops))] + return grad_input, grad_params, grad_extra_inputs + + +def fuse_userbuffers_backward_linear( + ops: list[tuple[FusibleOperation, list[int]]], +) -> list[tuple[FusibleOperation, list[int]]]: + """Substitute linear operations with Userbuffers implementation + + Parameters + ---------- + ops: list of tuples + Forward pass operations and the indices of the corresponding + basic operations. + + Returns + ------- + ops: list of tuples + Updated forward pass operations + + """ + + # Return immediately if environment is not distributed + if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: + return ops + + # Sliding window in list of ops + window = [] + + def peek_next_op() -> Optional[FusibleOperation]: + """Get next op in list of ops""" + nonlocal ops + if not ops: + return None + return ops[-1][0] + + def pop_next_op() -> FusibleOperation: + """Remove next op from list of ops and add to sliding window""" + nonlocal ops, window + window.insert(0, ops[-1]) + ops = ops[:-1] + return window[0][0] + + # Scan through ops in reverse order, fusing if possible + out_reversed = [] + while ops: + out_reversed.extend(reversed(window)) + window.clear() + + # Check if next op is linear + next_op = pop_next_op() + if not isinstance(next_op, BasicLinear): + continue + linear = next_op + if linear._userbuffers_options is None: + continue + + # Check if next op is bias + bias = None + if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias): + bias = pop_next_op() + + # Check if next op is reduce-scatter + reduce_scatter = None + if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter): + reduce_scatter = pop_next_op() + + # Check for invalid combinations + if reduce_scatter is None: + if linear.tensor_parallel_mode is None: + continue + if linear.tensor_parallel_size == 1: + continue + if linear.tensor_parallel_mode == "row" and bias is not None: + continue + else: + if linear.tensor_parallel_mode is not None: + continue + if reduce_scatter.process_group_size == 1: + continue + + # Replace window with fused op + op = UserbuffersBackwardLinear( + linear=linear, + bias=bias, + reduce_scatter=reduce_scatter, + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out_reversed.extend(reversed(window)) + out = out_reversed + out.reverse() + return out diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py new file mode 100644 index 0000000000..a1b0ca6a9e --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -0,0 +1,597 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Linear layer forward with Userbuffers communication.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +from transformer_engine_torch import CommOverlapAlgo +from ...cpp_extensions import fp8_gemm, gemm +from ...distributed import get_distributed_world_size +from ...float8_tensor import Float8Tensor +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ...module.base import get_ub, get_workspace +from ...utils import canonicalize_device, canonicalize_dtype +from ..basic import BasicLinear, Bias, ReduceScatter +from ..op import ( + BasicOperation, + FusedOperation, + FusibleOperation, + OperationContext, +) +from .._common import ( + convert_tensor, + get_fp8_meta_from_fp8_tensor, + is_float8_tensor, + reshape, +) + + +class UserbuffersForwardLinear(FusedOperation): + """Linear forward implementation using Userbuffers + + This operation is equivalent to a linear operation's forward pass, + but it uses Userbuffers to overlap tensor-parallel communication + with compute. + + """ + + def __init__( + self, + *, + linear: BasicLinear, + bias: Optional[Bias], + reduce_scatter: Optional[ReduceScatter], + ) -> None: + + # Basic operations that comprise this fused operation + op_idxs = {"linear": 0, "bias": None, "reduce_scatter": None} + ops = [linear] + if bias is not None: + op_idxs["bias"] = len(ops) + ops.append(bias) + if reduce_scatter is not None: + op_idxs["reduce_scatter"] = len(ops) + ops.append(reduce_scatter) + + # Initialize base class + super().__init__(ops) + + # Index of each basic operations + self._op_idxs: dict[str, Optional[int]] = op_idxs + + # Tensor parallelism configuration + self.tensor_parallel_mode: Optional[str] + self.tensor_parallel_group: Optional[torch.distributed.ProcessGroup] + self.tensor_parallel_size: int + self.sequence_parallel: bool + if reduce_scatter is None: + self.tensor_parallel_mode = linear.tensor_parallel_mode + self.tensor_parallel_group = linear.tensor_parallel_group + self.tensor_parallel_size = linear.tensor_parallel_size + self.sequence_parallel = linear.sequence_parallel + else: + self.tensor_parallel_mode = "row" + self.tensor_parallel_group = reduce_scatter.process_group + self.tensor_parallel_size = reduce_scatter.process_group_size + self.sequence_parallel = True + + @staticmethod + def _functional_forward( + input: torch.Tensor, # pylint: disable=redefined-builtin + weight: torch.Tensor, + *, + bias: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + tensor_parallel_mode: Optional[str] = None, + tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + tensor_parallel_size: Optional[int] = None, + sequence_parallel: bool = False, + with_fp8_compute: bool = False, + input_fp8_meta: Optional[dict[str, Any]] = None, + weight_fp8_meta: Optional[dict[str, Any]] = None, + output_fp8_meta: Optional[dict[str, Any]] = None, + ub_comm_name: str, + ) -> tuple[torch.Tensor, dict]: + """Functional API for forward pass + + Parameters + ---------- + input: torch.Tensor + Input tensor + weight: torch.Tensor + Weight tensor + bias: torch.Tensor, optional + Bias tensor + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype + tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + Mode for tensor parallelism + tensor_parallel_group: torch.distributed.ProcessGroup, default = world group + Process group for tensor parallelism + sequence_parallel: bool, default = `False` + Whether to apply sequence parallelism together with tensor + parallelism, i.e. distributing input or output tensors + along outer dimension (sequence or batch dim) when not + distributing along inner dimension (embedding dim) + with_fp8_compute: bool, default = `False` + Whether to perform compute in FP8 + input_fp8_meta: dict, optional + FP8 metadata for casting input tensor to FP8. Required for + FP8 compute if input is not already in FP8. + weight_fp8_meta: dict, optional + FP8 metadata for casting weight tensor to FP8. Required for + FP8 compute if weight is not already in FP8. + output_fp8_meta: dict, optional + FP8 metadata for casting output tensor to FP8 + ub_comm_name: str + Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is + used to access the corresponding Userbuffers communicators + (e.g. "qkv_fprop"). + + Returns + ------- + torch.Tensor + Output tensor + dict + Extra output tensors. "input" is the input tensor, + possibly cast and reshaped from the provided input tensor. + + """ + + # Check device + if device is None: + device = weight.device + device = canonicalize_device(device) + if device.type != "cuda": + raise ValueError(f"Only CUDA devices are supported (got {device})") + + # Check datatype + if dtype is None: + dtype = weight.dtype + dtype = canonicalize_dtype(dtype) + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") + + # Input tensor dims + input_dims = tuple(input.size()) + weight_dims = tuple(weight.size()) + if len(weight_dims) != 2: + raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})") + if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]: + raise ValueError( + f"Input tensor (shape={input_dims}) " + f"and weight tensor (shape={weight_dims}) " + "are not compatible" + ) + + # Output tensor dims + output_dims = list(input_dims) + output_dims[0] = -1 + output_dims[-1] = weight_dims[0] + + # Check tensor parallel group + if tensor_parallel_size is None: + tensor_parallel_size = get_distributed_world_size(tensor_parallel_group) + if tensor_parallel_size == 1: + tensor_parallel_mode = None + if tensor_parallel_mode not in ("column", "row"): + raise RuntimeError( + "Invalid configuration for Userbuffers " + f"({tensor_parallel_size=}, {tensor_parallel_mode=})" + ) + if not sequence_parallel: + raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})") + + # Check if FP8 is enabled + if with_fp8_compute: + if input_fp8_meta is None and not is_float8_tensor(input): + raise ValueError("No FP8 metadata was provided for casting input to FP8") + if weight_fp8_meta is None and not is_float8_tensor(weight): + raise ValueError("No FP8 metadata was provided for casting weight to FP8") + else: + input_fp8_meta = None + weight_fp8_meta = None + output_fp8_meta = None + with_fp8_output = ( + with_fp8_compute and tensor_parallel_mode != "row" and output_fp8_meta is not None + ) + + # Get Userbuffers communicator + ub_comm = get_ub(ub_comm_name + "_fprop") + ub_local_buffer = ub_comm.get_ubuf_output(0) + ub_global_buffer = ub_comm.get_ubuf_output(1) + with_ub_all_gather = tensor_parallel_mode == "column" + with_ub_reduce_scatter = tensor_parallel_mode == "row" + + # Choose Userbuffers communication algorithm + ub_algo = None + if with_ub_all_gather: + if with_fp8_compute and ub_comm.is_atomic_gemm(): + ub_algo = CommOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo = CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + elif with_ub_reduce_scatter: + is_atomic_gemm = with_fp8_compute and ub_comm.is_atomic_gemm() + ub_algo = { + (True, True): CommOverlapAlgo.ATOMIC_GEMM_RS_P2P, + (True, False): CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P, + (False, True): CommOverlapAlgo.ATOMIC_GEMM_RS, + (False, False): CommOverlapAlgo.SPLIT_PIPELINED_RS, + }[(ub_comm.is_p2p_overlap(), is_atomic_gemm)] + else: + raise RuntimeError("Could not choose Userbuffers communication algorithm") + + # Cast input tensor to correct dtype + x_local = reshape( + input, + (-1, input_dims[-1]), + device=device, + dtype=dtype, + ) + if with_fp8_compute and not is_float8_tensor(x_local): + fp8_dtype = get_fp8_te_dtype( + input_fp8_meta["recipe"], + fprop_tensor=True, + ) + with_transpose_cache = weight.requires_grad + if tensor_parallel_mode == "column" and sequence_parallel: + with_transpose_cache = False + x_local = Float8Tensor.to_float8( + x_local, + fp8_meta=input_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + data=(ub_local_buffer if with_ub_all_gather else None), + with_transpose_cache=with_transpose_cache, + ) + elif not with_fp8_compute and is_float8_tensor(x_local): + if with_ub_all_gather: + x_local = ub_local_buffer.copy_(x_local) + else: + x_local = x_local.dequantize() + + # Initialize buffers for UB all-gather if needed + x = x_local + if with_ub_all_gather: + if with_fp8_compute: + x = Float8Tensor.make_like(x_local, data=ub_global_buffer) + if x_local._data.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(x_local._data) + else: + x_local._data = torch.empty_like(x_local._data) + else: + x = ub_global_buffer + if x_local.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(x_local) + else: + x_local = torch.empty_like(x_local) + + # Check weight tensor + w = convert_tensor( + weight, + device=device, + dtype=dtype, + memory_format=torch.contiguous_format, + ) + if with_fp8_compute and not is_float8_tensor(w): + fp8_dtype = get_fp8_te_dtype( + weight_fp8_meta["recipe"], + fprop_tensor=True, + ) + w = Float8Tensor.to_float8( + w, + fp8_meta=weight_fp8_meta, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + ) + elif not with_fp8_compute and is_float8_tensor(w): + w = w.dequantize() + + # Check bias tensor + b = None + if bias is not None: + b = convert_tensor( + bias, + device=device, + dtype=dtype, + memory_format=torch.contiguous_format, + ) + + # Construct output tensor + y = None + y_local = None + if with_ub_reduce_scatter: + # Initialize buffers for UB reduce-scatter + if with_fp8_output: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + fp8_dtype = get_fp8_te_dtype( + output_fp8_meta["recipe"], + fprop_tensor=True, + ) + y = Float8Tensor( + data=ub_global_buffer, + fp8_meta=output_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + fp8_scale_inv=output_fp8_meta[fp8_meta_key].scale_inv[0], + dtype=dtype, + ) + ub_comm.set_ubuf_scale_inv(y._scale_inv) + else: + y = ub_global_buffer + y_local = torch.empty( + (x.size(0) // tensor_parallel_size, weight_dims[0]), + dtype=dtype, + device=device, + ) + else: + # Allocate output tensor + if with_fp8_output: + fp8_dtype = get_fp8_te_dtype( + output_fp8_meta["recipe"], + fprop_tensor=True, + ) + data = torch.empty( + (x.size(0), weight_dims[0]), + dtype=torch.uint8, + device=device, + ) + y = Float8Tensor( + data=data, + fp8_meta=output_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + dtype=dtype, + ) + else: + y = torch.empty( + (x.size(0), weight_dims[0]), + dtype=dtype, + device=device, + ) + y_local = y + + # Perform GEMM + if with_fp8_compute: + kwargs = { + "out": y, + "bias": b, + "use_bias": (b is not None), + "use_split_accumulator": False, + "ub_algo": ub_algo, + "ub": ub_comm, + } + if with_ub_all_gather: + kwargs["extra_output_tensor"] = x_local._data + if with_ub_reduce_scatter: + kwargs["extra_output_tensor"] = y_local + if with_fp8_output: + fp8_meta, fp8_meta_index = get_fp8_meta_from_fp8_tensor(y) + kwargs.update( + { + "out": y._data, + "out_index": fp8_meta_index, + "fp8_meta_tensor": fp8_meta, + "D_dtype": y._fp8_dtype, + } + ) + fp8_gemm( + w._data, + w._scale_inv, + 0, + w._fp8_dtype, + x._data, + x._scale_inv, + 0, + x._fp8_dtype, + y.dtype, + get_workspace(), + **kwargs, + ) + else: + kwargs = { + "out": y, + "bias": b, + "use_bias": (b is not None), + "ub_algo": ub_algo, + "ub": ub_comm, + } + if with_ub_all_gather: + kwargs["extra_output_tensor"] = x_local + if with_ub_reduce_scatter: + kwargs["extra_output_tensor"] = y_local + gemm(w, x, y.dtype, get_workspace(), **kwargs) + + # Reshape output tensor + out = reshape(y_local, output_dims) + + # Return cast tensors + extra_outputs = {"input": x_local, "weight": w} + return out, extra_outputs + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + basic_op_prev_ops: list[Optional[BasicOperation]], + basic_op_next_ops: list[Optional[BasicOperation]], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + + # Get basic operations + idx = self._op_idxs["linear"] + linear_op = self.basic_ops[idx] + linear_op_ctx = basic_op_ctxs[idx] + bias_op = None + bias = None + if self._op_idxs["bias"] is not None: + idx = self._op_idxs["bias"] + bias_op = self.basic_ops[idx] + bias = bias_op.bias + if basic_op_kwargs[idx]: + raise ValueError("Bias operation forward does not expect keyword arguments") + + # FP8 metadata + with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() + input_fp8_meta = None + weight_fp8_meta = None + output_fp8_meta = None + grad_output_fp8_meta = None + grad_input_fp8_meta = None + if with_fp8_compute: + input_fp8_meta = linear_op.get_fp8_meta("input") + weight_fp8_meta = linear_op.get_fp8_meta("param") + next_op = basic_op_next_ops[-1] + if next_op is not None and next_op.num_fp8_scales("input") > 0: + output_fp8_meta = next_op.get_fp8_meta("input") + grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output") + prev_op = basic_op_prev_ops[0] + if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: + grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + + # Get autocast dtype if needed + dtype = None + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + + # Userbuffers options + if linear_op._userbuffers_options is None: + raise RuntimeError("Linear op is missing dict for Userbuffers options") + + # Linear forward + output, extra_outputs = UserbuffersForwardLinear._functional_forward( + input=input_, + weight=linear_op.weight, + bias=bias, + device=linear_op.device, + dtype=dtype, + tensor_parallel_mode=self.tensor_parallel_mode, + tensor_parallel_group=self.tensor_parallel_group, + tensor_parallel_size=self.tensor_parallel_size, + sequence_parallel=self.sequence_parallel, + with_fp8_compute=with_fp8_compute, + input_fp8_meta=input_fp8_meta, + weight_fp8_meta=weight_fp8_meta, + output_fp8_meta=output_fp8_meta, + ub_comm_name=linear_op._userbuffers_options["comm_name"], + ) + x_local = extra_outputs["input"] + + # Save state for backward pass + linear_op_ctx.save_for_backward(x_local) + linear_op_ctx.with_fp8_compute = with_fp8_compute + linear_op_ctx.weight_fp8_meta = weight_fp8_meta + linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta + linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta + linear_op_ctx.dtype = dtype + linear_op_ctx.input_dims = input_.size() + linear_op_ctx.input_requires_grad = input_.requires_grad + linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad + linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None + + return output, [() for _ in range(len(self.basic_ops))] + + +def fuse_userbuffers_forward_linear( + ops: list[tuple[FusibleOperation, list[int]]], +) -> list[tuple[FusibleOperation, list[int]]]: + """Substitute linear operations with Userbuffers implementation + + Parameters + ---------- + ops: list of tuples + Forward pass operations and the indices of the corresponding + basic operations. + + Returns + ------- + ops: list of tuples + Updated forward pass operations + + """ + + # Return immediately if environment is not distributed + if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: + return ops + + # Sliding window in list of ops + window = [] + + def peek_next_op() -> Optional[FusibleOperation]: + """Get next op in list of ops""" + nonlocal ops + if not ops: + return None + return ops[0][0] + + def pop_next_op() -> FusibleOperation: + """Remove next op from list of ops and add to sliding window""" + nonlocal ops, window + window.append(ops[0]) + ops = ops[1:] + return window[-1][0] + + # Scan through ops, fusing if possible + out = [] + while ops: + out.extend(window) + window.clear() + + # Check if next op is linear + next_op = pop_next_op() + if not isinstance(next_op, BasicLinear): + continue + linear = next_op + if linear._userbuffers_options is None: + continue + + # Check if next op is bias + bias = None + if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias): + bias = pop_next_op() + + # Check if next op is reduce-scatter + reduce_scatter = None + if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter): + reduce_scatter = pop_next_op() + + # Check for invalid combinations + if reduce_scatter is None: + if linear.tensor_parallel_mode is None: + continue + if linear.tensor_parallel_size == 1: + continue + if linear.tensor_parallel_mode == "row" and bias is not None: + continue + else: + if linear.tensor_parallel_mode is not None: + continue + if reduce_scatter.process_group_size == 1: + continue + + # Replace window with fused op + op = UserbuffersForwardLinear( + linear=linear, + bias=bias, + reduce_scatter=reduce_scatter, + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out.extend(window) + return out diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index bbfb9416fc..6fcb435e5c 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -20,6 +20,8 @@ fuse_backward_linear_add, fuse_forward_linear_bias_activation, fuse_forward_linear_bias_add, + fuse_userbuffers_backward_linear, + fuse_userbuffers_forward_linear, ) @@ -345,6 +347,7 @@ def _fuse_forward_ops( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: """Attempt to fuse operations in forward pass""" + ops = fuse_userbuffers_forward_linear(ops) ops = fuse_forward_linear_bias_add(ops) ops = fuse_forward_linear_bias_activation(ops) return ops @@ -355,6 +358,7 @@ def _fuse_backward_ops( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: """Attempt to fuse operations in backward pass""" + ops = fuse_userbuffers_backward_linear(ops) ops = fuse_backward_linear_add(ops) return ops diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 110059d745..36136292df 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -109,10 +109,12 @@ def forward( fp8_meta_forward: bool = True, fp8_meta_index: Optional[int] = None, fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, + data: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None, amax: Optional[torch.Tensor] = None, scale_inv: Optional[torch.Tensor] = None, with_transpose_cache: bool = False, + data_transpose: Optional[torch.Tensor] = None, ) -> Float8Tensor: # pylint: disable=missing-function-docstring @@ -125,7 +127,8 @@ def forward( device = torch.device("cuda") # FP8 data buffer - data = torch.empty(tensor.size(), dtype=torch.uint8, device=device) + if data is None: + data = torch.empty(tensor.size(), dtype=torch.uint8, device=device) # Check scale if scale is None and fp8_meta is None: @@ -140,8 +143,7 @@ def forward( scale_inv = scale_inv.to(device=device, dtype=torch.float32) # Transpose cache - data_transpose = None - if with_transpose_cache: + if data_transpose is None and with_transpose_cache: data_transpose = torch.empty( (data.size(-1), data.numel() // data.size(-1)), dtype=torch.uint8, @@ -172,7 +174,7 @@ def backward( ) -> Tuple[Optional[torch.Tensor], ...]: # pylint: disable=missing-function-docstring # Assume that we want gradients in full precision - return grad, None, None, None, None, None, None, None + return grad, None, None, None, None, None, None, None, None, None class _IdentityFunc(torch.autograd.Function): @@ -688,10 +690,12 @@ def to_float8( fp8_meta_forward: bool = True, fp8_meta_index: Optional[int] = None, fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, + data: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None, amax: Optional[torch.Tensor] = None, scale_inv: Optional[torch.Tensor] = None, with_transpose_cache: bool = False, + data_transpose: Optional[torch.Tensor] = None, ): """Construct Float8Tensor from plain PyTorch tensor""" return _ToFloat8Func.apply( @@ -700,10 +704,12 @@ def to_float8( fp8_meta_forward, fp8_meta_index, fp8_dtype, + data, scale, amax, scale_inv, with_transpose_cache, + data_transpose, ) def detach(self) -> Float8Tensor: