Skip to content

Commit

Permalink
[Oryx] Make linear an optional parameter for cond_p.
Browse files Browse the repository at this point in the history
In preparation for jax-ml/jax#22119, add support for
`cond_p` even without the `linear` parameter included.

PiperOrigin-RevId: 647427206
  • Loading branch information
dfm authored and The oryx Authors committed Jun 27, 2024
1 parent e29141b commit 21722fa
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions oryx/core/interpreters/harvest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,7 +1097,7 @@ def _check_branch_metadata(branch_metadatas):
raise ValueError(f'Mismatched dtype between branches: \'{name}\'.')


def _reap_cond_rule(trace, *tracers, branches, linear):
def _reap_cond_rule(trace, *tracers, branches, linear=None):
"""Reaps each path of the `cond`."""
index_tracer, ops_tracers = tracers[0], tracers[1:]
index_val, ops_vals = tree_util.tree_map(lambda x: x.val,
Expand All @@ -1122,11 +1122,17 @@ def _reap_cond_rule(trace, *tracers, branches, linear):
new_branch_jaxprs, consts, out_trees = (
lcf._initial_style_jaxprs_with_common_consts( # pylint: disable=protected-access
reaped_branches, in_tree, ops_avals, lax.cond_p.name))
out = lax.cond_p.bind(
index_val,
*(tuple(consts) + ops_vals),
branches=tuple(new_branch_jaxprs),
linear=(False,) * len(tuple(consts) + linear))
if linear is None:
out = lax.cond_p.bind(
index_val,
*(tuple(consts) + ops_vals),
branches=tuple(new_branch_jaxprs))
else:
out = lax.cond_p.bind(
index_val,
*(tuple(consts) + ops_vals),
branches=tuple(new_branch_jaxprs),
linear=(False,) * len(tuple(consts) + linear))
out = jax_util.safe_map(trace.pure, out)
out, reaps, preds = tree_util.tree_unflatten(out_trees[0], out)
for k, v in reaps.items():
Expand Down Expand Up @@ -1558,7 +1564,7 @@ def new_body(*carry):
plant_custom_rules[lcf.while_p] = _plant_while_rule


def _plant_cond_rule(trace, *tracers, branches, linear):
def _plant_cond_rule(trace, *tracers, branches, linear=None):
"""Injects the same values into both branches of a conditional."""
index_tracer, ops_tracers = tracers[0], tracers[1:]
index_val, ops_vals = tree_util.tree_map(lambda x: x.val,
Expand All @@ -1584,11 +1590,17 @@ def _plant_cond_rule(trace, *tracers, branches, linear):
new_branch_jaxprs, consts, _ = (
lcf._initial_style_jaxprs_with_common_consts( # pylint: disable=protected-access
planted_branches, in_tree, ops_avals, lax.cond_p.name))
out = lax.cond_p.bind(
index_val,
*(tuple(consts) + ops_vals),
branches=tuple(new_branch_jaxprs),
linear=(False,) * len(tuple(consts) + linear))
if linear is None:
out = lax.cond_p.bind(
index_val,
*(tuple(consts) + ops_vals),
branches=tuple(new_branch_jaxprs))
else:
out = lax.cond_p.bind(
index_val,
*(tuple(consts) + ops_vals),
branches=tuple(new_branch_jaxprs),
linear=(False,) * len(tuple(consts) + linear))
return jax_util.safe_map(trace.pure, out)


Expand Down

0 comments on commit 21722fa

Please sign in to comment.