Calling PyTorch Functions from JAX? #15170
Unanswered
adam-hartshorne
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I stumbled across this code for calling JAX functions in Pytorch
https://gist.github.com/mattjj/e8b51074fed081d765d2f3ff90edf0e9
can you just effectively reverse the procedure using the dlpack operations to call Pytorch function in JAX? Or are there more compilations? If not, would it be sensible to have torch2jax in the main API in the same way of there is jax2tf?
Beta Was this translation helpful? Give feedback.
All reactions