Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support custom configure_optimizers method and avoid from saving random model by default #367

Merged
merged 5 commits into from
Sep 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions fuse/dl/lightning/pl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

import pytorch_lightning as pl
from typing import Optional, Union, Tuple
from typing import Optional, Union, Tuple, Callable
from collections import OrderedDict
import os

Expand Down Expand Up @@ -50,11 +50,11 @@ def __init__(
List[Tuple[str, OrderedDict[str, MetricBase]]],
]
] = None,
optimizers_and_lr_schs: Any = None,
optimizers_and_lr_schs: Union[dict, Callable] = None,
callbacks: Optional[Sequence[pl.Callback]] = None,
best_epoch_source: Optional[Union[Dict, List[Dict]]] = None,
save_hyperparameters_kwargs: Optional[dict] = None,
save_model: bool = True,
save_model: bool = False,
save_arguments: bool = False,
tensorboard_sep: str = ".",
log_unit: str = None,
Expand All @@ -63,7 +63,6 @@ def __init__(
"""
:param model_dir: location for checkpoints and logs
:param model: Pytorch model to use
:param optimizers_and_lr_schs: see pl.LightningModule.configure_optimizers for details and relevant options
:param losses: dict of FuseMedML style losses
Will be used for both train and validation unless validation_losses is specified.
:param validation_losses: Optional, typically used when there are multiple validation dataloaders - each with a different loss
Expand All @@ -75,7 +74,7 @@ def __init__(
:param test_metrics: ordereddict of FuseMedML style metrics - used for test set (must be different instances of metrics (from train_metrics)
In case of multiple test dataloaders, test_metrics should be list of tuples (that keeps the same dataloaders list order),
Each tuple built from test dataloader name and corresponding metrics dict.
:param optimizers_and_lr_schs: see pl.LightningModule.configure_optimizers return value for all options
:param optimizers_and_lr_schs: either a callable that follows pl.LightningModule.configure_optimizers prototype or just the output of such a function.
:param callbacks: see pl.LightningModule.configure_callbacks return value for details
:param best_epoch_source: Create list of pl.callbacks that saves checkpoints using (pl.callbacks.ModelCheckpoint) and print per epoch summary (fuse.dl.lightning.pl_epoch_summary.ModelEpochSummary).
Either a dict with arguments to pass to ModelCheckpoint or list dicts for multiple ModelCheckpoint callbacks (to monitor and save checkpoints for more then one metric).
Expand Down Expand Up @@ -329,6 +328,8 @@ def configure_callbacks(self) -> Sequence[pl.Callback]:

def configure_optimizers(self) -> torch.optim.Optimizer:
"""See pl.LightningModule.configure_optimizers return value for all options"""
if isinstance(self._optimizers_and_lr_schs, Callable):
return self._optimizers_and_lr_schs(self)
return self._optimizers_and_lr_schs

def set_predictions_keys(self, keys: List[str]) -> None:
Expand Down
Loading