Skip to content

Commit e62f724

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Set __module__ on jax.ShapeDtypeStruct.
PiperOrigin-RevId: 826026748
1 parent 0772c77 commit e62f724

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

jax/_src/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
tuple_delete, cache,
5151
HashableFunction, HashableWrapper, weakref_lru_cache,
5252
partition_list, StrictABCMeta, foreach,
53-
weakref_cache_key_types)
53+
weakref_cache_key_types, set_module)
5454
import jax._src.pretty_printer as pp
5555
from jax._src.named_sharding import NamedSharding
5656
from jax._src.sharding import Sharding
@@ -3653,6 +3653,7 @@ def _check_map(ctx_factory, prim, in_avals, params):
36533653

36543654
# ------------------- ShapeDtypeStruct -------------------
36553655

3656+
@set_module("jax")
36563657
class ShapeDtypeStruct:
36573658
"""A container for the shape, dtype, and other static attributes of an array.
36583659

0 commit comments

Comments
 (0)