Skip to content

Commit

Permalink
Add mlir print option
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 21, 2024
1 parent 98699ab commit 17695db
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,10 @@ def _enzyme_primal_lowering(

if lang == LANG_MHLO:
(in_tree, in_idx_map, out_idx_map, mfunc, jit_options) = source
print_mlir = False
if "print_mlir" in jit_options:
print_mlir = jit_options["print_mlir"]
del jit_options["print_mlir"]
assert len(out_idx_map) == len(out_shapes)

orig_shapes = []
Expand Down Expand Up @@ -586,6 +590,8 @@ def _enzyme_primal_lowering(
fns.append(f.sym_name.value)

name, nmod = enzyme_call.run_pass_pipeline(fns, source, pass_pipeline)
if print_mlir:
print(str(nmod), flush=True)
nmod = ir.Module.parse(nmod)
fn = None
for f in nmod.body:
Expand Down

0 comments on commit 17695db

Please sign in to comment.