Replies: 2 comments 2 replies
-
Have you seen |
Beta Was this translation helpful? Give feedback.
2 replies
-
You can synthesize something like what you want by doing a while-loop of scans. You'll need an iteration number in your carry for when the while loop condition was been achieved to (1) stop updating the carry, and (2) tell you how much of the output should be concatenated. You can then use concatenate the scan results from iterations of the external while loop. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi, I'm wondering if there are any plans (or if it's feasible?) to implement
jax.lax.scan
for while loops. So far, I've been using thejax.experimental
module and it works great, but it would be helpful to have an equivalent toscan
in cases where we only need an initial condition and a function that gets called over and over until a condition is met.Beta Was this translation helpful? Give feedback.
All reactions