What's the plan for XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
?
#25711
Unanswered
patrick-kidger
asked this question in
General
Replies: 1 comment
-
I'm the developer of import jax.numpy as jnp
import jax
import numpy as np
import time
print(jax.print_environment_info())
# Data generation
seed = 3 # for reproducibility of results
np.random.seed(seed)
n = 3
N = 3000000
U = jnp.array(np.random.randn(N, n, 1))
x0 = jnp.zeros((n,1))
@jax.jit
def sys(x, u):
x = .95*x + u
return x, x
t_scan=time.time()
_, Y = jax.lax.scan(sys, x0, U)
print("CPU time: ", time.time()-t_scan) See logs:
Setting
but it is still worse than with jax 0.4.31. I hope this further helps to bring |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Similar to #24501.
Over in patrick-kidger/diffrax#549 we're observing a 3x performance slowdown when upgrading the JAX version, and it looks like the new XLA runtime is the culprit.
I'm just curious what the plan is for the new runtime? (In particular just that we won't end up with only the current slower performance going forward, of course!)
Beta Was this translation helpful? Give feedback.
All reactions