From 42c27fc021b216b1da1969ea9c5de2f7b365cb6f Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Fri, 29 Nov 2024 15:17:23 +0100 Subject: [PATCH 1/7] added wrapper to eva.core --- pdm.lock | 2 +- src/eva/core/models/wrappers/__init__.py | 4 +- src/eva/core/models/wrappers/from_torchhub.py | 87 +++++++++++++++++++ src/eva/vision/models/wrappers/__init__.py | 2 +- 4 files changed, 92 insertions(+), 3 deletions(-) create mode 100644 src/eva/core/models/wrappers/from_torchhub.py diff --git a/pdm.lock b/pdm.lock index 21faa341..826a205e 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "all", "dev", "docs", "lint", "test", "typecheck", "vision"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:b8df35bf60e5573e36c31c4ad4f324d7693f16b31cadcd27e48b352ae6c0235b" +content_hash = "sha256:f1cb189ed1b12e6aa22ba372e4b934e29afefdd17894792c86863bb62d15217f" [[metadata.targets]] requires_python = ">=3.10" diff --git a/src/eva/core/models/wrappers/__init__.py b/src/eva/core/models/wrappers/__init__.py index 95ab6101..4b8559ff 100644 --- a/src/eva/core/models/wrappers/__init__.py +++ b/src/eva/core/models/wrappers/__init__.py @@ -4,10 +4,12 @@ from eva.core.models.wrappers.from_function import ModelFromFunction from eva.core.models.wrappers.huggingface import HuggingFaceModel from eva.core.models.wrappers.onnx import ONNXModel +from eva.core.models.wrappers.from_torchhub import TorchHubModel __all__ = [ "BaseModel", - "ModelFromFunction", "HuggingFaceModel", + "ModelFromFunction", "ONNXModel", + "TorchHubModel", ] diff --git a/src/eva/core/models/wrappers/from_torchhub.py b/src/eva/core/models/wrappers/from_torchhub.py new file mode 100644 index 00000000..0fdb0c07 --- /dev/null +++ b/src/eva/core/models/wrappers/from_torchhub.py @@ -0,0 +1,87 @@ +"""Model wrapper for torch.hub models.""" + +from typing import Any, Callable, Dict, Tuple + +import torch +import torch.nn as nn +from typing_extensions import override + +from eva.core.models import wrappers +from eva.core.models.wrappers import _utils + + +class TorchHubModel(wrappers.BaseModel): + """Model wrapper for `torch.hub` models.""" + + def __init__( + self, + model_name: str, + repo_or_dir: str, + pretrained: bool = True, + checkpoint_path: str = "", + out_indices: int | Tuple[int, ...] | None = None, + norm: bool = False, + trust_repo: bool = True, + model_kwargs: Dict[str, Any] | None = None, + tensor_transforms: Callable | None = None, + ) -> None: + """Initializes the encoder. + + Args: + model_name: Name of model to instantiate. + repo_or_dir: The torch.hub repository or local directory to load the model from. + pretrained: If set to `True`, load pretrained ImageNet-1k weights. + checkpoint_path: Path of checkpoint to load. + out_indices: Returns last n blocks if `int`, all if `None`, select + matching indices if sequence. + norm: Wether to apply norm layer to all intermediate features. Only + used when `out_indices` is not `None`. + trust_repo: If set to `False`, a prompt will ask the user whether the + repo should be trusted. + model_kwargs: Extra model arguments. + tensor_transforms: The transforms to apply to the output tensor + produced by the model. + """ + super().__init__(tensor_transforms=tensor_transforms) + + self._model_name = model_name + self._repo_or_dir = repo_or_dir + self._pretrained = pretrained + self._checkpoint_path = checkpoint_path + self._out_indices = out_indices + self._norm = norm + self._trust_repo = trust_repo + self._model_kwargs = model_kwargs or {} + + self.load_model() + + @override + def load_model(self) -> None: + """Builds and loads the timm model as feature extractor.""" + self._model: nn.Module = torch.hub.load( + repo_or_dir=self._repo_or_dir, + model=self._model_name, + trust_repo=self._trust_repo, + pretrained=self._pretrained, + **self._model_kwargs, + ) # type: ignore + + if self._checkpoint_path: + _utils.load_model_weights(self._model, self._checkpoint_path) + + TorchHubModel.__name__ = self._model_name + + @override + def model_forward(self, tensor: torch.Tensor) -> torch.Tensor: + if self._out_indices is not None: + if not hasattr(self._model, "get_intermediate_layers"): + raise ValueError( + "Only models with `get_intermediate_layers` are supported " + "when using `out_indices`." + ) + + return self._model.get_intermediate_layers( + tensor, self._out_indices, reshape=True, return_class_token=False, norm=self._norm + ) + + return self._model(tensor) diff --git a/src/eva/vision/models/wrappers/__init__.py b/src/eva/vision/models/wrappers/__init__.py index 14d63b68..d2f84de4 100644 --- a/src/eva/vision/models/wrappers/__init__.py +++ b/src/eva/vision/models/wrappers/__init__.py @@ -3,4 +3,4 @@ from eva.vision.models.wrappers.from_registry import ModelFromRegistry from eva.vision.models.wrappers.from_timm import TimmModel -__all__ = ["TimmModel", "ModelFromRegistry"] +__all__ = ["ModelFromRegistry", "TimmModel"] From a32340fa929433a521f75da216cdd72c4a6971b8 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Fri, 29 Nov 2024 15:35:15 +0100 Subject: [PATCH 2/7] add torch.hub models to model registry --- src/eva/core/models/wrappers/__init__.py | 2 +- .../models/networks/backbones/__init__.py | 4 +- .../networks/backbones/torchhub/__init__.py | 5 ++ .../networks/backbones/torchhub/backbones.py | 61 +++++++++++++++++++ 4 files changed, 69 insertions(+), 3 deletions(-) create mode 100644 src/eva/vision/models/networks/backbones/torchhub/__init__.py create mode 100644 src/eva/vision/models/networks/backbones/torchhub/backbones.py diff --git a/src/eva/core/models/wrappers/__init__.py b/src/eva/core/models/wrappers/__init__.py index 4b8559ff..979577bd 100644 --- a/src/eva/core/models/wrappers/__init__.py +++ b/src/eva/core/models/wrappers/__init__.py @@ -2,9 +2,9 @@ from eva.core.models.wrappers.base import BaseModel from eva.core.models.wrappers.from_function import ModelFromFunction +from eva.core.models.wrappers.from_torchhub import TorchHubModel from eva.core.models.wrappers.huggingface import HuggingFaceModel from eva.core.models.wrappers.onnx import ONNXModel -from eva.core.models.wrappers.from_torchhub import TorchHubModel __all__ = [ "BaseModel", diff --git a/src/eva/vision/models/networks/backbones/__init__.py b/src/eva/vision/models/networks/backbones/__init__.py index 0fdf2963..1ef7bc85 100644 --- a/src/eva/vision/models/networks/backbones/__init__.py +++ b/src/eva/vision/models/networks/backbones/__init__.py @@ -1,6 +1,6 @@ """Vision Model Backbones API.""" -from eva.vision.models.networks.backbones import pathology, timm, universal +from eva.vision.models.networks.backbones import pathology, timm, torchhub, universal from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model -__all__ = ["pathology", "timm", "universal", "BackboneModelRegistry", "register_model"] +__all__ = ["pathology", "timm", "torchhub", "universal", "BackboneModelRegistry", "register_model"] diff --git a/src/eva/vision/models/networks/backbones/torchhub/__init__.py b/src/eva/vision/models/networks/backbones/torchhub/__init__.py new file mode 100644 index 00000000..6acd9797 --- /dev/null +++ b/src/eva/vision/models/networks/backbones/torchhub/__init__.py @@ -0,0 +1,5 @@ +"""torch.hub backbones API.""" + +from eva.vision.models.networks.backbones.torchhub.backbones import torch_hub_model + +__all__ = ["torch_hub_model"] diff --git a/src/eva/vision/models/networks/backbones/torchhub/backbones.py b/src/eva/vision/models/networks/backbones/torchhub/backbones.py new file mode 100644 index 00000000..a8162530 --- /dev/null +++ b/src/eva/vision/models/networks/backbones/torchhub/backbones.py @@ -0,0 +1,61 @@ +"""torch.hub backbones.""" + +import functools +from typing import Tuple + +import torch +from loguru import logger +from torch import nn + +from eva.core.models import wrappers +from eva.vision.models.networks.backbones.registry import BackboneModelRegistry + +HUB_REPOS = ["facebookresearch/dinov2:main", "kaiko-ai/towards_large_pathology_fms"] +"""List of torch.hub repositories for which to add the models to the registry.""" + + +def torch_hub_model( + model_name: str, + repo_or_dir: str, + checkpoint_path: str | None = None, + pretrained: bool = False, + out_indices: int | Tuple[int, ...] | None = None, + **kwargs, +) -> nn.Module: + """Initializes any ViT model from torch.hub with weights from a specified checkpoint. + + Args: + model_name: The name of the model to load. + repo_or_dir: The torch.hub repository or local directory to load the model from. + checkpoint_path: The path to the checkpoint file. + pretrained: If set to `True`, load pretrained ImageNet-1k weights. + out_indices: Whether and which multi-level patch embeddings to return. + **kwargs: Additional arguments to pass to the model + + Returns: + The VIT model instance. + """ + logger.info( + f"Loading torch.hub model {model_name} from {repo_or_dir}" + + (f"using checkpoint {checkpoint_path}" if checkpoint_path else "") + ) + + return wrappers.TorchHubModel( + model_name=model_name, + repo_or_dir=repo_or_dir, + pretrained=pretrained, + checkpoint_path=checkpoint_path or "", + out_indices=out_indices, + model_kwargs=kwargs, + ) + + +BackboneModelRegistry._registry.update( + { + f"torchhub/{repo}:{model_name}": functools.partial( + torch_hub_model, model_name=model_name, repo_or_dir=repo + ) + for repo in HUB_REPOS + for model_name in torch.hub.list(repo, verbose=False) + } +) From 5b30e014c0152d9626e872fa4391310507cf7e9e Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Fri, 29 Nov 2024 15:45:45 +0100 Subject: [PATCH 3/7] added unit test --- .../core/models/wrappers/test_from_torchub.py | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 tests/eva/core/models/wrappers/test_from_torchub.py diff --git a/tests/eva/core/models/wrappers/test_from_torchub.py b/tests/eva/core/models/wrappers/test_from_torchub.py new file mode 100644 index 00000000..cfbfabc9 --- /dev/null +++ b/tests/eva/core/models/wrappers/test_from_torchub.py @@ -0,0 +1,76 @@ +"""TimmModel tests.""" + +from typing import Any, Dict, Tuple + +import pytest +import torch + +from eva.core.models import wrappers + + +@pytest.mark.parametrize( + "model_name, repo_or_dir, out_indices, model_kwargs, " + "input_tensor, expected_len, expected_shape", + [ + ( + "dinov2_vits14", + "facebookresearch/dinov2:main", + None, + None, + torch.Tensor(2, 3, 224, 224), + None, + torch.Size([2, 384]), + ), + ( + "dinov2_vits14", + "facebookresearch/dinov2:main", + 1, + None, + torch.Tensor(2, 3, 224, 224), + 1, + torch.Size([2, 384, 16, 16]), + ), + ( + "dinov2_vits14", + "facebookresearch/dinov2:main", + 3, + None, + torch.Tensor(2, 3, 224, 224), + 3, + torch.Size([2, 384, 16, 16]), + ), + ], +) +def test_torchhub_model( + torchhub_model: wrappers.TorchHubModel, + input_tensor: torch.Tensor, + expected_len: int | None, + expected_shape: torch.Size, +) -> None: + """Tests the torch.hub model wrapper.""" + outputs = torchhub_model(input_tensor) + if torchhub_model._out_indices is not None: + assert isinstance(outputs, list) or isinstance(outputs, tuple) + assert len(outputs) == expected_len + assert isinstance(outputs[0], torch.Tensor) + assert outputs[0].shape == expected_shape + else: + assert isinstance(outputs, torch.Tensor) + assert outputs.shape == expected_shape + + +@pytest.fixture(scope="function") +def torchhub_model( + model_name: str, + repo_or_dir: str, + out_indices: int | Tuple[int, ...] | None, + model_kwargs: Dict[str, Any] | None, +) -> wrappers.TorchHubModel: + """TimmModel fixture.""" + return wrappers.TorchHubModel( + model_name=model_name, + repo_or_dir=repo_or_dir, + out_indices=out_indices, + model_kwargs=model_kwargs, + pretrained=False, + ) From 9e69f86c01cbe64bf06f2c471331807a727a0eb4 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Fri, 29 Nov 2024 15:46:49 +0100 Subject: [PATCH 4/7] reverted change to lock file --- pdm.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pdm.lock b/pdm.lock index 826a205e..21faa341 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "all", "dev", "docs", "lint", "test", "typecheck", "vision"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:f1cb189ed1b12e6aa22ba372e4b934e29afefdd17894792c86863bb62d15217f" +content_hash = "sha256:b8df35bf60e5573e36c31c4ad4f324d7693f16b31cadcd27e48b352ae6c0235b" [[metadata.targets]] requires_python = ">=3.10" From a92f59061a988280e87b0eefc8ed213d2bc96e43 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Fri, 29 Nov 2024 15:47:53 +0100 Subject: [PATCH 5/7] update docstring --- src/eva/vision/models/networks/backbones/torchhub/backbones.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/eva/vision/models/networks/backbones/torchhub/backbones.py b/src/eva/vision/models/networks/backbones/torchhub/backbones.py index a8162530..d1503a80 100644 --- a/src/eva/vision/models/networks/backbones/torchhub/backbones.py +++ b/src/eva/vision/models/networks/backbones/torchhub/backbones.py @@ -28,7 +28,7 @@ def torch_hub_model( model_name: The name of the model to load. repo_or_dir: The torch.hub repository or local directory to load the model from. checkpoint_path: The path to the checkpoint file. - pretrained: If set to `True`, load pretrained ImageNet-1k weights. + pretrained: If set to `True`, load pretrained model weights if available. out_indices: Whether and which multi-level patch embeddings to return. **kwargs: Additional arguments to pass to the model From 840bf4deb77d1e28f3f5519ec5756db9512c667f Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Fri, 29 Nov 2024 16:06:32 +0100 Subject: [PATCH 6/7] update docstrings --- src/eva/core/models/wrappers/from_torchhub.py | 2 +- tests/eva/core/models/wrappers/test_from_torchub.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/eva/core/models/wrappers/from_torchhub.py b/src/eva/core/models/wrappers/from_torchhub.py index 0fdb0c07..cb424d01 100644 --- a/src/eva/core/models/wrappers/from_torchhub.py +++ b/src/eva/core/models/wrappers/from_torchhub.py @@ -57,7 +57,7 @@ def __init__( @override def load_model(self) -> None: - """Builds and loads the timm model as feature extractor.""" + """Builds and loads the torch.hub model.""" self._model: nn.Module = torch.hub.load( repo_or_dir=self._repo_or_dir, model=self._model_name, diff --git a/tests/eva/core/models/wrappers/test_from_torchub.py b/tests/eva/core/models/wrappers/test_from_torchub.py index cfbfabc9..bf275234 100644 --- a/tests/eva/core/models/wrappers/test_from_torchub.py +++ b/tests/eva/core/models/wrappers/test_from_torchub.py @@ -1,4 +1,4 @@ -"""TimmModel tests.""" +"""TorchHubModel tests.""" from typing import Any, Dict, Tuple @@ -66,7 +66,7 @@ def torchhub_model( out_indices: int | Tuple[int, ...] | None, model_kwargs: Dict[str, Any] | None, ) -> wrappers.TorchHubModel: - """TimmModel fixture.""" + """TorchHubModel fixture.""" return wrappers.TorchHubModel( model_name=model_name, repo_or_dir=repo_or_dir, From 8d659e1c74d928556efebe96804c630a8a01a859 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Fri, 29 Nov 2024 16:08:01 +0100 Subject: [PATCH 7/7] fixed isinstance check] --- .../vision/models/networks/decoders/segmentation/decoder2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/eva/vision/models/networks/decoders/segmentation/decoder2d.py b/src/eva/vision/models/networks/decoders/segmentation/decoder2d.py index c43b351c..ce242713 100644 --- a/src/eva/vision/models/networks/decoders/segmentation/decoder2d.py +++ b/src/eva/vision/models/networks/decoders/segmentation/decoder2d.py @@ -52,7 +52,7 @@ def _forward_features(self, features: torch.Tensor | List[torch.Tensor]) -> torc """ if isinstance(features, torch.Tensor): features = [features] - if not isinstance(features, list) or features[0].ndim != 4: + if not isinstance(features, (list, tuple)) or features[0].ndim != 4: raise ValueError( "Input features should be a list of four (4) dimensional inputs of " "shape (batch_size, hidden_size, n_patches_height, n_patches_width)."