Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TorchHubModel model wrapper #721

Merged
merged 7 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/eva/core/models/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from eva.core.models.wrappers.base import BaseModel
from eva.core.models.wrappers.from_function import ModelFromFunction
from eva.core.models.wrappers.from_torchhub import TorchHubModel
from eva.core.models.wrappers.huggingface import HuggingFaceModel
from eva.core.models.wrappers.onnx import ONNXModel

__all__ = [
"BaseModel",
"ModelFromFunction",
"HuggingFaceModel",
"ModelFromFunction",
"ONNXModel",
"TorchHubModel",
]
87 changes: 87 additions & 0 deletions src/eva/core/models/wrappers/from_torchhub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Model wrapper for torch.hub models."""

from typing import Any, Callable, Dict, Tuple

import torch
import torch.nn as nn
from typing_extensions import override

from eva.core.models import wrappers
from eva.core.models.wrappers import _utils


class TorchHubModel(wrappers.BaseModel):
"""Model wrapper for `torch.hub` models."""

def __init__(
self,
model_name: str,
repo_or_dir: str,
pretrained: bool = True,
checkpoint_path: str = "",
out_indices: int | Tuple[int, ...] | None = None,
norm: bool = False,
trust_repo: bool = True,
model_kwargs: Dict[str, Any] | None = None,
tensor_transforms: Callable | None = None,
) -> None:
"""Initializes the encoder.

Args:
model_name: Name of model to instantiate.
repo_or_dir: The torch.hub repository or local directory to load the model from.
pretrained: If set to `True`, load pretrained ImageNet-1k weights.
checkpoint_path: Path of checkpoint to load.
out_indices: Returns last n blocks if `int`, all if `None`, select
matching indices if sequence.
norm: Wether to apply norm layer to all intermediate features. Only
used when `out_indices` is not `None`.
trust_repo: If set to `False`, a prompt will ask the user whether the
repo should be trusted.
model_kwargs: Extra model arguments.
tensor_transforms: The transforms to apply to the output tensor
produced by the model.
"""
super().__init__(tensor_transforms=tensor_transforms)

self._model_name = model_name
self._repo_or_dir = repo_or_dir
self._pretrained = pretrained
self._checkpoint_path = checkpoint_path
self._out_indices = out_indices
self._norm = norm
self._trust_repo = trust_repo
self._model_kwargs = model_kwargs or {}

self.load_model()

@override
def load_model(self) -> None:
"""Builds and loads the torch.hub model."""
self._model: nn.Module = torch.hub.load(
repo_or_dir=self._repo_or_dir,
model=self._model_name,
trust_repo=self._trust_repo,
pretrained=self._pretrained,
**self._model_kwargs,
) # type: ignore

if self._checkpoint_path:
_utils.load_model_weights(self._model, self._checkpoint_path)

TorchHubModel.__name__ = self._model_name

@override
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
if self._out_indices is not None:
if not hasattr(self._model, "get_intermediate_layers"):
raise ValueError(
"Only models with `get_intermediate_layers` are supported "
"when using `out_indices`."
)

return self._model.get_intermediate_layers(
tensor, self._out_indices, reshape=True, return_class_token=False, norm=self._norm
)

return self._model(tensor)
4 changes: 2 additions & 2 deletions src/eva/vision/models/networks/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Vision Model Backbones API."""

from eva.vision.models.networks.backbones import pathology, timm, universal
from eva.vision.models.networks.backbones import pathology, timm, torchhub, universal
from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model

__all__ = ["pathology", "timm", "universal", "BackboneModelRegistry", "register_model"]
__all__ = ["pathology", "timm", "torchhub", "universal", "BackboneModelRegistry", "register_model"]
5 changes: 5 additions & 0 deletions src/eva/vision/models/networks/backbones/torchhub/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""torch.hub backbones API."""

from eva.vision.models.networks.backbones.torchhub.backbones import torch_hub_model

__all__ = ["torch_hub_model"]
61 changes: 61 additions & 0 deletions src/eva/vision/models/networks/backbones/torchhub/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""torch.hub backbones."""

import functools
from typing import Tuple

import torch
from loguru import logger
from torch import nn

from eva.core.models import wrappers
from eva.vision.models.networks.backbones.registry import BackboneModelRegistry

