Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DenseNet, InceptionV3 and refactor BaseModel #11

Merged
merged 10 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import os

import pytest

def pytest_configure():

def pytest_addoption(parser):
parser.addoption(
"--run_serialization",
action="store_true",
default=False,
help="run serialization tests",
)


def pytest_configure(config):
import tensorflow as tf

# disable tensorflow gpu memory preallocation
Expand All @@ -12,3 +23,18 @@ def pytest_configure():
# disable jax gpu memory preallocation
# https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

config.addinivalue_line(
"markers", "serialization: mark test as a serialization test"
)


def pytest_collection_modifyitems(config, items):
run_serialization_tests = config.getoption("--run_serialization")
skip_serialization = pytest.mark.skipif(
not run_serialization_tests,
reason="need --run_serialization option to run",
)
for item in items:
if "serialization" in item.name:
item.add_marker(skip_serialization)
6 changes: 5 additions & 1 deletion kimm/blocks/base_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def apply_conv2d_block(
raise ValueError(
f"kernel_size must be passed. Received: kernel_size={kernel_size}"
)
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
input_channels = inputs.shape[-1]
has_skip = add_skip and strides == 1 and input_channels == filters
x = inputs
Expand All @@ -42,7 +44,9 @@ def apply_conv2d_block(
padding = "same"
if strides > 1:
padding = "valid"
x = layers.ZeroPadding2D(kernel_size // 2, name=f"{name}_pad")(x)
x = layers.ZeroPadding2D(
(kernel_size[0] // 2, kernel_size[1] // 2), name=f"{name}_pad"
)(x)

if not use_depthwise:
x = layers.Conv2D(
Expand Down
4 changes: 2 additions & 2 deletions kimm/blocks/inverted_residual_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def apply_inverted_residual_block(
expansion_ratio=1.0,
se_ratio=0.0,
activation="swish",
se_input_channels=None,
se_channels=None,
se_activation=None,
se_gate_activation="sigmoid",
se_make_divisible_number=None,
Expand Down Expand Up @@ -57,7 +57,7 @@ def apply_inverted_residual_block(
se_ratio,
activation=se_activation or activation,
gate_activation=se_gate_activation,
se_input_channels=se_input_channels,
se_input_channels=se_channels,
make_divisible_number=se_make_divisible_number,
name=f"{name}_se",
)
Expand Down
15 changes: 0 additions & 15 deletions kimm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,3 @@ def get_config(self):
}
)
return config


if __name__ == "__main__":
from keras import models
from keras import random

inputs = layers.Input(shape=[197, 768])
outputs = Attention(768)(inputs)

model = models.Model(inputs, outputs)
model.summary()

inputs = random.uniform([1, 197, 768])
outputs = model(inputs)
print(outputs.shape)
15 changes: 0 additions & 15 deletions kimm/layers/layer_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,3 @@ def get_config(self):
}
)
return config


if __name__ == "__main__":
from keras import models
from keras import random

inputs = layers.Input(shape=[197, 768])
outputs = LayerScale(768)(inputs)

model = models.Model(inputs, outputs)
model.summary()

inputs = random.uniform([1, 197, 768])
outputs = model(inputs)
print(outputs.shape)
22 changes: 0 additions & 22 deletions kimm/layers/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,3 @@ def compute_output_shape(self, input_shape):

def get_config(self):
return super().get_config()


if __name__ == "__main__":
from keras import models
from keras import random

inputs = layers.Input([224, 224, 3])
x = layers.Conv2D(
768,
16,
16,
use_bias=True,
)(inputs)
x = layers.Reshape((-1, 768))(x)
outputs = PositionEmbedding()(x)

model = models.Model(inputs, outputs)
model.summary()

inputs = random.uniform([1, 224, 224, 3])
outputs = model(inputs)
print(outputs.shape)
2 changes: 1 addition & 1 deletion kimm/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from kimm.models.base_model import BaseModel
from kimm.models.efficientnet import * # noqa:F403
from kimm.models.feature_extractor import FeatureExtractor
from kimm.models.ghostnet import * # noqa:F403
from kimm.models.mobilenet_v2 import * # noqa:F403
from kimm.models.mobilenet_v3 import * # noqa:F403
Expand Down
157 changes: 157 additions & 0 deletions kimm/models/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import abc
import typing

from keras import KerasTensor
from keras import backend
from keras import layers
from keras import models
from keras.src.applications import imagenet_utils


class BaseModel(models.Model):
def __init__(
self,
inputs,
outputs,
features: typing.Optional[typing.Dict[str, KerasTensor]] = None,
feature_keys: typing.Optional[typing.List[str]] = None,
**kwargs,
):
self.as_feature_extractor = kwargs.pop("as_feature_extractor", False)
self.feature_keys = feature_keys
if self.as_feature_extractor:
if features is None:
raise ValueError(
"`features` must be set when "
f"`as_feature_extractor=True`. Got features={features}"
)
if self.feature_keys is None:
self.feature_keys = list(features.keys())
filtered_features = {}
for k in self.feature_keys:
if k not in features:
raise KeyError(
f"'{k}' is not a key of `features`. Available keys "
f"are: {list(features.keys())}"
)
filtered_features[k] = features[k]
super().__init__(inputs=inputs, outputs=filtered_features, **kwargs)
else:
del features
super().__init__(inputs=inputs, outputs=outputs, **kwargs)

def parse_kwargs(
self, kwargs: typing.Dict[str, typing.Any], default_size: int = 224
):
result = {
"input_tensor": kwargs.pop("input_tensor", None),
"input_shape": kwargs.pop("input_shape", None),
"include_preprocessing": kwargs.pop("include_preprocessing", True),
"include_top": kwargs.pop("include_top", True),
"pooling": kwargs.pop("pooling", None),
"dropout_rate": kwargs.pop("dropout_rate", 0.0),
"classes": kwargs.pop("classes", 1000),
"classifier_activation": kwargs.pop(
"classifier_activation", "softmax"
),
"weights": kwargs.pop("weights", "imagenet"),
"default_size": kwargs.pop("default_size", default_size),
}
return result

def determine_input_tensor(
self,
input_tensor=None,
input_shape=None,
default_size=224,
min_size=32,
require_flatten=False,
static_shape=False,
):
"""Determine the input tensor by the arguments."""
input_shape = imagenet_utils.obtain_input_shape(
input_shape,
default_size=default_size,
min_size=min_size,
data_format="channels_last", # always channels_last
require_flatten=require_flatten or static_shape,
weights=None,
)

if input_tensor is None:
x = layers.Input(shape=input_shape)
else:
if not backend.is_keras_tensor(input_tensor):
x = layers.Input(tensor=input_tensor, shape=input_shape)
else:
x = input_tensor
return x

def build_preprocessing(self, inputs, mode="imagenet"):
if mode == "imagenet":
# [0, 255] to [0, 1] and apply ImageNet mean and variance
x = layers.Rescaling(scale=1.0 / 255.0)(inputs)
x = layers.Normalization(
mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225]
)(x)
elif mode == "0_1":
# [0, 255] to [-1, 1]
x = layers.Rescaling(scale=1.0 / 255.0)(inputs)
elif mode == "-1_1":
# [0, 255] to [-1, 1]
x = layers.Rescaling(scale=1.0 / 127.5, offset=-1.0)(inputs)
else:
raise ValueError(
"`mode` must be one of ('imagenet', '0_1', '-1_1'). "
f"Received: mode={mode}"
)
return x

def build_top(self, inputs, classes, classifier_activation, dropout_rate):
x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs)
x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x)
x = layers.Dense(
classes, activation=classifier_activation, name="classifier"
)(x)
return x

