diff --git a/CHANGELOG.md b/CHANGELOG.md index 49789d01..acb1cd68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,9 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ## [Unreleased] +* Fixed spurious invalid value warnings when casting between floating point + types on Mac ARM. + ## [0.3.1] - 2023-09-22 * Added support for int4 casting to wider integers such as int8 diff --git a/ml_dtypes/tests/custom_float_test.py b/ml_dtypes/tests/custom_float_test.py index 8f8fbf78..d71ae8b1 100644 --- a/ml_dtypes/tests/custom_float_test.py +++ b/ml_dtypes/tests/custom_float_test.py @@ -214,9 +214,6 @@ def testRoundTripToInt(self, float_type): self.assertEqual(v, int(float_type(v))) self.assertEqual(-v, int(float_type(-v))) - @ignore_warning( - category=RuntimeWarning, message="invalid value encountered in cast" - ) @ignore_warning(category=RuntimeWarning, message="overflow encountered") def testRoundTripToNumpy(self, float_type): for dtype in [ @@ -241,7 +238,7 @@ def testRoundTripToNumpy(self, float_type): ), ) - def testBetweenCustomTypes(self, float_type): + def testCastBetweenCustomTypes(self, float_type): for dtype in FLOAT_DTYPES: x = np.array(FLOAT_VALUES[float_type], dtype=dtype) y = x.astype(float_type) @@ -460,9 +457,6 @@ def testSort(self, float_type): sorted_float_type = np.sort(values_to_sort.astype(float_type)) # pylint: disable=too-many-function-args np.testing.assert_equal(sorted_f32, np.float32(sorted_float_type)) - @ignore_warning( - category=RuntimeWarning, message="invalid value encountered in cast" - ) def testArgmax(self, float_type): values_to_sort = np.float32( float_type(np.float32(FLOAT_VALUES[float_type])) @@ -485,9 +479,6 @@ def testArgmaxOnNegativeInfinity(self, float_type): inf = np.array([float("-inf")], dtype=np.float32) np.testing.assert_equal(np.argmax(inf.astype(float_type)), np.argmax(inf)) - @ignore_warning( - category=RuntimeWarning, message="invalid value encountered in cast" - ) def testArgmin(self, float_type): values_to_sort = np.float32( float_type(np.float32(FLOAT_VALUES[float_type])) diff --git a/setup.py b/setup.py index 031172e2..f46c91a3 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,12 @@ COMPILE_ARGS = [ "-std=c++17", "-DEIGEN_MPL2_ONLY", + + # -ftrapping-math is necessary because NumPy looks at floating point + # exception state to determine whether to emit, e.g., invalid value + # warnings. Without this setting, on Mac ARM we see spurious "invalid + # value" warnings when running the tests. + "-ftrapping-math", ] exclude = ["third_party*"]