Skip to content

Commit

Permalink
In progress experimention. Add StringDType to JAX's supported types.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707662268
  • Loading branch information
Google-ML-Automation committed Jan 7, 2025
1 parent f1777d5 commit bdc0aee
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 7 deletions.
22 changes: 20 additions & 2 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,27 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
np.dtype('complex64'),
np.dtype('complex128'),
]
_jax_types = _bool_types + _int_types + _float_types + _complex_types
_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types}

_string_types: list[JAXType] = []
try:
import numpy.dtypes as np_dtypes

if hasattr(np_dtypes, 'StringDType'):
_string_types: list[JAXType] = [np_dtypes.StringDType()] # type: ignore
except ImportError:
np_dtypes = None # type: ignore

_jax_types = (
_bool_types + _int_types + _float_types + _complex_types + _string_types
)
_jax_dtype_set = {
float0,
*_bool_types,
*_int_types,
*_float_types,
*_complex_types,
*_string_types,
}

_dtype_kinds: dict[str, set] = {
'bool': {*_bool_types},
Expand Down
17 changes: 14 additions & 3 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@

import numpy as np

try:
import numpy.dtypes as np_dtypes
except ImportError:
np_dtypes = None # type: ignore


import jax
from jax import numpy as jnp
from jax._src import earray
Expand Down Expand Up @@ -770,16 +776,21 @@ def g(x):


class TestPromotionTables(jtu.JaxTestCase):
# Not all types are promotable. For example, currently StringDType is not
# promotable.
promotable_types = [
x for x in dtypes._jax_types if not isinstance(x, np_dtypes.StringDType)
]

@parameterized.named_parameters(
{"testcase_name": f"_{jaxtype=}", "jaxtype": jaxtype}
for jaxtype in dtypes._jax_types + dtypes._weak_types)
for jaxtype in promotable_types + dtypes._weak_types)
def testJaxTypeFromType(self, jaxtype):
self.assertIs(dtypes._jax_type(*dtypes._dtype_and_weaktype(jaxtype)), jaxtype)

@parameterized.named_parameters(
{"testcase_name": f"_{jaxtype=}", "jaxtype": jaxtype}
for jaxtype in dtypes._jax_types + dtypes._weak_types)
for jaxtype in promotable_types + dtypes._weak_types)
def testJaxTypeFromVal(self, jaxtype):
try:
val = jaxtype(0)
Expand All @@ -789,7 +800,7 @@ def testJaxTypeFromVal(self, jaxtype):

@parameterized.named_parameters(
{"testcase_name": f"_{dtype=}", "dtype": dtype}
for dtype in dtypes._jax_types)
for dtype in promotable_types)
def testJaxTypeWeak(self, dtype):
jax_type = dtypes._jax_type(dtype, weak_type=True)
if dtypes.issubdtype(jax_type, np.complexfloating):
Expand Down
13 changes: 11 additions & 2 deletions tests/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@

import numpy as np

try:
import numpy.dtypes as np_dtypes
except ImportError:
np_dtypes = None # type: ignore

# ruff: noqa: F401
try:
import flatbuffers
Expand Down Expand Up @@ -414,7 +419,6 @@ def f(x1, x2):
self.assertEqual(tree_util.tree_structure(res2),
tree_util.tree_structure(res))


def test_error_wrong_intree(self):
def f(a_b_pair, *, c):
return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c
Expand Down Expand Up @@ -1009,6 +1013,12 @@ def f_jax(x): # x: bool[b]
for dtype in dtypes._jax_types if dtype != np.dtype("bool")
])
def test_poly_numeric_dtypes(self, dtype=np.int32):
if hasattr(np_dtypes, "StringDType") and isinstance(
dtype, np_dtypes.StringDType
):
self.skipTest(
"StringDType is not a numeric type"
) # TODO(jmudigonda): revisit.
if str(dtype) in {"float8_e4m3b11fnuz",
"float8_e4m3fnuz",
"float8_e5m2fnuz",
Expand Down Expand Up @@ -1624,7 +1634,6 @@ def test_multi_platform_unknown_platform(self):
platforms=("tpu", "cpu", "cuda", "other"))(x)
self.assertEqual(exp.platforms, ("tpu", "cpu", "cuda", "other"))


def test_multi_platform_with_donation(self):
f = jax.jit(jnp.sin, donate_argnums=(0,))
x = np.arange(3, dtype=np.float32)
Expand Down

0 comments on commit bdc0aee

Please sign in to comment.