Skip to content

Commit

Permalink
feat: Added unsorted_segment_mean function (#24984)
Browse files Browse the repository at this point in the history
Co-authored-by: ivy-branch <[email protected]>
Co-authored-by: Rishab Mallick <[email protected]>
  • Loading branch information
3 people authored Oct 29, 2023
1 parent ea0eaad commit a84b6b6
Show file tree
Hide file tree
Showing 11 changed files with 334 additions and 12 deletions.
44 changes: 44 additions & 0 deletions ivy/data_classes/array/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,50 @@ def mel_weight_matrix(
upper_edge_hertz,
)

def unsorted_segment_mean(
self: ivy.Array,
segment_ids: ivy.Array,
num_segments: Union[int, ivy.Array],
) -> ivy.Array:
"""
Compute the mean of values in the array 'self' based on segment identifiers.
Parameters
----------
self : ivy.Array
The array from which to gather values.
segment_ids : ivy.Array
Must be in the same size with the first dimension of `self`. Has to be
of integer data type. The index-th element of `segment_ids` array is
the segment identifier for the index-th element of `self`.
num_segments : Union[int, ivy.Array]
An integer or array representing the total number of distinct segment IDs.
Returns
-------
ret : ivy.Array
The output array, representing the result of a segmented mean operation.
For each segment, it computes the mean of values in `self` where
`segment_ids` equals to segment ID.
Examples
--------
>>> data = ivy.array([1.0, 2.0, 3.0, 4.0])
>>> segment_ids = ivy.array([0, 0, 0, 0])
>>> num_segments = 1
>>> result = ivy.unsorted_segment_mean(data, segment_ids, num_segments)
>>> result
ivy.array([2.5])
>>> data = ivy.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
>>> segment_ids = ivy.array([0, 0, 1, 1, 2, 2])
>>> num_segments = 3
>>> result = ivy.unsorted_segment_mean(data, segment_ids, num_segments)
>>> result
ivy.array([[1.5, 3.5, 5.5],[1.5, 3.5, 5.5],[1.5, 3.5, 5.5]])
"""
return ivy.unsorted_segment_mean(self._data, segment_ids, num_segments)


def polyval(
coeffs=ivy.Array,
Expand Down
107 changes: 107 additions & 0 deletions ivy/data_classes/container/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,57 @@ def mel_weight_matrix(
)

@staticmethod
def static_unsorted_segment_mean(
data: ivy.Container,
segment_ids: Union[ivy.Array, ivy.Container],
num_segments: Union[int, ivy.Container],
*,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
) -> ivy.Container:
"""
Compute the mean of values in the input data based on segment identifiers.
Parameters
----------
data : ivy.Container
Input array or container from which to gather the input.
segment_ids : ivy.Container
An array of integers indicating the segment identifier for each element in
'data'.
num_segments : Union[int, ivy.Container]
An integer or array representing the total number of distinct segment IDs.
key_chains : Optional[Union[List[str], Dict[str, str], ivy.Container]], optional
The key-chains to apply or not apply the method to. Default is None.
to_apply : Union[bool, ivy.Container], optional
If True, the method will be applied to key-chains, otherwise key-chains will
be skipped. Default is True.
prune_unapplied : Union[bool, ivy.Container], optional
Whether to prune key-chains for which the function was not applied.
Default is False.
map_sequences : Union[bool, ivy.Container], optional
Whether to also map method to sequences (lists, tuples). Default is False.
Returns
-------
ivy.Container
A container representing the result of a segmented mean operation.
For each segment, it computes the mean of values in 'data' where
'segment_ids' equals the segment ID.
"""
return ContainerBase.cont_multi_map_in_function(
"unsorted_segment_mean",
data,
segment_ids,
num_segments,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
)

def static_polyval(
coeffs: ivy.Container,
x: Union[ivy.Container, int, float],
Expand Down Expand Up @@ -1253,6 +1304,62 @@ def static_polyval(
map_sequences=map_sequences,
)

def unsorted_segment_mean(
self: ivy.Container,
segment_ids: Union[ivy.Array, ivy.Container],
num_segments: Union[int, ivy.Container],
) -> ivy.Container:
"""
Compute the mean of values in the input array or container based on segment
identifiers.
Parameters
----------
self : ivy.Container
Input array or container from which to gather the input.
segment_ids : ivy.Container
An array of integers indicating the segment identifier for each element
in 'self'.
num_segments : Union[int, ivy.Container]
An integer or array representing the total number of distinct segment IDs.
Returns
-------
ivy.Container
A container representing the result of a segmented mean operation.
For each segment, it computes the mean of values in 'self' where
'segment_ids' equals the segment ID.
Example
--------
>>> data = ivy.Container(a=ivy.array([0., 1., 2., 4.]),
... b=ivy.array([3., 4., 5., 6.]))
>>> segment_ids = ivy.array([0, 0, 1, 1])
>>> num_segments = 2
>>> result = ivy.unsorted_segment_mean(data, segment_ids, num_segments)
>>> print(result)
{
a: ivy.array([0.5, 3.0]),
b: ivy.array([3.5, 5.5])
}
>>> data = ivy.Container(a=ivy.array([0., 1., 2., 4., 5., 6.]),
... b=ivy.array([3., 4., 5., 6., 7., 8.]))
>>> segment_ids = ivy.array([0, 0, 1, 1, 2, 2])
>>> num_segments = 3
>>> result = ivy.unsorted_segment_mean(data, segment_ids, num_segments)
>>> print(result)
{
a: ivy.array([0.5, 3.0, 5.5]),
b: ivy.array([3.5, 5.5, 7.5])
}
"""
return self.static_unsorted_segment_mean(
self,
segment_ids,
num_segments,
)

def polyval(
self: ivy.Container,
coeffs: ivy.Container,
Expand Down
21 changes: 19 additions & 2 deletions ivy/functional/backends/jax/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def unsorted_segment_min(
num_segments: int,
) -> JaxArray:
# added this check to keep the same behaviour as tensorflow
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)
return jax.ops.segment_min(data, segment_ids, num_segments)
Expand All @@ -98,7 +98,7 @@ def unsorted_segment_sum(
# the check should be same
# Might require to change the assertion function name to
# check_unsorted_segment_valid_params
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)
return jax.ops.segment_sum(data, segment_ids, num_segments)
Expand Down Expand Up @@ -165,6 +165,23 @@ def hz_to_mel(f):
return jnp.pad(mel_weights, [[1, 0], [0, 0]])


def unsorted_segment_mean(
data: JaxArray,
segment_ids: JaxArray,
num_segments: int,
) -> JaxArray:
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)
segment_sum = jax.ops.segment_sum(data, segment_ids, num_segments)

segment_count = jax.ops.segment_sum(jnp.ones_like(data), segment_ids, num_segments)

segment_mean = segment_sum / segment_count

return segment_mean


def polyval(
coeffs: JaxArray,
x: JaxArray,
Expand Down
32 changes: 30 additions & 2 deletions ivy/functional/backends/numpy/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def unsorted_segment_min(
segment_ids: np.ndarray,
num_segments: int,
) -> np.ndarray:
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)

Expand Down Expand Up @@ -143,7 +143,7 @@ def unsorted_segment_sum(
# check should be same
# Might require to change the assertion function name to
# check_unsorted_segment_valid_params
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)

Expand Down Expand Up @@ -203,6 +203,34 @@ def hz_to_mel(f):
return np.pad(mel_weights, [[1, 0], [0, 0]])


def unsorted_segment_mean(
data: np.ndarray,
segment_ids: np.ndarray,
num_segments: int,
) -> np.ndarray:
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)

if len(segment_ids) == 0:
# If segment_ids is empty, return an empty array of the correct shape
return np.zeros((num_segments,) + data.shape[1:], dtype=data.dtype)

# Initialize an array to store the sum of elements for each segment
res = np.zeros((num_segments,) + data.shape[1:], dtype=data.dtype)

# Initialize an array to keep track of the number of elements in each segment
counts = np.zeros(num_segments, dtype=np.int64)

for i in range(len(segment_ids)):
seg_id = segment_ids[i]
if seg_id < num_segments:
res[seg_id] += data[i]
counts[seg_id] += 1

return res / counts[:, np.newaxis]


def polyval(coeffs: np.ndarray, x: np.ndarray) -> np.ndarray:
with ivy.PreciseMode(True):
promoted_type = ivy.promote_types(ivy.dtype(coeffs[0]), ivy.dtype(x[0]))
Expand Down
34 changes: 32 additions & 2 deletions ivy/functional/backends/paddle/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def unsorted_segment_min(
segment_ids: paddle.Tensor,
num_segments: Union[int, paddle.Tensor],
) -> paddle.Tensor:
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)
if data.dtype == paddle.float32:
Expand Down Expand Up @@ -156,7 +156,7 @@ def unsorted_segment_sum(
# check should be same
# Might require to change the assertion function name to
# check_unsorted_segment_valid_params
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)

