Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The code may fail to transfer grads from TPU core memory to CPU host memory #1

Open
shawwn opened this issue Jun 29, 2021 · 2 comments

Comments

@shawwn
Copy link

shawwn commented Jun 29, 2021

Hiya! I was talking with Skye over at jax-ml/jax#2108 (comment)

one issue with the code snippet you have at the top for transferring grads from TPU to CPU memory: device_put is a no-op inside jit (I think there may an issue about making this less of a gotcha, but I can't find it now). Using device_put with a CPU device is the right idea though, it just has to happen outside of a compiled function.

The code snippet Skye is referring to was copied from this codebase. So I thought I should open an issue, because it sounds like that code is nonfunctional.

I’m posting this from my phone, but I’ll add more details in a little while.

@shawwn
Copy link
Author

shawwn commented Jun 29, 2021

Some extra details, as promised.

In swarm_layer.py:

@partial(jax.jit, static_argnums=3)
def opt_jit(grad_acc, opt_state, params, optimizer):
total_grad = jax.tree_map(lambda x: jnp.mean(x, axis=0), grad_acc)
cpu_device = jax.devices("cpu")[0]
total_grad = jax.device_put(total_grad, device=cpu_device)
cpu_params = jax.device_put(jax.tree_map(lambda x: x[0], params), device=cpu_device)

I think that this code isn't functioning properly. According to skye, device_put is a no-op inside jit. So I assume that each of those calls to device_put() has no effect.

I'm not experienced enough with Jax to know the best way to fix the problem. What do you think the right solution is?

Commenting out @partial(jax.jit, static_argnums=3) seems like the most straightforward "solution." But I don't know anything about Jax's JIT (yet), so I don't know if that makes any sense, or what the tradeoffs are.

@kingoflolz
Copy link
Owner

@shawwn hrm thanks for letting me know. I was trying to do CPU offload of the optimizer parameters initially to fit a bigger model on the TPU, but the better way to do that now would be to integrate part of mesh transformer jax to perform model parallel sharding within the 8 TPU devices. I think I'll leave it for now and refactor if/when I get around to it so this entire approach would be unnessasary.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants