From 108b2a3f2ac6b36b483133c22a7755cdd836394a Mon Sep 17 00:00:00 2001 From: ioangatop Date: Thu, 29 Feb 2024 12:05:52 +0100 Subject: [PATCH 1/3] Allow to load remote model weights with ModelFromFunction wrapper --- configs/vision/bach.yaml | 90 ++++++++++++++++++++++++ src/eva/models/networks/_utils.py | 28 ++++++++ src/eva/models/networks/from_function.py | 41 +---------- src/eva/models/networks/mlp.py | 2 +- src/eva/setup.py | 24 ++++++- src/eva/utils/__init__.py | 6 +- src/eva/utils/logger.py | 27 ------- 7 files changed, 146 insertions(+), 72 deletions(-) create mode 100644 configs/vision/bach.yaml create mode 100644 src/eva/models/networks/_utils.py delete mode 100644 src/eva/utils/logger.py diff --git a/configs/vision/bach.yaml b/configs/vision/bach.yaml new file mode 100644 index 00000000..74702472 --- /dev/null +++ b/configs/vision/bach.yaml @@ -0,0 +1,90 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + default_root_dir: &LIGHTNING_ROOT ${oc.env:LIGHTNING_ROOT, logs/dino_vits16/online/bach} + max_steps: &MAX_STEPS 12500 + callbacks: + - class_path: pytorch_lightning.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: pytorch_lightning.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: true + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: pytorch_lightning.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: 500 + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + logger: + - class_path: pytorch_lightning.loggers.TensorBoardLogger + init_args: + save_dir: *LIGHTNING_ROOT + name: "" +model: + class_path: eva.HeadModule + init_args: + backbone: + class_path: eva.models.ModelFromFunction + init_args: + path: timm.create_model + arguments: + model_name: vit_small_patch16_224 + num_classes: 0 + pretrained: false + checkpoint_path: https://github.com/lunit-io/benchmark-ssl-pathology/releases/download/pretrained-weights/dino_vit_small_patch8_ep200.torch + head: + class_path: torch.nn.Linear + init_args: + in_features: 384 + out_features: &NUM_CLASSES 4 + criterion: torch.nn.CrossEntropyLoss + optimizer: + class_path: torch.optim.SGD + init_args: + lr: &LR_VALUE 0.00064 + momentum: 0.9 + weight_decay: 0.0 + lr_scheduler: + class_path: torch.optim.lr_scheduler.CosineAnnealingLR + init_args: + T_max: *MAX_STEPS + eta_min: 0.0 + metrics: + common: + - class_path: eva.metrics.AverageLoss + - class_path: eva.metrics.MulticlassClassificationMetrics + init_args: + num_classes: *NUM_CLASSES +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.vision.datasets.BACH + init_args: &DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data}/bach + split: train + download: ${oc.env:DOWNLOAD_DATA, true} + image_transforms: + class_path: eva.vision.data.transforms.common.ResizeAndCrop + init_args: + size: ${oc.env:RESIZE_DIM, 224} + mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} + std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} + val: + class_path: eva.vision.datasets.BACH + init_args: + <<: *DATASET_ARGS + split: val + dataloaders: + train: + batch_size: &BATCH_SIZE 256 + shuffle: true + val: + batch_size: *BATCH_SIZE diff --git a/src/eva/models/networks/_utils.py b/src/eva/models/networks/_utils.py new file mode 100644 index 00000000..1b8f5b32 --- /dev/null +++ b/src/eva/models/networks/_utils.py @@ -0,0 +1,28 @@ +"""Utilities and helper functions for models.""" + +from lightning_fabric.utilities import cloud_io +from loguru import logger +from torch import nn + + +def load_model_weights( + model: nn.Module, + checkpoint_path: str, +) -> None: + """Loads (local or remote) weights to the model in-place. + + Args: + model: The model to load the weights to. + checkpoint_path: The path to the model weights/checkpoint. + """ + logger.info(f"Loading '{model.__class__.__name__}' model from checkpoint '{checkpoint_path}'") + + fs = cloud_io.get_filesystem(checkpoint_path) + with fs.open(checkpoint_path, "rb") as file: + checkpoint = cloud_io._load(file, map_location="cpu") # type: ignore + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + model.load_state_dict(checkpoint, strict=True) + + logger.info(f"Loading weights from '{checkpoint_path}' completed successfully.") diff --git a/src/eva/models/networks/from_function.py b/src/eva/models/networks/from_function.py index 78707234..ddaca38a 100644 --- a/src/eva/models/networks/from_function.py +++ b/src/eva/models/networks/from_function.py @@ -4,10 +4,11 @@ import jsonargparse import torch -from loguru import logger from torch import nn from typing_extensions import override +from eva.models.networks import _utils + class ModelFromFunction(nn.Module): """Wrapper class for models which are initialized from functions. @@ -27,15 +28,6 @@ def __init__( path: The path to the callable object (class or function). arguments: The extra callable function / class arguments. checkpoint_path: The path to the checkpoint to load the model weights from. - - Example: - >>> import torchvision - >>> network = ModelFromFunction( - >>> path=torchvision.models.resnet18, - >>> arguments={ - >>> "weights": torchvision.models.ResNet18_Weights.DEFAULT, - >>> }, - >>> ) """ super().__init__() @@ -50,34 +42,7 @@ def build_model(self) -> nn.Module: class_path = jsonargparse.class_from_function(self._path, func_return=nn.Module) model = class_path(**self._arguments or {}) if self._checkpoint_path is not None: - model = self.load_model_checkpoint(model, self._checkpoint_path) - return model - - def load_model_checkpoint( - self, - model: torch.nn.Module, - checkpoint_path: str, - ) -> torch.nn.Module: - """Initializes the model with the weights. - - Args: - model: model to initialize. - checkpoint_path: The path to the checkpoint to load the model weights from. - - Returns: - the model initialized with the checkpoint. - """ - logger.info(f"Loading {model.__class__.__name__} from checkpoint {checkpoint_path}") - - with open(checkpoint_path, "rb") as f: - checkpoint = torch.load(f, map_location="cpu") # type: ignore[arg-type] - if "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - model.load_state_dict(checkpoint, strict=True) - logger.info( - f"Loaded modules for {model.__class__.__name__} from checkpoint " - f"{checkpoint_path}" - ) + _utils.load_model_weights(model, self._checkpoint_path) return model @override diff --git a/src/eva/models/networks/mlp.py b/src/eva/models/networks/mlp.py index 7f1a3adf..4decad2a 100644 --- a/src/eva/models/networks/mlp.py +++ b/src/eva/models/networks/mlp.py @@ -17,7 +17,7 @@ def __init__( hidden_activation_fn: Type[torch.nn.Module] | None = nn.ReLU, output_activation_fn: Type[torch.nn.Module] | None = None, dropout: float = 0.0, - ): + ) -> None: """Initializes the MLP. Args: diff --git a/src/eva/setup.py b/src/eva/setup.py index 92b0bf55..7a2e5810 100644 --- a/src/eva/setup.py +++ b/src/eva/setup.py @@ -4,6 +4,27 @@ import sys import warnings +from loguru import logger + + +def _initialize_logger() -> None: + """Initializes, manipulates and customizes the logger. + + This customizable logger can be used by just importing `loguru` + from everywhere as follows: + >>> from loguru import logger + >>> logger.info(...) + """ + logger.remove() + logger.add( + sys.stderr, + format="[{time:HH:mm:ss}]" + " {level} " + " | {message}", + colorize=True, + level="INFO", + ) + def _suppress_warnings() -> None: """Suppress all warnings from all subprocesses.""" @@ -13,7 +34,7 @@ def _suppress_warnings() -> None: def _enable_mps_fallback() -> None: - """If not set, it enables the MPS fallback in torch. + """It enables the MPS fallback in torch. Note that this action has to take place before importing torch. """ @@ -23,6 +44,7 @@ def _enable_mps_fallback() -> None: def setup() -> None: """Sets up the environment before the module is imported.""" + _initialize_logger() _suppress_warnings() _enable_mps_fallback() diff --git a/src/eva/utils/__init__.py b/src/eva/utils/__init__.py index 27c0f808..f99e16fd 100644 --- a/src/eva/utils/__init__.py +++ b/src/eva/utils/__init__.py @@ -1,5 +1 @@ -"""Utilities and helper functionalities.""" - -from eva.utils.logger import logger - -__all__ = ["logger"] +"""Utilities and library level helper functionalities.""" diff --git a/src/eva/utils/logger.py b/src/eva/utils/logger.py deleted file mode 100644 index 1b245c86..00000000 --- a/src/eva/utils/logger.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Initializes the logger of the library. - -This customizable logger can be used by just importing -`loguru` from everywhere as follows: ->>> from loguru import logger ->>> logger.info(...) -""" - -import sys - -from loguru import logger - - -def _initialize_logger() -> None: - """Manipulates and customizes the logger.""" - logger.remove() - logger.add( - sys.stderr, - format="[{time:HH:mm:ss}]" - " {level} " - " | {message}", - colorize=True, - level="INFO", - ) - - -_initialize_logger() From 9f3a792de9defdc670825c6544f0dc6ff79cb581 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Thu, 29 Feb 2024 12:08:50 +0100 Subject: [PATCH 2/3] remove dev files --- configs/vision/bach.yaml | 90 ---------------------------------------- 1 file changed, 90 deletions(-) delete mode 100644 configs/vision/bach.yaml diff --git a/configs/vision/bach.yaml b/configs/vision/bach.yaml deleted file mode 100644 index 74702472..00000000 --- a/configs/vision/bach.yaml +++ /dev/null @@ -1,90 +0,0 @@ ---- -trainer: - class_path: eva.Trainer - init_args: - default_root_dir: &LIGHTNING_ROOT ${oc.env:LIGHTNING_ROOT, logs/dino_vits16/online/bach} - max_steps: &MAX_STEPS 12500 - callbacks: - - class_path: pytorch_lightning.callbacks.LearningRateMonitor - init_args: - logging_interval: epoch - - class_path: pytorch_lightning.callbacks.ModelCheckpoint - init_args: - filename: best - save_last: true - save_top_k: 1 - monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} - mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} - - class_path: pytorch_lightning.callbacks.EarlyStopping - init_args: - min_delta: 0 - patience: 500 - monitor: *MONITOR_METRIC - mode: *MONITOR_METRIC_MODE - logger: - - class_path: pytorch_lightning.loggers.TensorBoardLogger - init_args: - save_dir: *LIGHTNING_ROOT - name: "" -model: - class_path: eva.HeadModule - init_args: - backbone: - class_path: eva.models.ModelFromFunction - init_args: - path: timm.create_model - arguments: - model_name: vit_small_patch16_224 - num_classes: 0 - pretrained: false - checkpoint_path: https://github.com/lunit-io/benchmark-ssl-pathology/releases/download/pretrained-weights/dino_vit_small_patch8_ep200.torch - head: - class_path: torch.nn.Linear - init_args: - in_features: 384 - out_features: &NUM_CLASSES 4 - criterion: torch.nn.CrossEntropyLoss - optimizer: - class_path: torch.optim.SGD - init_args: - lr: &LR_VALUE 0.00064 - momentum: 0.9 - weight_decay: 0.0 - lr_scheduler: - class_path: torch.optim.lr_scheduler.CosineAnnealingLR - init_args: - T_max: *MAX_STEPS - eta_min: 0.0 - metrics: - common: - - class_path: eva.metrics.AverageLoss - - class_path: eva.metrics.MulticlassClassificationMetrics - init_args: - num_classes: *NUM_CLASSES -data: - class_path: eva.DataModule - init_args: - datasets: - train: - class_path: eva.vision.datasets.BACH - init_args: &DATASET_ARGS - root: ${oc.env:DATA_ROOT, ./data}/bach - split: train - download: ${oc.env:DOWNLOAD_DATA, true} - image_transforms: - class_path: eva.vision.data.transforms.common.ResizeAndCrop - init_args: - size: ${oc.env:RESIZE_DIM, 224} - mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} - std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} - val: - class_path: eva.vision.datasets.BACH - init_args: - <<: *DATASET_ARGS - split: val - dataloaders: - train: - batch_size: &BATCH_SIZE 256 - shuffle: true - val: - batch_size: *BATCH_SIZE From 7d66d04857025c24442b0cf75dae9a1f2a3bc261 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Thu, 29 Feb 2024 12:39:13 +0100 Subject: [PATCH 3/3] minor style update --- src/eva/models/networks/_utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/eva/models/networks/_utils.py b/src/eva/models/networks/_utils.py index 1b8f5b32..8ea5707a 100644 --- a/src/eva/models/networks/_utils.py +++ b/src/eva/models/networks/_utils.py @@ -5,10 +5,7 @@ from torch import nn -def load_model_weights( - model: nn.Module, - checkpoint_path: str, -) -> None: +def load_model_weights(model: nn.Module, checkpoint_path: str) -> None: """Loads (local or remote) weights to the model in-place. Args: