Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add extra arguments to min #27152

Merged
merged 18 commits into from
Dec 25, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion ivy/data_classes/array/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def min(
*,
axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False,
initial: Optional[Union[int, float, complex]] = None,
where: Optional[ivy.Array] = None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Expand All @@ -37,6 +39,11 @@ def min(
array (see :ref:`broadcasting`). Otherwise, if ``False``, the
reduced axes (dimensions) must not be included in the
result. Default: ``False``.
initial
The maximum value of an output element.
Must be present to allow computation on empty slice.
where
Elements to compare for minimum
out
optional output array, for writing the result to.

Expand Down Expand Up @@ -69,7 +76,14 @@ def min(
>>> print(y)
ivy.array(0.1)
"""
return ivy.min(self._data, axis=axis, keepdims=keepdims, out=out)
return ivy.min(
self._data,
axis=axis,
keepdims=keepdims,
initial=initial,
where=where,
out=out,
)

def max(
self: ivy.Array,
Expand Down
115 changes: 103 additions & 12 deletions ivy/data_classes/container/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,100 @@


class _ContainerWithStatistical(ContainerBase):
@staticmethod
def _static_min(
x: ivy.Container,
/,
*,
axis: Optional[Union[int, Sequence[int], ivy.Container]] = None,
keepdims: Union[bool, ivy.Container] = False,
initial: Optional[Union[int, float, complex, ivy.Container]] = None,
where: Optional[Union[ivy.Array, ivy.Container]] = None,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
):
"""
ivy.Container static method variant of ivy.min. This method simply wraps the
function, and so the docstring for ivy.min also applies to this method with
minimal changes.

Parameters
----------
self
Input container. Should have a real-valued data type.
axis
axis or axes along which minimum values must be computed.
By default, the minimum value must be computed over the
entire array. If a tuple of integers, minimum values must
be computed over multiple axes. Default: ``None``.
keepdims
optional boolean, if ``True``, the reduced axes
(dimensions) must be included in the result as
singleton dimensions, and, accordingly, the result
must be compatible with the input array
(see :ref:`broadcasting`). Otherwise, if ``False``, the
reduced axes (dimensions) must not be included in the
result. Default: ``False``.
initial
The maximum value of an output element.
Must be present to allow computation on empty slice.
where
Elements to compare for minimum
out
optional output array, for writing the result to.
Returns
-------
ret
if the minimum value was computed over the entire array,
a zero-dimensional array containing the minimum value;
otherwise, a non-zero-dimensional array containing the
minimum values. The returned array must have the same data type
as ``x``.
Examples
--------
With :class:`ivy.Container` input:
>> > x = ivy.Container(a=ivy.array([1, 2, 3]), \
b=ivy.array([2, 3, 4]))
>> > z = x.min()
>> > print(z)
{
a: ivy.array(1),
b: ivy.array(2)
}
>>> x = ivy.Container(a=ivy.array([[1, 2, 3],[-1,0,2]]),
... b=ivy.array([[2, 3, 4], [0, 1, 2]]))
>>> z = x.min(axis=1)
>>> print(z)
{
a:ivy.array([1,-1]),
b:ivy.array([2,0])
}
"""
return ContainerBase.cont_multi_map_in_function(
"min",
x,
axis=axis,
keepdims=keepdims,
initial=initial,
where=where,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

def min(
self: ivy.Container,
/,
*,
axis: Optional[Union[int, Sequence[int], ivy.Container]] = None,
keepdims: Union[bool, ivy.Container] = False,
initial: Optional[Union[int, float, complex, ivy.Container]] = None,
where: Optional[Union[ivy.Array, ivy.Container]] = None,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
Expand Down Expand Up @@ -43,6 +131,11 @@ def min(
(see :ref:`broadcasting`). Otherwise, if ``False``, the
reduced axes (dimensions) must not be included in the
result. Default: ``False``.
initial
The maximum value of an output element.
Must be present to allow computation on empty slice.
where
Elements to compare for minimum
out
optional output array, for writing the result to.

Expand Down Expand Up @@ -77,18 +170,16 @@ def min(
b:ivy.array([2,0])
}
"""
return self.cont_handle_inplace(
self.cont_map(
lambda x_, _: (
ivy.min(x_, axis=axis, keepdims=keepdims)
if ivy.is_array(x_)
else x_
),
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
),
return self._static_min(
self,
axis=axis,
keepdims=keepdims,
initial=initial,
where=where,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

Expand Down
6 changes: 5 additions & 1 deletion ivy/functional/backends/jax/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ def min(
*,
axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False,
initial: Optional[Union[int, float, complex]] = None,
where: Optional[JaxArray] = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
axis = tuple(axis) if isinstance(axis, list) else axis
return jnp.min(a=jnp.asarray(x), axis=axis, keepdims=keepdims)
return jnp.min(
a=jnp.asarray(x), axis=axis, keepdims=keepdims, initial=initial, where=where
)


def max(
Expand Down
10 changes: 9 additions & 1 deletion ivy/functional/backends/numpy/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,18 @@ def min(
*,
axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False,
initial: Optional[Union[int, float, complex]] = None,
where: Optional[np.ndarray] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if where is None:
vedpatwardhan marked this conversation as resolved.
Show resolved Hide resolved
where = True
axis = tuple(axis) if isinstance(axis, list) else axis
return np.asarray(np.amin(a=x, axis=axis, keepdims=keepdims, out=out))
return np.asarray(
np.amin(
a=x, axis=axis, keepdims=keepdims, initial=initial, where=where, out=out
)
)


min.support_native_out = True
Expand Down
18 changes: 18 additions & 0 deletions ivy/functional/backends/paddle/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def min(
*,
axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False,
initial: Optional[Union[int, float, complex]] = None,
where: Optional[paddle.Tensor] = None,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
ret_dtype = x.dtype
Expand All @@ -39,14 +41,30 @@ def min(
imag = paddle.amin(x.imag(), axis=axis, keepdim=keepdims)
ret = paddle.complex(real, imag)
else:
if where is not None:
if x.dtype == paddle.int32:
max_val = 2147483647
elif x.dtype == paddle.int64:
max_val = 922337203685477580
else:
max_val = float("inf")
val = paddle.ones_like(x) * max_val
# print("val=",val)
val = val.astype(ret_dtype)
x = paddle.where(where, x, val)
# print(x)
ret = paddle.amin(x, axis=axis, keepdim=keepdims)
# print(ret)
# The following code is to simulate other frameworks
# output shapes behaviour since min output dim is 1 in paddle
if isinstance(axis, Sequence):
if len(axis) == x.ndim:
axis = None
if (x.ndim == 1 or axis is None) and not keepdims:
ret = ret.squeeze()
if initial is not None:
initial = paddle.to_tensor(initial, dtype=ret_dtype)
ret = paddle.minimum(ret, initial)
return ret.astype(ret_dtype)


Expand Down
31 changes: 29 additions & 2 deletions ivy/functional/backends/tensorflow/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,44 @@
# -------------------#


@with_unsupported_dtypes({"2.14.0 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes(
{"2.14.0 and below": ("complex", "bool", "uint64")}, backend_version
Aaryan562 marked this conversation as resolved.
Show resolved Hide resolved
)
def min(
x: Union[tf.Tensor, tf.Variable],
/,
*,
axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False,
initial: Optional[Union[int, float, complex]] = None,
where: Optional[Union[tf.Tensor, tf.Variable]] = None,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
axis = tuple(axis) if isinstance(axis, list) else axis
return tf.math.reduce_min(x, axis=axis, keepdims=keepdims)
if where is not None:
if x.dtype == tf.int8:
vedpatwardhan marked this conversation as resolved.
Show resolved Hide resolved
max_val = tf.constant(127, dtype=x.dtype)
elif x.dtype == tf.int16:
max_val = tf.constant(32767, dtype=x.dtype)
elif x.dtype == tf.int32:
max_val = tf.constant(2147483647, dtype=x.dtype)
elif x.dtype == tf.int64:
max_val = tf.constant(9223372036854775807, dtype=x.dtype)
elif x.dtype == tf.uint8:
max_val = tf.constant(255, dtype=x.dtype)
elif x.dtype == tf.uint16:
max_val = tf.constant(65535, dtype=x.dtype)
elif x.dtype == tf.uint32:
max_val = tf.constant(4294967295, dtype=x.dtype)
elif x.dtype == tf.uint64:
max_val = tf.constant(18446744073709551615, dtype=x.dtype)
else:
max_val = tf.constant(float("inf"), dtype=x.dtype)
x = tf.where(where, x, tf.ones_like(x) * max_val)
result = tf.math.reduce_min(x, axis=axis, keepdims=keepdims)
if initial is not None:
result = tf.minimum(result, initial)
return result


def max(
Expand Down
30 changes: 27 additions & 3 deletions ivy/functional/backends/torch/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,47 @@
# -------------------#


@with_unsupported_dtypes({"2.1.0 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"2.1.0 and below": ("complex", "bool")}, backend_version)
Aaryan562 marked this conversation as resolved.
Show resolved Hide resolved
def min(
x: torch.Tensor,
/,
*,
axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False,
initial: Optional[Union[int, float, complex]] = None,
where: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if axis == ():
if ivy.exists(out):
return ivy.inplace_update(out, x)
else:
return x

Aaryan562 marked this conversation as resolved.
Show resolved Hide resolved
if where is not None:
if x.dtype == torch.int8:
max_val = 127
elif x.dtype == torch.int16:
max_val = 32767
elif x.dtype == torch.int32:
max_val = 2147483647
elif x.dtype == torch.int64:
max_val = 922337203685477580
elif x.dtype == torch.uint8:
max_val = 255
else:
max_val = float("inf")
val = torch.ones_like(x) * max_val
# print("val=",val)
val = val.type(x.dtype)
x = torch.where(where, x, val)
if not keepdims and not axis and axis != 0:
return torch.amin(input=x, out=out)
return torch.amin(input=x, dim=axis, keepdim=keepdims, out=out)
result = torch.amin(input=x, out=out)
result = torch.amin(input=x, dim=axis, keepdim=keepdims, out=out)
if initial is not None:
initial = torch.tensor(initial, dtype=x.dtype)
result = torch.minimum(result, initial)
return result


min.support_native_out = True
Expand Down
11 changes: 10 additions & 1 deletion ivy/functional/ivy/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def min(
*,
axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False,
initial: Optional[Union[int, float, complex]] = None,
where: Optional[ivy.Array] = None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Expand Down Expand Up @@ -83,6 +85,11 @@ def min(
compatible with the input array (see :ref:`broadcasting`). Otherwise,
if ``False``, the reduced axes (dimensions) must not be included in the result.
Default: ``False``.
initial
The maximum value of an output element.
Must be present to allow computation on empty slice.
where
Elements to compare for minimum
out
optional output array, for writing the result to.

Expand Down Expand Up @@ -140,7 +147,9 @@ def min(
b: ivy.array(2)
}
"""
return current_backend(x).min(x, axis=axis, keepdims=keepdims, out=out)
return current_backend(x).min(
x, axis=axis, keepdims=keepdims, initial=initial, where=where, out=out
)


@handle_exceptions
Expand Down
Loading
Loading