-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
112 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Advanced user guide | ||
|
||
### [Replicate Evaluations](replicate_evaluations.md) | ||
Find the instructions to completely replicate the evaluation results presented [on the ***eva*** main page](../../index.md) | ||
|
||
### [Model Wrappers](model_wrappers.md) | ||
Explains how to use **eva**'s Model Wrapper API to load models from different formats and sources. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Model Wrappers | ||
|
||
|
||
This document shows how to use **eva**'s [Model Wrapper API](../../../reference/models/networks/#wrappers) (`eva.models.networks.wrappers`) to load different model formats from a series of sources such as PyTorch Hub, HuggingFace Model Hub and ONNX. | ||
|
||
## Loading PyTorch models | ||
The **eva** framework is built on top of PyTorch Lightning and thus naturally supports loading PyTorch models. | ||
You just need to specify the class path of your model in the backbone section of the `.yaml` config file. | ||
|
||
``` | ||
backbone: | ||
class_path: path.to.your.ModelClass | ||
init_args: | ||
arg_1: ... | ||
arg_2: ... | ||
``` | ||
|
||
Note that your `ModelClass` should subclass `torch.nn.Module` and implement the `forward()` method to return embedding tensors of shape `[embedding_dim]`. | ||
|
||
### PyTorch Hub | ||
To load models from PyTorch Hub or other torch model providers, the easiest way is to use the `ModelFromFunction` wrapper class: | ||
|
||
``` | ||
backbone: | ||
class_path: eva.models.networks.wrappers.ModelFromFunction | ||
init_args: | ||
path: torch.hub.load | ||
arguments: | ||
repo_or_dir: facebookresearch/dino:main | ||
model: dino_vits16 | ||
pretrained: false | ||
checkpoint_path: path/to/your/checkpoint.torch | ||
``` | ||
|
||
|
||
Note that if a `checkpoint_path` is provided, `ModelFromFunction` will automatically initialize the specified model using the provided weights from that checkpoint file. | ||
|
||
|
||
### timm | ||
Similar to the above example, we can easily load models using the common vision library `timm`: | ||
``` | ||
backbone: | ||
class_path: eva.models.networks.wrappers.ModelFromFunction | ||
init_args: | ||
path: timm.create_model | ||
arguments: | ||
model_name: resnet18 | ||
pretrained: true | ||
``` | ||
|
||
|
||
## Loading models from HuggingFace Hub | ||
For loading models from HuggingFace Hub, **eva** provides a custom wrapper class `HuggingFaceModel` which can be used as follows: | ||
|
||
``` | ||
backbone: | ||
class_path: eva.models.networks.wrappers.HuggingFaceModel | ||
init_args: | ||
model_name_or_path: owkin/phikon | ||
tensor_transforms: | ||
class_path: eva.vision.data.transforms.model_output.ExtractCLSFeatures | ||
``` | ||
|
||
In the above example, the forward pass implemented by the `owkin/phikon` model returns an output tensor containing the hidden states of all input tokens. In order to extract the state corresponding to the CLS token only, we can specify a transformation via the `tensor_transforms` argument which will be applied to the model output. | ||
|
||
## Loading ONNX models | ||
`.onnx` model checkpoints can be loaded using the `ONNXModel` wrapper class as follows: | ||
|
||
``` | ||
class_path: eva.models.networks.wrappers.ONNXModel | ||
init_args: | ||
path: path/to/model.onnx | ||
device: cuda | ||
``` | ||
|
||
## Implementing custom model wrappers | ||
|
||
You can also implement your own model wrapper classes, in case your model format is not supported by the wrapper classes that **eva** already provides. To do so, you need to subclass `eva.models.networks.wrappers.BaseModel` and implement the following abstract methods: | ||
|
||
- `load_model`: Returns an instantiated model object & loads pre-trained model weights from a checkpoint if available. | ||
- `model_forward`: Implements the forward pass of the model and returns the output as a `torch.Tensor` of shape `[embedding_dim]` | ||
|
||
You can take the implementations of `ModelFromFunction`, `HuggingFaceModel` and `ONNXModel` wrappers as a reference. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
"""Model Wrappers API.""" | ||
|
||
from eva.models.networks.wrappers.base import BaseModel | ||
from eva.models.networks.wrappers.from_function import ModelFromFunction | ||
from eva.models.networks.wrappers.huggingface import HuggingFaceModel | ||
from eva.models.networks.wrappers.onnx import ONNXModel | ||
|
||
__all__ = ["ModelFromFunction", "HuggingFaceModel", "ONNXModel"] | ||
__all__ = ["BaseModel", "ModelFromFunction", "HuggingFaceModel", "ONNXModel"] |