Skip to content

Commit

Permalink
feat(jax backend): removed manual dtype casting. (#23655)
Browse files Browse the repository at this point in the history
  • Loading branch information
Madjid-CH authored Sep 20, 2023
1 parent 8b55d72 commit 8e9a211
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
8 changes: 6 additions & 2 deletions ivy/functional/backends/jax/experimental/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
from typing import Optional, Tuple, Sequence, Union
import jax.numpy as jnp
import jax.scipy.linalg as jla

from ivy.func_wrapper import with_supported_dtypes
from ivy.functional.backends.jax import JaxArray

import ivy

from ivy.functional.ivy.experimental.linear_algebra import _check_valid_dimension_size
from ivy.utils.exceptions import IvyNotImplementedException
from . import backend_version


def diagflat(
Expand Down Expand Up @@ -114,9 +117,10 @@ def eig(
return jnp.linalg.eig(x)


@with_supported_dtypes(
{"0.4.14 and below": ("complex", "float32", "float64")}, backend_version
)
def eigvals(x: JaxArray, /) -> JaxArray:
if not ivy.dtype(x) in (ivy.float32, ivy.float64, ivy.complex64, ivy.complex128):
x = x.astype(jnp.float64)
return jnp.linalg.eigvals(x)


Expand Down
24 changes: 13 additions & 11 deletions ivy/functional/backends/jax/experimental/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def cov(
)


@with_unsupported_dtypes({"0.4.14 and below": ("bool",)}, backend_version)
def cummax(
x: JaxArray,
/,
Expand All @@ -301,12 +302,8 @@ def cummax(
dtype: Optional[jnp.dtype] = None,
out: Optional[JaxArray] = None,
) -> Tuple[JaxArray, JaxArray]:
if x.dtype in (jnp.bool_, jnp.float16):
x = x.astype(jnp.float64)
elif x.dtype in (jnp.int16, jnp.int8, jnp.uint8):
x = x.astype(jnp.int64)
elif x.dtype in (jnp.complex128, jnp.complex64):
x = jnp.real(x).astype(jnp.float64)
if x.dtype in (jnp.complex128, jnp.complex64):
x = x.real

if exclusive or (reverse and exclusive):
if exclusive and reverse:
Expand Down Expand Up @@ -390,7 +387,15 @@ def __get_index(lst, indices=None, prefix=None):
return indices


@with_unsupported_dtypes({"0.4.14 and below": "bfloat16"}, backend_version)
@with_unsupported_dtypes(
{
"0.4.14 and below": (
"bfloat16",
"bool",
)
},
backend_version,
)
def cummin(
x: JaxArray,
/,
Expand All @@ -405,10 +410,7 @@ def cummin(
axis = axis + len(x.shape)
dtype = ivy.as_native_dtype(dtype)
if dtype is None:
if dtype is jnp.bool_:
dtype = ivy.default_int_dtype(as_native=True)
else:
dtype = _infer_dtype(x.dtype)
dtype = _infer_dtype(x.dtype)
return jlax.cummin(x, axis, reverse=reverse).astype(dtype)


Expand Down

0 comments on commit 8e9a211

Please sign in to comment.