diff --git a/ivy/functional/backends/tensorflow/experimental/layers.py b/ivy/functional/backends/tensorflow/experimental/layers.py index 94c60963871b5..51cd329087eff 100644 --- a/ivy/functional/backends/tensorflow/experimental/layers.py +++ b/ivy/functional/backends/tensorflow/experimental/layers.py @@ -30,6 +30,33 @@ def _determine_depth_max_pooling(x, kernel, strides, dims, data_format="channel_ return x, kernel, strides, depth_pooling +def _cal_output_shape( + input_shape, + padding, + kernel_size, + stride, +): + shape = [(l - 1) * s - 2 * p + k for l, s, p, k in zip(input_shape, stride, padding, kernel_size)] + return shape + + +def _broadcast_pooling_helper(x, pool_dims: str = "2d", name: str = "padding"): + dims = {"1d": 1, "2d": 2, "3d": 3} + + if isinstance(x, int): + return tuple([x for _ in range(dims[pool_dims])]) + + if len(x) == 1: + return tuple([x[0] for _ in range(dims[pool_dims])]) + elif len(x) == dims[pool_dims]: + return tuple(x) + elif len(x) != dims[pool_dims]: + raise ValueError( + f"`{name}` must either be a single int, " + f"or a tuple of {dims[pool_dims]} ints. " + ) + + def max_pool1d( x: Union[tf.Tensor, tf.Variable], kernel: Union[int, Tuple[int, ...]], @@ -1427,3 +1454,54 @@ def rfftn( else: # return result return tf.cast(result, tf.complex128) + + +def max_unpool1d( + input: Union[tf.Tensor, tf.Variable], + indices: Union[tf.Tensor, tf.Variable], + kernel_size: Union[Tuple[int], int], + strides: Union[int, Tuple[int]] = None, + padding: Union[int, Tuple[int]] = 0, + data_format: Optional[str] = "NCW", +): + if strides is None: + strides = kernel_size + input_shape = tf.shape(input) + if data_format in ["NCW", "NWC"]: + revert = False + if data_format == "NWC": + x_len = (input_shape[1],) + input = tf.transpose(input, perm=[0, 2, 1]) + indices = tf.transpose(indices, perm=[0, 2, 1]) + revert = True + else: + x_len = (input_shape[-1],) + else: + raise ValueError(f"data_format attr should be NCW or NWC but found {data_format}") + + input_shape = tf.shape(input) + ind_dtype = indices.dtype + kernel_size = _broadcast_pooling_helper(kernel_size, "1d", name="kernel_size") + padding = _broadcast_pooling_helper(padding, "1d", name="padding") + strides = _broadcast_pooling_helper(strides, "1d", name="strides") + output_len = _cal_output_shape(x_len, padding, kernel_size, strides) + output_shape = list(input_shape[:-1]) + output_len + one_like_mask = tf.ones_like(indices, dtype=ind_dtype) + batch_shape = tf.concat([[input_shape[0]], [1], [1]], axis=0) + batch_range = tf.reshape( + tf.range(0, output_shape[0], dtype=ind_dtype), shape=batch_shape + ) + b = one_like_mask * batch_range + feature_range = tf.reshape( + tf.range(0, output_shape[1], dtype=ind_dtype), shape=(1, -1, 1) + ) + f = one_like_mask * feature_range + output = tf.zeros(output_shape, dtype=input.dtype) + indices = tf.reshape(tf.stack([b, f, indices]), [3, -1]) + indices = tf.transpose(indices) + values = tf.reshape(input, (-1,)) + output = tf.tensor_scatter_nd_update(output, indices, values) + if revert: + output = tf.transpose(output, perm=[0, 2, 1]) + return output + diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py index 4c963d7bdd711..1c6ff27421eb5 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py @@ -1,5 +1,6 @@ # global import numpy as np +import torch from hypothesis import strategies as st, assume # local @@ -1190,23 +1191,54 @@ def test_max_pool3d( ) +@st.composite +def max_unpool1d_helper( + draw, + **data_gen_kwargs +): + dts, values, kernel_size, strides, _ = draw( + helpers.arrays_for_pooling( + min_dims=3, + max_dims=3, + data_format="channel_first", + **data_gen_kwargs) + ) + dts.extend(["int64"]) + values = values[0] + padding = draw(helpers.ints(min_value=0, max_value=2)) + if padding > (kernel_size[0] // 2): + padding = 0 + + values, indices = torch.nn.functional.max_pool1d( + torch.tensor(values.astype(np.float32)), + kernel_size, + strides, + padding, + return_indices=True, + ) + indices = indices.numpy().astype(np.int64) + max_idx = values.shape[-1] - 1 + indices = np.where(indices > max_idx, max_idx, indices) + values = values.numpy().astype(dts[0]) + return dts, values, indices, kernel_size, strides, padding + + @handle_test( - fn_tree="functional.ivy.experimental.layers.max_unpool1d", - x_k_s_p=helpers.arrays_for_pooling(min_dims=3, max_dims=3, min_side=1, max_side=4), - indices=st.lists(st.integers(0, 1), min_size=1, max_size=4), + fn_tree="functional.ivy.experimental.max_unpool1d", + x_k_s_p=max_unpool1d_helper(min_side=2, max_side=5), ground_truth_backend="jax", test_gradients=st.just(False), + test_with_out=st.just(False), ) def test_max_unpool1d( *, x_k_s_p, - indices, test_flags, backend_fw, fn_name, on_device, ): - dtype, x, kernel, stride, pad = x_k_s_p + dtype, x, ind, kernel, stride, pad = x_k_s_p helpers.test_function( input_dtypes=dtype, test_flags=test_flags, @@ -1215,11 +1247,11 @@ def test_max_unpool1d( fn_name=fn_name, rtol_=1e-2, atol_=1e-2, - x=x[0], - kernel=kernel, + input=x, + indices=ind, + kernel_size=kernel, strides=stride, padding=pad, - indices=indices, )