diff --git a/ivy/functional/frontends/torch/reduction_ops.py b/ivy/functional/frontends/torch/reduction_ops.py index d4a5305ed6642..c03e6400fb980 100644 --- a/ivy/functional/frontends/torch/reduction_ops.py +++ b/ivy/functional/frontends/torch/reduction_ops.py @@ -191,6 +191,42 @@ def nanmean(input, dim=None, keepdim=False, *, dtype=None, out=None): return ivy.nanmean(input, axis=dim, keepdims=keepdim, dtype=dtype, out=out) +@numpy_to_torch_style_args +@to_ivy_arrays_and_back +def nanmedian(input, dim=None, keepdim=False, *, out=None): + if dim is None: + flattened_input = ivy.flatten(input) + sorted_input = ivy.sort(flattened_input) + nonnan_index = int(sorted_input.shape[0] - ivy.isnan(sorted_input).sum()) + return sorted_input[(nonnan_index - 1) // 2] + + nanmedian_tuple = namedtuple("nanmedian", ["values", "indices"]) + + if input.ndim == 0: + result = nanmedian_tuple(input, ivy.array(0)) + else: + sorted_indices = ivy.argsort(input, axis=dim) + nonnan_index = ( + sorted_indices.shape[dim] - ivy.isnan(input).sum(axis=1) - 1 + ) // 2 + nonnan_index = ivy.expand_dims(nonnan_index, axis=1) + nanmedian_indices = ivy.gather_nd(sorted_indices, nonnan_index, batch_dims=1) + nanmedian_values = ivy.take_along_axis( + input, ivy.expand_dims(nanmedian_indices, axis=dim), dim + ).squeeze(axis=dim) + + if keepdim: + nanmedian_values = ivy.expand_dims(nanmedian_values, axis=dim) + nanmedian_indices = ivy.expand_dims(nanmedian_tuple, axis=dim) + + result = nanmedian_tuple(nanmedian_values, nanmedian_indices) + if out is not None: + ivy.inplace_update(out[0], result.values) + ivy.inplace_update(out[1], result.indices) + return out + return result + + @to_ivy_arrays_and_back @with_supported_dtypes( {"2.0.1 and below": ("float", "int")}, diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py index 98fd3b4793061..3e38dd7e88825 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py @@ -672,6 +672,40 @@ def test_torch_nanmean( ) +@handle_frontend_test( + fn_tree="torch.nanmedian", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + ), + keepdim=st.booleans(), +) +def test_torch_nanmedian( + *, + dtype_input_axis, + keepdim, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, input, dim = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=input[0], + dim=dim, + keepdim=keepdim, + ) + + @handle_frontend_test( fn_tree="torch.nansum", dtype_and_x=_get_castable_dtype(