Skip to content

Commit

Permalink
In progress experimention. Add StringDType to JAX's supported types.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707662268
  • Loading branch information
Google-ML-Automation committed Dec 20, 2024
1 parent 4216f8f commit 6752b2f
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,9 +476,17 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
np.dtype('complex64'),
np.dtype('complex128'),
]
_jax_types = _bool_types + _int_types + _float_types + _complex_types
_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types}

_string_types: list[JAXType] = []
try:
import numpy.dtypes as np_dtypes
if hasattr(np_dtypes, 'StringDType'):
_string_types: list[JAXType] = [np_dtypes.StringDType()] # type: ignore
except ImportError:
np_dtypes = None # type: ignore

_jax_types = _bool_types + _int_types + _float_types + _complex_types + _string_types
_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types, *_string_types}

_dtype_kinds: dict[str, set] = {
'bool': {*_bool_types},
Expand Down

0 comments on commit 6752b2f

Please sign in to comment.