From 6752b2ff2d8f8e469873462b411b5e156b58b7bb Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 18 Dec 2024 14:05:49 -0800 Subject: [PATCH] In progress experimention. Add StringDType to JAX's supported types. PiperOrigin-RevId: 707662268 --- jax/_src/dtypes.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 52cb3d87bbda..86e86e89e9d5 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -476,9 +476,17 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, np.dtype('complex64'), np.dtype('complex128'), ] -_jax_types = _bool_types + _int_types + _float_types + _complex_types -_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types} +_string_types: list[JAXType] = [] +try: + import numpy.dtypes as np_dtypes + if hasattr(np_dtypes, 'StringDType'): + _string_types: list[JAXType] = [np_dtypes.StringDType()] # type: ignore +except ImportError: + np_dtypes = None # type: ignore + +_jax_types = _bool_types + _int_types + _float_types + _complex_types + _string_types +_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types, *_string_types} _dtype_kinds: dict[str, set] = { 'bool': {*_bool_types},