forked from pytorch/FBGEMM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
env variable to select rounding mode (pytorch#3515)
Summary: X-link: facebookresearch/FBGEMM#595 Accuracy issue was reported by ads team, specifically when the intput tensor is large, some times we get inf relative difference. It happens because abs diff > expected diff and a non-zero value after quant and dequant becomes 0 (so divisor is 0), meaning the root cause is the abs diff is larger than expected. We can reproduce the problem with the following small size input, specifically -502.516 will become 0 after quant and dequant ``` -180.8454,276.3368,892.1324, 1101.1176, -502.5216,-302.0942,2268.5430,-5960.6919 ``` ideally -502 should be -500. The reason it becomes 0 is that in mx4 quant, number is scaled down by 2^shared_exponent (of that group) and the value of shared_exponent is impacted by rounding method. If shared_exponent is (relatively) bigger, after scaling, many number become small so we lose a bunch of info. Out of all rounding, floor should give the smallest exponent, ceil probably gives the biggest, even and nearest hard to say since they can round up or down depending on the input but likely still be smaller than ceil, stochastic tries to round down after adding some noise, so probably better or on par with even and nearest, worse than floor. This is also verified by the unit tests. whe rounding is set to floor and stochastic, tests pass, otherwise fail This diff enables selecting rounding mode through env variable. If a rounding method is provided through function call, it takes precedence otherwise it looks at env variable. Default is nearest Differential Revision: D67425485
- Loading branch information
1 parent
a75d8fe
commit af1198d
Showing
4 changed files
with
101 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
import os | ||
|
||
from typing import Any, Callable, Dict | ||
|
||
# pyre-ignore[5] | ||
environment_variables: Dict[str, Callable[[], Any]] = { | ||
# Decide which rounding mode to use when doing quantization and dequantization to/from MX4 | ||
# check https://fburl.com/code/rohboxgv for what's available | ||
"MX4_QUANT_ROUNDING_MODE": lambda: os.getenv("MX4_QUANT_ROUNDING_MODE", "nearest"), | ||
} | ||
|
||
|
||
# pyre-ignore[3] | ||
def __getattr__(name: str): | ||
if name in environment_variables: | ||
return environment_variables[name]() | ||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}") | ||
|
||
|
||
# pyre-ignore[3] | ||
def __dir__(): | ||
return list(environment_variables.keys()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters