Skip to content

Commit

Permalink
Revert "Merge pull request #13 from ASEM000/padding_kwargs"
Browse files Browse the repository at this point in the history
This reverts commit 3130931, reversing
changes made to 5d798c5.
  • Loading branch information
ASEM000 committed Jul 11, 2023
1 parent 3130931 commit c54dde0
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 54 deletions.
2 changes: 1 addition & 1 deletion kernex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@
"offset_kernel_scan",
)

__version__ = "0.2.1"
__version__ = "0.2.0"
10 changes: 2 additions & 8 deletions kernex/_src/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,9 @@ def kernel_map(
relative: bool = False,
map_kind: MapKind = "vmap",
map_kwargs: dict[str, Any] | None = None,
padding_kwargs: dict[str, Any] | None = None,
) -> Callable:

map_kwargs = map_kwargs or {}
padding_kwargs = padding_kwargs or {}
padding_kwargs.pop("pad_width", None) # handled by border

map_tranform = transform_func_map[map_kind]
pad_width = _calculate_pad_width(border)
args = (shape, kernel_size, strides, border)
Expand All @@ -86,7 +82,7 @@ def kernel_map(
slices = tuple(func_map.values())

def single_call_wrapper(array: jax.Array, *a, **k):
padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs)
padded_array = jnp.pad(array, pad_width)

# convert the function to a callable that takes a view and an array
# and returns the result of the function applied to the view
Expand All @@ -102,7 +98,7 @@ def map_func(view):
return result.reshape(*output_shape, *result.shape[1:])

def multi_call_wrapper(array: jax.Array, *a, **k):
padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs)
padded_array = jnp.pad(array, pad_width)
# convert the functions to a callable that takes a view and an array
# and returns the result of the function applied to the view
# the result is a 1D array of the same length as the number of views
Expand Down Expand Up @@ -137,7 +133,6 @@ def offset_kernel_map(
relative: bool = False,
map_kind: MapKind = "vmap",
map_kwargs: dict[str, Any] = None,
offset_kwargs: dict[str, Any] = None,
):

func = kernel_map(
Expand All @@ -149,7 +144,6 @@ def offset_kernel_map(
relative=relative,
map_kind=map_kind,
map_kwargs=map_kwargs,
padding_kwargs=offset_kwargs,
)
set_indices = _get_set_indices(shape, strides, offset)

Expand Down
9 changes: 2 additions & 7 deletions kernex/_src/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,9 @@ def kernel_scan(
relative: bool = False,
scan_kind: ScanKind = "scan", # dummy to make signature consistent with kernel_map
scan_kwargs: dict[str, Any] | None = None,
padding_kwargs: dict[str, Any] | None = None,
):

scan_kwargs = scan_kwargs or {}
padding_kwargs = padding_kwargs or {}
padding_kwargs.pop("pad_width", None)
scan_transform = transform_func_map[scan_kind]
pad_width = _calculate_pad_width(border)
args = (shape, kernel_size, strides, border)
Expand All @@ -85,7 +82,7 @@ def kernel_scan(
slices = tuple(func_map.values())

def single_call_wrapper(array: jax.Array, *a, **k):
padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs)
padded_array = jnp.pad(array, pad_width)
func0 = next(iter(func_map))
reduced_func = _transform_scan_func(func0, kernel_size, relative)(*a, **k)

Expand All @@ -98,7 +95,7 @@ def scan_body(padded_array: jax.Array, view: jax.Array):
return result.reshape(output_shape)

def multi_call_wrapper(array: jax.Array, *a, **k):
padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs)
padded_array = jnp.pad(array, pad_width)

