diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 449214491..0be56d894 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -556,7 +556,10 @@ def _enzyme_primal_lowering( orig_types.append(in_types[i]) avals = [ctx.avals_in[seen[i]] for i in seen] avals_in = jax.tree_util.tree_unflatten(in_tree, avals) - lowered_func = jax.jit(mfunc, **jit_options).lower(*avals_in) + lowered_func = jax.jit(mfunc, **jit_options).lower( + *avals_in, + _experimental_lowering_parameters=ctx.module_context.lowering_parameters + ) mhlo = lowered_func.compiler_ir(dialect="stablehlo") source = mhlo.operation.get_asm(enable_debug_info=True) kept = lowered_func.compile()._executable._kept_var_idx