Is the full set of possible types which could be arguments to JAX primitives available somewhere? #16634
-
@patrick-kidger pointed out to me that I'm overly assuming that things will be jax.Array instances, but in practice they could be other things as well. I was hoping there would be a common ancestor of all such things, but it looks like e.g. >>> jax.Array.__mro__
(<class 'jax.Array'>, <class 'abc.ABC'>, <class 'object'>)
>>> jax.core.Token.__mro__
(<class 'jax._src.core.Token'>, <class 'object'>) I can use |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
JAX makes frequent use of duck typing, so there is no single base class for all types that can be used in place of arrays. If you're interested in checking if a particular value is a valid jax array argument, you can use I would avoid trying to enumerate things like |
Beta Was this translation helpful? Give feedback.
JAX makes frequent use of duck typing, so there is no single base class for all types that can be used in place of arrays.
If you're interested in checking if a particular value is a valid jax array argument, you can use
isinstance(x, jax.Array)
. Thejax.Array
class is an abstract base class that has its__instancecheck__
overridden to return the right thing.I would avoid trying to enumerate things like
Token
,Ref
, etc. as these are considered internal implementation details that may change from release to release. If there's any operation that you cannot do via the public API, please open a bug with the details.