Skip to content

Commit

Permalink
feat(numpy backend): removed manual dtype casting. (#23681)
Browse files Browse the repository at this point in the history
  • Loading branch information
Madjid-CH authored Sep 18, 2023
1 parent 1e1aaa0 commit cfcfbea
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 29 deletions.
10 changes: 4 additions & 6 deletions ivy/functional/backends/numpy/experimental/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

import ivy
from ivy.func_wrapper import with_supported_dtypes
from ivy.func_wrapper import with_supported_dtypes, with_unsupported_dtypes
from ivy.utils.exceptions import IvyNotImplementedException
from .. import backend_version

Expand Down Expand Up @@ -114,24 +114,22 @@ def matrix_exp(
return exp_mat.astype(x.dtype)


@with_unsupported_dtypes({"1.25.2 and below": ("float16",)}, backend_version)
def eig(
x: np.ndarray,
/,
*,
out: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray]:
if ivy.dtype(x) == ivy.float16:
x = x.astype(np.float32)
) -> Tuple[np.ndarray, np.ndarray]:
e, v = np.linalg.eig(x)
return e.astype(complex), v.astype(complex)


eig.support_native_out = False


@with_unsupported_dtypes({"1.25.2 and below": ("float16",)}, backend_version)
def eigvals(x: np.ndarray, /) -> np.ndarray:
if ivy.dtype(x) == ivy.float16:
x = x.astype(np.float32)
e = np.linalg.eigvals(x)
return e.astype(complex)

Expand Down
12 changes: 1 addition & 11 deletions ivy/functional/backends/numpy/experimental/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,13 +428,6 @@ def cummax(
dtype: Optional[np.dtype] = None,
out: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray]:
if x.dtype in (np.bool_, np.float16):
x = x.astype(np.float64)
elif x.dtype in (np.int16, np.int8, np.uint8):
x = x.astype(np.int64)
elif x.dtype in (np.complex128, np.complex64):
x = np.real(x).astype(np.float64)

if exclusive or reverse:
if exclusive and reverse:
indices = __find_cummax_indices(np.flip(x, axis=axis), axis=axis)
Expand Down Expand Up @@ -527,10 +520,7 @@ def cummin(
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if dtype is None:
if x.dtype == "bool":
dtype = ivy.default_int_dtype(as_native=True)
else:
dtype = _infer_dtype(x.dtype)
dtype = _infer_dtype(x.dtype)
if not (reverse):
return np.minimum.accumulate(x, axis, dtype=dtype, out=out)
elif reverse:
Expand Down
11 changes: 2 additions & 9 deletions ivy/functional/backends/numpy/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def var(
# ------#


@with_unsupported_dtypes({"1.25.2 and below": "bfloat16"}, backend_version)
@with_unsupported_dtypes({"1.25.2 and below": ("bfloat16",)}, backend_version)
def cumprod(
x: np.ndarray,
/,
Expand All @@ -183,10 +183,7 @@ def cumprod(
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if dtype is None:
if x.dtype == "bool":
dtype = ivy.default_int_dtype(as_native=True)
else:
dtype = _infer_dtype(x.dtype)
dtype = _infer_dtype(x.dtype)
if not (exclusive or reverse):
return np.cumprod(x, axis, dtype=dtype, out=out)
elif exclusive and reverse:
Expand Down Expand Up @@ -218,10 +215,6 @@ def cumsum(
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if dtype is None:
if x.dtype == "bool":
dtype = ivy.default_int_dtype(as_native=True)
if ivy.is_int_dtype(x.dtype):
dtype = ivy.promote_types(x.dtype, ivy.default_int_dtype(as_native=True))
dtype = _infer_dtype(x.dtype)

if exclusive or reverse:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@st.composite
def _get_castable_dtype(draw, min_value=None, max_value=None):
available_dtypes = helpers.get_dtypes("numeric")
available_dtypes = helpers.get_dtypes("valid")
shape = draw(helpers.get_shape(min_num_dims=1, max_num_dims=4, max_dim_size=6))
dtype, values = draw(
helpers.dtype_and_values(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -895,12 +895,13 @@ def test_dot(*, data, test_flags, backend_fw, fn_name, on_device):
test_with_out=st.just(False),
test_gradients=st.just(False),
)
def test_eig(dtype_x, test_flags, backend_fw, fn_name):
def test_eig(dtype_x, test_flags, backend_fw, fn_name, on_device):
dtype, x = dtype_x
helpers.test_function(
input_dtypes=dtype,
test_flags=test_flags,
backend_to_test=backend_fw,
on_device=on_device,
fn_name=fn_name,
test_values=False,
x=x[0],
Expand Down Expand Up @@ -1002,12 +1003,13 @@ def test_eigh_tridiagonal(
test_with_out=st.just(False),
test_gradients=st.just(False),
)
def test_eigvals(dtype_x, test_flags, backend_fw, fn_name):
def test_eigvals(dtype_x, test_flags, backend_fw, fn_name, on_device):
dtype, x = dtype_x
helpers.test_function(
input_dtypes=dtype,
test_flags=test_flags,
backend_to_test=backend_fw,
on_device=on_device,
fn_name=fn_name,
test_values=False,
x=x[0],
Expand Down

0 comments on commit cfcfbea

Please sign in to comment.