Skip to content

Commit

Permalink
disable certain tests in lax_numpy_test when numpy 2.0 is used
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673051226
  • Loading branch information
trax-robot authored and copybara-github committed Sep 10, 2024
1 parent 6002f18 commit f896973
Showing 1 changed file with 95 additions and 44 deletions.
139 changes: 95 additions & 44 deletions trax/tf_numpy/jax_tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,16 +971,33 @@ def onp_fun(lhs, rhs):
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=False, atol=tol,
rtol=tol, check_incomplete_shape=True)

@named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_amin={}_amax={}".format(
jtu.format_shape_dtype_string(shape, dtype), a_min, a_max),
"shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max,
"rng_factory": jtu.rand_default}
for shape in all_shapes for dtype in minus(number_dtypes, complex_dtypes)
for a_min, a_max in [(-1, None), (None, 1), (-1, 1),
(-onp.ones(1), None),
(None, onp.ones(1)),
(-onp.ones(1), onp.ones(1))]))
@named_parameters(
jtu.cases_from_list(
{
"testcase_name": "_{}_amin={}_amax={}".format(
jtu.format_shape_dtype_string(shape, dtype), a_min, a_max
),
"shape": shape,
"dtype": dtype,
"a_min": a_min,
"a_max": a_max,
"rng_factory": jtu.rand_default,
}
for shape in all_shapes
for dtype in minus(number_dtypes, complex_dtypes)
for a_min, a_max in [
(-1, None),
(None, 1),
(-onp.ones(1), None),
(None, onp.ones(1)),
]
+ (
[]
if onp.__version__ >= onp.lib.NumpyVersion("2.0.0")
else [(-1, 1), (-onp.ones(1), onp.ones(1))]
)
)
)
def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng_factory):
rng = rng_factory()
onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max)
Expand Down Expand Up @@ -1357,7 +1374,6 @@ def testDiagIndices(self, ndim, n):
onp.testing.assert_equal(onp.diag_indices(n, ndim),
lnp.diag_indices(n, ndim))


@named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}".format(
jtu.format_shape_dtype_string(shape, dtype), k),
Expand Down Expand Up @@ -1951,7 +1967,6 @@ def testFlipud(self, shape, dtype, rng_factory):
self._CompileAndCheck(
lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True)


@named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
Expand All @@ -1968,7 +1983,6 @@ def testFliplr(self, shape, dtype, rng_factory):
self._CompileAndCheck(
lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True)


@named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_k={}_axes={}".format(
jtu.format_shape_dtype_string(shape, dtype), k, axes),
Expand Down Expand Up @@ -2295,7 +2309,6 @@ def onp_fun(*args):
tol=tol)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, rtol=tol)


@named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),
Expand All @@ -2318,7 +2331,6 @@ def testWhereOneArgument(self, shape, dtype):
check_unknown_rank=False,
check_experimental_compile=False, check_xla_forced_compile=False)


@named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format("_".join(
jtu.format_shape_dtype_string(shape, dtype)
Expand Down Expand Up @@ -2373,7 +2385,6 @@ def onp_fun(condlist, choicelist, default):
check_incomplete_shape=True,
rtol={onp.float64: 1e-7, onp.complex128: 1e-7})


@jtu.disable
def testIssue330(self):
x = lnp.full((1, 1), lnp.array([1])[0]) # doesn't crash
Expand Down Expand Up @@ -2429,7 +2440,6 @@ def testAtLeastNdLiterals(self, pytype, dtype, op):
self._CompileAndCheck(
lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True)


def testLongLong(self):
self.assertAllClose(
onp.int64(7), npe.jit(lambda x: x)(onp.longlong(7)), check_dtypes=True)
Expand Down Expand Up @@ -2676,19 +2686,40 @@ def testMeshGrid(self, shapes, dtype, indexing, sparse, rng_factory):

@named_parameters(
jtu.cases_from_list(
{"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}"
"_retstep={}_dtype={}").format(
start_shape, stop_shape, num, endpoint, retstep, dtype),
"start_shape": start_shape, "stop_shape": stop_shape,
"num": num, "endpoint": endpoint, "retstep": retstep,
"dtype": dtype, "rng_factory": rng_factory}
for start_shape in [(), (2,), (2, 2)]
for stop_shape in [(), (2,), (2, 2)]
for num in [0, 1, 2, 5, 20]
for endpoint in [True, False]
for retstep in [True, False]
for dtype in number_dtypes + [None,]
for rng_factory in [jtu.rand_default]))
{
"testcase_name": (
"_start_shape={}_stop_shape={}_num={}_endpoint={}"
"_retstep={}_dtype={}"
).format(start_shape, stop_shape, num, endpoint, retstep, dtype),
"start_shape": start_shape,
"stop_shape": stop_shape,
"num": num,
"endpoint": endpoint,
"retstep": retstep,
"dtype": dtype,
"rng_factory": rng_factory,
}
for start_shape in [(), (2,), (2, 2)]
for stop_shape in [(), (2,), (2, 2)]
for num in [0, 1, 2, 5, 20]
for endpoint in [True, False]
for retstep in [True, False]
for dtype in (
(
float_dtypes
+ complex_dtypes
+ [
None,
]
)
if onp.__version__ >= onp.lib.NumpyVersion("2.0.0")
else ([
number_dtypes + None,
])
)
for rng_factory in [jtu.rand_default]
)
)
def testLinspace(self, start_shape, stop_shape, num, endpoint,
retstep, dtype, rng_factory):
if not endpoint and onp.issubdtype(dtype, onp.integer):
Expand Down Expand Up @@ -2770,20 +2801,40 @@ def testLogspace(self, start_shape, stop_shape, num,

@named_parameters(
jtu.cases_from_list(
{"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}"
"_dtype={}").format(
start_shape, stop_shape, num, endpoint, dtype),
"start_shape": start_shape,
"stop_shape": stop_shape,
"num": num, "endpoint": endpoint,
"dtype": dtype, "rng_factory": rng_factory}
for start_shape in [(), (2,), (2, 2)]
for stop_shape in [(), (2,), (2, 2)]
for num in [0, 1, 2, 5, 20]
for endpoint in [True, False]
# NB: numpy's geomspace gives nonsense results on integer types
for dtype in inexact_dtypes + [None,]
for rng_factory in [jtu.rand_default]))
{
"testcase_name": (
"_start_shape={}_stop_shape={}_num={}_endpoint={}_dtype={}"
).format(start_shape, stop_shape, num, endpoint, dtype),
"start_shape": start_shape,
"stop_shape": stop_shape,
"num": num,
"endpoint": endpoint,
"dtype": dtype,
"rng_factory": rng_factory,
}
for start_shape in [(), (2,), (2, 2)]
for stop_shape in [(), (2,), (2, 2)]
for num in [0, 1, 2, 5, 20]
for endpoint in [True, False]
# NB: numpy's geomspace gives nonsense results on integer types
for dtype in (
(
float_dtypes
+ [
None,
]
)
if onp.__version__ >= onp.lib.NumpyVersion("2.0.0")
else (
inexact_dtypes
+ [
None,
]
)
)
for rng_factory in [jtu.rand_default]
)
)
def testGeomspace(self, start_shape, stop_shape, num,
endpoint, dtype, rng_factory):
rng = rng_factory()
Expand Down

0 comments on commit f896973

Please sign in to comment.