Replies: 3 comments 1 reply
-
Hi @AuguB An Op like this would be implemented as a ScalarOp (see pytensor.scalar.math.py) and then "vectorized" via Elemwise. It's fine for the gradients to use numerical approximation. We do this for some like GammaInc functions, you can find it in the same file. In this case scipy also seems to have a helper to compute it: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.kvp.html#scipy.special.kvp We can try to wrap it as it's own Op or implement directly in PyTensor for support in other backends For numba support of kv you'll need a dispatch for the ScalarOp. There's not much we can do there automatically. Some cases we try to access Scipy C functions directly. You can try to see if the same approach would work for this Op. Contributions are welcome of course! |
Beta Was this translation helpful? Give feedback.
-
I see now that you were suggesting using finite-differences. That is probably not good enough for us (but may be fine as your custom Op). By the way the gradient wrt to x is pretty simple, just For the gradient wrt to v, I found this paper and accompanying JAX/Tfp code that you may find interesting: https://github.com/tk2lab/logbesselk/tree/main It's for the log of Kv, not Kv directly, but it's straightforward to translate if you are not interested in the log representation itself. I'm opening a PR to add the Kv and gradient wrt to x, as it showed up in another project I was working on. We can chat about expanding it to cover the gradient wrt to v as well. |
Beta Was this translation helpful? Give feedback.
-
Hi @ricardoV94, thanks for the replies, they have been very helpful. I've currently implemented the Kv (and its derivative) as BinaryScalarOps. And I'm happy with how it looks and runs for my case for now (although speedups are always welcome). Having Kv and the gradient wrt v available in pymc sounds great, but I don't think I understand the different backends well enough to know what the best (and most general) solution is. If I understand correctly, the repo you linked provides solutions for Jax and TF, but not for Numba or pure Python. I have looked to providing C code for the K op directly, because it is fast and portable, but I am not sure if Numba, Jax, and TF support comes automatically (and my experience with C is meager at best). I would be up for working on this, what do you think is the best way to proceed? |
Beta Was this translation helpful? Give feedback.
-
Hi all, I am working on a custom Op that I have to use in a project, and I may want to submit a PR for this function (the scipy.special.kv function) and perhaps others in that module if it goes well, but I want to make sure that I am doing it correctly first.
Here's my Op as it is now:
Specifically, I need the gradient with respect to the first parameter, p, and currently I am using a numerical approximation, because there exists no analytical expression for it (afaik). The performance of this Op in a PyMC setting is not so good, so I would like to speed it up if possible.
My questions are the following:
Any help is appreciated!
Beta Was this translation helpful? Give feedback.
All reactions