What's the best way to get the floating point dtype? #16139
Answered
by
jakevdp
NeilGirdhar
asked this question in
General
-
Right now, I'm doing: @cache
def int_dtype() -> DTypeLike:
return jnp.asarray(0).dtype
@cache
def float_dtype() -> DTypeLike:
return jnp.empty(1).dtype
@cache
def complex_dtype() -> DTypeLike:
return jnp.asarray(1j).dtype |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
May 25, 2023
Replies: 1 comment 1 reply
-
Use from jax import dtypes
default_int = dtypes.canonicalize_dtype(int)
default_float = dtypes.canonicalize_dtype(float)
default_complex = dtypes.canonicalize_dtype(complex) In general you should be careful about caching these values, because e.g. |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
NeilGirdhar
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Use
dtypes.canonicalize_dtype
:In general you should be careful about caching these values, because e.g.
jax.experimental.enable_x64
can enable 64-bit values in a local context.