diff --git a/test/prototype/mx_formats/test_mx_dtensor.py b/test/prototype/mx_formats/test_mx_dtensor.py new file mode 100644 index 0000000000..bfc930c579 --- /dev/null +++ b/test/prototype/mx_formats/test_mx_dtensor.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +""" +Test numerics of manually defined float16 TP vs mxfp8 TP of toy models + +Note: for now, this does not run in CI. +TODO(future): make this run in CI +""" + +import os + +import pytest +import torch + +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 + +if not TORCH_VERSION_AT_LEAST_2_7: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + +from torch.distributed._tensor import DTensor, Shard, distribute_tensor +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from tqdm import tqdm + +from torchao.prototype.mx_formats import MXLinearConfig +from torchao.prototype.mx_formats.mx_tensor import MXTensor +from torchao.testing.training.dtensor_utils import ( + _test_lowp_mlp_tensor_parallelism_base, +) + +torch.set_float32_matmul_precision("high") + + +def setup_distributed(): + world_size = int(os.environ.get("WORLD_SIZE", -1)) + device_mesh = init_device_mesh("cuda", (world_size,)) + # seed must be the same in all processes + torch.manual_seed(1) + local_rank = torch.distributed.get_rank() + torch.cuda.set_device(local_rank) + return device_mesh + + +def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4): + device = mesh.device_type + + x_fp32 = torch.rand(size, size, device=device) + x_fp8 = MXTensor.to_mx(x_fp32, torch.float8_e4m3fn, block_size=size // 2) + + dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)]) + dist_x_fp8 = MXTensor.to_mx(dist_x_fp32, torch.float8_e4m3fn, block_size=size // 2) + assert isinstance(dist_x_fp8, DTensor) + + # Verify that the result of to_mx with DTensor matches the slice of the + # result of to_mx without DTensor. This will fail on numeric op mismatches. + local_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + assert size % world_size == 0, "unsupported" + x_fp8_fp32 = x_fp8.to_dtype(torch.float32) + rows_per_slice = size // world_size + slice_start = local_rank * rows_per_slice + slice_end = (local_rank + 1) * rows_per_slice + x_fp8_fp32_slice = x_fp8_fp32[slice_start:slice_end] + torch.testing.assert_close( + x_fp8_fp32_slice, dist_x_fp8.to_local().to_dtype(torch.float32), atol=0, rtol=0 + ) + + +def _test_mxfp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): + config = MXLinearConfig.from_recipe_name("mxfp8_emulated") + # TODO(future PR): assert that the K dim must be divisible by block size, + # today this is silently incorrect if block_size is greater than K + config.block_size = 16 + _test_lowp_mlp_tensor_parallelism_base( + mesh, config, size, compile=False, allgather_in_lowp=False + ) + + # TODO(future PR): compile + + +if __name__ == "__main__": + device_mesh = setup_distributed() + tests = [ + _test_dtensor_cast_to_mxfp8, + # TODO(next PR): enable this (current PR got too large, so splitting) + # _test_mxfp8_mlp_tensor_parallelism_eager, + ] + + for test in tqdm(tests, desc="Running tests"): + try: + test(device_mesh) + except Exception as e: + print(f"Test {test.__name__} failed with error: {e}") + raise e + + torch.distributed.destroy_process_group() diff --git a/test/prototype/mx_formats/test_mx_dtensor.sh b/test/prototype/mx_formats/test_mx_dtensor.sh new file mode 100755 index 0000000000..abf9424e3c --- /dev/null +++ b/test/prototype/mx_formats/test_mx_dtensor.sh @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +#!/bin/bash + +# terminate script on first error +set -e + +if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then + echo "Skipping test_dtensor.sh because no CUDA devices are available." + exit +fi + +# integration tests for TP/SP +NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_mx_dtensor.py diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index eacf0ac5df..f96e73a55a 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1102,15 +1102,12 @@ def _triton_calculate_scale(x, axis): bf16_mbits = 7 bf16_exp_bias = 127 fp32_mbits = 23 - # We use a small epsilon to avoid division by zero - epsilon = 1e-10 # Find the maximum absolute value for each row max_abs = tl.max(x, axis=axis) # Calculate the e8m0 scale by extracting the exponent (floor) # TODO(future PR): support other exponent extraction types (ceil, RNE) - max_abs = max_abs + epsilon max_abs = max_abs.to(tl.bfloat16) max_abs_int16 = max_abs.to(tl.int16, bitcast=True) extracted_pow2 = ((max_abs_int16 >> bf16_mbits) & 0b11111111) - bf16_exp_bias diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 784d3eda6d..ef9ae42fcd 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -21,6 +21,7 @@ from typing import Callable, Dict, Union import torch +from torch.distributed._tensor import DTensor from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.constants import ( @@ -166,6 +167,8 @@ def to_mx( # calculate the scale in e8m0 format orig_shape = data_hp.shape + # TODO(future PR): fix this line for TP, currently this reshape does not work + # for rank 3 tensor where dim1 is sharded data_hp = data_hp.reshape(-1, block_size) # find max value of the data @@ -174,10 +177,6 @@ def to_mx( # section 6.3. max_abs = torch.amax(torch.abs(data_hp), 1) - # Add an epsilon to prevent the log2 function call for returning -inf - # where the values are zero. - eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype) - # Set X to be the largest power-of-two less than or equal to # max_abs(v), divided by the largest power of two representable # in the element data type, and get the mbits at the same time @@ -233,8 +232,12 @@ def to_mx( ) # Calculate the scale for different modes - max_abs_int32 = (max_abs + eps).view(hp_int_dtype) - extracted_pow2 = ((max_abs_int32 >> hp_mbits) & 0b11111111) - hp_exp_bias + max_abs_int32 = max_abs.view(hp_int_dtype) + # For now, use `torch.bitwise_right_shift` instead of `>>` to support DTensor + # See https://github.com/pytorch/pytorch/issues/156533. + extracted_pow2 = ( + (torch.bitwise_right_shift(max_abs_int32, hp_mbits)) & 0b11111111 + ) - hp_exp_bias if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN): scale_e8m0_unbiased = extracted_pow2 - target_max_pow2 @@ -266,9 +269,11 @@ def to_mx( ) # For now, calculate the scale in floating point. - scale_fp32 = (scale_e8m0_biased.to(torch.int32) << MBITS_F32).view( - torch.float32 - ) + # For now, use `torch.bitwise_left_shift` instead of `<<` to support DTensor + # See https://github.com/pytorch/pytorch/issues/156533. + scale_fp32 = ( + torch.bitwise_left_shift(scale_e8m0_biased.to(torch.int32), MBITS_F32) + ).view(torch.float32) # Today, 2**-127 returns 0 in compile+inductor+triton because it is in the # float32 denormal range. For now, manually adjust the fp scale. This is @@ -597,6 +602,28 @@ def to_mx( scale_e8m0_biased, data_lp = to_mx( data_hp, elem_dtype, block_size, scaling_mode, pack_fp6 ) + if isinstance(scale_e8m0_biased, DTensor): + assert isinstance(data_lp, DTensor), "unsupported" + local_scale_e8m0_biased = scale_e8m0_biased.to_local() + local_data_lp = data_lp.to_local() + inner_mx_tensor = MXTensor( + local_scale_e8m0_biased, + local_data_lp, + elem_dtype, + block_size, + data_hp.dtype, + use_fp4_custom_triton_dequant_kernel, + gemm_kernel_choice, + pack_fp6, + ) + return DTensor.from_local( + inner_mx_tensor, + data_lp.device_mesh, + data_lp.placements, + run_check=False, + shape=data_lp.size(), + stride=data_lp.stride(), + ) return MXTensor( scale_e8m0_biased, data_lp, diff --git a/torchao/testing/training/dtensor_utils.py b/torchao/testing/training/dtensor_utils.py index 7ac0360363..815ee20969 100644 --- a/torchao/testing/training/dtensor_utils.py +++ b/torchao/testing/training/dtensor_utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import copy +from typing import Union import torch import torch.nn as nn @@ -24,6 +25,8 @@ Float8RowwiseParallel, PrepareFloat8ModuleInput, ) +from torchao.prototype.mx_formats.config import MXLinearConfig +from torchao.quantization import quantize_ class FeedForward(nn.Module): @@ -36,7 +39,9 @@ def __init__(self): self.out_proj = nn.Linear(32, 16, bias=False) def forward(self, x): - return self.out_proj(F.silu(self.w1(x)) * self.w2(x)) + x = F.silu(self.w1(x)) * self.w2(x) + x = self.out_proj(x) + return x class ToyModel(nn.Module): @@ -50,20 +55,26 @@ def forward(self, x): def _test_lowp_mlp_tensor_parallelism_base( mesh: DeviceMesh, - config: Float8LinearConfig, + config: Union[Float8LinearConfig, MXLinearConfig], size=16, compile: bool = False, allgather_in_lowp: bool = False, ): device = mesh.device_type + # TODO(future): remove this once float8 training works with `quantize_` API + convert_model_func = convert_to_float8_training + if isinstance(config, MXLinearConfig): + convert_model_func = quantize_ + toy_model = ToyModel().to(device) - toy_model_fp8 = convert_to_float8_training(toy_model, config=config) + toy_model_fp8 = copy.deepcopy(toy_model) + convert_model_func(toy_model_fp8, config=config) tp_model = copy.deepcopy(toy_model) - tp_model = convert_to_float8_training(tp_model, config=config) + convert_model_func(tp_model, config=config) sp_model = copy.deepcopy(toy_model) - sp_model = convert_to_float8_training(sp_model, config=config) + convert_model_func(sp_model, config=config) # For tensorwise scaling, enable float8 all_gather. # For rowwise scaling, keep high precision all_gather. Motivation for @@ -108,7 +119,7 @@ def _test_lowp_mlp_tensor_parallelism_base( # prepare_input_cls with specific submodule fqn sp_model2 = copy.deepcopy(toy_model) - sp_model2 = convert_to_float8_training(sp_model2, config=config) + convert_model_func(sp_model2, config=config) if not allgather_in_lowp: prepare_input = prepare_input_cls(