How Jax lower jaxpr control flow into XLA? #10990
-
how jaxpr_subcomp deal with control flow primitives, when these eqns are lowered into XLA HLO? For example, cond, while, scan. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Control flow primitives are lowered just like any other primitive in the sense that at lowering time, we are provided with 1) the MHLO arguments to the primitives and 2) the parameters of the primitive which include the jaxpr body (multiple jaxprs in the case of while/cond). At a high level, we recursively call Does that answer your question or did you want some more specific answers for each primitive? |
Beta Was this translation helpful? Give feedback.
Control flow primitives are lowered just like any other primitive in the sense that at lowering time, we are provided with 1) the MHLO arguments to the primitives and 2) the parameters of the primitive which include the jaxpr body (multiple jaxprs in the case of while/cond).
At a high level, we recursively call
mlir.jaxpr_subcomp
on these body jaxprs (see here for an example) while inside MLIR control flow blocks.Does that answer your question or did you want some more specific answers for each primitive?