Differentiating through simulations / reference request #24448
Unanswered
statsybanksy
asked this question in
Q&A
Replies: 1 comment
-
Hi - thanks for the question! I'm not sure exactly what answer you're looking for here – differentiation through a simulation is no different than any other autodiff setting. You're computing the partial derivative with respect to samples = mu + sigma * random.normal(key, (num_samples,)) is essentially this with respect to autodiff:
and the autodiff rule for Other expressions in the function are handled similarly, and then the chain rule is used to propagate the gradients through the sequence of operations to get the autodiff result. Does that make sense? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have been a Jax user for a long time now, but only recently did I realise that it allows differentiating through monte carlo simulations.
Having thought about how that's possible & then googled, I am struggling to find intuitive, engineering or mathematical explanations of how that actually works. [For background, I do understand the classic automatic differentiation.]
I was wondering if you have any intuition or resources you could share on the topic. I appreciate I probably won't understand (or don't even need to understand) the nitty-gritty details of how it works, but would be especially keen to understand where the limitations of the approach are.
To make it specific, below is a simple example of a monte carlo simulation in jax I've coded up.
Beta Was this translation helpful? Give feedback.
All reactions