Skip to content

Commit

Permalink
env variable to select rounding mode (pytorch#3515)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#595

Accuracy issue was reported by ads team, specifically when the intput tensor is large, some times we get inf relative difference. It happens because abs diff > expected diff and a non-zero value after quant and dequant becomes 0 (so divisor is 0), meaning the root cause is the abs diff is larger than expected. We can reproduce the problem with the following small size input, specifically -502.516 will become 0 after quant and dequant
```
-180.8454,276.3368,892.1324, 1101.1176, -502.5216,-302.0942,2268.5430,-5960.6919
```

ideally -502 should be -500. The reason it becomes 0 is that in mx4 quant, number is scaled down by 2^shared_exponent (of that group) and the value of shared_exponent is impacted by rounding method. If shared_exponent is (relatively) bigger, after scaling, many number become small so we lose a bunch of info. Out of all rounding, floor should give the smallest exponent, ceil probably gives the biggest, even and nearest hard to say since they can round up or down depending on the input but likely still be smaller than ceil,
stochastic tries to round down after adding some noise, so probably better or on par with even and nearest, worse than floor.

This is also verified by the unit tests. whe rounding is set to floor and stochastic, tests pass, otherwise fail

This diff enables selecting rounding mode through env variable. If a rounding method is provided through function call, it takes precedence otherwise it looks at env variable. Default is nearest

Differential Revision: D67425485
  • Loading branch information
Jingyuan Fan authored and facebook-github-bot committed Dec 20, 2024
1 parent a75d8fe commit af1198d
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 8 deletions.
30 changes: 30 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import os

from typing import Any, Callable, Dict

# pyre-ignore[5]
environment_variables: Dict[str, Callable[[], Any]] = {
# Decide which rounding mode to use when doing quantization and dequantization to/from MX4
# check https://fburl.com/code/rohboxgv for what's available
"MX4_QUANT_ROUNDING_MODE": lambda: os.getenv("MX4_QUANT_ROUNDING_MODE", "nearest"),
}


# pyre-ignore[3]
def __getattr__(name: str):
if name in environment_variables:
return environment_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


# pyre-ignore[3]
def __dir__():
return list(environment_variables.keys())
13 changes: 10 additions & 3 deletions fbgemm_gpu/fbgemm_gpu/quantize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Optional, Union

import torch
from fbgemm_gpu import envs

from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4, RoundingMode
from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4
Expand All @@ -30,13 +31,16 @@
TORCH_BFLOAT16_MIN: float = torch.finfo(torch.bfloat16).min
TORCH_BFLOAT16_MAX: float = torch.finfo(torch.bfloat16).max

# pyre-ignore[5]
MX4_QUANT_ROUNDING_MODE = envs.MX4_QUANT_ROUNDING_MODE


def fp32_to_mx4(
tensor: torch.Tensor,
group_size: int = 32,
ebits: int = 2,
mbits: int = 1,
rounding_mode: Optional[Union[RoundingMode, int]] = RoundingMode.even,
rounding_mode: Optional[Union[RoundingMode, int]] = None,
stochastic_casting: bool = False,
use_triton: bool = True,
) -> torch.Tensor:
Expand All @@ -57,8 +61,11 @@ def fp32_to_mx4(
"""
# Accelerated MX4 is only available on cuda, if input is on cpu, use python.
# Operate on flattened input.
if rounding_mode is None:
rounding_mode = RoundingMode.even
rounding_mode = (
RoundingMode.__members__.get(MX4_QUANT_ROUNDING_MODE, RoundingMode.even)
if rounding_mode is None
else rounding_mode
)

if not tensor.is_cuda:
return py_quantize_mx4(
Expand Down
9 changes: 4 additions & 5 deletions fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def py_quantize_mx4(
eg.
Input with shape [1, 8192] will be quantized to [1, 4096 + 256] as
each value contain two elements packed into an int8 and
there are 32 groups in each row.
there are 256 (8192 / group_size) groups in each row.
"""
# Define helpful constants.
FP32_MIN_NORMAL = 2 ** (-126)
Expand Down Expand Up @@ -150,16 +150,15 @@ def py_quantize_mx4(
biased_exp = torch.bitwise_and(a, FP32_EXP_MASK)
# Shift exponent over to least significant bits.
biased_exp = torch.bitwise_right_shift(biased_exp, FP32_EXP_OFFSET).to(torch.int8)

# Finally extract the mantissa.
trailing_mantissa = torch.bitwise_and(a, FP32_MANTISSA_MASK)
new_biased_exp = biased_exp - FP32_EXP_BIAS + FP4_EXP_BIAS

# Compute difference between ideal exponent and what can be represented.
exp_diff = torch.where(new_biased_exp <= 0, 1 - new_biased_exp, 0)
# Clip this difference to the maximum number of fp32 mantissa bits (23 + implicit).
exp_diff = torch.clamp(exp_diff, max=MAX_FP32_MANTISSA_BITS)

# Finally extract the mantissa.
trailing_mantissa = torch.bitwise_and(a, FP32_MANTISSA_MASK)

# Now perform mantissa rounding down to fp4.
is_subnorm = biased_exp == 0
# Add implied 1 to normal values.
Expand Down
57 changes: 57 additions & 0 deletions fbgemm_gpu/test/quantize/mx4_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,63 @@ def test_mx4_index_overflow_large_input(self) -> None:
# We just need to check that everything ran without an illegal memory access.
assert mx_dequantized[0][0] == 0

@unittest.skipIf(
not (
torch.cuda.is_available() and torch.cuda.mem_get_info()[0] / (1024**3) >= 32
),
"Test requires a gpu with at least 32GB of memory.",
)
# pyre-ignore[56]
@given(
shape=st.sampled_from(
[
[2 ^ 31 - 1], # Small shape with group_size = num_elements.
[1024 * 1024, 1024], # Multi dimensional shape that is padded.
[16, 1028], # Large shape with multiple padded rows.
[4, 30], # Multiple small rows with padding.
]
),
group_size=st.sampled_from([32, 64]),
magnitude=st.sampled_from([1.0, 1e3, 1e-3]),
mx4_format=st.sampled_from([(2, 1)]),
device=st.sampled_from(["cuda"]),
)
@settings(verbosity=Verbosity.verbose, max_examples=50, deadline=None)
def test_mx4_large_cases(
self,
shape: List[int],
group_size: int,
magnitude: int,
mx4_format: Tuple[int, int],
device: str,
) -> None:
"""Test correctness of mx4 routines with random inputs and unusual shapes."""
# We only want to consider total sizes that are divisible by group_size.
ebits, mbits = mx4_format

# Generate a random input with the specified magnitude.
input = torch.randn(shape, device=device, dtype=torch.float32) * magnitude

# Perform quant then dequant to check that proper shape is maintained and
# outputs are reasonably correct.
mx_quantized = fp32_to_mx4(input, group_size, ebits=ebits, mbits=mbits)
mx_dequantized = mx4_to_fp32(mx_quantized, group_size, ebits=ebits, mbits=mbits)

# If the rows of input are not divisible by group_size, we expect the output
# to be padded.
if input.shape[-1] % group_size != 0:
pad = group_size - (input.shape[-1] % group_size)
input = torch.nn.functional.pad(input, (0, pad))

# Check that output shape matches input shape.
assert mx_dequantized.shape == input.shape

# Check that values are reasonably close, based on expected variance.
# I give quite a bit of wiggle room to make sure this isnt flaky.
torch.testing.assert_close(input, mx_dequantized, rtol=1.0, atol=magnitude / 2)
assert torch.isnan(mx_dequantized).any() == False
assert torch.isinf(mx_dequantized).any() == False


if __name__ == "__main__":
unittest.main()

0 comments on commit af1198d

Please sign in to comment.