-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Hessian of gamma-distributed samples #21432
base: main
Are you sure you want to change the base?
Conversation
f5cbd12
to
38c61c0
Compare
from jax._src import dtypes | ||
from jax._src.interpreters import ad | ||
from jax._src.interpreters import mlir | ||
from jax._src.lib.mlir.dialects import chlo | ||
from jax._src.typing import Array, ArrayLike | ||
|
||
def _while_loop_scan(cond_fun, body_fun, init_val, max_iter): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I move this to api_util?
38c61c0
to
612fdda
Compare
@jakevdp Is there any interest in merging something like this? It's really hard to do the same thing in client code without copying thousands of lines of Jax code and then keeping them updated (which is extremely time-consuming). Or perhaps you have an alternative idea for how I can accomplish this? |
612fdda
to
2386fcd
Compare
2386fcd
to
ab2f665
Compare
Hi @NeilGirdhar, sorry for the delay on this. I'm hoping @froystig or @mattjj can weigh-in here. I'm a bit concerned about the approach here, because the bounded while-loop might have memory impacts when computing the second derivative for large numbers of samples: my understanding is it would require statically allocating a buffer 256 times larger than the buffer of samples you're generating. It may be that Perhaps there are ways to compute this second derivative more directly, without differentiating through the implementation of the first derivative? |
Absolutely no problem about the delay. Congratulations on completing the Jax implementation of the Array API so quickly!
Your concerns make perfect sense to me. I'm going to let the others weight in, but in the interest of eliminating back-and-forth, I'll make some comments and suggestions if that's okay 😄 First of all, the reason for my frequent force-pushes (sorry if that was noisy?) is because this feature is so important to me that I am now pointing my repo to my PR branch rather than to Jax directly. This way I have access to this feature. I tried lifting the gamma-random-generation out of Jax, but it's a large mass of code that has changed over the last year, so that was too much work to keep updated. As for the time and space concerns, I want to first remind readers that only the second derivative code is slow. I agree with your point that this could be a footgun. The ideal approach is probably to replace the for loop with solving a fixed point. (So, it would go back to being just a while loop in all cases.) Is there any precedent to fixed point optimization in Jax's source code? I know that JaxOpt and tjax have fixed point solvers. It would be some work, but should be possible to recast the algorithm slightly so that it fits the fixed point interface. An alternative approach would be to tune the 256 constant. I think it's about ten times too big. I didn't think about speed or memory, and I just wanted to get it working. Tuning this constant might solve the speed problem, but the fixed point solution makes the memory cost constant, I think. What do you think? |
ab2f665
to
ae6d6e8
Compare
ae6d6e8
to
5e60da4
Compare
c4245a1
to
f8872d8
Compare
3be8213
to
184c189
Compare
184c189
to
3120f5d
Compare
3120f5d
to
a8f66d1
Compare
Fixes #16076