device_put + jit vs pmap for data parallel training of neural networks #16282
Unanswered
YunfanZhang42
asked this question in
Q&A
Replies: 1 comment
-
After some digging on previous issues, it seems that the difference is mostly in efficiency. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I have a question regarding the best practices of data parallel training of complicated neural networks. I am using JAX + Flax to train a ViT based model in a data parallel/SPMD manner on TPUs. After reading the documentation, I see two ways of performing SPMD for neural networks training:
flax.jax_utils.replicate
, usepmean
intrain_step
function to accumulate the gradients and batch statistics, and finally usepmap
to parallelize thetrain_step
function. This is what google-research/vision_transformer and google-research/scenic do, so I also implemented my code in this way.device_put
to shard the inputs and replicate the model parameters and optimizer states, write thetrain_step
function as usual, and thenjit
thetrain_step
. This method seems to be more elegant and does not require accumulating gradients and batch statistics manually, but I am not sure it is feasible/recommended for a complex neural network.I am wondering what would be the preferred way to perform SPMD training of neural nets moving forward. Also, I am wondering if more documentation on this would make sense. Thanks for your help!
Beta Was this translation helpful? Give feedback.
All reactions