Skip to content

Commit

Permalink
[pallas] Support DMA start partial discharge and run_scoped() does it…
Browse files Browse the repository at this point in the history
…s own partial discharge.

This CL lays the ground for a future CL that makes run_scoped discharge to not request the discharge of the temporary buffers it creates. This causes issues becausa

a) dma_start can't discharge some but not all its references
b) run_scoped() lowering depends on run_scoped discharge to remove the run_scoped operation (or it goes in an infinite loop).

PiperOrigin-RevId: 722126566
  • Loading branch information
cperivol authored and Google-ML-Automation committed Feb 1, 2025
1 parent eb04fcb commit 8649132
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 37 deletions.
90 changes: 64 additions & 26 deletions jax/_src/pallas/mosaic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,8 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,

jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn

def dma_start_discharge_rule(in_avals, out_avals,
*args, tree, device_id_type):
def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals,
*args, tree, device_id_type):
(
src_ref,
src_transforms,
Expand All @@ -575,7 +575,22 @@ def dma_start_discharge_rule(in_avals, out_avals,
_,
) = tree_util.tree_unflatten(tree, in_avals)
del out_avals

(
_,
_,
dst_discharge,
_,
dst_sem_discharge,
_,
*maybe_src_sem_discharge,
) = tree_util.tree_unflatten(tree, should_discharge)
is_remote = device_id is not None
src_sem_discharge = None

if is_remote:
src_sem_discharge = maybe_src_sem_discharge[0]

if not is_remote:
# Local async copies only use one semaphore.
assert src_sem is None
Expand All @@ -586,7 +601,7 @@ def dma_start_discharge_rule(in_avals, out_avals,
num_src_transform_vals = len(tree_util.tree_leaves(src_transforms_avals))
num_dst_transform_vals = len(tree_util.tree_leaves(dst_transforms_avals))

updates = state_discharge.transform_array(src_ref, src_transforms)
updates = state_discharge.transform_array(src_ref[...], src_transforms)
local_src = updates

if is_remote:
Expand Down Expand Up @@ -641,47 +656,61 @@ def dma_start_discharge_rule(in_avals, out_avals,
global_dst_transforms,
)

_, new_dst = state_discharge.transform_swap_array(
dst_ref, dst_transforms, updates
)
def do_discharge_dst(dst_ref=dst_ref):
_, ret = state_discharge.transform_swap_array(
dst_ref, dst_transforms, updates
)
return ret

# Update semaphore values.
# TODO(justinfu): Potentially handle asymmetric copy sizes.
recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
dst_sem_value = _transform_semaphore(
dst_sem, dst_sem_transforms, dst_sem_aval
)
_, new_dst_sem = state_discharge.transform_swap_array(
dst_sem, dst_sem_transforms, dst_sem_value + recv_size
)
if is_remote:
def do_discharge_dst_sem(dst_sem=dst_sem):
recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
dst_sem_value = _transform_semaphore(
dst_sem, dst_sem_transforms, dst_sem_aval
)
_, ret = state_discharge.transform_swap_array(
dst_sem, dst_sem_transforms, dst_sem_value[...] + recv_size
)
return ret

def do_discharge_src_sem(src_sem=src_sem):
send_size = jnp.minimum(local_src.size, pl_core.SEMAPHORE_MAX_VALUE)
send_size = jnp.array(send_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
src_sem_value = _transform_semaphore(
src_sem, src_sem_transforms, src_sem_aval
)
_, new_src_sem = state_discharge.transform_swap_array(
src_sem, src_sem_transforms, src_sem_value + send_size
_, ret = state_discharge.transform_swap_array(
src_sem, src_sem_transforms, src_sem_value[...] + send_size
)
else:
new_src_sem = None
return ret

new_vals = (None,) # src_val
new_vals += (None,) * num_src_transform_vals
new_vals += (new_dst,) # dst_val
new_vals += (do_discharge_dst() if dst_discharge else None,) # dst_val
new_vals += (None,) * num_dst_transform_vals
new_vals += (new_dst_sem,) # dst_sem
new_vals += (do_discharge_dst_sem() if dst_sem_discharge else None,) # dst_sem
new_vals += (None,) * num_dst_sem_transforms
if is_remote:
new_vals += (new_src_sem,) # src_sem
new_vals += (do_discharge_src_sem() if src_sem_discharge else None,) # src_sem
new_vals += (None,) * num_src_sem_transforms
new_vals += (None,) # device_id
assert (len(new_vals) ==
len(in_avals)), f"{len(new_vals), new_vals} != {len(in_avals)}"

# If we didn't discharge everything we could we should keep writes
# to the references that are left over.
if not dst_discharge:
sp.ref_set(dst_ref, None, do_discharge_dst(dst_ref=dst_ref[...]))
if not dst_sem_discharge:
sp.ref_set(dst_sem, None, do_discharge_dst_sem(dst_sem=dst_sem[...]))
if is_remote and not src_sem_discharge:
sp.ref_set(src_sem, None, do_discharge_src_sem(src_sem=src_sem[...]))

return new_vals, []

state_discharge.register_discharge_rule(dma_start_p)(dma_start_discharge_rule)
state_discharge.register_partial_discharge_rule(dma_start_p)(dma_start_partial_discharge_rule)


dma_wait_p = jax_core.Primitive('dma_wait')
Expand Down Expand Up @@ -719,8 +748,9 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,

jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn

def dma_wait_discharge_rule(in_avals, out_avals,
*args, tree, device_id_type):
def dma_wait_partial_discharge_rule(should_discharge,
in_avals, out_avals,
*args, tree, device_id_type):
# TODO(b/370563115): perform ref update in dma_wait discharge rule instead of dma_start
del out_avals, device_id_type
_, _, dst_ref, dst_ref_transforms, dst_sem, dst_sem_transforms, _, _, _ = (
Expand All @@ -735,6 +765,14 @@ def dma_wait_discharge_rule(in_avals, out_avals,
src_sem_transforms_avals,
device_id_aval,
) = tree_util.tree_unflatten(tree, in_avals)

# The only one we can discharge is the dst semaphore. The provided
# buffers are only specified for their types and not their value so
# it's completely irrelevant for us here if they are discharged.
should_discharge_unflattened = tree_util.tree_unflatten(tree, should_discharge)
if not should_discharge_unflattened[4]:
return (None,) * len(in_avals), []

num_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals))
num_transforms = len(tree_util.tree_leaves(dst_ref_transforms_avals))
updates = state_discharge.transform_array(dst_ref, dst_ref_transforms)
Expand All @@ -754,7 +792,7 @@ def dma_wait_discharge_rule(in_avals, out_avals,
new_vals += (None,) * len(tree_util.tree_leaves(src_sem_transforms_avals))
new_vals += (None,) * len(tree_util.tree_leaves(device_id_aval)) # device_id
return new_vals, []
state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule)
state_discharge.register_partial_discharge_rule(dma_wait_p)(dma_wait_partial_discharge_rule)

def _get_ref_and_transforms(ref):
if isinstance(ref, state.TransformedRef):
Expand Down
25 changes: 14 additions & 11 deletions jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,17 +931,20 @@ def _run_scoped_discharge_rule(

@functools.partial(mlir.register_lowering, run_scoped_p)
def _run_scoped_lowering_rule(ctx, *args, jaxpr):
# This lowering rule gets triggered when run_scoped is not discharged.
# In this case there are no stateful effects to handle.
should_discharge = [
isinstance(aval, state.AbstractRef) for aval in ctx.avals_in
]
jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr)
num_return_values = len(jaxpr_noconst.outvars)
discharged_body, new_consts = state_discharge.discharge_state(
jaxpr_noconst, [], should_discharge=True)
if new_consts: raise NotImplementedError(
"Cannot handle new consts created by state discharge.")

def _lower_fun(*lower_fun_args):
updates, out = _run_scoped_discharge_rule(
should_discharge,
[], [], *lower_fun_args,
jaxpr=jaxpr)
assert len(updates) == 0, 'Cannot lower run_scoped with effects.'
return out
# Create inputs filled with uninitialized values to the body.
num_consts = len(lower_fun_args)
body_avals = [v.aval for v in discharged_body.invars[num_consts:]]
init_vals = [uninitialized_value(
aval.shape, aval.dtype) for aval in body_avals]
out = jax_core.eval_jaxpr(discharged_body, [], *lower_fun_args, *init_vals)
return out[:num_return_values]

return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args)

0 comments on commit 8649132

Please sign in to comment.