Skip to content

Commit

Permalink
ND subclassing
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 25, 2023
1 parent 99ac525 commit 1b2d206
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 457 deletions.
211 changes: 85 additions & 126 deletions serket/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,65 @@ def spatial_ndim(self) -> int:
return 3


class SeparableConv1D(sk.TreeClass):
class SeparableConvND(sk.TreeClass):
def __init__(
self,
in_features: int,
out_features: int,
kernel_size: KernelSizeType,
*,
depth_multiplier: int = 1,
strides: StridesType = 1,
padding: PaddingType = "SAME",
depthwise_weight_init_func: InitType = "glorot_uniform",
pointwise_weight_init_func: InitType = "glorot_uniform",
pointwise_bias_init_func: InitType = "zeros",
key: jr.KeyArray = jr.PRNGKey(0),
):
self.depthwise_conv = self.depthwise_convolution_layer(
in_features=in_features,
depth_multiplier=depth_multiplier,
kernel_size=kernel_size,
strides=strides,
padding=padding,
weight_init_func=depthwise_weight_init_func,
bias_init_func=None, # no bias for lhs
key=key,
)

self.pointwise_conv = self.pointwise_convolution_layer(
in_features=in_features * depth_multiplier,
out_features=out_features,
kernel_size=1,
strides=strides,
padding=padding,
weight_init_func=pointwise_weight_init_func,
bias_init_func=pointwise_bias_init_func,
key=key,
)

def __call__(self, x: jax.Array, **k) -> jax.Array:
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
return x

@property
@abc.abstractmethod
def spatial_ndim(self) -> int:
...

@property
@abc.abstractmethod
def pointwise_convolution_layer(self):
...

@property
@abc.abstractmethod
def depthwise_convolution_layer(self):
...


class SeparableConv1D(SeparableConvND):
"""1D Separable convolution layer.
Separable convolution is a depthwise convolution followed by a pointwise
Expand Down Expand Up @@ -924,53 +982,20 @@ class SeparableConv1D(sk.TreeClass):
- https://github.com/google/flax/blob/main/flax/linen/linear.py
"""

def __init__(
self,
in_features: int,
out_features: int,
kernel_size: KernelSizeType,
*,
depth_multiplier: int = 1,
strides: StridesType = 1,
padding: PaddingType = "SAME",
depthwise_weight_init_func: InitType = "glorot_uniform",
pointwise_weight_init_func: InitType = "glorot_uniform",
pointwise_bias_init_func: InitType = "zeros",
key: jr.KeyArray = jr.PRNGKey(0),
):
self.depthwise_conv = DepthwiseConv1D(
in_features=in_features,
depth_multiplier=depth_multiplier,
kernel_size=kernel_size,
strides=strides,
padding=padding,
weight_init_func=depthwise_weight_init_func,
bias_init_func=None, # no bias for lhs
key=key,
)

self.pointwise_conv = Conv1D(
in_features=in_features * depth_multiplier,
out_features=out_features,
kernel_size=1,
strides=strides,
padding=padding,
weight_init_func=pointwise_weight_init_func,
bias_init_func=pointwise_bias_init_func,
key=key,
)

def __call__(self, x: jax.Array, **k) -> jax.Array:
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
return x

@property
def spatial_ndim(self) -> int:
return 1

@property
def pointwise_convolution_layer(self):
return Conv1D

@property
def depthwise_convolution_layer(self):
return DepthwiseConv1D


class SeparableConv2D(sk.TreeClass):
class SeparableConv2D(SeparableConvND):
"""2D Separable convolution layer.
Separable convolution is a depthwise convolution followed by a pointwise
Expand Down Expand Up @@ -1029,53 +1054,20 @@ class SeparableConv2D(sk.TreeClass):
- https://github.com/google/flax/blob/main/flax/linen/linear.py
"""

def __init__(
self,
in_features: int,
out_features: int,
kernel_size: KernelSizeType,
*,
depth_multiplier: int = 1,
strides: StridesType = 1,
padding: PaddingType = "SAME",
depthwise_weight_init_func: InitType = "glorot_uniform",
pointwise_weight_init_func: InitType = "glorot_uniform",
pointwise_bias_init_func: InitType = "zeros",
key: jr.KeyArray = jr.PRNGKey(0),
):
self.depthwise_conv = DepthwiseConv2D(
in_features=in_features,
depth_multiplier=depth_multiplier,
kernel_size=kernel_size,
strides=strides,
padding=padding,
weight_init_func=depthwise_weight_init_func,
bias_init_func=None, # no bias for lhs
key=key,
)

self.pointwise_conv = Conv2D(
in_features=in_features * depth_multiplier,
out_features=out_features,
kernel_size=1,
strides=strides,
padding=padding,
weight_init_func=pointwise_weight_init_func,
bias_init_func=pointwise_bias_init_func,
key=key,
)

def __call__(self, x: jax.Array, **k) -> jax.Array:
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
return x

@property
def spatial_ndim(self) -> int:
return 2

@property
def pointwise_convolution_layer(self):
return Conv2D

class SeparableConv3D(sk.TreeClass):
@property
def depthwise_convolution_layer(self):
return DepthwiseConv2D


