Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added unsorted_segment_mean function #24984

Merged
merged 33 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
2bb4a57
Added unsorted_segment_mean function
kamlishgoswami Sep 19, 2023
245045d
Error fixing in test_unsorted_segment_mean
kamlishgoswami Sep 19, 2023
af671f2
Added unsorted_segment_mean function
kamlishgoswami Sep 20, 2023
7057568
Correct the error in the unsorted_segment_mean.
kamlishgoswami Sep 20, 2023
3a7a918
docstring missing
kamlishgoswami Sep 29, 2023
583c30e
Return implementation
kamlishgoswami Oct 5, 2023
2d28c87
🤖 Lint code
ivy-branch Oct 5, 2023
65e06d8
Update backend_fw
kamlishgoswami Oct 7, 2023
3406f72
Merge branch 'unifyai:main' into kamlish
kamlishgoswami Oct 7, 2023
77fd06e
Added examples
kamlishgoswami Oct 8, 2023
3827ead
🤖 Lint code
ivy-branch Oct 8, 2023
f40ea13
TensorFlow already provides mean function
kamlishgoswami Oct 8, 2023
ab0ed43
Merge branch 'kamlish' of github.com:kamlishgoswami/ivy into kamlish
kamlishgoswami Oct 8, 2023
4099e36
🤖 Lint code
ivy-branch Oct 8, 2023
94eeae4
_min_valid_params fn to _valid_params
kamlishgoswami Oct 8, 2023
028f8e9
Merge branch 'kamlish' of github.com:kamlishgoswami/ivy into kamlish
kamlishgoswami Oct 8, 2023
8a4b102
remove unnecessarily comments
kamlishgoswami Oct 8, 2023
b5ef16d
removed if condition
kamlishgoswami Oct 8, 2023
fb04820
🤖 Lint code
ivy-branch Oct 8, 2023
67f9c73
ivy Container Example
kamlishgoswami Oct 9, 2023
ef7f8ac
remove unnecessarily print statement
kamlishgoswami Oct 9, 2023
ab7185d
rename all instances of check_unsorted_segment_min_valid_params in iv…
kamlishgoswami Oct 9, 2023
d7cf1a5
🤖 Lint code
ivy-branch Oct 9, 2023
8002d7c
ivy Container Example
kamlishgoswami Oct 9, 2023
15d1f90
chore: updates test of renamed assertion
rishabgit Oct 11, 2023
1a4d817
link checks - Formatting
kamlishgoswami Oct 12, 2023
0c6e1e7
🤖 Lint code
ivy-branch Oct 12, 2023
fa8543f
chore: updates test of renamed assertion
kamlishgoswami Oct 12, 2023
45d2734
Updated array example
kamlishgoswami Oct 12, 2023
317f6bb
updated container example
kamlishgoswami Oct 13, 2023
a4e9425
Merge branch 'main' of github.com:kamlishgoswami/ivy into kamlish
kamlishgoswami Oct 24, 2023
89f9a0d
Merge branch 'main' of github.com:kamlishgoswami/ivy into kamlish
kamlishgoswami Oct 25, 2023
41bb91f
Merge branch 'kamlish' of https://github.com/kamlishgoswami/ivy into …
rishabgit Oct 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions ivy/data_classes/array/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,50 @@ def mel_weight_matrix(
upper_edge_hertz,
)

def unsorted_segment_mean(
self: ivy.Array,
segment_ids: ivy.Array,
num_segments: Union[int, ivy.Array],
) -> ivy.Array:
"""
Compute the mean of values in the array 'self' based on segment identifiers.

Parameters
----------
self : ivy.Array
The array from which to gather values.
segment_ids : ivy.Array
Must be in the same size with the first dimension of `self`. Has to be
of integer data type. The index-th element of `segment_ids` array is
the segment identifier for the index-th element of `self`.
num_segments : Union[int, ivy.Array]
An integer or array representing the total number of distinct segment IDs.

Returns
-------
ret : ivy.Array
The output array, representing the result of a segmented mean operation.
For each segment, it computes the mean of values in `self` where
`segment_ids` equals to segment ID.

Examples
--------
>>> data = ivy.array([1.0, 2.0, 3.0, 4.0])
>>> segment_ids = ivy.array([0, 0, 0, 0])
>>> num_segments = 1
>>> result = ivy.unsorted_segment_mean(data, segment_ids, num_segments)
>>> result
ivy.array([2.5])

>>> data = ivy.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
>>> segment_ids = ivy.array([0, 0, 1, 1, 2, 2])
>>> num_segments = 3
>>> result = ivy.unsorted_segment_mean(data, segment_ids, num_segments)
>>> result
ivy.array([[1.5, 3.5, 5.5],[1.5, 3.5, 5.5],[1.5, 3.5, 5.5]])
"""
rishabgit marked this conversation as resolved.
Show resolved Hide resolved
return ivy.unsorted_segment_mean(self._data, segment_ids, num_segments)


