Skip to content

Commit

Permalink
Handle inactive args from context (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Mar 9, 2024
1 parent 3cea3b1 commit d2030a6
Showing 1 changed file with 68 additions and 86 deletions.
154 changes: 68 additions & 86 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,15 @@ def to_jax(ty):
return {"f32": jnp.float32, "f64": jnp.float64}[tystr]


def activity_from_pipeline(pass_pipeline):
start = pass_pipeline.index("argTys=")
end = pass_pipeline.index(" ", start)
acts = pass_pipeline[start + len("argTys=") : end].split(",")
pre_act = pass_pipeline[: start + len("argTys=")]
post_act = pass_pipeline[end:]
return pre_act, acts, post_act


def _enzyme_primal_lowering(
ctx: jax_mlir.LoweringRuleContext,
*args_flat: ir.Value,
Expand Down Expand Up @@ -523,12 +532,10 @@ def _enzyme_primal_lowering(
if i not in in_idx_map or in_idx_map[i] in kept
)
if len(kept) != len(orig_shapes):
post = ",".join(["enzyme_dup"] * len(kept))
prev = ",".join(["enzyme_dup"] * len(orig_shapes))
pass_pipeline = pass_pipeline.replace(prev, post)
post = ",".join(["enzyme_out"] * len(kept))
prev = ",".join(["enzyme_out"] * len(orig_shapes))
pass_pipeline = pass_pipeline.replace(prev, post)
if "argTys=" in pass_pipeline:
pre_act, acts, post_act = activity_from_pipeline(pass_pipeline)
acts2 = [act for (i, act) in enumerate(acts) if i in kept]
pass_pipeline = pre_act + ",".join(acts2) + post_act

out_types = [
shape
Expand Down Expand Up @@ -917,18 +924,36 @@ def cpp_call(


def enzyme_jvp(arg_primals, arg_tangents, **kwargs):
print("arg_tan", arg_tangents)
print("kwargs", kwargs)

# TODO propagate activity info rather than make_zero
def make_zero(tan, prim):
return lax.zeros_like_array(prim) if type(tan) is ad.Zero else tan

arg_tangents = tuple(make_zero(t, p) for (t, p) in zip(arg_tangents, arg_primals))
args = tuple(v for t in zip(arg_primals, arg_tangents) for v in t)

pipeline_options = kwargs["pipeline_options"]

shadconv = None
if pipeline_options.mlir_ad() and kwargs["lang"] == LANG_MHLO:
act_tup = ",".join(["enzyme_dup" for a in arg_primals])
(in_tree, in_idx_map, out_idx_map, mfunc) = kwargs["source"]
act_tup = []
args = []

avals = {}

for idx, (v, s) in enumerate(zip(arg_primals, arg_tangents)):
avals[len(args)] = in_idx_map[idx]
args.append(v)
if type(s) is ad.Zero:
act_tup.append("enzyme_const")
else:
act_tup.append("enzyme_dup")
avals[len(args)] = in_idx_map[idx]
args.append(s)

args = tuple(args)
act_tup = ",".join(act_tup)

afterad = "arith-raise{stablehlo=true}, enzyme-hlo-opt, cse, canonicalize"
newpasses = (
"inline{default-pipeline=canonicalize max-iterations=4},"
Expand All @@ -940,10 +965,10 @@ def make_zero(tan, prim):
if pipeline_options.pass_pipeline() != "":
oldpasses = pipeline_options.pass_pipeline()
if "enzyme-wrap" in oldpasses:
start = passes.rindex("enzyme-wrap{")
end = passes.index("}", start)
prev_passes = passes[:end]
newpasses = prev_passes + afterad + newpasses + passes[end:]
start = oldpasses.rindex("enzyme-wrap{")
end = oldpasses.index("}", start)
prev_passes = oldpasses[:end]
newpasses = prev_passes + afterad + newpasses + oldpasses[end:]
else:
newpasses = newpasses + "," + oldpasses
if pipeline_options.stablehlo_inject():
Expand All @@ -954,10 +979,6 @@ def make_zero(tan, prim):
for o in kwargs["out_shapes"]:
outshapes2.append(o)
outshapes2.append(o)
(in_tree, in_idx_map, out_idx_map, mfunc) = kwargs["source"]
avals = {2 * k: v for k, v in in_idx_map.items()} | {
2 * k + 1: v for k, v in in_idx_map.items()
}
out_idx_map2 = {2 * k: v for k, v in out_idx_map.items()} | {
2 * k + 1: v for k, v in out_idx_map.items()
}
Expand All @@ -972,6 +993,10 @@ def make_zero(tan, prim):
pipeline_options=pipeline_options
)
else:
arg_tangents = tuple(
make_zero(t, p) for (t, p) in zip(arg_tangents, arg_primals)
)
args = tuple(v for t in zip(arg_primals, arg_tangents) for v in t)
shadconv = _enzyme_fwd_p.bind(
*args,
source=kwargs["source"],
Expand Down Expand Up @@ -1027,7 +1052,6 @@ def dejaxify(x):

def fwd_partial_eval(trace, *args, **kwargs):
assert len(args) % 2 == 0
nr_primals = len(args) // 2
primals, tangents = args[0::2], args[1::2]
all_primals_known = all(p.is_known() for p in primals)
some_tangents_unknown = any(not t.is_known() for t in tangents)
Expand All @@ -1050,6 +1074,9 @@ def fwd_partial_eval(trace, *args, **kwargs):


def primal_partial_eval(trace, *args, **kwargs):
print("trace ", trace)
print("args", args)
print("kwargs", kwargs)
pipeline_options = kwargs["pipeline_options"]
if (
not pipeline_options.mlir_ad()
Expand All @@ -1058,73 +1085,20 @@ def primal_partial_eval(trace, *args, **kwargs):
):
return trace.default_process_primitive(_enzyme_primal_p, args, kwargs)

assert len(args) % 2 == 0
nr_primals = len(args) // 2
primals, tangents = args[0::2], args[1::2]
all_primals_known = all(p.is_known() for p in primals)
some_tangents_unknown = any(not t.is_known() for t in tangents)

if not (all_primals_known and some_tangents_unknown):
return trace.default_process_primitive(_enzyme_primal_p, args, kwargs)

shadow_aug_args = primals + tangents

out_shapes = kwargs["out_shapes"]
out_shapes2 = out_shapes[: len(out_shapes) // 2]
del kwargs["out_shapes"]

shadows_known = trace.default_process_primitive(
_enzyme_shadow_aug_p, shadow_aug_args, kwargs | {"out_shapes": out_shapes2}
)

passes = pipeline_options.pass_pipeline()
start = passes.rindex("enzyme-wrap{")
prev_passes = passes[:start]
end = passes.index("}", start)
post_passes = passes[end + 1 :]
newpasses = prev_passes + post_passes[1:]

if pipeline_options.stablehlo_inject():
pipeline_options = JaXPipeline(newpasses)
else:
pipeline_options = NewXLAPipeline(newpasses, pipeline_options.mlir_ad())
_, acts, _ = activity_from_pipeline(pipeline_options.pass_pipeline())

(in_tree, in_idx_map, out_idx_map, mfunc) = kwargs["source"]

avals = {k // 2: v for k, v in in_idx_map.items() if k % 2 == 0}
outmap2 = {k // 2: v for k, v in out_idx_map.items() if k % 2 == 0}
source = (in_tree, avals, outmap2, mfunc)

primalret = trace.default_process_primitive(
_enzyme_primal_p,
primals,
{
"out_shapes": out_shapes2,
"source": source,
"fn": kwargs["fn"],
"argv": kwargs["argv"],
"lang": kwargs["lang"],
"pipeline_options": pipeline_options,
},
)
return primalret + shadows_known
primals = []
tangents = []
avals = {}

for idx, v in enumerate(acts):
avals[idx] = in_idx_map[len(primals) + len(tangents)]
primals.append(args[len(primals) + len(tangents)])
if v == "enzyme_dup":
tangents.append(args[len(primals) + len(tangents)])

pe.custom_partial_eval_rules[_enzyme_primal_p] = primal_partial_eval


def primal_partial_eval(trace, *args, **kwargs):
pipeline_options = kwargs["pipeline_options"]
if (
not pipeline_options.mlir_ad()
or kwargs["lang"] != LANG_MHLO
or pipeline_options.ad_level() == 0
):
return trace.default_process_primitive(_enzyme_primal_p, args, kwargs)

assert len(args) % 2 == 0
nr_primals = len(args) // 2
primals, tangents = args[0::2], args[1::2]
all_primals_known = all(p.is_known() for p in primals)
some_tangents_unknown = any(not t.is_known() for t in tangents)

Expand Down Expand Up @@ -1153,9 +1127,6 @@ def primal_partial_eval(trace, *args, **kwargs):
else:
pipeline_options = NewXLAPipeline(newpasses, pipeline_options.mlir_ad())

(in_tree, in_idx_map, out_idx_map, mfunc) = kwargs["source"]

avals = {k // 2: v for k, v in in_idx_map.items() if k % 2 == 0}
outmap2 = {k // 2: v for k, v in out_idx_map.items() if k % 2 == 0}
source = (in_tree, avals, outmap2, mfunc)

Expand All @@ -1180,14 +1151,16 @@ def primal_partial_eval(trace, *args, **kwargs):
def enzyme_vjp(shadow_rets, *prim_args, **kwargs):
pipeline_options = kwargs["pipeline_options"]
if pipeline_options.mlir_ad() and kwargs["lang"] == LANG_MHLO:
prim_args = prim_args[0 : len(prim_args) // 2]

passes = pipeline_options.pass_pipeline()
start = passes.rindex("enzyme-wrap{")
prev_passes = passes[:start]
end = passes.index("}", start)
post_passes = passes[end + 1 :]
ad_pass = passes[start : end + 1]

_, acts, _ = activity_from_pipeline(ad_pass)

ad_pass = ad_pass.replace("enzyme_dup", "enzyme_out")
ad_pass = ad_pass.replace("ForwardMode", "ReverseModeCombined")
newpasses = (
Expand All @@ -1204,7 +1177,16 @@ def enzyme_vjp(shadow_rets, *prim_args, **kwargs):

(in_tree, in_idx_map, out_idx_map, mfunc) = kwargs["source"]

avals = {k // 2: v for k, v in in_idx_map.items() if k % 2 == 0}
prim_args = prim_args[: len(acts)]

avals = {}
argidx = 0
for idx, v in enumerate(acts):
avals[idx] = in_idx_map[argidx]
argidx += 1
if v == "enzyme_dup":
argidx += 1

outmap = avals

primal_in_shapes = tuple((a.shape, jaxify(a.dtype)) for a in prim_args)
Expand Down

0 comments on commit d2030a6

Please sign in to comment.