Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple fit runs #170

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
b5fe069
adjust log dir
roman807 Feb 26, 2024
a6945b4
some clean up
roman807 Feb 27, 2024
8258ad7
some clean up
roman807 Feb 27, 2024
612ab9f
updated configs
roman807 Feb 27, 2024
b2087da
added tests
roman807 Feb 27, 2024
dcd98dc
added tests
roman807 Feb 27, 2024
027c52b
clean up interface
roman807 Feb 27, 2024
245254b
formatting
roman807 Feb 27, 2024
86e2cbe
formatting
roman807 Feb 27, 2024
8a9b361
Merge branch 'main' into 158-allow-multiple-fit-runs-and-report-avera…
roman807 Feb 27, 2024
50e2478
clean up
roman807 Feb 27, 2024
84af65d
Merge remote-tracking branch 'origin/158-allow-multiple-fit-runs-and-…
roman807 Feb 27, 2024
60145d5
clean up
roman807 Feb 27, 2024
0f53ec0
clean up
roman807 Feb 27, 2024
68e94e6
pyright
roman807 Feb 27, 2024
2e47ecf
lint
roman807 Feb 27, 2024
d1e99e2
addressed comments
roman807 Feb 27, 2024
e67badc
addressed comments
roman807 Feb 27, 2024
130b2b7
recorder -> recording
roman807 Feb 28, 2024
57e2d95
added example to recording
roman807 Feb 28, 2024
68ec810
Merge branch 'main' into 158-allow-multiple-fit-runs-and-report-avera…
roman807 Feb 28, 2024
07380c7
merge conflicts
roman807 Feb 28, 2024
5dcf963
refactored loggers
roman807 Feb 28, 2024
9f1cc9d
Merge branch 'main' into 158-allow-multiple-fit-runs-and-report-avera…
roman807 Feb 28, 2024
680be6d
formatting
roman807 Feb 28, 2024
d9a2496
updated lightning logger
roman807 Feb 29, 2024
ff58a55
update checkpoint dir
roman807 Feb 29, 2024
6c2ca9a
Merge branch 'main' into 158-allow-multiple-fit-runs-and-report-avera…
roman807 Feb 29, 2024
78af6ca
adjusted configs
roman807 Feb 29, 2024
a572aee
type check
roman807 Feb 29, 2024
208aa82
type check
roman807 Feb 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions configs/vision/dino_vits16/offline/bach.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
---
n_runs: 5
trainer:
class_path: eva.Trainer
init_args:
Expand All @@ -18,7 +19,7 @@ trainer:
- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
min_delta: 0
patience: 500
patience: 800
monitor: *MONITOR_METRIC
mode: *MONITOR_METRIC_MODE
- class_path: eva.callbacks.EmbeddingsWriter
Expand All @@ -36,7 +37,7 @@ trainer:
model: dino_vits16
pretrained: ${oc.env:PRETRAINED, true}
logger:
- class_path: pytorch_lightning.loggers.TensorBoardLogger
- class_path: eva.loggers.TensorBoardLogger
init_args:
save_dir: *LIGHTNING_ROOT
name: ""
Expand Down
5 changes: 3 additions & 2 deletions configs/vision/dino_vits16/offline/crc_he.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
---
n_runs: 5
trainer:
class_path: eva.Trainer
init_args:
Expand All @@ -18,7 +19,7 @@ trainer:
- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
min_delta: 0
patience: 100
patience: 48
monitor: *MONITOR_METRIC
mode: *MONITOR_METRIC_MODE
- class_path: eva.callbacks.EmbeddingsWriter
Expand All @@ -36,7 +37,7 @@ trainer:
model: dino_vits16
pretrained: ${oc.env:PRETRAINED, true}
logger:
- class_path: pytorch_lightning.loggers.TensorBoardLogger
- class_path: eva.loggers.TensorBoardLogger
init_args:
save_dir: *LIGHTNING_ROOT
name: ""
Expand Down
5 changes: 3 additions & 2 deletions configs/vision/dino_vits16/offline/crc_he_nonorm.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
---
n_runs: 5
trainer:
class_path: eva.Trainer
init_args:
Expand All @@ -18,7 +19,7 @@ trainer:
- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
min_delta: 0
patience: 100
patience: 48
monitor: *MONITOR_METRIC
mode: *MONITOR_METRIC_MODE
- class_path: eva.callbacks.EmbeddingsWriter
Expand All @@ -36,7 +37,7 @@ trainer:
model: dino_vits16
pretrained: ${oc.env:PRETRAINED, true}
logger:
- class_path: pytorch_lightning.loggers.TensorBoardLogger
- class_path: eva.loggers.TensorBoardLogger
init_args:
save_dir: *LIGHTNING_ROOT
name: ""
Expand Down
5 changes: 3 additions & 2 deletions configs/vision/dino_vits16/offline/patch_camelyon.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
---
n_runs: 5
trainer:
class_path: eva.Trainer
init_args:
Expand All @@ -18,7 +19,7 @@ trainer:
- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
min_delta: 0
patience: 100
patience: 25
monitor: *MONITOR_METRIC
mode: *MONITOR_METRIC_MODE
- class_path: eva.callbacks.EmbeddingsWriter
Expand All @@ -37,7 +38,7 @@ trainer:
model: dino_vits16
pretrained: ${oc.env:PRETRAINED, true}
logger:
- class_path: pytorch_lightning.loggers.TensorBoardLogger
- class_path: eva.loggers.TensorBoardLogger
init_args:
save_dir: *LIGHTNING_ROOT
name: ""
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/dino_vits16/online/bach.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ trainer:
monitor: *MONITOR_METRIC
mode: *MONITOR_METRIC_MODE
logger:
- class_path: pytorch_lightning.loggers.TensorBoardLogger
- class_path: eva.loggers.TensorBoardLogger
init_args:
save_dir: *LIGHTNING_ROOT
name: ""
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/dino_vits16/online/patch_camelyon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ trainer:
monitor: *MONITOR_METRIC
mode: *MONITOR_METRIC_MODE
logger:
- class_path: pytorch_lightning.loggers.TensorBoardLogger
- class_path: eva.loggers.TensorBoardLogger
init_args:
save_dir: *LIGHTNING_ROOT
name: ""
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/tests/offline/patch_camelyon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ trainer:
monitor: &MONITOR_METRIC val/BinaryAccuracy
mode: &MONITOR_METRIC_MODE max
logger:
- class_path: pytorch_lightning.loggers.TensorBoardLogger
- class_path: eva.loggers.TensorBoardLogger
init_args:
save_dir: *LIGHTNING_ROOT
name: ""
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/tests/offline/patches.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ trainer:
init_args:
logging_interval: epoch
logger:
- class_path: pytorch_lightning.loggers.TensorBoardLogger
- class_path: eva.loggers.TensorBoardLogger
init_args:
save_dir: *OUTPUT_DIR
name: ""
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/tests/offline/slides.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ trainer:
init_args:
logging_interval: epoch
logger:
- class_path: pytorch_lightning.loggers.TensorBoardLogger
- class_path: eva.loggers.TensorBoardLogger
init_args:
save_dir: *OUTPUT_DIR
name: ""
Expand Down
8 changes: 8 additions & 0 deletions configs/vision/tests/online/patch_camelyon.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
---
n_runs: 2
trainer:
class_path: eva.Trainer
init_args:
default_root_dir: &OUTPUT_DIR ${oc.env:OUTPUT_DIR, logs/dino_vits16/patch_camelyon}
max_epochs: &MAX_EPOCHS 1
limit_train_batches: 2
limit_val_batches: 2
logger:
- class_path: eva.loggers.TensorBoardLogger
init_args:
save_dir: *OUTPUT_DIR
- class_path: eva.loggers.CSVLogger
init_args:
save_dir: *OUTPUT_DIR
model:
class_path: eva.HeadModule
init_args:
Expand Down
2 changes: 2 additions & 0 deletions src/eva/callbacks/writers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) ->
self._write_queue.put(None)
self._write_process.join()
logger.info(f"Predictions and manifest saved to {self._output_dir}")
self._write_process = None # type: ignore
self._write_queue = None # type: ignore