Expand Down Expand Up @@ -225,6 +225,36 @@ def mel_weight_matrix(
return paddle.transpose(mel_mat, (1, 0))


def unsorted_segment_mean(
data: paddle.Tensor,
segment_ids: paddle.Tensor,
num_segments: Union[int, paddle.Tensor],
) -> paddle.Tensor:
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)

# Sum computation in paddle does not support int32, so needs to
# be converted to float32
needs_conv = False
if data.dtype == paddle.int32:
data = paddle.cast(data, "float32")
needs_conv = True

res = paddle.zeros((num_segments,) + tuple(data.shape[1:]), dtype=data.dtype)

count = paddle.bincount(segment_ids)
count = paddle.where(count > 0, count, paddle.to_tensor([1], dtype="int32"))
res = unsorted_segment_sum(data, segment_ids, num_segments)
res = res / paddle.reshape(count, (-1, 1))

# condition for converting float32 back to int32
if needs_conv is True:
res = paddle.cast(res, "int32")

return res


@with_unsupported_device_and_dtypes(
{
"2.5.1 and below": {
Expand Down
8 changes: 8 additions & 0 deletions ivy/functional/backends/tensorflow/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ def mel_weight_matrix(
)


def unsorted_segment_mean(
data: tf.Tensor,
segment_ids: tf.Tensor,
num_segments: Union[int, tf.Tensor],
) -> tf.Tensor:
return tf.math.unsorted_segment_mean(data, segment_ids, num_segments)


@with_unsupported_dtypes(
{"2.13.0 and below": ("bool", "bfloat16", "float16", "complex")}, backend_version
)
Expand Down
29 changes: 27 additions & 2 deletions ivy/functional/backends/torch/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def unsorted_segment_min(
segment_ids: torch.Tensor,
num_segments: Union[int, torch.Tensor],
) -> torch.Tensor:
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)
if data.dtype in [torch.float32, torch.float64, torch.float16, torch.bfloat16]:
Expand Down Expand Up @@ -180,7 +180,7 @@ def unsorted_segment_sum(
# check should be same
# Might require to change the assertion function name to
# check_unsorted_segment_valid_params
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)

Expand Down Expand Up @@ -247,6 +247,31 @@ def hz_to_mel(f):
return torch.nn.functional.pad(mel_weights, (0, 0, 1, 0))


def unsorted_segment_mean(
data: torch.Tensor,
segment_ids: torch.Tensor,
num_segments: Union[int, torch.Tensor],
) -> torch.Tensor:
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)

# Initialize an array to store the sum of elements for each segment
segment_sum = torch.zeros(
(num_segments,) + data.shape[1:], dtype=data.dtype, device=data.device
)

# Initialize an array to keep track of the number of elements in each segment
counts = torch.zeros(num_segments, dtype=torch.int64, device=data.device)

for i in range(len(segment_ids)):
seg_id = segment_ids[i]
segment_sum[seg_id] += data[i]
counts[seg_id] += 1

return segment_sum / counts[:, None]


@with_unsupported_dtypes({"2.0.1 and below": "float16"}, backend_version)
def polyval(
coeffs: torch.Tensor,
Expand Down
Loading

0 comments on commit a84b6b6

Please sign in to comment.