diff --git a/fbgemm_gpu/fbgemm_gpu/envs.py b/fbgemm_gpu/fbgemm_gpu/envs.py new file mode 100644 index 000000000..576c53be8 --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/envs.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import os + +from typing import Any, Callable, Dict + +# pyre-ignore[5] +environment_variables: Dict[str, Callable[[], Any]] = { + # Decide which rounding mode to use when doing quantization and dequantization to/from MX4 + # check https://fburl.com/code/rohboxgv for what's available + "MX4_QUANT_ROUNDING_MODE": lambda: os.getenv("MX4_QUANT_ROUNDING_MODE", "nearest"), +} + + +# pyre-ignore[3] +def __getattr__(name: str): + if name in environment_variables: + return environment_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +# pyre-ignore[3] +def __dir__(): + return list(environment_variables.keys()) diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_utils.py b/fbgemm_gpu/fbgemm_gpu/quantize_utils.py index 7646edf4e..7d830d751 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize_utils.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize_utils.py @@ -11,6 +11,7 @@ from typing import Optional, Union import torch +from fbgemm_gpu import envs from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4, RoundingMode from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4 @@ -30,13 +31,16 @@ TORCH_BFLOAT16_MIN: float = torch.finfo(torch.bfloat16).min TORCH_BFLOAT16_MAX: float = torch.finfo(torch.bfloat16).max +# pyre-ignore[5] +MX4_QUANT_ROUNDING_MODE = envs.MX4_QUANT_ROUNDING_MODE + def fp32_to_mx4( tensor: torch.Tensor, group_size: int = 32, ebits: int = 2, mbits: int = 1, - rounding_mode: Optional[Union[RoundingMode, int]] = RoundingMode.even, + rounding_mode: Optional[Union[RoundingMode, int]] = None, stochastic_casting: bool = False, use_triton: bool = True, ) -> torch.Tensor: @@ -57,8 +61,11 @@ def fp32_to_mx4( """ # Accelerated MX4 is only available on cuda, if input is on cpu, use python. # Operate on flattened input. - if rounding_mode is None: - rounding_mode = RoundingMode.even + rounding_mode = ( + RoundingMode.__members__.get(MX4_QUANT_ROUNDING_MODE, RoundingMode.even) + if rounding_mode is None + else rounding_mode + ) if not tensor.is_cuda: return py_quantize_mx4( diff --git a/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py b/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py index f0352b270..6cc17244e 100644 --- a/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py +++ b/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py @@ -83,7 +83,7 @@ def py_quantize_mx4( eg. Input with shape [1, 8192] will be quantized to [1, 4096 + 256] as each value contain two elements packed into an int8 and - there are 32 groups in each row. + there are 256 (8192 / group_size) groups in each row. """ # Define helpful constants. FP32_MIN_NORMAL = 2 ** (-126) @@ -150,16 +150,15 @@ def py_quantize_mx4( biased_exp = torch.bitwise_and(a, FP32_EXP_MASK) # Shift exponent over to least significant bits. biased_exp = torch.bitwise_right_shift(biased_exp, FP32_EXP_OFFSET).to(torch.int8) - - # Finally extract the mantissa. - trailing_mantissa = torch.bitwise_and(a, FP32_MANTISSA_MASK) new_biased_exp = biased_exp - FP32_EXP_BIAS + FP4_EXP_BIAS - # Compute difference between ideal exponent and what can be represented. exp_diff = torch.where(new_biased_exp <= 0, 1 - new_biased_exp, 0) # Clip this difference to the maximum number of fp32 mantissa bits (23 + implicit). exp_diff = torch.clamp(exp_diff, max=MAX_FP32_MANTISSA_BITS) + # Finally extract the mantissa. + trailing_mantissa = torch.bitwise_and(a, FP32_MANTISSA_MASK) + # Now perform mantissa rounding down to fp4. is_subnorm = biased_exp == 0 # Add implied 1 to normal values. diff --git a/fbgemm_gpu/test/quantize/mx4_test.py b/fbgemm_gpu/test/quantize/mx4_test.py index 03b160811..da1e89ded 100644 --- a/fbgemm_gpu/test/quantize/mx4_test.py +++ b/fbgemm_gpu/test/quantize/mx4_test.py @@ -319,6 +319,63 @@ def test_mx4_index_overflow_large_input(self) -> None: # We just need to check that everything ran without an illegal memory access. assert mx_dequantized[0][0] == 0 + @unittest.skipIf( + not ( + torch.cuda.is_available() and torch.cuda.mem_get_info()[0] / (1024**3) >= 32 + ), + "Test requires a gpu with at least 32GB of memory.", + ) + # pyre-ignore[56] + @given( + shape=st.sampled_from( + [ + [2 ^ 31 - 1], # Small shape with group_size = num_elements. + [1024 * 1024, 1024], # Multi dimensional shape that is padded. + [16, 1028], # Large shape with multiple padded rows. + [4, 30], # Multiple small rows with padding. + ] + ), + group_size=st.sampled_from([32, 64]), + magnitude=st.sampled_from([1.0, 1e3, 1e-3]), + mx4_format=st.sampled_from([(2, 1)]), + device=st.sampled_from(["cuda"]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=50, deadline=None) + def test_mx4_large_cases( + self, + shape: List[int], + group_size: int, + magnitude: int, + mx4_format: Tuple[int, int], + device: str, + ) -> None: + """Test correctness of mx4 routines with random inputs and unusual shapes.""" + # We only want to consider total sizes that are divisible by group_size. + ebits, mbits = mx4_format + + # Generate a random input with the specified magnitude. + input = torch.randn(shape, device=device, dtype=torch.float32) * magnitude + + # Perform quant then dequant to check that proper shape is maintained and + # outputs are reasonably correct. + mx_quantized = fp32_to_mx4(input, group_size, ebits=ebits, mbits=mbits) + mx_dequantized = mx4_to_fp32(mx_quantized, group_size, ebits=ebits, mbits=mbits) + + # If the rows of input are not divisible by group_size, we expect the output + # to be padded. + if input.shape[-1] % group_size != 0: + pad = group_size - (input.shape[-1] % group_size) + input = torch.nn.functional.pad(input, (0, pad)) + + # Check that output shape matches input shape. + assert mx_dequantized.shape == input.shape + + # Check that values are reasonably close, based on expected variance. + # I give quite a bit of wiggle room to make sure this isnt flaky. + torch.testing.assert_close(input, mx_dequantized, rtol=1.0, atol=magnitude / 2) + assert torch.isnan(mx_dequantized).any() == False + assert torch.isinf(mx_dequantized).any() == False + if __name__ == "__main__": unittest.main()