Skip to content

Commit

Permalink
add dtype pp rule
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 13, 2024
1 parent cbd8a18 commit bfb8031
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
assert is_masked(masked_tree[0]) is True
```

- Add `dataclasses` rule for `tree_{repr,str}`
- Add `dataclasses`, `dtype` rule for `tree_{repr,str}`

## V0.12

Expand Down
4 changes: 1 addition & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,4 @@ Apache2.0 License.
Indices
=======

* :ref:`genindex`


* :ref:`genindex`
7 changes: 7 additions & 0 deletions sepes/_src/tree_pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,3 +669,10 @@ def _(node: Any) -> str:
return global_info
shard_info = tree_repr(ShapeDTypePP(shard_shape, dtype))
return f"G:{global_info}\nS:{shard_info}"

@tree_str.def_type(type(jax.numpy.float32))
@tree_repr.def_type(type(jax.numpy.float32))
def _(node, **spec: Unpack[PPSpec]) -> str:
out = str(node.dtype)
out = out.replace("float", "f").replace("int", "i").replace("complex", "c")
return f"jax.numpy.{out}"

0 comments on commit bfb8031

Please sign in to comment.