From 587753ca9237de42522b628d32dd32edf19ea434 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Thu, 21 Mar 2024 01:46:03 +0900 Subject: [PATCH] print tracer type in tree repr/str --- sepes/_src/tree_base.py | 28 ++++++++++++++-------------- sepes/_src/tree_pprint.py | 24 ++---------------------- sepes/_src/tree_util.py | 3 ++- tests/test_pprint.py | 4 ++-- 4 files changed, 20 insertions(+), 39 deletions(-) diff --git a/sepes/_src/tree_base.py b/sepes/_src/tree_base.py index d53dcb4..f624363 100644 --- a/sepes/_src/tree_base.py +++ b/sepes/_src/tree_base.py @@ -155,11 +155,11 @@ class TreeClass(metaclass=TreeClassMeta): the tree. for example: >>> @sp.leafwise - ... @sp.autoinit ... class Tree(sp.TreeClass): - ... a:int = 1 - ... b:float = 2.0 - >>> tree = Tree() + ... def __init__(self, a:int, b:float): + ... self.a = a + ... self.b = b + >>> tree = Tree(a=1, b=2.0) >>> tree + 1 # will add 1 to each leaf Tree(a=2, b=3.0) @@ -168,11 +168,11 @@ class TreeClass(metaclass=TreeClassMeta): used to ``get``, ``set``, or ``apply`` a function to a leaf or a group of leaves using ``leaf`` name, index or by a boolean mask. - >>> @sp.autoinit - ... class Tree(sp.TreeClass): - ... a:int = 1 - ... b:float = 2.0 - >>> tree = Tree() + >>> class Tree(sp.TreeClass): + ... def __init__(self, a:int, b:float): + ... self.a = a + ... self.b = b + >>> tree = Tree(a=1, b=2.0) >>> tree.at["a"].get() Tree(a=1, b=None) >>> tree.at[0].get() @@ -274,14 +274,14 @@ def at(self) -> AtIndexer[Self]: Example: >>> import sepes as sp - >>> @sp.autoinit - ... class Tree(sp.TreeClass): - ... a: int = 1 - ... b: float = 2.0 + >>> class Tree(sp.TreeClass): + ... def __init__(self, a:int, b:float): + ... self.a = a + ... self.b = b ... def add(self, x: int) -> int: ... self.a += x ... return self.a - >>> tree = Tree() + >>> tree = Tree(a=1, b=2.0) >>> tree.at["a"].get() Tree(a=1, b=None) >>> tree.at["a"].set(100) diff --git a/sepes/_src/tree_pprint.py b/sepes/_src/tree_pprint.py index 198a266..923196d 100644 --- a/sepes/_src/tree_pprint.py +++ b/sepes/_src/tree_pprint.py @@ -182,7 +182,7 @@ def _(node: ft.partial, **spec: Unpack[PPSpec]) -> str: func = tree_str.pp(node.func, **spec) args = tree_str.pps(tree_str.pp, node.args, **spec) keywords = tree_str.pps(tree_str.kv_pp, node.keywords, **spec) - return f"Partial(" + ",".join([func, args, keywords]) + ")" + return "partial(" + ",".join([func, args, keywords]) + ")" @tree_str.def_type(list) @@ -448,26 +448,6 @@ def tree_summary( │Σ │list │6 │12.00B│ └─────────┴──────┴─────┴──────┘ - Example: - Set custom type display for ``jax`` jaxprs - - >>> import jax - >>> import sepes as sp - >>> ClosedJaxprType = type(jax.make_jaxpr(lambda x: x)(1)) - >>> @sp.tree_summary.def_type(ClosedJaxprType) - ... def _(expr: ClosedJaxprType) -> str: - ... jaxpr = expr.jaxpr - ... return f"Jaxpr({jaxpr.invars}, {jaxpr.outvars})" - >>> def func(x, y): - ... return x - >>> jaxpr = jax.make_jaxpr(func)(1, 2) - >>> print(sp.tree_summary(jaxpr)) - ┌────┬──────────────────┬─────┬────┐ - │Name│Type │Count│Size│ - ├────┼──────────────────┼─────┼────┤ - │Σ │Jaxpr([a, b], [a])│1 │ │ - └────┴──────────────────┴─────┴────┘ - Example: Display flops of a function in tree summary @@ -628,7 +608,7 @@ def _(node, **spec: Unpack[PPSpec]) -> str: shape = node.aval.shape dtype = node.aval.dtype string = tree_repr.dispatch(ShapeDTypePP(shape, dtype), **spec) - return f"Tracer({string})" + return f"{type(node).__name__}({string})" # handle the sharding info if it is sharded @tree_summary.def_type(jax.Array) diff --git a/sepes/_src/tree_util.py b/sepes/_src/tree_util.py index bbf705b..d21c685 100644 --- a/sepes/_src/tree_util.py +++ b/sepes/_src/tree_util.py @@ -226,7 +226,8 @@ def leafwise(klass: type[T]) -> type[T]: The decorated class. Example: - >>> # use ``numpy`` functions on :class:`TreeClass`` classes decorated with ``leafwise`` + Use ``numpy`` functions on :class:`TreeClass`` classes decorated with :func:`leafwise` + >>> import sepes as sp >>> import jax.numpy as jnp >>> @sp.leafwise diff --git a/tests/test_pprint.py b/tests/test_pprint.py index 018300b..0d29cd6 100644 --- a/tests/test_pprint.py +++ b/tests/test_pprint.py @@ -243,9 +243,9 @@ def test_tracer_repr(): @jax.jit def f(x): out = tree_repr(x) - assert out == "Tracer(f32[10,10])" + assert out == "DynamicJaxprTracer(f32[10,10])" out = tree_str(x) - assert out == "Tracer(f32[10,10])" + assert out == "DynamicJaxprTracer(f32[10,10])" return x f(jax.numpy.ones((10, 10)))