You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
one issue with the code snippet you have at the top for transferring grads from TPU to CPU memory: device_put is a no-op inside jit (I think there may an issue about making this less of a gotcha, but I can't find it now). Using device_put with a CPU device is the right idea though, it just has to happen outside of a compiled function.
The code snippet Skye is referring to was copied from this codebase. So I thought I should open an issue, because it sounds like that code is nonfunctional.
I’m posting this from my phone, but I’ll add more details in a little while.
The text was updated successfully, but these errors were encountered:
I think that this code isn't functioning properly. According to skye, device_put is a no-op inside jit. So I assume that each of those calls to device_put() has no effect.
I'm not experienced enough with Jax to know the best way to fix the problem. What do you think the right solution is?
Commenting out @partial(jax.jit, static_argnums=3) seems like the most straightforward "solution." But I don't know anything about Jax's JIT (yet), so I don't know if that makes any sense, or what the tradeoffs are.
@shawwn hrm thanks for letting me know. I was trying to do CPU offload of the optimizer parameters initially to fit a bigger model on the TPU, but the better way to do that now would be to integrate part of mesh transformer jax to perform model parallel sharding within the 8 TPU devices. I think I'll leave it for now and refactor if/when I get around to it so this entire approach would be unnessasary.
Hiya! I was talking with Skye over at jax-ml/jax#2108 (comment)
The code snippet Skye is referring to was copied from this codebase. So I thought I should open an issue, because it sounds like that code is nonfunctional.
I’m posting this from my phone, but I’ll add more details in a little while.
The text was updated successfully, but these errors were encountered: