Skip to content

Commit

Permalink
Fixes jax.core.Var is deprecated. Use jax.extend.core.Var instead.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705401152
  • Loading branch information
tomhennigan authored and copybara-github committed Dec 12, 2024
1 parent 19a2ed6 commit 01e149b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
1 change: 1 addition & 0 deletions haiku/_src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ hk_py_library(
deps = [
":summarise",
# pip: jax
# pip: jax/extend:core
],
)

Expand Down
37 changes: 21 additions & 16 deletions haiku/_src/jaxpr_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,16 @@
import logging
import os
import sys
from typing import Any, NamedTuple
from typing import Any, NamedTuple, TypeAlias

from haiku._src import summarise
import jax
import jax.core
from jax.extend import core as jax_core

# TODO(tomhennigan): Update to use symbols from jax.extend.core when available.
Atom: TypeAlias = jax.core.Atom
DropVar: TypeAlias = jax.core.DropVar


@dataclasses.dataclass
Expand Down Expand Up @@ -98,7 +103,7 @@ class Expression:
name_stack: Sequence[str] = dataclasses.field(default_factory=list)


ComputeFlopsFn = Callable[[jax.core.JaxprEqn, Expression], int]
ComputeFlopsFn = Callable[[jax_core.JaxprEqn, Expression], int]


def make_model_info(
Expand Down Expand Up @@ -221,7 +226,7 @@ class _ModuleScope(NamedTuple):
# identify which computations we've already seen.
named_call_id: str

def join(self, eqn: jax.core.JaxprEqn) -> '_ModuleScope':
def join(self, eqn: jax_core.JaxprEqn) -> '_ModuleScope':
return _ModuleScope(
named_call_id=self.named_call_id + '/' + str(id(eqn)))

Expand All @@ -231,9 +236,9 @@ def _format_shape(var):


def _mark_seen(
binder_idx: dict[jax.core.Var, int],
binder_idx: dict[jax_core.Var, int],
seen: set[str],
var: jax.core.Var,
var: jax_core.Var,
scope: _ModuleScope,
) -> bool:
"""Marks a variable as seen. Returns True if it was not previously seen."""
Expand All @@ -250,14 +255,14 @@ def _var_sort_key(s: str):


def _var_to_str(
binder_idx: dict[jax.core.Var, int], atom: jax.core.Atom
binder_idx: dict[jax_core.Var, int], atom: Atom
) -> str:
"""Returns an atom name based on var binding order in its containing jaxpr."""
if isinstance(atom, jax.core.DropVar):
if isinstance(atom, DropVar):
return '_'
if isinstance(atom, jax.core.Literal):
if isinstance(atom, jax_core.Literal):
return str(atom)
assert isinstance(atom, jax.core.Var)
assert isinstance(atom, jax_core.Var)
n = binder_idx[atom]
s = ''
while not s or n:
Expand All @@ -267,13 +272,13 @@ def _var_to_str(


def _process_eqn(
eqn: jax.core.JaxprEqn,
eqn: jax_core.JaxprEqn,
seen: set[str],
eqns_by_output: Mapping[str, jax.core.JaxprEqn],
eqns_by_output: Mapping[str, jax_core.JaxprEqn],
compute_flops: ComputeFlopsFn | None,
scope: _ModuleScope,
module: Module,
binder_idx: dict[jax.core.Var, int],
binder_idx: dict[jax_core.Var, int],
) -> int | None:
"""Recursive walks the JaxprEqn to compute the flops it takes."""
for out_var in eqn.outvars:
Expand Down Expand Up @@ -369,7 +374,7 @@ def _process_eqn(
module.expressions.append(expression)

for var in eqn.invars:
if isinstance(var, jax.core.Literal):
if isinstance(var, jax_core.Literal):
continue

key = _var_to_str(binder_idx, var)
Expand All @@ -391,14 +396,14 @@ def _process_eqn(


def _process_jaxpr(
jaxpr: jax.core.Jaxpr,
jaxpr: jax_core.Jaxpr,
compute_flops: ComputeFlopsFn | None,
scope: _ModuleScope,
seen: set[str],
module: Module,
) -> int | None:
"""Computes the flops used for a JAX expression, tracking module scope."""
if isinstance(jaxpr, jax.core.ClosedJaxpr):
if isinstance(jaxpr, jax_core.ClosedJaxpr):
return _process_jaxpr(jaxpr.jaxpr, compute_flops, scope, seen, module)

# Label variables by the order in which they're introduced.
Expand All @@ -420,7 +425,7 @@ def _process_jaxpr(
# Recursively walk the computation graph.
flops = None if compute_flops is None else 0
for var in jaxpr.outvars:
if (isinstance(var, jax.core.Var) and
if (isinstance(var, jax_core.Var) and
_mark_seen(binder_idx, seen, var, scope)):
f = _process_eqn(eqns_by_output[_var_to_str(binder_idx, var)], seen,
eqns_by_output, compute_flops, scope, module, binder_idx)
Expand Down

0 comments on commit 01e149b

Please sign in to comment.