Skip to content

Commit

Permalink
🤖 Lint code
Browse files Browse the repository at this point in the history
  • Loading branch information
ivy-branch committed Sep 29, 2023
1 parent 93443b7 commit 253016f
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 59 deletions.
6 changes: 3 additions & 3 deletions ivy/data_classes/array/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,9 @@ def mel_weight_matrix(
)

def unsorted_segment_mean(
self: ivy.Array,
segment_ids: ivy.Array,
num_segments: Union[int, ivy.Array],
self: ivy.Array,
segment_ids: ivy.Array,
num_segments: Union[int, ivy.Array],
) -> ivy.Array:
"""
Computes the mean of values in the array 'self' based on segment identifiers.
Expand Down
25 changes: 13 additions & 12 deletions ivy/data_classes/container/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,14 +1203,14 @@ def mel_weight_matrix(

@staticmethod
def static_unsorted_segment_mean(
data: ivy.Container,
segment_ids: ivy.Container,
num_segments: Union[int, ivy.Container],
*,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
data: ivy.Container,
segment_ids: ivy.Container,
num_segments: Union[int, ivy.Container],
*,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
) -> ivy.Container:
"""
Computes the mean of values in the input data based on segment identifiers.
Expand Down Expand Up @@ -1241,12 +1241,13 @@ def static_unsorted_segment_mean(
"""

def unsorted_segment_mean(
self: ivy.Container,
segment_ids: ivy.Container,
num_segments: Union[int, ivy.Container],
self: ivy.Container,
segment_ids: ivy.Container,
num_segments: Union[int, ivy.Container],
) -> ivy.Container:
"""
Computes the mean of values in the input array or container based on segment identifiers.
Computes the mean of values in the input array or container based on segment
identifiers.
Parameters
----------
Expand Down
1 change: 1 addition & 0 deletions ivy/functional/backends/jax/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def mel_weight_matrix(
mel_weights = jnp.maximum(zero, jnp.minimum(lower_slopes, upper_slopes))
return jnp.pad(mel_weights, [[1, 0], [0, 0]])


def unsorted_segment_mean(
data: JaxArray,
segment_ids: JaxArray,
Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/backends/numpy/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ def mel_weight_matrix(


def unsorted_segment_mean(
data: np.ndarray,
segment_ids: np.ndarray,
num_segments: int,
data: np.ndarray,
segment_ids: np.ndarray,
num_segments: int,
) -> np.ndarray:
# Check if the parameters are valid
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
Expand Down
3 changes: 2 additions & 1 deletion ivy/functional/backends/paddle/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def mel_weight_matrix(
)
return paddle.transpose(mel_mat, (1, 0))


def unsorted_segment_mean(
data: paddle.Tensor,
segment_ids: paddle.Tensor,
Expand All @@ -250,7 +251,7 @@ def unsorted_segment_mean(
res = paddle.zeros((num_segments,) + tuple(data.shape[1:]), dtype=data.dtype)

count = paddle.bincount(segment_ids)
count = paddle.where(count > 0, count, paddle.to_tensor([1], dtype='int32'))
count = paddle.where(count > 0, count, paddle.to_tensor([1], dtype="int32"))
res = unsorted_segment_sum(data, segment_ids, num_segments)
res = res / paddle.reshape(count, (-1, 1))

Expand Down
10 changes: 6 additions & 4 deletions ivy/functional/backends/tensorflow/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,17 @@ def mel_weight_matrix(


def unsorted_segment_mean(
data: tf.Tensor,
segment_ids: tf.Tensor,
num_segments: Union[int, tf.Tensor],
data: tf.Tensor,
segment_ids: tf.Tensor,
num_segments: Union[int, tf.Tensor],
) -> tf.Tensor:
# Calculate the sum along segments
segment_sum = tf.math.unsorted_segment_sum(data, segment_ids, num_segments)

# Calculate the count of elements in each segment
segment_count = tf.math.unsorted_segment_max(tf.ones_like(data), segment_ids, num_segments)
segment_count = tf.math.unsorted_segment_max(
tf.ones_like(data), segment_ids, num_segments
)

# Calculate the mean
segment_mean = segment_sum / segment_count
Expand Down
10 changes: 6 additions & 4 deletions ivy/functional/backends/torch/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,17 +245,19 @@ def mel_weight_matrix(


def unsorted_segment_mean(
data: torch.Tensor,
segment_ids: torch.Tensor,
num_segments: Union[int, torch.Tensor],
data: torch.Tensor,
segment_ids: torch.Tensor,
num_segments: Union[int, torch.Tensor],
) -> torch.Tensor:
# Check if the parameters are valid
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
data, segment_ids, num_segments
)

# Initialize an array to store the sum of elements for each segment
segment_sum = torch.zeros((num_segments,) + data.shape[1:], dtype=data.dtype, device=data.device)
segment_sum = torch.zeros(
(num_segments,) + data.shape[1:], dtype=data.dtype, device=data.device
)

# Initialize an array to keep track of the number of elements in each segment
counts = torch.zeros(num_segments, dtype=torch.int64, device=data.device)
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/ivy/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,7 @@ def mel_weight_matrix(
upper_edge_hertz,
)


# unsorted_segment_mean
@handle_exceptions
@handle_nestable
Expand Down Expand Up @@ -1115,4 +1116,3 @@ def unsorted_segment_mean(
"""
# Get the current backend and call its `unsorted_segment_mean` function
return ivy.current_backend().unsorted_segment_mean(data, segment_ids, num_segments)

Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,38 @@ def test_trilu(*, dtype_and_x, k, upper, test_flags, backend_fw, fn_name, on_dev
)


@handle_test(
fn_tree="functional.ivy.experimental.unsorted_segment_mean",
d_x_n_s=valid_unsorted_segment_min_inputs(),
test_with_out=st.just(False),
test_gradients=st.just(False),
)
def test_unsorted_segment_mean(
*,
d_x_n_s,
test_flags,
backend_fw,
fn_name,
on_device,
):
dtypes, data, num_segments, segment_ids = d_x_n_s
if backend_fw == "numpy":
# Modify the test case to ensure correct shape
if fn_name == "unsorted_segment_mean":
data = np.array([data]) # Wrap data in a list or array if it's not already
segment_ids = np.array([segment_ids]) # Wrap segment_ids similarly
helpers.test_function(
input_dtypes=dtypes,
test_flags=test_flags,
on_device=on_device,
backend_to_test=backend_fw,
fn_name=fn_name,
data=data,
segment_ids=segment_ids,
num_segments=num_segments,
)


# unsorted_segment_min
@handle_test(
fn_tree="functional.ivy.experimental.unsorted_segment_min",
Expand Down Expand Up @@ -767,34 +799,3 @@ def test_vorbis_window(
window_length=int(x[0]),
dtype=dtype[0],
)

@handle_test(
fn_tree="functional.ivy.experimental.unsorted_segment_mean",
d_x_n_s=valid_unsorted_segment_min_inputs(),
test_with_out=st.just(False),
test_gradients=st.just(False),
)
def test_unsorted_segment_mean(
*,
d_x_n_s,
test_flags,
backend_fw,
fn_name,
on_device,
):
dtypes, data, num_segments, segment_ids = d_x_n_s
if backend_fw == 'numpy':
# Modify the test case to ensure correct shape
if fn_name == 'unsorted_segment_mean':
data = np.array([data]) # Wrap data in a list or array if it's not already
segment_ids = np.array([segment_ids]) # Wrap segment_ids similarly
helpers.test_function(
input_dtypes=dtypes,
test_flags=test_flags,
on_device=on_device,
backend_to_test=backend_fw,
fn_name=fn_name,
data=data,
segment_ids=segment_ids,
num_segments=num_segments,
)

0 comments on commit 253016f

Please sign in to comment.