-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parallelisation #16
Comments
Hi Jack, thanks for bringing this up. Generally, JAX supports functions suited for parallelisation called However, if you are interested in parallelizing a message passing neural network (meaning you want to parallelise across graph nodes / atoms) you have to write your code such that it is compatible with these two functions. The If you only want to parallelise e.g. along the batch dimension during training, one could maybe make use of the Line 53 in bb1411b
update_fn Line 48 in bb1411b
However, so far I have never tried it so I can't provide any insight on potential pitfalls or best practices but in case you want to dive deeper into parallelisation either across atoms or batch dimension I am happy to assist in any way I can. Best, |
Thanks Thorben for the detailed explanation. It does sound a bit involved. I'll give it a try with the basic |
Hi,
Thanks for sharing this code. I had a little play with it on a garnet material and it works fairly good on training sets include finite temperature displaced structure generated from Phonopy.
So far I only managed to run this on a singe core. Just wondering how can I run the training on parallel architecture? Can you provide an example how to set it up? I just dive into N.N. recently, so not quite familiar with the JAX and other libraries that are used in your code.
Thanks a lot
Jack
The text was updated successfully, but these errors were encountered: