Skip to content

Commit

Permalink
Fix for a change in JAX's promotion rules for floor and ceil. (ke…
Browse files Browse the repository at this point in the history
…ras-team#19946)

JAX has an unreleased change jax-ml/jax#21441 that make it so that `int`s and `bool`s are no longer promoted to floats by `floor` and `ceil`.

Our current implementation follows the Numpy promotion rule, i.e. `int`s and `bool`s are promoted to floats. Therefore:
- updated unit tests, which use `jax.numpy` as the reference.
- updated our JAX implementation of `floor` and `ceil` (note that `ceil` would already cast).
  • Loading branch information
hertschuh authored Jul 1, 2024
1 parent 38ef5f2 commit f81009b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
8 changes: 6 additions & 2 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@ def ceil(x):
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
return cast(jnp.ceil(x), dtype)
x = cast(x, dtype)
return jnp.ceil(x)


def clip(x, x_min, x_max):
Expand Down Expand Up @@ -558,7 +559,10 @@ def flip(x, axis=None):
def floor(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
x = cast(x, config.floatx())
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
x = cast(x, dtype)
return jnp.floor(x)


Expand Down
7 changes: 5 additions & 2 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from keras.src import backend
from keras.src import testing
from keras.src.backend.common import dtypes
from keras.src.backend.common import is_int_dtype
from keras.src.backend.common import standardize_dtype
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.ops import numpy as knp
Expand Down Expand Up @@ -5794,7 +5795,8 @@ def test_ceil(self, dtype):
x = knp.array(value, dtype=dtype)
x_jax = jnp.array(value, dtype=dtype)
expected_dtype = standardize_dtype(jnp.ceil(x_jax).dtype)
if dtype == "int64":
# Here, we follow Numpy's rule, not JAX's; ints are promoted to floats.
if dtype == "bool" or is_int_dtype(dtype):
expected_dtype = backend.floatx()

self.assertEqual(standardize_dtype(knp.ceil(x).dtype), expected_dtype)
Expand Down Expand Up @@ -6377,7 +6379,8 @@ def test_floor(self, dtype):
x = knp.ones((1,), dtype=dtype)
x_jax = jnp.ones((1,), dtype=dtype)
expected_dtype = standardize_dtype(jnp.floor(x_jax).dtype)
if dtype == "int64":
# Here, we follow Numpy's rule, not JAX's; ints are promoted to floats.
if dtype == "bool" or is_int_dtype(dtype):
expected_dtype = backend.floatx()

self.assertEqual(standardize_dtype(knp.floor(x).dtype), expected_dtype)
Expand Down

0 comments on commit f81009b

Please sign in to comment.