From ad3e83f94509cd40ada69951dbf74993862dd7b9 Mon Sep 17 00:00:00 2001 From: Shreyansh Bardia <104841983+ShreyanshBardia@users.noreply.github.com> Date: Wed, 27 Sep 2023 17:01:52 +0530 Subject: [PATCH] fix: corrected implementation of where param numpy mean (#23477) Co-authored-by: @AnnaTz --- .../frontends/numpy/ndarray/ndarray.py | 2 +- .../statistics/averages_and_variances.py | 64 +++++-------------- .../test_numpy/test_ndarray/test_ndarray.py | 32 ++++++---- .../test_averages_and_variances.py | 14 +--- 4 files changed, 41 insertions(+), 71 deletions(-) diff --git a/ivy/functional/frontends/numpy/ndarray/ndarray.py b/ivy/functional/frontends/numpy/ndarray/ndarray.py index 9c6a652673631..8ac0ccb0458c5 100644 --- a/ivy/functional/frontends/numpy/ndarray/ndarray.py +++ b/ivy/functional/frontends/numpy/ndarray/ndarray.py @@ -165,7 +165,7 @@ def any(self, axis=None, dtype=None, out=None, keepdims=False, *, where=True): def argsort(self, *, axis=-1, kind=None, order=None): return np_frontend.argsort(self, axis=axis, kind=kind, order=order) - def mean(self, *, axis=None, dtype=None, out=None, keepdims=False, where=True): + def mean(self, axis=None, dtype=None, out=None, keepdims=False, *, where=True): return np_frontend.mean( self, axis=axis, diff --git a/ivy/functional/frontends/numpy/statistics/averages_and_variances.py b/ivy/functional/frontends/numpy/statistics/averages_and_variances.py index 1e68cc108b0a6..f22f4756a95c8 100644 --- a/ivy/functional/frontends/numpy/statistics/averages_and_variances.py +++ b/ivy/functional/frontends/numpy/statistics/averages_and_variances.py @@ -72,62 +72,30 @@ def cov( @handle_numpy_dtype @to_ivy_arrays_and_back @from_zero_dim_arrays_to_scalar -def mean( - a, - /, - *, - axis=None, - keepdims=False, - out=None, - dtype=None, - where=True, -): +def mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): axis = tuple(axis) if isinstance(axis, list) else axis - if dtype: - a = ivy.astype(ivy.array(a), ivy.as_ivy_dtype(dtype)) - - ret = ivy.mean(a, axis=axis, keepdims=keepdims, out=out) - if ivy.is_array(where): - ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out) + dtype = dtype or a.dtype if not ivy.is_int_dtype(a.dtype) else ivy.float64 + where = ivy.where(where, ivy.ones_like(a), 0) + if where is not True: + a = ivy.where(where, a, 0.0) + sum = ivy.sum(a, axis=axis, keepdims=keepdims, dtype=dtype) + cnt = ivy.sum(where, axis=axis, keepdims=keepdims, dtype=int) + ret = ivy.divide(sum, cnt, out=out) + else: + ret = ivy.mean(a.astype(dtype), axis=axis, keepdims=keepdims, out=out) - return ret + return ret.astype(dtype) @handle_numpy_out @handle_numpy_dtype @to_ivy_arrays_and_back @from_zero_dim_arrays_to_scalar -def nanmean( - a, - /, - *, - axis=None, - keepdims=False, - out=None, - dtype=None, - where=True, -): - is_nan = ivy.isnan(a) - axis = tuple(axis) if isinstance(axis, list) else axis - - if not ivy.any(is_nan): - if dtype: - a = ivy.astype(ivy.array(a), ivy.as_ivy_dtype(dtype)) - ret = ivy.mean(a, axis=axis, keepdims=keepdims, out=out) - - if ivy.is_array(where): - ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out) - - else: - a = [i for i in a if ivy.isnan(i) is False] - - if dtype: - a = ivy.astype(ivy.array(a), ivy.as_ivy_dtype(dtype)) - ret = ivy.mean(a, axis=axis, keepdims=keepdims, out=out) - - if ivy.is_array(where): - ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out) - +def nanmean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): + where = ~ivy.isnan(a) & where + ret = mean(a, axis, dtype, keepdims=keepdims, where=where).ivy_array + if out is not None: + out.data = ret.data return ret diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_ndarray/test_ndarray.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_ndarray/test_ndarray.py index f595f36a290c3..11e761c60e2ec 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_ndarray/test_ndarray.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_ndarray/test_ndarray.py @@ -28,6 +28,9 @@ from ivy_tests.test_ivy.test_frontends.test_numpy.test_manipulation_routines.test_changing_number_of_dimensions import ( # noqa _squeeze_helper, ) +from ivy_tests.test_ivy.test_functional.test_core.test_statistical import ( + _statistical_dtype_values, +) CLASS_TREE = "ivy.functional.frontends.numpy.ndarray" @@ -2766,16 +2769,16 @@ def test_numpy_ndarray_max( class_tree=CLASS_TREE, init_tree="numpy.array", method_name="mean", - dtype_x_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - min_axis=-1, - max_axis=0, - min_num_dims=1, - force_int_axis=True, - ), + dtype_and_x=_statistical_dtype_values(function="mean"), + dtype=helpers.get_dtypes("float", full=False, none=True), + where=np_frontend_helpers.where(), + keep_dims=st.booleans(), ) def test_numpy_ndarray_mean( - dtype_x_axis, + dtype_and_x, + dtype, + where, + keep_dims, frontend_method_data, init_flags, method_flags, @@ -2783,18 +2786,25 @@ def test_numpy_ndarray_mean( frontend, on_device, ): - input_dtypes, x, axis = dtype_x_axis + input_dtypes, x, axis = dtype_and_x + where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=method_flags, + ) helpers.test_frontend_method( init_input_dtypes=input_dtypes, backend_to_test=backend_fw, - method_input_dtypes=input_dtypes, + method_input_dtypes=input_dtypes[1:], init_all_as_kwargs_np={ "object": x[0], }, method_all_as_kwargs_np={ "axis": axis, - "dtype": "float64", + "dtype": dtype[0], "out": None, + "keepdims": keep_dims, + "where": where, }, frontend=frontend, frontend_method_data=frontend_method_data, diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_statistics/test_averages_and_variances.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_statistics/test_averages_and_variances.py index 38b605e86d1f8..74c69698e96e0 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_statistics/test_averages_and_variances.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_statistics/test_averages_and_variances.py @@ -223,31 +223,24 @@ def test_numpy_mean( keep_dims, ): input_dtypes, x, axis = dtype_and_x - if isinstance(axis, tuple): - axis = axis[0] - where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, test_flags=test_flags, ) - - np_frontend_helpers.test_frontend_function( + helpers.test_frontend_function( input_dtypes=input_dtypes, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - atol=1e-2, - rtol=1e-2, - x=x[0], + a=x[0], axis=axis, dtype=dtype[0], out=None, keepdims=keep_dims, where=where, - test_values=False, ) @@ -280,7 +273,7 @@ def test_numpy_nanmean( test_flags=test_flags, ) - np_frontend_helpers.test_frontend_function( + helpers.test_frontend_function( input_dtypes=input_dtypes, frontend=frontend, backend_to_test=backend_fw, @@ -295,7 +288,6 @@ def test_numpy_nanmean( out=None, keepdims=keep_dims, where=where, - test_values=False, )