From 8e9a211c0bfb616187ff0ec99a1806f418ff7c17 Mon Sep 17 00:00:00 2001 From: Madjid Chergui <100947451+Madjid-CH@users.noreply.github.com> Date: Wed, 20 Sep 2023 19:52:01 +0100 Subject: [PATCH] feat(jax backend): removed manual dtype casting. (#23655) --- .../jax/experimental/linear_algebra.py | 8 +++++-- .../backends/jax/experimental/statistical.py | 24 ++++++++++--------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/ivy/functional/backends/jax/experimental/linear_algebra.py b/ivy/functional/backends/jax/experimental/linear_algebra.py index 897254757e7ba..f4fd65d9f9a10 100644 --- a/ivy/functional/backends/jax/experimental/linear_algebra.py +++ b/ivy/functional/backends/jax/experimental/linear_algebra.py @@ -2,12 +2,15 @@ from typing import Optional, Tuple, Sequence, Union import jax.numpy as jnp import jax.scipy.linalg as jla + +from ivy.func_wrapper import with_supported_dtypes from ivy.functional.backends.jax import JaxArray import ivy from ivy.functional.ivy.experimental.linear_algebra import _check_valid_dimension_size from ivy.utils.exceptions import IvyNotImplementedException +from . import backend_version def diagflat( @@ -114,9 +117,10 @@ def eig( return jnp.linalg.eig(x) +@with_supported_dtypes( + {"0.4.14 and below": ("complex", "float32", "float64")}, backend_version +) def eigvals(x: JaxArray, /) -> JaxArray: - if not ivy.dtype(x) in (ivy.float32, ivy.float64, ivy.complex64, ivy.complex128): - x = x.astype(jnp.float64) return jnp.linalg.eigvals(x) diff --git a/ivy/functional/backends/jax/experimental/statistical.py b/ivy/functional/backends/jax/experimental/statistical.py index 9d35e51ae6017..06d5b2ef5670e 100644 --- a/ivy/functional/backends/jax/experimental/statistical.py +++ b/ivy/functional/backends/jax/experimental/statistical.py @@ -291,6 +291,7 @@ def cov( ) +@with_unsupported_dtypes({"0.4.14 and below": ("bool",)}, backend_version) def cummax( x: JaxArray, /, @@ -301,12 +302,8 @@ def cummax( dtype: Optional[jnp.dtype] = None, out: Optional[JaxArray] = None, ) -> Tuple[JaxArray, JaxArray]: - if x.dtype in (jnp.bool_, jnp.float16): - x = x.astype(jnp.float64) - elif x.dtype in (jnp.int16, jnp.int8, jnp.uint8): - x = x.astype(jnp.int64) - elif x.dtype in (jnp.complex128, jnp.complex64): - x = jnp.real(x).astype(jnp.float64) + if x.dtype in (jnp.complex128, jnp.complex64): + x = x.real if exclusive or (reverse and exclusive): if exclusive and reverse: @@ -390,7 +387,15 @@ def __get_index(lst, indices=None, prefix=None): return indices -@with_unsupported_dtypes({"0.4.14 and below": "bfloat16"}, backend_version) +@with_unsupported_dtypes( + { + "0.4.14 and below": ( + "bfloat16", + "bool", + ) + }, + backend_version, +) def cummin( x: JaxArray, /, @@ -405,10 +410,7 @@ def cummin( axis = axis + len(x.shape) dtype = ivy.as_native_dtype(dtype) if dtype is None: - if dtype is jnp.bool_: - dtype = ivy.default_int_dtype(as_native=True) - else: - dtype = _infer_dtype(x.dtype) + dtype = _infer_dtype(x.dtype) return jlax.cummin(x, axis, reverse=reverse).astype(dtype)