reduced_funcs = tuple(
_transform_scan_func(func, kernel_size, relative)(*a, **k)
Expand Down Expand Up @@ -127,7 +124,6 @@ def offset_kernel_scan(
relative: bool = False,
scan_kind: ScanKind = "scan",
scan_kwargs: dict[str, Any] | None = None,
offset_kwargs: dict[str, Any] | None = None,
):

func = kernel_scan(
Expand All @@ -139,7 +135,6 @@ def offset_kernel_scan(
relative=relative,
scan_kind=scan_kind,
scan_kwargs=scan_kwargs,
padding_kwargs=offset_kwargs,
)
set_indices = _get_set_indices(shape, strides, offset)

Expand Down
24 changes: 4 additions & 20 deletions kernex/interface/kernel_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
)

BorderType = Union[
int, # single int to pad all axes before and after the array
Tuple[int, ...], # tuple of ints to pad before and after each axis
Tuple[Tuple[int, int], ...], # tuple of tuples to pad before and after each axis
Literal["valid", "same", "SAME", "VALID"], # string to use a predefined padding
int,
Tuple[int, ...],
Tuple[Tuple[int, int], ...],
Literal["valid", "same", "SAME", "VALID"],
]

StridesType = Union[Tuple[int, ...], int]
Expand All @@ -56,7 +56,6 @@ def __init__(
container: dict[Callable, slice | int] | None = None,
transform_kind: MapKind | ScanKind | None = None,
transform_kwargs: dict[str, Any] | None = None,
border_kwargs: dict[str, Any] | None = None,
):
self.kernel_size = kernel_size
self.strides = strides
Expand All @@ -72,7 +71,6 @@ def __init__(
)
self.transform_kind = transform_kind
self.transform_kwargs = transform_kwargs
self.border_kwargs = border_kwargs

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -126,7 +124,6 @@ def _wrap_mesh(self, array: jax.Array, *a, **k):
self.relative,
self.transform_kind,
self.transform_kwargs,
self.border_kwargs,
)(array, *a, **k)

def _wrap_decorator(self, func):
Expand Down Expand Up @@ -157,7 +154,6 @@ def call(array, *args, **kwargs):
self.relative,
self.transform_kind,
self.transform_kwargs,
self.border_kwargs,
)(array, *args, **kwargs)

return call
Expand Down Expand Up @@ -256,7 +252,6 @@ def __init__(
named_axis=named_axis,
transform_kind=scan_kind,
transform_kwargs=scan_kwargs,
border_kwargs=None,
)


Expand Down Expand Up @@ -342,7 +337,6 @@ def __init__(
named_axis=named_axis,
transform_kind=map_kind,
transform_kwargs=map_kwargs,
border_kwargs=None,
)


Expand All @@ -356,7 +350,6 @@ def __init__(
named_axis: dict[int, str] = None,
scan_kind: ScanKind = "scan",
scan_kwargs: dict[str, Any] | None = None,
padding_kwargs: dict[str, Any] | None = None,
):
"""Apply a function to a sliding window of the input array sequentially.
Expand All @@ -375,9 +368,6 @@ def __init__(
scan_kwargs: optional kwargs to be passed to the scan function.
for example, `scan_kwargs={'reverse': True}` will reverse the
application of the function.
padding_kwargs: optional kwargs to be passed to the padding function.
for example, `padding_kwargs=dict(constant_values=10)` will pad
the input array with 10 for same padding.
Returns:
A function that takes an array as input and returns the result of
Expand Down Expand Up @@ -410,7 +400,6 @@ def __init__(
named_axis=named_axis,
transform_kind=scan_kind,
transform_kwargs=scan_kwargs,
border_kwargs=padding_kwargs,
)


Expand All @@ -424,7 +413,6 @@ def __init__(
named_axis: dict[int, str] = None,
map_kind: MapKind = "vmap",
map_kwargs: dict = None,
padding_kwargs: dict = None,
):
"""Apply a function to a sliding window of the input array in parallel.
Expand All @@ -444,9 +432,6 @@ def __init__(
map_kwargs: optional kwargs to be passed to the map function.
for example, `map_kwargs={'axis_name': 'i'}` will apply the
function along the axis named `i` for `pmap`.
padding_kwargs: optional kwargs to be passed to the padding function.
for example, `padding_kwargs=dict(constant_values=10)` will pad
the input array with 10 for same padding.
Returns:
A function that takes an array as input and applies the kernel
Expand Down Expand Up @@ -479,5 +464,4 @@ def __init__(
named_axis=named_axis,
transform_kind=map_kind,
transform_kwargs=map_kwargs,
border_kwargs=padding_kwargs,
)
18 changes: 0 additions & 18 deletions tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,21 +346,3 @@ def kex_conv2d(x, w):
pred_grad = jax.grad(lambda w: jnp.sum(kex_conv2d(x, w)))(w)

np.testing.assert_allclose(true_grad[0], pred_grad, atol=1e-3)


def test_padding_kwargs():
@kex.kmap(
kernel_size=(3,),
padding=("same"),
relative=False,
padding_kwargs=dict(constant_values=10),
)
def f(x):
return x

x = jnp.array([1, 2, 3, 4, 5])

np.testing.assert_allclose(
f(x),
np.array([[10, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 10]]),
)

0 comments on commit c54dde0

Please sign in to comment.