From 4b09f9668658d7ca085a6448e6b0a66e41326b05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20K=C3=A4nzig?= <36882833+nkaenzig@users.noreply.github.com> Date: Fri, 6 Dec 2024 10:07:14 +0100 Subject: [PATCH] Add `CHECKPOINT_TYPE` env variable to configs (#726) --- .../pathology/offline/classification/bach.yaml | 3 ++- .../pathology/offline/classification/camelyon16.yaml | 1 + .../offline/classification/camelyon16_small.yaml | 1 + .../vision/pathology/offline/classification/crc.yaml | 3 ++- .../pathology/offline/classification/mhist.yaml | 3 ++- .../pathology/offline/classification/panda.yaml | 1 + .../offline/classification/panda_small.yaml | 1 + .../offline/classification/patch_camelyon.yaml | 3 ++- .../vision/pathology/offline/segmentation/bcss.yaml | 5 +++-- .../pathology/offline/segmentation/consep.yaml | 5 +++-- .../pathology/offline/segmentation/monusac.yaml | 3 ++- .../offline/segmentation/total_segmentator_2d.yaml | 3 ++- .../vision/pathology/online/classification/bach.yaml | 4 +++- .../vision/pathology/online/classification/crc.yaml | 4 +++- .../pathology/online/classification/mhist.yaml | 4 +++- .../online/classification/patch_camelyon.yaml | 4 +++- .../vision/pathology/online/segmentation/bcss.yaml | 5 +++-- .../vision/pathology/online/segmentation/consep.yaml | 5 +++-- .../pathology/online/segmentation/monusac.yaml | 5 +++-- .../online/segmentation/total_segmentator_2d.yaml | 5 +++-- .../vision/radiology/offline/segmentation/lits.yaml | 4 ++-- .../offline/segmentation/lits_balanced.yaml | 5 +++-- .../vision/radiology/online/segmentation/lits.yaml | 3 ++- .../radiology/online/segmentation/lits_balanced.yaml | 3 ++- configs/vision/tests/offline/panda.yaml | 1 + configs/vision/tests/offline/patch_camelyon.yaml | 1 + configs/vision/tests/online/patch_camelyon.yaml | 1 + docs/user-guide/getting-started/how_to_use.md | 6 ++++-- src/eva/core/trainers/functional.py | 6 ++++-- src/eva/core/trainers/trainer.py | 12 ++++++++---- 30 files changed, 74 insertions(+), 36 deletions(-) diff --git a/configs/vision/pathology/offline/classification/bach.yaml b/configs/vision/pathology/offline/classification/bach.yaml index bf9494ba5..98ac45cc1 100644 --- a/configs/vision/pathology/offline/classification/bach.yaml +++ b/configs/vision/pathology/offline/classification/bach.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/bach} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -23,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 400 + patience: ${oc.env:PATIENCE, 400} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.ClassificationEmbeddingsWriter diff --git a/configs/vision/pathology/offline/classification/camelyon16.yaml b/configs/vision/pathology/offline/classification/camelyon16.yaml index 5a5f9a5c4..97ef1f6e5 100644 --- a/configs/vision/pathology/offline/classification/camelyon16.yaml +++ b/configs/vision/pathology/offline/classification/camelyon16.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/camelyon16} max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 100} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/pathology/offline/classification/camelyon16_small.yaml b/configs/vision/pathology/offline/classification/camelyon16_small.yaml index dd4be2af6..d3346dee4 100644 --- a/configs/vision/pathology/offline/classification/camelyon16_small.yaml +++ b/configs/vision/pathology/offline/classification/camelyon16_small.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/camelyon16} max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 100} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/pathology/offline/classification/crc.yaml b/configs/vision/pathology/offline/classification/crc.yaml index feca261e2..e54b095b3 100644 --- a/configs/vision/pathology/offline/classification/crc.yaml +++ b/configs/vision/pathology/offline/classification/crc.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/crc} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -23,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 24 + patience: ${oc.env:PATIENCE, 24} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.ClassificationEmbeddingsWriter diff --git a/configs/vision/pathology/offline/classification/mhist.yaml b/configs/vision/pathology/offline/classification/mhist.yaml index f96c1f151..ad7b5c36e 100644 --- a/configs/vision/pathology/offline/classification/mhist.yaml +++ b/configs/vision/pathology/offline/classification/mhist.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/mhist} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -23,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 70 + patience: ${oc.env:PATIENCE, 70} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.ClassificationEmbeddingsWriter diff --git a/configs/vision/pathology/offline/classification/panda.yaml b/configs/vision/pathology/offline/classification/panda.yaml index b88138c58..0ef3c3a4f 100644 --- a/configs/vision/pathology/offline/classification/panda.yaml +++ b/configs/vision/pathology/offline/classification/panda.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/panda} max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 49} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/pathology/offline/classification/panda_small.yaml b/configs/vision/pathology/offline/classification/panda_small.yaml index 53735a7cc..e4a4980ef 100644 --- a/configs/vision/pathology/offline/classification/panda_small.yaml +++ b/configs/vision/pathology/offline/classification/panda_small.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/panda} max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 49} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/pathology/offline/classification/patch_camelyon.yaml b/configs/vision/pathology/offline/classification/patch_camelyon.yaml index fc8450e79..4dfbd34fd 100644 --- a/configs/vision/pathology/offline/classification/patch_camelyon.yaml +++ b/configs/vision/pathology/offline/classification/patch_camelyon.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/patch_camelyon} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -23,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 9 + patience: ${oc.env:PATIENCE, 9} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.ClassificationEmbeddingsWriter diff --git a/configs/vision/pathology/offline/segmentation/bcss.yaml b/configs/vision/pathology/offline/segmentation/bcss.yaml index b7c0f6165..8265441d5 100644 --- a/configs/vision/pathology/offline/segmentation/bcss.yaml +++ b/configs/vision/pathology/offline/segmentation/bcss.yaml @@ -2,10 +2,11 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224}/bcss} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000} log_every_n_steps: 6 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -25,7 +26,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 8 + patience: ${oc.env:PATIENCE, 8} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.SegmentationEmbeddingsWriter diff --git a/configs/vision/pathology/offline/segmentation/consep.yaml b/configs/vision/pathology/offline/segmentation/consep.yaml index 79af29627..68d30a594 100644 --- a/configs/vision/pathology/offline/segmentation/consep.yaml +++ b/configs/vision/pathology/offline/segmentation/consep.yaml @@ -2,10 +2,11 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/consep} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000} log_every_n_steps: 6 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -25,7 +26,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 34 + patience: ${oc.env:PATIENCE, 34} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.SegmentationEmbeddingsWriter diff --git a/configs/vision/pathology/offline/segmentation/monusac.yaml b/configs/vision/pathology/offline/segmentation/monusac.yaml index 587f99846..26547aebc 100644 --- a/configs/vision/pathology/offline/segmentation/monusac.yaml +++ b/configs/vision/pathology/offline/segmentation/monusac.yaml @@ -6,6 +6,7 @@ trainer: default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/monusac} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000} log_every_n_steps: 6 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -28,7 +29,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 50 + patience: ${oc.env:PATIENCE, 50} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.SegmentationEmbeddingsWriter diff --git a/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml b/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml index 38080f1ae..36261851c 100644 --- a/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml +++ b/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 1} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/total_segmentator_2d} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 20000} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -25,7 +26,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 5 + patience: ${oc.env:PATIENCE, 5} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.SegmentationEmbeddingsWriter diff --git a/configs/vision/pathology/online/classification/bach.yaml b/configs/vision/pathology/online/classification/bach.yaml index 1719d821e..6e6f9bb88 100644 --- a/configs/vision/pathology/online/classification/bach.yaml +++ b/configs/vision/pathology/online/classification/bach.yaml @@ -2,8 +2,10 @@ trainer: class_path: eva.Trainer init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/online/bach} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -22,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 400 + patience: ${oc.env:PATIENCE, 400} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/pathology/online/classification/crc.yaml b/configs/vision/pathology/online/classification/crc.yaml index 5abe659e0..37bb52c61 100644 --- a/configs/vision/pathology/online/classification/crc.yaml +++ b/configs/vision/pathology/online/classification/crc.yaml @@ -2,8 +2,10 @@ trainer: class_path: eva.Trainer init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/online/crc} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -22,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 24 + patience: ${oc.env:PATIENCE, 24} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/pathology/online/classification/mhist.yaml b/configs/vision/pathology/online/classification/mhist.yaml index 25dcbc509..b2a23f13b 100644 --- a/configs/vision/pathology/online/classification/mhist.yaml +++ b/configs/vision/pathology/online/classification/mhist.yaml @@ -2,8 +2,10 @@ trainer: class_path: eva.Trainer init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &LIGHTNING_ROOT ${oc.env:LIGHTNING_ROOT, logs/dino_vits16/online/mhist} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -22,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 70 + patience: ${oc.env:PATIENCE, 70} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/pathology/online/classification/patch_camelyon.yaml b/configs/vision/pathology/online/classification/patch_camelyon.yaml index 13817a718..60800129c 100644 --- a/configs/vision/pathology/online/classification/patch_camelyon.yaml +++ b/configs/vision/pathology/online/classification/patch_camelyon.yaml @@ -2,8 +2,10 @@ trainer: class_path: eva.Trainer init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/online/patch_camelyon} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -22,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 9 + patience: ${oc.env:PATIENCE, 9} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/pathology/online/segmentation/bcss.yaml b/configs/vision/pathology/online/segmentation/bcss.yaml index 2c343f134..694936df7 100644 --- a/configs/vision/pathology/online/segmentation/bcss.yaml +++ b/configs/vision/pathology/online/segmentation/bcss.yaml @@ -2,10 +2,11 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224}/bcss} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 513} log_every_n_steps: 6 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -26,7 +27,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 100 + patience: ${oc.env:PATIENCE, 100} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/pathology/online/segmentation/consep.yaml b/configs/vision/pathology/online/segmentation/consep.yaml index 4935515b2..e5b6de151 100644 --- a/configs/vision/pathology/online/segmentation/consep.yaml +++ b/configs/vision/pathology/online/segmentation/consep.yaml @@ -2,10 +2,11 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/consep} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000} log_every_n_steps: 6 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -26,7 +27,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 34 + patience: ${oc.env:PATIENCE, 34} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/pathology/online/segmentation/monusac.yaml b/configs/vision/pathology/online/segmentation/monusac.yaml index acf8d9e1c..08b05644e 100644 --- a/configs/vision/pathology/online/segmentation/monusac.yaml +++ b/configs/vision/pathology/online/segmentation/monusac.yaml @@ -2,10 +2,11 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/monusac} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000} log_every_n_steps: 6 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -26,7 +27,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 50 + patience: ${oc.env:PATIENCE, 50} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml b/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml index 2671ec402..aa48bf77d 100644 --- a/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml +++ b/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml @@ -2,9 +2,10 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/total_segmentator_2d} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 20000} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -25,7 +26,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 5 + patience: ${oc.env:PATIENCE, 5} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/radiology/offline/segmentation/lits.yaml b/configs/vision/radiology/offline/segmentation/lits.yaml index d9e0c4903..1c4cfe498 100644 --- a/configs/vision/radiology/offline/segmentation/lits.yaml +++ b/configs/vision/radiology/offline/segmentation/lits.yaml @@ -2,7 +2,7 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} callbacks: @@ -24,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 100 + patience: ${oc.env:PATIENCE, 100} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.SegmentationEmbeddingsWriter diff --git a/configs/vision/radiology/offline/segmentation/lits_balanced.yaml b/configs/vision/radiology/offline/segmentation/lits_balanced.yaml index a0059e34b..866a70333 100644 --- a/configs/vision/radiology/offline/segmentation/lits_balanced.yaml +++ b/configs/vision/radiology/offline/segmentation/lits_balanced.yaml @@ -2,9 +2,10 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits_balanced} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -24,7 +25,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 100 + patience: ${oc.env:PATIENCE, 100} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.SegmentationEmbeddingsWriter diff --git a/configs/vision/radiology/online/segmentation/lits.yaml b/configs/vision/radiology/online/segmentation/lits.yaml index 3d8d2fc57..81ed0e2f7 100644 --- a/configs/vision/radiology/online/segmentation/lits.yaml +++ b/configs/vision/radiology/online/segmentation/lits.yaml @@ -6,6 +6,7 @@ trainer: default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} log_every_n_steps: 6 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -26,7 +27,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 100 + patience: ${oc.env:PATIENCE, 100} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/radiology/online/segmentation/lits_balanced.yaml b/configs/vision/radiology/online/segmentation/lits_balanced.yaml index cff4c88e8..b5224d6cb 100644 --- a/configs/vision/radiology/online/segmentation/lits_balanced.yaml +++ b/configs/vision/radiology/online/segmentation/lits_balanced.yaml @@ -6,6 +6,7 @@ trainer: default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits_balanced} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} log_every_n_steps: 6 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -26,7 +27,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 100 + patience: ${oc.env:PATIENCE, 100} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/tests/offline/panda.yaml b/configs/vision/tests/offline/panda.yaml index 4051b4edf..6bd0e958e 100644 --- a/configs/vision/tests/offline/panda.yaml +++ b/configs/vision/tests/offline/panda.yaml @@ -6,6 +6,7 @@ trainer: max_epochs: &MAX_EPOCHS 1 limit_train_batches: 2 limit_val_batches: 2 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ClassificationEmbeddingsWriter init_args: diff --git a/configs/vision/tests/offline/patch_camelyon.yaml b/configs/vision/tests/offline/patch_camelyon.yaml index e09a44c6b..f17a24ad9 100644 --- a/configs/vision/tests/offline/patch_camelyon.yaml +++ b/configs/vision/tests/offline/patch_camelyon.yaml @@ -6,6 +6,7 @@ trainer: max_epochs: &MAX_EPOCHS 1 limit_train_batches: 2 limit_val_batches: 2 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, last} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: eva.callbacks.ClassificationEmbeddingsWriter diff --git a/configs/vision/tests/online/patch_camelyon.yaml b/configs/vision/tests/online/patch_camelyon.yaml index 4b8709415..cf75e1888 100644 --- a/configs/vision/tests/online/patch_camelyon.yaml +++ b/configs/vision/tests/online/patch_camelyon.yaml @@ -6,6 +6,7 @@ trainer: max_epochs: &MAX_EPOCHS 1 limit_train_batches: 2 limit_val_batches: 2 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} model: class_path: eva.HeadModule init_args: diff --git a/docs/user-guide/getting-started/how_to_use.md b/docs/user-guide/getting-started/how_to_use.md index cea6952d5..34288f416 100644 --- a/docs/user-guide/getting-started/how_to_use.md +++ b/docs/user-guide/getting-started/how_to_use.md @@ -59,5 +59,7 @@ To customize runs, without the need of creating custom config-files, you can ove | `MONITOR_METRIC_MODE` | `str` | "min" or "max", depending on the `MONITOR_METRIC` used | | `REPO_OR_DIR` | `str` | GitHub repo with format containing model implementation, e.g. "facebookresearch/dino:main" | | `TQDM_REFRESH_RATE` | `str` | Determines at which rate (in number of batches) the progress bars get updated. Set it to 0 to disable the progress bar. | -| `N_DATA_WORKERS` | `str` | How many subprocesses to use for the torch dataloaders. Set to `null` to use the number of cpu cores. | -| `METRICS_DEVICE` | `str` | Specifies the device on which to compute the metrics. If not set, will use the same device as used for training. | \ No newline at end of file +| `N_DATA_WORKERS` | `str` | How many subprocesses to use for the torch dataloaders. Set to `null` to use the number of cpu cores. | +| `METRICS_DEVICE` | `str` | Specifies the device on which to compute the metrics. If not set, will use the same device as used for training. | +| `CHECKPOINT_TYPE` | `str` | Set to "best" or "last", to select which checkpoint to load for evaluations on validation & test sets after training. | +| `PATIENCE` | `int` | Number of checks with no improvement after which training will be stopped (early stopping). | \ No newline at end of file diff --git a/src/eva/core/trainers/functional.py b/src/eva/core/trainers/functional.py index 62229bf81..4d8bd5346 100644 --- a/src/eva/core/trainers/functional.py +++ b/src/eva/core/trainers/functional.py @@ -96,11 +96,13 @@ def fit_and_validate( A tuple of with the validation and the test metrics (if exists). """ trainer.fit(model, datamodule=datamodule) - validation_scores = trainer.validate(datamodule=datamodule, verbose=verbose) + validation_scores = trainer.validate( + datamodule=datamodule, verbose=verbose, ckpt_path=trainer.checkpoint_type + ) test_scores = ( None if datamodule.datasets.test is None - else trainer.test(datamodule=datamodule, verbose=verbose) + else trainer.test(datamodule=datamodule, verbose=verbose, ckpt_path=trainer.checkpoint_type) ) return validation_scores, test_scores diff --git a/src/eva/core/trainers/trainer.py b/src/eva/core/trainers/trainer.py index 006470339..beace9db3 100644 --- a/src/eva/core/trainers/trainer.py +++ b/src/eva/core/trainers/trainer.py @@ -1,7 +1,7 @@ """Core trainer module.""" import os -from typing import Any +from typing import Any, Literal import loguru from lightning.pytorch import loggers as pl_loggers @@ -28,6 +28,7 @@ def __init__( *args: Any, default_root_dir: str = "logs", n_runs: int = 1, + checkpoint_type: Literal["best", "last"] = "best", **kwargs: Any, ) -> None: """Initializes the trainer. @@ -40,11 +41,14 @@ def __init__( Unlike in ::class::`lightning.pytorch.Trainer`, this path would be the prioritized destination point. n_runs: The amount of runs (fit and evaluate) to perform in an evaluation session. + checkpoint_type: Wether to load the "best" or "last" checkpoint saved by the checkpoint + callback for evaluations on validation & test sets. kwargs: Kew-word arguments of ::class::`lightning.pytorch.Trainer`. """ super().__init__(*args, default_root_dir=default_root_dir, **kwargs) - self._n_runs = n_runs + self.checkpoint_type = checkpoint_type + self.n_runs = n_runs self._session_id: str = _logging.generate_session_id() self._log_dir: str = self.default_log_dir @@ -106,6 +110,6 @@ def run_evaluation_session( base_trainer=self, base_model=model, datamodule=datamodule, - n_runs=self._n_runs, - verbose=self._n_runs > 1, + n_runs=self.n_runs, + verbose=self.n_runs > 1, )