Skip to content

Commit

Permalink
feat: Updated jax version to 0.4.21 (#27442)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sai-Suraj-27 authored Dec 7, 2023
1 parent bd8435c commit 67566f7
Show file tree
Hide file tree
Showing 24 changed files with 112 additions and 112 deletions.
24 changes: 12 additions & 12 deletions ivy/functional/backends/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _array_unflatten(aux_data, children):

# update these to add new dtypes
valid_dtypes = {
"0.4.20 and below": (
"0.4.21 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
Expand All @@ -120,7 +120,7 @@ def _array_unflatten(aux_data, children):
)
}
valid_numeric_dtypes = {
"0.4.20 and below": (
"0.4.21 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
Expand All @@ -139,7 +139,7 @@ def _array_unflatten(aux_data, children):
}

valid_int_dtypes = {
"0.4.20 and below": (
"0.4.21 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
Expand All @@ -152,12 +152,12 @@ def _array_unflatten(aux_data, children):
}

valid_uint_dtypes = {
"0.4.20 and below": (ivy.uint8, ivy.uint16, ivy.uint32, ivy.uint64)
"0.4.21 and below": (ivy.uint8, ivy.uint16, ivy.uint32, ivy.uint64)
}
valid_float_dtypes = {
"0.4.20 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64)
"0.4.21 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64)
}
valid_complex_dtypes = {"0.4.20 and below": (ivy.complex64, ivy.complex128)}
valid_complex_dtypes = {"0.4.21 and below": (ivy.complex64, ivy.complex128)}


# leave these untouched
Expand All @@ -172,12 +172,12 @@ def _array_unflatten(aux_data, children):
# invalid data types

# update these to add new dtypes
invalid_dtypes = {"0.4.20 and below": ()}
invalid_numeric_dtypes = {"0.4.20 and below": ()}
invalid_int_dtypes = {"0.4.20 and below": ()}
invalid_float_dtypes = {"0.4.20 and below": ()}
invalid_uint_dtypes = {"0.4.20 and below": ()}
invalid_complex_dtypes = {"0.4.20 and below": ()}
invalid_dtypes = {"0.4.21 and below": ()}
invalid_numeric_dtypes = {"0.4.21 and below": ()}
invalid_int_dtypes = {"0.4.21 and below": ()}
invalid_float_dtypes = {"0.4.21 and below": ()}
invalid_uint_dtypes = {"0.4.21 and below": ()}
invalid_complex_dtypes = {"0.4.21 and below": ()}

# leave these untouched
invalid_dtypes = _dtype_from_version(invalid_dtypes, backend_version)
Expand Down
28 changes: 14 additions & 14 deletions ivy/functional/backends/jax/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def atanh(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.arctanh(x)


@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version)
def bitwise_and(
x1: Union[int, JaxArray],
x2: Union[int, JaxArray],
Expand All @@ -84,14 +84,14 @@ def bitwise_and(
return jnp.bitwise_and(x1, x2)


@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version)
def bitwise_invert(
x: Union[int, JaxArray], /, *, out: Optional[JaxArray] = None
) -> JaxArray:
return jnp.bitwise_not(x)


@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version)
def bitwise_left_shift(
x1: Union[int, JaxArray],
x2: Union[int, JaxArray],
Expand All @@ -103,7 +103,7 @@ def bitwise_left_shift(
return jnp.left_shift(x1, x2)


@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version)
def bitwise_or(
x1: Union[int, JaxArray],
x2: Union[int, JaxArray],
Expand All @@ -115,7 +115,7 @@ def bitwise_or(
return jnp.bitwise_or(x1, x2)


@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version)
def bitwise_right_shift(
x1: Union[int, JaxArray],
x2: Union[int, JaxArray],
Expand All @@ -127,7 +127,7 @@ def bitwise_right_shift(
return jnp.right_shift(x1, x2)


@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version)
def bitwise_xor(
x1: Union[int, JaxArray],
x2: Union[int, JaxArray],
Expand All @@ -139,7 +139,7 @@ def bitwise_xor(
return jnp.bitwise_xor(x1, x2)


@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version)
def ceil(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
if "int" in str(x.dtype):
return x
Expand All @@ -151,7 +151,7 @@ def cos(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.cos(x)


@with_unsupported_dtypes({"0.4.20 and below": ("float16",)}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("float16",)}, backend_version)
def cosh(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.cosh(x)

Expand Down Expand Up @@ -191,15 +191,15 @@ def expm1(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.expm1(x)


@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version)
def floor(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
if "int" in str(x.dtype):
return x
else:
return jnp.floor(x)


@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version)
def floor_divide(
x1: Union[float, JaxArray],
x2: Union[float, JaxArray],
Expand Down Expand Up @@ -427,7 +427,7 @@ def pow(
return jnp.power(x1, x2)


@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version)
def remainder(
x1: Union[float, JaxArray],
x2: Union[float, JaxArray],
Expand Down Expand Up @@ -524,7 +524,7 @@ def tanh(
return jnp.tanh(x)


@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version)
def trunc(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
if "int" in str(x.dtype):
return x
Expand Down Expand Up @@ -564,7 +564,7 @@ def angle(
# ------#


@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version)
def erf(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jax.scipy.special.erf(x)

Expand Down Expand Up @@ -615,7 +615,7 @@ def isreal(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.isreal(x)


@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version)
def fmod(
x1: JaxArray,
x2: JaxArray,
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/jax/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def sinc(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:


@with_supported_dtypes(
{"0.4.20 and below": ("float16", "float32", "float64")}, backend_version
{"0.4.21 and below": ("float16", "float32", "float64")}, backend_version
)
def lgamma(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jlax.lgamma(x)
Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/backends/jax/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def avg_pool3d(
return res


@with_supported_dtypes({"0.4.20 and below": ("float32", "float64")}, backend_version)
@with_supported_dtypes({"0.4.21 and below": ("float32", "float64")}, backend_version)
def dct(
x: JaxArray,
/,
Expand Down Expand Up @@ -822,7 +822,7 @@ def ifftn(


@with_unsupported_dtypes(
{"0.4.20 and below": ("bfloat16", "float16", "complex")}, backend_version
{"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version
)
def embedding(
weights: JaxArray,
Expand Down Expand Up @@ -870,7 +870,7 @@ def rfft(
return ret


@with_unsupported_dtypes({"0.4.20 and below": ("float16", "complex")}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("float16", "complex")}, backend_version)
def rfftn(
x: JaxArray,
s: Optional[Sequence[int]] = None,
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/jax/experimental/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def beta(
return jax.random.beta(rng_input, a, b, shape, dtype)


@with_unsupported_dtypes({"0.4.20 and below": ("bfloat16",)}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("bfloat16",)}, backend_version)
def gamma(
alpha: Union[float, JaxArray],
beta: Union[float, JaxArray],
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/jax/experimental/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def invert_permutation(


# lexsort
@with_unsupported_dtypes({"0.4.20 and below": ("bfloat16",)}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("bfloat16",)}, backend_version)
def lexsort(
keys: JaxArray,
/,
Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/backends/jax/experimental/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@with_unsupported_dtypes(
{"0.4.20 and below": ("bfloat16",)},
{"0.4.21 and below": ("bfloat16",)},
backend_version,
)
def histogram(
Expand Down Expand Up @@ -121,7 +121,7 @@ def histogram(


@with_unsupported_dtypes(
{"0.4.20 and below": ("complex64", "complex128")}, backend_version
{"0.4.21 and below": ("complex64", "complex128")}, backend_version
)
def median(
input: JaxArray,
Expand Down Expand Up @@ -406,7 +406,7 @@ def __get_index(lst, indices=None, prefix=None):

@with_unsupported_dtypes(
{
"0.4.20 and below": (
"0.4.21 and below": (
"bfloat16",
"bool",
)
Expand Down
4 changes: 2 additions & 2 deletions ivy/functional/backends/jax/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def array_equal(x0: JaxArray, x1: JaxArray, /) -> bool:
return bool(jnp.array_equal(x0, x1))


@with_unsupported_dtypes({"0.4.20 and below": ("bfloat16",)}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("bfloat16",)}, backend_version)
def to_numpy(x: JaxArray, /, *, copy: bool = True) -> np.ndarray:
if copy:
return np.array(_to_array(x))
Expand Down Expand Up @@ -420,7 +420,7 @@ def vmap(
)


@with_unsupported_dtypes({"0.4.20 and below": ("float16", "bfloat16")}, backend_version)
@with_unsupported_dtypes({"0.4.21 and below": ("float16", "bfloat16")}, backend_version)
def isin(
elements: JaxArray,
test_elements: JaxArray,
Expand Down
Loading

0 comments on commit 67566f7

Please sign in to comment.