Replies: 1 comment 4 replies
-
Hi Zach! Thanks for stopping by. My reading of the DLPack spec found that it had little to say on the topic of synchronization. So I opted to be conservative and over-synchronize via the host thread, which is certainly safe, if pessimistic. If we (the collective users of the DLPack spec) can agree on a better synchronization protocol it's probably not a big deal to implement it. There's a similar discussion happening with respect to I agree that my natural inclination would also be to include a I'm open to suggestions! |
Beta Was this translation helpful? Give feedback.
-
I am prototyping using jax along with pytorch to use xla to accelerate some fusible kernels inside a larger program. So far, I have code that looks like:
https://gist.github.com/zdevito/dff820e2053b29b1f688ad8db1da5f35
This enables jax (really just the XLA APIs) to run on pytorch data on the GPU and then produce new values without needing to do any copies.
However, when
buffer_to_dlpack_managed_tensor
it effectively callsblock_until_ready
, waiting on the cuda stream to ensure the data is ready. When using jax/xla to general kernels inside a larger program, this isn't great for performance because it inserts CPU synchronization points in a program that really only needs ordering of kernels on the GPU. A similar scenario exists for translating into xla, where technically we have to sync the PyTorch work before calling XLA kernels.I was looking to get your thoughts on ways around this. One way would be for the dlpack-style functions to take or return cudaEvent objects that indicate when the data is ready. Naturally this only works for CUDA, and there might needs to be a similar thing CPUs.
Beta Was this translation helpful? Give feedback.
All reactions