def polyval(
coeffs=ivy.Array,
Expand Down
107 changes: 107 additions & 0 deletions ivy/data_classes/container/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,57 @@ def mel_weight_matrix(
)

@staticmethod
def static_unsorted_segment_mean(
data: ivy.Container,
segment_ids: Union[ivy.Array, ivy.Container],
num_segments: Union[int, ivy.Container],
*,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
) -> ivy.Container:
"""
Compute the mean of values in the input data based on segment identifiers.

Parameters
----------
data : ivy.Container
Input array or container from which to gather the input.
segment_ids : ivy.Container
An array of integers indicating the segment identifier for each element in
'data'.
num_segments : Union[int, ivy.Container]
An integer or array representing the total number of distinct segment IDs.
key_chains : Optional[Union[List[str], Dict[str, str], ivy.Container]], optional
The key-chains to apply or not apply the method to. Default is None.
to_apply : Union[bool, ivy.Container], optional
If True, the method will be applied to key-chains, otherwise key-chains will
be skipped. Default is True.
prune_unapplied : Union[bool, ivy.Container], optional
Whether to prune key-chains for which the function was not applied.
Default is False.
map_sequences : Union[bool, ivy.Container], optional
Whether to also map method to sequences (lists, tuples). Default is False.

Returns
-------
ivy.Container
A container representing the result of a segmented mean operation.
For each segment, it computes the mean of values in 'data' where
'segment_ids' equals the segment ID.
"""
return ContainerBase.cont_multi_map_in_function(
"unsorted_segment_mean",
data,
segment_ids,
num_segments,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
)

rishabgit marked this conversation as resolved.
Show resolved Hide resolved
def static_polyval(
coeffs: ivy.Container,
x: Union[ivy.Container, int, float],
Expand Down Expand Up @@ -1253,6 +1304,62 @@ def static_polyval(
map_sequences=map_sequences,
)

def unsorted_segment_mean(
self: ivy.Container,
segment_ids: Union[ivy.Array, ivy.Container],
num_segments: Union[int, ivy.Container],
) -> ivy.Container:
"""
Compute the mean of values in the input array or container based on segment
identifiers.

Parameters
----------
self : ivy.Container
Input array or container from which to gather the input.
segment_ids : ivy.Container
An array of integers indicating the segment identifier for each element
in 'self'.
num_segments : Union[int, ivy.Container]
An integer or array representing the total number of distinct segment IDs.

Returns
-------
ivy.Container
A container representing the result of a segmented mean operation.
For each segment, it computes the mean of values in 'self' where
'segment_ids' equals the segment ID.

Example
--------
>>> data = ivy.Container(a=ivy.array([0., 1., 2., 4.]),
... b=ivy.array([3., 4., 5., 6.]))
>>> segment_ids = ivy.array([0, 0, 1, 1])
>>> num_segments = 2
>>> result = ivy.unsorted_segment_mean(data, segment_ids, num_segments)
>>> print(result)
{
a: ivy.array([0.5, 3.0]),
b: ivy.array([3.5, 5.5])
}

>>> data = ivy.Container(a=ivy.array([0., 1., 2., 4., 5., 6.]),
... b=ivy.array([3., 4., 5., 6., 7., 8.]))
>>> segment_ids = ivy.array([0, 0, 1, 1, 2, 2])
>>> num_segments = 3
>>> result = ivy.unsorted_segment_mean(data, segment_ids, num_segments)
>>> print(result)
{
a: ivy.array([0.5, 3.0, 5.5]),
b: ivy.array([3.5, 5.5, 7.5])
}
"""
rishabgit marked this conversation as resolved.
Show resolved Hide resolved
return self.static_unsorted_segment_mean(
self,
segment_ids,
num_segments,
)

def polyval(
self: ivy.Container,
coeffs: ivy.Container,
Expand Down
21 changes: 19 additions & 2 deletions ivy/functional/backends/jax/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def unsorted_segment_min(
num_segments: int,
) -> JaxArray:
# added this check to keep the same behaviour as tensorflow
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)
return jax.ops.segment_min(data, segment_ids, num_segments)
Expand All @@ -98,7 +98,7 @@ def unsorted_segment_sum(
# the check should be same
# Might require to change the assertion function name to
# check_unsorted_segment_valid_params
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)
return jax.ops.segment_sum(data, segment_ids, num_segments)
Expand Down Expand Up @@ -165,6 +165,23 @@ def hz_to_mel(f):
return jnp.pad(mel_weights, [[1, 0], [0, 0]])


def unsorted_segment_mean(
data: JaxArray,
segment_ids: JaxArray,
num_segments: int,
) -> JaxArray:
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)
segment_sum = jax.ops.segment_sum(data, segment_ids, num_segments)

segment_count = jax.ops.segment_sum(jnp.ones_like(data), segment_ids, num_segments)

segment_mean = segment_sum / segment_count

return segment_mean


def polyval(
coeffs: JaxArray,
x: JaxArray,
Expand Down
32 changes: 30 additions & 2 deletions ivy/functional/backends/numpy/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def unsorted_segment_min(
segment_ids: np.ndarray,
num_segments: int,
) -> np.ndarray:
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)

Expand Down Expand Up @@ -146,7 +146,7 @@ def unsorted_segment_sum(
# check should be same
# Might require to change the assertion function name to
# check_unsorted_segment_valid_params
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)

Expand Down Expand Up @@ -206,6 +206,34 @@ def hz_to_mel(f):
return np.pad(mel_weights, [[1, 0], [0, 0]])


def unsorted_segment_mean(
data: np.ndarray,
segment_ids: np.ndarray,
num_segments: int,
) -> np.ndarray:
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)

