-
Following up on @jakevdp's comment here, I have a case where I'm wondering of Concretely, I have something like def f(X, k):
state1 = jnp.empty((0, 0))
state2 = X
for _ in range(k):
state1, state2 = update_state(state1, state2)
return state1, state2
f_jit = jax.jit(f, static_argnames=("k",)) where And in particular, this means that |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
In general, I think the best solution here would be to change In other words, I think this is a case where the details are important. |
Beta Was this translation helpful? Give feedback.
In general, I think the best solution here would be to change
update_state
so that it returns padded arrays of the same shape as the input, and then use afori_loop
. Without more information about whatupdate_state
does, though, it's hard to say how feasible that is.In other words, I think this is a case where the details are important.