Skip to content

Commit

Permalink
add padding _kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 11, 2023
1 parent 5d798c5 commit c4824c4
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 9 deletions.
2 changes: 1 addition & 1 deletion kernex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@
"offset_kernel_scan",
)

__version__ = "0.2.0"
__version__ = "0.2.1"
10 changes: 8 additions & 2 deletions kernex/_src/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,13 @@ def kernel_map(
relative: bool = False,
map_kind: MapKind = "vmap",
map_kwargs: dict[str, Any] | None = None,
padding_kwargs: dict[str, Any] | None = None,
) -> Callable:

map_kwargs = map_kwargs or {}
padding_kwargs = padding_kwargs or {}
padding_kwargs.pop("pad_width", None) # handled by border

map_tranform = transform_func_map[map_kind]
pad_width = _calculate_pad_width(border)
args = (shape, kernel_size, strides, border)
Expand All @@ -82,7 +86,7 @@ def kernel_map(
slices = tuple(func_map.values())

def single_call_wrapper(array: jax.Array, *a, **k):
padded_array = jnp.pad(array, pad_width)
padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs)

# convert the function to a callable that takes a view and an array
# and returns the result of the function applied to the view
Expand All @@ -98,7 +102,7 @@ def map_func(view):
return result.reshape(*output_shape, *result.shape[1:])

def multi_call_wrapper(array: jax.Array, *a, **k):
padded_array = jnp.pad(array, pad_width)
padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs)
# convert the functions to a callable that takes a view and an array
# and returns the result of the function applied to the view
# the result is a 1D array of the same length as the number of views
Expand Down Expand Up @@ -133,6 +137,7 @@ def offset_kernel_map(
relative: bool = False,
map_kind: MapKind = "vmap",
map_kwargs: dict[str, Any] = None,
offset_kwargs: dict[str, Any] = None,
):

func = kernel_map(
Expand All @@ -144,6 +149,7 @@ def offset_kernel_map(
relative=relative,
map_kind=map_kind,
map_kwargs=map_kwargs,
padding_kwargs=offset_kwargs,
)
set_indices = _get_set_indices(shape, strides, offset)

Expand Down
9 changes: 7 additions & 2 deletions kernex/_src/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,12 @@ def kernel_scan(
relative: bool = False,
scan_kind: ScanKind = "scan", # dummy to make signature consistent with kernel_map
scan_kwargs: dict[str, Any] | None = None,
padding_kwargs: dict[str, Any] | None = None,
):

scan_kwargs = scan_kwargs or {}
padding_kwargs = padding_kwargs or {}
padding_kwargs.pop("pad_width", None)
scan_transform = transform_func_map[scan_kind]
pad_width = _calculate_pad_width(border)
args = (shape, kernel_size, strides, border)
Expand All @@ -82,7 +85,7 @@ def kernel_scan(
slices = tuple(func_map.values())

def single_call_wrapper(array: jax.Array, *a, **k):
padded_array = jnp.pad(array, pad_width)
padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs)
func0 = next(iter(func_map))
reduced_func = _transform_scan_func(func0, kernel_size, relative)(*a, **k)

Expand All @@ -95,7 +98,7 @@ def scan_body(padded_array: jax.Array, view: jax.Array):
return result.reshape(output_shape)

def multi_call_wrapper(array: jax.Array, *a, **k):
padded_array = jnp.pad(array, pad_width)
padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs)

