diff --git a/examples/examples_test.py b/examples/examples_test.py index fd705a4ef799..007e8e65824d 100644 --- a/examples/examples_test.py +++ b/examples/examples_test.py @@ -23,7 +23,6 @@ import numpy as np import jax -from jax import lax from jax import random import jax.numpy as jnp from jax._src import test_util as jtu @@ -55,12 +54,13 @@ def testKernelRegressionGram(self): kernel = lambda x, y: jnp.dot(x, y) np.testing.assert_allclose(kernel_lsq.gram(kernel, xs), jnp.dot(xs, xs.T), atol=1E-5) + @jax.default_matmul_precision("float32") def testKernelRegressionTrainAndPredict(self): n, d = 100, 20 truth = self.rng.normal(size=d) xs = self.rng.normal(size=(n, d)) ys = jnp.dot(xs, truth) - kernel = lambda x, y: jnp.dot(x, y, precision=lax.Precision.HIGH) + kernel = lambda x, y: jnp.dot(x, y) predict = kernel_lsq.train(kernel, xs, ys) np.testing.assert_allclose(predict(xs), ys, atol=1e-3, rtol=1e-3)