From af1198d426a6a2bea2d6fe7ea7a6eecb97cb3b60 Mon Sep 17 00:00:00 2001 From: Jingyuan Fan Date: Thu, 19 Dec 2024 17:36:15 -0800 Subject: [PATCH] env variable to select rounding mode (#3515) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/595 Accuracy issue was reported by ads team, specifically when the intput tensor is large, some times we get inf relative difference. It happens because abs diff > expected diff and a non-zero value after quant and dequant becomes 0 (so divisor is 0), meaning the root cause is the abs diff is larger than expected. We can reproduce the problem with the following small size input, specifically -502.516 will become 0 after quant and dequant ``` -180.8454,276.3368,892.1324, 1101.1176, -502.5216,-302.0942,2268.5430,-5960.6919 ``` ideally -502 should be -500. The reason it becomes 0 is that in mx4 quant, number is scaled down by 2^shared_exponent (of that group) and the value of shared_exponent is impacted by rounding method. If shared_exponent is (relatively) bigger, after scaling, many number become small so we lose a bunch of info. Out of all rounding, floor should give the smallest exponent, ceil probably gives the biggest, even and nearest hard to say since they can round up or down depending on the input but likely still be smaller than ceil, stochastic tries to round down after adding some noise, so probably better or on par with even and nearest, worse than floor. This is also verified by the unit tests. whe rounding is set to floor and stochastic, tests pass, otherwise fail This diff enables selecting rounding mode through env variable. If a rounding method is provided through function call, it takes precedence otherwise it looks at env variable. Default is nearest Differential Revision: D67425485 --- fbgemm_gpu/fbgemm_gpu/envs.py | 30 +++++++++++ fbgemm_gpu/fbgemm_gpu/quantize_utils.py | 13 +++-- fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py | 9 ++-- fbgemm_gpu/test/quantize/mx4_test.py | 57 ++++++++++++++++++++ 4 files changed, 101 insertions(+), 8 deletions(-) create mode 100644 fbgemm_gpu/fbgemm_gpu/envs.py diff --git a/fbgemm_gpu/fbgemm_gpu/envs.py b/fbgemm_gpu/fbgemm_gpu/envs.py new file mode 100644 index 0000000000..576c53be86 --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/envs.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import os + +from typing import Any, Callable, Dict + +# pyre-ignore[5] +environment_variables: Dict[str, Callable[[], Any]] = { + # Decide which rounding mode to use when doing quantization and dequantization to/from MX4 + # check https://fburl.com/code/rohboxgv for what's available + "MX4_QUANT_ROUNDING_MODE": lambda: os.getenv("MX4_QUANT_ROUNDING_MODE", "nearest"), +} + + +# pyre-ignore[3] +def __getattr__(name: str): + if name in environment_variables: + return environment_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +# pyre-ignore[3] +def __dir__(): + return list(environment_variables.keys()) diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_utils.py b/fbgemm_gpu/fbgemm_gpu/quantize_utils.py index 7646edf4ef..7d830d7515 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize_utils.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize_utils.py @@ -11,6 +11,7 @@ from typing import Optional, Union import torch +from fbgemm_gpu import envs from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4, RoundingMode from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4 @@ -30,13 +31,16 @@ TORCH_BFLOAT16_MIN: float = torch.finfo(torch.bfloat16).min TORCH_BFLOAT16_MAX: float = torch.finfo(torch.bfloat16).max +# pyre-ignore[5] +MX4_QUANT_ROUNDING_MODE = envs.MX4_QUANT_ROUNDING_MODE + def fp32_to_mx4( tensor: torch.Tensor, group_size: int = 32, ebits: int = 2, mbits: int = 1, - rounding_mode: Optional[Union[RoundingMode, int]] = RoundingMode.even, + rounding_mode: Optional[Union[RoundingMode, int]] = None, stochastic_casting: bool = False, use_triton: bool = True, ) -> torch.Tensor: @@ -57,8 +61,11 @@ def fp32_to_mx4( """ # Accelerated MX4 is only available on cuda, if input is on cpu, use python. # Operate on flattened input. - if rounding_mode is None: - rounding_mode = RoundingMode.even + rounding_mode = ( + RoundingMode.__members__.get(MX4_QUANT_ROUNDING_MODE, RoundingMode.even) + if rounding_mode is None + else rounding_mode + ) if not tensor.is_cuda: return py_quantize_mx4( diff --git a/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py b/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py index f0352b2707..6cc17244ed 100644 --- a/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py +++ b/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py @@ -83,7 +83,7 @@ def py_quantize_mx4( eg. Input with shape [1, 8192] will be quantized to [1, 4096 + 256] as each value contain two elements packed into an int8 and - there are 32 groups in each row. + there are 256 (8192 / group_size) groups in each row. """ # Define helpful constants. FP32_MIN_NORMAL = 2 ** (-126) @@ -150,16 +150,15 @@ def py_quantize_mx4( biased_exp = torch.bitwise_and(a, FP32_EXP_MASK) # Shift exponent over to least significant bits. biased_exp = torch.bitwise_right_shift(biased_exp, FP32_EXP_OFFSET).to(torch.int8) - - # Finally extract the mantissa. - trailing_mantissa = torch.bitwise_and(a, FP32_MANTISSA_MASK) new_biased_exp = biased_exp - FP32_EXP_BIAS + FP4_EXP_BIAS - # Compute difference between ideal exponent and what can be represented. exp_diff = torch.where(new_biased_exp <= 0, 1 - new_biased_exp, 0) # Clip this difference to the maximum number of fp32 mantissa bits (23 + implicit). exp_diff = torch.clamp(exp_diff, max=MAX_FP32_MANTISSA_BITS) + # Finally extract the mantissa. + trailing_mantissa = torch.bitwise_and(a, FP32_MANTISSA_MASK) + # Now perform mantissa rounding down to fp4. is_subnorm = biased_exp == 0 # Add implied 1 to normal values. diff --git a/fbgemm_gpu/test/quantize/mx4_test.py b/fbgemm_gpu/test/quantize/mx4_test.py index 03b1608116..da1e89ded3 100644 --- a/fbgemm_gpu/test/quantize/mx4_test.py +++ b/fbgemm_gpu/test/quantize/mx4_test.py @@ -319,6 +319,63 @@ def test_mx4_index_overflow_large_input(self) -> None: # We just need to check that everything ran without an illegal memory access. assert mx_dequantized[0][0] == 0 + @unittest.skipIf( + not ( + torch.cuda.is_available() and torch.cuda.mem_get_info()[0] / (1024**3) >= 32 + ), + "Test requires a gpu with at least 32GB of memory.", + ) + # pyre-ignore[56] + @given( + shape=st.sampled_from( + [ + [2 ^ 31 - 1], # Small shape with group_size = num_elements. + [1024 * 1024, 1024], # Multi dimensional shape that is padded. + [16, 1028], # Large shape with multiple padded rows. + [4, 30], # Multiple small rows with padding. + ] + ), + group_size=st.sampled_from([32, 64]), + magnitude=st.sampled_from([1.0, 1e3, 1e-3]), + mx4_format=st.sampled_from([(2, 1)]), + device=st.sampled_from(["cuda"]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=50, deadline=None) + def test_mx4_large_cases( + self, + shape: List[int], + group_size: int, + magnitude: int, + mx4_format: Tuple[int, int], + device: str, + ) -> None: + """Test correctness of mx4 routines with random inputs and unusual shapes.""" + # We only want to consider total sizes that are divisible by group_size. + ebits, mbits = mx4_format + + # Generate a random input with the specified magnitude. + input = torch.randn(shape, device=device, dtype=torch.float32) * magnitude + + # Perform quant then dequant to check that proper shape is maintained and + # outputs are reasonably correct. + mx_quantized = fp32_to_mx4(input, group_size, ebits=ebits, mbits=mbits) + mx_dequantized = mx4_to_fp32(mx_quantized, group_size, ebits=ebits, mbits=mbits) + + # If the rows of input are not divisible by group_size, we expect the output + # to be padded. + if input.shape[-1] % group_size != 0: + pad = group_size - (input.shape[-1] % group_size) + input = torch.nn.functional.pad(input, (0, pad)) + + # Check that output shape matches input shape. + assert mx_dequantized.shape == input.shape + + # Check that values are reasonably close, based on expected variance. + # I give quite a bit of wiggle room to make sure this isnt flaky. + torch.testing.assert_close(input, mx_dequantized, rtol=1.0, atol=magnitude / 2) + assert torch.isnan(mx_dequantized).any() == False + assert torch.isinf(mx_dequantized).any() == False + if __name__ == "__main__": unittest.main()