Skip to content

Commit

Permalink
added max_unpool1d implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Mohammed committed Sep 1, 2023
1 parent d9b21ed commit 1df7a40
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 8 deletions.
78 changes: 78 additions & 0 deletions ivy/functional/backends/tensorflow/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,33 @@ def _determine_depth_max_pooling(x, kernel, strides, dims, data_format="channel_
return x, kernel, strides, depth_pooling


def _cal_output_shape(
input_shape,
padding,
kernel_size,
stride,
):
shape = [(l - 1) * s - 2 * p + k for l, s, p, k in zip(input_shape, stride, padding, kernel_size)]
return shape


def _broadcast_pooling_helper(x, pool_dims: str = "2d", name: str = "padding"):
dims = {"1d": 1, "2d": 2, "3d": 3}

if isinstance(x, int):
return tuple([x for _ in range(dims[pool_dims])])

if len(x) == 1:
return tuple([x[0] for _ in range(dims[pool_dims])])
elif len(x) == dims[pool_dims]:
return tuple(x)
elif len(x) != dims[pool_dims]:
raise ValueError(
f"`{name}` must either be a single int, "
f"or a tuple of {dims[pool_dims]} ints. "
)


def max_pool1d(
x: Union[tf.Tensor, tf.Variable],
kernel: Union[int, Tuple[int, ...]],
Expand Down Expand Up @@ -1427,3 +1454,54 @@ def rfftn(
else:
# return result
return tf.cast(result, tf.complex128)


def max_unpool1d(
input: Union[tf.Tensor, tf.Variable],
indices: Union[tf.Tensor, tf.Variable],
kernel_size: Union[Tuple[int], int],
strides: Union[int, Tuple[int]] = None,
padding: Union[int, Tuple[int]] = 0,
data_format: Optional[str] = "NCW",
):
if strides is None:
strides = kernel_size
input_shape = tf.shape(input)
if data_format in ["NCW", "NWC"]:
revert = False
if data_format == "NWC":
x_len = (input_shape[1],)
input = tf.transpose(input, perm=[0, 2, 1])
indices = tf.transpose(indices, perm=[0, 2, 1])
revert = True
else:
x_len = (input_shape[-1],)
else:
raise ValueError(f"data_format attr should be NCW or NWC but found {data_format}")

input_shape = tf.shape(input)
ind_dtype = indices.dtype
kernel_size = _broadcast_pooling_helper(kernel_size, "1d", name="kernel_size")
padding = _broadcast_pooling_helper(padding, "1d", name="padding")
strides = _broadcast_pooling_helper(strides, "1d", name="strides")
output_len = _cal_output_shape(x_len, padding, kernel_size, strides)
output_shape = list(input_shape[:-1]) + output_len
one_like_mask = tf.ones_like(indices, dtype=ind_dtype)
batch_shape = tf.concat([[input_shape[0]], [1], [1]], axis=0)
batch_range = tf.reshape(
tf.range(0, output_shape[0], dtype=ind_dtype), shape=batch_shape
)
b = one_like_mask * batch_range
feature_range = tf.reshape(
tf.range(0, output_shape[1], dtype=ind_dtype), shape=(1, -1, 1)
)
f = one_like_mask * feature_range
output = tf.zeros(output_shape, dtype=input.dtype)
indices = tf.reshape(tf.stack([b, f, indices]), [3, -1])
indices = tf.transpose(indices)
values = tf.reshape(input, (-1,))
output = tf.tensor_scatter_nd_update(output, indices, values)
if revert:
output = tf.transpose(output, perm=[0, 2, 1])
return output

Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# global
import numpy as np
import torch
from hypothesis import strategies as st, assume

# local
Expand Down Expand Up @@ -1190,23 +1191,54 @@ def test_max_pool3d(
)


@st.composite
def max_unpool1d_helper(
draw,
**data_gen_kwargs
):
dts, values, kernel_size, strides, _ = draw(
helpers.arrays_for_pooling(
min_dims=3,
max_dims=3,
data_format="channel_first",
**data_gen_kwargs)
)
dts.extend(["int64"])
values = values[0]
padding = draw(helpers.ints(min_value=0, max_value=2))
if padding > (kernel_size[0] // 2):
padding = 0

values, indices = torch.nn.functional.max_pool1d(
torch.tensor(values.astype(np.float32)),
kernel_size,
strides,
padding,
return_indices=True,
)
indices = indices.numpy().astype(np.int64)
max_idx = values.shape[-1] - 1
indices = np.where(indices > max_idx, max_idx, indices)
values = values.numpy().astype(dts[0])
return dts, values, indices, kernel_size, strides, padding


@handle_test(
fn_tree="functional.ivy.experimental.layers.max_unpool1d",
x_k_s_p=helpers.arrays_for_pooling(min_dims=3, max_dims=3, min_side=1, max_side=4),
indices=st.lists(st.integers(0, 1), min_size=1, max_size=4),
fn_tree="functional.ivy.experimental.max_unpool1d",
x_k_s_p=max_unpool1d_helper(min_side=2, max_side=5),
ground_truth_backend="jax",
test_gradients=st.just(False),
test_with_out=st.just(False),
)
def test_max_unpool1d(
*,
x_k_s_p,
indices,
test_flags,
backend_fw,
fn_name,
on_device,
):
dtype, x, kernel, stride, pad = x_k_s_p
dtype, x, ind, kernel, stride, pad = x_k_s_p
helpers.test_function(
input_dtypes=dtype,
test_flags=test_flags,
Expand All @@ -1215,11 +1247,11 @@ def test_max_unpool1d(
fn_name=fn_name,
rtol_=1e-2,
atol_=1e-2,
x=x[0],
kernel=kernel,
input=x,
indices=ind,
kernel_size=kernel,
strides=stride,
padding=pad,
indices=indices,
)


Expand Down

0 comments on commit 1df7a40

Please sign in to comment.