Skip to content

Commit

Permalink
feat(frontends): Add nanmedian to PyTorch reduction ops (#26467)
Browse files Browse the repository at this point in the history
Co-authored-by: ivy-branch <[email protected]>
  • Loading branch information
spacefarers and ivy-branch authored Oct 4, 2023
1 parent 23d2890 commit e5d4dd1
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
36 changes: 36 additions & 0 deletions ivy/functional/frontends/torch/reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,42 @@ def nanmean(input, dim=None, keepdim=False, *, dtype=None, out=None):
return ivy.nanmean(input, axis=dim, keepdims=keepdim, dtype=dtype, out=out)


@numpy_to_torch_style_args
@to_ivy_arrays_and_back
def nanmedian(input, dim=None, keepdim=False, *, out=None):
if dim is None:
flattened_input = ivy.flatten(input)
sorted_input = ivy.sort(flattened_input)
nonnan_index = int(sorted_input.shape[0] - ivy.isnan(sorted_input).sum())
return sorted_input[(nonnan_index - 1) // 2]

nanmedian_tuple = namedtuple("nanmedian", ["values", "indices"])

if input.ndim == 0:
result = nanmedian_tuple(input, ivy.array(0))
else:
sorted_indices = ivy.argsort(input, axis=dim)
nonnan_index = (
sorted_indices.shape[dim] - ivy.isnan(input).sum(axis=1) - 1
) // 2
nonnan_index = ivy.expand_dims(nonnan_index, axis=1)
nanmedian_indices = ivy.gather_nd(sorted_indices, nonnan_index, batch_dims=1)
nanmedian_values = ivy.take_along_axis(
input, ivy.expand_dims(nanmedian_indices, axis=dim), dim
).squeeze(axis=dim)

if keepdim:
nanmedian_values = ivy.expand_dims(nanmedian_values, axis=dim)
nanmedian_indices = ivy.expand_dims(nanmedian_tuple, axis=dim)

result = nanmedian_tuple(nanmedian_values, nanmedian_indices)
if out is not None:
ivy.inplace_update(out[0], result.values)
ivy.inplace_update(out[1], result.indices)
return out
return result


@to_ivy_arrays_and_back
@with_supported_dtypes(
{"2.0.1 and below": ("float", "int")},
Expand Down
34 changes: 34 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,40 @@ def test_torch_nanmean(
)


@handle_frontend_test(
fn_tree="torch.nanmedian",
dtype_input_axis=helpers.dtype_values_axis(
available_dtypes=helpers.get_dtypes("numeric"),
min_num_dims=1,
valid_axis=True,
force_int_axis=True,
),
keepdim=st.booleans(),
)
def test_torch_nanmedian(
*,
dtype_input_axis,
keepdim,
on_device,
fn_tree,
frontend,
test_flags,
backend_fw,
):
input_dtype, input, dim = dtype_input_axis
helpers.test_frontend_function(
input_dtypes=input_dtype,
backend_to_test=backend_fw,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
input=input[0],
dim=dim,
keepdim=keepdim,
)


@handle_frontend_test(
fn_tree="torch.nansum",
dtype_and_x=_get_castable_dtype(
Expand Down

0 comments on commit e5d4dd1

Please sign in to comment.