Skip to content

Commit

Permalink
fix: corrected implementation of where param numpy mean (ivy-llc#23477)
Browse files Browse the repository at this point in the history
Co-authored-by: @AnnaTz
  • Loading branch information
ShreyanshBardia authored and druvdub committed Oct 14, 2023
1 parent 962b5d5 commit ad3e83f
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 71 deletions.
2 changes: 1 addition & 1 deletion ivy/functional/frontends/numpy/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def any(self, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
def argsort(self, *, axis=-1, kind=None, order=None):
return np_frontend.argsort(self, axis=axis, kind=kind, order=order)

def mean(self, *, axis=None, dtype=None, out=None, keepdims=False, where=True):
def mean(self, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
return np_frontend.mean(
self,
axis=axis,
Expand Down
64 changes: 16 additions & 48 deletions ivy/functional/frontends/numpy/statistics/averages_and_variances.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,62 +72,30 @@ def cov(
@handle_numpy_dtype
@to_ivy_arrays_and_back
@from_zero_dim_arrays_to_scalar
def mean(
a,
/,
*,
axis=None,
keepdims=False,
out=None,
dtype=None,
where=True,
):
def mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
axis = tuple(axis) if isinstance(axis, list) else axis
if dtype:
a = ivy.astype(ivy.array(a), ivy.as_ivy_dtype(dtype))

ret = ivy.mean(a, axis=axis, keepdims=keepdims, out=out)
if ivy.is_array(where):
ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)
dtype = dtype or a.dtype if not ivy.is_int_dtype(a.dtype) else ivy.float64
where = ivy.where(where, ivy.ones_like(a), 0)
if where is not True:
a = ivy.where(where, a, 0.0)
sum = ivy.sum(a, axis=axis, keepdims=keepdims, dtype=dtype)
cnt = ivy.sum(where, axis=axis, keepdims=keepdims, dtype=int)
ret = ivy.divide(sum, cnt, out=out)
else:
ret = ivy.mean(a.astype(dtype), axis=axis, keepdims=keepdims, out=out)

return ret
return ret.astype(dtype)


@handle_numpy_out
@handle_numpy_dtype
@to_ivy_arrays_and_back
@from_zero_dim_arrays_to_scalar
def nanmean(
a,
/,
*,
axis=None,
keepdims=False,
out=None,
dtype=None,
where=True,
):
is_nan = ivy.isnan(a)
axis = tuple(axis) if isinstance(axis, list) else axis

if not ivy.any(is_nan):
if dtype:
a = ivy.astype(ivy.array(a), ivy.as_ivy_dtype(dtype))
ret = ivy.mean(a, axis=axis, keepdims=keepdims, out=out)

if ivy.is_array(where):
ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)

else:
a = [i for i in a if ivy.isnan(i) is False]

if dtype:
a = ivy.astype(ivy.array(a), ivy.as_ivy_dtype(dtype))
ret = ivy.mean(a, axis=axis, keepdims=keepdims, out=out)

if ivy.is_array(where):
ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)

def nanmean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
where = ~ivy.isnan(a) & where
ret = mean(a, axis, dtype, keepdims=keepdims, where=where).ivy_array
if out is not None:
out.data = ret.data
return ret


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from ivy_tests.test_ivy.test_frontends.test_numpy.test_manipulation_routines.test_changing_number_of_dimensions import ( # noqa
_squeeze_helper,
)
from ivy_tests.test_ivy.test_functional.test_core.test_statistical import (
_statistical_dtype_values,
)

CLASS_TREE = "ivy.functional.frontends.numpy.ndarray"

Expand Down Expand Up @@ -2766,35 +2769,42 @@ def test_numpy_ndarray_max(
class_tree=CLASS_TREE,
init_tree="numpy.array",
method_name="mean",
dtype_x_axis=helpers.dtype_values_axis(
available_dtypes=helpers.get_dtypes("float"),
min_axis=-1,
max_axis=0,
min_num_dims=1,
force_int_axis=True,
),
dtype_and_x=_statistical_dtype_values(function="mean"),
dtype=helpers.get_dtypes("float", full=False, none=True),
where=np_frontend_helpers.where(),
keep_dims=st.booleans(),
)
def test_numpy_ndarray_mean(
dtype_x_axis,
dtype_and_x,
dtype,
where,
keep_dims,
frontend_method_data,
init_flags,
method_flags,
backend_fw,
frontend,
on_device,
):
input_dtypes, x, axis = dtype_x_axis
input_dtypes, x, axis = dtype_and_x
where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools(
where=where,
input_dtype=input_dtypes,
test_flags=method_flags,
)
helpers.test_frontend_method(
init_input_dtypes=input_dtypes,
backend_to_test=backend_fw,
method_input_dtypes=input_dtypes,
method_input_dtypes=input_dtypes[1:],
init_all_as_kwargs_np={
"object": x[0],
},
method_all_as_kwargs_np={
"axis": axis,
"dtype": "float64",
"dtype": dtype[0],
"out": None,
"keepdims": keep_dims,
"where": where,
},
frontend=frontend,
frontend_method_data=frontend_method_data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,31 +223,24 @@ def test_numpy_mean(
keep_dims,
):
input_dtypes, x, axis = dtype_and_x
if isinstance(axis, tuple):
axis = axis[0]

where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools(
where=where,
input_dtype=input_dtypes,
test_flags=test_flags,
)

np_frontend_helpers.test_frontend_function(
helpers.test_frontend_function(
input_dtypes=input_dtypes,
frontend=frontend,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
atol=1e-2,
rtol=1e-2,
x=x[0],
a=x[0],
axis=axis,
dtype=dtype[0],
out=None,
keepdims=keep_dims,
where=where,
test_values=False,
)


Expand Down Expand Up @@ -280,7 +273,7 @@ def test_numpy_nanmean(
test_flags=test_flags,
)

np_frontend_helpers.test_frontend_function(
helpers.test_frontend_function(
input_dtypes=input_dtypes,
frontend=frontend,
backend_to_test=backend_fw,
Expand All @@ -295,7 +288,6 @@ def test_numpy_nanmean(
out=None,
keepdims=keep_dims,
where=where,
test_values=False,
)


Expand Down

0 comments on commit ad3e83f

Please sign in to comment.