Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
viictorjimenezzz committed Mar 22, 2024
1 parent 0312d4b commit f2fe80f
Show file tree
Hide file tree
Showing 135 changed files with 8,842 additions and 2,205 deletions.
3 changes: 3 additions & 0 deletions configs/callbacks/accuracy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
accuracy:
_target_: src.callbacks.accuracy.Accuracy_Callback
n_classes: ${model.net.n_classes}
6 changes: 6 additions & 0 deletions configs/callbacks/accuracy_domains.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
accuracy_domains:
_target_: src.callbacks.accuracy_domains.AccuracyDomains_Callback
n_classes: ${model.net.n_classes}
n_domains_train: ${len:'${data.train_config}'}
n_domains_val: ${len:'${data.val_config}'}
n_domains_test: ${len:'${data.test_config}'}
9 changes: 9 additions & 0 deletions configs/callbacks/components/pa_diagvib_trainval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# We turn a (train,val,test)_config into a single environment configuration.
dset_list:
- _target_: src.data.components.diagvib_dataset.DiagVib6DatasetPA
mnist_preprocessed_path: ${paths.data_dir}/dg/mnist_processed.npz
cache_filepath: data/dg/dg_datasets/test_data_pipeline/train_singlevar0.pkl

- _target_: src.data.components.diagvib_dataset.DiagVib6DatasetPA
mnist_preprocessed_path: ${paths.data_dir}/dg/mnist_processed.npz
cache_filepath: data/dg/dg_datasets/test_data_pipeline/val_singlevar0.pkl
28 changes: 28 additions & 0 deletions configs/callbacks/components/pa_wilds_trainval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# We turn a (train,val,test)_config into a single environment configuration.
dset_list:
- _target_: src.data.components.wilds_dataset.WILDSDatasetEnv
dataset:
_target_: wilds.get_dataset
dataset: ${data.dataset_name}
download: false
unlabeled: false
root_dir: ${data.dataset_dir}
transform: ${data.transform}

env_config:
_target_: src.data.components.wilds_dataset.WILDS_multiple_to_single
multiple_env_config: ${data.train_config} # training data


- _target_: src.data.components.wilds_dataset.WILDSDatasetEnv
dataset:
_target_: wilds.get_dataset
dataset: ${data.dataset_name}
download: false
unlabeled: false
root_dir: ${data.dataset_dir}
transform: ${data.transform}

env_config:
_target_: src.data.components.wilds_dataset.WILDS_multiple_to_single
multiple_env_config: ${data.val_config} # training data
2 changes: 2 additions & 0 deletions configs/callbacks/debugging.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
debugging:
_target_: src.callbacks.debugging.Debugging_Callback
23 changes: 23 additions & 0 deletions configs/callbacks/default.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,33 @@
defaults:
- accuracy_domains.yaml
- accuracy.yaml
- debugging.yaml
- posterioragreement.yaml
- model_checkpoint.yaml
- early_stopping.yaml
- model_summary.yaml
- rich_progress_bar.yaml
- _self_

# model_checkpoint:
# dirpath: ${paths.output_dir}/checkpoints
# filename: "acc_epoch_{epoch:03d}"
# monitor: "val/acc_pa"
# mode: "max"
# save_last: True
# auto_insert_metric_name: False
# every_n_epochs: ${callbacks.pa_lightning.log_every_n_epochs}

# early_stopping:
# monitor: "val/acc_pa"
# patience: 50
# mode: "max"
# strict: False

# - accuracy.yaml
# - debugging.yaml


model_checkpoint:
dirpath: ${paths.output_dir}/checkpoints
filename: "epoch_{epoch:03d}"
Expand Down
16 changes: 16 additions & 0 deletions configs/callbacks/pa_early_stopping.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html

early_stopping_PA:
_target_: pytorch_lightning.callbacks.EarlyStopping
monitor: "val/logPA"
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
patience: 100 #${trainer.max_epochs}/${callbacks.pa_lightning.log_every_n_epochs} # patience in epochs
mode: "max"
strict: False # whether to crash the training if monitor is not found in the validation metrics

verbose: False # verbosity mode
check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
# log_rank_zero_only: False # this keyword argument isn't available in stable version
18 changes: 18 additions & 0 deletions configs/callbacks/pa_model_checkpoint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html

model_checkpoint_PA:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
dirpath: ${paths.output_dir}/checkpoints
filename: "pa_epoch_{epoch:03d}"
monitor: "val/logPA"
mode: "max"
save_last: False
auto_insert_metric_name: False
every_n_epochs: ${callbacks.pa_lightning.log_every_n_epochs} # check every n training epochs

verbose: False # verbosity mode
save_top_k: 1 # save k best models (determined by above metric)
save_weights_only: False # if True, then only the model’s weights will be saved
every_n_train_steps: null # number of training steps between checkpoints
train_time_interval: null # checkpoints are monitored at the specified time interval
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
23 changes: 23 additions & 0 deletions configs/callbacks/posterioragreement.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# @package _global_

# We pass the dset_list from a specific configuration
defaults:
- [email protected]: pa_diagvib_trainval.yaml

callbacks:
posterioragreement:
_target_: posterioragreement.lightning_callback.PA_Callback

log_every_n_epochs: 1
pa_epochs: 50
beta0: 1.0
pairing_strategy: label
pairing_csv: ${data.dataset_dir}${data.dataset_name}_${callbacks.posterioragreement.pairing_strategy}_pairing.csv
# optimizer: Optional[torch.optim.Optimizer] = None
cuda_devices: 4
batch_size: ${data.batch_size}
num_workers: ${data.num_workers}