HUB_REPOS = ["facebookresearch/dinov2:main", "kaiko-ai/towards_large_pathology_fms"]
"""List of torch.hub repositories for which to add the models to the registry."""


def torch_hub_model(
model_name: str,
repo_or_dir: str,
checkpoint_path: str | None = None,
pretrained: bool = False,
out_indices: int | Tuple[int, ...] | None = None,
**kwargs,
) -> nn.Module:
"""Initializes any ViT model from torch.hub with weights from a specified checkpoint.

Args:
model_name: The name of the model to load.
repo_or_dir: The torch.hub repository or local directory to load the model from.
checkpoint_path: The path to the checkpoint file.
pretrained: If set to `True`, load pretrained model weights if available.
out_indices: Whether and which multi-level patch embeddings to return.
**kwargs: Additional arguments to pass to the model

Returns:
The VIT model instance.
"""
logger.info(
f"Loading torch.hub model {model_name} from {repo_or_dir}"
+ (f"using checkpoint {checkpoint_path}" if checkpoint_path else "")
)

return wrappers.TorchHubModel(
model_name=model_name,
repo_or_dir=repo_or_dir,
pretrained=pretrained,
checkpoint_path=checkpoint_path or "",
out_indices=out_indices,
model_kwargs=kwargs,
)


BackboneModelRegistry._registry.update(
{
f"torchhub/{repo}:{model_name}": functools.partial(
torch_hub_model, model_name=model_name, repo_or_dir=repo
)
for repo in HUB_REPOS
for model_name in torch.hub.list(repo, verbose=False)
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _forward_features(self, features: torch.Tensor | List[torch.Tensor]) -> torc
"""
if isinstance(features, torch.Tensor):
features = [features]
if not isinstance(features, list) or features[0].ndim != 4:
if not isinstance(features, (list, tuple)) or features[0].ndim != 4:
raise ValueError(
"Input features should be a list of four (4) dimensional inputs of "
"shape (batch_size, hidden_size, n_patches_height, n_patches_width)."
Expand Down
2 changes: 1 addition & 1 deletion src/eva/vision/models/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from eva.vision.models.wrappers.from_registry import ModelFromRegistry
from eva.vision.models.wrappers.from_timm import TimmModel

__all__ = ["TimmModel", "ModelFromRegistry"]
__all__ = ["ModelFromRegistry", "TimmModel"]
76 changes: 76 additions & 0 deletions tests/eva/core/models/wrappers/test_from_torchub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""TorchHubModel tests."""

from typing import Any, Dict, Tuple

import pytest
import torch

from eva.core.models import wrappers


@pytest.mark.parametrize(
"model_name, repo_or_dir, out_indices, model_kwargs, "
"input_tensor, expected_len, expected_shape",
[
(
"dinov2_vits14",
"facebookresearch/dinov2:main",
None,
None,
torch.Tensor(2, 3, 224, 224),
None,
torch.Size([2, 384]),
),
(
"dinov2_vits14",
"facebookresearch/dinov2:main",
1,
None,
torch.Tensor(2, 3, 224, 224),
1,
torch.Size([2, 384, 16, 16]),
),
(
"dinov2_vits14",
"facebookresearch/dinov2:main",
3,
None,
torch.Tensor(2, 3, 224, 224),
3,
torch.Size([2, 384, 16, 16]),
),
],
)
def test_torchhub_model(
torchhub_model: wrappers.TorchHubModel,
input_tensor: torch.Tensor,
expected_len: int | None,
expected_shape: torch.Size,
) -> None:
"""Tests the torch.hub model wrapper."""
outputs = torchhub_model(input_tensor)
if torchhub_model._out_indices is not None:
assert isinstance(outputs, list) or isinstance(outputs, tuple)
assert len(outputs) == expected_len
assert isinstance(outputs[0], torch.Tensor)
assert outputs[0].shape == expected_shape
else:
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == expected_shape


@pytest.fixture(scope="function")
def torchhub_model(
model_name: str,
repo_or_dir: str,
out_indices: int | Tuple[int, ...] | None,
model_kwargs: Dict[str, Any] | None,
) -> wrappers.TorchHubModel:
"""TorchHubModel fixture."""
return wrappers.TorchHubModel(
model_name=model_name,
repo_or_dir=repo_or_dir,
out_indices=out_indices,
model_kwargs=model_kwargs,
pretrained=False,
)