diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 4991aea909f1..cde578a5849d 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1855,18 +1855,26 @@ def _while_typecheck(_, *in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts, f'Effects not supported in `while`: {disallowed_effects}') return body_jaxpr.out_avals, joined_effects -def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr, +def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): # TODO(sharadmv): enable supporting state effects in the cond if any(isinstance(eff, state.RefEffect) for eff in cond_jaxpr.effects): raise NotImplementedError + cond_consts_discharge, body_consts_discharge, carry_discharge = split_list( + should_discharge, [cond_nconsts, body_nconsts]) + + if any(cond_consts_discharge): + raise NotImplementedError cond_consts, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts]) cond_consts_avals, body_consts_avals, carry_avals = split_list(in_avals, [cond_nconsts, body_nconsts]) # There shouldn't be any `Ref`s in the `cond` (because of our check above). assert not any(isinstance(aval, state.AbstractRef) for aval in cond_consts_avals) - is_ref = [isinstance(aval, state.AbstractRef) for aval in body_consts_avals] + is_ref = [ + isinstance(aval, state.AbstractRef) and should + for aval, should in zip(body_consts_avals, body_consts_discharge) + ] remaining_body_consts, refs = partition_list(is_ref, body_consts) remaining_body_const_avals, ref_avals = partition_list(is_ref, body_consts_avals) @@ -1886,7 +1894,7 @@ def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr, # Therefore we need to rewrite the jaxpr to shuffle around the `Ref`s so that # they are part of the carry. discharged_body_jaxpr, discharged_consts = state_discharge.discharge_state( - body_jaxpr, ()) + body_jaxpr, (), should_discharge=[*body_consts_discharge, *carry_discharge]) if discharged_consts: raise NotImplementedError def new_body(*consts_refs_carry): @@ -1943,7 +1951,7 @@ def new_cond(*consts_refs_carry): pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom mlir.register_lowering(while_p, _while_lowering) core.custom_typechecks[while_p] = _while_typecheck -state_discharge.register_discharge_rule(while_p)(_while_discharge_rule) +state_discharge.register_partial_discharge_rule(while_p)(_while_partial_discharge_rule) def _pred_bcast_select_hlo(ctx, diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 037d104eea76..3a22086ad1ae 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1430,15 +1430,6 @@ def _run_scoped_lowering_rule( ctx.module_ctx, ctx.launch_ctx, jaxpr, input_refs, consts ) - for o in outs: - # This is definitely one of the accumulators we produced. Each - # run_scoped call is responsible for dereferencing its own - # accumulators. - if isinstance(o, mgpu.WGMMAAccumulator) or ( - isinstance(o, ir.Value) and ir.MemRefType.isinstance(o.type) - ): - raise ValueError(f"No references are allowed to escape a scope. (got {o})") - assert len(outs) == len(jaxpr.outvars), (jaxpr, outs) return outs diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 9c07bcb49b65..ed22fdebe8fd 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -896,22 +896,25 @@ def _run_scoped_discharge_rule( **_): del out_avals num_consts = len(args_flat) + # discharge_state only discharges invars, not consts, so in order to + # discharge the requested refs we need to move them to the invar set. jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr) num_return_values = len(jaxpr_noconst.outvars) - should_discharge = should_discharge + [ - isinstance(var.aval, state.AbstractRef) for var in jaxpr.invars - ] discharged_body, new_consts = state_discharge.discharge_state( - jaxpr_noconst, [], should_discharge=should_discharge) + jaxpr_noconst, + [], + should_discharge=should_discharge + [False] * len(jaxpr.invars), + ) if new_consts: raise NotImplementedError( "Cannot handle new consts created by state discharge.") - # Create inputs filled with uninitialized values to the body. - body_avals = [v.aval for v in discharged_body.invars[num_consts:]] - init_vals = [uninitialized_value( - aval.shape, aval.dtype) for aval in body_avals] - init_vals_with_consts = args_flat + tuple(init_vals) - out = jax_core.eval_jaxpr(discharged_body, [], *init_vals_with_consts) + + # Lowering expects that the jaxpr.consts to be the eqn.invals. + discharged_body = pe.convert_invars_to_constvars(discharged_body, num_consts) + + # Run_scoped discharged the external variables but the scoped ones + # are not discharged. + out = run_scoped_p.bind(*args_flat, jaxpr=discharged_body) # Order of outputs: # (1) return values, (2) closed refs, (3) scoped refs. return_values = out[:num_return_values] @@ -919,8 +922,8 @@ def _run_scoped_discharge_rule( # We update all ref values with their updated values from the discharged # body. For other values we leave them in place. updates = [ - ref_outputs.pop(0) if isinstance(aval, pallas_core.AbstractMemoryRef) - else None for aval in in_avals] + ref_outputs.pop(0) if should and isinstance(aval, pallas_core.AbstractMemoryRef) + else None for should, aval in zip(should_discharge, in_avals)] assert len(updates) == len(in_avals), f'{len(updates)} != {len(in_avals)}' return updates, return_values