Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add extra arguments to min #27152

Merged
merged 18 commits into from
Dec 25, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion ivy/data_classes/array/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def min(
*,
axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False,
initial: Optional[Union[int, float, complex]] = None,
where: Optional[ivy.Array] = None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Expand All @@ -37,6 +39,11 @@ def min(
array (see :ref:`broadcasting`). Otherwise, if ``False``, the
reduced axes (dimensions) must not be included in the
result. Default: ``False``.
initial
The maximum value of an output element.
Must be present to allow computation on empty slice.
where
Elements to compare for minimum
out
optional output array, for writing the result to.

Expand Down Expand Up @@ -69,7 +76,14 @@ def min(
>>> print(y)
ivy.array(0.1)
"""
return ivy.min(self._data, axis=axis, keepdims=keepdims, out=out)
return ivy.min(
self._data,
axis=axis,
keepdims=keepdims,
initial=initial,
where=where,
out=out,
)

def max(
self: ivy.Array,
Expand Down
115 changes: 103 additions & 12 deletions ivy/data_classes/container/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,100 @@


class _ContainerWithStatistical(ContainerBase):
@staticmethod
def _static_min(
x: ivy.Container,
/,
*,
axis: Optional[Union[int, Sequence[int], ivy.Container]] = None,
keepdims: Union[bool, ivy.Container] = False,
initial: Optional[Union[int, float, complex, ivy.Container]] = None,
where: Optional[Union[ivy.Array, ivy.Container]] = None,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
):
"""
ivy.Container static method variant of ivy.min. This method simply wraps the
function, and so the docstring for ivy.min also applies to this method with
minimal changes.

Parameters
----------
self
Input container. Should have a real-valued data type.
axis
axis or axes along which minimum values must be computed.
By default, the minimum value must be computed over the
entire array. If a tuple of integers, minimum values must
be computed over multiple axes. Default: ``None``.
keepdims
optional boolean, if ``True``, the reduced axes
(dimensions) must be included in the result as
singleton dimensions, and, accordingly, the result
must be compatible with the input array
(see :ref:`broadcasting`). Otherwise, if ``False``, the
reduced axes (dimensions) must not be included in the
result. Default: ``False``.
initial
The maximum value of an output element.
Must be present to allow computation on empty slice.
where
Elements to compare for minimum
out
optional output array, for writing the result to.
Returns
-------
ret
if the minimum value was computed over the entire array,
a zero-dimensional array containing the minimum value;
otherwise, a non-zero-dimensional array containing the
minimum values. The returned array must have the same data type
as ``x``.
Examples
--------
With :class:`ivy.Container` input:
>> > x = ivy.Container(a=ivy.array([1, 2, 3]), \
b=ivy.array([2, 3, 4]))
>> > z = x.min()
>> > print(z)
{
a: ivy.array(1),
b: ivy.array(2)
}
>>> x = ivy.Container(a=ivy.array([[1, 2, 3],[-1,0,2]]),
... b=ivy.array([[2, 3, 4], [0, 1, 2]]))
>>> z = x.min(axis=1)
>>> print(z)
{
a:ivy.array([1,-1]),
b:ivy.array([2,0])
}
"""
return ContainerBase.cont_multi_map_in_function(
"min",
x,
axis=axis,
keepdims=keepdims,
initial=initial,
where=where,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

def min(
self: ivy.Container,
/,
*,
axis: Optional[Union[int, Sequence[int], ivy.Container]] = None,
keepdims: Union[bool, ivy.Container] = False,
initial: Optional[Union[int, float, complex, ivy.Container]] = None,
where: Optional[Union[ivy.Array, ivy.Container]] = None,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
Expand Down Expand Up @@ -43,6 +131,11 @@ def min(
(see :ref:`broadcasting`). Otherwise, if ``False``, the
reduced axes (dimensions) must not be included in the
result. Default: ``False``.
initial
The maximum value of an output element.
Must be present to allow computation on empty slice.
where
Elements to compare for minimum
out
optional output array, for writing the result to.

Expand Down Expand Up @@ -77,18 +170,16 @@ def min(
b:ivy.array([2,0])
}
"""
return self.cont_handle_inplace(
self.cont_map(
lambda x_, _: (
ivy.min(x_, axis=axis, keepdims=keepdims)
if ivy.is_array(x_)
else x_
),
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
),
return self._static_min(
self,
axis=axis,
keepdims=keepdims,
initial=initial,
where=where,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

Expand Down
6 changes: 5 additions & 1 deletion ivy/functional/backends/jax/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ def min(
*,
axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False,
initial: Optional[Union[int, float, complex]] = None,
where: Optional[JaxArray] = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
axis = tuple(axis) if isinstance(axis, list) else axis
return jnp.min(a=jnp.asarray(x), axis=axis, keepdims=keepdims)
return jnp.min(
a=jnp.asarray(x), axis=axis, keepdims=keepdims, initial=initial, where=where
)


def max(
Expand Down
8 changes: 7 additions & 1 deletion ivy/functional/backends/numpy/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,16 @@ def min(
*,
axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False,
initial: Optional[Union[int, float, complex]] = None,
where: Optional[np.ndarray] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
axis = tuple(axis) if isinstance(axis, list) else axis
return np.asarray(np.amin(a=x, axis=axis, keepdims=keepdims, out=out))
return np.asarray(
np.amin(
a=x, axis=axis, keepdims=keepdims, initial=initial, where=where, out=out
)
)


min.support_native_out = True
Expand Down
17 changes: 17 additions & 0 deletions ivy/functional/backends/paddle/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def min(
*,
axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False,
initial: Optional[Union[int, float, complex]] = None,
where: Optional[paddle.Tensor] = None,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
ret_dtype = x.dtype
Expand All @@ -39,6 +41,18 @@ def min(
imag = paddle.amin(x.imag(), axis=axis, keepdim=keepdims)
ret = paddle.complex(real, imag)
else:
if where is not None:
max_val = (
ivy.iinfo(x.dtype).max
if ivy.is_int_dtype(x.dtype)
else ivy.finfo(x.dtype).max
)
max_val = max_val / 10
# max_val becomes negative after multiplying with paddle.ones_like(x)
# therefore reduced it
val = paddle.ones_like(x) * max_val
val = val.astype(ret_dtype)
x = paddle.where(where, x, val)
ret = paddle.amin(x, axis=axis, keepdim=keepdims)
# The following code is to simulate other frameworks
# output shapes behaviour since min output dim is 1 in paddle
Expand All @@ -47,6 +61,9 @@ def min(
axis = None
if (x.ndim == 1 or axis is None) and not keepdims:
ret = ret.squeeze()
if initial is not None:
initial = paddle.to_tensor(initial, dtype=ret_dtype)
ret = paddle.minimum(ret, initial)
return ret.astype(ret_dtype)


Expand Down
16 changes: 14 additions & 2 deletions ivy/functional/backends/tensorflow/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,29 @@
# -------------------#


@with_unsupported_dtypes({"2.14.0 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"2.14.0 and below": ("complex", "bool")}, backend_version)
def min(
x: Union[tf.Tensor, tf.Variable],
/,
*,
axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False,
initial: Optional[Union[int, float, complex]] = None,
where: Optional[Union[tf.Tensor, tf.Variable]] = None,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
axis = tuple(axis) if isinstance(axis, list) else axis
return tf.math.reduce_min(x, axis=axis, keepdims=keepdims)
if where is not None:
max_val = (
ivy.iinfo(x.dtype).max
if ivy.is_int_dtype(x.dtype)
else ivy.finfo(x.dtype).max
)
x = tf.where(where, x, tf.ones_like(x) * max_val)
result = tf.math.reduce_min(x, axis=axis, keepdims=keepdims)
if initial is not None:
result = tf.minimum(result, initial)
return result


def max(
Expand Down
19 changes: 17 additions & 2 deletions ivy/functional/backends/torch/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,31 @@ def min(
*,
axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False,
initial: Optional[Union[int, float, complex]] = None,
where: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if axis == ():
if ivy.exists(out):
return ivy.inplace_update(out, x)
else:
return x
if where is not None:
max_val = (
ivy.iinfo(x.dtype).max
if ivy.is_int_dtype(x.dtype)
else ivy.finfo(x.dtype).max
)
val = torch.ones_like(x) * max_val
val = val.type(x.dtype)
x = torch.where(where, x, val)
if not keepdims and not axis and axis != 0:
return torch.amin(input=x, out=out)
return torch.amin(input=x, dim=axis, keepdim=keepdims, out=out)
result = torch.amin(input=x, out=out)
result = torch.amin(input=x, dim=axis, keepdim=keepdims, out=out)
if initial is not None:
initial = torch.tensor(initial, dtype=x.dtype)
result = torch.minimum(result, initial)
return result


min.support_native_out = True
Expand Down
11 changes: 10 additions & 1 deletion ivy/functional/ivy/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def min(
*,
axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False,
initial: Optional[Union[int, float, complex]] = None,
where: Optional[ivy.Array] = None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Expand Down Expand Up @@ -83,6 +85,11 @@ def min(
compatible with the input array (see :ref:`broadcasting`). Otherwise,
if ``False``, the reduced axes (dimensions) must not be included in the result.
Default: ``False``.
initial
The maximum value of an output element.
Must be present to allow computation on empty slice.
where
Elements to compare for minimum
out
optional output array, for writing the result to.

Expand Down Expand Up @@ -140,7 +147,9 @@ def min(
b: ivy.array(2)
}
"""
return current_backend(x).min(x, axis=axis, keepdims=keepdims, out=out)
return current_backend(x).min(
x, axis=axis, keepdims=keepdims, initial=initial, where=where, out=out
)


@handle_exceptions
Expand Down
27 changes: 23 additions & 4 deletions ivy_tests/test_ivy/test_functional/test_core/test_statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,20 @@ def _statistical_dtype_values(draw, *, function, min_value=None, max_value=None)
| helpers.floats(min_value=0, max_value=max_correction - 1)
)
return dtype, values, axis, correction
return dtype, values, axis

if isinstance(axis, tuple):
axis = axis[0]

where_shape = draw(
helpers.mutually_broadcastable_shapes(
num_shapes=1, base_shape=shape, min_dims=0, max_dims=axis
vedpatwardhan marked this conversation as resolved.
Show resolved Hide resolved
)
)
dtype3, where = draw(
helpers.dtype_and_values(available_dtypes=["bool"], shape=where_shape[0])
)

return dtype, values, axis, dtype3, where


# --- Main --- #
Expand Down Expand Up @@ -259,18 +272,24 @@ def test_mean(*, dtype_and_x, keep_dims, test_flags, backend_fw, fn_name, on_dev
fn_tree="functional.ivy.min",
dtype_and_x=_statistical_dtype_values(function="min"),
keep_dims=st.booleans(),
test_gradients=st.just(False),
initial=st.integers(min_value=-5, max_value=5),
)
def test_min(*, dtype_and_x, keep_dims, test_flags, backend_fw, fn_name, on_device):
input_dtype, x, axis = dtype_and_x
def test_min(
*, dtype_and_x, keep_dims, initial, test_flags, backend_fw, fn_name, on_device
):
input_dtype, x, axis, dtype3, where = dtype_and_x
helpers.test_function(
input_dtypes=input_dtype,
input_dtypes=[input_dtype[0], dtype3[0]],
test_flags=test_flags,
backend_to_test=backend_fw,
fn_name=fn_name,
on_device=on_device,
x=x[0],
axis=axis,
keepdims=keep_dims,
initial=initial,
where=where[0],
)


Expand Down
Loading