Skip to content

Commit

Permalink
Merge pull request #183 from candemircan/tf-keras-fix
Browse files Browse the repository at this point in the history
fix scaling of RGBs for tensorflow/keras
LukasMut authored Dec 4, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
2 parents aac5fe0 + 4aa351e commit ee8b576
Showing 2 changed files with 46 additions and 1 deletion.
45 changes: 45 additions & 0 deletions thingsvision/core/extraction/extractors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
import warnings
from typing import Any, Callable, Dict, List, Optional, Union

@@ -9,6 +10,7 @@
import torch
import torchvision

from tensorflow.keras.layers import Lambda
from torch.hub import load_state_dict_from_url

from thingsvision.utils.checkpointing import get_torch_home
@@ -195,10 +197,53 @@ def load_model_from_source(self) -> None:
else:
weights = None
self.model = model(weights=weights)
preproc_fun_name = self.get_keras_preprocessing(self.model_name)
if isinstance(preproc_fun_name, str):
# get preprocessing function for a specific model
preproc_fun = self.get_preproc_fun(preproc_fun_name)
# different models take differently sized inputs. this has to be accounted for.
resize_dim = self.model.layers[0].input_shape[0][-2] # -2 and -3 are the H and W channel dims.
self.preprocess = tf.keras.Sequential([Lambda(preproc_fun), tf.keras.layers.experimental.preprocessing.Resizing(resize_dim, resize_dim)])
else:
raise ValueError(
f"\nCould not find {self.model_name} among TensorFlow models.\n"
)


@staticmethod
def get_preproc_fun(preproc_fun_name: str) -> Callable:
"""Get the preprocessing function associated with a specific model."""
return getattr(getattr(tensorflow_models, preproc_fun_name), "preprocess_input")


def get_keras_preprocessing(self, model_name:str) -> Union[str, None]:
"""Get the preprocessing function for the corresponding model from `tensorflow.keras.applications.*`"""

patterns = [
(r'^ConvNeXt(Base|Large|Small|Tiny|XLarge)$', 'convnext'),
(r'^DenseNet\d+$', 'densenet'),
(r'^EfficientNetB[0-7]$', 'efficientnet'),
(r'^EfficientNetV2(B[0-3]|[LMS])$', 'efficientnet_v2'),
(r'^InceptionResNetV2$', 'inception_resnet_v2'),
(r'^InceptionV3$', 'inception_v3'),
(r'^MobileNet$', 'mobilenet'),
(r'^MobileNetV2$', 'mobilenet_v2'),
(r'^MobileNetV3(Large|Small)$', 'mobilenet_v3'),
(r'^NasNet(Large|Mobile)$', 'nasnet'),
(r'^ResNet\d+$', 'resnet'),
(r'^ResNet\d+V2$', 'resnet_v2'),
(r'^VGG16$', 'vgg16'),
(r'^VGG19$', 'vgg19'),
(r'^Xception$', 'xception')
]
# Try each pattern
for pattern, preproc_val in patterns:
if re.search(pattern, model_name):
return preproc_val

# If no match is found, print a warning message
warnings.warn(f"No preprocessing function found for model {model_name}, so falling back to default preprocessing.\nOften, models that come from Keras Applications have their own preprocessing functions.\nThus, this may create inaccurate results. If you need to manually specify a preprocessing function, please do so under the `transforms` argument when creating your Dataset")
return None


class SSLExtractor(PyTorchExtractor):
2 changes: 1 addition & 1 deletion thingsvision/core/extraction/tensorflow.py
Original file line number Diff line number Diff line change
@@ -88,7 +88,7 @@ def get_default_transformation(
apply_center_crop: bool = True,
) -> Any:
resize_dim = crop_dim
composes = [layers.experimental.preprocessing.Resizing(resize_dim, resize_dim)]
composes = [layers.experimental.preprocessing.Resizing(resize_dim, resize_dim), layers.experimental.preprocessing.Rescaling(1./255.)]
if apply_center_crop:
pass
# TODO: fix center crop problem with Keras

0 comments on commit ee8b576

Please sign in to comment.