differentiable fori_loop
?
#9699
-
According to The Sharp Bits, However, this actually works import jax
import jax.numpy as jnp
@jax.jit
def bar(x):
def _body_fun(i, val):
return val + x[i]**2
return jax.lax.fori_loop(0, x.shape[0], _body_fun, 0.0)
x = jnp.arange(4, dtype=jnp.float32)
print (jax.grad(bar)(x)) # prints [0. 2. 4. 6.] To my understanding Thank you in advance for explanation. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question! That part of the doc should be updated: while it's true that in general |
Beta Was this translation helpful? Give feedback.
Thanks for the question! That part of the doc should be updated: while it's true that in general
fori_loop
is not reverse-mode differentiable, in the special case of concrete start and end-points, we lowerfori_loop
toscan
to allow for reverse-mode differentiation. Here's where it happens in the source: https://github.com/google/jax/blob/d5a1c64d135ae8519c61e15a2f32a75d8de36ab3/jax/_src/lax/control_flow.py#L205-L217