diff --git a/ivy/functional/backends/jax/__init__.py b/ivy/functional/backends/jax/__init__.py index ca2e47c668e58..171ffe5a794d4 100644 --- a/ivy/functional/backends/jax/__init__.py +++ b/ivy/functional/backends/jax/__init__.py @@ -101,7 +101,7 @@ def _array_unflatten(aux_data, children): # update these to add new dtypes valid_dtypes = { - "0.4.20 and below": ( + "0.4.21 and below": ( ivy.int8, ivy.int16, ivy.int32, @@ -120,7 +120,7 @@ def _array_unflatten(aux_data, children): ) } valid_numeric_dtypes = { - "0.4.20 and below": ( + "0.4.21 and below": ( ivy.int8, ivy.int16, ivy.int32, @@ -139,7 +139,7 @@ def _array_unflatten(aux_data, children): } valid_int_dtypes = { - "0.4.20 and below": ( + "0.4.21 and below": ( ivy.int8, ivy.int16, ivy.int32, @@ -152,12 +152,12 @@ def _array_unflatten(aux_data, children): } valid_uint_dtypes = { - "0.4.20 and below": (ivy.uint8, ivy.uint16, ivy.uint32, ivy.uint64) + "0.4.21 and below": (ivy.uint8, ivy.uint16, ivy.uint32, ivy.uint64) } valid_float_dtypes = { - "0.4.20 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64) + "0.4.21 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64) } -valid_complex_dtypes = {"0.4.20 and below": (ivy.complex64, ivy.complex128)} +valid_complex_dtypes = {"0.4.21 and below": (ivy.complex64, ivy.complex128)} # leave these untouched @@ -172,12 +172,12 @@ def _array_unflatten(aux_data, children): # invalid data types # update these to add new dtypes -invalid_dtypes = {"0.4.20 and below": ()} -invalid_numeric_dtypes = {"0.4.20 and below": ()} -invalid_int_dtypes = {"0.4.20 and below": ()} -invalid_float_dtypes = {"0.4.20 and below": ()} -invalid_uint_dtypes = {"0.4.20 and below": ()} -invalid_complex_dtypes = {"0.4.20 and below": ()} +invalid_dtypes = {"0.4.21 and below": ()} +invalid_numeric_dtypes = {"0.4.21 and below": ()} +invalid_int_dtypes = {"0.4.21 and below": ()} +invalid_float_dtypes = {"0.4.21 and below": ()} +invalid_uint_dtypes = {"0.4.21 and below": ()} +invalid_complex_dtypes = {"0.4.21 and below": ()} # leave these untouched invalid_dtypes = _dtype_from_version(invalid_dtypes, backend_version) diff --git a/ivy/functional/backends/jax/elementwise.py b/ivy/functional/backends/jax/elementwise.py index 23e3211ccffb5..909da651c29e6 100644 --- a/ivy/functional/backends/jax/elementwise.py +++ b/ivy/functional/backends/jax/elementwise.py @@ -72,7 +72,7 @@ def atanh(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: return jnp.arctanh(x) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def bitwise_and( x1: Union[int, JaxArray], x2: Union[int, JaxArray], @@ -84,14 +84,14 @@ def bitwise_and( return jnp.bitwise_and(x1, x2) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def bitwise_invert( x: Union[int, JaxArray], /, *, out: Optional[JaxArray] = None ) -> JaxArray: return jnp.bitwise_not(x) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def bitwise_left_shift( x1: Union[int, JaxArray], x2: Union[int, JaxArray], @@ -103,7 +103,7 @@ def bitwise_left_shift( return jnp.left_shift(x1, x2) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def bitwise_or( x1: Union[int, JaxArray], x2: Union[int, JaxArray], @@ -115,7 +115,7 @@ def bitwise_or( return jnp.bitwise_or(x1, x2) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def bitwise_right_shift( x1: Union[int, JaxArray], x2: Union[int, JaxArray], @@ -127,7 +127,7 @@ def bitwise_right_shift( return jnp.right_shift(x1, x2) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def bitwise_xor( x1: Union[int, JaxArray], x2: Union[int, JaxArray], @@ -139,7 +139,7 @@ def bitwise_xor( return jnp.bitwise_xor(x1, x2) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def ceil(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: if "int" in str(x.dtype): return x @@ -151,7 +151,7 @@ def cos(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: return jnp.cos(x) -@with_unsupported_dtypes({"0.4.20 and below": ("float16",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("float16",)}, backend_version) def cosh(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: return jnp.cosh(x) @@ -191,7 +191,7 @@ def expm1(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: return jnp.expm1(x) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def floor(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: if "int" in str(x.dtype): return x @@ -199,7 +199,7 @@ def floor(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: return jnp.floor(x) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def floor_divide( x1: Union[float, JaxArray], x2: Union[float, JaxArray], @@ -427,7 +427,7 @@ def pow( return jnp.power(x1, x2) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def remainder( x1: Union[float, JaxArray], x2: Union[float, JaxArray], @@ -524,7 +524,7 @@ def tanh( return jnp.tanh(x) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def trunc(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: if "int" in str(x.dtype): return x @@ -564,7 +564,7 @@ def angle( # ------# -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def erf(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: return jax.scipy.special.erf(x) @@ -615,7 +615,7 @@ def isreal(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: return jnp.isreal(x) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def fmod( x1: JaxArray, x2: JaxArray, diff --git a/ivy/functional/backends/jax/experimental/elementwise.py b/ivy/functional/backends/jax/experimental/elementwise.py index 4a005bd28c203..1ebf9d8c99ff8 100644 --- a/ivy/functional/backends/jax/experimental/elementwise.py +++ b/ivy/functional/backends/jax/experimental/elementwise.py @@ -50,7 +50,7 @@ def sinc(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: @with_supported_dtypes( - {"0.4.20 and below": ("float16", "float32", "float64")}, backend_version + {"0.4.21 and below": ("float16", "float32", "float64")}, backend_version ) def lgamma(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: return jlax.lgamma(x) diff --git a/ivy/functional/backends/jax/experimental/layers.py b/ivy/functional/backends/jax/experimental/layers.py index 95fbb1277a19c..475749cadcc31 100644 --- a/ivy/functional/backends/jax/experimental/layers.py +++ b/ivy/functional/backends/jax/experimental/layers.py @@ -440,7 +440,7 @@ def avg_pool3d( return res -@with_supported_dtypes({"0.4.20 and below": ("float32", "float64")}, backend_version) +@with_supported_dtypes({"0.4.21 and below": ("float32", "float64")}, backend_version) def dct( x: JaxArray, /, @@ -822,7 +822,7 @@ def ifftn( @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, backend_version + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version ) def embedding( weights: JaxArray, @@ -870,7 +870,7 @@ def rfft( return ret -@with_unsupported_dtypes({"0.4.20 and below": ("float16", "complex")}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("float16", "complex")}, backend_version) def rfftn( x: JaxArray, s: Optional[Sequence[int]] = None, diff --git a/ivy/functional/backends/jax/experimental/random.py b/ivy/functional/backends/jax/experimental/random.py index 0afa2f35dc9a1..4675cf0d8293d 100644 --- a/ivy/functional/backends/jax/experimental/random.py +++ b/ivy/functional/backends/jax/experimental/random.py @@ -56,7 +56,7 @@ def beta( return jax.random.beta(rng_input, a, b, shape, dtype) -@with_unsupported_dtypes({"0.4.20 and below": ("bfloat16",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("bfloat16",)}, backend_version) def gamma( alpha: Union[float, JaxArray], beta: Union[float, JaxArray], diff --git a/ivy/functional/backends/jax/experimental/sorting.py b/ivy/functional/backends/jax/experimental/sorting.py index 232bf3d4be6e2..9dcc596eae53c 100644 --- a/ivy/functional/backends/jax/experimental/sorting.py +++ b/ivy/functional/backends/jax/experimental/sorting.py @@ -23,7 +23,7 @@ def invert_permutation( # lexsort -@with_unsupported_dtypes({"0.4.20 and below": ("bfloat16",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("bfloat16",)}, backend_version) def lexsort( keys: JaxArray, /, diff --git a/ivy/functional/backends/jax/experimental/statistical.py b/ivy/functional/backends/jax/experimental/statistical.py index 65beaf0363722..f83e98bee4c88 100644 --- a/ivy/functional/backends/jax/experimental/statistical.py +++ b/ivy/functional/backends/jax/experimental/statistical.py @@ -10,7 +10,7 @@ @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16",)}, + {"0.4.21 and below": ("bfloat16",)}, backend_version, ) def histogram( @@ -121,7 +121,7 @@ def histogram( @with_unsupported_dtypes( - {"0.4.20 and below": ("complex64", "complex128")}, backend_version + {"0.4.21 and below": ("complex64", "complex128")}, backend_version ) def median( input: JaxArray, @@ -406,7 +406,7 @@ def __get_index(lst, indices=None, prefix=None): @with_unsupported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "bfloat16", "bool", ) diff --git a/ivy/functional/backends/jax/general.py b/ivy/functional/backends/jax/general.py index c2f643afcd9cc..cb9c1edceab73 100644 --- a/ivy/functional/backends/jax/general.py +++ b/ivy/functional/backends/jax/general.py @@ -101,7 +101,7 @@ def array_equal(x0: JaxArray, x1: JaxArray, /) -> bool: return bool(jnp.array_equal(x0, x1)) -@with_unsupported_dtypes({"0.4.20 and below": ("bfloat16",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("bfloat16",)}, backend_version) def to_numpy(x: JaxArray, /, *, copy: bool = True) -> np.ndarray: if copy: return np.array(_to_array(x)) @@ -420,7 +420,7 @@ def vmap( ) -@with_unsupported_dtypes({"0.4.20 and below": ("float16", "bfloat16")}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("float16", "bfloat16")}, backend_version) def isin( elements: JaxArray, test_elements: JaxArray, diff --git a/ivy/functional/backends/jax/linear_algebra.py b/ivy/functional/backends/jax/linear_algebra.py index 0e622f58b8bed..c5243e5c404de 100644 --- a/ivy/functional/backends/jax/linear_algebra.py +++ b/ivy/functional/backends/jax/linear_algebra.py @@ -20,7 +20,7 @@ @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version, ) def cholesky( @@ -34,7 +34,7 @@ def cholesky( return ret -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def cross( x1: JaxArray, x2: JaxArray, @@ -51,14 +51,14 @@ def cross( @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version, ) def det(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: return jnp.linalg.det(x) -@with_unsupported_dtypes({"0.4.20 and below": ("float16", "bfloat16")}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("float16", "bfloat16")}, backend_version) def eig(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> Tuple[JaxArray]: result_tuple = NamedTuple( "eig", [("eigenvalues", JaxArray), ("eigenvectors", JaxArray)] @@ -67,7 +67,7 @@ def eig(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> Tuple[JaxArray]: return result_tuple(eigenvalues, eigenvectors) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def diagonal( x: JaxArray, /, @@ -104,7 +104,7 @@ def tensorsolve( @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version, ) def eigh( @@ -118,7 +118,7 @@ def eigh( @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version, ) def eigvalsh( @@ -127,14 +127,14 @@ def eigvalsh( return jnp.linalg.eigvalsh(x, UPLO=UPLO) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def inner(x1: JaxArray, x2: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: x1, x2 = ivy.promote_types_of_inputs(x1, x2) return jnp.inner(x1, x2) @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version, ) def inv( @@ -155,7 +155,7 @@ def inv( @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version, ) def matmul( @@ -181,7 +181,7 @@ def matmul( @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version, ) def matrix_norm( @@ -202,13 +202,13 @@ def matrix_norm( return jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def matrix_power(x: JaxArray, n: int, /, *, out: Optional[JaxArray] = None) -> JaxArray: return jnp.linalg.matrix_power(x, n) @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version, ) def matrix_rank( @@ -239,7 +239,7 @@ def matrix_rank( @with_unsupported_dtypes( - {"0.4.20 and below": ("int", "float16", "complex")}, + {"0.4.21 and below": ("int", "float16", "complex")}, backend_version, ) def matrix_transpose( @@ -251,7 +251,7 @@ def matrix_transpose( @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version, ) def outer( @@ -266,7 +266,7 @@ def outer( @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version, ) def pinv( @@ -284,7 +284,7 @@ def pinv( @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version, ) def qr( @@ -296,7 +296,7 @@ def qr( @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version, ) def slogdet( @@ -309,7 +309,7 @@ def slogdet( @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version, ) def solve( @@ -351,7 +351,7 @@ def solve( @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version, ) def svd( @@ -368,7 +368,7 @@ def svd( @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version, ) def svdvals( @@ -378,7 +378,7 @@ def svdvals( return jnp.linalg.svd(x, compute_uv=False) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def tensordot( x1: JaxArray, x2: JaxArray, @@ -392,7 +392,7 @@ def tensordot( @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version, ) def trace( @@ -407,7 +407,7 @@ def trace( return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2, out=out) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def vecdot( x1: JaxArray, x2: JaxArray, /, *, axis: int = -1, out: Optional[JaxArray] = None ) -> JaxArray: @@ -415,7 +415,7 @@ def vecdot( return jnp.tensordot(x1, x2, axes=(axis, axis)) -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def vector_norm( x: JaxArray, /, @@ -445,7 +445,7 @@ def vector_norm( # ------# -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def diag( x: JaxArray, /, @@ -457,7 +457,7 @@ def diag( @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "complex")}, + {"0.4.21 and below": ("bfloat16", "float16", "complex")}, backend_version, ) def vander( @@ -473,7 +473,7 @@ def vander( @with_unsupported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "complex", "unsigned", ) diff --git a/ivy/functional/backends/jax/manipulation.py b/ivy/functional/backends/jax/manipulation.py index 161ba13f9f082..c7bc18b42c636 100644 --- a/ivy/functional/backends/jax/manipulation.py +++ b/ivy/functional/backends/jax/manipulation.py @@ -226,7 +226,7 @@ def clip( return x -@with_unsupported_dtypes({"0.4.20 and below": ("uint64",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("uint64",)}, backend_version) def constant_pad( x: JaxArray, /, diff --git a/ivy/functional/backends/jax/random.py b/ivy/functional/backends/jax/random.py index 506eb095fae5c..f277860762f4c 100644 --- a/ivy/functional/backends/jax/random.py +++ b/ivy/functional/backends/jax/random.py @@ -82,7 +82,7 @@ def random_normal( return jax.random.normal(rng_input, shape, dtype=dtype) * std + mean -@with_unsupported_dtypes({"0.4.20 and below": ("bfloat16",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("bfloat16",)}, backend_version) def multinomial( population_size: int, num_samples: int, diff --git a/ivy/functional/backends/jax/searching.py b/ivy/functional/backends/jax/searching.py index 679c02981e49e..a2a6a7381ec2d 100644 --- a/ivy/functional/backends/jax/searching.py +++ b/ivy/functional/backends/jax/searching.py @@ -12,7 +12,7 @@ # ------------------ # -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def argmax( x: JaxArray, /, @@ -38,7 +38,7 @@ def argmax( return ret -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def argmin( x: JaxArray, /, diff --git a/ivy/functional/backends/jax/sorting.py b/ivy/functional/backends/jax/sorting.py index 1dc7eecb5b2c4..a633b8044004a 100644 --- a/ivy/functional/backends/jax/sorting.py +++ b/ivy/functional/backends/jax/sorting.py @@ -80,7 +80,7 @@ def searchsorted( # msort -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, backend_version) def msort( a: Union[JaxArray, list, tuple], /, diff --git a/ivy/functional/backends/jax/statistical.py b/ivy/functional/backends/jax/statistical.py index 9e0cef560e067..3524c4ff88d36 100644 --- a/ivy/functional/backends/jax/statistical.py +++ b/ivy/functional/backends/jax/statistical.py @@ -38,7 +38,7 @@ def max( @with_unsupported_dtypes( - {"0.4.20 and below": "bfloat16"}, + {"0.4.21 and below": "bfloat16"}, backend_version, ) def mean( @@ -143,7 +143,7 @@ def var( # ------# -@with_unsupported_dtypes({"0.4.20 and below": "bfloat16"}, backend_version) +@with_unsupported_dtypes({"0.4.21 and below": "bfloat16"}, backend_version) def cumprod( x: JaxArray, /, diff --git a/ivy/functional/frontends/__init__.py b/ivy/functional/frontends/__init__.py index 74b5c3b1c4b17..712f483d3df7e 100644 --- a/ivy/functional/frontends/__init__.py +++ b/ivy/functional/frontends/__init__.py @@ -5,7 +5,7 @@ "torch": "2.1.1", "tensorflow": "2.15.0", "numpy": "1.25.2", - "jax": "0.4.14", + "jax": "0.4.21", "scipy": "1.10.1", "paddle": "2.5.2", "sklearn": "1.3.0", diff --git a/ivy/functional/frontends/jax/lax/operators.py b/ivy/functional/frontends/jax/lax/operators.py index b2ce5a8631c81..935911d628424 100644 --- a/ivy/functional/frontends/jax/lax/operators.py +++ b/ivy/functional/frontends/jax/lax/operators.py @@ -157,7 +157,7 @@ def broadcast(operand, sizes): @with_supported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "float16", "float32", "float64", @@ -309,7 +309,7 @@ def cosh(x): @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16", "float16", "bool", "complex64", "complex128")}, + {"0.4.21 and below": ("bfloat16", "float16", "bool", "complex64", "complex128")}, "jax", ) @to_ivy_arrays_and_back @@ -400,7 +400,7 @@ def erf(x): @with_supported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "float16", "float32", "float64", @@ -465,7 +465,7 @@ def imag(x): @with_unsupported_dtypes( - {"0.4.20 and below": ("bool", "bfloat16")}, + {"0.4.21 and below": ("bool", "bfloat16")}, "jax", ) @to_ivy_arrays_and_back diff --git a/ivy/functional/frontends/jax/nn/non_linear_activations.py b/ivy/functional/frontends/jax/nn/non_linear_activations.py index 0cc778cf320cb..726114cbac3e1 100644 --- a/ivy/functional/frontends/jax/nn/non_linear_activations.py +++ b/ivy/functional/frontends/jax/nn/non_linear_activations.py @@ -289,7 +289,7 @@ def sigmoid(x): @with_supported_dtypes( - {"0.4.20 and below": ("complex", "float")}, + {"0.4.21 and below": ("complex", "float")}, "jax", ) @to_ivy_arrays_and_back diff --git a/ivy/functional/frontends/jax/numpy/creation.py b/ivy/functional/frontends/jax/numpy/creation.py index 93a8cba0eca7b..b83c1ff83d27e 100644 --- a/ivy/functional/frontends/jax/numpy/creation.py +++ b/ivy/functional/frontends/jax/numpy/creation.py @@ -17,7 +17,7 @@ @with_unsupported_device_and_dtypes( { - "0.4.20 and below": { + "0.4.21 and below": { "cpu": ( "float16", "bflooat16", @@ -196,7 +196,7 @@ def iterable(y): @to_ivy_arrays_and_back @with_unsupported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "float16", "bfloat16", ) @@ -217,7 +217,7 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis @to_ivy_arrays_and_back @with_unsupported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "float16", "bfloat16", ) diff --git a/ivy/functional/frontends/jax/numpy/linalg.py b/ivy/functional/frontends/jax/numpy/linalg.py index 1a34220713653..04c18eec1313f 100644 --- a/ivy/functional/frontends/jax/numpy/linalg.py +++ b/ivy/functional/frontends/jax/numpy/linalg.py @@ -88,7 +88,7 @@ def multi_dot(arrays, *, precision=None): @to_ivy_arrays_and_back @with_supported_dtypes( - {"0.4.20 and below": ("float32", "float64")}, + {"0.4.21 and below": ("float32", "float64")}, "jax", ) def norm(x, ord=None, axis=None, keepdims=False): @@ -127,7 +127,7 @@ def svd(a, /, *, full_matrices=True, compute_uv=True, hermitian=None): @to_ivy_arrays_and_back -@with_unsupported_dtypes({"0.4.20 and below": ("float16", "bfloat16")}, "jax") +@with_unsupported_dtypes({"0.4.21 and below": ("float16", "bfloat16")}, "jax") def tensorinv(a, ind=2): old_shape = ivy.shape(a) prod = 1 diff --git a/ivy/functional/frontends/jax/numpy/logic.py b/ivy/functional/frontends/jax/numpy/logic.py index c6d79f193d91d..d5a28d2803509 100644 --- a/ivy/functional/frontends/jax/numpy/logic.py +++ b/ivy/functional/frontends/jax/numpy/logic.py @@ -101,7 +101,7 @@ def equal(x1, x2, /): @to_ivy_arrays_and_back -@with_unsupported_dtypes({"0.4.20 and below": ("bfloat16",)}, "jax") +@with_unsupported_dtypes({"0.4.21 and below": ("bfloat16",)}, "jax") def fromfunction(function, shape, *, dtype=float, **kwargs): def canonicalize_shape(shape, context="shape argument"): if isinstance(shape, int): @@ -285,7 +285,7 @@ def right_shift(x1, x2, /): @to_ivy_arrays_and_back -@with_unsupported_dtypes({"0.4.20 and below": ("bfloat16", "bool")}, "jax") +@with_unsupported_dtypes({"0.4.21 and below": ("bfloat16", "bool")}, "jax") def setxor1d(ar1, ar2, assume_unique=False): common_dtype = ivy.promote_types(ivy.dtype(ar1), ivy.dtype(ar2)) ar1 = ivy.asarray(ar1, dtype=common_dtype) diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index c8a607f689ced..9e661cad4746e 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -76,7 +76,7 @@ def around(a, decimals=0, out=None): @with_unsupported_dtypes( - {"0.4.20 and below": ("bfloat16",)}, + {"0.4.21 and below": ("bfloat16",)}, "jax", ) @to_ivy_arrays_and_back @@ -420,7 +420,7 @@ def expm1( @with_unsupported_dtypes( - {"0.4.20 and below": ("uint16",)}, + {"0.4.21 and below": ("uint16",)}, "jax", ) @to_ivy_arrays_and_back @@ -584,7 +584,7 @@ def minimum(x1, x2, /): @to_ivy_arrays_and_back -@with_unsupported_dtypes({"0.4.20 and below": ("complex",)}, "jax") +@with_unsupported_dtypes({"0.4.21 and below": ("complex",)}, "jax") def mod(x1, x2, /): x1, x2 = promote_types_of_jax_inputs(x1, x2) return ivy.remainder(x1, x2) @@ -626,7 +626,7 @@ def negative( @with_unsupported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "bfloat16", "float16", ) @@ -671,7 +671,7 @@ def polyadd(a1, a2): @with_unsupported_dtypes( - {"0.4.20 and below": ("float16",)}, + {"0.4.21 and below": ("float16",)}, "jax", ) @to_ivy_arrays_and_back @@ -713,7 +713,7 @@ def polydiv(u, v, *, trim_leading_zeros=False): @with_unsupported_dtypes( - {"0.4.20 and below": ("float16",)}, + {"0.4.21 and below": ("float16",)}, "jax", ) @to_ivy_arrays_and_back diff --git a/ivy/functional/frontends/jax/numpy/searching_sorting.py b/ivy/functional/frontends/jax/numpy/searching_sorting.py index 3ad8b7b6d835f..bd692378a6749 100644 --- a/ivy/functional/frontends/jax/numpy/searching_sorting.py +++ b/ivy/functional/frontends/jax/numpy/searching_sorting.py @@ -15,7 +15,7 @@ @to_ivy_arrays_and_back @with_unsupported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "float16", "bfloat16", ) @@ -58,7 +58,7 @@ def argwhere(a, /, *, size=None, fill_value=None): @with_unsupported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "uint8", "int8", "bool", diff --git a/ivy/functional/frontends/jax/numpy/statistical.py b/ivy/functional/frontends/jax/numpy/statistical.py index bb40e941322a4..43ed0fe0499b7 100644 --- a/ivy/functional/frontends/jax/numpy/statistical.py +++ b/ivy/functional/frontends/jax/numpy/statistical.py @@ -103,7 +103,7 @@ def corrcoef(x, y=None, rowvar=True): @to_ivy_arrays_and_back -@with_unsupported_dtypes({"0.4.20 and below": ("float16", "bfloat16")}, "jax") +@with_unsupported_dtypes({"0.4.21 and below": ("float16", "bfloat16")}, "jax") def correlate(a, v, mode="valid", precision=None): if ivy.get_num_dims(a) != 1 or ivy.get_num_dims(v) != 1: raise ValueError("correlate() only support 1-dimensional inputs.") @@ -572,7 +572,7 @@ def ptp(a, axis=None, out=None, keepdims=False): @to_ivy_arrays_and_back @with_unsupported_dtypes( - {"0.4.20 and below": ("complex64", "complex128", "bfloat16", "bool", "float16")}, + {"0.4.21 and below": ("complex64", "complex128", "bfloat16", "bool", "float16")}, "jax", ) def quantile( @@ -597,7 +597,7 @@ def quantile( @handle_jax_dtype -@with_unsupported_dtypes({"0.4.20 and below": ("bfloat16",)}, "jax") +@with_unsupported_dtypes({"0.4.21 and below": ("bfloat16",)}, "jax") @to_ivy_arrays_and_back def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None): axis = tuple(axis) if isinstance(axis, list) else axis diff --git a/ivy/functional/frontends/jax/random.py b/ivy/functional/frontends/jax/random.py index f8adcc6778a04..898b45aa05a5a 100644 --- a/ivy/functional/frontends/jax/random.py +++ b/ivy/functional/frontends/jax/random.py @@ -38,7 +38,7 @@ def PRNGKey(seed): @to_ivy_arrays_and_back @with_supported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "float32", "float64", ) @@ -70,7 +70,7 @@ def bernoulli(key, p=0.5, shape=None): @to_ivy_arrays_and_back @with_unsupported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "float16", "bfloat16", ) @@ -85,7 +85,7 @@ def beta(key, a, b, shape=None, dtype=None): @to_ivy_arrays_and_back @with_unsupported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "float16", "bfloat16", ) @@ -133,7 +133,7 @@ def cauchy(key, shape=(), dtype="float64"): @to_ivy_arrays_and_back @with_unsupported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "float16", "bfloat16", ) @@ -149,7 +149,7 @@ def dirichlet(key, alpha, shape=None, dtype="float32"): @handle_jax_dtype @to_ivy_arrays_and_back @with_unsupported_dtypes( - {"0.4.20 and below": "uint32"}, + {"0.4.21 and below": "uint32"}, "jax", ) def double_sided_maxwell(key, loc, scale, shape=(), dtype="float64"): @@ -168,7 +168,7 @@ def double_sided_maxwell(key, loc, scale, shape=(), dtype="float64"): @to_ivy_arrays_and_back @with_unsupported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "float16", "bfloat16", ) @@ -196,7 +196,7 @@ def fold_in(key, data): @to_ivy_arrays_and_back @with_unsupported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "float16", "bfloat16", ) @@ -212,7 +212,7 @@ def gamma(key, a, shape=None, dtype="float64"): @to_ivy_arrays_and_back @with_unsupported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "float16", "bfloat16", ) @@ -231,7 +231,7 @@ def generalized_normal(key, p, shape=(), dtype="float64"): @to_ivy_arrays_and_back @with_unsupported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "float16", "bfloat16", ) @@ -255,7 +255,7 @@ def gumbel(key, shape=(), dtype="float64"): @to_ivy_arrays_and_back @with_unsupported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "float16", "bfloat16", ) @@ -270,7 +270,7 @@ def loggamma(key, a, shape=None, dtype="float64"): @handle_jax_dtype @to_ivy_arrays_and_back @with_unsupported_dtypes( - {"0.4.20 and below": ("float16", "bfloat16")}, + {"0.4.21 and below": ("float16", "bfloat16")}, "jax", ) def logistic(key, shape=(), dtype="float64"): @@ -301,7 +301,7 @@ def maxwell(key, shape, dtype="float64"): @to_ivy_arrays_and_back @with_unsupported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "float16", "bfloat16", ) @@ -358,7 +358,7 @@ def orthogonal(key, n, shape=(), dtype=None): @to_ivy_arrays_and_back @with_unsupported_dtypes( { - "0.4.20 and below": ( + "0.4.21 and below": ( "float16", "bfloat16", ) @@ -393,7 +393,7 @@ def permutation(key, x, axis=0, independent=False): @handle_jax_dtype @to_ivy_arrays_and_back @with_unsupported_dtypes( - {"0.4.20 and below": ("unsigned", "int8", "int16")}, + {"0.4.21 and below": ("unsigned", "int8", "int16")}, "jax", ) def poisson(key, lam, shape=None, dtype=None): @@ -404,7 +404,7 @@ def poisson(key, lam, shape=None, dtype=None): @handle_jax_dtype @to_ivy_arrays_and_back @with_unsupported_dtypes( - {"0.4.20 and below": ("unsigned", "int8", "int16")}, + {"0.4.21 and below": ("unsigned", "int8", "int16")}, "jax", ) def rademacher(key, shape, dtype="int64"): @@ -418,7 +418,7 @@ def rademacher(key, shape, dtype="int64"): @handle_jax_dtype @to_ivy_arrays_and_back @with_unsupported_dtypes( - {"0.4.20 and below": ("unsigned", "int8", "int16")}, + {"0.4.21 and below": ("unsigned", "int8", "int16")}, "jax", ) def randint(key, shape, minval, maxval, dtype="int64"):