Skip to content

Commit

Permalink
Fix nextafter for FP8 FNUZ types.
Browse files Browse the repository at this point in the history
Before, it would return NaN if calling nextafter(negative_value_closest_to_zero, 1).

PiperOrigin-RevId: 681694699
  • Loading branch information
reedwm authored and Google-ML-Automation committed Oct 3, 2024
1 parent 92e3c7a commit 6349c83
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 9 deletions.
18 changes: 18 additions & 0 deletions xla/hlo/builder/lib/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <array>
#include <cmath>
#include <cstdint>
#include <functional>
#include <limits>
#include <vector>
Expand Down Expand Up @@ -1487,6 +1488,23 @@ XlaOp NextAfter(XlaOp from, XlaOp to) {
Broadcast(ScalarLike(from_as_int, -1), shape.dimensions()),
Broadcast(ScalarLike(from_as_int, 1), shape.dimensions()));
auto result = Add(from_as_int, magnitude_adjustment);

if (shape.element_type() == F8E5M2FNUZ ||
shape.element_type() == F8E4M3FNUZ ||
shape.element_type() == F8E4M3B11FNUZ) {
// Handle 'from' is the negative value closest to zero and 'to' is
// positive. For FNUZ dtypes, the result is +0 instead of -0 since -0
// represents a NaN value.
const int64_t min_negative = sign_mask | 1;
auto to_is_nonnegative = Not(ConvertElementType(to_sign, PRED));
auto predicate =
And(Eq(from_as_int, ScalarLike(from_as_int, min_negative)),
to_is_nonnegative);
auto result_if_predicate =
Broadcast(ScalarLike(from_as_int, 0), shape.dimensions());
result = Select(predicate, result_if_predicate, result);
}

// Handle from == ±0.
result = Select(from_is_zero,
Select(to_is_zero, result_for_both_zero,
Expand Down
3 changes: 3 additions & 0 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ py_strict_library(
"@absl_py//absl/logging",
"@absl_py//absl/testing:absltest",
"@absl_py//absl/testing:parameterized",
"@ml_dtypes",
] + if_google(["//third_party/py/numpy"]),
)

Expand All @@ -114,6 +115,7 @@ py_strict_test(
"@absl_py//absl/logging",
"@absl_py//absl/testing:absltest",
"@absl_py//absl/testing:parameterized",
"@ml_dtypes",
] + if_google(["//third_party/py/numpy"]) + xla_py_test_deps(),
)

Expand Down Expand Up @@ -150,6 +152,7 @@ py_strict_test(
"@absl_py//absl/logging",
"@absl_py//absl/testing:absltest",
"@absl_py//absl/testing:parameterized",
"@ml_dtypes",
] + if_google(
[
":xla_gpu_extension",
Expand Down
37 changes: 28 additions & 9 deletions xla/python/xla_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
import numpy as np

from xla.python import xla_client
Expand Down Expand Up @@ -152,7 +153,8 @@ def TestFactory(xla_backend,
# TODO(zhangqiaorjc): test fp8 types when XLA support is complete.
# standard_dtypes is only used for BufferProtocolTest so we only test fp8
# round trip tests.
standard_dtypes += [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2]
fp8_dtypes = [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2]
standard_dtypes += fp8_dtypes
# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
# standard_dtypes += [float8_e3m4, float8_e4m3]
dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes
Expand Down Expand Up @@ -2191,15 +2193,32 @@ def testFft(self):
c, expected=[np.fft.irfftn(a, axes=(1, 2, 3))], rtol=2e-4
)

def testNextAfter(self):
c = self._NewComputation()
ops.NextAfter(
ops.Constant(c, np.array([1, 2], dtype=np.float32)),
ops.Constant(c, np.array([2, 1], dtype=np.float32)))
@parameterized.named_parameters({
"testcase_name": "_{}".format(dtype.__name__),
"dtype": dtype,
} for dtype in float_dtypes + fp8_dtypes)
def testNextAfter(self, dtype):
finfo = ml_dtypes.finfo(dtype)
eps = finfo.eps
c = self._NewComputation()
# Each row is (value, direction, expected), where
# 'nextafter(value, direction)' should be 'expected'.
data = np.array(
[
[1, 2, 1 + finfo.eps],
[2, 1, 2 - eps],
[-0., 1, finfo.smallest_subnormal],
[0., -1, -finfo.smallest_subnormal],
[-finfo.smallest_subnormal, 1, -0.],
[finfo.smallest_subnormal, 1, 2 * finfo.smallest_subnormal],
[finfo.smallest_subnormal, -1, 0],
],
dtype=dtype,
)

ops.NextAfter(ops.Constant(c, data[:, 0]), ops.Constant(c, data[:, 1]))
out, = self._Execute(c, ())
eps = np.finfo(np.float32).eps
np.testing.assert_equal(
np.array([eps + 1, 2 - eps], dtype=np.float32), out)
np.testing.assert_equal(data[:, 2], out)

@parameterized.named_parameters({
"testcase_name": "_{}".format(dtype.__name__),
Expand Down

0 comments on commit 6349c83

Please sign in to comment.