From 7a657fda0981c4aa24f139c87785adfa4d6c1016 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 15 Jan 2024 17:55:46 +0800 Subject: [PATCH 01/10] Fix export name --- tools/convert_efficientnet_from_timm.py | 3 +++ tools/convert_ghostnet_from_timm.py | 3 +++ tools/convert_mobilenet_v2_from_timm.py | 3 +++ tools/convert_mobilenet_v3_from_timm.py | 3 +++ tools/convert_mobilevit_from_timm.py | 3 +++ tools/convert_resnet_from_timm.py | 3 +++ tools/convert_vit_from_timm.py | 3 +++ 7 files changed, 21 insertions(+) diff --git a/tools/convert_efficientnet_from_timm.py b/tools/convert_efficientnet_from_timm.py index ce014ac..21b64eb 100644 --- a/tools/convert_efficientnet_from_timm.py +++ b/tools/convert_efficientnet_from_timm.py @@ -2,6 +2,8 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ +import os + import keras import numpy as np import timm @@ -188,6 +190,7 @@ """ Save converted model """ + os.makedirs("exported", exist_ok=True) export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_ghostnet_from_timm.py b/tools/convert_ghostnet_from_timm.py index 05aaddf..320e287 100644 --- a/tools/convert_ghostnet_from_timm.py +++ b/tools/convert_ghostnet_from_timm.py @@ -2,6 +2,8 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ +import os + import keras import numpy as np import timm @@ -146,6 +148,7 @@ """ Save converted model """ + os.makedirs("exported", exist_ok=True) export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_mobilenet_v2_from_timm.py b/tools/convert_mobilenet_v2_from_timm.py index d0f702f..0b01ce1 100644 --- a/tools/convert_mobilenet_v2_from_timm.py +++ b/tools/convert_mobilenet_v2_from_timm.py @@ -2,6 +2,8 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ +import os + import keras import numpy as np import timm @@ -132,6 +134,7 @@ """ Save converted model """ + os.makedirs("exported", exist_ok=True) export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_mobilenet_v3_from_timm.py b/tools/convert_mobilenet_v3_from_timm.py index 4420913..0c3e677 100644 --- a/tools/convert_mobilenet_v3_from_timm.py +++ b/tools/convert_mobilenet_v3_from_timm.py @@ -2,6 +2,8 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ +import os + import keras import numpy as np import timm @@ -160,6 +162,7 @@ """ Save converted model """ + os.makedirs("exported", exist_ok=True) export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_mobilevit_from_timm.py b/tools/convert_mobilevit_from_timm.py index 445d2a7..4c53d5c 100644 --- a/tools/convert_mobilevit_from_timm.py +++ b/tools/convert_mobilevit_from_timm.py @@ -2,6 +2,8 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ +import os + import keras import numpy as np import timm @@ -136,6 +138,7 @@ """ Save converted model """ + os.makedirs("exported", exist_ok=True) export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_resnet_from_timm.py b/tools/convert_resnet_from_timm.py index eadd28c..3a8823f 100644 --- a/tools/convert_resnet_from_timm.py +++ b/tools/convert_resnet_from_timm.py @@ -2,6 +2,8 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ +import os + import keras import numpy as np import timm @@ -123,6 +125,7 @@ """ Save converted model """ + os.makedirs("exported", exist_ok=True) export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_vit_from_timm.py b/tools/convert_vit_from_timm.py index eeb0289..6da3175 100644 --- a/tools/convert_vit_from_timm.py +++ b/tools/convert_vit_from_timm.py @@ -2,6 +2,8 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ +import os + import keras import numpy as np import timm @@ -133,6 +135,7 @@ """ Save converted model """ + os.makedirs("exported", exist_ok=True) export_path = f"exported/{keras_model.name.lower()}_imagenet_384.keras" keras_model.save(export_path) print(f"Export to {export_path}") From 25f1aabbf4327f39a55eb085eaed7125b693dfca Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 15 Jan 2024 17:56:03 +0800 Subject: [PATCH 02/10] Add `DenseNet` --- kimm/models/densenet.py | 367 ++++++++++++++++++++++++++++ kimm/models/densenet_test.py | 50 ++++ tools/convert_densenet_from_timm.py | 123 ++++++++++ 3 files changed, 540 insertions(+) create mode 100644 kimm/models/densenet.py create mode 100644 kimm/models/densenet_test.py create mode 100644 tools/convert_densenet_from_timm.py diff --git a/kimm/models/densenet.py b/kimm/models/densenet.py new file mode 100644 index 0000000..f00959f --- /dev/null +++ b/kimm/models/densenet.py @@ -0,0 +1,367 @@ +import typing + +import keras +from keras import backend +from keras import layers +from keras import utils +from keras.src.applications import imagenet_utils + +from kimm.blocks import apply_conv2d_block +from kimm.models.feature_extractor import FeatureExtractor +from kimm.utils import add_model_to_registry + + +def apply_dense_layer( + inputs, growth_rate, expansion_ratio=4.0, name="dense_layer" +): + x = inputs + x = layers.BatchNormalization( + momentum=0.9, epsilon=1e-5, name=f"{name}_norm1" + )(x) + x = layers.ReLU()(x) + x = apply_conv2d_block( + x, + int(growth_rate * expansion_ratio), + 1, + 1, + activation="relu", + name=f"{name}_conv1", + ) + x = layers.Conv2D( + growth_rate, 3, 1, padding="same", use_bias=False, name=f"{name}_conv2" + )(x) + return x + + +def apply_dense_block( + inputs, num_layers, growth_rate, expansion_ratio=4.0, name="dense_block" +): + x = inputs + + features = [x] + for i in range(num_layers): + new_features = layers.Concatenate()(features) + new_features = apply_dense_layer( + new_features, + growth_rate, + expansion_ratio, + name=f"{name}_denselayer{i + 1}", + ) + features.append(new_features) + x = layers.Concatenate()(features) + return x + + +def apply_dense_transition_block( + inputs, output_channels, name="dense_transition_block" +): + x = inputs + x = layers.BatchNormalization( + momentum=0.9, epsilon=1e-5, name=f"{name}_norm" + )(x) + x = layers.ReLU()(x) + x = layers.Conv2D( + output_channels, 1, 1, "same", use_bias=False, name=f"{name}_conv" + )(x) + x = layers.AveragePooling2D(2, 2, name=f"{name}_pool")(x) + return x + + +class DenseNet(FeatureExtractor): + def __init__( + self, + growth_rate: float = 32, + num_blocks: typing.Sequence[int] = [6, 12, 24, 16], + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + **kwargs, + ): + # default_size + default_size = kwargs.pop("default_size", 224) + + # Prepare feature extraction + features = {} + + # Determine proper input shape + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=default_size, + min_size=32, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + x = img_input + + # [0, 255] to [0, 1] and apply ImageNet mean and variance + if include_preprocessing: + x = layers.Rescaling(scale=1.0 / 255.0)(x) + x = layers.Normalization( + mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] + )(x) + + # Stem block + stem_channel = growth_rate * 2 + x = apply_conv2d_block( + x, stem_channel, 7, 2, activation="relu", name="features_conv0" + ) + x = layers.ZeroPadding2D(1, name="features_pad0")(x) + x = layers.MaxPooling2D(3, 2, name="features_pool0")(x) + features["STEM_S4"] = x + + # Blocks + current_stride = 4 + input_channels = stem_channel + for current_block_idx, num_layers in enumerate(num_blocks): + x = apply_dense_block( + x, + num_layers, + growth_rate, + expansion_ratio=4.0, + name=f"features_denseblock{current_block_idx + 1}", + ) + input_channels = input_channels + num_layers * growth_rate + if current_block_idx != len(num_blocks) - 1: + current_stride *= 2 + x = apply_dense_transition_block( + x, + input_channels // 2, + name=f"features_transition{current_block_idx + 1}", + ) + input_channels = input_channels // 2 + + features[f"BLOCK{current_block_idx}_S{current_stride}"] = x + + # Final batch norm + x = layers.BatchNormalization( + momentum=0.9, epsilon=1e-5, name="features_norm5" + )(x) + x = layers.ReLU()(x) + + # Head + if include_top: + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D(name="max_pool")(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + + # All references to `self` below this line + self.growth_rate = growth_rate + self.num_blocks = num_blocks + self.include_preprocessing = include_preprocessing + self.include_top = include_top + self.pooling = pooling + self.dropout_rate = dropout_rate + self.classes = classes + self.classifier_activation = classifier_activation + self._weights = weights # `self.weights` is been used internally + + @staticmethod + def available_feature_keys(): + feature_keys = ["STEM_S4"] + feature_keys.extend( + [f"BLOCK{i}_S{j}" for i, j in zip(range(4), [8, 16, 32, 32])] + ) + return feature_keys + + def get_config(self): + config = super().get_config() + config.update( + { + "growth_rate": self.growth_rate, + "num_blocks": self.num_blocks, + "input_shape": self.input_shape[1:], + "include_preprocessing": self.include_preprocessing, + "include_top": self.include_top, + "pooling": self.pooling, + "dropout_rate": self.dropout_rate, + "classes": self.classes, + "classifier_activation": self.classifier_activation, + "weights": self._weights, + } + ) + return config + + def fix_config(self, config: typing.Dict): + unused_kwargs = ["growth_rate", "num_blocks"] + for k in unused_kwargs: + config.pop(k, None) + return config + + +""" +Model Definition +""" + + +class DenseNet121(DenseNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + name: str = "DenseNet121", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 32, + [6, 12, 24, 16], + input_tensor, + input_shape, + include_preprocessing, + include_top, + pooling, + dropout_rate, + classes, + classifier_activation, + weights, + name=name, + default_size=288, + **kwargs, + ) + + +class DenseNet161(DenseNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + name: str = "DenseNet161", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 48, + [6, 12, 36, 24], + input_tensor, + input_shape, + include_preprocessing, + include_top, + pooling, + dropout_rate, + classes, + classifier_activation, + weights, + name=name, + default_size=224, + **kwargs, + ) + + +class DenseNet169(DenseNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + name: str = "DenseNet169", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 32, + [6, 12, 32, 32], + input_tensor, + input_shape, + include_preprocessing, + include_top, + pooling, + dropout_rate, + classes, + classifier_activation, + weights, + name=name, + default_size=224, + **kwargs, + ) + + +class DenseNet201(DenseNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + name: str = "DenseNet201", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 32, + [6, 12, 48, 32], + input_tensor, + input_shape, + include_preprocessing, + include_top, + pooling, + dropout_rate, + classes, + classifier_activation, + weights, + name=name, + default_size=224, + **kwargs, + ) + + +add_model_to_registry(DenseNet121, True) +add_model_to_registry(DenseNet161, True) +add_model_to_registry(DenseNet169, True) +add_model_to_registry(DenseNet201, True) diff --git a/kimm/models/densenet_test.py b/kimm/models/densenet_test.py new file mode 100644 index 0000000..31a3ab9 --- /dev/null +++ b/kimm/models/densenet_test.py @@ -0,0 +1,50 @@ +from absl.testing import parameterized +from keras import models +from keras import random +from keras.src import testing + +from kimm.models.densenet import DenseNet121 + + +class DenseNetTest(testing.TestCase, parameterized.TestCase): + @parameterized.named_parameters([(DenseNet121.__name__, DenseNet121)]) + def test_densenet_base(self, model_class): + # TODO: test the correctness of the real image + x = random.uniform([1, 224, 224, 3]) * 255.0 + model = model_class(input_shape=[224, 224, 3]) + + y = model(x, training=False) + + self.assertEqual(y.shape, (1, 1000)) + + @parameterized.named_parameters([(DenseNet121.__name__, DenseNet121)]) + def test_mobilenet_v2_feature_extractor(self, model_class): + x = random.uniform([1, 224, 224, 3]) * 255.0 + model = model_class( + input_shape=[224, 224, 3], as_feature_extractor=True + ) + + y = model(x, training=False) + + self.assertIsInstance(y, dict) + self.assertAllEqual( + list(y.keys()), model_class.available_feature_keys() + ) + self.assertEqual(list(y["STEM_S4"].shape), [1, 56, 56, 64]) + self.assertEqual(list(y["BLOCK0_S8"].shape), [1, 28, 28, 128]) + self.assertEqual(list(y["BLOCK1_S16"].shape), [1, 14, 14, 256]) + self.assertEqual(list(y["BLOCK2_S32"].shape), [1, 7, 7, 512]) + self.assertEqual(list(y["BLOCK3_S32"].shape), [1, 7, 7, 1024]) + + @parameterized.named_parameters([(DenseNet121.__name__, DenseNet121, 224)]) + def test_mobilenet_v2_serialization(self, model_class, image_size): + x = random.uniform([1, image_size, image_size, 3]) * 255.0 + temp_dir = self.get_temp_dir() + model1 = model_class(input_shape=[224, 224, 3]) + y1 = model1(x, training=False) + model1.save(temp_dir + "/model.keras") + + model2 = models.load_model(temp_dir + "/model.keras") + y2 = model2(x, training=False) + + self.assertAllClose(y1, y2) diff --git a/tools/convert_densenet_from_timm.py b/tools/convert_densenet_from_timm.py new file mode 100644 index 0000000..5d882a7 --- /dev/null +++ b/tools/convert_densenet_from_timm.py @@ -0,0 +1,123 @@ +""" +pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu +pip install timm +""" +import os + +import keras +import numpy as np +import timm +import torch + +from kimm.models import densenet +from kimm.utils.timm_utils import assign_weights +from kimm.utils.timm_utils import is_same_weights +from kimm.utils.timm_utils import separate_keras_weights +from kimm.utils.timm_utils import separate_torch_state_dict + +timm_model_names = [ + "densenet121.ra_in1k", + "densenet161.tv_in1k", + "densenet169.tv_in1k", + "densenet201.tv_in1k", +] +keras_model_classes = [ + densenet.DenseNet121, + densenet.DenseNet161, + densenet.DenseNet169, + densenet.DenseNet201, +] + +for timm_model_name, keras_model_class in zip( + timm_model_names, keras_model_classes +): + """ + Prepare timm model and keras model + """ + input_shape = [224, 224, 3] + torch_model = timm.create_model(timm_model_name, pretrained=True) + torch_model = torch_model.eval() + trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( + torch_model.state_dict() + ) + keras_model = keras_model_class( + input_shape=input_shape, + include_preprocessing=False, + classifier_activation="linear", + ) + trainable_weights, non_trainable_weights = separate_keras_weights( + keras_model + ) + + # for torch_name, (_, keras_name) in zip( + # trainable_state_dict.keys(), trainable_weights + # ): + # print(f"{torch_name} {keras_name}") + + # print(len(trainable_state_dict.keys())) + # print(len(trainable_weights)) + + # exit() + + """ + Assign weights + """ + for keras_weight, keras_name in trainable_weights + non_trainable_weights: + keras_name: str + torch_name = keras_name + torch_name = torch_name.replace("_", ".") + # stem + torch_name = torch_name.replace("conv0.conv2d", "conv0") + torch_name = torch_name.replace("conv0.bn", "norm0") + # blocks + torch_name = torch_name.replace("conv1.conv2d", "conv1") + torch_name = torch_name.replace("conv1.bn", "norm2") + + # weights naming mapping + torch_name = torch_name.replace("kernel", "weight") # conv2d + torch_name = torch_name.replace("gamma", "weight") # bn + torch_name = torch_name.replace("beta", "bias") # bn + torch_name = torch_name.replace("moving.mean", "running_mean") # bn + torch_name = torch_name.replace("moving.variance", "running_var") # bn + + # assign weights + if torch_name in trainable_state_dict: + torch_weights = trainable_state_dict[torch_name].numpy() + elif torch_name in non_trainable_state_dict: + torch_weights = non_trainable_state_dict[torch_name].numpy() + else: + raise ValueError( + "Can't find the corresponding torch weights. " + f"Got keras_name={keras_name}, torch_name={torch_name}" + ) + if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): + assign_weights(keras_name, keras_weight, torch_weights) + else: + raise ValueError( + "Can't find the corresponding torch weights. The shape is " + f"mismatched. Got keras_name={keras_name}, " + f"keras_weight shape={keras_weight.shape}, " + f"torch_name={torch_name}, " + f"torch_weights shape={torch_weights.shape}" + ) + + """ + Verify model outputs + """ + np.random.seed(2023) + keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") + torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) + torch_y = torch_model(torch_data) + keras_y = keras_model(keras_data, training=False) + torch_y = torch_y.detach().cpu().numpy() + keras_y = keras.ops.convert_to_numpy(keras_y) + np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) + print(f"{keras_model_class.__name__}: output matched!") + + """ + Save converted model + """ + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" + keras_model.save(export_path) + print(f"Export to {export_path}") From 7d57a0865b6e9f556ce98abf6486c9841b55ee37 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 15 Jan 2024 18:04:32 +0800 Subject: [PATCH 03/10] Cleanup --- kimm/layers/attention.py | 15 --------------- kimm/layers/layer_scale.py | 15 --------------- kimm/layers/position_embedding.py | 22 ---------------------- 3 files changed, 52 deletions(-) diff --git a/kimm/layers/attention.py b/kimm/layers/attention.py index 1797020..7eb90af 100644 --- a/kimm/layers/attention.py +++ b/kimm/layers/attention.py @@ -118,18 +118,3 @@ def get_config(self): } ) return config - - -if __name__ == "__main__": - from keras import models - from keras import random - - inputs = layers.Input(shape=[197, 768]) - outputs = Attention(768)(inputs) - - model = models.Model(inputs, outputs) - model.summary() - - inputs = random.uniform([1, 197, 768]) - outputs = model(inputs) - print(outputs.shape) diff --git a/kimm/layers/layer_scale.py b/kimm/layers/layer_scale.py index 9f030c7..0afce2d 100644 --- a/kimm/layers/layer_scale.py +++ b/kimm/layers/layer_scale.py @@ -35,18 +35,3 @@ def get_config(self): } ) return config - - -if __name__ == "__main__": - from keras import models - from keras import random - - inputs = layers.Input(shape=[197, 768]) - outputs = LayerScale(768)(inputs) - - model = models.Model(inputs, outputs) - model.summary() - - inputs = random.uniform([1, 197, 768]) - outputs = model(inputs) - print(outputs.shape) diff --git a/kimm/layers/position_embedding.py b/kimm/layers/position_embedding.py index 167b9ae..82670f3 100644 --- a/kimm/layers/position_embedding.py +++ b/kimm/layers/position_embedding.py @@ -38,25 +38,3 @@ def compute_output_shape(self, input_shape): def get_config(self): return super().get_config() - - -if __name__ == "__main__": - from keras import models - from keras import random - - inputs = layers.Input([224, 224, 3]) - x = layers.Conv2D( - 768, - 16, - 16, - use_bias=True, - )(inputs) - x = layers.Reshape((-1, 768))(x) - outputs = PositionEmbedding()(x) - - model = models.Model(inputs, outputs) - model.summary() - - inputs = random.uniform([1, 224, 224, 3]) - outputs = model(inputs) - print(outputs.shape) From 181ea9840b2093a66cc666329de431e8a74c601d Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 16 Jan 2024 10:34:23 +0800 Subject: [PATCH 04/10] Add `InceptionV3` --- kimm/blocks/base_block.py | 6 +- kimm/models/inception_v3.py | 384 ++++++++++++++++++++++++ kimm/models/inception_v3_test.py | 48 +++ tools/convert_inception_v3_from_timm.py | 150 +++++++++ 4 files changed, 587 insertions(+), 1 deletion(-) create mode 100644 kimm/models/inception_v3.py create mode 100644 kimm/models/inception_v3_test.py create mode 100644 tools/convert_inception_v3_from_timm.py diff --git a/kimm/blocks/base_block.py b/kimm/blocks/base_block.py index 2a428c6..54e3672 100644 --- a/kimm/blocks/base_block.py +++ b/kimm/blocks/base_block.py @@ -34,6 +34,8 @@ def apply_conv2d_block( raise ValueError( f"kernel_size must be passed. Received: kernel_size={kernel_size}" ) + if isinstance(kernel_size, int): + kernel_size = [kernel_size, kernel_size] input_channels = inputs.shape[-1] has_skip = add_skip and strides == 1 and input_channels == filters x = inputs @@ -42,7 +44,9 @@ def apply_conv2d_block( padding = "same" if strides > 1: padding = "valid" - x = layers.ZeroPadding2D(kernel_size // 2, name=f"{name}_pad")(x) + x = layers.ZeroPadding2D( + (kernel_size[0] // 2, kernel_size[1] // 2), name=f"{name}_pad" + )(x) if not use_depthwise: x = layers.Conv2D( diff --git a/kimm/models/inception_v3.py b/kimm/models/inception_v3.py new file mode 100644 index 0000000..e30fd02 --- /dev/null +++ b/kimm/models/inception_v3.py @@ -0,0 +1,384 @@ +import functools +import typing + +import keras +from keras import backend +from keras import layers +from keras import utils +from keras.src.applications import imagenet_utils + +from kimm.blocks import apply_conv2d_block +from kimm.models.feature_extractor import FeatureExtractor +from kimm.utils import add_model_to_registry + +_apply_conv2d_block = functools.partial( + apply_conv2d_block, activation="relu", bn_epsilon=1e-3, padding="valid" +) + + +def apply_inception_a_block(inputs, pool_channels, name="inception_a_block"): + x = inputs + + branch1x1 = _apply_conv2d_block(x, 64, 1, 1, name=f"{name}_branch1x1") + + branch5x5 = _apply_conv2d_block(x, 48, 1, 1, name=f"{name}_branch5x5_1") + branch5x5 = _apply_conv2d_block( + branch5x5, 64, 5, 1, padding=None, name=f"{name}_branch5x5_2" + ) + + branch3x3dbl = _apply_conv2d_block( + x, 64, 1, 1, name=f"{name}_branch3x3dbl_1" + ) + branch3x3dbl = _apply_conv2d_block( + branch3x3dbl, 96, 3, 1, padding=None, name=f"{name}_branch3x3dbl_2" + ) + branch3x3dbl = _apply_conv2d_block( + branch3x3dbl, 96, 3, 1, padding=None, name=f"{name}_branch3x3dbl_3" + ) + + branch_pool = layers.ZeroPadding2D(1)(x) + branch_pool = layers.AveragePooling2D(3, 1)(branch_pool) + branch_pool = _apply_conv2d_block( + branch_pool, + pool_channels, + 1, + 1, + activation="relu", + name=f"{name}_branch_pool", + ) + x = layers.Concatenate()([branch1x1, branch5x5, branch3x3dbl, branch_pool]) + return x + + +def apply_inception_b_block(inputs, name="incpetion_b_block"): + x = inputs + + branch3x3 = _apply_conv2d_block(x, 384, 3, 2, name=f"{name}_branch3x3") + + branch3x3dbl = _apply_conv2d_block( + x, 64, 1, 1, name=f"{name}_branch3x3dbl_1" + ) + branch3x3dbl = _apply_conv2d_block( + branch3x3dbl, 96, 3, 1, padding=None, name=f"{name}_branch3x3dbl_2" + ) + branch3x3dbl = _apply_conv2d_block( + branch3x3dbl, 96, 3, 2, name=f"{name}_branch3x3dbl_3" + ) + + branch_pool = layers.MaxPooling2D(3, 2, name=f"{name}_branch_pool")(x) + x = layers.Concatenate()([branch3x3, branch3x3dbl, branch_pool]) + return x + + +def apply_inception_c_block( + inputs, branch7x7_channels, name="inception_c_block" +): + c7 = branch7x7_channels + x = inputs + + branch1x1 = _apply_conv2d_block(x, 192, 1, 1, name=f"{name}_branch1x1") + + branch7x7 = _apply_conv2d_block(x, c7, 1, 1, name=f"{name}_branch7x7_1") + branch7x7 = _apply_conv2d_block( + branch7x7, c7, (1, 7), 1, padding=None, name=f"{name}_branch7x7_2" + ) + branch7x7 = _apply_conv2d_block( + branch7x7, 192, (7, 1), 1, padding=None, name=f"{name}_branch7x7_3" + ) + + branch7x7dbl = _apply_conv2d_block( + x, c7, 1, 1, name=f"{name}_branch7x7dbl_1" + ) + branch7x7dbl = _apply_conv2d_block( + branch7x7dbl, c7, (7, 1), 1, padding=None, name=f"{name}_branch7x7dbl_2" + ) + branch7x7dbl = _apply_conv2d_block( + branch7x7dbl, c7, (1, 7), 1, padding=None, name=f"{name}_branch7x7dbl_3" + ) + branch7x7dbl = _apply_conv2d_block( + branch7x7dbl, c7, (7, 1), 1, padding=None, name=f"{name}_branch7x7dbl_4" + ) + branch7x7dbl = _apply_conv2d_block( + branch7x7dbl, + 192, + (1, 7), + 1, + padding=None, + name=f"{name}_branch7x7dbl_5", + ) + + branch_pool = layers.ZeroPadding2D(1)(x) + branch_pool = layers.AveragePooling2D(3, 1)(branch_pool) + branch_pool = _apply_conv2d_block( + branch_pool, 192, 1, 1, name=f"{name}_branch_pool" + ) + x = layers.Concatenate()([branch1x1, branch7x7, branch7x7dbl, branch_pool]) + return x + + +def apply_inception_d_block(inputs, name="inception_d_block"): + x = inputs + + branch3x3 = _apply_conv2d_block(x, 192, 1, 1, name=f"{name}_branch3x3_1") + branch3x3 = _apply_conv2d_block( + branch3x3, 320, 3, 2, name=f"{name}_branch3x3_2" + ) + + branch7x7x3 = _apply_conv2d_block( + x, 192, 1, 1, name=f"{name}_branch7x7x3_1" + ) + branch7x7x3 = _apply_conv2d_block( + branch7x7x3, 192, (1, 7), 1, padding=None, name=f"{name}_branch7x7x3_2" + ) + branch7x7x3 = _apply_conv2d_block( + branch7x7x3, 192, (7, 1), 1, padding=None, name=f"{name}_branch7x7x3_3" + ) + branch7x7x3 = _apply_conv2d_block( + branch7x7x3, 192, 3, 2, name=f"{name}_branch7x7x3_4" + ) + + branch_pool = layers.MaxPooling2D(3, 2)(x) + x = layers.Concatenate()([branch3x3, branch7x7x3, branch_pool]) + return x + + +def apply_inception_e_block(inputs, name="inception_e_block"): + x = inputs + + branch1x1 = _apply_conv2d_block(x, 320, 1, 1, name=f"{name}_branch1x1") + + branch3x3 = _apply_conv2d_block(x, 384, 1, 1, name=f"{name}_branch3x3_1") + branch3x3 = [ + _apply_conv2d_block( + branch3x3, 384, (1, 3), 1, padding=None, name=f"{name}_branch3x3_2a" + ), + _apply_conv2d_block( + branch3x3, 384, (3, 1), 1, padding=None, name=f"{name}_branch3x3_2b" + ), + ] + branch3x3 = layers.Concatenate()(branch3x3) + + branch3x3dbl = _apply_conv2d_block( + x, 448, 1, 1, name=f"{name}_branch3x3dbl_1" + ) + branch3x3dbl = _apply_conv2d_block( + branch3x3dbl, 384, 3, 1, padding=None, name=f"{name}_branch3x3dbl_2" + ) + branch3x3dbl = [ + _apply_conv2d_block( + branch3x3dbl, + 384, + (1, 3), + 1, + padding=None, + name=f"{name}_branch3x3dbl_3a", + ), + _apply_conv2d_block( + branch3x3dbl, + 384, + (3, 1), + 1, + padding=None, + name=f"{name}_branch3x3dbl_3b", + ), + ] + branch3x3dbl = layers.Concatenate()(branch3x3dbl) + + branch_pool = layers.ZeroPadding2D(1)(x) + branch_pool = layers.AveragePooling2D(3, 1)(branch_pool) + branch_pool = _apply_conv2d_block( + branch_pool, 192, 1, 1, name=f"{name}_branch_pool" + ) + x = layers.Concatenate()([branch1x1, branch3x3, branch3x3dbl, branch_pool]) + return x + + +def apply_inception_aux_block(inputs, classes, name="inception_aux_block"): + x = inputs + + x = layers.AveragePooling2D(5, 3)(x) + x = _apply_conv2d_block(x, 128, 1, 1, name=f"{name}_conv0") + x = _apply_conv2d_block(x, 768, 5, 1, name=f"{name}_conv1") + x = layers.GlobalAveragePooling2D()(x) + x = layers.Dense(classes, use_bias=True, name=f"{name}_fc")(x) + return x + + +class InceptionV3Base(FeatureExtractor): + def __init__( + self, + has_aux_logits=False, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + **kwargs, + ): + # default_size + default_size = kwargs.pop("default_size", 299) + + # Prepare feature extraction + features = {} + + # Determine proper input shape + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=default_size, + min_size=32, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + x = img_input + + # [0, 255] to [0, 1] and apply ImageNet mean and variance + if include_preprocessing: + x = layers.Rescaling(scale=1.0 / 255.0)(x) + x = layers.Normalization( + mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] + )(x) + + # Stem block + x = _apply_conv2d_block(x, 32, 3, 2, name="Conv2d_1a_3x3") + x = _apply_conv2d_block(x, 32, 3, 1, name="Conv2d_2a_3x3") + x = _apply_conv2d_block(x, 64, 3, 1, padding=None, name="Conv2d_2b_3x3") + features["STEM_S2"] = x + + # Blocks + x = layers.MaxPooling2D(3, 2, name="Pool1")(x) + x = _apply_conv2d_block(x, 80, 1, 1, name="Conv2d_3b_1x1") + x = _apply_conv2d_block(x, 192, 3, 1, name="Conv2d_4a_3x3") + features["BLOCK0_S4"] = x + x = layers.MaxPooling2D(3, 2, name="Pool2")(x) + x = apply_inception_a_block(x, 32, "Mixed_5b") + x = apply_inception_a_block(x, 64, "Mixed_5c") + x = apply_inception_a_block(x, 64, "Mixed_5d") + features["BLOCK1_S8"] = x + + x = apply_inception_b_block(x, "Mixed_6a") + + x = apply_inception_c_block(x, 128, "Mixed_6b") + x = apply_inception_c_block(x, 160, "Mixed_6c") + x = apply_inception_c_block(x, 160, "Mixed_6d") + x = apply_inception_c_block(x, 192, "Mixed_6e") + features["BLOCK2_S16"] = x + + if has_aux_logits: + aux_logits = apply_inception_aux_block(x, classes, "AuxLogits") + + x = apply_inception_d_block(x, "Mixed_7a") + x = apply_inception_e_block(x, "Mixed_7b") + x = apply_inception_e_block(x, "Mixed_7c") + features["BLOCK3_S32"] = x + + # Head + if include_top: + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D(name="max_pool")(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + if has_aux_logits: + x = [x, aux_logits] + super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + + # All references to `self` below this line + self.has_aux_logits = has_aux_logits + self.include_preprocessing = include_preprocessing + self.include_top = include_top + self.pooling = pooling + self.dropout_rate = dropout_rate + self.classes = classes + self.classifier_activation = classifier_activation + self._weights = weights # `self.weights` is been used internally + + @staticmethod + def available_feature_keys(): + feature_keys = ["STEM_S2"] + feature_keys.extend( + [f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])] + ) + return feature_keys + + def get_config(self): + config = super().get_config() + config.update( + { + "has_aux_logits": self.has_aux_logits, + "input_shape": self.input_shape[1:], + "include_preprocessing": self.include_preprocessing, + "include_top": self.include_top, + "pooling": self.pooling, + "dropout_rate": self.dropout_rate, + "classes": self.classes, + "classifier_activation": self.classifier_activation, + "weights": self._weights, + } + ) + return config + + def fix_config(self, config: typing.Dict): + return config + + +class InceptionV3(InceptionV3Base): + def __init__( + self, + has_aux_logits: bool = False, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + name: str = "InceptionV3", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + has_aux_logits, + input_tensor, + input_shape, + include_preprocessing, + include_top, + pooling, + dropout_rate, + classes, + classifier_activation, + weights, + name=name, + **kwargs, + ) + + +add_model_to_registry(InceptionV3, True) diff --git a/kimm/models/inception_v3_test.py b/kimm/models/inception_v3_test.py new file mode 100644 index 0000000..3899e01 --- /dev/null +++ b/kimm/models/inception_v3_test.py @@ -0,0 +1,48 @@ +from absl.testing import parameterized +from keras import models +from keras import random +from keras.src import testing + +from kimm.models.inception_v3 import InceptionV3 + + +class InceptionV3Test(testing.TestCase, parameterized.TestCase): + @parameterized.named_parameters([(InceptionV3.__name__, InceptionV3)]) + def test_inception_v3_base(self, model_class): + # TODO: test the correctness of the real image + x = random.uniform([1, 299, 299, 3]) * 255.0 + model = model_class() + + y = model(x, training=False) + + self.assertEqual(y.shape, (1, 1000)) + + @parameterized.named_parameters([(InceptionV3.__name__, InceptionV3)]) + def test_inception_v3_feature_extractor(self, model_class): + x = random.uniform([1, 299, 299, 3]) * 255.0 + model = model_class(as_feature_extractor=True) + + y = model(x, training=False) + + self.assertIsInstance(y, dict) + self.assertAllEqual( + list(y.keys()), model_class.available_feature_keys() + ) + self.assertEqual(list(y["STEM_S2"].shape), [1, 147, 147, 64]) + self.assertEqual(list(y["BLOCK0_S4"].shape), [1, 71, 71, 192]) + self.assertEqual(list(y["BLOCK1_S8"].shape), [1, 35, 35, 288]) + self.assertEqual(list(y["BLOCK2_S16"].shape), [1, 17, 17, 768]) + self.assertEqual(list(y["BLOCK3_S32"].shape), [1, 8, 8, 2048]) + + @parameterized.named_parameters([(InceptionV3.__name__, InceptionV3, 299)]) + def test_inception_v3_serialization(self, model_class, image_size): + x = random.uniform([1, image_size, image_size, 3]) * 255.0 + temp_dir = self.get_temp_dir() + model1 = model_class() + y1 = model1(x, training=False) + model1.save(temp_dir + "/model.keras") + + model2 = models.load_model(temp_dir + "/model.keras") + y2 = model2(x, training=False) + + self.assertAllClose(y1, y2) diff --git a/tools/convert_inception_v3_from_timm.py b/tools/convert_inception_v3_from_timm.py new file mode 100644 index 0000000..f73d277 --- /dev/null +++ b/tools/convert_inception_v3_from_timm.py @@ -0,0 +1,150 @@ +""" +pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu +pip install timm +""" +import os + +import keras +import numpy as np +import timm +import torch + +from kimm.models import inception_v3 +from kimm.utils.timm_utils import assign_weights +from kimm.utils.timm_utils import is_same_weights +from kimm.utils.timm_utils import separate_keras_weights +from kimm.utils.timm_utils import separate_torch_state_dict + +timm_model_names = [ + "inception_v3.gluon_in1k", +] +keras_model_classes = [ + inception_v3.InceptionV3, +] + +for timm_model_name, keras_model_class in zip( + timm_model_names, keras_model_classes +): + """ + Prepare timm model and keras model + """ + input_shape = [299, 299, 3] + torch_model = timm.create_model( + timm_model_name, pretrained=True, aux_logits=True + ) + torch_model = torch_model.eval() + trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( + torch_model.state_dict() + ) + keras_model = keras_model_class( + has_aux_logits=True, + input_shape=input_shape, + include_preprocessing=False, + classifier_activation="linear", + ) + trainable_weights, non_trainable_weights = separate_keras_weights( + keras_model + ) + + # for torch_name, (_, keras_name) in zip( + # trainable_state_dict.keys(), trainable_weights + # ): + # print(f"{torch_name} {keras_name}") + + # print(len(trainable_state_dict.keys())) + # print(len(trainable_weights)) + + # exit() + + """ + Preprocess + """ + new_dict = {} + old_keys = trainable_state_dict.keys() + new_keys = [] + for k in old_keys: + new_key = k.replace("_", ".") + new_key = new_key.replace("running.mean", "running_mean") + new_key = new_key.replace("running.var", "running_var") + new_keys.append(new_key) + for k1, k2 in zip(trainable_state_dict.keys(), new_keys): + new_dict[k2] = trainable_state_dict[k1] + trainable_state_dict = new_dict + + new_dict = {} + old_keys = non_trainable_state_dict.keys() + new_keys = [] + for k in old_keys: + new_key = k.replace("_", ".") + new_key = new_key.replace("running.mean", "running_mean") + new_key = new_key.replace("running.var", "running_var") + new_keys.append(new_key) + for k1, k2 in zip(non_trainable_state_dict.keys(), new_keys): + new_dict[k2] = non_trainable_state_dict[k1] + non_trainable_state_dict = new_dict + + """ + Assign weights + """ + for keras_weight, keras_name in trainable_weights + non_trainable_weights: + keras_name: str + torch_name = keras_name + torch_name = torch_name.replace("_", ".") + # general + torch_name = torch_name.replace("conv2d", "conv") + # head + torch_name = torch_name.replace("classifier", "fc") + + # weights naming mapping + torch_name = torch_name.replace("kernel", "weight") # conv2d + torch_name = torch_name.replace("gamma", "weight") # bn + torch_name = torch_name.replace("beta", "bias") # bn + torch_name = torch_name.replace("moving.mean", "running_mean") # bn + torch_name = torch_name.replace("moving.variance", "running_var") # bn + + # assign weights + if torch_name in trainable_state_dict: + torch_weights = trainable_state_dict[torch_name].numpy() + elif torch_name in non_trainable_state_dict: + torch_weights = non_trainable_state_dict[torch_name].numpy() + else: + raise ValueError( + "Can't find the corresponding torch weights. " + f"Got keras_name={keras_name}, torch_name={torch_name}" + ) + if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): + assign_weights(keras_name, keras_weight, torch_weights) + else: + raise ValueError( + "Can't find the corresponding torch weights. The shape is " + f"mismatched. Got keras_name={keras_name}, " + f"keras_weight shape={keras_weight.shape}, " + f"torch_name={torch_name}, " + f"torch_weights shape={torch_weights.shape}" + ) + + """ + Verify model outputs + """ + np.random.seed(2023) + keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") + torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) + torch_y = torch_model(torch_data) + torch_y, torch_y_aux = torch_y[0], torch_y[1] + keras_y = keras_model(keras_data, training=False) + keras_y, keras_y_aux = keras_y[0], keras_y[1] + torch_y = torch_y.detach().cpu().numpy() + torch_y_aux = torch_y_aux.detach().cpu().numpy() + keras_y = keras.ops.convert_to_numpy(keras_y) + keras_y_aux = keras.ops.convert_to_numpy(keras_y_aux) + np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) + np.testing.assert_allclose(torch_y_aux, keras_y_aux, atol=1e-5) + print(f"{keras_model_class.__name__}: output matched!") + + """ + Save converted model + """ + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" + keras_model.save(export_path) + print(f"Export to {export_path}") From 5eb46b60afb2fac820fa2b565546cfc62e4a047f Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 16 Jan 2024 23:49:06 +0800 Subject: [PATCH 05/10] Refactor `BaseModel` --- kimm/models/__init__.py | 2 +- kimm/models/base_model.py | 135 +++++ ...e_extractor_test.py => base_model_test.py} | 6 +- kimm/models/densenet.py | 184 +++--- kimm/models/efficientnet.py | 573 +++++++++--------- kimm/models/feature_extractor.py | 54 -- kimm/models/ghostnet.py | 226 ++++--- kimm/models/inception_v3.py | 142 ++--- kimm/models/mobilenet_v2.py | 202 +++--- kimm/models/mobilenet_v3.py | 333 +++++----- kimm/models/mobilevit.py | 168 +++-- kimm/models/resnet.py | 4 +- kimm/models/vision_transformer.py | 4 +- kimm/utils/model_registry.py | 4 +- kimm/utils/model_registry_test.py | 4 +- shell/export.sh | 11 + tools/convert_efficientnet_from_timm.py | 2 +- tools/convert_ghostnet_from_timm.py | 2 +- tools/convert_mobilenet_v2_from_timm.py | 2 +- tools/convert_mobilenet_v3_from_timm.py | 2 +- tools/convert_mobilevit_from_timm.py | 2 +- tools/convert_resnet_from_timm.py | 2 +- tools/convert_vit_from_timm.py | 2 +- 23 files changed, 1004 insertions(+), 1062 deletions(-) create mode 100644 kimm/models/base_model.py rename kimm/models/{feature_extractor_test.py => base_model_test.py} (93%) delete mode 100644 kimm/models/feature_extractor.py create mode 100755 shell/export.sh diff --git a/kimm/models/__init__.py b/kimm/models/__init__.py index 2b60641..6be22a2 100644 --- a/kimm/models/__init__.py +++ b/kimm/models/__init__.py @@ -1,5 +1,5 @@ +from kimm.models.base_model import BaseModel from kimm.models.efficientnet import * # noqa:F403 -from kimm.models.feature_extractor import FeatureExtractor from kimm.models.ghostnet import * # noqa:F403 from kimm.models.mobilenet_v2 import * # noqa:F403 from kimm.models.mobilenet_v3 import * # noqa:F403 diff --git a/kimm/models/base_model.py b/kimm/models/base_model.py new file mode 100644 index 0000000..ee1f7e2 --- /dev/null +++ b/kimm/models/base_model.py @@ -0,0 +1,135 @@ +import abc +import typing + +from keras import KerasTensor +from keras import backend +from keras import layers +from keras import models +from keras.src.applications import imagenet_utils + + +class BaseModel(models.Model): + def __init__( + self, + inputs, + outputs, + features: typing.Optional[typing.Dict[str, KerasTensor]] = None, + feature_keys: typing.Optional[typing.List[str]] = None, + **kwargs, + ): + self.as_feature_extractor = kwargs.pop("as_feature_extractor", False) + self.feature_keys = feature_keys + if self.as_feature_extractor: + if features is None: + raise ValueError( + "`features` must be set when " + f"`as_feature_extractor=True`. Got features={features}" + ) + if self.feature_keys is None: + self.feature_keys = list(features.keys()) + filtered_features = {} + for k in self.feature_keys: + if k not in features: + raise KeyError( + f"'{k}' is not a key of `features`. Available keys " + f"are: {list(features.keys())}" + ) + filtered_features[k] = features[k] + super().__init__(inputs=inputs, outputs=filtered_features, **kwargs) + else: + del features + super().__init__(inputs=inputs, outputs=outputs, **kwargs) + + def parse_kwargs( + self, kwargs: typing.Dict[str, typing.Any], default_size: int = 224 + ): + result = { + "input_tensor": kwargs.pop("input_tensor", None), + "input_shape": kwargs.pop("input_shape", None), + "include_preprocessing": kwargs.pop("include_preprocessing", True), + "include_top": kwargs.pop("include_top", True), + "pooling": kwargs.pop("pooling", None), + "dropout_rate": kwargs.pop("dropout_rate", 0.0), + "classes": kwargs.pop("classes", 1000), + "classifier_activation": kwargs.pop( + "classifier_activation", "softmax" + ), + "weights": kwargs.pop("weights", "imagenet"), + "default_size": kwargs.pop("default_size", default_size), + } + return result + + def determine_input_tensor( + self, + input_tensor=None, + input_shape=None, + default_size=224, + min_size=32, + require_flatten=False, + static_shape=False, + ): + """Determine the input tensor by the arguments.""" + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=default_size, + min_size=min_size, + data_format="channels_last", # always channels_last + require_flatten=require_flatten or static_shape, + weights=None, + ) + + if input_tensor is None: + x = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + x = layers.Input(tensor=input_tensor, shape=input_shape) + else: + x = input_tensor + return x + + def build_preprocessing(self, inputs): + # TODO: add docstring + raise NotImplementedError + + def build_top(self, inputs, classes, classifier_activation, dropout_rate): + # TODO: add docstring + raise NotImplementedError + + def add_references(self, parsed_kwargs: typing.Dict[str, typing.Any]): + self.include_preprocessing = parsed_kwargs["include_preprocessing"] + self.include_top = parsed_kwargs["include_top"] + self.pooling = parsed_kwargs["pooling"] + self.dropout_rate = parsed_kwargs["dropout_rate"] + self.classes = parsed_kwargs["classes"] + self.classifier_activation = parsed_kwargs["classifier_activation"] + # `self.weights` is been used internally + self._weights = parsed_kwargs["weights"] + + @staticmethod + @abc.abstractmethod + def available_feature_keys(): + # TODO: add docstring + raise NotImplementedError + + def get_config(self): + # Don't chain to super here. The default `get_config()` for functional + # models is nested and cannot be passed to BaseModel. + config = { + "name": self.name, + "trainable": self.trainable, + "as_feature_extractor": self.as_feature_extractor, + "feature_keys": self.feature_keys, + # common + "input_shape": self.input_shape[1:], + "include_preprocessing": self.include_preprocessing, + "include_top": self.include_top, + "pooling": self.pooling, + "dropout_rate": self.dropout_rate, + "classes": self.classes, + "classifier_activation": self.classifier_activation, + "weights": self._weights, + } + return config + + def fix_config(self, config: typing.Dict): + return config diff --git a/kimm/models/feature_extractor_test.py b/kimm/models/base_model_test.py similarity index 93% rename from kimm/models/feature_extractor_test.py rename to kimm/models/base_model_test.py index 0987340..a0977ba 100644 --- a/kimm/models/feature_extractor_test.py +++ b/kimm/models/base_model_test.py @@ -3,10 +3,10 @@ from keras import random from keras.src import testing -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models.base_model import BaseModel -class SampleModel(FeatureExtractor): +class SampleModel(BaseModel): def __init__(self, **kwargs): inputs = layers.Input(shape=[224, 224, 3]) @@ -34,7 +34,7 @@ def get_config(self): return super().get_config() -class GhostNetTest(testing.TestCase, parameterized.TestCase): +class BaseModelTest(testing.TestCase, parameterized.TestCase): def test_feature_extractor(self): x = random.uniform([1, 224, 224, 3]) diff --git a/kimm/models/densenet.py b/kimm/models/densenet.py index f00959f..60668b5 100644 --- a/kimm/models/densenet.py +++ b/kimm/models/densenet.py @@ -1,13 +1,11 @@ import typing import keras -from keras import backend from keras import layers from keras import utils -from keras.src.applications import imagenet_utils from kimm.blocks import apply_conv2d_block -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models import BaseModel from kimm.utils import add_model_to_registry @@ -67,54 +65,26 @@ def apply_dense_transition_block( return x -class DenseNet(FeatureExtractor): +class DenseNet(BaseModel): def __init__( self, growth_rate: float = 32, num_blocks: typing.Sequence[int] = [6, 12, 24, 16], - input_tensor: keras.KerasTensor = None, - input_shape: typing.Optional[typing.Sequence[int]] = None, - include_preprocessing: bool = True, - include_top: bool = True, - pooling: typing.Optional[str] = None, - dropout_rate: float = 0.0, - classes: int = 1000, - classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet **kwargs, ): - # default_size - default_size = kwargs.pop("default_size", 224) - - # Prepare feature extraction - features = {} - - # Determine proper input shape - input_shape = imagenet_utils.obtain_input_shape( - input_shape, - default_size=default_size, - min_size=32, - data_format=backend.image_data_format(), - require_flatten=include_top, - weights=weights, + parsed_kwargs = self.parse_kwargs(kwargs) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], ) - - if input_tensor is None: - img_input = layers.Input(shape=input_shape) - else: - if not backend.is_keras_tensor(input_tensor): - img_input = layers.Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - x = img_input - # [0, 255] to [0, 1] and apply ImageNet mean and variance - if include_preprocessing: - x = layers.Rescaling(scale=1.0 / 255.0)(x) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x) + + # Prepare feature extraction + features = {} # Stem block stem_channel = growth_rate * 2 @@ -155,37 +125,48 @@ def __init__( x = layers.ReLU()(x) # Head - if include_top: - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="classifier" - )(x) + if parsed_kwargs["include_top"]: + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + ) else: - if pooling == "avg": + if parsed_kwargs["pooling"] == "avg": x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif pooling == "max": + elif parsed_kwargs["pooling"] == "max": x = layers.GlobalMaxPooling2D(name="max_pool")(x) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = utils.get_source_inputs(input_tensor) + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) else: inputs = img_input super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line + self.add_references(parsed_kwargs) self.growth_rate = growth_rate self.num_blocks = num_blocks - self.include_preprocessing = include_preprocessing - self.include_top = include_top - self.pooling = pooling - self.dropout_rate = dropout_rate - self.classes = classes - self.classifier_activation = classifier_activation - self._weights = weights # `self.weights` is been used internally + + def build_preprocessing(self, inputs): + # [0, 255] to [0, 1] and apply ImageNet mean and variance + x = layers.Rescaling(scale=1.0 / 255.0)(inputs) + x = layers.Normalization( + mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] + )(x) + return x + + def build_top(self, inputs, classes, classifier_activation, dropout_rate): + x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs) + x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + return x @staticmethod def available_feature_keys(): @@ -198,18 +179,7 @@ def available_feature_keys(): def get_config(self): config = super().get_config() config.update( - { - "growth_rate": self.growth_rate, - "num_blocks": self.num_blocks, - "input_shape": self.input_shape[1:], - "include_preprocessing": self.include_preprocessing, - "include_top": self.include_top, - "pooling": self.pooling, - "dropout_rate": self.dropout_rate, - "classes": self.classes, - "classifier_activation": self.classifier_activation, - "weights": self._weights, - } + {"growth_rate": self.growth_rate, "num_blocks": self.num_blocks} ) return config @@ -244,15 +214,15 @@ def __init__( super().__init__( 32, [6, 12, 24, 16], - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=288, **kwargs, @@ -278,15 +248,15 @@ def __init__( super().__init__( 48, [6, 12, 36, 24], - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=224, **kwargs, @@ -312,15 +282,15 @@ def __init__( super().__init__( 32, [6, 12, 32, 32], - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=224, **kwargs, @@ -346,15 +316,15 @@ def __init__( super().__init__( 32, [6, 12, 48, 32], - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=224, **kwargs, diff --git a/kimm/models/efficientnet.py b/kimm/models/efficientnet.py index 562c71b..03d8de3 100644 --- a/kimm/models/efficientnet.py +++ b/kimm/models/efficientnet.py @@ -2,16 +2,14 @@ import typing import keras -from keras import backend from keras import layers from keras import utils -from keras.src.applications import imagenet_utils from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_depthwise_separation_block from kimm.blocks import apply_inverted_residual_block from kimm.blocks import apply_se_block -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models import BaseModel from kimm.utils import add_model_to_registry from kimm.utils import make_divisible @@ -138,7 +136,7 @@ def apply_edge_residual_block( return x -class EfficientNet(FeatureExtractor): +class EfficientNet(BaseModel): def __init__( self, width: float = 1.0, @@ -148,15 +146,6 @@ def __init__( fix_stem_and_head_channels: bool = False, fix_first_and_last_blocks: bool = False, activation="swish", - input_tensor: keras.KerasTensor = None, - input_shape: typing.Optional[typing.Sequence[int]] = None, - include_preprocessing: bool = True, - include_top: bool = True, - pooling: typing.Optional[str] = None, - dropout_rate: float = 0.0, - classes: int = 1000, - classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet config: typing.Union[str, typing.List] = "v1", **kwargs, ): @@ -189,7 +178,6 @@ def __init__( f"Received: config={config}" ) # TF default config - default_size = kwargs.pop("default_size", 224) bn_epsilon = kwargs.pop("bn_epsilon", 1e-5) padding = kwargs.pop("padding", None) # EfficientNetV2Base config @@ -197,35 +185,19 @@ def __init__( # TinyNet config round_fn = kwargs.pop("round_fn", math.ceil) - # Prepare feature extraction - features = {} - - # Determine proper input shape - input_shape = imagenet_utils.obtain_input_shape( - input_shape, - default_size=default_size, - min_size=32, - data_format=backend.image_data_format(), - require_flatten=include_top, - weights=weights, + parsed_kwargs = self.parse_kwargs(kwargs) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], ) - - if input_tensor is None: - img_input = layers.Input(shape=input_shape) - else: - if not backend.is_keras_tensor(input_tensor): - img_input = layers.Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - x = img_input - # [0, 255] to [0, 1] and apply ImageNet mean and variance - if include_preprocessing: - x = layers.Rescaling(scale=1.0 / 255.0)(x) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x) + + # Prepare feature extraction + features = {} # Stem block stem_channel = ( @@ -322,28 +294,30 @@ def __init__( ) # Head - if include_top: - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - x = layers.Dropout(rate=dropout_rate, name="conv_head_dropout")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="classifier" - )(x) + if parsed_kwargs["include_top"]: + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + ) else: - if pooling == "avg": + if parsed_kwargs["pooling"] == "avg": x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif pooling == "max": + elif parsed_kwargs["pooling"] == "max": x = layers.GlobalMaxPooling2D(name="max_pool")(x) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = utils.get_source_inputs(input_tensor) + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) else: inputs = img_input super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line + self.add_references(parsed_kwargs) self.width = width self.depth = depth self.stem_channels = stem_channels @@ -351,15 +325,24 @@ def __init__( self.fix_stem_and_head_channels = fix_stem_and_head_channels self.fix_first_and_last_blocks = fix_first_and_last_blocks self.activation = activation - self.include_preprocessing = include_preprocessing - self.include_top = include_top - self.pooling = pooling - self.dropout_rate = dropout_rate - self.classes = classes - self.classifier_activation = classifier_activation - self._weights = weights # `self.weights` is been used internally self.config = config + def build_preprocessing(self, inputs): + # [0, 255] to [0, 1] and apply ImageNet mean and variance + x = layers.Rescaling(scale=1.0 / 255.0)(inputs) + x = layers.Normalization( + mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] + )(x) + return x + + def build_top(self, inputs, classes, classifier_activation, dropout_rate): + x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs) + x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + return x + @staticmethod def available_feature_keys(): # for: v1, v1_lite, v2_m, v2_l, v2_xl, tinynet @@ -384,14 +367,6 @@ def get_config(self): "fix_stem_and_head_channels": self.fix_stem_and_head_channels, "fix_first_and_last_blocks": self.fix_first_and_last_blocks, "activation": self.activation, - "input_shape": self.input_shape[1:], - "include_preprocessing": self.include_preprocessing, - "include_top": self.include_top, - "pooling": self.pooling, - "dropout_rate": self.dropout_rate, - "classes": self.classes, - "classifier_activation": self.classifier_activation, - "weights": self._weights, "config": self.config, } ) @@ -443,16 +418,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=224, bn_epsilon=1e-3, @@ -487,16 +462,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=240, bn_epsilon=1e-3, @@ -531,16 +506,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=260, bn_epsilon=1e-3, @@ -575,16 +550,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=300, bn_epsilon=1e-3, @@ -619,16 +594,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=380, bn_epsilon=1e-3, @@ -663,16 +638,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=456, bn_epsilon=1e-3, @@ -707,16 +682,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=528, bn_epsilon=1e-3, @@ -751,16 +726,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=600, bn_epsilon=1e-3, @@ -795,16 +770,16 @@ def __init__( True, True, "relu6", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=224, bn_epsilon=1e-3, @@ -839,16 +814,16 @@ def __init__( True, True, "relu6", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=240, bn_epsilon=1e-3, @@ -883,16 +858,16 @@ def __init__( True, True, "relu6", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=260, bn_epsilon=1e-3, @@ -927,16 +902,16 @@ def __init__( True, True, "relu6", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=300, bn_epsilon=1e-3, @@ -971,16 +946,16 @@ def __init__( True, True, "relu6", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=380, bn_epsilon=1e-3, @@ -1015,16 +990,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=300, bn_epsilon=1e-3, @@ -1067,16 +1042,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=384, bn_epsilon=1e-3, @@ -1111,16 +1086,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=384, bn_epsilon=1e-3, @@ -1155,16 +1130,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=384, bn_epsilon=1e-3, @@ -1199,16 +1174,16 @@ def __init__( True, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=192, bn_epsilon=1e-3, @@ -1251,16 +1226,16 @@ def __init__( True, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=192, bn_epsilon=1e-3, @@ -1303,16 +1278,16 @@ def __init__( True, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=208, bn_epsilon=1e-3, @@ -1356,16 +1331,16 @@ def __init__( True, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=240, bn_epsilon=1e-3, @@ -1408,16 +1383,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=192, round_fn=round, # tinynet config @@ -1450,16 +1425,16 @@ def __init__( True, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=192, round_fn=round, # tinynet config @@ -1492,16 +1467,16 @@ def __init__( True, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=188, round_fn=round, # tinynet config @@ -1534,16 +1509,16 @@ def __init__( True, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=152, round_fn=round, # tinynet config @@ -1576,16 +1551,16 @@ def __init__( True, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=106, round_fn=round, # tinynet config diff --git a/kimm/models/feature_extractor.py b/kimm/models/feature_extractor.py deleted file mode 100644 index 827a5d6..0000000 --- a/kimm/models/feature_extractor.py +++ /dev/null @@ -1,54 +0,0 @@ -import abc -import typing - -from keras import KerasTensor -from keras import models - - -class FeatureExtractor(models.Model): - @staticmethod - @abc.abstractmethod - def available_feature_keys(): - return [] - - def __init__( - self, - inputs, - outputs, - features: typing.Optional[typing.Dict[str, KerasTensor]] = None, - feature_keys: typing.Optional[typing.List[str]] = None, - **kwargs, - ): - self.as_feature_extractor = kwargs.pop("as_feature_extractor", False) - self.feature_keys = feature_keys - if self.as_feature_extractor: - if features is None: - raise ValueError( - "`features` must be set when " - f"`as_feature_extractor=True`. Got features={features}" - ) - if self.feature_keys is None: - self.feature_keys = list(features.keys()) - filtered_features = {} - for k in self.feature_keys: - if k not in features: - raise KeyError( - f"'{k}' is not a key of `features`. Available keys " - f"are: {list(features.keys())}" - ) - filtered_features[k] = features[k] - super().__init__(inputs=inputs, outputs=filtered_features, **kwargs) - else: - del features - super().__init__(inputs=inputs, outputs=outputs, **kwargs) - - def get_config(self): - # Don't chain to super here. The default `get_config()` for functional - # models is nested and cannot be passed to FeatureExtractor. - config = { - "name": self.name, - "trainable": self.trainable, - "as_feature_extractor": self.as_feature_extractor, - "feature_keys": self.feature_keys, - } - return config diff --git a/kimm/models/ghostnet.py b/kimm/models/ghostnet.py index 072b23e..4d28d9a 100644 --- a/kimm/models/ghostnet.py +++ b/kimm/models/ghostnet.py @@ -1,15 +1,13 @@ import typing import keras -from keras import backend from keras import layers from keras import ops from keras import utils -from keras.src.applications import imagenet_utils from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_se_block -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models import BaseModel from kimm.utils import add_model_to_registry from kimm.utils import make_divisible @@ -231,19 +229,10 @@ def apply_ghost_bottleneck( return out -class GhostNet(FeatureExtractor): +class GhostNet(BaseModel): def __init__( self, width: float = 1.0, - input_tensor: keras.KerasTensor = None, - input_shape: typing.Optional[typing.Sequence[int]] = None, - include_preprocessing: bool = True, - include_top: bool = True, - pooling: typing.Optional[str] = None, - dropout_rate: float = 0.2, - classes: int = 1000, - classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet config: typing.Union[str, typing.List] = "default", version: str = "v1", **kwargs, @@ -262,35 +251,21 @@ def __init__( f"Received version={version}" ) - # Prepare feature extraction - features = {} - - # Determine proper input shape - input_shape = imagenet_utils.obtain_input_shape( - input_shape, - default_size=224, - min_size=32, - data_format=backend.image_data_format(), - require_flatten=include_top, - weights=weights, + parsed_kwargs = self.parse_kwargs(kwargs) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], + require_flatten=parsed_kwargs["include_top"], + static_shape=True if version == "v2" else False, ) - - if input_tensor is None: - img_input = layers.Input(shape=input_shape) - else: - if not backend.is_keras_tensor(input_tensor): - img_input = layers.Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - x = img_input - # [0, 255] to [0, 1] and apply ImageNet mean and variance - if include_preprocessing: - x = layers.Rescaling(scale=1.0 / 255.0)(x) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x) + + # Prepare feature extraction + features = {} # stem stem_channels = make_divisible(16 * width, 4) @@ -333,42 +308,57 @@ def __init__( name=f"blocks_{current_block_idx+1}", ) - if include_top: - x = layers.GlobalAveragePooling2D(name="avg_pool", keepdims=True)(x) - x = layers.Conv2D(1280, 1, 1, use_bias=True, name="conv_head")(x) - x = layers.ReLU(name="conv_head_relu")(x) - x = layers.Flatten()(x) - x = layers.Dropout(rate=dropout_rate, name="conv_head_dropout")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="classifier" - )(x) + # Head + if parsed_kwargs["include_top"]: + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + ) else: - if pooling == "avg": + if parsed_kwargs["pooling"] == "avg": x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif pooling == "max": + elif parsed_kwargs["pooling"] == "max": x = layers.GlobalMaxPooling2D(name="max_pool")(x) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = utils.get_source_inputs(input_tensor) + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) else: inputs = img_input super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line + self.add_references(parsed_kwargs) self.width = width - self.include_preprocessing = include_preprocessing - self.include_top = include_top - self.pooling = pooling - self.dropout_rate = dropout_rate - self.classes = classes - self.classifier_activation = classifier_activation - self._weights = weights # `self.weights` is been used internally self.config = config self.version = version + def build_preprocessing(self, inputs): + # [0, 255] to [0, 1] and apply ImageNet mean and variance + x = layers.Rescaling(scale=1.0 / 255.0)(inputs) + x = layers.Normalization( + mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] + )(x) + return x + + def build_top(self, inputs, classes, classifier_activation, dropout_rate): + x = layers.GlobalAveragePooling2D(name="avg_pool", keepdims=True)( + inputs + ) + x = layers.Conv2D( + 1280, 1, 1, use_bias=True, activation="relu", name="conv_head" + )(x) + x = layers.Flatten()(x) + x = layers.Dropout(rate=dropout_rate, name="conv_head_dropout")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + return x + @staticmethod def available_feature_keys(): feature_keys = ["STEM_S2"] @@ -385,14 +375,6 @@ def get_config(self): config.update( { "width": self.width, - "input_shape": self.input_shape[1:], - "include_preprocessing": self.include_preprocessing, - "include_top": self.include_top, - "pooling": self.pooling, - "dropout_rate": self.dropout_rate, - "classes": self.classes, - "classifier_activation": self.classifier_activation, - "weights": self._weights, "config": self.config, "version": self.version, } @@ -400,7 +382,7 @@ def get_config(self): return config def fix_config(self, config): - unused_kwargs = ["width", "version"] + unused_kwargs = ["width", "config", "version"] for k in unused_kwargs: config.pop(k, None) return config @@ -430,17 +412,17 @@ def __init__( kwargs = self.fix_config(kwargs) super().__init__( 0.5, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, "v1", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -465,17 +447,17 @@ def __init__( kwargs = self.fix_config(kwargs) super().__init__( 1.0, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, "v1", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -500,17 +482,17 @@ def __init__( kwargs = self.fix_config(kwargs) super().__init__( 1.3, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, "v1", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -535,17 +517,17 @@ def __init__( kwargs = self.fix_config(kwargs) super().__init__( 1.0, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, "v2", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -570,17 +552,17 @@ def __init__( kwargs = self.fix_config(kwargs) super().__init__( 1.3, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, "v2", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -605,17 +587,17 @@ def __init__( kwargs = self.fix_config(kwargs) super().__init__( 1.6, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, "v2", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) diff --git a/kimm/models/inception_v3.py b/kimm/models/inception_v3.py index e30fd02..d12a001 100644 --- a/kimm/models/inception_v3.py +++ b/kimm/models/inception_v3.py @@ -2,13 +2,11 @@ import typing import keras -from keras import backend from keras import layers from keras import utils -from keras.src.applications import imagenet_utils from kimm.blocks import apply_conv2d_block -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models import BaseModel from kimm.utils import add_model_to_registry _apply_conv2d_block = functools.partial( @@ -204,53 +202,22 @@ def apply_inception_aux_block(inputs, classes, name="inception_aux_block"): return x -class InceptionV3Base(FeatureExtractor): - def __init__( - self, - has_aux_logits=False, - input_tensor: keras.KerasTensor = None, - input_shape: typing.Optional[typing.Sequence[int]] = None, - include_preprocessing: bool = True, - include_top: bool = True, - pooling: typing.Optional[str] = None, - dropout_rate: float = 0.0, - classes: int = 1000, - classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet - **kwargs, - ): - # default_size - default_size = kwargs.pop("default_size", 299) - - # Prepare feature extraction - features = {} - - # Determine proper input shape - input_shape = imagenet_utils.obtain_input_shape( - input_shape, - default_size=default_size, - min_size=32, - data_format=backend.image_data_format(), - require_flatten=include_top, - weights=weights, +class InceptionV3Base(BaseModel): + def __init__(self, has_aux_logits=False, **kwargs): + parsed_kwargs = self.parse_kwargs(kwargs, default_size=299) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], + require_flatten=parsed_kwargs["include_top"], ) - - if input_tensor is None: - img_input = layers.Input(shape=input_shape) - else: - if not backend.is_keras_tensor(input_tensor): - img_input = layers.Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - x = img_input - # [0, 255] to [0, 1] and apply ImageNet mean and variance - if include_preprocessing: - x = layers.Rescaling(scale=1.0 / 255.0)(x) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x) + + # Prepare feature extraction + features = {} # Stem block x = _apply_conv2d_block(x, 32, 3, 2, name="Conv2d_1a_3x3") @@ -278,7 +245,9 @@ def __init__( features["BLOCK2_S16"] = x if has_aux_logits: - aux_logits = apply_inception_aux_block(x, classes, "AuxLogits") + aux_logits = apply_inception_aux_block( + x, parsed_kwargs["classes"], "AuxLogits" + ) x = apply_inception_d_block(x, "Mixed_7a") x = apply_inception_e_block(x, "Mixed_7b") @@ -286,22 +255,23 @@ def __init__( features["BLOCK3_S32"] = x # Head - if include_top: - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="classifier" - )(x) + if parsed_kwargs["include_top"]: + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + ) else: - if pooling == "avg": + if parsed_kwargs["pooling"] == "avg": x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif pooling == "max": + elif parsed_kwargs["pooling"] == "max": x = layers.GlobalMaxPooling2D(name="max_pool")(x) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = utils.get_source_inputs(input_tensor) + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) else: inputs = img_input @@ -310,14 +280,24 @@ def __init__( super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line + self.add_references(parsed_kwargs) self.has_aux_logits = has_aux_logits - self.include_preprocessing = include_preprocessing - self.include_top = include_top - self.pooling = pooling - self.dropout_rate = dropout_rate - self.classes = classes - self.classifier_activation = classifier_activation - self._weights = weights # `self.weights` is been used internally + + def build_preprocessing(self, inputs): + # [0, 255] to [0, 1] and apply ImageNet mean and variance + x = layers.Rescaling(scale=1.0 / 255.0)(inputs) + x = layers.Normalization( + mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] + )(x) + return x + + def build_top(self, inputs, classes, classifier_activation, dropout_rate): + x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs) + x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + return x @staticmethod def available_feature_keys(): @@ -329,19 +309,7 @@ def available_feature_keys(): def get_config(self): config = super().get_config() - config.update( - { - "has_aux_logits": self.has_aux_logits, - "input_shape": self.input_shape[1:], - "include_preprocessing": self.include_preprocessing, - "include_top": self.include_top, - "pooling": self.pooling, - "dropout_rate": self.dropout_rate, - "classes": self.classes, - "classifier_activation": self.classifier_activation, - "weights": self._weights, - } - ) + config.update({"has_aux_logits": self.has_aux_logits}) return config def fix_config(self, config: typing.Dict): @@ -367,15 +335,15 @@ def __init__( kwargs = self.fix_config(kwargs) super().__init__( has_aux_logits, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) diff --git a/kimm/models/mobilenet_v2.py b/kimm/models/mobilenet_v2.py index 3da5f1c..2bf5c0f 100644 --- a/kimm/models/mobilenet_v2.py +++ b/kimm/models/mobilenet_v2.py @@ -2,15 +2,13 @@ import typing import keras -from keras import backend from keras import layers from keras import utils -from keras.src.applications import imagenet_utils from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_depthwise_separation_block from kimm.blocks import apply_inverted_residual_block -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models.base_model import BaseModel from kimm.utils import add_model_to_registry from kimm.utils import make_divisible @@ -26,21 +24,12 @@ ] -class MobileNetV2(FeatureExtractor): +class MobileNetV2(BaseModel): def __init__( self, width: float = 1.0, depth: float = 1.0, fix_stem_and_head_channels: bool = False, - input_tensor: keras.KerasTensor = None, - input_shape: typing.Optional[typing.Sequence[int]] = None, - include_preprocessing: bool = True, - include_top: bool = True, - pooling: typing.Optional[str] = None, - dropout_rate: float = 0.0, - classes: int = 1000, - classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet config: typing.Union[str, typing.List] = "default", **kwargs, ): @@ -53,35 +42,19 @@ def __init__( f"Received: config={config}" ) - # Prepare feature extraction - features = {} - - # Determine proper input shape - input_shape = imagenet_utils.obtain_input_shape( - input_shape, - default_size=224, - min_size=32, - data_format=backend.image_data_format(), - require_flatten=include_top, - weights=weights, + parsed_kwargs = self.parse_kwargs(kwargs) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], ) - - if input_tensor is None: - img_input = layers.Input(shape=input_shape) - else: - if not backend.is_keras_tensor(input_tensor): - img_input = layers.Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - x = img_input - # [0, 255] to [0, 1] and apply ImageNet mean and variance - if include_preprocessing: - x = layers.Rescaling(scale=1.0 / 255.0)(x) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x) + + # Prepare feature extraction + features = {} # stem stem_channel = ( @@ -134,40 +107,52 @@ def __init__( x, head_channels, 1, 1, activation="relu6", name="conv_head" ) - if include_top: - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - x = layers.Dropout(rate=dropout_rate, name="conv_head_dropout")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="classifier" - )(x) + # Head + if parsed_kwargs["include_top"]: + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + ) else: - if pooling == "avg": + if parsed_kwargs["pooling"] == "avg": x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif pooling == "max": + elif parsed_kwargs["pooling"] == "max": x = layers.GlobalMaxPooling2D(name="max_pool")(x) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = utils.get_source_inputs(input_tensor) + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) else: inputs = img_input super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line + self.add_references(parsed_kwargs) self.width = width self.depth = depth self.fix_stem_and_head_channels = fix_stem_and_head_channels - self.include_preprocessing = include_preprocessing - self.include_top = include_top - self.pooling = pooling - self.dropout_rate = dropout_rate - self.classes = classes - self.classifier_activation = classifier_activation - self._weights = weights # `self.weights` is been used internally self.config = config + def build_preprocessing(self, inputs): + # [0, 255] to [0, 1] and apply ImageNet mean and variance + x = layers.Rescaling(scale=1.0 / 255.0)(inputs) + x = layers.Normalization( + mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] + )(x) + return x + + def build_top(self, inputs, classes, classifier_activation, dropout_rate): + x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs) + x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + return x + @staticmethod def available_feature_keys(): feature_keys = ["STEM_S2"] @@ -186,21 +171,18 @@ def get_config(self): "width": self.width, "depth": self.depth, "fix_stem_and_head_channels": self.fix_stem_and_head_channels, - "input_shape": self.input_shape[1:], - "include_preprocessing": self.include_preprocessing, - "include_top": self.include_top, - "pooling": self.pooling, - "dropout_rate": self.dropout_rate, - "classes": self.classes, - "classifier_activation": self.classifier_activation, - "weights": self._weights, "config": self.config, } ) return config def fix_config(self, config): - unused_kwargs = ["width", "depth", "fix_stem_and_head_channels"] + unused_kwargs = [ + "width", + "depth", + "fix_stem_and_head_channels", + "config", + ] for k in unused_kwargs: config.pop(k, None) return config @@ -232,16 +214,16 @@ def __init__( 0.5, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -268,16 +250,16 @@ def __init__( 1.0, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -304,16 +286,16 @@ def __init__( 1.1, 1.2, True, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -340,16 +322,16 @@ def __init__( 1.2, 1.4, True, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -376,16 +358,16 @@ def __init__( 1.4, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) diff --git a/kimm/models/mobilenet_v3.py b/kimm/models/mobilenet_v3.py index dfbef04..2e17da3 100644 --- a/kimm/models/mobilenet_v3.py +++ b/kimm/models/mobilenet_v3.py @@ -2,15 +2,13 @@ import typing import keras -from keras import backend from keras import layers from keras import utils -from keras.src.applications import imagenet_utils from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_depthwise_separation_block from kimm.blocks import apply_inverted_residual_block -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models.base_model import BaseModel from kimm.utils import add_model_to_registry from kimm.utils import make_divisible @@ -82,21 +80,12 @@ ] -class MobileNetV3(FeatureExtractor): +class MobileNetV3(BaseModel): def __init__( self, width: float = 1.0, depth: float = 1.0, fix_stem_and_head_channels: bool = False, - input_tensor: keras.KerasTensor = None, - input_shape: typing.Optional[typing.Sequence[int]] = None, - include_preprocessing: bool = True, - include_top: bool = True, - pooling: typing.Optional[str] = None, - dropout_rate: float = 0.0, - classes: int = 1000, - classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet config: typing.Union[str, typing.List] = "large", minimal: bool = False, **kwargs, @@ -128,35 +117,19 @@ def __init__( bn_epsilon = kwargs.pop("bn_epsilon", 1e-5) padding = kwargs.pop("padding", None) - # Prepare feature extraction - features = {} - - # Determine proper input shape - input_shape = imagenet_utils.obtain_input_shape( - input_shape, - default_size=224, - min_size=32, - data_format=backend.image_data_format(), - require_flatten=include_top, - weights=weights, + parsed_kwargs = self.parse_kwargs(kwargs) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], ) - - if input_tensor is None: - img_input = layers.Input(shape=input_shape) - else: - if not backend.is_keras_tensor(input_tensor): - img_input = layers.Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - x = img_input - # [0, 255] to [0, 1] and apply ImageNet mean and variance - if include_preprocessing: - x = layers.Rescaling(scale=1.0 / 255.0)(x) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x) + + # Prepare feature extraction + features = {} # stem stem_channel = ( @@ -247,8 +220,8 @@ def __init__( current_stride *= s features[f"BLOCK{current_stage_idx}_S{current_stride}"] = x - if include_top: - x = layers.GlobalAveragePooling2D(name="avg_pool", keepdims=True)(x) + # Head + if parsed_kwargs["include_top"]: if fix_stem_and_head_channels: conv_head_channels = conv_head_channels else: @@ -256,46 +229,69 @@ def __init__( conv_head_channels, make_divisible(conv_head_channels * width), ) - x = layers.Conv2D( - conv_head_channels, 1, 1, use_bias=True, name="conv_head" - )(x) - x = layers.Activation( - force_activation or "hard_swish", name="act2" - )(x) - x = layers.Flatten()(x) - x = layers.Dropout(rate=dropout_rate, name="conv_head_dropout")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="classifier" - )(x) + head_activation = force_activation or "hard_swish" + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + conv_head_channels=conv_head_channels, + head_activation=head_activation, + ) else: - if pooling == "avg": + if parsed_kwargs["pooling"] == "avg": x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif pooling == "max": + elif parsed_kwargs["pooling"] == "max": x = layers.GlobalMaxPooling2D(name="max_pool")(x) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = utils.get_source_inputs(input_tensor) + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) else: inputs = img_input super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line + self.add_references(parsed_kwargs) self.width = width self.depth = depth self.fix_stem_and_head_channels = fix_stem_and_head_channels - self.include_preprocessing = include_preprocessing - self.include_top = include_top - self.pooling = pooling - self.dropout_rate = dropout_rate - self.classes = classes - self.classifier_activation = classifier_activation - self._weights = weights # `self.weights` is been used internally self.config = config self.minimal = minimal + def build_preprocessing(self, inputs): + # [0, 255] to [0, 1] and apply ImageNet mean and variance + x = layers.Rescaling(scale=1.0 / 255.0)(inputs) + x = layers.Normalization( + mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] + )(x) + return x + + def build_top( + self, + inputs, + classes, + classifier_activation, + dropout_rate, + conv_head_channels, + head_activation, + ): + x = layers.GlobalAveragePooling2D(name="avg_pool", keepdims=True)( + inputs + ) + x = layers.Conv2D( + conv_head_channels, 1, 1, use_bias=True, name="conv_head" + )(x) + x = layers.Activation(head_activation, name="act2")(x) + x = layers.Flatten()(x) + x = layers.Dropout(rate=dropout_rate, name="conv_head_dropout")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + return x + @staticmethod def available_feature_keys(): raise NotImplementedError() @@ -307,14 +303,6 @@ def get_config(self): "width": self.width, "depth": self.depth, "fix_stem_and_head_channels": self.fix_stem_and_head_channels, - "input_shape": self.input_shape[1:], - "include_preprocessing": self.include_preprocessing, - "include_top": self.include_top, - "pooling": self.pooling, - "dropout_rate": self.dropout_rate, - "classes": self.classes, - "classifier_activation": self.classifier_activation, - "weights": self._weights, "config": self.config, "minimal": self.minimal, } @@ -326,6 +314,7 @@ def fix_config(self, config): "width", "depth", "fix_stem_and_head_channels", + "config", "minimal", ] for k in unused_kwargs: @@ -359,16 +348,16 @@ def __init__( 0.5, 1.0, True, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -403,16 +392,16 @@ def __init__( 0.75, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -447,16 +436,16 @@ def __init__( 1.0, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -492,17 +481,17 @@ def __init__( 1.0, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, - minimal=True, + True, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, bn_epsilon=1e-3, padding="same", @@ -539,16 +528,16 @@ def __init__( 1.0, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -587,17 +576,17 @@ def __init__( 1.0, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, - minimal=True, + True, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, bn_epsilon=1e-3, padding="same", @@ -638,16 +627,16 @@ def __init__( 0.35, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -683,16 +672,16 @@ def __init__( 0.5, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -728,16 +717,16 @@ def __init__( 0.75, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -773,16 +762,16 @@ def __init__( 1.0, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -818,16 +807,16 @@ def __init__( 1.5, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) diff --git a/kimm/models/mobilevit.py b/kimm/models/mobilevit.py index 813a019..d2c601f 100644 --- a/kimm/models/mobilevit.py +++ b/kimm/models/mobilevit.py @@ -2,16 +2,14 @@ import typing import keras -from keras import backend from keras import layers from keras import ops from keras import utils -from keras.src.applications import imagenet_utils from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_inverted_residual_block from kimm.blocks import apply_transformer_block -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models.base_model import BaseModel from kimm.utils import add_model_to_registry from kimm.utils import make_divisible @@ -163,21 +161,12 @@ def apply_mobilevit_block( return x -class MobileViT(FeatureExtractor): +class MobileViT(BaseModel): def __init__( self, stem_channels: int = 16, head_channels: int = 640, activation="swish", - input_tensor: keras.KerasTensor = None, - input_shape: typing.Optional[typing.Sequence[int]] = None, - include_preprocessing: bool = True, - include_top: bool = True, - pooling: typing.Optional[str] = None, - dropout_rate: float = 0.1, - classes: int = 1000, - classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet config: str = "v1_s", **kwargs, ): @@ -194,35 +183,20 @@ def __init__( f"Received: config={config}" ) - # Prepare feature extraction - features = {} - - # Determine proper input shape - input_shape = imagenet_utils.obtain_input_shape( - input_shape, - default_size=256, - min_size=32, - data_format=backend.image_data_format(), - require_flatten=include_top, - weights=weights, + parsed_kwargs = self.parse_kwargs(kwargs, 256) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], + static_shape=True, ) - - if input_tensor is None: - img_input = layers.Input(shape=input_shape) - else: - if not backend.is_keras_tensor(input_tensor): - img_input = layers.Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - x = img_input - # [0, 255] to [0, 1] and apply ImageNet mean and variance - if include_preprocessing: - x = layers.Rescaling(scale=1.0 / 255.0)(x) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x) + + # Prepare feature extraction + features = {} # stem x = apply_conv2d_block( @@ -281,39 +255,52 @@ def __init__( x, head_channels, 1, 1, activation=activation, name="final_conv" ) - if include_top: - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - x = layers.Dropout(dropout_rate, name="head_drop")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="head_fc" - )(x) + # Head + if parsed_kwargs["include_top"]: + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + ) else: - if pooling == "avg": + if parsed_kwargs["pooling"] == "avg": x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif pooling == "max": + elif parsed_kwargs["pooling"] == "max": x = layers.GlobalMaxPooling2D(name="max_pool")(x) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = utils.get_source_inputs(input_tensor) + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) else: inputs = img_input super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + # All references to `self` below this line + self.add_references(parsed_kwargs) self.stem_channels = stem_channels self.head_channels = head_channels self.activation = activation - self.include_preprocessing = include_preprocessing - self.include_top = include_top - self.pooling = pooling - self.dropout_rate = dropout_rate - self.classes = classes - self.classifier_activation = classifier_activation - self._weights = weights # `self.weights` is been used internally self.config = config + def build_preprocessing(self, inputs): + # [0, 255] to [0, 1] and apply ImageNet mean and variance + x = layers.Rescaling(scale=1.0 / 255.0)(inputs) + x = layers.Normalization( + mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] + )(x) + return x + + def build_top(self, inputs, classes, classifier_activation, dropout_rate): + x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs) + x = layers.Dropout(dropout_rate, name="head_drop")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="head_fc" + )(x) + return x + @staticmethod def available_feature_keys(): feature_keys = ["STEM_S2"] @@ -329,21 +316,18 @@ def get_config(self): "stem_channels": self.stem_channels, "head_channels": self.head_channels, "activation": self.activation, - "input_shape": self.input_shape[1:], - "include_preprocessing": self.include_preprocessing, - "include_top": self.include_top, - "pooling": self.pooling, - "dropout_rate": self.dropout_rate, - "classes": self.classes, - "classifier_activation": self.classifier_activation, - "weights": self._weights, "config": self.config, } ) return config def fix_config(self, config): - unused_kwargs = ["stem_channels", "head_channels", "activation"] + unused_kwargs = [ + "stem_channels", + "head_channels", + "activation", + "config", + ] for k in unused_kwargs: config.pop(k, None) return config @@ -370,16 +354,16 @@ def __init__( 16, 640, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -406,16 +390,16 @@ def __init__( 16, 384, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -442,16 +426,16 @@ def __init__( 16, 320, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) diff --git a/kimm/models/resnet.py b/kimm/models/resnet.py index 8354412..b729da7 100644 --- a/kimm/models/resnet.py +++ b/kimm/models/resnet.py @@ -7,7 +7,7 @@ from keras.src.applications import imagenet_utils from kimm.blocks import apply_conv2d_block -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models.base_model import BaseModel from kimm.utils import add_model_to_registry @@ -106,7 +106,7 @@ def apply_bottleneck_block( return x -class ResNet(FeatureExtractor): +class ResNet(BaseModel): def __init__( self, block_fn: str, diff --git a/kimm/models/vision_transformer.py b/kimm/models/vision_transformer.py index ad40fa9..2abb21c 100644 --- a/kimm/models/vision_transformer.py +++ b/kimm/models/vision_transformer.py @@ -8,11 +8,11 @@ from kimm import layers as kimm_layers from kimm.blocks import apply_transformer_block -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models.base_model import BaseModel from kimm.utils import add_model_to_registry -class VisionTransformer(FeatureExtractor): +class VisionTransformer(BaseModel): def __init__( self, patch_size: int, diff --git a/kimm/utils/model_registry.py b/kimm/utils/model_registry.py index 36ba7d5..15a2517 100644 --- a/kimm/utils/model_registry.py +++ b/kimm/utils/model_registry.py @@ -32,11 +32,11 @@ def clear_registry(): def add_model_to_registry(model_cls, has_pretrained=False): - from kimm.models.feature_extractor import FeatureExtractor + from kimm.models.base_model import BaseModel support_feature = False available_feature_keys = [] - if issubclass(model_cls, FeatureExtractor): + if issubclass(model_cls, BaseModel): support_feature = True available_feature_keys = model_cls.available_feature_keys() for info in MODEL_REGISTRY: diff --git a/kimm/utils/model_registry_test.py b/kimm/utils/model_registry_test.py index 73c5961..f979811 100644 --- a/kimm/utils/model_registry_test.py +++ b/kimm/utils/model_registry_test.py @@ -1,7 +1,7 @@ from keras import models from keras.src import testing -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models.base_model import BaseModel from kimm.utils.model_registry import MODEL_REGISTRY from kimm.utils.model_registry import add_model_to_registry from kimm.utils.model_registry import clear_registry @@ -12,7 +12,7 @@ class DummyModel(models.Model): pass -class DummyFeatureExtractor(FeatureExtractor): +class DummyFeatureExtractor(BaseModel): @staticmethod def available_feature_keys(): return ["A", "B", "C"] diff --git a/shell/export.sh b/shell/export.sh new file mode 100755 index 0000000..27c209a --- /dev/null +++ b/shell/export.sh @@ -0,0 +1,11 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES= +export TF_CPP_MIN_LOG_LEVEL=3 +python3 -m tools.convert_densenet_from_timm && +python3 -m tools.convert_efficientnet_from_timm && +python3 -m tools.convert_ghostnet_from_timm && +python3 -m tools.convert_inception_v3_from_timm && +python3 -m tools.convert_mobilenet_v2_from_timm && +python3 -m tools.convert_mobilenet_v3_from_timm && +python3 -m tools.convert_mobilevit_from_timm && +echo "Export finished successfully!" diff --git a/tools/convert_efficientnet_from_timm.py b/tools/convert_efficientnet_from_timm.py index 21b64eb..3af0f24 100644 --- a/tools/convert_efficientnet_from_timm.py +++ b/tools/convert_efficientnet_from_timm.py @@ -191,6 +191,6 @@ Save converted model """ os.makedirs("exported", exist_ok=True) - export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_ghostnet_from_timm.py b/tools/convert_ghostnet_from_timm.py index 320e287..7ffd688 100644 --- a/tools/convert_ghostnet_from_timm.py +++ b/tools/convert_ghostnet_from_timm.py @@ -149,6 +149,6 @@ Save converted model """ os.makedirs("exported", exist_ok=True) - export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_mobilenet_v2_from_timm.py b/tools/convert_mobilenet_v2_from_timm.py index 0b01ce1..ec55782 100644 --- a/tools/convert_mobilenet_v2_from_timm.py +++ b/tools/convert_mobilenet_v2_from_timm.py @@ -135,6 +135,6 @@ Save converted model """ os.makedirs("exported", exist_ok=True) - export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_mobilenet_v3_from_timm.py b/tools/convert_mobilenet_v3_from_timm.py index 0c3e677..0ba712c 100644 --- a/tools/convert_mobilenet_v3_from_timm.py +++ b/tools/convert_mobilenet_v3_from_timm.py @@ -163,6 +163,6 @@ Save converted model """ os.makedirs("exported", exist_ok=True) - export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_mobilevit_from_timm.py b/tools/convert_mobilevit_from_timm.py index 4c53d5c..42aac54 100644 --- a/tools/convert_mobilevit_from_timm.py +++ b/tools/convert_mobilevit_from_timm.py @@ -139,6 +139,6 @@ Save converted model """ os.makedirs("exported", exist_ok=True) - export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_resnet_from_timm.py b/tools/convert_resnet_from_timm.py index 3a8823f..ad91cad 100644 --- a/tools/convert_resnet_from_timm.py +++ b/tools/convert_resnet_from_timm.py @@ -126,6 +126,6 @@ Save converted model """ os.makedirs("exported", exist_ok=True) - export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_vit_from_timm.py b/tools/convert_vit_from_timm.py index 6da3175..f0fd93d 100644 --- a/tools/convert_vit_from_timm.py +++ b/tools/convert_vit_from_timm.py @@ -136,6 +136,6 @@ Save converted model """ os.makedirs("exported", exist_ok=True) - export_path = f"exported/{keras_model.name.lower()}_imagenet_384.keras" + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" keras_model.save(export_path) print(f"Export to {export_path}") From 3c709b2abb43320c8c6cdd06fbeb279a630a09c0 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Wed, 17 Jan 2024 09:01:17 +0800 Subject: [PATCH 06/10] Refactor `BaseModel` --- kimm/models/resnet.py | 201 ++++++++++-------------- kimm/models/vision_transformer.py | 250 ++++++++++++++---------------- 2 files changed, 202 insertions(+), 249 deletions(-) diff --git a/kimm/models/resnet.py b/kimm/models/resnet.py index b729da7..2d07735 100644 --- a/kimm/models/resnet.py +++ b/kimm/models/resnet.py @@ -1,10 +1,8 @@ import typing import keras -from keras import backend from keras import layers from keras import utils -from keras.src.applications import imagenet_utils from kimm.blocks import apply_conv2d_block from kimm.models.base_model import BaseModel @@ -108,54 +106,27 @@ def apply_bottleneck_block( class ResNet(BaseModel): def __init__( - self, - block_fn: str, - num_blocks: typing.Sequence[int], - input_tensor: keras.KerasTensor = None, - input_shape: typing.Optional[typing.Sequence[int]] = None, - include_preprocessing: bool = True, - include_top: bool = True, - pooling: typing.Optional[str] = None, - dropout_rate: float = 0.0, - classes: int = 1000, - classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet - **kwargs, + self, block_fn: str, num_blocks: typing.Sequence[int], **kwargs ): if block_fn not in ("basic", "bottleneck"): raise ValueError( "`block_fn` must be one of ('basic', 'bottelneck'). " f"Received: block_fn={block_fn}" ) - # Prepare feature extraction - features = {} - # Determine proper input shape - input_shape = imagenet_utils.obtain_input_shape( - input_shape, - default_size=224, - min_size=32, - data_format=backend.image_data_format(), - require_flatten=include_top, - weights=weights, + parsed_kwargs = self.parse_kwargs(kwargs) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], ) - - if input_tensor is None: - img_input = layers.Input(shape=input_shape) - else: - if not backend.is_keras_tensor(input_tensor): - img_input = layers.Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - x = img_input - # [0, 255] to [0, 1] and apply ImageNet mean and variance - if include_preprocessing: - x = layers.Rescaling(scale=1.0 / 255.0)(x) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x) + + # Prepare feature extraction + features = {} # stem stem_channels = 64 @@ -189,38 +160,49 @@ def __init__( # add feature features[f"BLOCK{current_stage_idx}_S{current_stride}"] = x - if include_top: - x = layers.GlobalAveragePooling2D(name="avg_pool", keepdims=True)(x) - x = layers.Flatten()(x) - x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="fc" - )(x) + # Head + if parsed_kwargs["include_top"]: + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + ) else: - if pooling == "avg": + if parsed_kwargs["pooling"] == "avg": x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif pooling == "max": + elif parsed_kwargs["pooling"] == "max": x = layers.GlobalMaxPooling2D(name="max_pool")(x) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = utils.get_source_inputs(input_tensor) + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) else: inputs = img_input super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line + self.add_references(parsed_kwargs) self.block_fn = block_fn self.num_blocks = num_blocks - self.include_preprocessing = include_preprocessing - self.include_top = include_top - self.pooling = pooling - self.dropout_rate = dropout_rate - self.classes = classes - self.classifier_activation = classifier_activation - self._weights = weights # `self.weights` is been used internally + + def build_preprocessing(self, inputs): + # [0, 255] to [0, 1] and apply ImageNet mean and variance + x = layers.Rescaling(scale=1.0 / 255.0)(inputs) + x = layers.Normalization( + mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] + )(x) + return x + + def build_top(self, inputs, classes, classifier_activation, dropout_rate): + x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs) + x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) + x = layers.Dense(classes, activation=classifier_activation, name="fc")( + x + ) + return x @staticmethod def available_feature_keys(): @@ -233,18 +215,7 @@ def available_feature_keys(): def get_config(self): config = super().get_config() config.update( - { - "block_fn": self.block_fn, - "num_blocks": self.num_blocks, - "input_shape": self.input_shape[1:], - "include_preprocessing": self.include_preprocessing, - "include_top": self.include_top, - "pooling": self.pooling, - "dropout_rate": self.dropout_rate, - "classes": self.classes, - "classifier_activation": self.classifier_activation, - "weights": self._weights, - } + {"block_fn": self.block_fn, "num_blocks": self.num_blocks} ) return config @@ -279,15 +250,15 @@ def __init__( super().__init__( "basic", [2, 2, 2, 2], - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -312,15 +283,15 @@ def __init__( super().__init__( "basic", [3, 4, 6, 3], - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -345,15 +316,15 @@ def __init__( super().__init__( "bottleneck", [3, 4, 6, 3], - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -378,15 +349,15 @@ def __init__( super().__init__( "bottleneck", [3, 4, 23, 3], - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -411,15 +382,15 @@ def __init__( super().__init__( "bottleneck", [3, 8, 36, 3], - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) diff --git a/kimm/models/vision_transformer.py b/kimm/models/vision_transformer.py index 2abb21c..d2f0b4a 100644 --- a/kimm/models/vision_transformer.py +++ b/kimm/models/vision_transformer.py @@ -1,10 +1,8 @@ import typing import keras -from keras import backend from keras import layers from keras import utils -from keras.src.applications import imagenet_utils from kimm import layers as kimm_layers from kimm.blocks import apply_transformer_block @@ -22,44 +20,28 @@ def __init__( mlp_ratio: float = 4.0, use_qkv_bias: bool = True, use_qk_norm: bool = False, - input_tensor: keras.KerasTensor = None, - input_shape: typing.Optional[typing.Sequence[int]] = None, - include_preprocessing: bool = True, - include_top: bool = True, - pooling: typing.Optional[str] = None, pos_dropout_rate: float = 0.0, - dropout_rate: float = 0.1, - classes: int = 1000, - classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet **kwargs, ): - # Prepare feature extraction - features = {} - - # Determine proper input shape - input_shape = imagenet_utils.obtain_input_shape( - input_shape, - default_size=384, - min_size=32, - data_format=backend.image_data_format(), - require_flatten=include_top, - weights=weights, + parsed_kwargs = self.parse_kwargs(kwargs, 384) + if parsed_kwargs["pooling"] is not None: + raise ValueError( + "`VisionTransformer` doesn't support `pooling`. " + f"Received: pooling={parsed_kwargs['pooling']}" + ) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], + static_shape=True, ) - - if input_tensor is None: - img_input = layers.Input(shape=input_shape) - else: - if not backend.is_keras_tensor(input_tensor): - img_input = layers.Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - x = img_input - # [0, 255] to [-1, 1] - if include_preprocessing: - x = layers.Rescaling(scale=1.0 / 127.5, offset=-1.0)(x) + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x) + + # Prepare feature extraction + features = {} # patch embedding x = layers.Conv2D( @@ -89,27 +71,26 @@ def __init__( features[f"BLOCK{i}"] = x x = layers.LayerNormalization(epsilon=1e-6, name="norm")(x) - if include_top: - x = x[:, 0] # class token - x = layers.Dropout(dropout_rate, name="head_drop")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="head" - )(x) - else: - if pooling == "avg": - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif pooling == "max": - x = layers.GlobalMaxPooling2D(name="max_pool")(x) + # Head + if parsed_kwargs["include_top"]: + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + ) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = utils.get_source_inputs(input_tensor) + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) else: inputs = img_input super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + # All references to `self` below this line + self.add_references(parsed_kwargs) self.patch_size = patch_size self.embed_dim = embed_dim self.depth = depth @@ -117,13 +98,20 @@ def __init__( self.mlp_ratio = mlp_ratio self.use_qkv_bias = use_qkv_bias self.use_qk_norm = use_qk_norm - self.include_preprocessing = include_preprocessing - self.include_top = include_top - self.pooling = pooling - self.dropout_rate = dropout_rate - self.classes = classes - self.classifier_activation = classifier_activation - self._weights = weights # `self.weights` is been used internally + self.pos_dropout_rate = pos_dropout_rate + + def build_preprocessing(self, inputs): + # [0, 255] to [-1, 1] + x = layers.Rescaling(scale=1.0 / 127.5, offset=-1.0)(inputs) + return x + + def build_top(self, inputs, classes, classifier_activation, dropout_rate): + x = inputs[:, 0] # class token + x = layers.Dropout(dropout_rate, name="head_drop")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="head" + )(x) + return x @staticmethod def available_feature_keys(): @@ -140,14 +128,7 @@ def get_config(self): "mlp_ratio": self.mlp_ratio, "use_qkv_bias": self.use_qkv_bias, "use_qk_norm": self.use_qk_norm, - "input_shape": self.input_shape[1:], - "include_preprocessing": self.include_preprocessing, - "include_top": self.include_top, - "pooling": self.pooling, - "dropout_rate": self.dropout_rate, - "classes": self.classes, - "classifier_activation": self.classifier_activation, - "weights": self._weights, + "pos_dropout_rate": self.pos_dropout_rate, } ) return config @@ -161,6 +142,7 @@ def fix_config(self, config): "mlp_ratio", "use_qkv_bias", "use_qk_norm", + "pos_dropout_rate", ] for k in unused_kwargs: config.pop(k, None) @@ -200,16 +182,16 @@ def __init__( mlp_ratio, use_qkv_bias, use_qk_norm, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, pos_dropout_rate, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -249,16 +231,16 @@ def __init__( mlp_ratio, use_qkv_bias, use_qk_norm, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, pos_dropout_rate, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -298,16 +280,16 @@ def __init__( mlp_ratio, use_qkv_bias, use_qk_norm, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, pos_dropout_rate, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -347,16 +329,16 @@ def __init__( mlp_ratio, use_qkv_bias, use_qk_norm, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, pos_dropout_rate, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -396,16 +378,16 @@ def __init__( mlp_ratio, use_qkv_bias, use_qk_norm, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, pos_dropout_rate, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -445,16 +427,16 @@ def __init__( mlp_ratio, use_qkv_bias, use_qk_norm, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, pos_dropout_rate, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -494,16 +476,16 @@ def __init__( mlp_ratio, use_qkv_bias, use_qk_norm, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, pos_dropout_rate, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -543,16 +525,16 @@ def __init__( mlp_ratio, use_qkv_bias, use_qk_norm, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, pos_dropout_rate, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) From 576c712da9311389855e4e6de14da324a857a462 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Wed, 17 Jan 2024 09:24:51 +0800 Subject: [PATCH 07/10] Simplify `build_preprocessing` and `build_top` --- kimm/models/base_model.py | 32 +++++++++++++++++++++++----- kimm/models/densenet.py | 18 +--------------- kimm/models/efficientnet.py | 18 +--------------- kimm/models/ghostnet.py | 10 +-------- kimm/models/inception_v3.py | 18 +--------------- kimm/models/mobilenet_v2.py | 18 +--------------- kimm/models/mobilenet_v3.py | 10 +-------- kimm/models/mobilevit.py | 18 +--------------- kimm/models/resnet.py | 18 +--------------- kimm/models/vision_transformer.py | 7 +----- tools/convert_mobilevit_from_timm.py | 2 ++ tools/convert_resnet_from_timm.py | 2 ++ 12 files changed, 40 insertions(+), 131 deletions(-) diff --git a/kimm/models/base_model.py b/kimm/models/base_model.py index ee1f7e2..45204fd 100644 --- a/kimm/models/base_model.py +++ b/kimm/models/base_model.py @@ -87,13 +87,33 @@ def determine_input_tensor( x = input_tensor return x - def build_preprocessing(self, inputs): - # TODO: add docstring - raise NotImplementedError + def build_preprocessing(self, inputs, mode="imagenet"): + if mode == "imagenet": + # [0, 255] to [0, 1] and apply ImageNet mean and variance + x = layers.Rescaling(scale=1.0 / 255.0)(inputs) + x = layers.Normalization( + mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] + )(x) + elif mode == "0_1": + # [0, 255] to [-1, 1] + x = layers.Rescaling(scale=1.0 / 255.0)(inputs) + elif mode == "-1_1": + # [0, 255] to [-1, 1] + x = layers.Rescaling(scale=1.0 / 127.5, offset=-1.0)(inputs) + else: + raise ValueError( + "`mode` must be one of ('imagenet', '0_1', '-1_1'). " + f"Received: mode={mode}" + ) + return x def build_top(self, inputs, classes, classifier_activation, dropout_rate): - # TODO: add docstring - raise NotImplementedError + x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs) + x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + return x def add_references(self, parsed_kwargs: typing.Dict[str, typing.Any]): self.include_preprocessing = parsed_kwargs["include_preprocessing"] @@ -115,8 +135,10 @@ def get_config(self): # Don't chain to super here. The default `get_config()` for functional # models is nested and cannot be passed to BaseModel. config = { + # models.Model "name": self.name, "trainable": self.trainable, + # feature extractor "as_feature_extractor": self.as_feature_extractor, "feature_keys": self.feature_keys, # common diff --git a/kimm/models/densenet.py b/kimm/models/densenet.py index 60668b5..479a110 100644 --- a/kimm/models/densenet.py +++ b/kimm/models/densenet.py @@ -81,7 +81,7 @@ def __init__( x = img_input if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x) + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -152,22 +152,6 @@ def __init__( self.growth_rate = growth_rate self.num_blocks = num_blocks - def build_preprocessing(self, inputs): - # [0, 255] to [0, 1] and apply ImageNet mean and variance - x = layers.Rescaling(scale=1.0 / 255.0)(inputs) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) - return x - - def build_top(self, inputs, classes, classifier_activation, dropout_rate): - x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs) - x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="classifier" - )(x) - return x - @staticmethod def available_feature_keys(): feature_keys = ["STEM_S4"] diff --git a/kimm/models/efficientnet.py b/kimm/models/efficientnet.py index 03d8de3..1ec1221 100644 --- a/kimm/models/efficientnet.py +++ b/kimm/models/efficientnet.py @@ -194,7 +194,7 @@ def __init__( x = img_input if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x) + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -327,22 +327,6 @@ def __init__( self.activation = activation self.config = config - def build_preprocessing(self, inputs): - # [0, 255] to [0, 1] and apply ImageNet mean and variance - x = layers.Rescaling(scale=1.0 / 255.0)(inputs) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) - return x - - def build_top(self, inputs, classes, classifier_activation, dropout_rate): - x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs) - x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="classifier" - )(x) - return x - @staticmethod def available_feature_keys(): # for: v1, v1_lite, v2_m, v2_l, v2_xl, tinynet diff --git a/kimm/models/ghostnet.py b/kimm/models/ghostnet.py index 4d28d9a..4b0b89b 100644 --- a/kimm/models/ghostnet.py +++ b/kimm/models/ghostnet.py @@ -262,7 +262,7 @@ def __init__( x = img_input if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x) + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -337,14 +337,6 @@ def __init__( self.config = config self.version = version - def build_preprocessing(self, inputs): - # [0, 255] to [0, 1] and apply ImageNet mean and variance - x = layers.Rescaling(scale=1.0 / 255.0)(inputs) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) - return x - def build_top(self, inputs, classes, classifier_activation, dropout_rate): x = layers.GlobalAveragePooling2D(name="avg_pool", keepdims=True)( inputs diff --git a/kimm/models/inception_v3.py b/kimm/models/inception_v3.py index d12a001..b37428d 100644 --- a/kimm/models/inception_v3.py +++ b/kimm/models/inception_v3.py @@ -214,7 +214,7 @@ def __init__(self, has_aux_logits=False, **kwargs): x = img_input if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x) + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -283,22 +283,6 @@ def __init__(self, has_aux_logits=False, **kwargs): self.add_references(parsed_kwargs) self.has_aux_logits = has_aux_logits - def build_preprocessing(self, inputs): - # [0, 255] to [0, 1] and apply ImageNet mean and variance - x = layers.Rescaling(scale=1.0 / 255.0)(inputs) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) - return x - - def build_top(self, inputs, classes, classifier_activation, dropout_rate): - x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs) - x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="classifier" - )(x) - return x - @staticmethod def available_feature_keys(): feature_keys = ["STEM_S2"] diff --git a/kimm/models/mobilenet_v2.py b/kimm/models/mobilenet_v2.py index 2bf5c0f..1350f9b 100644 --- a/kimm/models/mobilenet_v2.py +++ b/kimm/models/mobilenet_v2.py @@ -51,7 +51,7 @@ def __init__( x = img_input if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x) + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -137,22 +137,6 @@ def __init__( self.fix_stem_and_head_channels = fix_stem_and_head_channels self.config = config - def build_preprocessing(self, inputs): - # [0, 255] to [0, 1] and apply ImageNet mean and variance - x = layers.Rescaling(scale=1.0 / 255.0)(inputs) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) - return x - - def build_top(self, inputs, classes, classifier_activation, dropout_rate): - x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs) - x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="classifier" - )(x) - return x - @staticmethod def available_feature_keys(): feature_keys = ["STEM_S2"] diff --git a/kimm/models/mobilenet_v3.py b/kimm/models/mobilenet_v3.py index 2e17da3..76aaa2c 100644 --- a/kimm/models/mobilenet_v3.py +++ b/kimm/models/mobilenet_v3.py @@ -126,7 +126,7 @@ def __init__( x = img_input if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x) + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -261,14 +261,6 @@ def __init__( self.config = config self.minimal = minimal - def build_preprocessing(self, inputs): - # [0, 255] to [0, 1] and apply ImageNet mean and variance - x = layers.Rescaling(scale=1.0 / 255.0)(inputs) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) - return x - def build_top( self, inputs, diff --git a/kimm/models/mobilevit.py b/kimm/models/mobilevit.py index d2c601f..d3c6079 100644 --- a/kimm/models/mobilevit.py +++ b/kimm/models/mobilevit.py @@ -193,7 +193,7 @@ def __init__( x = img_input if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x) + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -285,22 +285,6 @@ def __init__( self.activation = activation self.config = config - def build_preprocessing(self, inputs): - # [0, 255] to [0, 1] and apply ImageNet mean and variance - x = layers.Rescaling(scale=1.0 / 255.0)(inputs) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) - return x - - def build_top(self, inputs, classes, classifier_activation, dropout_rate): - x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs) - x = layers.Dropout(dropout_rate, name="head_drop")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="head_fc" - )(x) - return x - @staticmethod def available_feature_keys(): feature_keys = ["STEM_S2"] diff --git a/kimm/models/resnet.py b/kimm/models/resnet.py index 2d07735..7d882d3 100644 --- a/kimm/models/resnet.py +++ b/kimm/models/resnet.py @@ -123,7 +123,7 @@ def __init__( x = img_input if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x) + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -188,22 +188,6 @@ def __init__( self.block_fn = block_fn self.num_blocks = num_blocks - def build_preprocessing(self, inputs): - # [0, 255] to [0, 1] and apply ImageNet mean and variance - x = layers.Rescaling(scale=1.0 / 255.0)(inputs) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) - return x - - def build_top(self, inputs, classes, classifier_activation, dropout_rate): - x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs) - x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) - x = layers.Dense(classes, activation=classifier_activation, name="fc")( - x - ) - return x - @staticmethod def available_feature_keys(): feature_keys = ["STEM_S2"] diff --git a/kimm/models/vision_transformer.py b/kimm/models/vision_transformer.py index d2f0b4a..0fae02c 100644 --- a/kimm/models/vision_transformer.py +++ b/kimm/models/vision_transformer.py @@ -38,7 +38,7 @@ def __init__( x = img_input if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x) + x = self.build_preprocessing(x, "-1_1") # Prepare feature extraction features = {} @@ -100,11 +100,6 @@ def __init__( self.use_qk_norm = use_qk_norm self.pos_dropout_rate = pos_dropout_rate - def build_preprocessing(self, inputs): - # [0, 255] to [-1, 1] - x = layers.Rescaling(scale=1.0 / 127.5, offset=-1.0)(inputs) - return x - def build_top(self, inputs, classes, classifier_activation, dropout_rate): x = inputs[:, 0] # class token x = layers.Dropout(dropout_rate, name="head_drop")(x) diff --git a/tools/convert_mobilevit_from_timm.py b/tools/convert_mobilevit_from_timm.py index 42aac54..5000e74 100644 --- a/tools/convert_mobilevit_from_timm.py +++ b/tools/convert_mobilevit_from_timm.py @@ -88,6 +88,8 @@ # final block torch_name = torch_name.replace("final.conv.conv2d", "final_conv.conv") torch_name = torch_name.replace("final.conv.bn", "final_conv.bn") + # head + torch_name = torch_name.replace("classifier", "head.fc") # weights naming mapping torch_name = torch_name.replace("kernel", "weight") # conv2d diff --git a/tools/convert_resnet_from_timm.py b/tools/convert_resnet_from_timm.py index ad91cad..4dde8bf 100644 --- a/tools/convert_resnet_from_timm.py +++ b/tools/convert_resnet_from_timm.py @@ -80,6 +80,8 @@ torch_name = torch_name.replace("conv3.bn", "bn3") torch_name = torch_name.replace("downsample.conv2d", "downsample.0") torch_name = torch_name.replace("downsample.bn", "downsample.1") + # head + torch_name = torch_name.replace("classifier", "fc") # weights naming mapping torch_name = torch_name.replace("kernel", "weight") # conv2d From cf0d34a8428e363c50c3047546e3e654360275af Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Wed, 17 Jan 2024 09:43:33 +0800 Subject: [PATCH 08/10] Simplify code --- kimm/blocks/inverted_residual_block.py | 4 +-- kimm/models/efficientnet.py | 50 ++++---------------------- kimm/models/mobilenet_v2.py | 8 +---- kimm/models/mobilenet_v3.py | 14 +++----- kimm/models/mobilenet_v3_test.py | 12 +++++++ kimm/models/mobilevit.py | 12 ++----- 6 files changed, 28 insertions(+), 72 deletions(-) diff --git a/kimm/blocks/inverted_residual_block.py b/kimm/blocks/inverted_residual_block.py index 1d58267..1043a90 100644 --- a/kimm/blocks/inverted_residual_block.py +++ b/kimm/blocks/inverted_residual_block.py @@ -15,7 +15,7 @@ def apply_inverted_residual_block( expansion_ratio=1.0, se_ratio=0.0, activation="swish", - se_input_channels=None, + se_channels=None, se_activation=None, se_gate_activation="sigmoid", se_make_divisible_number=None, @@ -57,7 +57,7 @@ def apply_inverted_residual_block( se_ratio, activation=se_activation or activation, gate_activation=se_gate_activation, - se_input_channels=se_input_channels, + se_input_channels=se_channels, make_divisible_number=se_make_divisible_number, name=f"{name}_se", ) diff --git a/kimm/models/efficientnet.py b/kimm/models/efficientnet.py index 1ec1221..969012b 100644 --- a/kimm/models/efficientnet.py +++ b/kimm/models/efficientnet.py @@ -8,7 +8,6 @@ from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_depthwise_separation_block from kimm.blocks import apply_inverted_residual_block -from kimm.blocks import apply_se_block from kimm.models import BaseModel from kimm.utils import add_model_to_registry from kimm.utils import make_divisible @@ -88,7 +87,6 @@ def apply_edge_residual_block( pointwise_kernel_size=1, strides=1, expansion_ratio=1.0, - se_ratio=0.0, activation="swish", bn_epsilon=1e-5, padding=None, @@ -110,16 +108,6 @@ def apply_edge_residual_block( padding=padding, name=f"{name}_conv_exp", ) - # Squeeze-and-excitation - if se_ratio > 0: - x = apply_se_block( - x, - se_ratio, - activation=activation, - gate_activation="sigmoid", - se_input_channels=input_channels, - name=f"{name}_se", - ) # Point-wise linear projection x = apply_conv2d_block( x, @@ -230,50 +218,26 @@ def __init__( r = int(round_fn(r * depth)) for current_layer_idx in range(r): s = s if current_layer_idx == 0 else 1 - common_kwargs = { + _kwargs = { "bn_epsilon": bn_epsilon, "padding": padding, "name": f"blocks_{current_block_idx}_{current_layer_idx}", + "activation": activation, } if block_type == "ds": x = apply_depthwise_separation_block( - x, - c, - k, - 1, - s, - se, - activation=activation, - se_activation=activation, - **common_kwargs, + x, c, k, 1, s, se, se_activation=activation, **_kwargs ) elif block_type == "ir": + se_c = x.shape[-1] x = apply_inverted_residual_block( - x, - c, - k, - 1, - 1, - s, - e, - se, - activation, - se_input_channels=x.shape[-1], - **common_kwargs, + x, c, k, 1, 1, s, e, se, se_channels=se_c, **_kwargs ) elif block_type == "cn": - x = apply_conv2d_block( - x, - filters=c, - kernel_size=k, - strides=s, - activation=activation, - add_skip=True, - **common_kwargs, - ) + x = apply_conv2d_block(x, c, k, s, add_skip=True, **_kwargs) elif block_type == "er": x = apply_edge_residual_block( - x, c, k, 1, s, e, se, activation, **common_kwargs + x, c, k, 1, s, e, **_kwargs ) current_stride *= s features[f"BLOCK{current_block_idx}_S{current_stride}"] = x diff --git a/kimm/models/mobilenet_v2.py b/kimm/models/mobilenet_v2.py index 1350f9b..a9ee436 100644 --- a/kimm/models/mobilenet_v2.py +++ b/kimm/models/mobilenet_v2.py @@ -83,13 +83,7 @@ def __init__( name = f"blocks_{current_block_idx}_{current_layer_idx}" if block_type == "ds": x = apply_depthwise_separation_block( - x, - c, - k, - 1, - s, - activation="relu6", - name=name, + x, c, k, 1, s, activation="relu6", name=name ) elif block_type == "ir": x = apply_inverted_residual_block( diff --git a/kimm/models/mobilenet_v3.py b/kimm/models/mobilenet_v3.py index 76aaa2c..b55759d 100644 --- a/kimm/models/mobilenet_v3.py +++ b/kimm/models/mobilenet_v3.py @@ -167,7 +167,7 @@ def __init__( r = int(math.ceil(r * depth)) for current_layer_idx in range(r): s = s if current_layer_idx == 0 else 1 - common_kwargs = { + _kwargs = { "bn_epsilon": bn_epsilon, "padding": padding, "name": ( @@ -189,7 +189,7 @@ def __init__( se_make_divisible_number=8, pw_activation=act if block_type == "dsa" else None, skip=False if block_type == "dsa" else True, - **common_kwargs, + **_kwargs, ) elif block_type == "ir": x = apply_inverted_residual_block( @@ -202,20 +202,14 @@ def __init__( e, se, act, - se_input_channels=None, se_activation="relu", se_gate_activation="hard_sigmoid", se_make_divisible_number=8, - **common_kwargs, + **_kwargs, ) elif block_type == "cn": x = apply_conv2d_block( - x, - filters=c, - kernel_size=k, - strides=s, - activation=act, - **common_kwargs, + x, c, k, s, activation=act, **_kwargs ) current_stride *= s features[f"BLOCK{current_stage_idx}_S{current_stride}"] = x diff --git a/kimm/models/mobilenet_v3_test.py b/kimm/models/mobilenet_v3_test.py index fea61a5..128764c 100644 --- a/kimm/models/mobilenet_v3_test.py +++ b/kimm/models/mobilenet_v3_test.py @@ -6,6 +6,7 @@ from kimm.models.mobilenet_v3 import LCNet100 from kimm.models.mobilenet_v3 import MobileNet100V3Large from kimm.models.mobilenet_v3 import MobileNet100V3Small +from kimm.models.mobilenet_v3 import MobileNet100V3SmallMinimal from kimm.utils import make_divisible @@ -13,6 +14,7 @@ class MobileNetV3Test(testing.TestCase, parameterized.TestCase): @parameterized.named_parameters( [ (MobileNet100V3Small.__name__, MobileNet100V3Small), + (MobileNet100V3SmallMinimal.__name__, MobileNet100V3SmallMinimal), (MobileNet100V3Large.__name__, MobileNet100V3Large), (LCNet100.__name__, LCNet100), ] @@ -29,6 +31,11 @@ def test_mobilenet_v3_base(self, model_class): @parameterized.named_parameters( [ (MobileNet100V3Small.__name__, MobileNet100V3Small, 1.0), + ( + MobileNet100V3SmallMinimal.__name__, + MobileNet100V3SmallMinimal, + 1.0, + ), (MobileNet100V3Large.__name__, MobileNet100V3Large, 1.0), ] ) @@ -120,6 +127,11 @@ def test_lcnet_feature_extractor(self, model_class, width): @parameterized.named_parameters( [ (MobileNet100V3Small.__name__, MobileNet100V3Small, 224), + ( + MobileNet100V3SmallMinimal.__name__, + MobileNet100V3SmallMinimal, + 224, + ), (MobileNet100V3Large.__name__, MobileNet100V3Large, 224), (LCNet100.__name__, LCNet100, 224), ] diff --git a/kimm/models/mobilevit.py b/kimm/models/mobilevit.py index d3c6079..4649462 100644 --- a/kimm/models/mobilevit.py +++ b/kimm/models/mobilevit.py @@ -213,7 +213,7 @@ def __init__( k, c, s, - expansion_ratio, + e, transformer_dim, transformer_depth, patch_size, @@ -223,15 +223,7 @@ def __init__( s = s if current_layer_idx == 0 else 1 name = f"stages_{current_block_idx}_{current_layer_idx}" x = apply_inverted_residual_block( - x, - c, - k, - 1, - 1, - s, - expansion_ratio, - activation=activation, - name=name, + x, c, k, 1, 1, s, e, activation=activation, name=name ) current_stride *= s if block_type == "mobilevit": From 6a8fd3066d6b2a45a73faeed08fdb8eb20d0acb9 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Wed, 17 Jan 2024 09:48:32 +0800 Subject: [PATCH 09/10] Format --- kimm/models/efficientnet.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/kimm/models/efficientnet.py b/kimm/models/efficientnet.py index 969012b..bc5c7dc 100644 --- a/kimm/models/efficientnet.py +++ b/kimm/models/efficientnet.py @@ -236,9 +236,7 @@ def __init__( elif block_type == "cn": x = apply_conv2d_block(x, c, k, s, add_skip=True, **_kwargs) elif block_type == "er": - x = apply_edge_residual_block( - x, c, k, 1, s, e, **_kwargs - ) + x = apply_edge_residual_block(x, c, k, 1, s, e, **_kwargs) current_stride *= s features[f"BLOCK{current_block_idx}_S{current_stride}"] = x From 532d86f30c66452d056ed86b5844c70b91457fba Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Wed, 17 Jan 2024 10:02:29 +0800 Subject: [PATCH 10/10] Mark serialization and skip them by default --- conftest.py | 28 +++++++++++++++++++++++++- kimm/models/densenet_test.py | 6 ++++-- kimm/models/efficientnet_test.py | 2 ++ kimm/models/ghostnet_test.py | 2 ++ kimm/models/inception_v3_test.py | 2 ++ kimm/models/mobilenet_v2_test.py | 2 ++ kimm/models/mobilenet_v3_test.py | 2 ++ kimm/models/mobilevit_test.py | 2 ++ kimm/models/resnet_test.py | 2 ++ kimm/models/vision_transformer_test.py | 2 ++ 10 files changed, 47 insertions(+), 3 deletions(-) diff --git a/conftest.py b/conftest.py index 0127de4..fce612b 100644 --- a/conftest.py +++ b/conftest.py @@ -1,7 +1,18 @@ import os +import pytest -def pytest_configure(): + +def pytest_addoption(parser): + parser.addoption( + "--run_serialization", + action="store_true", + default=False, + help="run serialization tests", + ) + + +def pytest_configure(config): import tensorflow as tf # disable tensorflow gpu memory preallocation @@ -12,3 +23,18 @@ def pytest_configure(): # disable jax gpu memory preallocation # https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + + config.addinivalue_line( + "markers", "serialization: mark test as a serialization test" + ) + + +def pytest_collection_modifyitems(config, items): + run_serialization_tests = config.getoption("--run_serialization") + skip_serialization = pytest.mark.skipif( + not run_serialization_tests, + reason="need --run_serialization option to run", + ) + for item in items: + if "serialization" in item.name: + item.add_marker(skip_serialization) diff --git a/kimm/models/densenet_test.py b/kimm/models/densenet_test.py index 31a3ab9..95c50b3 100644 --- a/kimm/models/densenet_test.py +++ b/kimm/models/densenet_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import models from keras import random @@ -18,7 +19,7 @@ def test_densenet_base(self, model_class): self.assertEqual(y.shape, (1, 1000)) @parameterized.named_parameters([(DenseNet121.__name__, DenseNet121)]) - def test_mobilenet_v2_feature_extractor(self, model_class): + def test_densenet_feature_extractor(self, model_class): x = random.uniform([1, 224, 224, 3]) * 255.0 model = model_class( input_shape=[224, 224, 3], as_feature_extractor=True @@ -36,8 +37,9 @@ def test_mobilenet_v2_feature_extractor(self, model_class): self.assertEqual(list(y["BLOCK2_S32"].shape), [1, 7, 7, 512]) self.assertEqual(list(y["BLOCK3_S32"].shape), [1, 7, 7, 1024]) + @pytest.mark.serialization @parameterized.named_parameters([(DenseNet121.__name__, DenseNet121, 224)]) - def test_mobilenet_v2_serialization(self, model_class, image_size): + def test_densenet_serialization(self, model_class, image_size): x = random.uniform([1, image_size, image_size, 3]) * 255.0 temp_dir = self.get_temp_dir() model1 = model_class(input_shape=[224, 224, 3]) diff --git a/kimm/models/efficientnet_test.py b/kimm/models/efficientnet_test.py index 10837a9..570c4ed 100644 --- a/kimm/models/efficientnet_test.py +++ b/kimm/models/efficientnet_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import models from keras import random @@ -123,6 +124,7 @@ def test_efficentnet_v2_feature_extractor(self, model_class, width): [1, 7, 7, make_divisible(192 * width)], ) + @pytest.mark.serialization @parameterized.named_parameters( [ (EfficientNetB0.__name__, EfficientNetB0, 224), diff --git a/kimm/models/ghostnet_test.py b/kimm/models/ghostnet_test.py index c800c8e..6b3881f 100644 --- a/kimm/models/ghostnet_test.py +++ b/kimm/models/ghostnet_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import models from keras import random @@ -62,6 +63,7 @@ def test_ghostnetv2_feature_extractor(self, model_class): self.assertEqual(list(y["BLOCK5_S16"].shape), [1, 14, 14, 80]) self.assertEqual(list(y["BLOCK7_S32"].shape), [1, 7, 7, 160]) + @pytest.mark.serialization @parameterized.named_parameters( [ (GhostNet100.__name__, GhostNet100, 224), diff --git a/kimm/models/inception_v3_test.py b/kimm/models/inception_v3_test.py index 3899e01..af13e2f 100644 --- a/kimm/models/inception_v3_test.py +++ b/kimm/models/inception_v3_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import models from keras import random @@ -34,6 +35,7 @@ def test_inception_v3_feature_extractor(self, model_class): self.assertEqual(list(y["BLOCK2_S16"].shape), [1, 17, 17, 768]) self.assertEqual(list(y["BLOCK3_S32"].shape), [1, 8, 8, 2048]) + @pytest.mark.serialization @parameterized.named_parameters([(InceptionV3.__name__, InceptionV3, 299)]) def test_inception_v3_serialization(self, model_class, image_size): x = random.uniform([1, image_size, image_size, 3]) * 255.0 diff --git a/kimm/models/mobilenet_v2_test.py b/kimm/models/mobilenet_v2_test.py index 859f7e9..05454c0 100644 --- a/kimm/models/mobilenet_v2_test.py +++ b/kimm/models/mobilenet_v2_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import models from keras import random @@ -56,6 +57,7 @@ def test_mobilenet_v2_feature_extractor(self, model_class, width): list(y["BLOCK5_S32"].shape), [1, 7, 7, make_divisible(160 * width)] ) + @pytest.mark.serialization @parameterized.named_parameters( [(MobileNet050V2.__name__, MobileNet050V2, 224)] ) diff --git a/kimm/models/mobilenet_v3_test.py b/kimm/models/mobilenet_v3_test.py index 128764c..f7a31a9 100644 --- a/kimm/models/mobilenet_v3_test.py +++ b/kimm/models/mobilenet_v3_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import models from keras import random @@ -124,6 +125,7 @@ def test_lcnet_feature_extractor(self, model_class, width): [1, 7, 7, make_divisible(512 * width)], ) + @pytest.mark.serialization @parameterized.named_parameters( [ (MobileNet100V3Small.__name__, MobileNet100V3Small, 224), diff --git a/kimm/models/mobilevit_test.py b/kimm/models/mobilevit_test.py index fbd2939..1738fe9 100644 --- a/kimm/models/mobilevit_test.py +++ b/kimm/models/mobilevit_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import models from keras import random @@ -46,6 +47,7 @@ def test_mobilevit_feature_extractor(self, model_class): self.assertEqual(list(y["BLOCK3_S16"].shape), [1, 16, 16, 80]) self.assertEqual(list(y["BLOCK4_S32"].shape), [1, 8, 8, 96]) + @pytest.mark.serialization @parameterized.named_parameters( [ (MobileViTS.__name__, MobileViTS, 256), diff --git a/kimm/models/resnet_test.py b/kimm/models/resnet_test.py index 478ef6b..47bf01d 100644 --- a/kimm/models/resnet_test.py +++ b/kimm/models/resnet_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import models from keras import random @@ -47,6 +48,7 @@ def test_resnet_feature_extractor(self, model_class, expansion): list(y["BLOCK3_S32"].shape), [1, 7, 7, 512 * expansion] ) + @pytest.mark.serialization @parameterized.named_parameters( [(ResNet18.__name__, ResNet18, 224), (ResNet50.__name__, ResNet50, 224)] ) diff --git a/kimm/models/vision_transformer_test.py b/kimm/models/vision_transformer_test.py index 7d8ba97..74bdab7 100644 --- a/kimm/models/vision_transformer_test.py +++ b/kimm/models/vision_transformer_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import models from keras import random @@ -50,6 +51,7 @@ def test_vision_transformer_feature_extractor( elif patch_size == 32: self.assertEqual(list(y["BLOCK5"].shape), [1, 145, 192]) + @pytest.mark.serialization @parameterized.named_parameters( [ (VisionTransformerTiny16.__name__, VisionTransformerTiny16, 384),