def _initialize_write_process(self) -> None:
self._write_queue = multiprocessing.Queue()
Expand Down
79 changes: 75 additions & 4 deletions src/eva/interface/interface.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
"""Main interface class."""

import copy
import os
from datetime import datetime

import pytorch_lightning as pl
from loguru import logger
from pytorch_lightning.callbacks import ModelCheckpoint

from eva import trainers
from eva.data import datamodules
from eva.data.datamodules import schemas
from eva.models import modules
from eva.utils.recording import get_evaluation_id, record_results


class Interface:
Expand All @@ -18,6 +27,7 @@ def fit(
model: modules.ModelModule,
data: datamodules.DataModule,
trainer: trainers.Trainer,
n_runs: int = 1,
) -> None:
"""Perform model training and evaluation in place.

Expand All @@ -34,11 +44,24 @@ def fit(
model: The model module.
data: The data module.
trainer: The trainer which processes the model and data.
n_runs: The number of runs to perform.
"""
trainer.fit(model=model, datamodule=data)
trainer.validate(datamodule=data)
if data.datasets.test is not None:
trainer.test(datamodule=data)
evaluation_id = get_evaluation_id()

for run_id in range(n_runs):
_trainer = copy.deepcopy(trainer)
_model = copy.deepcopy(model)
log_dir = os.path.join(_trainer.default_root_dir, evaluation_id, f"run_{run_id}")
_adapt_log_dirs(_trainer, log_dir)

