Replies: 1 comment 9 replies
-
If you could provide a minimal reproducible example of the type of operation you're trying to do, it would be much easier to answer this question. The best approach here depends on the details. (I also saw what I assume is your Stackoverflow question on the same topic, but someone there already asked for a minimal reproducible example, and so I was waiting for your update). |
Beta Was this translation helpful? Give feedback.
9 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Dear all,
I'm working with a JAX implementation that contains a large for loop (exceeding 1 million iterations). The operations within this loop—specifically, calls to multiple JIT-compiled functions—are independent. I'm curious whether JAX provides functionality similar to Numba's 'prange' for parallelizing the loop. While I recognize that MPI4JAX could be a potential solution by distributing the workload across multiple CPUs/GPUs, my resources are limited to fewer than 100 CPUs and a single GPU. It would be ideal if I could leverage the GPU cores (e.g., the 16,384 cores available on an RTX 4090) for parallel execution of the for loop. Any insights or suggestions would be greatly appreciated.
Thank you!
Beta Was this translation helpful? Give feedback.
All reactions