def add_references(self, parsed_kwargs: typing.Dict[str, typing.Any]):
self.include_preprocessing = parsed_kwargs["include_preprocessing"]
self.include_top = parsed_kwargs["include_top"]
self.pooling = parsed_kwargs["pooling"]
self.dropout_rate = parsed_kwargs["dropout_rate"]
self.classes = parsed_kwargs["classes"]
self.classifier_activation = parsed_kwargs["classifier_activation"]
# `self.weights` is been used internally
self._weights = parsed_kwargs["weights"]

@staticmethod
@abc.abstractmethod
def available_feature_keys():
# TODO: add docstring
raise NotImplementedError

def get_config(self):
# Don't chain to super here. The default `get_config()` for functional
# models is nested and cannot be passed to BaseModel.
config = {
# models.Model
"name": self.name,
"trainable": self.trainable,
# feature extractor
"as_feature_extractor": self.as_feature_extractor,
"feature_keys": self.feature_keys,
# common
"input_shape": self.input_shape[1:],
"include_preprocessing": self.include_preprocessing,
"include_top": self.include_top,
"pooling": self.pooling,
"dropout_rate": self.dropout_rate,
"classes": self.classes,
"classifier_activation": self.classifier_activation,
"weights": self._weights,
}
return config

def fix_config(self, config: typing.Dict):
return config
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from keras import random
from keras.src import testing

from kimm.models.feature_extractor import FeatureExtractor
from kimm.models.base_model import BaseModel


class SampleModel(FeatureExtractor):
class SampleModel(BaseModel):
def __init__(self, **kwargs):
inputs = layers.Input(shape=[224, 224, 3])

Expand Down Expand Up @@ -34,7 +34,7 @@ def get_config(self):
return super().get_config()


class GhostNetTest(testing.TestCase, parameterized.TestCase):
class BaseModelTest(testing.TestCase, parameterized.TestCase):
def test_feature_extractor(self):
x = random.uniform([1, 224, 224, 3])

Expand Down
Loading