Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pallas/pallas_mgpu] Discharging run_scoped should not be discharging the intermediates #25639

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1848,18 +1848,26 @@ def _while_typecheck(_, *in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
f'Effects not supported in `while`: {disallowed_effects}')
return body_jaxpr.out_avals, joined_effects

def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
cond_nconsts, body_nconsts):
# TODO(sharadmv): enable supporting state effects in the cond
if any(isinstance(eff, state.RefEffect) for eff in cond_jaxpr.effects):
raise NotImplementedError
cond_consts_discharge, body_consts_discharge, carry_discharge = split_list(
should_discharge, [cond_nconsts, body_nconsts])

if any(cond_consts_discharge):
raise NotImplementedError
cond_consts, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
cond_consts_avals, body_consts_avals, carry_avals = split_list(in_avals,
[cond_nconsts,
body_nconsts])
# There shouldn't be any `Ref`s in the `cond` (because of our check above).
assert not any(isinstance(aval, state.AbstractRef) for aval in cond_consts_avals)
is_ref = [isinstance(aval, state.AbstractRef) for aval in body_consts_avals]
is_ref = [
isinstance(aval, state.AbstractRef) and should
for aval, should in zip(body_consts_avals, body_consts_discharge)
]
remaining_body_consts, refs = partition_list(is_ref, body_consts)
remaining_body_const_avals, ref_avals = partition_list(is_ref,
body_consts_avals)
Expand All @@ -1879,7 +1887,7 @@ def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
# Therefore we need to rewrite the jaxpr to shuffle around the `Ref`s so that
# they are part of the carry.
discharged_body_jaxpr, discharged_consts = state_discharge.discharge_state(
body_jaxpr, ())
body_jaxpr, (), should_discharge=[*body_consts_discharge, *carry_discharge])
if discharged_consts: raise NotImplementedError

def new_body(*consts_refs_carry):
Expand Down Expand Up @@ -1936,7 +1944,7 @@ def new_cond(*consts_refs_carry):
pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom
mlir.register_lowering(while_p, _while_lowering)
core.custom_typechecks[while_p] = _while_typecheck
state_discharge.register_discharge_rule(while_p)(_while_discharge_rule)
state_discharge.register_partial_discharge_rule(while_p)(_while_partial_discharge_rule)


def _pred_bcast_select_hlo(ctx,
Expand Down
9 changes: 0 additions & 9 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,15 +1426,6 @@ def _run_scoped_lowering_rule(
ctx.module_ctx, ctx.launch_ctx, jaxpr, input_refs, consts
)

for o in outs:
# This is definitely one of the accumulators we produced. Each
# run_scoped call is responsible for dereferencing its own
# accumulators.
if isinstance(o, mgpu.WGMMAAccumulator) or (
isinstance(o, ir.Value) and ir.MemRefType.isinstance(o.type)
):
raise ValueError(f"No references are allowed to escape a scope. (got {o})")

assert len(outs) == len(jaxpr.outvars), (jaxpr, outs)
return outs

Expand Down
27 changes: 15 additions & 12 deletions jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,31 +893,34 @@ def _run_scoped_discharge_rule(
**_):
del out_avals
num_consts = len(args_flat)
# discharge_state only discharges invars, not consts, so in order to
# discharge the requested refs we need to move them to the invar set.
jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr)
num_return_values = len(jaxpr_noconst.outvars)
should_discharge = should_discharge + [
isinstance(var.aval, state.AbstractRef) for var in jaxpr.invars
]
discharged_body, new_consts = state_discharge.discharge_state(
jaxpr_noconst, [], should_discharge=should_discharge)
jaxpr_noconst,
[],
should_discharge=should_discharge + [False] * len(jaxpr.invars),
)
if new_consts:
raise NotImplementedError(
"Cannot handle new consts created by state discharge.")
# Create inputs filled with uninitialized values to the body.
body_avals = [v.aval for v in discharged_body.invars[num_consts:]]
init_vals = [uninitialized_value(
aval.shape, aval.dtype) for aval in body_avals]
init_vals_with_consts = args_flat + tuple(init_vals)
out = jax_core.eval_jaxpr(discharged_body, [], *init_vals_with_consts)

# Lowering expects that the jaxpr.consts to be the eqn.invals.
discharged_body = pe.convert_invars_to_constvars(discharged_body, num_consts)

# Run_scoped discharged the external variables but the scoped ones
# are not discharged.
out = run_scoped_p.bind(*args_flat, jaxpr=discharged_body)
# Order of outputs:
# (1) return values, (2) closed refs, (3) scoped refs.
return_values = out[:num_return_values]
ref_outputs = out[num_return_values:]
# We update all ref values with their updated values from the discharged
# body. For other values we leave them in place.
updates = [
ref_outputs.pop(0) if isinstance(aval, pallas_core.AbstractMemoryRef)
else None for aval in in_avals]
ref_outputs.pop(0) if should and isinstance(aval, pallas_core.AbstractMemoryRef)
else None for should, aval in zip(should_discharge, in_avals)]
assert len(updates) == len(in_avals), f'{len(updates)} != {len(in_avals)}'
return updates, return_values

Expand Down
Loading