-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0312d4b
commit f2fe80f
Showing
135 changed files
with
8,842 additions
and
2,205 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
accuracy: | ||
_target_: src.callbacks.accuracy.Accuracy_Callback | ||
n_classes: ${model.net.n_classes} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
accuracy_domains: | ||
_target_: src.callbacks.accuracy_domains.AccuracyDomains_Callback | ||
n_classes: ${model.net.n_classes} | ||
n_domains_train: ${len:'${data.train_config}'} | ||
n_domains_val: ${len:'${data.val_config}'} | ||
n_domains_test: ${len:'${data.test_config}'} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# We turn a (train,val,test)_config into a single environment configuration. | ||
dset_list: | ||
- _target_: src.data.components.diagvib_dataset.DiagVib6DatasetPA | ||
mnist_preprocessed_path: ${paths.data_dir}/dg/mnist_processed.npz | ||
cache_filepath: data/dg/dg_datasets/test_data_pipeline/train_singlevar0.pkl | ||
|
||
- _target_: src.data.components.diagvib_dataset.DiagVib6DatasetPA | ||
mnist_preprocessed_path: ${paths.data_dir}/dg/mnist_processed.npz | ||
cache_filepath: data/dg/dg_datasets/test_data_pipeline/val_singlevar0.pkl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# We turn a (train,val,test)_config into a single environment configuration. | ||
dset_list: | ||
- _target_: src.data.components.wilds_dataset.WILDSDatasetEnv | ||
dataset: | ||
_target_: wilds.get_dataset | ||
dataset: ${data.dataset_name} | ||
download: false | ||
unlabeled: false | ||
root_dir: ${data.dataset_dir} | ||
transform: ${data.transform} | ||
|
||
env_config: | ||
_target_: src.data.components.wilds_dataset.WILDS_multiple_to_single | ||
multiple_env_config: ${data.train_config} # training data | ||
|
||
|
||
- _target_: src.data.components.wilds_dataset.WILDSDatasetEnv | ||
dataset: | ||
_target_: wilds.get_dataset | ||
dataset: ${data.dataset_name} | ||
download: false | ||
unlabeled: false | ||
root_dir: ${data.dataset_dir} | ||
transform: ${data.transform} | ||
|
||
env_config: | ||
_target_: src.data.components.wilds_dataset.WILDS_multiple_to_single | ||
multiple_env_config: ${data.val_config} # training data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
debugging: | ||
_target_: src.callbacks.debugging.Debugging_Callback |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html | ||
|
||
early_stopping_PA: | ||
_target_: pytorch_lightning.callbacks.EarlyStopping | ||
monitor: "val/logPA" | ||
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement | ||
patience: 100 #${trainer.max_epochs}/${callbacks.pa_lightning.log_every_n_epochs} # patience in epochs | ||
mode: "max" | ||
strict: False # whether to crash the training if monitor is not found in the validation metrics | ||
|
||
verbose: False # verbosity mode | ||
check_finite: True # when set True, stops training when the monitor becomes NaN or infinite | ||
stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold | ||
divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold | ||
check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch | ||
# log_rank_zero_only: False # this keyword argument isn't available in stable version |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html | ||
|
||
model_checkpoint_PA: | ||
_target_: pytorch_lightning.callbacks.ModelCheckpoint | ||
dirpath: ${paths.output_dir}/checkpoints | ||
filename: "pa_epoch_{epoch:03d}" | ||
monitor: "val/logPA" | ||
mode: "max" | ||
save_last: False | ||
auto_insert_metric_name: False | ||
every_n_epochs: ${callbacks.pa_lightning.log_every_n_epochs} # check every n training epochs | ||
|
||
verbose: False # verbosity mode | ||
save_top_k: 1 # save k best models (determined by above metric) | ||
save_weights_only: False # if True, then only the model’s weights will be saved | ||
every_n_train_steps: null # number of training steps between checkpoints | ||
train_time_interval: null # checkpoints are monitored at the specified time interval | ||
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# @package _global_ | ||
|
||
# We pass the dset_list from a specific configuration | ||
defaults: | ||
- [email protected]: pa_diagvib_trainval.yaml | ||
|
||
callbacks: | ||
posterioragreement: | ||
_target_: posterioragreement.lightning_callback.PA_Callback | ||
|
||
log_every_n_epochs: 1 | ||
pa_epochs: 50 | ||
beta0: 1.0 | ||
pairing_strategy: label | ||
pairing_csv: ${data.dataset_dir}${data.dataset_name}_${callbacks.posterioragreement.pairing_strategy}_pairing.csv | ||
# optimizer: Optional[torch.optim.Optimizer] = None | ||
cuda_devices: 4 | ||
batch_size: ${data.batch_size} | ||
num_workers: ${data.num_workers} | ||
|
||
dataset: | ||
_target_: posterioragreement.datautils.MultienvDataset | ||
dset_list: ??? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# We train with 3 hospitals, validate either with same (ID) or with fourth hospital (OOD), and test with fifth. | ||
|
||
dataset_name: camelyon17 | ||
n_classes: 2 | ||
|
||
# train: 302436 | ||
# val (OOD): 34904 | ||
# id_val (ID): 33560 | ||
# test: 85054 | ||
|
||
# HOSPITALS: | ||
# train: 0, 3, 4 (53425, 0, 0, 116959, 132052) | ||
# val (OOD): 1 (0,34904,0,0,0) | ||
# id_val (ID): (6011, 0, 0, 12879, 14670) | ||
# test: 2 (0,0,85054,0,0) | ||
|
||
# SLIDES: | ||
# test: 20-30 (28 is the most numerous, with 32k observations) | ||
|
||
transform: | ||
# Important to load the right transform for the data. | ||
# SOURCE: WILDS code | ||
# https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py | ||
_target_: src.data.components.wilds_transforms.initialize_transform | ||
dataset: | ||
original_resolution: [96, 96] | ||
transform_name: image_base | ||
config: | ||
target_resolution: [96, 96] | ||
is_training: ??? | ||
additional_transform_name: null | ||
|
||
train_config: | ||
env1: | ||
split_name: train | ||
group_by_fields: ["hospital"] | ||
values: | ||
hospital: [0] | ||
env2: | ||
split_name: train | ||
group_by_fields: ["hospital"] | ||
values: | ||
hospital: [3] | ||
env3: | ||
split_name: train | ||
group_by_fields: ["hospital"] | ||
values: | ||
hospital: [4] | ||
|
||
# For ID validation | ||
val_config: | ||
env1: | ||
split_name: id_val | ||
group_by_fields: ["hospital"] | ||
values: | ||
hospital: [0] | ||
env2: | ||
split_name: id_val | ||
group_by_fields: ["hospital"] | ||
values: | ||
hospital: [3] | ||
|
||
env3: | ||
split_name: id_val | ||
group_by_fields: ["hospital"] | ||
values: | ||
hospital: [4] | ||
|
||
# Always OOD testing | ||
test_config: | ||
env1: | ||
split_name: test | ||
group_by_fields: ["hospital"] | ||
values: | ||
hospital: [2] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# We train with 3 hospitals, validate either with same (ID) or with fourth hospital (OOD), and test with fifth. | ||
|
||
dataset_name: camelyon17 | ||
n_classes: 2 | ||
|
||
# train: 302436 | ||
# val (OOD): 34904 | ||
# id_val (ID): 33560 | ||
# test: 85054 | ||
|
||
# HOSPITALS: | ||
# train: 0, 3, 4 (53425, 0, 0, 116959, 132052) | ||
# val (OOD): 1 (0,34904,0,0,0) | ||
# id_val (ID): (6011, 0, 0, 12879, 14670) | ||
# test: 2 (0,0,85054,0,0) | ||
|
||
# SLIDES: | ||
# test: 20-30 (28 is the most numerous, with 32k observations) | ||
|
||
transform: | ||
# Important to load the right transform for the data. | ||
# SOURCE: WILDS code | ||
# https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py | ||
_target_: src.data.components.wilds_transforms.initialize_transform | ||
dataset: | ||
original_resolution: [96, 96] | ||
transform_name: image_base | ||
config: | ||
target_resolution: [96, 96] | ||
is_training: ??? | ||
additional_transform_name: null | ||
|
||
train_config: | ||
env1: | ||
split_name: train | ||
group_by_fields: ["hospital"] | ||
values: | ||
hospital: [0] | ||
env2: | ||
split_name: train | ||
group_by_fields: ["hospital"] | ||
values: | ||
hospital: [3] | ||
env3: | ||
split_name: train | ||
group_by_fields: ["hospital"] | ||
values: | ||
hospital: [4] | ||
|
||
# For OOD validation | ||
val_config: | ||
env1: | ||
split_name: val | ||
group_by_fields: ["hospital"] | ||
values: | ||
hospital: [1] | ||
|
||
# Always OOD testing | ||
test_config: | ||
env1: | ||
split_name: test | ||
group_by_fields: ["hospital"] | ||
values: | ||
hospital: [2] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# In contrast with camelyion17.yaml, we train with an additional slide from the testing hospital. | ||
|
||
dataset_name: camelyon17 | ||
n_classes: 2 | ||
|
||
# HOSPITALS: | ||
# train & id_val: 0, 3, 4 | ||
# val (OOD): 1 | ||
# test: 2 | ||
|
||
# SLIDES: | ||
# test: 20-30 (28 is the most numerous, with 32k observations) | ||
|
||
transform: | ||
# Important to load the right transform for the data. | ||
# SOURCE: WILDS code | ||
# https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py | ||
_target_: src.data.components.wilds_transforms.initialize_transform | ||
dataset: | ||
original_resolution: [96, 96] | ||
transform_name: image_base | ||
config: | ||
target_resolution: [96, 96] | ||
is_training: ??? | ||
additional_transform_name: null | ||
|
||
train_config: | ||
env1: | ||
split_name: train | ||
group_by_fields: ["hospital"] | ||
values: | ||
hospital: [0] | ||
env2: | ||
split_name: train | ||
group_by_fields: ["hospital"] | ||
values: | ||
hospital: [3] | ||
env3: | ||
split_name: train | ||
group_by_fields: ["hospital"] | ||
values: | ||
hospital: [4] | ||
|
||
# Take a slide from the test hospital for training | ||
env4: | ||
split_name: test | ||
group_by_fields: ["slide"] | ||
values: | ||
slide: [28] | ||
|
||
# For OOD validation | ||
val_config: | ||
env1: | ||
split_name: val | ||
group_by_fields: ["hospital"] | ||
values: | ||
hospital: [1] | ||
|
||
|
||
# Always OOD testing | ||
test_config: | ||
env1: | ||
split_name: test | ||
group_by_fields: ["hospital"] | ||
values: | ||
hospital: [2] |
Oops, something went wrong.