From a84b6b6632dfa9a336882b6aff6c2320555e9383 Mon Sep 17 00:00:00 2001 From: Kamlish <125588073+kamlishgoswami@users.noreply.github.com> Date: Sun, 29 Oct 2023 16:51:35 +0500 Subject: [PATCH] feat: Added unsorted_segment_mean function (#24984) Co-authored-by: ivy-branch Co-authored-by: Rishab Mallick --- .../array/experimental/creation.py | 44 +++++++ .../container/experimental/creation.py | 107 ++++++++++++++++++ .../backends/jax/experimental/creation.py | 21 +++- .../backends/numpy/experimental/creation.py | 32 +++++- .../backends/paddle/experimental/creation.py | 34 +++++- .../tensorflow/experimental/creation.py | 8 ++ .../backends/torch/experimental/creation.py | 29 ++++- ivy/functional/ivy/experimental/creation.py | 36 ++++++ ivy/utils/assertions.py | 2 +- .../test_core/test_creation.py | 27 +++++ .../test_ivy/test_misc/test_assertions.py | 6 +- 11 files changed, 334 insertions(+), 12 deletions(-) diff --git a/ivy/data_classes/array/experimental/creation.py b/ivy/data_classes/array/experimental/creation.py index 8fc3459e5cd2b..fdca1bcffabf3 100644 --- a/ivy/data_classes/array/experimental/creation.py +++ b/ivy/data_classes/array/experimental/creation.py @@ -265,6 +265,50 @@ def mel_weight_matrix( upper_edge_hertz, ) + def unsorted_segment_mean( + self: ivy.Array, + segment_ids: ivy.Array, + num_segments: Union[int, ivy.Array], + ) -> ivy.Array: + """ + Compute the mean of values in the array 'self' based on segment identifiers. + + Parameters + ---------- + self : ivy.Array + The array from which to gather values. + segment_ids : ivy.Array + Must be in the same size with the first dimension of `self`. Has to be + of integer data type. The index-th element of `segment_ids` array is + the segment identifier for the index-th element of `self`. + num_segments : Union[int, ivy.Array] + An integer or array representing the total number of distinct segment IDs. + + Returns + ------- + ret : ivy.Array + The output array, representing the result of a segmented mean operation. + For each segment, it computes the mean of values in `self` where + `segment_ids` equals to segment ID. + + Examples + -------- + >>> data = ivy.array([1.0, 2.0, 3.0, 4.0]) + >>> segment_ids = ivy.array([0, 0, 0, 0]) + >>> num_segments = 1 + >>> result = ivy.unsorted_segment_mean(data, segment_ids, num_segments) + >>> result + ivy.array([2.5]) + + >>> data = ivy.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + >>> segment_ids = ivy.array([0, 0, 1, 1, 2, 2]) + >>> num_segments = 3 + >>> result = ivy.unsorted_segment_mean(data, segment_ids, num_segments) + >>> result + ivy.array([[1.5, 3.5, 5.5],[1.5, 3.5, 5.5],[1.5, 3.5, 5.5]]) + """ + return ivy.unsorted_segment_mean(self._data, segment_ids, num_segments) + def polyval( coeffs=ivy.Array, diff --git a/ivy/data_classes/container/experimental/creation.py b/ivy/data_classes/container/experimental/creation.py index b708253e7adfb..b760f16ddde3f 100644 --- a/ivy/data_classes/container/experimental/creation.py +++ b/ivy/data_classes/container/experimental/creation.py @@ -1202,6 +1202,57 @@ def mel_weight_matrix( ) @staticmethod + def static_unsorted_segment_mean( + data: ivy.Container, + segment_ids: Union[ivy.Array, ivy.Container], + num_segments: Union[int, ivy.Container], + *, + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + ) -> ivy.Container: + """ + Compute the mean of values in the input data based on segment identifiers. + + Parameters + ---------- + data : ivy.Container + Input array or container from which to gather the input. + segment_ids : ivy.Container + An array of integers indicating the segment identifier for each element in + 'data'. + num_segments : Union[int, ivy.Container] + An integer or array representing the total number of distinct segment IDs. + key_chains : Optional[Union[List[str], Dict[str, str], ivy.Container]], optional + The key-chains to apply or not apply the method to. Default is None. + to_apply : Union[bool, ivy.Container], optional + If True, the method will be applied to key-chains, otherwise key-chains will + be skipped. Default is True. + prune_unapplied : Union[bool, ivy.Container], optional + Whether to prune key-chains for which the function was not applied. + Default is False. + map_sequences : Union[bool, ivy.Container], optional + Whether to also map method to sequences (lists, tuples). Default is False. + + Returns + ------- + ivy.Container + A container representing the result of a segmented mean operation. + For each segment, it computes the mean of values in 'data' where + 'segment_ids' equals the segment ID. + """ + return ContainerBase.cont_multi_map_in_function( + "unsorted_segment_mean", + data, + segment_ids, + num_segments, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + ) + def static_polyval( coeffs: ivy.Container, x: Union[ivy.Container, int, float], @@ -1253,6 +1304,62 @@ def static_polyval( map_sequences=map_sequences, ) + def unsorted_segment_mean( + self: ivy.Container, + segment_ids: Union[ivy.Array, ivy.Container], + num_segments: Union[int, ivy.Container], + ) -> ivy.Container: + """ + Compute the mean of values in the input array or container based on segment + identifiers. + + Parameters + ---------- + self : ivy.Container + Input array or container from which to gather the input. + segment_ids : ivy.Container + An array of integers indicating the segment identifier for each element + in 'self'. + num_segments : Union[int, ivy.Container] + An integer or array representing the total number of distinct segment IDs. + + Returns + ------- + ivy.Container + A container representing the result of a segmented mean operation. + For each segment, it computes the mean of values in 'self' where + 'segment_ids' equals the segment ID. + + Example + -------- + >>> data = ivy.Container(a=ivy.array([0., 1., 2., 4.]), + ... b=ivy.array([3., 4., 5., 6.])) + >>> segment_ids = ivy.array([0, 0, 1, 1]) + >>> num_segments = 2 + >>> result = ivy.unsorted_segment_mean(data, segment_ids, num_segments) + >>> print(result) + { + a: ivy.array([0.5, 3.0]), + b: ivy.array([3.5, 5.5]) + } + + >>> data = ivy.Container(a=ivy.array([0., 1., 2., 4., 5., 6.]), + ... b=ivy.array([3., 4., 5., 6., 7., 8.])) + >>> segment_ids = ivy.array([0, 0, 1, 1, 2, 2]) + >>> num_segments = 3 + >>> result = ivy.unsorted_segment_mean(data, segment_ids, num_segments) + >>> print(result) + { + a: ivy.array([0.5, 3.0, 5.5]), + b: ivy.array([3.5, 5.5, 7.5]) + } + """ + return self.static_unsorted_segment_mean( + self, + segment_ids, + num_segments, + ) + def polyval( self: ivy.Container, coeffs: ivy.Container, diff --git a/ivy/functional/backends/jax/experimental/creation.py b/ivy/functional/backends/jax/experimental/creation.py index 85c018578c38d..19e1833db5e4f 100644 --- a/ivy/functional/backends/jax/experimental/creation.py +++ b/ivy/functional/backends/jax/experimental/creation.py @@ -83,7 +83,7 @@ def unsorted_segment_min( num_segments: int, ) -> JaxArray: # added this check to keep the same behaviour as tensorflow - ivy.utils.assertions.check_unsorted_segment_min_valid_params( + ivy.utils.assertions.check_unsorted_segment_valid_params( data, segment_ids, num_segments ) return jax.ops.segment_min(data, segment_ids, num_segments) @@ -98,7 +98,7 @@ def unsorted_segment_sum( # the check should be same # Might require to change the assertion function name to # check_unsorted_segment_valid_params - ivy.utils.assertions.check_unsorted_segment_min_valid_params( + ivy.utils.assertions.check_unsorted_segment_valid_params( data, segment_ids, num_segments ) return jax.ops.segment_sum(data, segment_ids, num_segments) @@ -165,6 +165,23 @@ def hz_to_mel(f): return jnp.pad(mel_weights, [[1, 0], [0, 0]]) +def unsorted_segment_mean( + data: JaxArray, + segment_ids: JaxArray, + num_segments: int, +) -> JaxArray: + ivy.utils.assertions.check_unsorted_segment_valid_params( + data, segment_ids, num_segments + ) + segment_sum = jax.ops.segment_sum(data, segment_ids, num_segments) + + segment_count = jax.ops.segment_sum(jnp.ones_like(data), segment_ids, num_segments) + + segment_mean = segment_sum / segment_count + + return segment_mean + + def polyval( coeffs: JaxArray, x: JaxArray, diff --git a/ivy/functional/backends/numpy/experimental/creation.py b/ivy/functional/backends/numpy/experimental/creation.py index 6916cb31ef88f..e6c4b5a064779 100644 --- a/ivy/functional/backends/numpy/experimental/creation.py +++ b/ivy/functional/backends/numpy/experimental/creation.py @@ -89,7 +89,7 @@ def unsorted_segment_min( segment_ids: np.ndarray, num_segments: int, ) -> np.ndarray: - ivy.utils.assertions.check_unsorted_segment_min_valid_params( + ivy.utils.assertions.check_unsorted_segment_valid_params( data, segment_ids, num_segments ) @@ -143,7 +143,7 @@ def unsorted_segment_sum( # check should be same # Might require to change the assertion function name to # check_unsorted_segment_valid_params - ivy.utils.assertions.check_unsorted_segment_min_valid_params( + ivy.utils.assertions.check_unsorted_segment_valid_params( data, segment_ids, num_segments ) @@ -203,6 +203,34 @@ def hz_to_mel(f): return np.pad(mel_weights, [[1, 0], [0, 0]]) +def unsorted_segment_mean( + data: np.ndarray, + segment_ids: np.ndarray, + num_segments: int, +) -> np.ndarray: + ivy.utils.assertions.check_unsorted_segment_valid_params( + data, segment_ids, num_segments + ) + + if len(segment_ids) == 0: + # If segment_ids is empty, return an empty array of the correct shape + return np.zeros((num_segments,) + data.shape[1:], dtype=data.dtype) + + # Initialize an array to store the sum of elements for each segment + res = np.zeros((num_segments,) + data.shape[1:], dtype=data.dtype) + + # Initialize an array to keep track of the number of elements in each segment + counts = np.zeros(num_segments, dtype=np.int64) + + for i in range(len(segment_ids)): + seg_id = segment_ids[i] + if seg_id < num_segments: + res[seg_id] += data[i] + counts[seg_id] += 1 + + return res / counts[:, np.newaxis] + + def polyval(coeffs: np.ndarray, x: np.ndarray) -> np.ndarray: with ivy.PreciseMode(True): promoted_type = ivy.promote_types(ivy.dtype(coeffs[0]), ivy.dtype(x[0])) diff --git a/ivy/functional/backends/paddle/experimental/creation.py b/ivy/functional/backends/paddle/experimental/creation.py index 6f7d33f083002..bfbecb80e654c 100644 --- a/ivy/functional/backends/paddle/experimental/creation.py +++ b/ivy/functional/backends/paddle/experimental/creation.py @@ -103,7 +103,7 @@ def unsorted_segment_min( segment_ids: paddle.Tensor, num_segments: Union[int, paddle.Tensor], ) -> paddle.Tensor: - ivy.utils.assertions.check_unsorted_segment_min_valid_params( + ivy.utils.assertions.check_unsorted_segment_valid_params( data, segment_ids, num_segments ) if data.dtype == paddle.float32: @@ -156,7 +156,7 @@ def unsorted_segment_sum( # check should be same # Might require to change the assertion function name to # check_unsorted_segment_valid_params - ivy.utils.assertions.check_unsorted_segment_min_valid_params( + ivy.utils.assertions.check_unsorted_segment_valid_params( data, segment_ids, num_segments ) @@ -225,6 +225,36 @@ def mel_weight_matrix( return paddle.transpose(mel_mat, (1, 0)) +def unsorted_segment_mean( + data: paddle.Tensor, + segment_ids: paddle.Tensor, + num_segments: Union[int, paddle.Tensor], +) -> paddle.Tensor: + ivy.utils.assertions.check_unsorted_segment_valid_params( + data, segment_ids, num_segments + ) + + # Sum computation in paddle does not support int32, so needs to + # be converted to float32 + needs_conv = False + if data.dtype == paddle.int32: + data = paddle.cast(data, "float32") + needs_conv = True + + res = paddle.zeros((num_segments,) + tuple(data.shape[1:]), dtype=data.dtype) + + count = paddle.bincount(segment_ids) + count = paddle.where(count > 0, count, paddle.to_tensor([1], dtype="int32")) + res = unsorted_segment_sum(data, segment_ids, num_segments) + res = res / paddle.reshape(count, (-1, 1)) + + # condition for converting float32 back to int32 + if needs_conv is True: + res = paddle.cast(res, "int32") + + return res + + @with_unsupported_device_and_dtypes( { "2.5.1 and below": { diff --git a/ivy/functional/backends/tensorflow/experimental/creation.py b/ivy/functional/backends/tensorflow/experimental/creation.py index e86263c7d3f76..86af4561091c2 100644 --- a/ivy/functional/backends/tensorflow/experimental/creation.py +++ b/ivy/functional/backends/tensorflow/experimental/creation.py @@ -160,6 +160,14 @@ def mel_weight_matrix( ) +def unsorted_segment_mean( + data: tf.Tensor, + segment_ids: tf.Tensor, + num_segments: Union[int, tf.Tensor], +) -> tf.Tensor: + return tf.math.unsorted_segment_mean(data, segment_ids, num_segments) + + @with_unsupported_dtypes( {"2.13.0 and below": ("bool", "bfloat16", "float16", "complex")}, backend_version ) diff --git a/ivy/functional/backends/torch/experimental/creation.py b/ivy/functional/backends/torch/experimental/creation.py index 770e94654a1f4..953970525fd66 100644 --- a/ivy/functional/backends/torch/experimental/creation.py +++ b/ivy/functional/backends/torch/experimental/creation.py @@ -131,7 +131,7 @@ def unsorted_segment_min( segment_ids: torch.Tensor, num_segments: Union[int, torch.Tensor], ) -> torch.Tensor: - ivy.utils.assertions.check_unsorted_segment_min_valid_params( + ivy.utils.assertions.check_unsorted_segment_valid_params( data, segment_ids, num_segments ) if data.dtype in [torch.float32, torch.float64, torch.float16, torch.bfloat16]: @@ -180,7 +180,7 @@ def unsorted_segment_sum( # check should be same # Might require to change the assertion function name to # check_unsorted_segment_valid_params - ivy.utils.assertions.check_unsorted_segment_min_valid_params( + ivy.utils.assertions.check_unsorted_segment_valid_params( data, segment_ids, num_segments ) @@ -247,6 +247,31 @@ def hz_to_mel(f): return torch.nn.functional.pad(mel_weights, (0, 0, 1, 0)) +def unsorted_segment_mean( + data: torch.Tensor, + segment_ids: torch.Tensor, + num_segments: Union[int, torch.Tensor], +) -> torch.Tensor: + ivy.utils.assertions.check_unsorted_segment_valid_params( + data, segment_ids, num_segments + ) + + # Initialize an array to store the sum of elements for each segment + segment_sum = torch.zeros( + (num_segments,) + data.shape[1:], dtype=data.dtype, device=data.device + ) + + # Initialize an array to keep track of the number of elements in each segment + counts = torch.zeros(num_segments, dtype=torch.int64, device=data.device) + + for i in range(len(segment_ids)): + seg_id = segment_ids[i] + segment_sum[seg_id] += data[i] + counts[seg_id] += 1 + + return segment_sum / counts[:, None] + + @with_unsupported_dtypes({"2.0.1 and below": "float16"}, backend_version) def polyval( coeffs: torch.Tensor, diff --git a/ivy/functional/ivy/experimental/creation.py b/ivy/functional/ivy/experimental/creation.py index 710a5b4b5e74d..db01811743ebb 100644 --- a/ivy/functional/ivy/experimental/creation.py +++ b/ivy/functional/ivy/experimental/creation.py @@ -1139,6 +1139,42 @@ def mel_weight_matrix( ) +# unsorted_segment_mean +@handle_exceptions +@handle_nestable +@to_native_arrays_and_back +def unsorted_segment_mean( + data: Union[ivy.Array, ivy.NativeArray], + segment_ids: Union[ivy.Array, ivy.NativeArray], + num_segments: Union[int, ivy.Array, ivy.NativeArray], +) -> ivy.Array: + """ + Compute the mean of elements along segments of an array. Segments are defined by an + integer array of segment IDs. + + Parameters + ---------- + data : Union[ivy.Array, ivy.NativeArray] + The array from which to gather values. + + segment_ids : Union[ivy.Array, ivy.NativeArray] + Must be in the same size with the first dimension of `data`. Has to be + of integer data type. The index-th element of `segment_ids` array is + the segment identifier for the index-th element of `data`. + + num_segments : Union[int, ivy.Array, ivy.NativeArray] + An integer or array representing the total number of distinct segment IDs. + + Returns + ------- + ivy.Array + The output array, representing the result of a segmented mean operation. + For each segment, it computes the mean value in `data` where `segment_ids` + equals to segment ID. + """ + return ivy.current_backend().unsorted_segment_mean(data, segment_ids, num_segments) + + @handle_exceptions @handle_nestable @handle_array_function diff --git a/ivy/utils/assertions.py b/ivy/utils/assertions.py index 0a5f653f7fd8f..3ce9cb927f8be 100644 --- a/ivy/utils/assertions.py +++ b/ivy/utils/assertions.py @@ -182,7 +182,7 @@ def check_same_dtype(x1, x2, message=""): # -------- # -def check_unsorted_segment_min_valid_params(data, segment_ids, num_segments): +def check_unsorted_segment_valid_params(data, segment_ids, num_segments): if not isinstance(num_segments, int): raise ValueError("num_segments must be of integer type") diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py index 51b14a8d29b15..151cc3fcb344d 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py @@ -799,6 +799,33 @@ def test_trilu(*, dtype_and_x, k, upper, test_flags, backend_fw, fn_name, on_dev ) +@handle_test( + fn_tree="functional.ivy.experimental.unsorted_segment_mean", + d_x_n_s=valid_unsorted_segment_min_inputs(), + test_with_out=st.just(False), + test_gradients=st.just(False), +) +def test_unsorted_segment_mean( + *, + d_x_n_s, + test_flags, + backend_fw, + fn_name, + on_device, +): + dtypes, data, num_segments, segment_ids = d_x_n_s + helpers.test_function( + input_dtypes=dtypes, + test_flags=test_flags, + on_device=on_device, + backend_to_test=backend_fw, + fn_name=fn_name, + data=data, + segment_ids=segment_ids, + num_segments=num_segments, + ) + + # unsorted_segment_min @handle_test( fn_tree="functional.ivy.experimental.unsorted_segment_min", diff --git a/ivy_tests/test_ivy/test_misc/test_assertions.py b/ivy_tests/test_ivy/test_misc/test_assertions.py index b6e4c948c1dea..8d6484d4bacb4 100644 --- a/ivy_tests/test_ivy/test_misc/test_assertions.py +++ b/ivy_tests/test_ivy/test_misc/test_assertions.py @@ -23,7 +23,7 @@ check_shape, check_shapes_broadcastable, check_true, - check_unsorted_segment_min_valid_params, + check_unsorted_segment_valid_params, ) from ivy.utils.assertions import _check_jax_x64_flag import ivy @@ -852,7 +852,7 @@ def test_check_true(expression): (ivy.array([1, 2, 3]), ivy.array([0, 1, 0], dtype=ivy.int32), ivy.array([2])), ], ) -def test_check_unsorted_segment_min_valid_params(data, segment_ids, num_segments): +def test_check_unsorted_segment_valid_params(data, segment_ids, num_segments): filename = "except_out.txt" orig_stdout = sys.stdout @@ -860,7 +860,7 @@ def test_check_unsorted_segment_min_valid_params(data, segment_ids, num_segments sys.stdout = f lines = "" try: - check_unsorted_segment_min_valid_params(data, segment_ids, num_segments) + check_unsorted_segment_valid_params(data, segment_ids, num_segments) local_vars = {**locals()} except Exception as e: local_vars = {**locals()}