Skip to content

Commit

Permalink
print tracer type in tree repr/str
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Mar 20, 2024
1 parent 6e09014 commit 587753c
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 39 deletions.
28 changes: 14 additions & 14 deletions sepes/_src/tree_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,11 @@ class TreeClass(metaclass=TreeClassMeta):
the tree. for example:
>>> @sp.leafwise
... @sp.autoinit
... class Tree(sp.TreeClass):
... a:int = 1
... b:float = 2.0
>>> tree = Tree()
... def __init__(self, a:int, b:float):
... self.a = a
... self.b = b
>>> tree = Tree(a=1, b=2.0)
>>> tree + 1 # will add 1 to each leaf
Tree(a=2, b=3.0)
Expand All @@ -168,11 +168,11 @@ class TreeClass(metaclass=TreeClassMeta):
used to ``get``, ``set``, or ``apply`` a function to a leaf or a group of
leaves using ``leaf`` name, index or by a boolean mask.
>>> @sp.autoinit
... class Tree(sp.TreeClass):
... a:int = 1
... b:float = 2.0
>>> tree = Tree()
>>> class Tree(sp.TreeClass):
... def __init__(self, a:int, b:float):
... self.a = a
... self.b = b
>>> tree = Tree(a=1, b=2.0)
>>> tree.at["a"].get()
Tree(a=1, b=None)
>>> tree.at[0].get()
Expand Down Expand Up @@ -274,14 +274,14 @@ def at(self) -> AtIndexer[Self]:
Example:
>>> import sepes as sp
>>> @sp.autoinit
... class Tree(sp.TreeClass):
... a: int = 1
... b: float = 2.0
>>> class Tree(sp.TreeClass):
... def __init__(self, a:int, b:float):
... self.a = a
... self.b = b
... def add(self, x: int) -> int:
... self.a += x
... return self.a
>>> tree = Tree()
>>> tree = Tree(a=1, b=2.0)
>>> tree.at["a"].get()
Tree(a=1, b=None)
>>> tree.at["a"].set(100)
Expand Down
24 changes: 2 additions & 22 deletions sepes/_src/tree_pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _(node: ft.partial, **spec: Unpack[PPSpec]) -> str:
func = tree_str.pp(node.func, **spec)
args = tree_str.pps(tree_str.pp, node.args, **spec)
keywords = tree_str.pps(tree_str.kv_pp, node.keywords, **spec)
return f"Partial(" + ",".join([func, args, keywords]) + ")"
return "partial(" + ",".join([func, args, keywords]) + ")"


@tree_str.def_type(list)
Expand Down Expand Up @@ -448,26 +448,6 @@ def tree_summary(
│Σ │list │6 │12.00B│
└─────────┴──────┴─────┴──────┘
Example:
Set custom type display for ``jax`` jaxprs
>>> import jax
>>> import sepes as sp
>>> ClosedJaxprType = type(jax.make_jaxpr(lambda x: x)(1))
>>> @sp.tree_summary.def_type(ClosedJaxprType)
... def _(expr: ClosedJaxprType) -> str:
... jaxpr = expr.jaxpr
... return f"Jaxpr({jaxpr.invars}, {jaxpr.outvars})"
>>> def func(x, y):
... return x
>>> jaxpr = jax.make_jaxpr(func)(1, 2)
>>> print(sp.tree_summary(jaxpr))
┌────┬──────────────────┬─────┬────┐
│Name│Type │Count│Size│
├────┼──────────────────┼─────┼────┤
│Σ │Jaxpr([a, b], [a])│1 │ │
└────┴──────────────────┴─────┴────┘
Example:
Display flops of a function in tree summary
Expand Down Expand Up @@ -628,7 +608,7 @@ def _(node, **spec: Unpack[PPSpec]) -> str:
shape = node.aval.shape
dtype = node.aval.dtype
string = tree_repr.dispatch(ShapeDTypePP(shape, dtype), **spec)
return f"Tracer({string})"
return f"{type(node).__name__}({string})"

# handle the sharding info if it is sharded
@tree_summary.def_type(jax.Array)
Expand Down
3 changes: 2 additions & 1 deletion sepes/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def leafwise(klass: type[T]) -> type[T]:
The decorated class.
Example:
>>> # use ``numpy`` functions on :class:`TreeClass`` classes decorated with ``leafwise``
Use ``numpy`` functions on :class:`TreeClass`` classes decorated with :func:`leafwise`
>>> import sepes as sp
>>> import jax.numpy as jnp
>>> @sp.leafwise
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,9 @@ def test_tracer_repr():
@jax.jit
def f(x):
out = tree_repr(x)
assert out == "Tracer(f32[10,10])"
assert out == "DynamicJaxprTracer(f32[10,10])"
out = tree_str(x)
assert out == "Tracer(f32[10,10])"
assert out == "DynamicJaxprTracer(f32[10,10])"
return x

f(jax.numpy.ones((10, 10)))
Expand Down

0 comments on commit 587753c

Please sign in to comment.