Skip to content

Commit

Permalink
Preserve lowering parameters (like backend)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 24, 2024
1 parent e64d2e4 commit e6a6ccc
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,10 @@ def _enzyme_primal_lowering(
orig_types.append(in_types[i])
avals = [ctx.avals_in[seen[i]] for i in seen]
avals_in = jax.tree_util.tree_unflatten(in_tree, avals)
lowered_func = jax.jit(mfunc, **jit_options).lower(*avals_in)
lowered_func = jax.jit(mfunc, **jit_options).lower(
*avals_in,
_experimental_lowering_parameters=ctx.module_context.lowering_parameters
)
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
source = mhlo.operation.get_asm(enable_debug_info=True)
kept = lowered_func.compile()._executable._kept_var_idx
Expand Down

0 comments on commit e6a6ccc

Please sign in to comment.