Skip to content

How Jax lower jaxpr control flow into XLA? #10990

Answered by sharadmv
adamantboy asked this question in General
Discussion options

You must be logged in to vote

Control flow primitives are lowered just like any other primitive in the sense that at lowering time, we are provided with 1) the MHLO arguments to the primitives and 2) the parameters of the primitive which include the jaxpr body (multiple jaxprs in the case of while/cond).

At a high level, we recursively call mlir.jaxpr_subcomp on these body jaxprs (see here for an example) while inside MLIR control flow blocks.

Does that answer your question or did you want some more specific answers for each primitive?

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@adamantboy
Comment options

@sharadmv
Comment options

sharadmv Jun 7, 2022
Collaborator

Answer selected by adamantboy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants