From 9b3070838f6088340dc8c4f9e94ac11b3a448bf6 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Mon, 28 Aug 2023 20:23:11 +0900 Subject: [PATCH] add missing functional form for conv --- docs/API/convolution.rst | 2 +- docs/{API => _static}/fft_bench.svg | 0 serket/nn/activation.py | 117 ++++++++- serket/nn/clustering.py | 3 +- serket/nn/convolution.py | 392 ++++++++++++++++++++++------ tests/test_clustering.py | 2 - 6 files changed, 415 insertions(+), 101 deletions(-) rename docs/{API => _static}/fft_bench.svg (100%) diff --git a/docs/API/convolution.rst b/docs/API/convolution.rst index 3849e4f..180100d 100644 --- a/docs/API/convolution.rst +++ b/docs/API/convolution.rst @@ -26,7 +26,7 @@ Convolution .. note:: The ``fft`` convolution variant is useful in myriad of cases, specifically the ``fft`` variant could be faster for larger kernel sizes. the following figure compares the speed of both implementation for different kernel size on mac ``m1`` cpu setup. - .. image:: fft_bench.svg + .. image:: ../_static/fft_bench.svg :width: 600 :align: center diff --git a/docs/API/fft_bench.svg b/docs/_static/fft_bench.svg similarity index 100% rename from docs/API/fft_bench.svg rename to docs/_static/fft_bench.svg diff --git a/serket/nn/activation.py b/serket/nn/activation.py index 017ec4e..55a1480 100644 --- a/serket/nn/activation.py +++ b/serket/nn/activation.py @@ -27,6 +27,19 @@ T = TypeVar("T") +def adaptive_leaky_relu( + x: jax.typing.ArrayLike, + a: float = 1.0, + v: float = 1.0, +) -> jax.Array: + """Adaptive Leaky ReLU activation function + + Reference: + https://arxiv.org/pdf/1906.01170.pdf. + """ + return jnp.maximum(0, a * x) - v * jnp.maximum(0, -a * x) + + @sk.autoinit class AdaptiveLeakyReLU(sk.TreeClass): """Leaky ReLU activation function @@ -40,7 +53,16 @@ class AdaptiveLeakyReLU(sk.TreeClass): def __call__(self, x: jax.Array) -> jax.Array: v = jax.lax.stop_gradient(self.v) - return jnp.maximum(0, self.a * x) - v * jnp.maximum(0, -self.a * x) + return adaptive_leaky_relu(x, self.a, v) + + +def adaptive_relu(x: jax.typing.ArrayLike, a: float = 1.0) -> jax.Array: + """Adaptive ReLU activation function + + Reference: + https://arxiv.org/pdf/1906.01170.pdf. + """ + return jnp.maximum(0, a * x) @sk.autoinit @@ -54,7 +76,16 @@ class AdaptiveReLU(sk.TreeClass): a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array) -> jax.Array: - return jnp.maximum(0, self.a * x) + return adaptive_relu(x, self.a) + + +def adaptive_sigmoid(x: jax.typing.ArrayLike, a: float = 1.0) -> jax.Array: + """Adaptive sigmoid activation function + + Reference: + https://arxiv.org/pdf/1906.01170.pdf. + """ + return 1 / (1 + jnp.exp(-a * x)) @sk.autoinit @@ -68,7 +99,16 @@ class AdaptiveSigmoid(sk.TreeClass): a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array) -> jax.Array: - return 1 / (1 + jnp.exp(-self.a * x)) + return adaptive_sigmoid(x, self.a) + + +def adaptive_tanh(x: jax.typing.ArrayLike, a: float = 1.0) -> jax.Array: + """Adaptive tanh activation function + + Reference: + https://arxiv.org/pdf/1906.01170.pdf. + """ + return (jnp.exp(a * x) - jnp.exp(-a * x)) / (jnp.exp(a * x) + jnp.exp(-a * x)) @sk.autoinit @@ -83,7 +123,7 @@ class AdaptiveTanh(sk.TreeClass): def __call__(self, x: jax.Array) -> jax.Array: a = self.a - return (jnp.exp(a * x) - jnp.exp(-a * x)) / (jnp.exp(a * x) + jnp.exp(-a * x)) + return adaptive_tanh(x, a) @sk.autoinit @@ -222,6 +262,19 @@ def __call__(self, x: jax.Array) -> jax.Array: return x / (1 + jnp.abs(x)) +def softshrink(x: jax.typing.ArrayLike, alpha: float = 0.5) -> jax.Array: + """Soft shrink activation function + + Reference: + https://arxiv.org/pdf/1702.00783.pdf. + """ + return jnp.where( + x < -alpha, + x + alpha, + jnp.where(x > alpha, x - alpha, 0.0), + ) + + @sk.autoinit class SoftShrink(sk.TreeClass): """SoftShrink activation function""" @@ -230,18 +283,23 @@ class SoftShrink(sk.TreeClass): def __call__(self, x: jax.Array) -> jax.Array: alpha = lax.stop_gradient(self.alpha) - return jnp.where( - x < -alpha, - x + alpha, - jnp.where(x > alpha, x - alpha, 0.0), - ) + return softshrink(x, alpha) + + +def squareplus(x: jax.typing.ArrayLike) -> jax.Array: + """SquarePlus activation function + + Reference: + https://arxiv.org/pdf/1908.08681.pdf. + """ + return 0.5 * (x + jnp.sqrt(x * x + 4)) class SquarePlus(sk.TreeClass): """SquarePlus activation function""" def __call__(self, x: jax.Array) -> jax.Array: - return 0.5 * (x + jnp.sqrt(x * x + 4)) + return squareplus(x) class Swish(sk.TreeClass): @@ -265,6 +323,15 @@ def __call__(self, x: jax.Array) -> jax.Array: return x - jax.nn.tanh(x) +def thresholded_relu(x: jax.typing.ArrayLike, theta: float = 1.0) -> jax.Array: + """Thresholded ReLU activation function + + Reference: + https://arxiv.org/pdf/1911.09737.pdf. + """ + return jnp.where(x > theta, x, 0) + + @sk.autoinit class ThresholdedReLU(sk.TreeClass): """Thresholded ReLU activation function.""" @@ -273,14 +340,24 @@ class ThresholdedReLU(sk.TreeClass): def __call__(self, x: jax.Array) -> jax.Array: theta = lax.stop_gradient(self.theta) - return jnp.where(x > theta, x, 0) + return thresholded_relu(x, theta) + + +def mish(x: jax.typing.ArrayLike) -> jax.Array: + """Mish activation function https://arxiv.org/pdf/1908.08681.pdf.""" + return x * jax.nn.tanh(jax.nn.softplus(x)) class Mish(sk.TreeClass): """Mish activation function https://arxiv.org/pdf/1908.08681.pdf.""" def __call__(self, x: jax.Array) -> jax.Array: - return x * jax.nn.tanh(jax.nn.softplus(x)) + return mish(x) + + +def prelu(x: jax.typing.ArrayLike, a: float = 0.25) -> jax.Array: + """Parametric ReLU activation function""" + return jnp.where(x >= 0, x, x * a) @sk.autoinit @@ -290,7 +367,19 @@ class PReLU(sk.TreeClass): a: float = sk.field(default=0.25, callbacks=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array) -> jax.Array: - return jnp.where(x >= 0, x, x * self.a) + return prelu(x, self.a) + + +def snake(x: jax.typing.ArrayLike, a: float = 1.0) -> jax.Array: + """Snake activation function + + Args: + a: scalar (frequency) parameter of the activation function, default is 1.0. + + Reference: + https://arxiv.org/pdf/2006.08195.pdf. + """ + return x + (1 - jnp.cos(2 * a * x)) / (2 * a) @sk.autoinit @@ -308,7 +397,7 @@ class Snake(sk.TreeClass): def __call__(self, x: jax.Array) -> jax.Array: a = lax.stop_gradient(self.a) - return x + (1 - jnp.cos(2 * a * x)) / (2 * a) + return snake(x, a) # useful for building layers from configuration text diff --git a/serket/nn/clustering.py b/serket/nn/clustering.py index d88aaec..9c1299f 100644 --- a/serket/nn/clustering.py +++ b/serket/nn/clustering.py @@ -144,7 +144,7 @@ class KMeans(sk.TreeClass): >>> labels, state = layer(x) >>> plt.scatter(x[:, 0], x[:, 1], c=labels[:, 0], cmap="jet_r") # doctest: +SKIP >>> plt.scatter(state.centers[:, 0], state.centers[:, 1], c="r", marker="o", linewidths=4) # doctest: +SKIP - + .. image:: ../_static/kmeans.svg :width: 600 :align: center @@ -227,7 +227,6 @@ def __call__( ) -> tuple[jax.Array, KMeansState]: distances = distances_from_centers(x, state.centers) labels = labels_from_distances(distances) - state = state._replace(iters=None, error=None) return labels, state diff --git a/serket/nn/convolution.py b/serket/nn/convolution.py index 3fec1f3..e8624a5 100644 --- a/serket/nn/convolution.py +++ b/serket/nn/convolution.py @@ -25,6 +25,7 @@ import jax.numpy as jnp import jax.random as jr from jax.lax import ConvDimensionNumbers +from typing_extensions import Annotated import serket as sk from serket.nn.initialization import DType, InitType, resolve_init @@ -212,6 +213,44 @@ def spatial_ndim(self) -> int: ... +def convolution_ndim( + array: jax.Array, + weight: Annotated[jax.Array, "OIHW"], + bias: jax.Array | None, + strides: tuple[int, ...], + padding: tuple[tuple[int, int], ...], + dilation: tuple[int, ...], + groups: int, +) -> jax.Array: + """Convolution function wrapping ``jax.lax.conv_general_dilated``. + + Args: + array: input array. shape is (in_features, *spatial). + weight: convolutional kernel. shape is (out_features, in_features, *kernel). + bias: bias. shape is (out_features, (1,)*spatial). + strides: stride of the convolution accepts tuple of integers for different + strides in each dimension. + padding: padding of the input before convolution accepts tuple of integers + for different padding in each dimension. + dilation: dilation of the convolutional kernel accepts tuple of integers + for different dilation in each dimension. + groups: number of groups to use for grouped convolution. + """ + x = jax.lax.conv_general_dilated( + lhs=jnp.expand_dims(array, 0), + rhs=weight, + window_strides=strides, + padding=padding, + rhs_dilation=dilation, + dimension_numbers=generate_conv_dim_numbers(array.ndim - 1), + feature_group_count=groups, + ) + + if bias is None: + return jnp.squeeze(x, 0) + return jnp.squeeze((x + bias), 0) + + class ConvND(BaseConvND): @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @@ -224,20 +263,16 @@ def __call__(self, x: jax.Array) -> jax.Array: strides=self.strides, ) - x = jax.lax.conv_general_dilated( - lhs=jnp.expand_dims(x, 0), - rhs=self.weight, - window_strides=self.strides, + return convolution_ndim( + array=x, + weight=self.weight, + bias=self.bias, + strides=self.strides, padding=padding, - rhs_dilation=self.dilation, - dimension_numbers=generate_conv_dim_numbers(self.spatial_ndim), - feature_group_count=self.groups, + dilation=self.dilation, + groups=self.groups, ) - if self.bias is None: - return jnp.squeeze(x, 0) - return jnp.squeeze((x + self.bias), 0) - class Conv1D(ConvND): """1D Convolutional layer. @@ -524,6 +559,43 @@ def spatial_ndim(self) -> int: return 3 +def fft_convolution_ndim( + array: jax.Array, + weight: Annotated[jax.Array, "OIHW"], + bias: jax.Array | None, + strides: tuple[int, ...], + padding: tuple[tuple[int, int], ...], + dilation: tuple[int, ...], + groups: int, +) -> jax.Array: + """Convolution function wrapping ``fft_conv_general_dilated``. + + Args: + array: input array. shape is (in_features, *spatial). + weight: convolutional kernel. shape is (out_features, in_features, *kernel). + bias: bias. shape is (out_features, (1,)*spatial). + strides: stride of the convolution accepts tuple of integers for different + strides in each dimension. + padding: padding of the input before convolution accepts tuple of integers + for different padding in each dimension. + dilation: dilation of the convolutional kernel accepts tuple of integers + for different dilation in each dimension. + groups: number of groups to use for grouped convolution. + """ + x = fft_conv_general_dilated( + lhs=jnp.expand_dims(array, 0), + rhs=weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + ) + + if bias is None: + return jnp.squeeze(x, 0) + return jnp.squeeze((x + bias), 0) + + class FFTConvND(BaseConvND): @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @@ -536,19 +608,16 @@ def __call__(self, x: jax.Array) -> jax.Array: strides=self.strides, ) - x = fft_conv_general_dilated( - lhs=jnp.expand_dims(x, 0), - rhs=self.weight, + return fft_convolution_ndim( + array=x, + weight=self.weight, + bias=self.bias, strides=self.strides, padding=padding, - groups=self.groups, dilation=self.dilation, + groups=self.groups, ) - if self.bias is None: - return jnp.squeeze(x, 0) - return jnp.squeeze((x + self.bias), 0) - class FFTConv1D(FFTConvND): """1D Convolutional layer. @@ -881,6 +950,49 @@ def spatial_ndim(self) -> int: ... +def transposed_convolution_ndim( + array: jax.Array, + weight: Annotated[jax.Array, "OIHW"], + bias: jax.Array | None, + strides: tuple[int, ...], + padding: tuple[tuple[int, int], ...], + dilation: tuple[int, ...], + out_padding: int, +) -> jax.Array: + """Transposed convolution function wrapping ``jax.lax.conv_general_dilated``. + + Args: + array: input array. shape is (in_features, *spatial). + weight: convolutional kernel. shape is (out_features, in_features, *kernel). + bias: bias. shape is (out_features, (1,)*spatial). + strides: stride of the convolution accepts tuple of integers for different + strides in each dimension. + padding: padding of the input before convolution accepts tuple of integers + for different padding in each dimension. + dilation: dilation of the convolutional kernel accepts tuple of integers + for different dilation in each dimension. + out_padding: padding of the output after convolution. + """ + transposed_padding = calculate_transpose_padding( + padding=padding, + extra_padding=out_padding, + kernel_size=weight.shape[2:], + input_dilation=dilation, + ) + x = jax.lax.conv_transpose( + lhs=jnp.expand_dims(array, 0), + rhs=weight, + strides=strides, + padding=transposed_padding, + rhs_dilation=dilation, + dimension_numbers=generate_conv_dim_numbers(array.ndim - 1), + ) + + if bias is None: + return jnp.squeeze(x, 0) + return jnp.squeeze(x + bias, 0) + + class ConvNDTranspose(BaseConvNDTranspose): @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @@ -893,26 +1005,16 @@ def __call__(self, x: jax.Array) -> jax.Array: strides=self.strides, ) - transposed_padding = calculate_transpose_padding( - padding=padding, - extra_padding=self.out_padding, - kernel_size=self.kernel_size, - input_dilation=self.dilation, - ) - - x = jax.lax.conv_transpose( - lhs=jnp.expand_dims(x, 0), - rhs=self.weight, + return transposed_convolution_ndim( + array=x, + weight=self.weight, + bias=self.bias, strides=self.strides, - padding=transposed_padding, - rhs_dilation=self.dilation, - dimension_numbers=generate_conv_dim_numbers(self.spatial_ndim), + padding=padding, + dilation=self.dilation, + out_padding=self.out_padding, ) - if self.bias is None: - return jnp.squeeze(x, 0) - return jnp.squeeze(x + self.bias, 0) - class Conv1DTranspose(ConvNDTranspose): """1D Convolution transpose layer. @@ -1211,6 +1313,49 @@ def spatial_ndim(self) -> int: return 3 +def transposed_fft_convolution_ndim( + array: jax.Array, + weight: Annotated[jax.Array, "OIHW"], + bias: jax.Array | None, + strides: tuple[int, ...], + padding: tuple[tuple[int, int], ...], + dilation: tuple[int, ...], + out_padding: int, +) -> jax.Array: + """Transposed convolution function wrapping ``fft_conv_general_dilated``. + + Args: + array: input array. shape is (in_features, *spatial). + weight: convolutional kernel. shape is (out_features, in_features, *kernel). + bias: bias. shape is (out_features, (1,)*spatial). + strides: stride of the convolution accepts tuple of integers for different + strides in each dimension. + padding: padding of the input before convolution accepts tuple of integers + for different padding in each dimension. + dilation: dilation of the convolutional kernel accepts tuple of integers + for different dilation in each dimension. + out_padding: padding of the output after convolution. + """ + transposed_padding = calculate_transpose_padding( + padding=padding, + extra_padding=out_padding, + kernel_size=weight.shape[2:], + input_dilation=dilation, + ) + x = fft_conv_general_dilated( + lhs=jnp.expand_dims(array, 0), + rhs=weight, + strides=strides, + padding=transposed_padding, + dilation=dilation, + groups=1, + ) + + if bias is None: + return jnp.squeeze(x, 0) + return jnp.squeeze(x + bias, 0) + + class FFTConvNDTranspose(BaseConvNDTranspose): @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @@ -1223,24 +1368,15 @@ def __call__(self, x: jax.Array) -> jax.Array: strides=self.strides, ) - transposed_padding = calculate_transpose_padding( - padding=padding, - extra_padding=self.out_padding, - kernel_size=self.kernel_size, - input_dilation=self.dilation, - ) - - x = fft_conv_general_dilated( - lhs=jnp.expand_dims(x, 0), - rhs=self.weight, + return transposed_fft_convolution_ndim( + array=x, + weight=self.weight, + bias=self.bias, strides=self.strides, - padding=transposed_padding, + padding=padding, dilation=self.dilation, - groups=1, + out_padding=self.out_padding, ) - if self.bias is None: - return jnp.squeeze(x, 0) - return jnp.squeeze(x + self.bias, 0) class FFTConv1DTranspose(FFTConvNDTranspose): @@ -1576,6 +1712,40 @@ def spatial_ndim(self) -> int: ... +def depthwise_convolution_ndim( + array: jax.Array, + weight: Annotated[jax.Array, "OIHW"], + bias: jax.Array | None, + strides: tuple[int, ...], + padding: tuple[tuple[int, int], ...], +) -> jax.Array: + """Depthwise convolution function wrapping ``jax.lax.conv_general_dilated``. + + Args: + array: input array. shape is (in_features, *spatial). + weight: convolutional kernel. shape is (out_features, in_features, *kernel). + bias: bias. shape is (out_features, (1,)*spatial). + strides: stride of the convolution accepts tuple of integers for different + strides in each dimension. + padding: padding of the input before convolution accepts tuple of integers + for different padding in each dimension. + """ + + x = jax.lax.conv_general_dilated( + lhs=jnp.expand_dims(array, 0), + rhs=weight, + window_strides=strides, + padding=padding, + rhs_dilation=(1,) * (array.ndim - 1), + dimension_numbers=generate_conv_dim_numbers(array.ndim - 1), + feature_group_count=array.shape[0], # in_features + ) + + if bias is None: + return jnp.squeeze(x, 0) + return jnp.squeeze(x + bias, 0) + + class DepthwiseConvND(BaseDepthwiseConvND): @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @@ -1588,20 +1758,14 @@ def __call__(self, x: jax.Array) -> jax.Array: strides=self.strides, ) - x = jax.lax.conv_general_dilated( - lhs=jnp.expand_dims(x, 0), - rhs=self.weight, - window_strides=self.strides, + return depthwise_convolution_ndim( + array=x, + weight=self.weight, + bias=self.bias, + strides=self.strides, padding=padding, - rhs_dilation=canonicalize(1, self.spatial_ndim, "dilation"), - dimension_numbers=generate_conv_dim_numbers(self.spatial_ndim), - feature_group_count=self.in_features, ) - if self.bias is None: - return jnp.squeeze(x, 0) - return jnp.squeeze((x + self.bias), 0) - class DepthwiseConv1D(DepthwiseConvND): """1D Depthwise convolution layer. @@ -1850,6 +2014,39 @@ def spatial_ndim(self) -> int: return 3 +def depthwise_fft_convolution_ndim( + array: jax.Array, + weight: Annotated[jax.Array, "OIHW"], + bias: jax.Array | None, + strides: tuple[int, ...], + padding: tuple[tuple[int, int], ...], +) -> jax.Array: + """Depthwise convolution function wrapping ``jax.lax.conv_general_dilated``. + + Args: + array: input array. shape is (in_features, *spatial). + weight: convolutional kernel. shape is (out_features, in_features, *kernel). + bias: bias. shape is (out_features, (1,)*spatial). + strides: stride of the convolution accepts tuple of integers for different + strides in each dimension. + padding: padding of the input before convolution accepts tuple of integers + for different padding in each dimension. + """ + + x = fft_conv_general_dilated( + lhs=jnp.expand_dims(array, 0), + rhs=weight, + strides=strides, + padding=padding, + dilation=(1,) * (array.ndim - 1), + groups=array.shape[0], # in_features + ) + + if bias is None: + return jnp.squeeze(x, 0) + return jnp.squeeze(x + bias, 0) + + class DepthwiseFFTConvND(BaseDepthwiseConvND): @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @@ -1862,19 +2059,14 @@ def __call__(self, x: jax.Array) -> jax.Array: strides=self.strides, ) - x = fft_conv_general_dilated( - lhs=jnp.expand_dims(x, 0), - rhs=self.weight, + return depthwise_fft_convolution_ndim( + array=x, + weight=self.weight, + bias=self.bias, strides=self.strides, padding=padding, - dilation=canonicalize(1, self.spatial_ndim, "dilation"), - groups=self.in_features, ) - if self.bias is None: - return jnp.squeeze(x, 0) - return jnp.squeeze((x + self.bias), 0) - class DepthwiseFFTConv1D(DepthwiseFFTConvND): """1D Depthwise FFT convolution layer. @@ -2770,6 +2962,46 @@ def infer_in_size(_, x, *__, **___) -> tuple[int, ...]: updates = {**dict(in_size=infer_in_size), **updates} +def local_convolution_ndim( + array: jax.Array, + weight: Annotated[jax.Array, "OIHW"], + bias: jax.Array | None, + strides: tuple[int, ...], + padding: tuple[tuple[int, int], ...], + dilation: tuple[int, ...], + kernel_size: tuple[int, ...], +) -> jax.Array: + """Local convolution function wrapping ``jax.lax.conv_general_dilated_local``. + + Args: + array: input array. shape is (in_features, *spatial). + weight: convolutional kernel. shape is (out_features, in_features, *kernel). + bias: bias. shape is (out_features, (1,)*spatial). + strides: stride of the convolution accepts tuple of integers for different + strides in each dimension. + padding: padding of the input before convolution accepts tuple of integers + for different padding in each dimension. + dilation: dilation of the convolution accepts tuple of integers for different + dilation in each dimension. + kernel_size: size of the convolutional kernel accepts tuple of integers for + different kernel sizes in each dimension. + """ + + x = jax.lax.conv_general_dilated_local( + lhs=jnp.expand_dims(array, 0), + rhs=weight, + window_strides=strides, + padding=padding, + filter_shape=kernel_size, + rhs_dilation=dilation, + dimension_numbers=generate_conv_dim_numbers(array.ndim - 1), + ) + + if bias is None: + return jnp.squeeze(x, 0) + return jnp.squeeze(x + bias, 0) + + class ConvNDLocal(sk.TreeClass): @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) def __init__( @@ -2827,20 +3059,16 @@ def __init__( @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array) -> jax.Array: - x = jax.lax.conv_general_dilated_local( - lhs=jnp.expand_dims(x, 0), - rhs=self.weight, - window_strides=self.strides, + return local_convolution_ndim( + array=x, + weight=self.weight, + bias=self.bias, + strides=self.strides, padding=self.padding, - filter_shape=self.kernel_size, - rhs_dilation=self.dilation, - dimension_numbers=generate_conv_dim_numbers(self.spatial_ndim), + dilation=self.dilation, + kernel_size=self.kernel_size, ) - if self.bias is None: - return jnp.squeeze(x, 0) - return jnp.squeeze((x + self.bias), 0) - @property @abc.abstractmethod def spatial_ndim(self) -> int: diff --git a/tests/test_clustering.py b/tests/test_clustering.py index 1177661..90611bf 100644 --- a/tests/test_clustering.py +++ b/tests/test_clustering.py @@ -42,8 +42,6 @@ def test_kmeans(): # pick a point near one of the centers xx = jnp.array([[0.5, 0.2]]) labels, eval_state = sk.tree_eval(layer)(xx, state) - assert eval_state.iters is None - assert eval_state.error is None # centers should not change npt.assert_allclose(state.centers, eval_state.centers, atol=1e-6) assert labels[0] == 0