-
Hi, I'm running into a non-addressable-data error when running distributed (on TPU) and trying to debug where I went wrong. I've seen this before, and I've usually been able to figure it out by looking at the stack trace, but today it's not terribly helpful (it's erroring out inside JAX itself). I was wondering if there's an easy way to reliably emulate the non-addressable data condition on a laptop, since it's a lot easier to debug there than on TPUs... FWIW, this is the stack trace
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
This means that the constant you are closing over is a jax.Array which is not fully addressable. I would suggest passing it as an argument to the jitted function instead and then it should work. |
Beta Was this translation helpful? Give feedback.
This means that the constant you are closing over is a jax.Array which is not fully addressable.
I would suggest passing it as an argument to the jitted function instead and then it should work.