class SeparableConv3D(SeparableConvND):
"""3D Separable convolution layer.
Separable convolution is a depthwise convolution followed by a pointwise
Expand Down Expand Up @@ -1134,51 +1126,18 @@ class SeparableConv3D(sk.TreeClass):
- https://github.com/google/flax/blob/main/flax/linen/linear.py
"""

def __init__(
self,
in_features: int,
out_features: int,
kernel_size: KernelSizeType,
*,
depth_multiplier: int = 1,
strides: StridesType = 1,
padding: PaddingType = "SAME",
depthwise_weight_init_func: InitType = "glorot_uniform",
pointwise_weight_init_func: InitType = "glorot_uniform",
pointwise_bias_init_func: InitType = "zeros",
key: jr.KeyArray = jr.PRNGKey(0),
):
self.depthwise_conv = DepthwiseConv3D(
in_features=in_features,
depth_multiplier=depth_multiplier,
kernel_size=kernel_size,
strides=strides,
padding=padding,
weight_init_func=depthwise_weight_init_func,
bias_init_func=None, # no bias for lhs
key=key,
)

self.pointwise_conv = Conv3D(
in_features=in_features * depth_multiplier,
out_features=out_features,
kernel_size=1,
strides=strides,
padding=padding,
weight_init_func=pointwise_weight_init_func,
bias_init_func=pointwise_bias_init_func,
key=key,
)

def __call__(self, x: jax.Array, **k) -> jax.Array:
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
return x

@property
def spatial_ndim(self) -> int:
return 3

@property
def pointwise_convolution_layer(self):
return Conv3D

@property
def depthwise_convolution_layer(self):
return DepthwiseConv3D


class ConvNDLocal(sk.TreeClass):
def __init__(
Expand Down
127 changes: 76 additions & 51 deletions serket/nn/fft_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,74 @@ def spatial_ndim(self) -> int:
return 3


class SeparableFFTConv1D(sk.TreeClass):
class SeparableFFTConvND(sk.TreeClass):
def __init__(
self,
in_features: int,
out_features: int,
kernel_size: KernelSizeType,
*,
depth_multiplier: int = 1,
strides: StridesType = 1,
padding: PaddingType = "SAME",
depthwise_weight_init_func: InitType = "glorot_uniform",
pointwise_weight_init_func: InitType = "glorot_uniform",
pointwise_bias_init_func: InitType = "zeros",
key: jr.KeyArray = jr.PRNGKey(0),
):
self.in_features = in_features
self.depth_multiplier = canonicalize(
depth_multiplier,
self.in_features,
name="depth_multiplier",
)

self.depthwise_conv = DepthwiseFFTConv1D(
in_features=in_features,
depth_multiplier=depth_multiplier,
kernel_size=kernel_size,
strides=strides,
padding=padding,
weight_init_func=depthwise_weight_init_func,
bias_init_func=None, # no bias for lhs
key=key,
)

self.pointwise_conv = FFTConv1D(
in_features=in_features * depth_multiplier,
out_features=out_features,
kernel_size=1,
strides=strides,
padding=padding,
weight_init_func=pointwise_weight_init_func,
bias_init_func=pointwise_bias_init_func,
key=key,
)

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array, **k) -> jax.Array:
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
return x

@property
@abc.abstractclassmethod
def spatial_ndim(self) -> int:
...

@property
@abc.abstractclassmethod
def pointwise_convolution_layer(self):
...

@property
@abc.abstractclassmethod
def depthwise_convolution_layer(self):
...


class SeparableFFTConv1D(SeparableFFTConvND):
"""1D Separable FFT convolution layer.
Separable convolution is a depthwise convolution followed by a pointwise
Expand Down Expand Up @@ -1011,60 +1078,18 @@ class SeparableFFTConv1D(sk.TreeClass):
- https://github.com/google/flax/blob/main/flax/linen/linear.py
"""

def __init__(
self,
in_features: int,
out_features: int,
kernel_size: KernelSizeType,
*,
depth_multiplier: int = 1,
strides: StridesType = 1,
padding: PaddingType = "SAME",
depthwise_weight_init_func: InitType = "glorot_uniform",
pointwise_weight_init_func: InitType = "glorot_uniform",
pointwise_bias_init_func: InitType = "zeros",
key: jr.KeyArray = jr.PRNGKey(0),
):
self.in_features = in_features
self.depth_multiplier = canonicalize(
depth_multiplier,
self.in_features,
name="depth_multiplier",
)

self.depthwise_conv = DepthwiseFFTConv1D(
in_features=in_features,
depth_multiplier=depth_multiplier,
kernel_size=kernel_size,
strides=strides,
padding=padding,
weight_init_func=depthwise_weight_init_func,
bias_init_func=None, # no bias for lhs
key=key,
)

self.pointwise_conv = FFTConv1D(
in_features=in_features * depth_multiplier,
out_features=out_features,
kernel_size=1,
strides=strides,
padding=padding,
weight_init_func=pointwise_weight_init_func,
bias_init_func=pointwise_bias_init_func,
key=key,
)

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array, **k) -> jax.Array:
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
return x

@property
def spatial_ndim(self) -> int:
return 1

@property
def pointwise_convolution_layer(self):
return FFTConv1D

@property
def depthwise_convolution_layer(self):
return DepthwiseFFTConv1D


class SeparableFFTConv2D(sk.TreeClass):
"""2D Separable FFT convolution layer.
Expand Down
Loading

0 comments on commit 1b2d206

Please sign in to comment.