Skip to content

Commit

Permalink
more edits
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 1, 2023
1 parent e438def commit e0f25f3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 46 deletions.
67 changes: 23 additions & 44 deletions serket/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def infer_key(instance, *_, **__) -> jr.KeyArray:
class BaseConvND(sk.TreeClass):
def __init__(
self,
in_features: int,
in_features: int | None,
out_features: int,
kernel_size: KernelSizeType,
*,
Expand All @@ -190,18 +190,12 @@ def __init__(
):
self.in_features = positive_int_or_none_cb(in_features)
self.out_features = positive_int_cb(out_features)
self.kernel_size = canonicalize(
kernel_size,
self.spatial_ndim,
name="kernel_size",
)
self.strides = canonicalize(strides, self.spatial_ndim, name="strides")
self.kernel_size = canonicalize(kernel_size, self.spatial_ndim, "kernel_size")
self.strides = canonicalize(strides, self.spatial_ndim, "strides")
self.padding = padding
self.dilation = canonicalize(dilation, self.spatial_ndim, name="dilation")

self.dilation = canonicalize(dilation, self.spatial_ndim, "dilation")
self.weight_init = weight_init
self.bias_init = bias_init

self.groups = positive_int_cb(groups)

if in_features is None:
Expand Down Expand Up @@ -676,32 +670,26 @@ def spatial_ndim(self) -> int:
class BaseConvNDTranspose(sk.TreeClass):
def __init__(
self,
in_features: int,
in_features: int | None,
out_features: int,
kernel_size: KernelSizeType,
*,
strides: StridesType = 1,
padding: PaddingType = "same",
output_padding: int = 0,
out_padding: int = 0,
dilation: DilationType = 1,
weight_init: InitType = "glorot_uniform",
bias_init: InitType = "zeros",
groups: int = 1,
key: jr.KeyArray = jr.PRNGKey(0),
):
self.in_features = positive_int_or_none_cb(in_features)
self.in_features = in_features
self.out_features = positive_int_cb(out_features)
self.kernel_size = canonicalize(
kernel_size, self.spatial_ndim, name="kernel_size"
)
self.strides = canonicalize(strides, self.spatial_ndim, name="strides")
self.kernel_size = canonicalize(kernel_size, self.spatial_ndim, "kernel_size")
self.strides = canonicalize(strides, self.spatial_ndim, "strides")
self.padding = padding # delayed canonicalization
self.output_padding = canonicalize(
output_padding,
self.spatial_ndim,
name="output_padding",
)
self.dilation = canonicalize(dilation, self.spatial_ndim, name="dilation")
self.out_padding = canonicalize(out_padding, self.spatial_ndim, "out_padding")
self.dilation = canonicalize(dilation, self.spatial_ndim, "dilation")
self.weight_init = weight_init
self.bias_init = bias_init
self.groups = positive_int_cb(groups)
Expand All @@ -713,6 +701,7 @@ def __init__(
if self.out_features % self.groups != 0:
raise ValueError(f"{(self.out_features % self.groups ==0)=}")

in_features = positive_int_cb(self.in_features)
weight_shape = (out_features, in_features // groups, *self.kernel_size) # OIHW
self.weight = resolve_init_func(self.weight_init)(key, weight_shape)

Expand Down Expand Up @@ -740,7 +729,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:

transposed_padding = calculate_transpose_padding(
padding=padding,
extra_padding=self.output_padding,
extra_padding=self.out_padding,
kernel_size=self.kernel_size,
input_dilation=self.dilation,
)
Expand Down Expand Up @@ -983,7 +972,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:

transposed_padding = calculate_transpose_padding(
padding=padding,
extra_padding=self.output_padding,
extra_padding=self.out_padding,
kernel_size=self.kernel_size,
input_dilation=self.dilation,
)
Expand Down Expand Up @@ -1214,7 +1203,7 @@ def spatial_ndim(self) -> int:
class BaseDepthwiseConvND(sk.TreeClass):
def __init__(
self,
in_features: int,
in_features: int | None,
kernel_size: KernelSizeType,
*,
depth_multiplier: int = 1,
Expand All @@ -1225,15 +1214,11 @@ def __init__(
key: jr.KeyArray = jr.PRNGKey(0),
):
self.in_features = positive_int_or_none_cb(in_features)
self.kernel_size = canonicalize(
kernel_size,
self.spatial_ndim,
name="kernel_size",
)
self.kernel_size = canonicalize(kernel_size, self.spatial_ndim, "kernel_size")
self.depth_multiplier = positive_int_cb(depth_multiplier)
self.strides = canonicalize(strides, self.spatial_ndim, name="strides")
self.strides = canonicalize(strides, self.spatial_ndim, "strides")
self.padding = padding # delayed canonicalization
self.dilation = canonicalize(1, self.spatial_ndim, name="dilation")
self.dilation = canonicalize(1, self.spatial_ndim, "dilation")
self.weight_init = weight_init
self.bias_init = bias_init

Expand Down Expand Up @@ -1635,7 +1620,7 @@ def spatial_ndim(self) -> int:
class SeparableConvND(sk.TreeClass):
def __init__(
self,
in_features: int,
in_features: int | None,
out_features: int,
kernel_size: KernelSizeType,
*,
Expand Down Expand Up @@ -2144,7 +2129,7 @@ def _depthwise_convolution_layer(self):
class ConvNDLocal(sk.TreeClass):
def __init__(
self,
in_features: int,
in_features: int | None,
out_features: int,
kernel_size: KernelSizeType,
*,
Expand All @@ -2158,19 +2143,13 @@ def __init__(
):
self.in_features = positive_int_or_none_cb(in_features)
self.out_features = positive_int_cb(out_features)
self.kernel_size = canonicalize(
kernel_size,
self.spatial_ndim,
name="kernel_size",
)

self.kernel_size = canonicalize(kernel_size, self.spatial_ndim, "kernel_size")
self.in_size = (
canonicalize(in_size, self.spatial_ndim, name="in_size")
if in_size is not None
else None
)

self.strides = canonicalize(strides, self.spatial_ndim, name="strides")
self.strides = canonicalize(strides, self.spatial_ndim, "strides")

if in_size is None:
self.padding = padding
Expand All @@ -2183,7 +2162,7 @@ def __init__(
self.strides,
)

self.dilation = canonicalize(dilation, self.spatial_ndim, name="dilation")
self.dilation = canonicalize(dilation, self.spatial_ndim, "dilation")
self.weight_init = weight_init
self.bias_init = bias_init

Expand Down
4 changes: 2 additions & 2 deletions serket/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def delayed_canonicalize_padding(
)


def canonicalize(value, ndim, *, name: str | None = None):
def canonicalize(value, ndim, name: str | None = None):
if isinstance(value, int):
return (value,) * ndim
if isinstance(value, jax.Array):
Expand Down Expand Up @@ -275,7 +275,7 @@ def check_spatial_in_shape(x, spatial_ndim: int) -> None:
spatial = {", ".join(("rows", "cols", "depths")[:spatial_ndim])}
raise ValueError(
f"Input should satisfy:\n"
f"- {spatial_ndim+1=} dimension, got {x.ndim=}.\n"
f"- {(spatial_ndim + 1)=} dimension, got {x.ndim=}.\n"
f"- shape of (in_features, {spatial}), got {x.shape=}.\n"
+ (
# maybe the user apply the layer on a batched input
Expand Down

0 comments on commit e0f25f3

Please sign in to comment.