Skip to content

Commit

Permalink
refactor(ivy): removed additional framework dependencies due to pad (
Browse files Browse the repository at this point in the history
  • Loading branch information
vedpatwardhan authored Sep 2, 2023
1 parent 62db717 commit 446cf62
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 49 deletions.
44 changes: 9 additions & 35 deletions ivy/functional/backends/paddle/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
import ivy
import ivy.functional.backends.paddle as paddle_backend
from ivy.func_wrapper import with_supported_device_and_dtypes
from ...tensorflow.experimental.manipulation import _to_tf_padding
from ivy.functional.ivy.experimental.manipulation import (
_check_paddle_pad,
_to_paddle_padding,
)

# Code from cephes for i0

Expand Down Expand Up @@ -148,42 +151,13 @@ def pad(
)


pad.partial_mixed_handler = lambda *args, mode="constant", constant_values=0, reflect_type="even", **kwargs: _check_paddle_pad(
mode, reflect_type, args[1], args[0].shape, constant_values, 3
)


def _check_paddle_pad(
mode, reflect_type, pad_width, input_shape, constant_values, ndim_limit
):
pad_width = _to_tf_padding(pad_width, len(input_shape))
return isinstance(constant_values, Number) and (
mode == "constant"
or (
(
(
mode == "reflect"
and reflect_type == "even"
and all(
pad_width[i][0] < s and pad_width[i][1] < s
for i, s in enumerate(input_shape)
)
)
or mode in ["edge", "wrap"]
)
and len(input_shape) <= ndim_limit
pad.partial_mixed_handler = (
lambda *args, mode="constant", constant_values=0, reflect_type="even", **kwargs: (
_check_paddle_pad(
mode, reflect_type, args[1], args[0].shape, constant_values, 3
)
)


def _to_paddle_padding(pad_width, ndim):
if isinstance(pad_width, Number):
pad_width = [pad_width] * (2 * ndim)
else:
if len(pad_width) == 2 and isinstance(pad_width[0], Number) and ndim != 1:
pad_width = pad_width * ndim
pad_width = [item for sublist in pad_width for item in sublist[::-1]][::-1]
return pad_width
)


@with_unsupported_device_and_dtypes(
Expand Down
15 changes: 5 additions & 10 deletions ivy/functional/backends/tensorflow/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ivy.func_wrapper import with_unsupported_dtypes
from .. import backend_version
import ivy
from ivy.functional.ivy.experimental.manipulation import _to_tf_padding


def moveaxis(
Expand Down Expand Up @@ -294,8 +295,10 @@ def pad(
)


pad.partial_mixed_handler = lambda *args, mode="constant", constant_values=0, reflect_type="even", **kwargs: _check_tf_pad(
args[0].shape, args[1], mode, constant_values, reflect_type
pad.partial_mixed_handler = (
lambda *args, mode="constant", constant_values=0, reflect_type="even", **kwargs: (
_check_tf_pad(args[0].shape, args[1], mode, constant_values, reflect_type)
)
)


Expand Down Expand Up @@ -325,14 +328,6 @@ def _check_tf_pad(input_shape, pad_width, mode, constant_values, reflect_type):
)


def _to_tf_padding(pad_width, ndim):
if isinstance(pad_width, Number):
pad_width = [[pad_width] * 2] * ndim
elif len(pad_width) == 2 and isinstance(pad_width[0], Number):
pad_width = pad_width * ndim
return pad_width


def expand(
x: Union[tf.Tensor, tf.Variable],
shape: Union[List[int], List[Tuple]],
Expand Down
13 changes: 9 additions & 4 deletions ivy/functional/backends/torch/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes
from .. import backend_version
import ivy
from ...paddle.experimental.manipulation import _check_paddle_pad, _to_paddle_padding
from ...tensorflow.experimental.manipulation import _to_tf_padding
from ivy.functional.ivy.experimental.manipulation import (
_to_tf_padding,
_check_paddle_pad,
_to_paddle_padding,
)


def moveaxis(
Expand Down Expand Up @@ -110,8 +113,10 @@ def pad(
).squeeze(0)


pad.partial_mixed_handler = lambda *args, mode="constant", constant_values=0, reflect_type="even", **kwargs: _check_torch_pad(
mode, reflect_type, args[1], args[0].shape, constant_values
pad.partial_mixed_handler = (
lambda *args, mode="constant", constant_values=0, reflect_type="even", **kwargs: (
_check_torch_pad(mode, reflect_type, args[1], args[0].shape, constant_values)
)
)


Expand Down
45 changes: 45 additions & 0 deletions ivy/functional/ivy/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,51 @@
from ivy.utils.exceptions import handle_exceptions


# Helpers #
# ------- #


def _to_tf_padding(pad_width, ndim):
if isinstance(pad_width, Number):
pad_width = [[pad_width] * 2] * ndim
elif len(pad_width) == 2 and isinstance(pad_width[0], Number):
pad_width = pad_width * ndim
return pad_width


def _check_paddle_pad(
mode, reflect_type, pad_width, input_shape, constant_values, ndim_limit
):
pad_width = _to_tf_padding(pad_width, len(input_shape))
return isinstance(constant_values, Number) and (
mode == "constant"
or (
(
(
mode == "reflect"
and reflect_type == "even"
and all(
pad_width[i][0] < s and pad_width[i][1] < s
for i, s in enumerate(input_shape)
)
)
or mode in ["edge", "wrap"]
)
and len(input_shape) <= ndim_limit
)
)


def _to_paddle_padding(pad_width, ndim):
if isinstance(pad_width, Number):
pad_width = [pad_width] * (2 * ndim)
else:
if len(pad_width) == 2 and isinstance(pad_width[0], Number) and ndim != 1:
pad_width = pad_width * ndim
pad_width = [item for sublist in pad_width for item in sublist[::-1]][::-1]
return pad_width


@handle_exceptions
@handle_nestable
@handle_partial_mixed_function
Expand Down

0 comments on commit 446cf62

Please sign in to comment.