-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
TorchHubModel
model wrapper (#721)
- Loading branch information
Showing
8 changed files
with
236 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
"""Model wrapper for torch.hub models.""" | ||
|
||
from typing import Any, Callable, Dict, Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
from typing_extensions import override | ||
|
||
from eva.core.models import wrappers | ||
from eva.core.models.wrappers import _utils | ||
|
||
|
||
class TorchHubModel(wrappers.BaseModel): | ||
"""Model wrapper for `torch.hub` models.""" | ||
|
||
def __init__( | ||
self, | ||
model_name: str, | ||
repo_or_dir: str, | ||
pretrained: bool = True, | ||
checkpoint_path: str = "", | ||
out_indices: int | Tuple[int, ...] | None = None, | ||
norm: bool = False, | ||
trust_repo: bool = True, | ||
model_kwargs: Dict[str, Any] | None = None, | ||
tensor_transforms: Callable | None = None, | ||
) -> None: | ||
"""Initializes the encoder. | ||
Args: | ||
model_name: Name of model to instantiate. | ||
repo_or_dir: The torch.hub repository or local directory to load the model from. | ||
pretrained: If set to `True`, load pretrained ImageNet-1k weights. | ||
checkpoint_path: Path of checkpoint to load. | ||
out_indices: Returns last n blocks if `int`, all if `None`, select | ||
matching indices if sequence. | ||
norm: Wether to apply norm layer to all intermediate features. Only | ||
used when `out_indices` is not `None`. | ||
trust_repo: If set to `False`, a prompt will ask the user whether the | ||
repo should be trusted. | ||
model_kwargs: Extra model arguments. | ||
tensor_transforms: The transforms to apply to the output tensor | ||
produced by the model. | ||
""" | ||
super().__init__(tensor_transforms=tensor_transforms) | ||
|
||
self._model_name = model_name | ||
self._repo_or_dir = repo_or_dir | ||
self._pretrained = pretrained | ||
self._checkpoint_path = checkpoint_path | ||
self._out_indices = out_indices | ||
self._norm = norm | ||
self._trust_repo = trust_repo | ||
self._model_kwargs = model_kwargs or {} | ||
|
||
self.load_model() | ||
|
||
@override | ||
def load_model(self) -> None: | ||
"""Builds and loads the torch.hub model.""" | ||
self._model: nn.Module = torch.hub.load( | ||
repo_or_dir=self._repo_or_dir, | ||
model=self._model_name, | ||
trust_repo=self._trust_repo, | ||
pretrained=self._pretrained, | ||
**self._model_kwargs, | ||
) # type: ignore | ||
|
||
if self._checkpoint_path: | ||
_utils.load_model_weights(self._model, self._checkpoint_path) | ||
|
||
TorchHubModel.__name__ = self._model_name | ||
|
||
@override | ||
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor: | ||
if self._out_indices is not None: | ||
if not hasattr(self._model, "get_intermediate_layers"): | ||
raise ValueError( | ||
"Only models with `get_intermediate_layers` are supported " | ||
"when using `out_indices`." | ||
) | ||
|
||
return self._model.get_intermediate_layers( | ||
tensor, self._out_indices, reshape=True, return_class_token=False, norm=self._norm | ||
) | ||
|
||
return self._model(tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
"""Vision Model Backbones API.""" | ||
|
||
from eva.vision.models.networks.backbones import pathology, timm, universal | ||
from eva.vision.models.networks.backbones import pathology, timm, torchhub, universal | ||
from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model | ||
|
||
__all__ = ["pathology", "timm", "universal", "BackboneModelRegistry", "register_model"] | ||
__all__ = ["pathology", "timm", "torchhub", "universal", "BackboneModelRegistry", "register_model"] |
5 changes: 5 additions & 0 deletions
5
src/eva/vision/models/networks/backbones/torchhub/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""torch.hub backbones API.""" | ||
|
||
from eva.vision.models.networks.backbones.torchhub.backbones import torch_hub_model | ||
|
||
__all__ = ["torch_hub_model"] |
61 changes: 61 additions & 0 deletions
61
src/eva/vision/models/networks/backbones/torchhub/backbones.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
"""torch.hub backbones.""" | ||
|
||
import functools | ||
from typing import Tuple | ||
|
||
import torch | ||
from loguru import logger | ||
from torch import nn | ||
|
||
from eva.core.models import wrappers | ||
from eva.vision.models.networks.backbones.registry import BackboneModelRegistry | ||
|
||
HUB_REPOS = ["facebookresearch/dinov2:main", "kaiko-ai/towards_large_pathology_fms"] | ||
"""List of torch.hub repositories for which to add the models to the registry.""" | ||
|
||
|
||
def torch_hub_model( | ||
model_name: str, | ||
repo_or_dir: str, | ||
checkpoint_path: str | None = None, | ||
pretrained: bool = False, | ||
out_indices: int | Tuple[int, ...] | None = None, | ||
**kwargs, | ||
) -> nn.Module: | ||
"""Initializes any ViT model from torch.hub with weights from a specified checkpoint. | ||
Args: | ||
model_name: The name of the model to load. | ||
repo_or_dir: The torch.hub repository or local directory to load the model from. | ||
checkpoint_path: The path to the checkpoint file. | ||
pretrained: If set to `True`, load pretrained model weights if available. | ||
out_indices: Whether and which multi-level patch embeddings to return. | ||
**kwargs: Additional arguments to pass to the model | ||
Returns: | ||
The VIT model instance. | ||
""" | ||
logger.info( | ||
f"Loading torch.hub model {model_name} from {repo_or_dir}" | ||
+ (f"using checkpoint {checkpoint_path}" if checkpoint_path else "") | ||
) | ||
|
||
return wrappers.TorchHubModel( | ||
model_name=model_name, | ||
repo_or_dir=repo_or_dir, | ||
pretrained=pretrained, | ||
checkpoint_path=checkpoint_path or "", | ||
out_indices=out_indices, | ||
model_kwargs=kwargs, | ||
) | ||
|
||
|
||
BackboneModelRegistry._registry.update( | ||
{ | ||
f"torchhub/{repo}:{model_name}": functools.partial( | ||
torch_hub_model, model_name=model_name, repo_or_dir=repo | ||
) | ||
for repo in HUB_REPOS | ||
for model_name in torch.hub.list(repo, verbose=False) | ||
} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
"""TorchHubModel tests.""" | ||
|
||
from typing import Any, Dict, Tuple | ||
|
||
import pytest | ||
import torch | ||
|
||
from eva.core.models import wrappers | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"model_name, repo_or_dir, out_indices, model_kwargs, " | ||
"input_tensor, expected_len, expected_shape", | ||
[ | ||
( | ||
"dinov2_vits14", | ||
"facebookresearch/dinov2:main", | ||
None, | ||
None, | ||
torch.Tensor(2, 3, 224, 224), | ||
None, | ||
torch.Size([2, 384]), | ||
), | ||
( | ||
"dinov2_vits14", | ||
"facebookresearch/dinov2:main", | ||
1, | ||
None, | ||
torch.Tensor(2, 3, 224, 224), | ||
1, | ||
torch.Size([2, 384, 16, 16]), | ||
), | ||
( | ||
"dinov2_vits14", | ||
"facebookresearch/dinov2:main", | ||
3, | ||
None, | ||
torch.Tensor(2, 3, 224, 224), | ||
3, | ||
torch.Size([2, 384, 16, 16]), | ||
), | ||
], | ||
) | ||
def test_torchhub_model( | ||
torchhub_model: wrappers.TorchHubModel, | ||
input_tensor: torch.Tensor, | ||
expected_len: int | None, | ||
expected_shape: torch.Size, | ||
) -> None: | ||
"""Tests the torch.hub model wrapper.""" | ||
outputs = torchhub_model(input_tensor) | ||
if torchhub_model._out_indices is not None: | ||
assert isinstance(outputs, list) or isinstance(outputs, tuple) | ||
assert len(outputs) == expected_len | ||
assert isinstance(outputs[0], torch.Tensor) | ||
assert outputs[0].shape == expected_shape | ||
else: | ||
assert isinstance(outputs, torch.Tensor) | ||
assert outputs.shape == expected_shape | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def torchhub_model( | ||
model_name: str, | ||
repo_or_dir: str, | ||
out_indices: int | Tuple[int, ...] | None, | ||
model_kwargs: Dict[str, Any] | None, | ||
) -> wrappers.TorchHubModel: | ||
"""TorchHubModel fixture.""" | ||
return wrappers.TorchHubModel( | ||
model_name=model_name, | ||
repo_or_dir=repo_or_dir, | ||
out_indices=out_indices, | ||
model_kwargs=model_kwargs, | ||
pretrained=False, | ||
) |