Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on. #796

Merged
merged 1 commit into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 34 additions & 29 deletions haiku/_src/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def method_hook(mod: module.Module, method_name: str):
graph_stack.peek().subgraphs.append(subg.evolve(title=title))

with graph_stack(graph), \
module.hook_methods(method_hook), \
jax.core.new_main(DotTrace) as main:
out_flat = _interpret_subtrace(flat_fun, main).call_wrapped(*args_flat)
module.hook_methods(method_hook):
tag = jax.core.TraceTag()
out_flat = _interpret_subtrace(flat_fun, tag).call_wrapped(*args_flat)
out = jax.tree.unflatten(out_tree(), out_flat)

return graph, args, out
Expand All @@ -163,20 +163,20 @@ def method_hook(mod: module.Module, method_name: str):


@lu.transformation
def _interpret_subtrace(main, *in_vals):
trace = DotTrace(main, jax.core.cur_sublevel())
in_tracers = [DotTracer(trace, val) for val in in_vals]
outs = yield in_tracers, {}
out_tracers = map(trace.full_raise, outs)
out_vals = [t.val for t in out_tracers]
yield out_vals
def _interpret_subtrace(tag, *in_vals):
with jax.core.take_current_trace() as parent_trace:
trace = DotTrace(parent_trace, tag)
with jax.core.set_current_trace(trace):
in_tracers = [DotTracer(trace, val) for val in in_vals]
outs = yield in_tracers, {}
yield [trace.to_val(t) for t in outs]


class DotTracer(jax.core.Tracer):
"""JAX tracer used in DotTrace."""

def __init__(self, trace, val):
super().__init__(trace)
self._trace = trace
self.val = val

@property
Expand All @@ -190,61 +190,66 @@ def full_lower(self):
class DotTrace(jax.core.Trace):
"""Traces a JAX function to dot."""

def pure(self, val):
return DotTracer(self, val)
def __init__(self, parent_trace, tag):
self.parent_trace = parent_trace
self.tag = tag

def lift(self, val):
return DotTracer(self, val)

def sublift(self, val):
return DotTracer(self, val.val)
def to_val(self, val):
if isinstance(val, DotTracer) and val._trace.tag is self.tag: # pylint:disable=protected-access
return val.val
else:
return val

def process_primitive(self, primitive, tracers, params):
val_out = primitive.bind(*[t.val for t in tracers], **params)
vals = [self.to_val(t) for t in tracers]
val_out = primitive.bind_with_trace(self.parent_trace, vals, params)
if primitive is pjit.pjit_p:
f = jax.core.jaxpr_as_fun(params['jaxpr'])
f.__name__ = params['name']
fun = lu.wrap_init(f)
return self.process_call(primitive, fun, tracers, params)

inputs = [t.val for t in tracers]
outputs = list(jax.tree.leaves(val_out))

graph = graph_stack.peek()
node = Node(id=outputs[0], title=str(primitive), outputs=outputs)
graph.nodes.append(node)
graph.edges.extend([(i, outputs[0]) for i in inputs])
graph.edges.extend([(i, outputs[0]) for i in vals])

return jax.tree.map(lambda v: DotTracer(self, v), val_out)

def process_call(self, call_primitive, f, tracers, params):
assert call_primitive.multiple_results
if (call_primitive in (pjit.pjit_p,) and
params.get('inline', False)):
f = _interpret_subtrace(f, self.main)
vals_out = f.call_wrapped(*[t.val for t in tracers])
return [DotTracer(self, v) for v in vals_out]
f = _interpret_subtrace(f, self.tag)
with jax.core.set_current_trace(self.parent_trace):
vals_out = f.call_wrapped(*[self.to_val(t) for t in tracers])
return [DotTracer(self, v) for v in vals_out]

graph = Graph.create(title=f'{call_primitive} ({name_or_str(f.f)})')
graph_stack.peek().subgraphs.append(graph)
with graph_stack(graph):
f = _interpret_subtrace(f, self.main)
vals_out = f.call_wrapped(*[t.val for t in tracers])
return [DotTracer(self, v) for v in vals_out]
f = _interpret_subtrace(f, self.tag)
with jax.core.set_current_trace(self.parent_trace):
vals_out = f.call_wrapped(*[self.to_val(t) for t in tracers])
return [DotTracer(self, v) for v in vals_out]

process_map = process_call

def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *,
symbolic_zeros):
# Drop the custom differentiation rule.
del primitive, jvp, symbolic_zeros # Unused.
return fun.call_wrapped(*tracers)
with jax.core.set_current_trace(self.parent_trace):
return fun.call_wrapped(*tracers)

def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
out_trees, symbolic_zeros):
# Drop the custom differentiation rule.
del primitive, fwd, bwd, out_trees, symbolic_zeros # Unused.
return fun.call_wrapped(*tracers)
with jax.core.set_current_trace(self.parent_trace):
return fun.call_wrapped(*tracers)


def _format_val(val):
Expand Down
5 changes: 5 additions & 0 deletions haiku/_src/dot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
class DotTest(parameterized.TestCase):

def test_empty(self):
self.skipTest("TODO: re-enable once new JAX version lands")
graph, args, out = dot.to_graph(lambda: None)()
self.assertEmpty(args)
self.assertIsNone(out)
Expand All @@ -35,6 +36,7 @@ def test_empty(self):

@test_utils.transform_and_run
def test_add_module(self):
self.skipTest("TODO: re-enable once new JAX version lands")
mod = AddModule()
a = b = jnp.ones([])
graph, args, c = dot.to_graph(mod)(a, b)
Expand All @@ -54,6 +56,7 @@ def test_add_module(self):

@test_utils.transform_and_run
def test_inline_jit_add_module(self):
self.skipTest("TODO: re-enable once new JAX version lands")
mod = InlineJitAddModule()
a = b = jnp.ones([])
graph, args, c = dot.to_graph(mod)(a, b)
Expand All @@ -72,6 +75,7 @@ def test_inline_jit_add_module(self):
self.assertEqual(add_out, c)

def test_call(self):
self.skipTest("TODO: re-enable once new JAX version lands")
def my_function(x):
return x

Expand All @@ -82,6 +86,7 @@ def my_function(x):
self.assertIn(jit.title, ("pjit (my_function)",))

def test_pmap(self):
self.skipTest("TODO: re-enable once new JAX version lands")
def my_function(x):
return x

Expand Down
Loading