Fast network's weights transfert to other processes #10438
-
Hello! Currently I am using a custom build shared_memory made with pytrees. I can therefore access to the weights from all the actor processes after updating it from the learner process. However it takes some time to update the shared_memory with the learner (approximately 15-25% of the time needed to do a training step). I wanted to know if there is any way to store/retrieve/transfer the network's weights more efficiently (since it is originally already available on the GPU). Maybe by doing a copy of the weights directly from the GPU to other parts of the GPU (or RAM) in order to access it from the other processes ? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
You may be interested in https://github.com/mpi4jax/mpi4jax |
Beta Was this translation helpful? Give feedback.
You may be interested in https://github.com/mpi4jax/mpi4jax