From 584cf53e5d35567a19a61ec69bb5e7c4e558625f Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Mon, 28 Aug 2023 21:02:14 +0900 Subject: [PATCH] remove blocks --- CHANEGLOG.md | 1 + docs/API/api.rst | 1 - docs/API/containers.rst | 3 +- docs/API/misc.rst | 9 -- serket/nn/__init__.py | 6 - serket/nn/blocks/__init__.py | 18 --- serket/nn/blocks/unet.py | 176 ---------------------------- serket/nn/blocks/vgg.py | 214 ----------------------------------- serket/nn/containers.py | 4 +- tests/test_blocks.py | 52 --------- 10 files changed, 5 insertions(+), 479 deletions(-) delete mode 100644 docs/API/misc.rst delete mode 100644 serket/nn/blocks/__init__.py delete mode 100644 serket/nn/blocks/unet.py delete mode 100644 serket/nn/blocks/vgg.py delete mode 100644 tests/test_blocks.py diff --git a/CHANEGLOG.md b/CHANEGLOG.md index ae720e4..3ee5e26 100644 --- a/CHANEGLOG.md +++ b/CHANEGLOG.md @@ -55,3 +55,4 @@ - `Bilinear` is deprecated, use `Multilinear((in1_features, in2_features), out_features)` - `HistogramEqualization2D` +- Remove `.blocks`, and will move it to examples \ No newline at end of file diff --git a/docs/API/api.rst b/docs/API/api.rst index ec98a6f..79ffbb7 100644 --- a/docs/API/api.rst +++ b/docs/API/api.rst @@ -24,6 +24,5 @@ pooling recurrent reshaping - misc diff --git a/docs/API/containers.rst b/docs/API/containers.rst index 9474069..4350f8c 100644 --- a/docs/API/containers.rst +++ b/docs/API/containers.rst @@ -4,4 +4,5 @@ Containers .. autoclass:: Sequential -.. autoclass:: RandomApply \ No newline at end of file +.. autoclass:: RandomApply +.. autoclass:: RandomChoice \ No newline at end of file diff --git a/docs/API/misc.rst b/docs/API/misc.rst deleted file mode 100644 index 995e821..0000000 --- a/docs/API/misc.rst +++ /dev/null @@ -1,9 +0,0 @@ -Misc ---------------------------------- -.. currentmodule:: serket.nn - -.. autoclass:: VGG16Block -.. autoclass:: VGG19Block -.. autoclass:: UNetBlock - - diff --git a/serket/nn/__init__.py b/serket/nn/__init__.py index 095fba6..aa98357 100644 --- a/serket/nn/__init__.py +++ b/serket/nn/__init__.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import blocks from .activation import ( ELU, GELU, @@ -46,7 +45,6 @@ ThresholdedReLU, ) from .attention import MultiHeadAttention -from .blocks import UNetBlock, VGG16Block, VGG19Block from .clustering import KMeans from .containers import RandomApply, RandomChoice, Sequential from .convolution import ( @@ -214,10 +212,6 @@ "ThresholdedReLU", # attention "MultiHeadAttention", - # blocks - "UNetBlock", - "VGG16Block", - "VGG19Block", # container "RandomApply", "RandomChoice", diff --git a/serket/nn/blocks/__init__.py b/serket/nn/blocks/__init__.py deleted file mode 100644 index c950706..0000000 --- a/serket/nn/blocks/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2023 Serket authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .unet import UNetBlock -from .vgg import VGG16Block, VGG19Block - -__all__ = ["VGG16Block", "VGG19Block", "UNetBlock"] diff --git a/serket/nn/blocks/unet.py b/serket/nn/blocks/unet.py deleted file mode 100644 index 02a0d95..0000000 --- a/serket/nn/blocks/unet.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2023 Serket authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# see : https://pytorch.org/hub/mateuszbuda_brain-segmentation-pytorch_unet/ -# current implementation is based on the above link - - -from __future__ import annotations - -import jax -import jax.numpy as jnp - -import serket as sk -from serket.nn.utils import positive_int_cb - - -class ResizeAndCat(sk.TreeClass): - def __call__(self, x1: jax.Array, x2: jax.Array) -> jax.Array: - """resize a tensor to the same size as another tensor and concatenate x2 to x1 along the channel axis""" - x1 = jax.image.resize(x1, shape=x2.shape, method="nearest") - x1 = jnp.concatenate([x2, x1], axis=0) - return x1 - - -class DoubleConvBlock(sk.TreeClass): - def __init__(self, in_features: int, out_features: int): - self.conv1 = sk.nn.Conv2D( - in_features=in_features, - out_features=out_features, - kernel_size=3, - padding=1, - bias_init=None, - ) - self.conv2 = sk.nn.Conv2D( - in_features=out_features, - out_features=out_features, - kernel_size=3, - padding=1, - bias_init=None, - ) - - def __call__(self, x: jax.Array, **k) -> jax.Array: - x = self.conv1(x) - x = jax.nn.relu(x) - x = self.conv2(x) - x = jax.nn.relu(x) - return x - - -class UpscaleBlock(sk.TreeClass): - def __init__(self, in_features: int, out_features: int): - self.conv = sk.nn.Conv2DTranspose( - in_features=in_features, - out_features=out_features, - kernel_size=2, - strides=2, - ) - - def __call__(self, x: jax.Array, **k) -> jax.Array: - return self.conv(x) - - -@sk.autoinit -class UNetBlock(sk.TreeClass): - """Vanilla UNet - - Args: - in_features : number of input channels. This is the number of channels in the input image. - out_features : number of output channels. This is the number of classes - blocks : number of blocks in the UNet architecture . Default is 4 - init_features : number of features in the first block. Default is 64 - """ - - in_features: int = sk.field(callbacks=[positive_int_cb]) - out_features: int = sk.field(callbacks=[positive_int_cb]) - blocks: int = sk.field(callbacks=[positive_int_cb], default=4) - init_features: int = sk.field(callbacks=[positive_int_cb], default=64) - - def __post_init__(self): - """ - Note: - d0_1 : - block_number = 0 , operation = (conv->relu) x2 - d0_2 : - block_number = 0 , operation = maxpool previous output - u0_1 : - expansive block corresponding to block 0 in contractive path , - operation = doubling row,col size and halving channels size of previous layer - u0_2 : - expansive block corresponding to block 0 in contractive path , - operation = pad the previous layer from expansive path (u0_1) and concatenate with corresponding - layer from contractive path (d0_1) - u0_3 : - expansive block corresponding to block 0 in contractive path , - operation = (conv->relu) x2 of previous layer (u0_2) - b0_1 : - bottleneck layer - f0_1 : - final output layer - - """ - self.d0_1 = DoubleConvBlock(self.in_features, self.init_features) - self.d0_2 = sk.nn.MaxPool2D(kernel_size=2, strides=2) - - for i in range(1, self.blocks): - in_dim = self.init_features * (2 ** (i - 1)) - out_dim = self.init_features * (2**i) - - layer = DoubleConvBlock(in_dim, out_dim) - setattr(self, f"d{i}_1", layer) - setattr(self, f"d{i}_2", sk.nn.MaxPool2D(kernel_size=2, strides=2)) - - self.b0_1 = DoubleConvBlock( - self.init_features * (2 ** (self.blocks - 1)), - self.init_features * (2 ** (self.blocks)), - ) - - for i in range(self.blocks, 0, -1): - # upscale and conv to halves channels size and double row,col size - in_dim = self.init_features * (2 ** (i - 1)) - out_dim = self.init_features * (2**i) - - layer = UpscaleBlock(out_dim, in_dim) - setattr(self, f"u{i-1}_1", layer) - - layer = ResizeAndCat() - setattr(self, f"u{i-1}_2", layer) - - layer = DoubleConvBlock(out_dim, in_dim) - setattr(self, f"u{i-1}_3", layer) - - self.f0_1 = sk.nn.Conv2D(self.init_features, self.out_features, kernel_size=1) - - def __call__(self, x: jax.Array, **k) -> jax.Array: - # TODO: fix to not record intermediate results - result = dict() - blocks = self.blocks - - # contractive path - result["d0_1"] = self.d0_1(x) - result["d0_2"] = self.d0_2(result["d0_1"]) - - for i in range(1, blocks): - result[f"d{i}_1"] = getattr(self, f"d{i}_1")(result[f"d{i-1}_2"]) - result[f"d{i}_2"] = getattr(self, f"d{i}_2")(result[f"d{i}_1"]) - - result["b0_1"] = self.b0_1(result[f"d{blocks-1}_2"]) - - result[f"u{blocks-1}_1"] = getattr(self, f"u{blocks-1}_1")(result["b0_1"]) - lhs_key, rhs_key = f"u{blocks-1}_1", f"d{blocks-1}_1" - result[f"u{blocks-1}_2"] = getattr(self, f"u{blocks-1}_2")( - result[lhs_key], result[rhs_key] - ) - result[f"u{blocks-1}_3"] = getattr(self, f"u{blocks-1}_3")( - result[f"u{blocks-1}_2"] - ) - - for i in range(blocks - 1, 0, -1): - result[f"u{i-1}_1"] = getattr(self, f"u{i-1}_1")(result[f"u{i}_3"]) - result[f"u{i-1}_2"] = getattr(self, f"u{blocks-1}_2")( - result[f"u{i-1}_1"], result[f"d{i-1}_1"] - ) - result[f"u{i-1}_3"] = getattr(self, f"u{i-1}_3")(result[f"u{i-1}_2"]) - - return self.f0_1(result["u0_3"]) diff --git a/serket/nn/blocks/vgg.py b/serket/nn/blocks/vgg.py deleted file mode 100644 index f1661b2..0000000 --- a/serket/nn/blocks/vgg.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright 2023 Serket authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import jax -import jax.random as jr - -import serket as sk - - -class VGG16Block(sk.TreeClass): - def __init__( - self, - in_features: int, - *, - pooling: str = "max", - key: jr.KeyArray = jr.PRNGKey(0), - ): - """ - Args: - in_features: number of input features - pooling: pooling method to use. GlobalMaxPool2D(`max`) or GlobalAvgPool2D(`avg`). - - Note: - if num_classes is None, then the classifier is not added. - see: - https://github.com/keras-team/keras/blob/v2.10.0/keras/applications/vgg16.py - https://arxiv.org/abs/1409.1556 - - """ - keys = jr.split(key, 13) - - # block 1 - self.conv_1_1 = sk.nn.Conv2D(in_features, 64, 3, padding="same", key=keys[0]) - self.conv_1_2 = sk.nn.Conv2D(64, 64, 3, padding="same", key=keys[1]) - self.maxpool_1 = sk.nn.MaxPool2D(2, strides=2) - - # block 2 - self.conv_2_1 = sk.nn.Conv2D(64, 128, 3, padding="same", key=keys[2]) - self.conv_2_2 = sk.nn.Conv2D(128, 128, 3, padding="same", key=keys[3]) - self.maxpool_2 = sk.nn.MaxPool2D(2, strides=2) - - # block 3 - self.conv_3_1 = sk.nn.Conv2D(128, 256, 3, padding="same", key=keys[4]) - self.conv_3_2 = sk.nn.Conv2D(256, 256, 3, padding="same", key=keys[5]) - self.conv_3_3 = sk.nn.Conv2D(256, 256, 3, padding="same", key=keys[6]) - self.maxpool_3 = sk.nn.MaxPool2D(2, strides=2) - - # block 4 - self.conv_4_1 = sk.nn.Conv2D(256, 512, 3, padding="same", key=keys[7]) - self.conv_4_2 = sk.nn.Conv2D(512, 512, 3, padding="same", key=keys[8]) - self.conv_4_3 = sk.nn.Conv2D(512, 512, 3, padding="same", key=keys[9]) - self.maxpool_4 = sk.nn.MaxPool2D(2, strides=2) - - # block 5 - self.conv_5_1 = sk.nn.Conv2D(512, 512, 3, padding="same", key=keys[10]) - self.conv_5_2 = sk.nn.Conv2D(512, 512, 3, padding="same", key=keys[11]) - self.conv_5_3 = sk.nn.Conv2D(512, 512, 3, padding="same", key=keys[12]) - self.maxpool_5 = sk.nn.MaxPool2D(2, strides=2) - - self.pooling = sk.nn.GlobalMaxPool2D() if pooling == "max" else sk.nn.GlobalAvgPool2D() # fmt: skip - - def __call__(self, x: jax.Array, **kwargs) -> jax.Array: - x = self.conv_1_1(x) - x = jax.nn.relu(x) - x = self.conv_1_2(x) - x = jax.nn.relu(x) - x = self.maxpool_1(x) - - x = self.conv_2_1(x) - x = jax.nn.relu(x) - x = self.conv_2_2(x) - x = jax.nn.relu(x) - x = self.maxpool_2(x) - - x = self.conv_3_1(x) - x = jax.nn.relu(x) - x = self.conv_3_2(x) - x = jax.nn.relu(x) - x = self.conv_3_3(x) - x = jax.nn.relu(x) - x = self.maxpool_3(x) - - x = self.conv_4_1(x) - x = jax.nn.relu(x) - x = self.conv_4_2(x) - x = jax.nn.relu(x) - x = self.conv_4_3(x) - x = jax.nn.relu(x) - x = self.maxpool_4(x) - - x = self.conv_5_1(x) - x = jax.nn.relu(x) - x = self.conv_5_2(x) - x = jax.nn.relu(x) - x = self.conv_5_3(x) - x = jax.nn.relu(x) - x = self.maxpool_5(x) - x = self.pooling(x) - return x - - -class VGG19Block(sk.TreeClass): - def __init__( - self, - in_feautres: int, - *, - pooling: str = "max", - key: jr.KeyArray = jr.PRNGKey(0), - ): - """ - Args: - in_features: number of input features - pooling: pooling method to use. GlobalMaxPool2D(`max`) or GlobalAvgPool2D(`avg`). - - Note: - if num_classes is None, then the classifier is not added. - see: - https://github.com/keras-team/keras/blob/v2.10.0/keras/applications/vgg19.py - https://arxiv.org/abs/1409.1556 - """ - keys = jr.split(jr.PRNGKey(0), 16) - - # block 1 - self.conv_1_1 = sk.nn.Conv2D(in_feautres, 64, 3, padding="same", key=keys[0]) - self.conv_1_2 = sk.nn.Conv2D(64, 64, 3, padding="same", key=keys[1]) - self.maxpool_1 = sk.nn.MaxPool2D(2, strides=2) - - # block 2 - self.conv_2_1 = sk.nn.Conv2D(64, 128, 3, padding="same", key=keys[2]) - self.conv_2_2 = sk.nn.Conv2D(128, 128, 3, padding="same", key=keys[3]) - self.maxpool_2 = sk.nn.MaxPool2D(2, strides=2) - - # block 3 - self.conv_3_1 = sk.nn.Conv2D(128, 256, 3, padding="same", key=keys[4]) - self.conv_3_2 = sk.nn.Conv2D(256, 256, 3, padding="same", key=keys[5]) - self.conv_3_3 = sk.nn.Conv2D(256, 256, 3, padding="same", key=keys[6]) - self.conv_3_4 = sk.nn.Conv2D(256, 256, 3, padding="same", key=keys[7]) - self.maxpool_3 = sk.nn.MaxPool2D(2, strides=2) - - # block 4 - self.conv_4_1 = sk.nn.Conv2D(256, 512, 3, padding="same", key=keys[8]) - self.conv_4_2 = sk.nn.Conv2D(512, 512, 3, padding="same", key=keys[9]) - self.conv_4_3 = sk.nn.Conv2D(512, 512, 3, padding="same", key=keys[10]) - self.conv_4_4 = sk.nn.Conv2D(512, 512, 3, padding="same", key=keys[11]) - self.maxpool_4 = sk.nn.MaxPool2D(2, strides=2) - - # block 5 - self.conv_5_1 = sk.nn.Conv2D(512, 512, 3, padding="same", key=keys[12]) - self.conv_5_2 = sk.nn.Conv2D(512, 512, 3, padding="same", key=keys[13]) - self.conv_5_3 = sk.nn.Conv2D(512, 512, 3, padding="same", key=keys[14]) - self.conv_5_4 = sk.nn.Conv2D(512, 512, 3, padding="same", key=keys[15]) - self.maxpool_5 = sk.nn.MaxPool2D(2, strides=2) - - self.pooling = ( - sk.nn.GlobalMaxPool2D() if pooling == "max" else sk.nn.GlobalAvgPool2D() - ) - - def __call__(self, x: jax.Array, **kwargs) -> jax.Array: - x = self.conv_1_1(x) - x = jax.nn.relu(x) - x = self.conv_1_2(x) - x = jax.nn.relu(x) - x = self.maxpool_1(x) - - x = self.conv_2_1(x) - x = jax.nn.relu(x) - x = self.conv_2_2(x) - x = jax.nn.relu(x) - x = self.maxpool_2(x) - - x = self.conv_3_1(x) - x = jax.nn.relu(x) - x = self.conv_3_2(x) - x = jax.nn.relu(x) - x = self.conv_3_3(x) - x = jax.nn.relu(x) - x = self.conv_3_4(x) - x = jax.nn.relu(x) - x = self.maxpool_3(x) - - x = self.conv_4_1(x) - x = jax.nn.relu(x) - x = self.conv_4_2(x) - x = jax.nn.relu(x) - x = self.conv_4_3(x) - x = jax.nn.relu(x) - x = self.conv_4_4(x) - x = jax.nn.relu(x) - - x = self.conv_5_1(x) - x = jax.nn.relu(x) - x = self.conv_5_2(x) - x = jax.nn.relu(x) - x = self.conv_5_3(x) - x = jax.nn.relu(x) - x = self.conv_5_4(x) - x = jax.nn.relu(x) - x = self.maxpool_5(x) - - x = self.pooling(x) - return x diff --git a/serket/nn/containers.py b/serket/nn/containers.py index 95ebd80..092e6d1 100644 --- a/serket/nn/containers.py +++ b/serket/nn/containers.py @@ -123,10 +123,10 @@ class RandomChoice(sk.TreeClass): >>> import serket as sk >>> import jax.random as jr >>> key = jr.PRNGKey(0) - >>> print(sk.nn.RandomChoice(lambda x: x + 2, lambda x: x * 2)(1.0, key)) + >>> print(sk.nn.RandomChoice(lambda x: x + 2, lambda x: x * 2)(1.0, key=key)) 3.0 >>> key = jr.PRNGKey(10) - >>> print(sk.nn.RandomChoice(lambda x: x + 2, lambda x: x * 2)(1.0, key)) + >>> print(sk.nn.RandomChoice(lambda x: x + 2, lambda x: x * 2)(1.0, key=key)) 2.0 Note: diff --git a/tests/test_blocks.py b/tests/test_blocks.py deleted file mode 100644 index 0dbf5ae..0000000 --- a/tests/test_blocks.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2023 Serket authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import jax -import jax.numpy as jnp - -from serket.nn.blocks import UNetBlock, VGG16Block, VGG19Block - - -def count_parameters(model): - is_array = lambda x: isinstance(x, jax.Array) - count = 0 - - def map_func(leaf): - nonlocal count - if is_array(leaf): - count += leaf.size - return leaf - - jax.tree_map(map_func, model) - return count - - -def test_vgg16_block(): - model = VGG16Block(3) - assert count_parameters(model) == 14_714_688 - assert model(jnp.ones([3, 224, 224])).shape == (512, 1, 1) - - -def test_vgg19_block(): - model = VGG19Block(3) - assert count_parameters(model) == 20_024_384 - assert model(jnp.ones([3, 224, 224])).shape == (512, 1, 1) - - -def test_unet_block(): - # assert count_parameters(UNetBlock(3, 1, 32)) == 7_757_153 - model = UNetBlock(3, 1, 2) - assert model(jnp.ones((3, 320, 320))).shape == (1, 320, 320)