dataset:
_target_: posterioragreement.datautils.MultienvDataset
dset_list: ???
2 changes: 1 addition & 1 deletion configs/data/adv/cifar10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defaults:
- /model/adv/cifar10@_here_
- attack: null

_target_: src.data.cifar10_datamodule.CIFAR10DataModule
_target_: src.data.cifar10_datamodules.CIFAR10DataModule
batch_size: ???
adversarial_ratio: 1.
data_dir: ${paths.data_dir}/adv/adv_datasets
Expand Down
75 changes: 75 additions & 0 deletions configs/data/dg/wilds/camelyon17_idval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# We train with 3 hospitals, validate either with same (ID) or with fourth hospital (OOD), and test with fifth.

dataset_name: camelyon17
n_classes: 2

# train: 302436
# val (OOD): 34904
# id_val (ID): 33560
# test: 85054

# HOSPITALS:
# train: 0, 3, 4 (53425, 0, 0, 116959, 132052)
# val (OOD): 1 (0,34904,0,0,0)
# id_val (ID): (6011, 0, 0, 12879, 14670)
# test: 2 (0,0,85054,0,0)

# SLIDES:
# test: 20-30 (28 is the most numerous, with 32k observations)

transform:
# Important to load the right transform for the data.
# SOURCE: WILDS code
# https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py
_target_: src.data.components.wilds_transforms.initialize_transform
dataset:
original_resolution: [96, 96]
transform_name: image_base
config:
target_resolution: [96, 96]
is_training: ???
additional_transform_name: null

train_config:
env1:
split_name: train
group_by_fields: ["hospital"]
values:
hospital: [0]
env2:
split_name: train
group_by_fields: ["hospital"]
values:
hospital: [3]
env3:
split_name: train
group_by_fields: ["hospital"]
values:
hospital: [4]

# For ID validation
val_config:
env1:
split_name: id_val
group_by_fields: ["hospital"]
values:
hospital: [0]
env2:
split_name: id_val
group_by_fields: ["hospital"]
values:
hospital: [3]

env3:
split_name: id_val
group_by_fields: ["hospital"]
values:
hospital: [4]

# Always OOD testing
test_config:
env1:
split_name: test
group_by_fields: ["hospital"]
values:
hospital: [2]
64 changes: 64 additions & 0 deletions configs/data/dg/wilds/camelyon17_oodval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# We train with 3 hospitals, validate either with same (ID) or with fourth hospital (OOD), and test with fifth.

dataset_name: camelyon17
n_classes: 2

# train: 302436
# val (OOD): 34904
# id_val (ID): 33560
# test: 85054

# HOSPITALS:
# train: 0, 3, 4 (53425, 0, 0, 116959, 132052)
# val (OOD): 1 (0,34904,0,0,0)
# id_val (ID): (6011, 0, 0, 12879, 14670)
# test: 2 (0,0,85054,0,0)

# SLIDES:
# test: 20-30 (28 is the most numerous, with 32k observations)

transform:
# Important to load the right transform for the data.
# SOURCE: WILDS code
# https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py
_target_: src.data.components.wilds_transforms.initialize_transform
dataset:
original_resolution: [96, 96]
transform_name: image_base
config:
target_resolution: [96, 96]
is_training: ???
additional_transform_name: null

train_config:
env1:
split_name: train
group_by_fields: ["hospital"]
values:
hospital: [0]
env2:
split_name: train
group_by_fields: ["hospital"]
values:
hospital: [3]
env3:
split_name: train
group_by_fields: ["hospital"]
values:
hospital: [4]

# For OOD validation
val_config:
env1:
split_name: val
group_by_fields: ["hospital"]
values:
hospital: [1]

# Always OOD testing
test_config:
env1:
split_name: test
group_by_fields: ["hospital"]
values:
hospital: [2]
66 changes: 66 additions & 0 deletions configs/data/dg/wilds/camelyon17_oracle.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# In contrast with camelyion17.yaml, we train with an additional slide from the testing hospital.

dataset_name: camelyon17
n_classes: 2

# HOSPITALS:
# train & id_val: 0, 3, 4
# val (OOD): 1
# test: 2

# SLIDES:
# test: 20-30 (28 is the most numerous, with 32k observations)

transform:
# Important to load the right transform for the data.
# SOURCE: WILDS code
# https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py
_target_: src.data.components.wilds_transforms.initialize_transform
dataset:
original_resolution: [96, 96]
transform_name: image_base
config:
target_resolution: [96, 96]
is_training: ???
additional_transform_name: null

train_config:
env1:
split_name: train
group_by_fields: ["hospital"]
values:
hospital: [0]
env2:
split_name: train
group_by_fields: ["hospital"]
values:
hospital: [3]
env3:
split_name: train
group_by_fields: ["hospital"]
values:
hospital: [4]

# Take a slide from the test hospital for training
env4:
split_name: test
group_by_fields: ["slide"]
values:
slide: [28]

# For OOD validation
val_config:
env1:
split_name: val
group_by_fields: ["hospital"]
values:
hospital: [1]


# Always OOD testing
test_config:
env1:
split_name: test
group_by_fields: ["hospital"]
values:
hospital: [2]
Loading

0 comments on commit f2fe80f

Please sign in to comment.