diff --git a/kernex/__init__.py b/kernex/__init__.py index 30d6008..133d5c8 100644 --- a/kernex/__init__.py +++ b/kernex/__init__.py @@ -27,4 +27,4 @@ "offset_kernel_scan", ) -__version__ = "0.2.1" +__version__ = "0.2.0" diff --git a/kernex/_src/map.py b/kernex/_src/map.py index 4cbbf43..67c8238 100644 --- a/kernex/_src/map.py +++ b/kernex/_src/map.py @@ -67,13 +67,9 @@ def kernel_map( relative: bool = False, map_kind: MapKind = "vmap", map_kwargs: dict[str, Any] | None = None, - padding_kwargs: dict[str, Any] | None = None, ) -> Callable: map_kwargs = map_kwargs or {} - padding_kwargs = padding_kwargs or {} - padding_kwargs.pop("pad_width", None) # handled by border - map_tranform = transform_func_map[map_kind] pad_width = _calculate_pad_width(border) args = (shape, kernel_size, strides, border) @@ -86,7 +82,7 @@ def kernel_map( slices = tuple(func_map.values()) def single_call_wrapper(array: jax.Array, *a, **k): - padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs) + padded_array = jnp.pad(array, pad_width) # convert the function to a callable that takes a view and an array # and returns the result of the function applied to the view @@ -102,7 +98,7 @@ def map_func(view): return result.reshape(*output_shape, *result.shape[1:]) def multi_call_wrapper(array: jax.Array, *a, **k): - padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs) + padded_array = jnp.pad(array, pad_width) # convert the functions to a callable that takes a view and an array # and returns the result of the function applied to the view # the result is a 1D array of the same length as the number of views @@ -137,7 +133,6 @@ def offset_kernel_map( relative: bool = False, map_kind: MapKind = "vmap", map_kwargs: dict[str, Any] = None, - offset_kwargs: dict[str, Any] = None, ): func = kernel_map( @@ -149,7 +144,6 @@ def offset_kernel_map( relative=relative, map_kind=map_kind, map_kwargs=map_kwargs, - padding_kwargs=offset_kwargs, ) set_indices = _get_set_indices(shape, strides, offset) diff --git a/kernex/_src/scan.py b/kernex/_src/scan.py index 719056e..b8b924e 100644 --- a/kernex/_src/scan.py +++ b/kernex/_src/scan.py @@ -71,12 +71,9 @@ def kernel_scan( relative: bool = False, scan_kind: ScanKind = "scan", # dummy to make signature consistent with kernel_map scan_kwargs: dict[str, Any] | None = None, - padding_kwargs: dict[str, Any] | None = None, ): scan_kwargs = scan_kwargs or {} - padding_kwargs = padding_kwargs or {} - padding_kwargs.pop("pad_width", None) scan_transform = transform_func_map[scan_kind] pad_width = _calculate_pad_width(border) args = (shape, kernel_size, strides, border) @@ -85,7 +82,7 @@ def kernel_scan( slices = tuple(func_map.values()) def single_call_wrapper(array: jax.Array, *a, **k): - padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs) + padded_array = jnp.pad(array, pad_width) func0 = next(iter(func_map)) reduced_func = _transform_scan_func(func0, kernel_size, relative)(*a, **k) @@ -98,7 +95,7 @@ def scan_body(padded_array: jax.Array, view: jax.Array): return result.reshape(output_shape) def multi_call_wrapper(array: jax.Array, *a, **k): - padded_array = jnp.pad(array, pad_width=pad_width, **padding_kwargs) + padded_array = jnp.pad(array, pad_width) reduced_funcs = tuple( _transform_scan_func(func, kernel_size, relative)(*a, **k) @@ -127,7 +124,6 @@ def offset_kernel_scan( relative: bool = False, scan_kind: ScanKind = "scan", scan_kwargs: dict[str, Any] | None = None, - offset_kwargs: dict[str, Any] | None = None, ): func = kernel_scan( @@ -139,7 +135,6 @@ def offset_kernel_scan( relative=relative, scan_kind=scan_kind, scan_kwargs=scan_kwargs, - padding_kwargs=offset_kwargs, ) set_indices = _get_set_indices(shape, strides, offset) diff --git a/kernex/interface/kernel_interface.py b/kernex/interface/kernel_interface.py index cb5a190..04195a2 100644 --- a/kernex/interface/kernel_interface.py +++ b/kernex/interface/kernel_interface.py @@ -32,10 +32,10 @@ ) BorderType = Union[ - int, # single int to pad all axes before and after the array - Tuple[int, ...], # tuple of ints to pad before and after each axis - Tuple[Tuple[int, int], ...], # tuple of tuples to pad before and after each axis - Literal["valid", "same", "SAME", "VALID"], # string to use a predefined padding + int, + Tuple[int, ...], + Tuple[Tuple[int, int], ...], + Literal["valid", "same", "SAME", "VALID"], ] StridesType = Union[Tuple[int, ...], int] @@ -56,7 +56,6 @@ def __init__( container: dict[Callable, slice | int] | None = None, transform_kind: MapKind | ScanKind | None = None, transform_kwargs: dict[str, Any] | None = None, - border_kwargs: dict[str, Any] | None = None, ): self.kernel_size = kernel_size self.strides = strides @@ -72,7 +71,6 @@ def __init__( ) self.transform_kind = transform_kind self.transform_kwargs = transform_kwargs - self.border_kwargs = border_kwargs def __repr__(self) -> str: return ( @@ -126,7 +124,6 @@ def _wrap_mesh(self, array: jax.Array, *a, **k): self.relative, self.transform_kind, self.transform_kwargs, - self.border_kwargs, )(array, *a, **k) def _wrap_decorator(self, func): @@ -157,7 +154,6 @@ def call(array, *args, **kwargs): self.relative, self.transform_kind, self.transform_kwargs, - self.border_kwargs, )(array, *args, **kwargs) return call @@ -256,7 +252,6 @@ def __init__( named_axis=named_axis, transform_kind=scan_kind, transform_kwargs=scan_kwargs, - border_kwargs=None, ) @@ -342,7 +337,6 @@ def __init__( named_axis=named_axis, transform_kind=map_kind, transform_kwargs=map_kwargs, - border_kwargs=None, ) @@ -356,7 +350,6 @@ def __init__( named_axis: dict[int, str] = None, scan_kind: ScanKind = "scan", scan_kwargs: dict[str, Any] | None = None, - padding_kwargs: dict[str, Any] | None = None, ): """Apply a function to a sliding window of the input array sequentially. @@ -375,9 +368,6 @@ def __init__( scan_kwargs: optional kwargs to be passed to the scan function. for example, `scan_kwargs={'reverse': True}` will reverse the application of the function. - padding_kwargs: optional kwargs to be passed to the padding function. - for example, `padding_kwargs=dict(constant_values=10)` will pad - the input array with 10 for same padding. Returns: A function that takes an array as input and returns the result of @@ -410,7 +400,6 @@ def __init__( named_axis=named_axis, transform_kind=scan_kind, transform_kwargs=scan_kwargs, - border_kwargs=padding_kwargs, ) @@ -424,7 +413,6 @@ def __init__( named_axis: dict[int, str] = None, map_kind: MapKind = "vmap", map_kwargs: dict = None, - padding_kwargs: dict = None, ): """Apply a function to a sliding window of the input array in parallel. @@ -444,9 +432,6 @@ def __init__( map_kwargs: optional kwargs to be passed to the map function. for example, `map_kwargs={'axis_name': 'i'}` will apply the function along the axis named `i` for `pmap`. - padding_kwargs: optional kwargs to be passed to the padding function. - for example, `padding_kwargs=dict(constant_values=10)` will pad - the input array with 10 for same padding. Returns: A function that takes an array as input and applies the kernel @@ -479,5 +464,4 @@ def __init__( named_axis=named_axis, transform_kind=map_kind, transform_kwargs=map_kwargs, - border_kwargs=padding_kwargs, ) diff --git a/tests/test_interface.py b/tests/test_interface.py index 56f141c..d3a475a 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -346,21 +346,3 @@ def kex_conv2d(x, w): pred_grad = jax.grad(lambda w: jnp.sum(kex_conv2d(x, w)))(w) np.testing.assert_allclose(true_grad[0], pred_grad, atol=1e-3) - - -def test_padding_kwargs(): - @kex.kmap( - kernel_size=(3,), - padding=("same"), - relative=False, - padding_kwargs=dict(constant_values=10), - ) - def f(x): - return x - - x = jnp.array([1, 2, 3, 4, 5]) - - np.testing.assert_allclose( - f(x), - np.array([[10, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 10]]), - )