diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index d848bc437df9..9a5fc60c2320 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -538,6 +538,17 @@ def kernel(x_ref, y_ref): out = self.pallas_call(kernel, out_shape=x_shape_dtype)(x) self.assertAllClose(out, func(x), atol=tol, rtol=tol) + @parameterized.product(to_dtype=[jnp.int32, jnp.int16, jnp.int8, jnp.int4]) + def test_rounding(self, to_dtype): + """Make sure rounding when casting float to int is consistent with JAX.""" + x = jnp.repeat((jnp.arange(256)[None] - 128) / 4, 64, axis=0) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...].astype(to_dtype) + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct(x.shape, to_dtype) + )(x) + np.testing.assert_array_equal(out, x.astype(to_dtype)) + @parameterized.product(from_dtype=_DTYPES, to_dtype=_DTYPES) @hp.given(hps.data()) def test_cast(self, from_dtype, to_dtype, data):