Skip to content

Efficient implementation of Stratonovich drift correction #16449

Answered by anh-tong
ylefay asked this question in Q&A
Discussion options

You must be logged in to vote

Hi, I think you approach is fair. As you mentioned, we can use jax.jvp to avoid the need for storing the matrix jax.jacfwd(diffusion)(x), which would otherwise require space of order $O(n \times m)$. Below is an example for you case (see this in Diffrax)

def drift_correction(diffusion, x):
    b = diffusion(x)
    _, corrected = jax.jvp(diffusion, (x, ), (b, ))
    return corrected

If you're interested in SDE, you may want to check out Diffrax which is designed for solving differential equations in JAX.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by ylefay
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants