Skip to content

Commit

Permalink
Merge pull request #22042 from hawkinsp:numpy
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 645519817
  • Loading branch information
jax authors committed Jun 21, 2024
2 parents 56e8fe6 + 7f24837 commit 300d06a
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 41 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.31

* Changes
* The minimum NumPy version is now 1.24.

## jaxlib 0.4.31

* Bug fixes
Expand Down
2 changes: 1 addition & 1 deletion jaxlib/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def has_ext_modules(self):
install_requires=[
'scipy>=1.9',
"scipy>=1.11.1; python_version>='3.12'",
'numpy>=1.22',
'numpy>=1.24',
'ml_dtypes>=0.2.0',
],
url='https://github.com/google/jax',
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def load_version_module(pkg_path):
install_requires=[
f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}',
'ml_dtypes>=0.2.0',
'numpy>=1.22',
"numpy>=1.23.2; python_version>='3.11'",
'numpy>=1.24',
"numpy>=1.26.0; python_version>='3.12'",
'opt_einsum',
'scipy>=1.9',
Expand Down
1 change: 0 additions & 1 deletion tests/array_interoperability_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def testNumpyToJax(self, shape, dtype, copy):
shape=all_shapes,
dtype=numpy_dtypes,
)
@unittest.skipIf(numpy_version < (1, 23, 0), "Requires numpy 1.23 or newer")
@jtu.run_on_devices("cpu") # NumPy only accepts cpu DLPacks
def testJaxToNumpy(self, shape, dtype):
rng = jtu.rand_default(self.rng())
Expand Down
15 changes: 2 additions & 13 deletions tests/lax_metal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1679,9 +1679,6 @@ def testDeleteMaskArray(self, shape, dtype, axis):
rng = jtu.rand_default(self.rng())
mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis]
mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool)
if numpy_version == (1, 23, 0) and mask.shape == (1,):
# https://github.com/numpy/numpy/issues/21840
self.skipTest("test fails for numpy v1.23.0")
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, mask, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, mask, axis=axis)
Expand Down Expand Up @@ -1943,9 +1940,6 @@ def np_fun(x, fill_value=fill_value):
@unittest.skip("jax-metal fail.")
@jtu.sample_product(dtype=inexact_dtypes)
def testUniqueNans(self, dtype):
if numpy_version == (1, 23, 0) and dtype == np.float16:
# https://github.com/numpy/numpy/issues/21838
self.skipTest("Known failure on numpy 1.23.0")
def args_maker():
x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan]
if np.issubdtype(dtype, np.complexfloating):
Expand All @@ -1966,8 +1960,6 @@ def np_fun(x):
@unittest.skip("jax-metal fail.")
@jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False])
def testUniqueEqualNan(self, dtype, equal_nan):
if numpy_version < (1, 24, 0):
self.skipTest("np.unique equal_nan requires NumPy 1.24 or newer.")
shape = (20,)
rng = jtu.rand_some_nan(self.rng())
args_maker = lambda: [rng(shape, dtype)]
Expand Down Expand Up @@ -2669,10 +2661,7 @@ def testStack(self, shape, axis, dtypes, array_input, out_dtype):
else:
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]

if numpy_version < (1, 24):
np_fun = jtu.promote_like_jnp(lambda *args: np.stack(*args, axis=axis).astype(out_dtype))
else:
np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe'))
np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe'))

jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype)
with jtu.strict_promotion_if_dtypes_match(dtypes):
Expand All @@ -2699,7 +2688,7 @@ def testHVDStack(self, shape, op, dtypes, array_input, out_dtype):
else:
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]

if numpy_version < (1, 24) or op == "dstack":
if op == "dstack":
np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype))
else:
np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype,
Expand Down
20 changes: 7 additions & 13 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,14 +507,12 @@ def testReductionWithRepeatedAxisError(self):
for weights_shape in ([None, shape] if axis is None or len(shape) == 1 or isinstance(axis, tuple)
else [None, (shape[axis],), shape])
],
keepdims=([False, True] if numpy_version >= (1, 23) else [None]),
keepdims=[False, True],
returned=[False, True],
)
def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims):
rng = jtu.rand_default(self.rng())
kwds = dict(returned=returned)
if keepdims is not None:
kwds['keepdims'] = keepdims
kwds = dict(returned=returned, keepdims=keepdims)
if weights_shape is None:
np_fun = lambda x: np.average(x, axis, **kwds)
jnp_fun = lambda x: jnp.average(x, axis, **kwds)
Expand All @@ -527,15 +525,11 @@ def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims):
tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-2, np.float32: 1e-5,
np.float64: 1e-12, np.complex64: 1e-5}
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE
if numpy_version == (1, 23, 0) and keepdims and weights_shape is not None and axis is not None:
# Known failure: https://github.com/numpy/numpy/issues/21850
pass
else:
try:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=check_dtypes, tol=tol)
except ZeroDivisionError:
self.skipTest("don't support checking for ZeroDivisionError")
try:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=check_dtypes, tol=tol)
except ZeroDivisionError:
self.skipTest("don't support checking for ZeroDivisionError")
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes,
rtol=tol, atol=tol)

Expand Down
14 changes: 3 additions & 11 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2048,9 +2048,6 @@ def np_fun(x, fill_value=fill_value):

@jtu.sample_product(dtype=inexact_dtypes)
def testUniqueNans(self, dtype):
if numpy_version == (1, 23, 0) and dtype == np.float16:
# https://github.com/numpy/numpy/issues/21838
self.skipTest("Known failure on numpy 1.23.0")
def args_maker():
x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan]
if np.issubdtype(dtype, np.complexfloating):
Expand All @@ -2070,8 +2067,6 @@ def np_fun(x):

@jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False])
def testUniqueEqualNan(self, dtype, equal_nan):
if numpy_version < (1, 24, 0):
self.skipTest("np.unique equal_nan requires NumPy 1.24 or newer.")
shape = (20,)
rng = jtu.rand_some_nan(self.rng())
args_maker = lambda: [rng(shape, dtype)]
Expand Down Expand Up @@ -2784,10 +2779,7 @@ def testStack(self, shape, axis, dtypes, array_input, out_dtype):
else:
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]

if numpy_version < (1, 24):
np_fun = jtu.promote_like_jnp(lambda *args: np.stack(*args, axis=axis).astype(out_dtype))
else:
np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe'))
np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe'))

jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype)
with jtu.strict_promotion_if_dtypes_match(dtypes):
Expand All @@ -2814,7 +2806,7 @@ def testHVDStack(self, shape, op, dtypes, array_input, out_dtype):
else:
args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]

if numpy_version < (1, 24) or op == "dstack":
if op == "dstack":
np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype))
else:
np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype,
Expand Down Expand Up @@ -5992,7 +5984,7 @@ def testWrappedSignaturesMatch(self):
mismatches = {}

for name, (jnp_fun, np_fun) in func_pairs.items():
if numpy_version >= (1, 24) and name in ['histogram', 'histogram2d', 'histogramdd']:
if name in ['histogram', 'histogram2d', 'histogramdd']:
# numpy 1.24 re-orders the density and weights arguments.
# TODO(jakevdp): migrate histogram APIs to match newer numpy versions.
continue
Expand Down

0 comments on commit 300d06a

Please sign in to comment.