-
Hello, thanks a lot for your work on this library! I am trying to implement this function I wrote that def drift_correction(diffusion, x):
return jnp.einsum('jk,ikj->i', diffusion(x), jax.jacfwd(diffusion)(x)) and wanted to know if it was possible to implement it using WDYT? Yvann |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hi, I think you approach is fair. As you mentioned, we can use def drift_correction(diffusion, x):
b = diffusion(x)
_, corrected = jax.jvp(diffusion, (x, ), (b, ))
return corrected If you're interested in SDE, you may want to check out Diffrax which is designed for solving differential equations in JAX. |
Beta Was this translation helpful? Give feedback.
Hi, I think you approach is fair. As you mentioned, we can use$O(n \times m)$ . Below is an example for you case (see this in Diffrax)
jax.jvp
to avoid the need for storing the matrixjax.jacfwd(diffusion)(x)
, which would otherwise require space of orderIf you're interested in SDE, you may want to check out Diffrax which is designed for solving differential equations in JAX.