Skip to content

Commit

Permalink
Build with -ftrapping-math.
Browse files Browse the repository at this point in the history
This fixes the spurious cast warnings seen in jax-ml#106, which occured when, for example, casting a NaN float to a bfloat16. Casting NaNs between floating point types should not warn in NumPy.

Reverts jax-ml#107, since we now fix rather than suppressing the errors.
  • Loading branch information
hawkinsp committed Oct 9, 2023
1 parent 3570723 commit 79e4a1a
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

## [Unreleased]

* Fixed spurious invalid value warnings when casting between floating point
types on Mac ARM.

## [0.3.1] - 2023-09-22

* Added support for int4 casting to wider integers such as int8
Expand Down
11 changes: 1 addition & 10 deletions ml_dtypes/tests/custom_float_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,6 @@ def testRoundTripToInt(self, float_type):
self.assertEqual(v, int(float_type(v)))
self.assertEqual(-v, int(float_type(-v)))

@ignore_warning(
category=RuntimeWarning, message="invalid value encountered in cast"
)
@ignore_warning(category=RuntimeWarning, message="overflow encountered")
def testRoundTripToNumpy(self, float_type):
for dtype in [
Expand All @@ -241,7 +238,7 @@ def testRoundTripToNumpy(self, float_type):
),
)

def testBetweenCustomTypes(self, float_type):
def testCastBetweenCustomTypes(self, float_type):
for dtype in FLOAT_DTYPES:
x = np.array(FLOAT_VALUES[float_type], dtype=dtype)
y = x.astype(float_type)
Expand Down Expand Up @@ -460,9 +457,6 @@ def testSort(self, float_type):
sorted_float_type = np.sort(values_to_sort.astype(float_type)) # pylint: disable=too-many-function-args
np.testing.assert_equal(sorted_f32, np.float32(sorted_float_type))

@ignore_warning(
category=RuntimeWarning, message="invalid value encountered in cast"
)
def testArgmax(self, float_type):
values_to_sort = np.float32(
float_type(np.float32(FLOAT_VALUES[float_type]))
Expand All @@ -485,9 +479,6 @@ def testArgmaxOnNegativeInfinity(self, float_type):
inf = np.array([float("-inf")], dtype=np.float32)
np.testing.assert_equal(np.argmax(inf.astype(float_type)), np.argmax(inf))

@ignore_warning(
category=RuntimeWarning, message="invalid value encountered in cast"
)
def testArgmin(self, float_type):
values_to_sort = np.float32(
float_type(np.float32(FLOAT_VALUES[float_type]))
Expand Down
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
COMPILE_ARGS = [
"-std=c++17",
"-DEIGEN_MPL2_ONLY",
# -ftrapping-math is necessary because NumPy looks at floating point
# exception state to determine whether to emit, e.g., invalid value
# warnings. Without this setting, on Mac ARM we see spurious "invalid
# value" warnings when running the tests.
"-ftrapping-math",
]

exclude = ["third_party*"]
Expand Down

0 comments on commit 79e4a1a

Please sign in to comment.