if len(segment_ids) == 0:
# If segment_ids is empty, return an empty array of the correct shape
return np.zeros((num_segments,) + data.shape[1:], dtype=data.dtype)

# Initialize an array to store the sum of elements for each segment
res = np.zeros((num_segments,) + data.shape[1:], dtype=data.dtype)

# Initialize an array to keep track of the number of elements in each segment
counts = np.zeros(num_segments, dtype=np.int64)

for i in range(len(segment_ids)):
seg_id = segment_ids[i]
if seg_id < num_segments:
res[seg_id] += data[i]
counts[seg_id] += 1

return res / counts[:, np.newaxis]


def polyval(coeffs: np.ndarray, x: np.ndarray) -> np.ndarray:
with ivy.PreciseMode(True):
promoted_type = ivy.promote_types(ivy.dtype(coeffs[0]), ivy.dtype(x[0]))
Expand Down
34 changes: 32 additions & 2 deletions ivy/functional/backends/paddle/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def unsorted_segment_min(
segment_ids: paddle.Tensor,
num_segments: Union[int, paddle.Tensor],
) -> paddle.Tensor:
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)
if data.dtype == paddle.float32:
Expand Down Expand Up @@ -156,7 +156,7 @@ def unsorted_segment_sum(
# check should be same
# Might require to change the assertion function name to
# check_unsorted_segment_valid_params
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)

Expand Down Expand Up @@ -225,6 +225,36 @@ def mel_weight_matrix(
return paddle.transpose(mel_mat, (1, 0))


def unsorted_segment_mean(
data: paddle.Tensor,
segment_ids: paddle.Tensor,
num_segments: Union[int, paddle.Tensor],
) -> paddle.Tensor:
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)

# Sum computation in paddle does not support int32, so needs to
# be converted to float32
needs_conv = False
if data.dtype == paddle.int32:
data = paddle.cast(data, "float32")
needs_conv = True

res = paddle.zeros((num_segments,) + tuple(data.shape[1:]), dtype=data.dtype)

count = paddle.bincount(segment_ids)
count = paddle.where(count > 0, count, paddle.to_tensor([1], dtype="int32"))
res = unsorted_segment_sum(data, segment_ids, num_segments)
res = res / paddle.reshape(count, (-1, 1))

# condition for converting float32 back to int32
if needs_conv is True:
res = paddle.cast(res, "int32")

return res


@with_unsupported_device_and_dtypes(
{
"2.5.1 and below": {
Expand Down
8 changes: 8 additions & 0 deletions ivy/functional/backends/tensorflow/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ def mel_weight_matrix(
)


def unsorted_segment_mean(
data: tf.Tensor,
segment_ids: tf.Tensor,
num_segments: Union[int, tf.Tensor],
) -> tf.Tensor:
return tf.math.unsorted_segment_mean(data, segment_ids, num_segments)


@with_unsupported_dtypes(
{"2.13.0 and below": ("bool", "bfloat16", "float16", "complex")}, backend_version
)
Expand Down
29 changes: 27 additions & 2 deletions ivy/functional/backends/torch/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def unsorted_segment_min(
segment_ids: torch.Tensor,
num_segments: Union[int, torch.Tensor],
) -> torch.Tensor:
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)
if data.dtype in [torch.float32, torch.float64, torch.float16, torch.bfloat16]:
Expand Down Expand Up @@ -180,7 +180,7 @@ def unsorted_segment_sum(
# check should be same
# Might require to change the assertion function name to
# check_unsorted_segment_valid_params
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)

Expand Down Expand Up @@ -247,6 +247,31 @@ def hz_to_mel(f):
return torch.nn.functional.pad(mel_weights, (0, 0, 1, 0))


def unsorted_segment_mean(
data: torch.Tensor,
segment_ids: torch.Tensor,
num_segments: Union[int, torch.Tensor],
) -> torch.Tensor:
ivy.utils.assertions.check_unsorted_segment_valid_params(
data, segment_ids, num_segments
)

# Initialize an array to store the sum of elements for each segment
segment_sum = torch.zeros(
(num_segments,) + data.shape[1:], dtype=data.dtype, device=data.device
)

# Initialize an array to keep track of the number of elements in each segment
counts = torch.zeros(num_segments, dtype=torch.int64, device=data.device)

for i in range(len(segment_ids)):
seg_id = segment_ids[i]
segment_sum[seg_id] += data[i]
counts[seg_id] += 1

return segment_sum / counts[:, None]


@with_unsupported_dtypes({"2.0.1 and below": "float16"}, backend_version)
def polyval(
coeffs: torch.Tensor,
Expand Down
Loading
Loading