From 7f24837eef654a5d25ae6261742d3c993f093600 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 21 Jun 2024 14:57:35 -0700 Subject: [PATCH] Update minimum NumPy version to v1.24. --- CHANGELOG.md | 3 +++ jaxlib/setup.py | 2 +- setup.py | 3 +-- tests/array_interoperability_test.py | 1 - tests/lax_metal_test.py | 15 ++------------- tests/lax_numpy_reducers_test.py | 20 +++++++------------- tests/lax_numpy_test.py | 14 +++----------- 7 files changed, 17 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ad4dfcecf753..aa15035fe5e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.31 +* Changes + * The minimum NumPy version is now 1.24. + ## jaxlib 0.4.31 * Bug fixes diff --git a/jaxlib/setup.py b/jaxlib/setup.py index aecace3a928f..0e4f422be67a 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -63,7 +63,7 @@ def has_ext_modules(self): install_requires=[ 'scipy>=1.9', "scipy>=1.11.1; python_version>='3.12'", - 'numpy>=1.22', + 'numpy>=1.24', 'ml_dtypes>=0.2.0', ], url='https://github.com/google/jax', diff --git a/setup.py b/setup.py index c0fe601b3afb..82b163ee5d4f 100644 --- a/setup.py +++ b/setup.py @@ -55,8 +55,7 @@ def load_version_module(pkg_path): install_requires=[ f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}', 'ml_dtypes>=0.2.0', - 'numpy>=1.22', - "numpy>=1.23.2; python_version>='3.11'", + 'numpy>=1.24', "numpy>=1.26.0; python_version>='3.12'", 'opt_einsum', 'scipy>=1.9', diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index b555576b3261..c2cd4c0f968d 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -214,7 +214,6 @@ def testNumpyToJax(self, shape, dtype, copy): shape=all_shapes, dtype=numpy_dtypes, ) - @unittest.skipIf(numpy_version < (1, 23, 0), "Requires numpy 1.23 or newer") @jtu.run_on_devices("cpu") # NumPy only accepts cpu DLPacks def testJaxToNumpy(self, shape, dtype): rng = jtu.rand_default(self.rng()) diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index 5069187d2334..dab26d86c0a2 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -1679,9 +1679,6 @@ def testDeleteMaskArray(self, shape, dtype, axis): rng = jtu.rand_default(self.rng()) mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool) - if numpy_version == (1, 23, 0) and mask.shape == (1,): - # https://github.com/numpy/numpy/issues/21840 - self.skipTest("test fails for numpy v1.23.0") args_maker = lambda: [rng(shape, dtype)] np_fun = lambda arg: np.delete(arg, mask, axis=axis) jnp_fun = lambda arg: jnp.delete(arg, mask, axis=axis) @@ -1943,9 +1940,6 @@ def np_fun(x, fill_value=fill_value): @unittest.skip("jax-metal fail.") @jtu.sample_product(dtype=inexact_dtypes) def testUniqueNans(self, dtype): - if numpy_version == (1, 23, 0) and dtype == np.float16: - # https://github.com/numpy/numpy/issues/21838 - self.skipTest("Known failure on numpy 1.23.0") def args_maker(): x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan] if np.issubdtype(dtype, np.complexfloating): @@ -1966,8 +1960,6 @@ def np_fun(x): @unittest.skip("jax-metal fail.") @jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False]) def testUniqueEqualNan(self, dtype, equal_nan): - if numpy_version < (1, 24, 0): - self.skipTest("np.unique equal_nan requires NumPy 1.24 or newer.") shape = (20,) rng = jtu.rand_some_nan(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -2669,10 +2661,7 @@ def testStack(self, shape, axis, dtypes, array_input, out_dtype): else: args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - if numpy_version < (1, 24): - np_fun = jtu.promote_like_jnp(lambda *args: np.stack(*args, axis=axis).astype(out_dtype)) - else: - np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe')) + np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe')) jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype) with jtu.strict_promotion_if_dtypes_match(dtypes): @@ -2699,7 +2688,7 @@ def testHVDStack(self, shape, op, dtypes, array_input, out_dtype): else: args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - if numpy_version < (1, 24) or op == "dstack": + if op == "dstack": np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype)) else: np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype, diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 37e8410ddbfb..588368cd8553 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -507,14 +507,12 @@ def testReductionWithRepeatedAxisError(self): for weights_shape in ([None, shape] if axis is None or len(shape) == 1 or isinstance(axis, tuple) else [None, (shape[axis],), shape]) ], - keepdims=([False, True] if numpy_version >= (1, 23) else [None]), + keepdims=[False, True], returned=[False, True], ) def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims): rng = jtu.rand_default(self.rng()) - kwds = dict(returned=returned) - if keepdims is not None: - kwds['keepdims'] = keepdims + kwds = dict(returned=returned, keepdims=keepdims) if weights_shape is None: np_fun = lambda x: np.average(x, axis, **kwds) jnp_fun = lambda x: jnp.average(x, axis, **kwds) @@ -527,15 +525,11 @@ def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims): tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5} check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE - if numpy_version == (1, 23, 0) and keepdims and weights_shape is not None and axis is not None: - # Known failure: https://github.com/numpy/numpy/issues/21850 - pass - else: - try: - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - check_dtypes=check_dtypes, tol=tol) - except ZeroDivisionError: - self.skipTest("don't support checking for ZeroDivisionError") + try: + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + check_dtypes=check_dtypes, tol=tol) + except ZeroDivisionError: + self.skipTest("don't support checking for ZeroDivisionError") self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes, rtol=tol, atol=tol) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 6581e69f8625..ad6efbb8ad28 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2048,9 +2048,6 @@ def np_fun(x, fill_value=fill_value): @jtu.sample_product(dtype=inexact_dtypes) def testUniqueNans(self, dtype): - if numpy_version == (1, 23, 0) and dtype == np.float16: - # https://github.com/numpy/numpy/issues/21838 - self.skipTest("Known failure on numpy 1.23.0") def args_maker(): x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan] if np.issubdtype(dtype, np.complexfloating): @@ -2070,8 +2067,6 @@ def np_fun(x): @jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False]) def testUniqueEqualNan(self, dtype, equal_nan): - if numpy_version < (1, 24, 0): - self.skipTest("np.unique equal_nan requires NumPy 1.24 or newer.") shape = (20,) rng = jtu.rand_some_nan(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -2784,10 +2779,7 @@ def testStack(self, shape, axis, dtypes, array_input, out_dtype): else: args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - if numpy_version < (1, 24): - np_fun = jtu.promote_like_jnp(lambda *args: np.stack(*args, axis=axis).astype(out_dtype)) - else: - np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe')) + np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe')) jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype) with jtu.strict_promotion_if_dtypes_match(dtypes): @@ -2814,7 +2806,7 @@ def testHVDStack(self, shape, op, dtypes, array_input, out_dtype): else: args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - if numpy_version < (1, 24) or op == "dstack": + if op == "dstack": np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype)) else: np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype, @@ -5992,7 +5984,7 @@ def testWrappedSignaturesMatch(self): mismatches = {} for name, (jnp_fun, np_fun) in func_pairs.items(): - if numpy_version >= (1, 24) and name in ['histogram', 'histogram2d', 'histogramdd']: + if name in ['histogram', 'histogram2d', 'histogramdd']: # numpy 1.24 re-orders the density and weights arguments. # TODO(jakevdp): migrate histogram APIs to match newer numpy versions. continue