diff --git a/kernex/__init__.py b/kernex/__init__.py index 133d5c8..30d6008 100644 --- a/kernex/__init__.py +++ b/kernex/__init__.py @@ -27,4 +27,4 @@ "offset_kernel_scan", ) -__version__ = "0.2.0" +__version__ = "0.2.1" diff --git a/kernex/_src/map.py b/kernex/_src/map.py index 67c8238..4cbbf43 100644 --- a/kernex/_src/map.py +++ b/kernex/_src/map.py @@ -67,9 +67,13 @@ def kernel_map( relative: bool = False, map_kind: MapKind = "vmap", map_kwargs: dict[str, Any] | None = None, + padding_kwargs: dict[str, Any] | None = None, ) -> Callable: map_kwargs = map_kwargs or {} + padding_kwargs = padding_kwargs or {} + padding_kwargs.pop("pad_width", None) # handled by border + map_tranform = transform_func_map[map_kind] pad_width = _calculate_pad_width(border) args = (shape, kernel_size, strides, border) @@ -82,7 +86,7 @@ def kernel_map( slices = tuple(func_map.values()) def single_call_wrapper(array: jax.Array, *a, **k): - padded_array = jnp.pad(array, pad_width) + padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs) # convert the function to a callable that takes a view and an array # and returns the result of the function applied to the view @@ -98,7 +102,7 @@ def map_func(view): return result.reshape(*output_shape, *result.shape[1:]) def multi_call_wrapper(array: jax.Array, *a, **k): - padded_array = jnp.pad(array, pad_width) + padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs) # convert the functions to a callable that takes a view and an array # and returns the result of the function applied to the view # the result is a 1D array of the same length as the number of views @@ -133,6 +137,7 @@ def offset_kernel_map( relative: bool = False, map_kind: MapKind = "vmap", map_kwargs: dict[str, Any] = None, + offset_kwargs: dict[str, Any] = None, ): func = kernel_map( @@ -144,6 +149,7 @@ def offset_kernel_map( relative=relative, map_kind=map_kind, map_kwargs=map_kwargs, + padding_kwargs=offset_kwargs, ) set_indices = _get_set_indices(shape, strides, offset) diff --git a/kernex/_src/scan.py b/kernex/_src/scan.py index b8b924e..719056e 100644 --- a/kernex/_src/scan.py +++ b/kernex/_src/scan.py @@ -71,9 +71,12 @@ def kernel_scan( relative: bool = False, scan_kind: ScanKind = "scan", # dummy to make signature consistent with kernel_map scan_kwargs: dict[str, Any] | None = None, + padding_kwargs: dict[str, Any] | None = None, ): scan_kwargs = scan_kwargs or {} + padding_kwargs = padding_kwargs or {} + padding_kwargs.pop("pad_width", None) scan_transform = transform_func_map[scan_kind] pad_width = _calculate_pad_width(border) args = (shape, kernel_size, strides, border) @@ -82,7 +85,7 @@ def kernel_scan( slices = tuple(func_map.values()) def single_call_wrapper(array: jax.Array, *a, **k): - padded_array = jnp.pad(array, pad_width) + padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs) func0 = next(iter(func_map)) reduced_func = _transform_scan_func(func0, kernel_size, relative)(*a, **k) @@ -95,7 +98,7 @@ def scan_body(padded_array: jax.Array, view: jax.Array): return result.reshape(output_shape) def multi_call_wrapper(array: jax.Array, *a, **k): - padded_array = jnp.pad(array, pad_width) + padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs) reduced_funcs = tuple( _transform_scan_func(func, kernel_size, relative)(*a, **k) @@ -124,6 +127,7 @@ def offset_kernel_scan( relative: bool = False, scan_kind: ScanKind = "scan", scan_kwargs: dict[str, Any] | None = None, + offset_kwargs: dict[str, Any] | None = None, ): func = kernel_scan( @@ -135,6 +139,7 @@ def offset_kernel_scan( relative=relative, scan_kind=scan_kind, scan_kwargs=scan_kwargs, + padding_kwargs=offset_kwargs, ) set_indices = _get_set_indices(shape, strides, offset) diff --git a/kernex/interface/kernel_interface.py b/kernex/interface/kernel_interface.py index 04195a2..cb5a190 100644 --- a/kernex/interface/kernel_interface.py +++ b/kernex/interface/kernel_interface.py @@ -32,10 +32,10 @@ ) BorderType = Union[ - int, - Tuple[int, ...], - Tuple[Tuple[int, int], ...], - Literal["valid", "same", "SAME", "VALID"], + int, # single int to pad all axes before and after the array + Tuple[int, ...], # tuple of ints to pad before and after each axis + Tuple[Tuple[int, int], ...], # tuple of tuples to pad before and after each axis + Literal["valid", "same", "SAME", "VALID"], # string to use a predefined padding ] StridesType = Union[Tuple[int, ...], int] @@ -56,6 +56,7 @@ def __init__( container: dict[Callable, slice | int] | None = None, transform_kind: MapKind | ScanKind | None = None, transform_kwargs: dict[str, Any] | None = None, + border_kwargs: dict[str, Any] | None = None, ): self.kernel_size = kernel_size self.strides = strides @@ -71,6 +72,7 @@ def __init__( ) self.transform_kind = transform_kind self.transform_kwargs = transform_kwargs + self.border_kwargs = border_kwargs def __repr__(self) -> str: return ( @@ -124,6 +126,7 @@ def _wrap_mesh(self, array: jax.Array, *a, **k): self.relative, self.transform_kind, self.transform_kwargs, + self.border_kwargs, )(array, *a, **k) def _wrap_decorator(self, func): @@ -154,6 +157,7 @@ def call(array, *args, **kwargs): self.relative, self.transform_kind, self.transform_kwargs, + self.border_kwargs, )(array, *args, **kwargs) return call @@ -252,6 +256,7 @@ def __init__( named_axis=named_axis, transform_kind=scan_kind, transform_kwargs=scan_kwargs, + border_kwargs=None, ) @@ -337,6 +342,7 @@ def __init__( named_axis=named_axis, transform_kind=map_kind, transform_kwargs=map_kwargs, + border_kwargs=None, ) @@ -350,6 +356,7 @@ def __init__( named_axis: dict[int, str] = None, scan_kind: ScanKind = "scan", scan_kwargs: dict[str, Any] | None = None, + padding_kwargs: dict[str, Any] | None = None, ): """Apply a function to a sliding window of the input array sequentially. @@ -368,6 +375,9 @@ def __init__( scan_kwargs: optional kwargs to be passed to the scan function. for example, `scan_kwargs={'reverse': True}` will reverse the application of the function. + padding_kwargs: optional kwargs to be passed to the padding function. + for example, `padding_kwargs=dict(constant_values=10)` will pad + the input array with 10 for same padding. Returns: A function that takes an array as input and returns the result of @@ -400,6 +410,7 @@ def __init__( named_axis=named_axis, transform_kind=scan_kind, transform_kwargs=scan_kwargs, + border_kwargs=padding_kwargs, ) @@ -413,6 +424,7 @@ def __init__( named_axis: dict[int, str] = None, map_kind: MapKind = "vmap", map_kwargs: dict = None, + padding_kwargs: dict = None, ): """Apply a function to a sliding window of the input array in parallel. @@ -432,6 +444,9 @@ def __init__( map_kwargs: optional kwargs to be passed to the map function. for example, `map_kwargs={'axis_name': 'i'}` will apply the function along the axis named `i` for `pmap`. + padding_kwargs: optional kwargs to be passed to the padding function. + for example, `padding_kwargs=dict(constant_values=10)` will pad + the input array with 10 for same padding. Returns: A function that takes an array as input and applies the kernel @@ -464,4 +479,5 @@ def __init__( named_axis=named_axis, transform_kind=map_kind, transform_kwargs=map_kwargs, + border_kwargs=padding_kwargs, ) diff --git a/tests/test_interface.py b/tests/test_interface.py index d3a475a..56f141c 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -346,3 +346,21 @@ def kex_conv2d(x, w): pred_grad = jax.grad(lambda w: jnp.sum(kex_conv2d(x, w)))(w) np.testing.assert_allclose(true_grad[0], pred_grad, atol=1e-3) + + +def test_padding_kwargs(): + @kex.kmap( + kernel_size=(3,), + padding=("same"), + relative=False, + padding_kwargs=dict(constant_values=10), + ) + def f(x): + return x + + x = jnp.array([1, 2, 3, 4, 5]) + + np.testing.assert_allclose( + f(x), + np.array([[10, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 10]]), + )