Skip to content

Commit

Permalink
Force float32 matmuls in examples_test.
Browse files Browse the repository at this point in the history
This test started failing when we changed our CI to use L4 GPUs. Using
highest precision resolves the problem.
  • Loading branch information
hawkinsp committed May 10, 2024
1 parent c2d78ab commit 24b4731
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import numpy as np

import jax
from jax import lax
from jax import random
import jax.numpy as jnp
from jax._src import test_util as jtu
Expand Down Expand Up @@ -55,12 +54,13 @@ def testKernelRegressionGram(self):
kernel = lambda x, y: jnp.dot(x, y)
np.testing.assert_allclose(kernel_lsq.gram(kernel, xs), jnp.dot(xs, xs.T), atol=1E-5)

@jax.default_matmul_precision("float32")
def testKernelRegressionTrainAndPredict(self):
n, d = 100, 20
truth = self.rng.normal(size=d)
xs = self.rng.normal(size=(n, d))
ys = jnp.dot(xs, truth)
kernel = lambda x, y: jnp.dot(x, y, precision=lax.Precision.HIGH)
kernel = lambda x, y: jnp.dot(x, y)
predict = kernel_lsq.train(kernel, xs, ys)
np.testing.assert_allclose(predict(xs), ys, atol=1e-3, rtol=1e-3)

Expand Down

0 comments on commit 24b4731

Please sign in to comment.