Skip to content

What's the best way to get the floating point dtype? #16139

Answered by jakevdp
NeilGirdhar asked this question in General
Discussion options

You must be logged in to vote

Use dtypes.canonicalize_dtype:

from jax import dtypes
default_int = dtypes.canonicalize_dtype(int)
default_float = dtypes.canonicalize_dtype(float)
default_complex = dtypes.canonicalize_dtype(complex)

In general you should be careful about caching these values, because e.g. jax.experimental.enable_x64 can enable 64-bit values in a local context.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@NeilGirdhar
Comment options

Answer selected by NeilGirdhar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants