unbound axis name error using xmap with odeint and SerialLoop #15807
Unanswered
DanPuzzuoli
asked this question in
General
Replies: 1 comment 1 reply
-
This looks like a bug. |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi, I'm trying to use
xmap
to parallelize the mapping of a function over multiple CPU cores and am running into an error.Below is a small reproduction. I have:
f
, whose input and return types are both a single floatinput_array
of floats of length16
I want to mapf
overI want to parallelize the mapping of
f
overinput_array
using the8
cores. In this case2
inputs tof
will be evaluated on each core, however I want these to be evaluated using a serial loop. My code attempting this, with a simple version off
, is below:this yields the error
What's confusing me is that if I change the definition of
f
, e.g. to:then the code seems to work. Also, if I remove the
SerialLoop(2)
specification, the code also works. However, for the full version of my function, I definitely want to do a serial loop as it doesn't vectorize well.I'm not sure here if I'm not using
SerialLoop
properly, or if there is something aboutodeint
that's not combining well withxmap
.I'm using
jax==0.4.8
andjaxlib==0.4.7
.Beta Was this translation helpful? Give feedback.
All reactions