From bfb80312a8f74dccac8adb972a2052a61ad8e85a Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Sat, 13 Apr 2024 21:02:04 +0900 Subject: [PATCH] add dtype pp rule --- CHANGELOG.md | 2 +- docs/index.rst | 4 +--- sepes/_src/tree_pprint.py | 7 +++++++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f853bc5..c0230a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,7 +39,7 @@ assert is_masked(masked_tree[0]) is True ``` -- Add `dataclasses` rule for `tree_{repr,str}` +- Add `dataclasses`, `dtype` rule for `tree_{repr,str}` ## V0.12 diff --git a/docs/index.rst b/docs/index.rst index caf145a..da44229 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -31,6 +31,4 @@ Apache2.0 License. Indices ======= -* :ref:`genindex` - - +* :ref:`genindex` \ No newline at end of file diff --git a/sepes/_src/tree_pprint.py b/sepes/_src/tree_pprint.py index af9bb0c..f7ce9ff 100644 --- a/sepes/_src/tree_pprint.py +++ b/sepes/_src/tree_pprint.py @@ -669,3 +669,10 @@ def _(node: Any) -> str: return global_info shard_info = tree_repr(ShapeDTypePP(shard_shape, dtype)) return f"G:{global_info}\nS:{shard_info}" + + @tree_str.def_type(type(jax.numpy.float32)) + @tree_repr.def_type(type(jax.numpy.float32)) + def _(node, **spec: Unpack[PPSpec]) -> str: + out = str(node.dtype) + out = out.replace("float", "f").replace("int", "i").replace("complex", "c") + return f"jax.numpy.{out}"