Skip to content

Commit

Permalink
Trainers: skip weights and augmentations when saving hparams (#1670)
Browse files Browse the repository at this point in the history
* Update base.py to fix for custom augmentations

* Allow subclasses to ignore specific arguments

* Fix typing

* Save to self.weights

* pyupgrade

* Add test

* Save weights

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
2 people authored and nilsleh committed Nov 10, 2023
1 parent a84dce7 commit 4d2e1cd
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 17 deletions.
6 changes: 6 additions & 0 deletions tests/conf/ssl4eo_l_moco_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ model:
temperature: 0.07
memory_bank_size: 10
moco_momentum: 0.999
augmentation1:
class_path: kornia.augmentation.RandomResizedCrop
init_args:
size:
- 224
- 224
data:
class_path: SSL4EOLDataModule
init_args:
Expand Down
13 changes: 9 additions & 4 deletions torchgeo/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""Base classes for all :mod:`torchgeo` trainers."""

from abc import ABC, abstractmethod
from typing import Any
from collections.abc import Sequence
from typing import Any, Optional, Union

import lightning
from lightning.pytorch import LightningModule
Expand All @@ -27,10 +28,14 @@ class BaseTask(LightningModule, ABC):
#: Whether the goal is to minimize or maximize the performance metric to monitor.
mode = "min"

def __init__(self) -> None:
"""Initialize a new BaseTask instance."""
def __init__(self, ignore: Optional[Union[Sequence[str], str]] = None) -> None:
"""Initialize a new BaseTask instance.
Args:
ignore: Arguments to skip when saving hyperparameters.
"""
super().__init__()
self.save_hyperparameters()
self.save_hyperparameters(ignore=ignore)
self.configure_losses()
self.configure_metrics()
self.configure_models()
Expand Down
5 changes: 3 additions & 2 deletions torchgeo/trainers/byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,12 @@ def __init__(
*backbone*, *learning_rate*, and *learning_rate_schedule_patience* were
renamed to *model*, *lr*, and *patience*.
"""
super().__init__()
self.weights = weights
super().__init__(ignore="weights")

def configure_models(self) -> None:
"""Initialize the model."""
weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"]
weights = self.weights
in_channels: int = self.hparams["in_channels"]

# Create backbone
Expand Down
5 changes: 3 additions & 2 deletions torchgeo/trainers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class and used with 'ce' loss.
*learning_rate* and *learning_rate_schedule_patience* were renamed to
*lr* and *patience*.
"""
super().__init__()
self.weights = weights
super().__init__(ignore="weights")

def configure_losses(self) -> None:
"""Initialize the loss criterion.
Expand Down Expand Up @@ -117,7 +118,7 @@ def configure_metrics(self) -> None:

def configure_models(self) -> None:
"""Initialize the model."""
weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"]
weights = self.weights

# Create model
self.model = timm.create_model(
Expand Down
5 changes: 3 additions & 2 deletions torchgeo/trainers/moco.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def __init__(
if memory_bank_size > 0:
warnings.warn("MoCo v3 does not use a memory bank")

super().__init__()
self.weights = weights
super().__init__(ignore=["weights", "augmentation1", "augmentation2"])

grayscale_weights = grayscale_weights or torch.ones(in_channels)
aug1, aug2 = moco_augmentations(version, size, grayscale_weights)
Expand All @@ -236,7 +237,7 @@ def configure_losses(self) -> None:
def configure_models(self) -> None:
"""Initialize the model."""
model: str = self.hparams["model"]
weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"]
weights = self.weights
in_channels: int = self.hparams["in_channels"]
version: int = self.hparams["version"]
layers: int = self.hparams["layers"]
Expand Down
7 changes: 4 additions & 3 deletions torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def __init__(
*learning_rate* and *learning_rate_schedule_patience* were renamed to
*lr* and *patience*.
"""
super().__init__()
self.weights = weights
super().__init__(ignore="weights")

def configure_losses(self) -> None:
"""Initialize the loss criterion.
Expand Down Expand Up @@ -110,7 +111,7 @@ def configure_metrics(self) -> None:
def configure_models(self) -> None:
"""Initialize the model."""
# Create model
weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"]
weights = self.weights
self.model = timm.create_model(
self.hparams["model"],
num_classes=self.hparams["num_outputs"],
Expand Down Expand Up @@ -256,7 +257,7 @@ class PixelwiseRegressionTask(RegressionTask):

def configure_models(self) -> None:
"""Initialize the model."""
weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"]
weights = self.weights

if self.hparams["model"] == "unet":
self.model = smp.Unet(
Expand Down
5 changes: 3 additions & 2 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ class and used with 'ce' loss.
UserWarning,
)

super().__init__()
self.weights = weights
super().__init__(ignore="weights")

def configure_losses(self) -> None:
"""Initialize the loss criterion.
Expand Down Expand Up @@ -151,7 +152,7 @@ def configure_models(self) -> None:
"""
model: str = self.hparams["model"]
backbone: str = self.hparams["backbone"]
weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"]
weights = self.weights
in_channels: int = self.hparams["in_channels"]
num_classes: int = self.hparams["num_classes"]
num_filters: int = self.hparams["num_filters"]
Expand Down
5 changes: 3 additions & 2 deletions torchgeo/trainers/simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def __init__(
if memory_bank_size == 0:
warnings.warn("SimCLR v2 uses a memory bank")

super().__init__()
self.weights = weights
super().__init__(ignore=["weights", "augmentations"])

grayscale_weights = grayscale_weights or torch.ones(in_channels)
self.augmentations = augmentations or simclr_augmentations(
Expand All @@ -151,7 +152,7 @@ def configure_losses(self) -> None:

def configure_models(self) -> None:
"""Initialize the model."""
weights: Optional[Union[WeightsEnum, str, bool]] = self.hparams["weights"]
weights = self.weights
hidden_dim: int = self.hparams["hidden_dim"]
output_dim: int = self.hparams["output_dim"]

Expand Down

0 comments on commit 4d2e1cd

Please sign in to comment.