start_time = datetime.now()
pl.seed_everything(run_id + 3, workers=True)

evaluation_results = _fit_validate_test(_trainer, _model, data)

end_time = datetime.now()
results_path = os.path.join(log_dir, "results.json")
record_results(evaluation_results, results_path, start_time, end_time)

def predict(
self,
Expand Down Expand Up @@ -80,3 +103,51 @@ def predict_fit(
"""
self.predict(model=model, data=data, trainer=trainer)
self.fit(model=model, data=data, trainer=trainer)


def _fit_validate_test(
trainer: trainers.Trainer,
model: modules.ModelModule,
data: datamodules.DataModule,
) -> dict:
"""Combines the fit and validate commands in one method.

Helper method to perform the following three steps:
1. fit: training the model using the provided data.
2. validate: evaluating the model using the validation data.
3. test: evaluating the model using the test data. (if available)

Args:
model: The model module.
data: The data module.
trainer: The trainer which processes the model and data.
"""
trainer.fit(model=model, datamodule=data)
evaluation_results = {"val": trainer.validate(datamodule=data)}
if data.datasets.test is not None:
evaluation_results["test"] = trainer.test(datamodule=data)
return evaluation_results


def _adapt_log_dirs(trainer, log_dir: str) -> None:
"""Sets the log directory for the logger, trainer and callbacks.

Args:
trainer: The trainer instance.
log_dir: The log directory.
"""
for train_logger in trainer.loggers:
try:
train_logger.log_dir = log_dir
except Exception:
logger.warning(f"Could not set log_dir for logger {train_logger}")

trainer.log_dir = log_dir
if len(trainer.callbacks) > 0:
model_checkpoint_callbacks = [
c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)
]
if len(model_checkpoint_callbacks) > 0:
model_checkpoint_callbacks[0].dirpath = os.path.join(log_dir, "checkpoints")
else:
logger.warning("No ModelCheckpoint callback found in trainer.callbacks")
5 changes: 5 additions & 0 deletions src/eva/loggers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Loggers API."""

from eva.loggers.lightning import CSVLogger, TensorBoardLogger

__all__ = ["CSVLogger", "TensorBoardLogger"]
44 changes: 44 additions & 0 deletions src/eva/loggers/lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Custom logger classes for PyTorch Lightning."""

from pytorch_lightning import loggers
from typing_extensions import override


class BaseLogger(loggers.Logger):
"""Base logger class."""

def __init__(self, *args, **kwargs):
"""Initializes the BaseLogger instance.

Overwrites the parent class to allow for custom log_dir setting.
"""
super().__init__(*args, **kwargs)
self._log_dir = None

@property
@override
def log_dir(self) -> str:
if self._log_dir is not None:
return self._log_dir
else:
return super().log_dir # type: ignore

@log_dir.setter
def log_dir(self, value):
self._log_dir = value


class TensorBoardLogger(BaseLogger, loggers.TensorBoardLogger):
"""TensorBoard logger class."""

def __init__(self, *args, **kwargs):
"""Initializes the TensorBoardLogger instance."""
super().__init__(*args, **kwargs)


class CSVLogger(BaseLogger, loggers.CSVLogger):
"""CSV logger class."""

def __init__(self, *args, **kwargs):
"""Initializes the CSVLogger instance."""
super().__init__(*args, **kwargs)
30 changes: 28 additions & 2 deletions src/eva/trainers/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,32 @@
"""Core trainer module."""

from typing import Optional

from pytorch_lightning import trainer
from typing_extensions import override

from eva.utils.recording import get_evaluation_id


class Trainer(trainer.Trainer):
"""Core Trainer class."""

def __init__(self, **kwargs):
"""Initializes a new Trainer instance."""
super(Trainer, self).__init__(**kwargs)
self.evaluation_id = get_evaluation_id()
self._log_dir = None
self.i = 0

@property
@override
def log_dir(self) -> Optional[str]:
"""Overrides the log_dir getter from parent class."""
if self._log_dir is not None:
return self._log_dir
else:
return super().log_dir

Trainer = trainer.Trainer
"""Core trainer class."""
@log_dir.setter
def log_dir(self, value):
self._log_dir = value
Loading
Loading