Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to load remote model weights with ModelFromFunction wrapper #187

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/eva/models/networks/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Utilities and helper functions for models."""

from lightning_fabric.utilities import cloud_io
from loguru import logger
from torch import nn


def load_model_weights(model: nn.Module, checkpoint_path: str) -> None:
"""Loads (local or remote) weights to the model in-place.

Args:
model: The model to load the weights to.
checkpoint_path: The path to the model weights/checkpoint.
"""
logger.info(f"Loading '{model.__class__.__name__}' model from checkpoint '{checkpoint_path}'")

fs = cloud_io.get_filesystem(checkpoint_path)
with fs.open(checkpoint_path, "rb") as file:
checkpoint = cloud_io._load(file, map_location="cpu") # type: ignore
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]

model.load_state_dict(checkpoint, strict=True)

logger.info(f"Loading weights from '{checkpoint_path}' completed successfully.")
41 changes: 3 additions & 38 deletions src/eva/models/networks/from_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

import jsonargparse
import torch
from loguru import logger
from torch import nn
from typing_extensions import override

from eva.models.networks import _utils


class ModelFromFunction(nn.Module):
"""Wrapper class for models which are initialized from functions.
Expand All @@ -27,15 +28,6 @@ def __init__(
path: The path to the callable object (class or function).
arguments: The extra callable function / class arguments.
checkpoint_path: The path to the checkpoint to load the model weights from.

Example:
>>> import torchvision
>>> network = ModelFromFunction(
>>> path=torchvision.models.resnet18,
>>> arguments={
>>> "weights": torchvision.models.ResNet18_Weights.DEFAULT,
>>> },
>>> )
"""
super().__init__()

Expand All @@ -50,34 +42,7 @@ def build_model(self) -> nn.Module:
class_path = jsonargparse.class_from_function(self._path, func_return=nn.Module)
model = class_path(**self._arguments or {})
if self._checkpoint_path is not None:
model = self.load_model_checkpoint(model, self._checkpoint_path)
return model

def load_model_checkpoint(
self,
model: torch.nn.Module,
checkpoint_path: str,
) -> torch.nn.Module:
"""Initializes the model with the weights.

Args:
model: model to initialize.
checkpoint_path: The path to the checkpoint to load the model weights from.

Returns:
the model initialized with the checkpoint.
"""
logger.info(f"Loading {model.__class__.__name__} from checkpoint {checkpoint_path}")

with open(checkpoint_path, "rb") as f:
checkpoint = torch.load(f, map_location="cpu") # type: ignore[arg-type]
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
model.load_state_dict(checkpoint, strict=True)
logger.info(
f"Loaded modules for {model.__class__.__name__} from checkpoint "
f"{checkpoint_path}"
)
_utils.load_model_weights(model, self._checkpoint_path)
return model

@override
Expand Down
2 changes: 1 addition & 1 deletion src/eva/models/networks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
hidden_activation_fn: Type[torch.nn.Module] | None = nn.ReLU,
output_activation_fn: Type[torch.nn.Module] | None = None,
dropout: float = 0.0,
):
) -> None:
"""Initializes the MLP.

Args:
Expand Down
24 changes: 23 additions & 1 deletion src/eva/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,27 @@
import sys
import warnings

from loguru import logger


def _initialize_logger() -> None:
"""Initializes, manipulates and customizes the logger.

This customizable logger can be used by just importing `loguru`
from everywhere as follows:
>>> from loguru import logger
>>> logger.info(...)
"""
logger.remove()
logger.add(
sys.stderr,
format="<magenta>[{time:HH:mm:ss}]</magenta>"
" <bold><level>{level}</level></bold> "
" | {message}",
colorize=True,
level="INFO",
)


def _suppress_warnings() -> None:
"""Suppress all warnings from all subprocesses."""
Expand All @@ -13,7 +34,7 @@ def _suppress_warnings() -> None:


def _enable_mps_fallback() -> None:
"""If not set, it enables the MPS fallback in torch.
"""It enables the MPS fallback in torch.

Note that this action has to take place before importing torch.
"""
Expand All @@ -23,6 +44,7 @@ def _enable_mps_fallback() -> None:

def setup() -> None:
"""Sets up the environment before the module is imported."""
_initialize_logger()
_suppress_warnings()
_enable_mps_fallback()

Expand Down
6 changes: 1 addition & 5 deletions src/eva/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1 @@
"""Utilities and helper functionalities."""

from eva.utils.logger import logger

__all__ = ["logger"]
"""Utilities and library level helper functionalities."""
27 changes: 0 additions & 27 deletions src/eva/utils/logger.py

This file was deleted.

Loading