-
I checked that both jax.jit and jax.pmap return a function that is compiled. I added the line:
But I hit error:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Thanks for the question! You need to give Maybe you want something more like this? example_x = jnp.zeros(8)
lowered = jax.pmap(_wrapped_step_fun, in_axes=(0,), axis_name=...).lower(example_x) |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
You need to give
jax.pmap(f).lower
arguments that are like the same arguments you'd give tojax.pmap(f)
itself. So in particular, the argument(0,)
can't be mapped over, hence the error. Also, you need to passin_axes
andaxis_name
values tojax.pmap
as usual, not to.lower
.Maybe you want something more like this?