reduced_funcs = tuple(
_transform_scan_func(func, kernel_size, relative)(*a, **k)
Expand Down Expand Up @@ -124,6 +127,7 @@ def offset_kernel_scan(
relative: bool = False,
scan_kind: ScanKind = "scan",
scan_kwargs: dict[str, Any] | None = None,
offset_kwargs: dict[str, Any] | None = None,
):

func = kernel_scan(
Expand All @@ -135,6 +139,7 @@ def offset_kernel_scan(
relative=relative,
scan_kind=scan_kind,
scan_kwargs=scan_kwargs,
padding_kwargs=offset_kwargs,
)
set_indices = _get_set_indices(shape, strides, offset)

Expand Down
24 changes: 20 additions & 4 deletions kernex/interface/kernel_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
)

BorderType = Union[
int,
Tuple[int, ...],
Tuple[Tuple[int, int], ...],
Literal["valid", "same", "SAME", "VALID"],
int, # single int to pad all axes before and after the array
Tuple[int, ...], # tuple of ints to pad before and after each axis
Tuple[Tuple[int, int], ...], # tuple of tuples to pad before and after each axis
Literal["valid", "same", "SAME", "VALID"], # string to use a predefined padding
]

StridesType = Union[Tuple[int, ...], int]
Expand All @@ -56,6 +56,7 @@ def __init__(
container: dict[Callable, slice | int] | None = None,
transform_kind: MapKind | ScanKind | None = None,
transform_kwargs: dict[str, Any] | None = None,
border_kwargs: dict[str, Any] | None = None,
):
self.kernel_size = kernel_size
self.strides = strides
Expand All @@ -71,6 +72,7 @@ def __init__(
)
self.transform_kind = transform_kind
self.transform_kwargs = transform_kwargs
self.border_kwargs = border_kwargs

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -124,6 +126,7 @@ def _wrap_mesh(self, array: jax.Array, *a, **k):
self.relative,
self.transform_kind,
self.transform_kwargs,
self.border_kwargs,
)(array, *a, **k)

def _wrap_decorator(self, func):
Expand Down Expand Up @@ -154,6 +157,7 @@ def call(array, *args, **kwargs):
self.relative,
self.transform_kind,
self.transform_kwargs,
self.border_kwargs,
)(array, *args, **kwargs)

return call
Expand Down Expand Up @@ -252,6 +256,7 @@ def __init__(
named_axis=named_axis,
transform_kind=scan_kind,
transform_kwargs=scan_kwargs,
border_kwargs=None,
)


Expand Down Expand Up @@ -337,6 +342,7 @@ def __init__(
named_axis=named_axis,
transform_kind=map_kind,
transform_kwargs=map_kwargs,
border_kwargs=None,
)


Expand All @@ -350,6 +356,7 @@ def __init__(
named_axis: dict[int, str] = None,
scan_kind: ScanKind = "scan",
scan_kwargs: dict[str, Any] | None = None,
padding_kwargs: dict[str, Any] | None = None,
):
"""Apply a function to a sliding window of the input array sequentially.
Expand All @@ -368,6 +375,9 @@ def __init__(
scan_kwargs: optional kwargs to be passed to the scan function.
for example, `scan_kwargs={'reverse': True}` will reverse the
application of the function.
padding_kwargs: optional kwargs to be passed to the padding function.
for example, `padding_kwargs=dict(constant_values=10)` will pad
the input array with 10 for same padding.
Returns:
A function that takes an array as input and returns the result of
Expand Down Expand Up @@ -400,6 +410,7 @@ def __init__(
named_axis=named_axis,
transform_kind=scan_kind,
transform_kwargs=scan_kwargs,
border_kwargs=padding_kwargs,
)


Expand All @@ -413,6 +424,7 @@ def __init__(
named_axis: dict[int, str] = None,
map_kind: MapKind = "vmap",
map_kwargs: dict = None,
padding_kwargs: dict = None,
):
"""Apply a function to a sliding window of the input array in parallel.
Expand All @@ -432,6 +444,9 @@ def __init__(
map_kwargs: optional kwargs to be passed to the map function.
for example, `map_kwargs={'axis_name': 'i'}` will apply the
function along the axis named `i` for `pmap`.
padding_kwargs: optional kwargs to be passed to the padding function.
for example, `padding_kwargs=dict(constant_values=10)` will pad
the input array with 10 for same padding.
Returns:
A function that takes an array as input and applies the kernel
Expand Down Expand Up @@ -464,4 +479,5 @@ def __init__(
named_axis=named_axis,
transform_kind=map_kind,
transform_kwargs=map_kwargs,
border_kwargs=padding_kwargs,
)
18 changes: 18 additions & 0 deletions tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,21 @@ def kex_conv2d(x, w):
pred_grad = jax.grad(lambda w: jnp.sum(kex_conv2d(x, w)))(w)

np.testing.assert_allclose(true_grad[0], pred_grad, atol=1e-3)


def test_padding_kwargs():
@kex.kmap(
kernel_size=(3,),
padding=("same"),
relative=False,
padding_kwargs=dict(constant_values=10),
)
def f(x):
return x

x = jnp.array([1, 2, 3, 4, 5])

np.testing.assert_allclose(
f(x),
np.array([[10, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 10]]),
)

0 comments on commit c4824c4

Please sign in to comment.