diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 52cb3d87bbda..86e86e89e9d5 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -476,9 +476,17 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, np.dtype('complex64'), np.dtype('complex128'), ] -_jax_types = _bool_types + _int_types + _float_types + _complex_types -_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types} +_string_types: list[JAXType] = [] +try: + import numpy.dtypes as np_dtypes + if hasattr(np_dtypes, 'StringDType'): + _string_types: list[JAXType] = [np_dtypes.StringDType()] # type: ignore +except ImportError: + np_dtypes = None # type: ignore + +_jax_types = _bool_types + _int_types + _float_types + _complex_types + _string_types +_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types, *_string_types} _dtype_kinds: dict[str, set] = { 'bool': {*_bool_types},