Skip to content

How to test if a value is an jax dtype? #25497

Answered by jakevdp
samuela asked this question in Q&A
Discussion options

You must be logged in to vote

You can use isinstance(thing, jnp.dtype) if you have a dtype; for example:

dt = jnp.dtype('int8')
isinstance(dt, jnp.dtype)  # True

Where it gets confusing is that jnp.int8 is not a dtype, it is a scalar constructor. JAX inherits this behavior from NumPy, where numpy.dtype('int8') is a dtype, while numpy.int8 is a scalar constructor which is not an instance of dtype.

If you are hoping to work explicitly with dtypes, you can always cast scalar constructors to dtypes using something like jnp.dtype(jnp.int8); if that doesn't help then perhaps you could share more about the problem you're trying to solve.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@samuela
Comment options

Answer selected by samuela
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants