diff --git a/configs/callbacks/accuracy.yaml b/configs/callbacks/accuracy.yaml new file mode 100644 index 0000000..2138067 --- /dev/null +++ b/configs/callbacks/accuracy.yaml @@ -0,0 +1,3 @@ +accuracy: + _target_: src.callbacks.accuracy.Accuracy_Callback + n_classes: ${model.net.n_classes} \ No newline at end of file diff --git a/configs/callbacks/accuracy_domains.yaml b/configs/callbacks/accuracy_domains.yaml new file mode 100644 index 0000000..54b8920 --- /dev/null +++ b/configs/callbacks/accuracy_domains.yaml @@ -0,0 +1,6 @@ +accuracy_domains: + _target_: src.callbacks.accuracy_domains.AccuracyDomains_Callback + n_classes: ${model.net.n_classes} + n_domains_train: ${len:'${data.train_config}'} + n_domains_val: ${len:'${data.val_config}'} + n_domains_test: ${len:'${data.test_config}'} \ No newline at end of file diff --git a/configs/callbacks/components/pa_diagvib_trainval.yaml b/configs/callbacks/components/pa_diagvib_trainval.yaml new file mode 100644 index 0000000..e36966b --- /dev/null +++ b/configs/callbacks/components/pa_diagvib_trainval.yaml @@ -0,0 +1,9 @@ +# We turn a (train,val,test)_config into a single environment configuration. +dset_list: + - _target_: src.data.components.diagvib_dataset.DiagVib6DatasetPA + mnist_preprocessed_path: ${paths.data_dir}/dg/mnist_processed.npz + cache_filepath: data/dg/dg_datasets/test_data_pipeline/train_singlevar0.pkl + + - _target_: src.data.components.diagvib_dataset.DiagVib6DatasetPA + mnist_preprocessed_path: ${paths.data_dir}/dg/mnist_processed.npz + cache_filepath: data/dg/dg_datasets/test_data_pipeline/val_singlevar0.pkl \ No newline at end of file diff --git a/configs/callbacks/components/pa_wilds_trainval.yaml b/configs/callbacks/components/pa_wilds_trainval.yaml new file mode 100644 index 0000000..e806d47 --- /dev/null +++ b/configs/callbacks/components/pa_wilds_trainval.yaml @@ -0,0 +1,28 @@ +# We turn a (train,val,test)_config into a single environment configuration. +dset_list: + - _target_: src.data.components.wilds_dataset.WILDSDatasetEnv + dataset: + _target_: wilds.get_dataset + dataset: ${data.dataset_name} + download: false + unlabeled: false + root_dir: ${data.dataset_dir} + transform: ${data.transform} + + env_config: + _target_: src.data.components.wilds_dataset.WILDS_multiple_to_single + multiple_env_config: ${data.train_config} # training data + + + - _target_: src.data.components.wilds_dataset.WILDSDatasetEnv + dataset: + _target_: wilds.get_dataset + dataset: ${data.dataset_name} + download: false + unlabeled: false + root_dir: ${data.dataset_dir} + transform: ${data.transform} + + env_config: + _target_: src.data.components.wilds_dataset.WILDS_multiple_to_single + multiple_env_config: ${data.val_config} # training data \ No newline at end of file diff --git a/configs/callbacks/debugging.yaml b/configs/callbacks/debugging.yaml new file mode 100644 index 0000000..d5172a2 --- /dev/null +++ b/configs/callbacks/debugging.yaml @@ -0,0 +1,2 @@ +debugging: + _target_: src.callbacks.debugging.Debugging_Callback \ No newline at end of file diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml index 5df27bf..4ef5678 100644 --- a/configs/callbacks/default.yaml +++ b/configs/callbacks/default.yaml @@ -1,10 +1,33 @@ defaults: + - accuracy_domains.yaml + - accuracy.yaml + - debugging.yaml + - posterioragreement.yaml - model_checkpoint.yaml - early_stopping.yaml - model_summary.yaml - rich_progress_bar.yaml - _self_ +# model_checkpoint: +# dirpath: ${paths.output_dir}/checkpoints +# filename: "acc_epoch_{epoch:03d}" +# monitor: "val/acc_pa" +# mode: "max" +# save_last: True +# auto_insert_metric_name: False +# every_n_epochs: ${callbacks.pa_lightning.log_every_n_epochs} + +# early_stopping: +# monitor: "val/acc_pa" +# patience: 50 +# mode: "max" +# strict: False + + # - accuracy.yaml + # - debugging.yaml + + model_checkpoint: dirpath: ${paths.output_dir}/checkpoints filename: "epoch_{epoch:03d}" diff --git a/configs/callbacks/pa_early_stopping.yaml b/configs/callbacks/pa_early_stopping.yaml new file mode 100644 index 0000000..49724b0 --- /dev/null +++ b/configs/callbacks/pa_early_stopping.yaml @@ -0,0 +1,16 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html + +early_stopping_PA: + _target_: pytorch_lightning.callbacks.EarlyStopping + monitor: "val/logPA" + min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement + patience: 100 #${trainer.max_epochs}/${callbacks.pa_lightning.log_every_n_epochs} # patience in epochs + mode: "max" + strict: False # whether to crash the training if monitor is not found in the validation metrics + + verbose: False # verbosity mode + check_finite: True # when set True, stops training when the monitor becomes NaN or infinite + stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold + divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold + check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch + # log_rank_zero_only: False # this keyword argument isn't available in stable version \ No newline at end of file diff --git a/configs/callbacks/pa_model_checkpoint.yaml b/configs/callbacks/pa_model_checkpoint.yaml new file mode 100644 index 0000000..6072458 --- /dev/null +++ b/configs/callbacks/pa_model_checkpoint.yaml @@ -0,0 +1,18 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html + +model_checkpoint_PA: + _target_: pytorch_lightning.callbacks.ModelCheckpoint + dirpath: ${paths.output_dir}/checkpoints + filename: "pa_epoch_{epoch:03d}" + monitor: "val/logPA" + mode: "max" + save_last: False + auto_insert_metric_name: False + every_n_epochs: ${callbacks.pa_lightning.log_every_n_epochs} # check every n training epochs + + verbose: False # verbosity mode + save_top_k: 1 # save k best models (determined by above metric) + save_weights_only: False # if True, then only the model’s weights will be saved + every_n_train_steps: null # number of training steps between checkpoints + train_time_interval: null # checkpoints are monitored at the specified time interval + save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation diff --git a/configs/callbacks/posterioragreement.yaml b/configs/callbacks/posterioragreement.yaml new file mode 100644 index 0000000..91035eb --- /dev/null +++ b/configs/callbacks/posterioragreement.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +# We pass the dset_list from a specific configuration +defaults: + - components@callbacks.posterioragreement.dataset: pa_diagvib_trainval.yaml + +callbacks: + posterioragreement: + _target_: posterioragreement.lightning_callback.PA_Callback + + log_every_n_epochs: 1 + pa_epochs: 50 + beta0: 1.0 + pairing_strategy: label + pairing_csv: ${data.dataset_dir}${data.dataset_name}_${callbacks.posterioragreement.pairing_strategy}_pairing.csv + # optimizer: Optional[torch.optim.Optimizer] = None + cuda_devices: 4 + batch_size: ${data.batch_size} + num_workers: ${data.num_workers} + + dataset: + _target_: posterioragreement.datautils.MultienvDataset + dset_list: ??? \ No newline at end of file diff --git a/configs/data/adv/cifar10.yaml b/configs/data/adv/cifar10.yaml index 7de33ea..43c3ff7 100644 --- a/configs/data/adv/cifar10.yaml +++ b/configs/data/adv/cifar10.yaml @@ -2,7 +2,7 @@ defaults: - /model/adv/cifar10@_here_ - attack: null -_target_: src.data.cifar10_datamodule.CIFAR10DataModule +_target_: src.data.cifar10_datamodules.CIFAR10DataModule batch_size: ??? adversarial_ratio: 1. data_dir: ${paths.data_dir}/adv/adv_datasets diff --git a/configs/data/dg/wilds/camelyon17_idval.yaml b/configs/data/dg/wilds/camelyon17_idval.yaml new file mode 100644 index 0000000..7c37ce6 --- /dev/null +++ b/configs/data/dg/wilds/camelyon17_idval.yaml @@ -0,0 +1,75 @@ +# We train with 3 hospitals, validate either with same (ID) or with fourth hospital (OOD), and test with fifth. + +dataset_name: camelyon17 +n_classes: 2 + +# train: 302436 +# val (OOD): 34904 +# id_val (ID): 33560 +# test: 85054 + +# HOSPITALS: +# train: 0, 3, 4 (53425, 0, 0, 116959, 132052) +# val (OOD): 1 (0,34904,0,0,0) +# id_val (ID): (6011, 0, 0, 12879, 14670) +# test: 2 (0,0,85054,0,0) + +# SLIDES: +# test: 20-30 (28 is the most numerous, with 32k observations) + +transform: + # Important to load the right transform for the data. + # SOURCE: WILDS code + # https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + _target_: src.data.components.wilds_transforms.initialize_transform + dataset: + original_resolution: [96, 96] + transform_name: image_base + config: + target_resolution: [96, 96] + is_training: ??? + additional_transform_name: null + +train_config: + env1: + split_name: train + group_by_fields: ["hospital"] + values: + hospital: [0] + env2: + split_name: train + group_by_fields: ["hospital"] + values: + hospital: [3] + env3: + split_name: train + group_by_fields: ["hospital"] + values: + hospital: [4] + +# For ID validation +val_config: + env1: + split_name: id_val + group_by_fields: ["hospital"] + values: + hospital: [0] + env2: + split_name: id_val + group_by_fields: ["hospital"] + values: + hospital: [3] + + env3: + split_name: id_val + group_by_fields: ["hospital"] + values: + hospital: [4] + +# Always OOD testing +test_config: + env1: + split_name: test + group_by_fields: ["hospital"] + values: + hospital: [2] \ No newline at end of file diff --git a/configs/data/dg/wilds/camelyon17_oodval.yaml b/configs/data/dg/wilds/camelyon17_oodval.yaml new file mode 100644 index 0000000..6affa59 --- /dev/null +++ b/configs/data/dg/wilds/camelyon17_oodval.yaml @@ -0,0 +1,64 @@ +# We train with 3 hospitals, validate either with same (ID) or with fourth hospital (OOD), and test with fifth. + +dataset_name: camelyon17 +n_classes: 2 + +# train: 302436 +# val (OOD): 34904 +# id_val (ID): 33560 +# test: 85054 + +# HOSPITALS: +# train: 0, 3, 4 (53425, 0, 0, 116959, 132052) +# val (OOD): 1 (0,34904,0,0,0) +# id_val (ID): (6011, 0, 0, 12879, 14670) +# test: 2 (0,0,85054,0,0) + +# SLIDES: +# test: 20-30 (28 is the most numerous, with 32k observations) + +transform: + # Important to load the right transform for the data. + # SOURCE: WILDS code + # https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + _target_: src.data.components.wilds_transforms.initialize_transform + dataset: + original_resolution: [96, 96] + transform_name: image_base + config: + target_resolution: [96, 96] + is_training: ??? + additional_transform_name: null + +train_config: + env1: + split_name: train + group_by_fields: ["hospital"] + values: + hospital: [0] + env2: + split_name: train + group_by_fields: ["hospital"] + values: + hospital: [3] + env3: + split_name: train + group_by_fields: ["hospital"] + values: + hospital: [4] + +# For OOD validation +val_config: + env1: + split_name: val + group_by_fields: ["hospital"] + values: + hospital: [1] + +# Always OOD testing +test_config: + env1: + split_name: test + group_by_fields: ["hospital"] + values: + hospital: [2] diff --git a/configs/data/dg/wilds/camelyon17_oracle.yaml b/configs/data/dg/wilds/camelyon17_oracle.yaml new file mode 100644 index 0000000..2fc9467 --- /dev/null +++ b/configs/data/dg/wilds/camelyon17_oracle.yaml @@ -0,0 +1,66 @@ +# In contrast with camelyion17.yaml, we train with an additional slide from the testing hospital. + +dataset_name: camelyon17 +n_classes: 2 + +# HOSPITALS: +# train & id_val: 0, 3, 4 +# val (OOD): 1 +# test: 2 + +# SLIDES: +# test: 20-30 (28 is the most numerous, with 32k observations) + +transform: + # Important to load the right transform for the data. + # SOURCE: WILDS code + # https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + _target_: src.data.components.wilds_transforms.initialize_transform + dataset: + original_resolution: [96, 96] + transform_name: image_base + config: + target_resolution: [96, 96] + is_training: ??? + additional_transform_name: null + +train_config: + env1: + split_name: train + group_by_fields: ["hospital"] + values: + hospital: [0] + env2: + split_name: train + group_by_fields: ["hospital"] + values: + hospital: [3] + env3: + split_name: train + group_by_fields: ["hospital"] + values: + hospital: [4] + + # Take a slide from the test hospital for training + env4: + split_name: test + group_by_fields: ["slide"] + values: + slide: [28] + +# For OOD validation +val_config: + env1: + split_name: val + group_by_fields: ["hospital"] + values: + hospital: [1] + + +# Always OOD testing +test_config: + env1: + split_name: test + group_by_fields: ["hospital"] + values: + hospital: [2] \ No newline at end of file diff --git a/configs/data/dg/wilds/celebA.yaml b/configs/data/dg/wilds/celebA.yaml new file mode 100644 index 0000000..dbc5dce --- /dev/null +++ b/configs/data/dg/wilds/celebA.yaml @@ -0,0 +1,64 @@ +dataset_name: celebA +n_classes: 2 + +# train: 162770 +# val: 19867 +# test: 19962 + +# MALE: 0-1 +# train: 0-1 (94509, 68261) +# val: 0-1 (11409, 8458) +# test: 0-1 (12247, 7715) + +# In the labels, we have spurious correlations: (MALE, Y), where Y: blonde/no blonde. +# (0,0): 89931 +# (0,1): 28234 +# (1,0): 82685 +# (1,1): 1749 + +transform: + # Important to load the right transform for the data. + # SOURCE: WILDS code + # https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + _target_: src.data.components.wilds_transforms.initialize_transform + dataset: + original_resolution: [178, 218] + transform_name: image_base + config: + target_resolution: null + is_training: ??? + additional_transform_name: null + + +# The domains are the labels (two hair colours) +train_config: + env1: + split_name: train + group_by_fields: ["y"] + values: + y: [0] + env2: + split_name: train + group_by_fields: ["y"] + values: + y: [1] + +val_config: + env1: + split_name: val + group_by_fields: ["y"] + values: + male: [0] + env2: + split_name: val + group_by_fields: ["y"] + values: + male: [1] + +test_config: + env1: + split_name: test + group_by_fields: ["male", "y"] + values: + male: [0, 1] + y: [0, 1] \ No newline at end of file diff --git a/configs/data/dg/wilds/fmow_idtest.yaml b/configs/data/dg/wilds/fmow_idtest.yaml new file mode 100644 index 0000000..e6d0dff --- /dev/null +++ b/configs/data/dg/wilds/fmow_idtest.yaml @@ -0,0 +1,548 @@ +# ID (in distribution) validation and testing + +dataset_name: fmow +n_classes: 62 + +# In general, the features of the dataset are the following: +# REGION: (Africa, Americas, Oceania, Asia, Europe) +# 0-5 (103299, 162333, 33239, 157711, 13253, 251) + +# YEAR: 0-15 (2002-2017) + +# SPLITS BY YEAR, REGION: +# train: 0-10 (2002-2012), 0-5 (17809, 34816, 1582, 20973, 1641, 42) +# val: 11-13 (2013-2015), 0-5 (4121, 7732, 803, 6562, 693, 4) +# id_val: 0-10 (2002-2012), 0-5 (2693, 5268, 1990, 3076, 251, 5) +# test: 14-15 (2016-2017), 0-5 (4963, 5858, 2593, 8024, 666, 4) +# id_test: 0-10 (2002-2012), 0-5 (2615, 7765, 7974, 11104, 11322, 11327) + +transform: + # Important to load the right transform for the data. + # SOURCE: WILDS code + # https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + _target_: src.data.components.wilds_transforms.initialize_transform + dataset: + original_resolution: [224, 224] + transform_name: image_base + config: + target_resolution: null + is_training: ??? + additional_transform_name: null + +train_config: + env1: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [0] + env2: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [1] + env3: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [2] + env4: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [3] + env5: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [4] + env6: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [5] + env7: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [6] + env8: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [7] + env9: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [8] + env10: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [9] + env11: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [10] + env12: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [0] + env13: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [1] + env14: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [2] + env15: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [3] + env16: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [4] + env17: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [5] + env18: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [6] + env19: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [7] + env20: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [8] + env21: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [9] + env22: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [10] + env23: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [0] + env24: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [1] + env25: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [2] + env26: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [3] + env27: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [4] + env28: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [5] + env29: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [6] + env30: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [7] + env31: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [8] + env32: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [9] + env33: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [10] + env34: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [0] + env35: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [1] + env36: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [2] + env37: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [3] + env38: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [4] + env39: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [5] + env40: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [6] + env41: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [7] + env42: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [8] + env43: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [9] + env44: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [10] + env45: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [0] + env46: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [1] + env47: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [2] + env48: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [3] + env49: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [4] + env50: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [5] + env51: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [6] + env52: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [7] + env53: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [8] + env54: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [9] + env55: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [10] + env56: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [0] + env57: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [1] + env58: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [2] + env59: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [3] + env60: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [4] + env61: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [5] + env62: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [6] + env63: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [7] + env64: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [8] + env65: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [9] + env66: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [10] + +# For OOD validation +val_config: + env1: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [0] + year: [11] + env2: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [0] + year: [12] + env3: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [0] + year: [13] + env4: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [1] + year: [11] + env5: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [1] + year: [12] + env6: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [1] + year: [13] + env7: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [2] + year: [11] + env8: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [2] + year: [12] + env9: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [2] + year: [13] + env10: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [3] + year: [11] + env11: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [3] + year: [12] + env12: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [3] + year: [13] + env13: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [4] + year: [11] + env14: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [4] + year: [12] + env15: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [4] + year: [13] + env16: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [5] + year: [11] + env17: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [5] + year: [12] + env18: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [5] + year: [13] + +# ID testing, in a single environment +test_config: + env1: + split_name: id_test + group_by_fields: ["region", "year"] + values: + region: [0, 1, 2, 3, 4, 5] + year: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] \ No newline at end of file diff --git a/configs/data/dg/wilds/fmow_oodtest.yaml b/configs/data/dg/wilds/fmow_oodtest.yaml new file mode 100644 index 0000000..82933e8 --- /dev/null +++ b/configs/data/dg/wilds/fmow_oodtest.yaml @@ -0,0 +1,548 @@ +# ID (in distribution) validation and testing + +dataset_name: fmow +n_classes: 62 + +# In general, the features of the dataset are the following: +# REGION: (Africa, Americas, Oceania, Asia, Europe) +# 0-5 (103299, 162333, 33239, 157711, 13253, 251) + +# YEAR: 0-15 (2002-2017) + +# SPLITS BY YEAR, REGION: +# train: 0-10 (2002-2012), 0-5 (17809, 34816, 1582, 20973, 1641, 42) +# val: 11-13 (2013-2015), 0-5 (4121, 7732, 803, 6562, 693, 4) +# id_val: 0-10 (2002-2012), 0-5 (2693, 5268, 1990, 3076, 251, 5) +# test: 14-15 (2016-2017), 0-5 (4963, 5858, 2593, 8024, 666, 4) +# id_test: 0-10 (2002-2012), 0-5 (2615, 7765, 7974, 11104, 11322, 11327) + +transform: + # Important to load the right transform for the data. + # SOURCE: WILDS code + # https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + _target_: src.data.components.wilds_transforms.initialize_transform + dataset: + original_resolution: [224, 224] + transform_name: image_base + config: + target_resolution: null + is_training: ??? + additional_transform_name: null + +train_config: + env1: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [0] + env2: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [1] + env3: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [2] + env4: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [3] + env5: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [4] + env6: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [5] + env7: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [6] + env8: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [7] + env9: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [8] + env10: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [9] + env11: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [0] + year: [10] + env12: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [0] + env13: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [1] + env14: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [2] + env15: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [3] + env16: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [4] + env17: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [5] + env18: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [6] + env19: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [7] + env20: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [8] + env21: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [9] + env22: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [1] + year: [10] + env23: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [0] + env24: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [1] + env25: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [2] + env26: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [3] + env27: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [4] + env28: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [5] + env29: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [6] + env30: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [7] + env31: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [8] + env32: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [9] + env33: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [2] + year: [10] + env34: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [0] + env35: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [1] + env36: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [2] + env37: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [3] + env38: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [4] + env39: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [5] + env40: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [6] + env41: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [7] + env42: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [8] + env43: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [9] + env44: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [3] + year: [10] + env45: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [0] + env46: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [1] + env47: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [2] + env48: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [3] + env49: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [4] + env50: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [5] + env51: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [6] + env52: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [7] + env53: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [8] + env54: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [9] + env55: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [4] + year: [10] + env56: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [0] + env57: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [1] + env58: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [2] + env59: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [3] + env60: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [4] + env61: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [5] + env62: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [6] + env63: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [7] + env64: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [8] + env65: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [9] + env66: + split_name: train + group_by_fields: ["region", "year"] + values: + region: [5] + year: [10] + +# For OOD validation +val_config: + env1: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [0] + year: [11] + env2: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [0] + year: [12] + env3: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [0] + year: [13] + env4: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [1] + year: [11] + env5: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [1] + year: [12] + env6: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [1] + year: [13] + env7: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [2] + year: [11] + env8: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [2] + year: [12] + env9: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [2] + year: [13] + env10: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [3] + year: [11] + env11: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [3] + year: [12] + env12: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [3] + year: [13] + env13: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [4] + year: [11] + env14: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [4] + year: [12] + env15: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [4] + year: [13] + env16: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [5] + year: [11] + env17: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [5] + year: [12] + env18: + split_name: val + group_by_fields: ["region", "year"] + values: + region: [5] + year: [13] + +# OOD testing, in a single environment +test_config: + env1: + split_name: test + group_by_fields: ["region", "year"] + values: + region: [0, 1, 2, 3, 4, 5] + year: [14, 15] \ No newline at end of file diff --git a/configs/data/dg/wilds/rxrx1_idtest.yaml b/configs/data/dg/wilds/rxrx1_idtest.yaml new file mode 100644 index 0000000..b3be6b3 --- /dev/null +++ b/configs/data/dg/wilds/rxrx1_idtest.yaml @@ -0,0 +1,294 @@ +dataset_name: rxrx1 +n_classes: 1139 + +# EVERY EXPERIMENT SHOULD CONTAIN ONLY ONE SAMPLE FROM EVERY CLASS. + +# train: 40612 +# val (ood): 9854 +# id_test (id): 40612 +# test (ood): 34432 + +# CELL_TYPE: 0-3 +# train: 0-3 (8622, 19671, 8623, 3696) +# val (ood): 0-3 (2462, 2464, 2464, 2464) +# id_test (id): 0-3 (8622, 19671, 8623, 12319) +# test (ood): 0-3 (7388, 17244, 7360, 2440) + +# EXPERIMENT: 0-50 +# train: 0-48 +# val (ood): 7, 27, 42, 49 +# id_test (id): 0-48 +# test (ood): 8, 9, 10, 28, 29, 30, 31, 32, 33, 34, 43, 44, 45, 50 + +# PLATE: 1-4 +# train: 1-4 (10153, 10153, 10153, 10153) +# val (ood): 1-4 (2464, 2464, 2464, 2464) +# id_test (id): 1-4 (10153, 10153, 10153, 10153) +# test (ood): 1-4 (8610, 8608, 8610, 8604) + +# WELL: 0-307 +# train: 0-307 +# val (ood): 0-307 +# id_test (id): 0-307 +# test (ood): 0-307 + +# SITE: 1-2 +# train: 1 (40612, 0) +# val (ood): 1-2 (4927, 4927) +# id_test (id): 2 (0, 40612) +# test (ood): 1-2 (17216, 17216) + + +transform: + # Important to load the right transform for the data. + # SOURCE: WILDS code + # https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + _target_: src.data.components.wilds_transforms.initialize_transform + dataset: + original_resolution: null + transform_name: rxrx1 + config: + target_resolution: [256, 256] + is_training: ??? + additional_transform_name: null + + +# 33 environments: {experiment} 0, 1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 35, 36, 37, 38, 39, 40, 41, 46, 47, 48 +# Also only from site 1 +train_config: + env1: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [0] + site: [1] + env2: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [1] + site: [1] + env3: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [2] + site: [1] + env4: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [3] + site: [1] + env5: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [4] + site: [1] + env6: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [5] + site: [1] + env7: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [6] + site: [1] + env8: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [11] + site: [1] + env9: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [12] + site: [1] + env10: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [13] + site: [1] + env11: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [14] + site: [1] + env12: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [15] + site: [1] + env13: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [16] + site: [1] + env14: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [17] + site: [1] + env15: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [18] + site: [1] + env16: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [19] + site: [1] + env17: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [20] + site: [1] + env18: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [21] + site: [1] + env19: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [22] + site: [1] + env20: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [23] + site: [1] + env21: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [24] + site: [1] + env22: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [25] + site: [1] + env23: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [26] + site: [1] + env24: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [35] + site: [1] + env25: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [36] + site: [1] + env26: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [37] + site: [1] + env27: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [38] + site: [1] + env28: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [39] + site: [1] + env29: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [40] + site: [1] + env30: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [41] + site: [1] + env31: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [46] + site: [1] + env32: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [47] + site: [1] + env33: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [48] + site: [1] + +# 4 environments: {experiment} 7, 27, 42, 49 +# Both sites. I specify them in the config dict to avoid confusion, although it's not necessary. +val_config: + env1: + split_name: val + group_by_fields: ["experiment", "site"] + values: + experiment: [7] + site: [1, 2] + env2: + split_name: val + group_by_fields: ["experiment", "site"] + values: + experiment: [27] + site: [1, 2] + env3: + split_name: val + group_by_fields: ["experiment", "site"] + values: + experiment: [42] + site: [1, 2] + env4: + split_name: val + group_by_fields: ["experiment", "site"] + values: + experiment: [49] + site: [1, 2] + +# For ID testing, same 33 environments as in training, but as a single dataset. +# Site 2 only +test_config: + env1: + split_name: id_test + group_by_fields: ["experiment", "site"] + values: + experiment: [0, 1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 35, 36, 37, 38, 39, 40, 41, 46, 47, 48] + site: [2] \ No newline at end of file diff --git a/configs/data/dg/wilds/rxrx1_mixed_to_test.yaml b/configs/data/dg/wilds/rxrx1_mixed_to_test.yaml new file mode 100644 index 0000000..ef3bb81 --- /dev/null +++ b/configs/data/dg/wilds/rxrx1_mixed_to_test.yaml @@ -0,0 +1,259 @@ +dataset_name: rxrx1 +n_classes: 1139 + +# EVERY EXPERIMENT SHOULD CONTAIN ONLY ONE SAMPLE FROM EVERY CLASS. + +# Variation of the OODtest for ERM baseline, in which 14 of the training experiments are replaced by +# the 14 test environment experiments but only for site 1 as well. +# The test size is thus reduced in half as we only test in site 2 now. +# I choose to substitute the training experiments 0-13 for convention. + +transform: + # Important to load the right transform for the data. + # SOURCE: WILDS code + # https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + _target_: src.data.components.wilds_transforms.initialize_transform + dataset: + original_resolution: null + transform_name: rxrx1 + config: + target_resolution: [256, 256] + is_training: ??? + additional_transform_name: null + +# We substitute 0-13 by [8, 9, 10, 28, 29, 30, 31, 32, 33, 34, 43, 44, 45, 50] +train_config: + env1: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [8] + site: [1] + env2: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [9] + site: [1] + env3: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [10] + site: [1] + env4: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [28] + site: [1] + env5: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [29] + site: [1] + env6: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [30] + site: [1] + env7: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [31] + site: [1] + env8: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [32] + site: [1] + env9: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [33] + site: [1] + env10: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [34] + site: [1] + env11: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [43] + site: [1] + env12: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [44] + site: [1] + env13: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [45] + site: [1] + env14: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [50] + site: [1] + env15: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [18] + site: [1] + env16: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [19] + site: [1] + env17: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [20] + site: [1] + env18: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [21] + site: [1] + env19: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [22] + site: [1] + env20: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [23] + site: [1] + env21: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [24] + site: [1] + env22: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [25] + site: [1] + env23: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [26] + site: [1] + env24: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [35] + site: [1] + env25: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [36] + site: [1] + env26: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [37] + site: [1] + env27: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [38] + site: [1] + env28: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [39] + site: [1] + env29: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [40] + site: [1] + env30: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [41] + site: [1] + env31: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [46] + site: [1] + env32: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [47] + site: [1] + env33: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [48] + site: [1] + +# Same as before. +val_config: + env1: + split_name: val + group_by_fields: ["experiment", "site"] + values: + experiment: [7] + site: [1, 2] + env2: + split_name: val + group_by_fields: ["experiment", "site"] + values: + experiment: [27] + site: [1, 2] + env3: + split_name: val + group_by_fields: ["experiment", "site"] + values: + experiment: [42] + site: [1, 2] + env4: + split_name: val + group_by_fields: ["experiment", "site"] + values: + experiment: [49] + site: [1, 2] + +# We remove site 1. +test_config: + env1: + split_name: test + group_by_fields: ["experiment", "site"] + values: + experiment: [8, 9, 10, 28, 29, 30, 31, 32, 33, 34, 43, 44, 45, 50] + site: [2] \ No newline at end of file diff --git a/configs/data/dg/wilds/rxrx1_oodtest.yaml b/configs/data/dg/wilds/rxrx1_oodtest.yaml new file mode 100644 index 0000000..159fae1 --- /dev/null +++ b/configs/data/dg/wilds/rxrx1_oodtest.yaml @@ -0,0 +1,293 @@ +dataset_name: rxrx1 +n_classes: 1139 + +# EVERY EXPERIMENT SHOULD CONTAIN ONLY ONE SAMPLE FROM EVERY CLASS. + +# train: 40612 +# val (ood): 9854 +# id_test (id): 40612 +# test (ood): 34432 + +# CELL_TYPE: 0-3 +# train: 0-3 (8622, 19671, 8623, 3696) +# val (ood): 0-3 (2462, 2464, 2464, 2464) +# id_test (id): 0-3 (8622, 19671, 8623, 12319) +# test (ood): 0-3 (7388, 17244, 7360, 2440) + +# EXPERIMENT: 0-50 +# train: 0-48 +# val (ood): 7, 27, 42, 49 +# id_test (id): 0-48 +# test (ood): 8, 9, 10, 28, 29, 30, 31, 32, 33, 34, 43, 44, 45, 50 + +# PLATE: 1-4 +# train: 1-4 (10153, 10153, 10153, 10153) +# val (ood): 1-4 (2464, 2464, 2464, 2464) +# id_test (id): 1-4 (10153, 10153, 10153, 10153) +# test (ood): 1-4 (8610, 8608, 8610, 8604) + +# WELL: 0-307 +# train: 0-307 +# val (ood): 0-307 +# id_test (id): 0-307 +# test (ood): 0-307 + +# SITE: 1-2 +# train: 1 (40612, 0) +# val (ood): 1-2 (4927, 4927) +# id_test (id): 2 (0, 40612) +# test (ood): 1-2 (17216, 17216) + +transform: + # Important to load the right transform for the data. + # SOURCE: WILDS code + # https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + _target_: src.data.components.wilds_transforms.initialize_transform + dataset: + original_resolution: null + transform_name: rxrx1 + config: + target_resolution: [256, 256] + is_training: ??? + additional_transform_name: null + + +# 33 environments: {experiment} 0, 1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 35, 36, 37, 38, 39, 40, 41, 46, 47, 48 +# Also only from site 1 +train_config: + env1: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [0] + site: [1] + env2: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [1] + site: [1] + env3: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [2] + site: [1] + env4: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [3] + site: [1] + env5: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [4] + site: [1] + env6: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [5] + site: [1] + env7: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [6] + site: [1] + env8: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [11] + site: [1] + env9: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [12] + site: [1] + env10: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [13] + site: [1] + env11: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [14] + site: [1] + env12: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [15] + site: [1] + env13: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [16] + site: [1] + env14: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [17] + site: [1] + env15: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [18] + site: [1] + env16: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [19] + site: [1] + env17: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [20] + site: [1] + env18: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [21] + site: [1] + env19: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [22] + site: [1] + env20: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [23] + site: [1] + env21: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [24] + site: [1] + env22: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [25] + site: [1] + env23: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [26] + site: [1] + env24: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [35] + site: [1] + env25: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [36] + site: [1] + env26: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [37] + site: [1] + env27: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [38] + site: [1] + env28: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [39] + site: [1] + env29: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [40] + site: [1] + env30: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [41] + site: [1] + env31: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [46] + site: [1] + env32: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [47] + site: [1] + env33: + split_name: train + group_by_fields: ["experiment", "site"] + values: + experiment: [48] + site: [1] + +# 4 environments: {experiment} 7, 27, 42, 49 +# Both sites. I specify them in the config dict to avoid confusion, although it's not necessary. +val_config: + env1: + split_name: val + group_by_fields: ["experiment", "site"] + values: + experiment: [7] + site: [1, 2] + env2: + split_name: val + group_by_fields: ["experiment", "site"] + values: + experiment: [27] + site: [1, 2] + env3: + split_name: val + group_by_fields: ["experiment", "site"] + values: + experiment: [42] + site: [1, 2] + env4: + split_name: val + group_by_fields: ["experiment", "site"] + values: + experiment: [49] + site: [1, 2] + +# For OOD testing: 8, 9, 10, 28, 29, 30, 31, 32, 33, 34, 43, 44, 45, 50 +# Sites 1 and 2. I specify them in the config dict to avoid confusion, although it's not necessary. +test_config: + env1: + split_name: test + group_by_fields: ["experiment", "site"] + values: + experiment: [8, 9, 10, 28, 29, 30, 31, 32, 33, 34, 43, 44, 45, 50] + site: [1, 2] \ No newline at end of file diff --git a/configs/data/dg/wilds/waterbirds.yaml b/configs/data/dg/wilds/waterbirds.yaml new file mode 100644 index 0000000..efb821a --- /dev/null +++ b/configs/data/dg/wilds/waterbirds.yaml @@ -0,0 +1,55 @@ +dataset_name: waterbirds +n_classes: 2 + +# train: 4795 +# val: 1199 +# test: 5794 + +# BACKGROUND 0-1 +# train: 0-1 (3554, 1241) # there is a clear subpopulation shift in the training set +# val: 0-1 (600, 599) +# test: 0-1 (2897, 2897) + +transform: + # Important to load the right transform for the data. + # SOURCE: WILDS code + # https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + _target_: src.data.components.wilds_transforms.initialize_transform + dataset: + original_resolution: [224, 224] + transform_name: image_resize_and_center_crop + config: + resize_scale: 1.142857 # 256.0/224.0 + is_training: ??? + additional_transform_name: null + +train_config: + env1: + split_name: train + group_by_fields: ["background"] + values: + background: [0] + env2: + split_name: train + group_by_fields: ["background"] + values: + background: [1] + +val_config: + env1: + split_name: val + group_by_fields: ["background"] + values: + background: [0] + env2: + split_name: val + group_by_fields: ["background"] + values: + background: [1] + +test_config: + env1: + split_name: test + group_by_fields: ["background"] + values: + background: [0, 1] \ No newline at end of file diff --git a/configs/data/dg/wilds_multienv.yaml b/configs/data/dg/wilds_multienv.yaml new file mode 100644 index 0000000..744c0e7 --- /dev/null +++ b/configs/data/dg/wilds_multienv.yaml @@ -0,0 +1,18 @@ +_target_: src.data.wilds_datamodules.WILDSDataModule + +dataset_name: ??? +dataset_dir: ${paths.data_dir}/dg/dg_datasets/wilds/ +cache: true + +# Configurations for the multienvironment +train_config: ??? +val_config: ??? +test_config: ??? + +# Collate function and transform depends on the model: +transform: ??? + +# This is for the dataloader, depends on the architecture of the model: +batch_size: ??? +num_workers: 0 +pin_memory: False \ No newline at end of file diff --git a/configs/experiment/adv/optimize_beta.yaml b/configs/experiment/adv/optimize_beta.yaml index d287fb1..3210658 100644 --- a/configs/experiment/adv/optimize_beta.yaml +++ b/configs/experiment/adv/optimize_beta.yaml @@ -14,6 +14,10 @@ defaults: # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters +data: + adv: + _target_: src.data.cifar10_datamodules.CIFAR10DataModulePA + model: adv: classifier: ${data.adv.classifier} diff --git a/configs/experiment/adv/optimize_beta_logits.yaml b/configs/experiment/adv/optimize_beta_logits.yaml index 7daef79..a4b6c9e 100644 --- a/configs/experiment/adv/optimize_beta_logits.yaml +++ b/configs/experiment/adv/optimize_beta_logits.yaml @@ -6,7 +6,7 @@ defaults: # Switch to logits dataloader and pass the classifier originally defined for the model data: adv: - _target_: src.data.cifar10_datamodule.CIFAR10DataModulelogits # use logits dataloader instead + _target_: src.data.cifar10_datamodule.CIFAR10DataModulePAlogits # use logits dataloader instead # Remove the classifier from the model instantiation, so that nn.Identity() is used model: diff --git a/configs/experiment/dg/erm_irm_lisa_pa.yaml b/configs/experiment/dg/erm_irm_lisa_pa.yaml index 78ac43e..aa68c9a 100644 --- a/configs/experiment/dg/erm_irm_lisa_pa.yaml +++ b/configs/experiment/dg/erm_irm_lisa_pa.yaml @@ -27,13 +27,12 @@ seed: 12345 trainer: min_epochs: 50 max_epochs: 50 - limit_val_batches: 0.0 # no validation ofc #log_every_n_steps: 10 devices: 1 data: dg: - datasets_dir: ${paths.data_dir}/dg/dg_datasets/rebuttal/ + datasets_dir: ${paths.data_dir}/dg/dg_datasets/rebuttaltf/ collate_fn: _target_: hydra.utils.get_method path: src.data.components.collate_functions.MultiEnv_collate_fn diff --git a/configs/experiment/dg/pa_training.yaml b/configs/experiment/dg/pa_training.yaml new file mode 100644 index 0000000..1b522c9 --- /dev/null +++ b/configs/experiment/dg/pa_training.yaml @@ -0,0 +1,20 @@ +# @package _global_ + +defaults: + - ./dg/train_dg_erm.yaml # same configuration as in the ERM baseline + - /data/dg/@trainer.pa_datamodule: diagvib_PA.yaml + +trainer: + _target_: src.models.components.PA_trainer.PosteriorAgreementTrainer + pa_optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 0.0 + pa_epochs: 50 + beta0: 1.0 + pa_datamodule: + # Mandatory values in diagvib_PA.yaml, therefore not yet overriden: + collate_fn: ${data.dg.collate_fn} + batch_size: ${data.dg.batch_size} + # envs_index, envs_name and shift_ratio should also be specified, either here or via flags/.sh \ No newline at end of file diff --git a/configs/experiment/dg/train_dg_erm copy.yaml b/configs/experiment/dg/train_dg_erm copy.yaml new file mode 100644 index 0000000..5586f1e --- /dev/null +++ b/configs/experiment/dg/train_dg_erm copy.yaml @@ -0,0 +1,55 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data/dg@data: wilds_multienv.yaml + - override /model/dg@model: erm.yaml + - override /trainer: ddp.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +name: ??? +tags: ["wilds", "erm", "${name}"] + +seed: 123 + +trainer: + min_epochs: 100 + max_epochs: 100 + gradient_clip_val: 0.0 + +model: + optimizer: + _target_: torch.optim.SGD + lr: 0.0001 + net: + pretrained: false + net: resnet18 + n_classes: ${data.dg.n_classes} + + scheduler: null + +data: + # datasets_dir: ${paths.data_dir}/dg/dg_datasets/rebuttalft/ + dataset_dir: ${paths.data_dir}/dg/dg_datasets/wilds/ + # collate_fn: + # _target_: hydra.utils.get_method # to avoid instantiation of a callable + # path: src.data.components.collate_functions.SingleEnv_collate_fn + batch_size: 64 + num_workers: 2 + pin_memory: true + + +logger: + wandb: + tags: ${tags} + entity: entity-name + project: project-name + group: group-name + save_dir: ${paths.output_dir}/dg + name: ${name} + + diff --git a/configs/experiment/dg/train_dg_erm.yaml b/configs/experiment/dg/train_dg_erm.yaml index 7190c02..61d951d 100644 --- a/configs/experiment/dg/train_dg_erm.yaml +++ b/configs/experiment/dg/train_dg_erm.yaml @@ -4,39 +4,39 @@ # python train.py experiment=example defaults: - - override /data/dg: diagvib_multienv.yaml - - override /model/dg: erm.yaml + - override /data/dg@data: wilds_multienv.yaml + - override /model/dg@model: erm.yaml - override /trainer: ddp.yaml # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters name: ??? -tags: ["diagvib", "erm", "${name}"] +tags: ["wilds", "erm", "${name}"] seed: 123 trainer: min_epochs: 100 max_epochs: 100 - gradient_clip_val: 0.5 + gradient_clip_val: 0.0 model: optimizer: - lr: 0.002 + _target_: torch.optim.SGD + lr: 0.0001 net: - pretrained: true + pretrained: false net: resnet18 + n_classes: ${data.dg.n_classes} + + scheduler: null data: - dg: - datasets_dir: ${paths.data_dir}/dg/dg_datasets/rebuttal/ - collate_fn: - _target_: hydra.utils.get_method # to avoid instantiation of a callable - path: src.data.components.collate_functions.SingleEnv_collate_fn - batch_size: 64 - num_workers: 2 - pin_memory: true + dataset_dir: ${paths.data_dir}/dg/dg_datasets/wilds/ + batch_size: 64 + num_workers: 2 + pin_memory: true logger: diff --git a/configs/experiment/dg/train_dg_erm_old.yaml b/configs/experiment/dg/train_dg_erm_old.yaml new file mode 100644 index 0000000..da57302 --- /dev/null +++ b/configs/experiment/dg/train_dg_erm_old.yaml @@ -0,0 +1,53 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data/dg: diagvib_multienv.yaml + - override /model/dg: erm.yaml + - override /trainer: ddp.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +name: ??? +tags: ["diagvib", "erm", "${name}"] + +seed: 123 + +trainer: + min_epochs: 100 + max_epochs: 100 + gradient_clip_val: 0.5 + +model: + dg: + optimizer: + _target_: torch.optim.SGD + lr: 0.002 + net: + pretrained: false + net: resnet18 + +data: + dg: + datasets_dir: ${paths.data_dir}/dg/dg_datasets/test_data_pipeline/ + collate_fn: + _target_: hydra.utils.get_method # to avoid instantiation of a callable + path: src.data.components.collate_functions.SingleEnv_collate_fn + batch_size: 4 + num_workers: 0 + pin_memory: true + + +logger: + wandb: + tags: ${tags} + entity: entity-name + project: project-name + group: group-name + save_dir: ${paths.output_dir}/dg + name: ${name} + + diff --git a/configs/experiment/dg/train_dg_irm.yaml b/configs/experiment/dg/train_dg_irm.yaml index 25a9dcb..208a272 100644 --- a/configs/experiment/dg/train_dg_irm.yaml +++ b/configs/experiment/dg/train_dg_irm.yaml @@ -4,8 +4,8 @@ # python train.py experiment=example defaults: - - override /data/dg: diagvib_multienv.yaml - - override /model/dg: irm.yaml + - override /data/dg@data: wilds_multienv.yaml + - override /model/dg@model: irm.yaml - override /trainer: ddp.yaml @@ -13,32 +13,34 @@ defaults: # this allows you to overwrite only specified parameters name: ??? -tags: ["diagvib", "irm", "${name}"] +tags: ["wilds", "irm", "${name}"] seed: 123 trainer: min_epochs: 100 max_epochs: 100 - gradient_clip_val: 0.5 + gradient_clip_val: 0.0 model: optimizer: - lr: 0.002 + _target_: torch.optim.SGD + lr: 0.0001 net: - pretrained: true + pretrained: false net: resnet18 + n_classes: ${data.dg.n_classes} data: - dg: - datasets_dir: ${paths.data_dir}/dg/dg_datasets/rebuttal/ - collate_fn: - _target_: hydra.utils.get_method # to avoid instantiation of a callable - path: src.data.components.collate_functions.MultiEnv_collate_fn - batch_size: 64 - num_workers: 2 - pin_memory: true - # root_dir: '/cluster/project/jbuhmann/posterior_agreement/adv_pa/data/diagvib/domain_shift' # data_dir is specified in config.yaml + # datasets_dir: ${paths.data_dir}/dg/dg_datasets/rebuttal/ + dataset_dir: ${paths.data_dir}/dg/dg_datasets/wilds/ + # collate_fn: + # _target_: hydra.utils.get_method # to avoid instantiation of a callable + # path: src.data.components.collate_functions.MultiEnv_collate_fn + batch_size: 64 + num_workers: 2 + pin_memory: true + # root_dir: '/cluster/project/jbuhmann/posterior_agreement/adv_pa/data/diagvib/domain_shift' # data_dir is specified in config.yaml logger: diff --git a/configs/experiment/dg/train_dg_lisa.yaml b/configs/experiment/dg/train_dg_lisa.yaml index f75ab78..325902b 100644 --- a/configs/experiment/dg/train_dg_lisa.yaml +++ b/configs/experiment/dg/train_dg_lisa.yaml @@ -4,8 +4,8 @@ # python train.py experiment=example defaults: - - override /data/dg: diagvib_multienv.yaml - - override /model/dg: lisa.yaml + - override /data/dg@data: wilds_multienv.yaml + - override /model/dg@model: lisa.yaml - override /trainer: ddp.yaml @@ -13,32 +13,37 @@ defaults: # this allows you to overwrite only specified parameters name: ??? -tags: ["diagvib", "lisa", "${name}"] +tags: ["wilds", "lisa", "${name}"] seed: 123 trainer: min_epochs: 100 max_epochs: 100 - gradient_clip_val: 0.5 + gradient_clip_val: 0.0 model: + mixup_strategy: cutmix optimizer: - lr: 0.002 + _target_: torch.optim.SGD + lr: 0.0001 net: - pretrained: true + pretrained: false net: resnet18 + n_classes: ${data.dg.n_classes} + + scheduler: null data: - dg: - datasets_dir: ${paths.data_dir}/dg/dg_datasets/rebuttal/ - collate_fn: - _target_: hydra.utils.get_method # to avoid instantiation of a callable - path: src.data.components.collate_functions.MultiEnv_collate_fn - batch_size: 64 - num_workers: 2 - pin_memory: true - # root_dir: '/cluster/project/jbuhmann/posterior_agreement/adv_pa/data/diagvib/domain_shift' # data_dir is specified in config.yaml + # datasets_dir: ${paths.data_dir}/dg/dg_datasets/rebuttal/ + dataset_dir: ${paths.data_dir}/dg/dg_datasets/wilds/ + # collate_fn: + # _target_: hydra.utils.get_method # to avoid instantiation of a callable + # path: src.data.components.collate_functions.MultiEnv_collate_fn + batch_size: 64 + num_workers: 2 + pin_memory: true + # root_dir: '/cluster/project/jbuhmann/posterior_agreement/adv_pa/data/diagvib/domain_shift' # data_dir is specified in config.yaml logger: diff --git a/configs/experiment/dg/wilds/camelyon17_erm.yaml b/configs/experiment/dg/wilds/camelyon17_erm.yaml new file mode 100644 index 0000000..e12182a --- /dev/null +++ b/configs/experiment/dg/wilds/camelyon17_erm.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +# Source: WILDS paper & code +# https://arxiv.org/pdf/2012.07421.pdf +# https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + +defaults: + - /experiment/dg/wilds/wilds.yaml + - /model/dg@model: erm.yaml + +# Batch size = 32 +trainer: + accumulate_grad_batches: 2 + max_epochs: 10 + +model: + optimizer: + _target_: torch.optim.SGD + lr: 0.001 + weight_decay: 0.01 + momentum: 0.9 + net: + pretrained: true + net: densenet121 + n_classes: ${data.n_classes} + + scheduler: null + +tags: + - "dg" + - "wilds" + - "${data.dataset_name}" + - "erm" + - "${model.net.net}" + - "${name_logger}" + - "${classname: ${model.optimizer._target_}}" \ No newline at end of file diff --git a/configs/experiment/dg/wilds/camelyon17_irm.yaml b/configs/experiment/dg/wilds/camelyon17_irm.yaml new file mode 100644 index 0000000..bd21fdf --- /dev/null +++ b/configs/experiment/dg/wilds/camelyon17_irm.yaml @@ -0,0 +1,31 @@ +# @package _global_ + +# Source: WILDS paper & code +# https://arxiv.org/pdf/2012.07421.pdf +# https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + +defaults: + - /experiment/dg/wilds/wilds.yaml + - /model/dg@model: irm.yaml + +# Batch size is 32, we will use 16 with 2 gradient accumulation steps +trainer: + accumulate_grad_batches: 2 + max_epochs: 10 + +model: + lamb: 1.0 + + optimizer: + _target_: torch.optim.SGD + lr: 0.001 + weight_decay: 0.01 + momentum: 0.9 + net: + pretrained: true + net: densenet121 + n_classes: ${data.n_classes} + + scheduler: null + +tags: ["dg", "wilds", "${data.dataset_name}", "irm", "${model.net.net}", "${name_logger}"] \ No newline at end of file diff --git a/configs/experiment/dg/wilds/camelyon17_lisa.yaml b/configs/experiment/dg/wilds/camelyon17_lisa.yaml new file mode 100644 index 0000000..ec32b2c --- /dev/null +++ b/configs/experiment/dg/wilds/camelyon17_lisa.yaml @@ -0,0 +1,33 @@ +# @package _global_ + +# SOURCE: LISA paper & code +# https://arxiv.org/pdf/2201.00299.pdf +# https://github.com/huaxiuyao/LISA/blob/c857a6b296b5d130898f0d51a6d693411c39e651/domain_shifts/config.py + +defaults: + - /experiment/dg/wilds/wilds.yaml + - override /model/dg@model: lisa.yaml + +# Batch size is 32, we will use 16 with 2 gradient accumulation steps +trainer: + max_epochs: 1 + accumulate_grad_batches: 2 + +model: + ppred: 1.0 + mix_alpha: 2.0 + mixup_strategy: cutmix + + optimizer: + _target_: torch.optim.SGD + lr: 0.0001 + weight_decay: 0.0 + momentum: 0.9 + net: + pretrained: true + net: densenet121 + n_classes: ${data.n_classes} + + scheduler: null + +tags: ["dg", "wilds", "${data.dataset_name}", "lisa", "${model.net.net}", "${name_logger}"] \ No newline at end of file diff --git a/configs/experiment/dg/wilds/celebA_erm.yaml b/configs/experiment/dg/wilds/celebA_erm.yaml new file mode 100644 index 0000000..7da68ee --- /dev/null +++ b/configs/experiment/dg/wilds/celebA_erm.yaml @@ -0,0 +1,29 @@ +# @package _global_ + +# Source: LISA paper & WILDS code +# https://arxiv.org/pdf/2201.00299.pdf +# https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + +defaults: + - /experiment/dg/wilds/wilds.yaml + - override /model/dg@model: erm.yaml + +# The original batch size is 64, we will use 16 with 4 gradient accumulation steps +trainer: + max_epochs: 200 + accumulate_grad_batches: 4 + +model: + optimizer: + _target_: torch.optim.SGD + lr: 0.001 + weight_decay: 0.0 + momentum: 0.9 + net: + pretrained: true + net: resnet50 + n_classes: ${data.n_classes} + + scheduler: null + +tags: ["dg", "wilds", "${data.dataset_name}", "erm", "${model.net.net}", "${name_logger}"] \ No newline at end of file diff --git a/configs/experiment/dg/wilds/celebA_irm.yaml b/configs/experiment/dg/wilds/celebA_irm.yaml new file mode 100644 index 0000000..70fd668 --- /dev/null +++ b/configs/experiment/dg/wilds/celebA_irm.yaml @@ -0,0 +1,30 @@ +# @package _global_ + +# Source: LISA paper & WILDS code +# https://arxiv.org/pdf/2201.00299.pdf +# https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + +defaults: + - /experiment/dg/wilds/wilds.yaml + - override /model/dg@model: irm.yaml + +# The original batch size is 64, we will use 16 with 4 gradient accumulation steps +trainer: + max_epochs: 200 + accumulate_grad_batches: 4 + +model: + lamb: 1.0 # hyperparameter search needed + optimizer: + _target_: torch.optim.SGD + lr: 0.001 + weight_decay: 0.0 + momentum: 0.9 + net: + pretrained: true + net: resnet50 + n_classes: ${data.n_classes} + + scheduler: null + +tags: ["dg", "wilds", "${data.dataset_name}", "irm", "${model.net.net}", "${name_logger}"] \ No newline at end of file diff --git a/configs/experiment/dg/wilds/celebA_lisa.yaml b/configs/experiment/dg/wilds/celebA_lisa.yaml new file mode 100644 index 0000000..3a1d576 --- /dev/null +++ b/configs/experiment/dg/wilds/celebA_lisa.yaml @@ -0,0 +1,33 @@ +# @package _global_ + +# SOURCE: LISA paper & code +# https://arxiv.org/pdf/2201.00299.pdf +# https://github.com/huaxiuyao/LISA/blob/c857a6b296b5d130898f0d51a6d693411c39e651/domain_shifts/config.py + +defaults: + - /experiment/dg/wilds/wilds.yaml + - override /model/dg@model: lisa.yaml + +# The original batch size is 16 +trainer: + max_epochs: 50 + accumulate_grad_batches: null + +model: + ppred: 0.5 + mix_alpha: 2.0 + mixup_strategy: cutmix + + optimizer: + _target_: torch.optim.SGD + lr: 0.0001 + weight_decay: 0.0001 + momentum: 0.0 + net: + pretrained: true + net: resnet50 + n_classes: ${data.n_classes} + + scheduler: null + +tags: ["dg", "wilds", "${data.dataset_name}", "lisa", "${model.net.net}", "${name_logger}"] \ No newline at end of file diff --git a/configs/experiment/dg/wilds/fmow_erm.yaml b/configs/experiment/dg/wilds/fmow_erm.yaml new file mode 100644 index 0000000..b5d4ce5 --- /dev/null +++ b/configs/experiment/dg/wilds/fmow_erm.yaml @@ -0,0 +1,44 @@ +# @package _global_ + +# Source: WILDS paper & code +# https://arxiv.org/pdf/2012.07421.pdf +# https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + +defaults: + - /experiment/dg/wilds/wilds.yaml + - override /model/dg@model: erm.yaml + +# The original batch size is 32, we will use 16 with 2 gradient accumulation steps +trainer: + max_epochs: 60 + accumulate_grad_batches: 2 + +model: + optimizer: + _target_: torch.optim.Adam + lr: 0.0001 + weight_decay: 0.0 + momentum: 0.0 + net: + pretrained: true + net: densenet121 + n_classes: ${data.n_classes} + + scheduler: + interval: epoch + frequency: 1 + + scheduler: + _target_: torch.optim.lr_scheduler.StepLR + _partial_: true + gamma: 0.96 + +tags: + - "dg" + - "wilds" + - "${data.dataset_name}" + - "erm" + - "${model.net.net}" + - "${name_logger}" + - "${classname: ${model.optimizer._target_}}" + - "${classname: ${model.scheduler.scheduler._target_}}" \ No newline at end of file diff --git a/configs/experiment/dg/wilds/fmow_irm.yaml b/configs/experiment/dg/wilds/fmow_irm.yaml new file mode 100644 index 0000000..2ddd4c1 --- /dev/null +++ b/configs/experiment/dg/wilds/fmow_irm.yaml @@ -0,0 +1,40 @@ +# @package _global_ + +# Source: WILDS paper & code +# https://arxiv.org/pdf/2012.07421.pdf +# https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + +defaults: + - /experiment/dg/wilds/wilds.yaml + - override /model/dg@model: irm.yaml + +# The original batch size is 32, we will use 16 with 2 gradient accumulation steps +trainer: + max_epochs: 60 + accumulate_grad_batches: 2 + +model: + lamb: 1.0 + + optimizer: + _target_: torch.optim.Adam + lr: 0.0001 + weight_decay: 0.0 + momentum: 0.0 + net: + pretrained: true + net: densenet121 + n_classes: ${data.n_classes} + + scheduler: + interval: epoch + frequency: 1 + + scheduler: + _target_: torch.optim.lr_scheduler.StepLR + _partial_: true + gamma: 0.96 + + + +tags: ["dg", "wilds", "${data.dataset_name}", "irm", "${model.net.net}", "${name_logger}"] \ No newline at end of file diff --git a/configs/experiment/dg/wilds/fmow_lisa.yaml b/configs/experiment/dg/wilds/fmow_lisa.yaml new file mode 100644 index 0000000..9c6ea26 --- /dev/null +++ b/configs/experiment/dg/wilds/fmow_lisa.yaml @@ -0,0 +1,34 @@ +# @package _global_ + +# SOURCE: LISA paper & code +# https://arxiv.org/pdf/2201.00299.pdf +# https://github.com/huaxiuyao/LISA/blob/c857a6b296b5d130898f0d51a6d693411c39e651/domain_shifts/config.py + +defaults: + - /experiment/dg/wilds/wilds.yaml + - override /model/dg@model: lisa.yaml + +# The original batch size is 32, we will use 16 with 2 gradient accumulation steps +trainer: + max_epochs: 5 + accumulate_grad_batches: 2 + +model: + ppred: 1.0 + mix_alpha: 2.0 + mixup_strategy: cutmix + + optimizer: + _target_: torch.optim.Adam + lr: 0.0001 + weight_decay: 0.0 + amsgrad: true + + net: + pretrained: true + net: densenet121 + n_classes: ${data.n_classes} + + scheduler: null + +tags: ["dg", "wilds", "${data.dataset_name}", "lisa", "${model.net.net}", "${name_logger}"] \ No newline at end of file diff --git a/configs/experiment/dg/wilds/rxrx1_erm.yaml b/configs/experiment/dg/wilds/rxrx1_erm.yaml new file mode 100644 index 0000000..4278d12 --- /dev/null +++ b/configs/experiment/dg/wilds/rxrx1_erm.yaml @@ -0,0 +1,39 @@ +# @package _global_ + +# Source: WILDS paper & code +# https://arxiv.org/pdf/2012.07421.pdf +# https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + +defaults: + - /experiment/dg/wilds/wilds.yaml + - override /model/dg@model: erm.yaml + +# The original batch size is 72, we will use 18 with 4 gradient accumulation steps +data: + batch_size: 18 +trainer: + max_epochs: 90 + accumulate_grad_batches: 4 + +model: + optimizer: + _target_: torch.optim.Adam + lr: 0.001 + weight_decay: 0.00001 + momentum: 0.0 + net: + pretrained: true + net: resnet50 + n_classes: ${data.n_classes} + + scheduler: + interval: step + frequency: 1 + + scheduler: + _target_: transformers.get_cosine_schedule_with_warmup + _partial_: true + num_warmup_steps: ${eval:'${trainer.accumulate_grad_batches} * 5415'} # 5415 is the number of warmup steps in the original code + num_training_steps: ${eval:'0000 * ${trainer.max_epochs}'} # number_steps * number_of_epochs + +tags: ["dg", "wilds", "${data.dataset_name}", "erm", "${model.net.net}", "${name_logger}"] \ No newline at end of file diff --git a/configs/experiment/dg/wilds/rxrx1_irm.yaml b/configs/experiment/dg/wilds/rxrx1_irm.yaml new file mode 100644 index 0000000..c867bb7 --- /dev/null +++ b/configs/experiment/dg/wilds/rxrx1_irm.yaml @@ -0,0 +1,41 @@ +# @package _global_ + +# Source: WILDS paper & code +# https://arxiv.org/pdf/2012.07421.pdf +# https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + +defaults: + - /experiment/dg/wilds/wilds.yaml + - override /model/dg@model: irm.yaml + +# The original batch size is 72, we will use 18 with 4 gradient accumulation steps +data: + batch_size: 18 +trainer: + max_epochs: 90 + accumulate_grad_batches: 4 + +model: + lamb: 1.0 + + optimizer: + _target_: torch.optim.Adam + lr: 0.001 + weight_decay: 0.00001 + momentum: 0.0 + net: + pretrained: true + net: resnet50 + n_classes: ${data.n_classes} + + scheduler: + interval: step + frequency: 1 + + scheduler: + _target_: transformers.get_cosine_schedule_with_warmup + _partial_: true + num_warmup_steps: ${eval:'${trainer.accumulate_grad_batches} * 5415'} # 5415 is the number of warmup steps in the original code + num_training_steps: ${eval:'0000 * ${trainer.max_epochs}'} # number_steps * number_of_epochs + +tags: ["dg", "wilds", "${data.dataset_name}", "irm", "${model.net.net}", "${name_logger}"] \ No newline at end of file diff --git a/configs/experiment/dg/wilds/rxrx1_lisa.yaml b/configs/experiment/dg/wilds/rxrx1_lisa.yaml new file mode 100644 index 0000000..53f02de --- /dev/null +++ b/configs/experiment/dg/wilds/rxrx1_lisa.yaml @@ -0,0 +1,44 @@ +# @package _global_ + +# SOURCE: LISA paper & code +# https://arxiv.org/pdf/2201.00299.pdf +# https://github.com/huaxiuyao/LISA/blob/c857a6b296b5d130898f0d51a6d693411c39e651/domain_shifts/config.py + +defaults: + - /experiment/dg/wilds/wilds.yaml + - override /model/dg@model: lisa.yaml + +# The original batch size is 72, we will use 18 with 4 gradient accumulation steps +data: + batch_size: 18 +trainer: + max_epochs: 90 + accumulate_grad_batches: 4 + +model: + ppred: 1.0 + mix_alpha: 2.0 + mixup_strategy: cutmix + + optimizer: + _target_: torch.optim.Adam + lr: 0.001 + weight_decay: 0.00001 + amsgrad: true + + net: + pretrained: true + net: resnet50 + n_classes: ${data.n_classes} + + scheduler: + interval: step + frequency: 1 + + scheduler: + _target_: transformers.get_cosine_schedule_with_warmup + _partial_: true + num_warmup_steps: ${eval:'${trainer.accumulate_grad_batches} * 5415'} # 5415 is the number of warmup steps in the original code + num_training_steps: ${eval:'0000 * ${trainer.max_epochs}'} # number_steps * number_of_epochs + +tags: ["dg", "wilds", "${data.dataset_name}", "lisa", "${model.net.net}", "${name_logger}"] \ No newline at end of file diff --git a/configs/experiment/dg/wilds/waterbirds_erm.yaml b/configs/experiment/dg/wilds/waterbirds_erm.yaml new file mode 100644 index 0000000..3d80132 --- /dev/null +++ b/configs/experiment/dg/wilds/waterbirds_erm.yaml @@ -0,0 +1,29 @@ +# @package _global_ + +# Source: LISA paper & WILDS code +# https://arxiv.org/pdf/2201.00299.pdf +# https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + +defaults: + - /experiment/dg/wilds/wilds.yaml + - override /model/dg@model: erm.yaml + +# The original batch size is 128, we will use 16 with 8 gradient accumulation steps +trainer: + max_epochs: 300 + accumulate_grad_batches: 8 + +model: + optimizer: + _target_: torch.optim.SGD + lr: 0.00001 + weight_decay: 1.0 + momentum: 0.9 + net: + pretrained: true + net: resnet50 + n_classes: ${data.n_classes} + + scheduler: null + +tags: ["dg", "wilds", "${data.dataset_name}", "erm", "${model.net.net}", "${name_logger}"] \ No newline at end of file diff --git a/configs/experiment/dg/wilds/waterbirds_irm.yaml b/configs/experiment/dg/wilds/waterbirds_irm.yaml new file mode 100644 index 0000000..f4d2739 --- /dev/null +++ b/configs/experiment/dg/wilds/waterbirds_irm.yaml @@ -0,0 +1,29 @@ +# @package _global_ + +# Source: LISA paper & WILDS code +# https://arxiv.org/pdf/2201.00299.pdf +# https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/configs/datasets.py + +defaults: + - /experiment/dg/wilds/wilds.yaml + - override /model/dg@model: irm.yaml + +# The original batch size is 128, we will use 16 with 8 gradient accumulation steps +trainer: + max_epochs: 300 + accumulate_grad_batches: 8 + +model: + optimizer: + _target_: torch.optim.SGD + lr: 0.00001 + weight_decay: 1.0 + momentum: 0.9 + net: + pretrained: true + net: resnet50 + n_classes: ${data.n_classes} + + scheduler: null + +tags: ["dg", "wilds", "${data.dataset_name}", "irm", "${model.net.net}", "${name_logger}"] \ No newline at end of file diff --git a/configs/experiment/dg/wilds/waterbirds_lisa.yaml b/configs/experiment/dg/wilds/waterbirds_lisa.yaml new file mode 100644 index 0000000..236c710 --- /dev/null +++ b/configs/experiment/dg/wilds/waterbirds_lisa.yaml @@ -0,0 +1,33 @@ +# @package _global_ + +# SOURCE: LISA paper & code +# https://arxiv.org/pdf/2201.00299.pdf +# https://github.com/huaxiuyao/LISA/blob/c857a6b296b5d130898f0d51a6d693411c39e651/domain_shifts/config.py + +defaults: + - /experiment/dg/wilds/wilds.yaml + - override /model/dg@model: lisa.yaml + +# The original batch size is 16 +trainer: + max_epochs: 300 + accumulate_grad_batches: 0 + +model: + ppred: 0.5 + mix_alpha: 2.0 + mixup_strategy: mixup + + optimizer: + _target_: torch.optim.SGD + lr: 0.001 + weight_decay: 0.0001 + momentum: 0.0 + net: + pretrained: true + net: resnet50 + n_classes: ${data.n_classes} + + scheduler: null + +tags: ["dg", "wilds", "${data.dataset_name}", "lisa", "${model.net.net}", "${name_logger}"] \ No newline at end of file diff --git a/configs/experiment/dg/wilds/wilds.yaml b/configs/experiment/dg/wilds/wilds.yaml new file mode 100644 index 0000000..5b70e3e --- /dev/null +++ b/configs/experiment/dg/wilds/wilds.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +defaults: + - /data/dg@data: wilds_multienv.yaml + - /trainer: ddp.yaml + +name_logger: ??? + +trainer: + min_epochs: 1 + max_epochs: 10 + gradient_clip_val: 0.0 + + multiple_trainloader_mode: ${data.multiple_trainloader_mode} + accumulate_grad_batches: 0 + replace_sampler_ddp: true + +data: + dataset_dir: ${paths.data_dir}/dg/dg_datasets/wilds/ + batch_size: 16 + num_workers: 2 + pin_memory: true + multiple_trainloader_mode: max_size_cycle # vs min_size, depends on how to deal with spurious datasets. + +logger: + wandb: + tags: ${tags} + entity: malvai + project: cov_pa + group: group-name + save_dir: ${paths.output_dir}/dg + name: ${name_logger} + +task_name: dg_wilds + + diff --git a/configs/experiment/example.yaml b/configs/experiment/example.yaml deleted file mode 100644 index b68c913..0000000 --- a/configs/experiment/example.yaml +++ /dev/null @@ -1,40 +0,0 @@ -# @package _global_ - -# to execute this experiment run: -# python train.py experiment=example - -defaults: - - override /data: mnist.yaml - - override /model: mnist.yaml - - override /callbacks: default.yaml - - override /trainer: default.yaml - -# all parameters below will be merged with parameters from default configurations set above -# this allows you to overwrite only specified parameters - -tags: ["mnist", "simple_dense_net"] - -seed: 12345 - -trainer: - min_epochs: 10 - max_epochs: 10 - gradient_clip_val: 0.5 - -model: - optimizer: - lr: 0.002 - net: - lin1_size: 128 - lin2_size: 256 - lin3_size: 64 - -data: - batch_size: 64 - -logger: - wandb: - tags: ${tags} - group: "mnist" - aim: - experiment: "mnist" diff --git a/configs/generate_dg_data.yaml b/configs/generate_dg_data.yaml index ea15088..4c971b4 100644 --- a/configs/generate_dg_data.yaml +++ b/configs/generate_dg_data.yaml @@ -9,13 +9,18 @@ extras: ignore_warnings: True BATCH_SIZE: 64 -DATASETS_DIR: ${paths.data_dir}/dg/dg_datasets/${envs_name}/ +DATASETS_DIR: ${paths.data_dir}/dg/dg_datasets/${folder_name}/ # Desired sizes of the datasets # If the datasets are meant to have unique observations, do not surpass the sizes (4500, 1000, 900). -SIZE_TRAIN: 4500 -SIZE_VAL: 1000 -SIZE_TEST: 900 +SIZE_TRAIN: 128 +SIZE_VAL: 128 +SIZE_TEST: 128 + +# SIZE_TRAIN: 4500 +# SIZE_VAL: 1000 +# SIZE_TEST: 900 + # SIZE_TRAIN: 40000 # SIZE_VAL: 20000 # SIZE_TEST: 10000 @@ -23,12 +28,26 @@ SIZE_TEST: 900 ## FACTORS: hue, lightness, texture, position, scale BALANCE: False train_val_randperm: True # If each training and validation environment is to be shuffled randomly. In such case there will be no label correspondence between the two environments. -envs_name: randnobal +folder_name: test_data_pipeline +file_name: singlevar # 2 train/val environments +# train_val_especs: +# train_val_1: [0,0,2,8,4] +# train_val_2: [4,3,1,6,3] + +# # 6 test environments +# test_especs: +# test_0: [0,0,2,8,4] +# test_1: [2,0,2,8,4] +# test_2: [2,1,2,8,4] +# test_3: [2,1,2,1,4] +# test_4: [2,1,2,1,1] +# test_5: [2,1,0,1,1] + train_val_especs: train_val_1: [0,0,2,8,4] - train_val_2: [4,3,1,6,3] + train_val_2: [4,0,2,8,4] # 6 test environments test_especs: diff --git a/configs/model/dg/erm.yaml b/configs/model/dg/erm.yaml index 425efff..d141403 100644 --- a/configs/model/dg/erm.yaml +++ b/configs/model/dg/erm.yaml @@ -1,10 +1,13 @@ defaults: - net.yaml -_target_: src.models.erm_module.ERMMnist +_target_: src.models.erm.ERM +loss: + _target_: torch.nn.CrossEntropyLoss + +n_classes: ${model.net.n_classes} -n_classes: ${model.dg.net.n_classes} optimizer: _target_: torch.optim.Adam _partial_: true @@ -12,8 +15,15 @@ optimizer: weight_decay: 0.0 scheduler: - _target_: torch.optim.lr_scheduler.ReduceLROnPlateau - _partial_: true - mode: min - factor: 0.1 - patience: 10 + monitor: val/loss + interval: epoch + frequency: 1 # ${trainer.check_val_every_n_epoch} + + scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + + \ No newline at end of file diff --git a/configs/model/dg/irm.yaml b/configs/model/dg/irm.yaml index 1173352..9fce481 100644 --- a/configs/model/dg/irm.yaml +++ b/configs/model/dg/irm.yaml @@ -1,11 +1,15 @@ defaults: - net.yaml -_target_: src.models.irm_module.IRMMnist +_target_: src.models.irm.IRM + lamb: 1.0 +loss: + _target_: torch.nn.CrossEntropyLoss n_classes: ${model.dg.net.n_classes} + optimizer: _target_: torch.optim.Adam _partial_: true @@ -13,8 +17,13 @@ optimizer: weight_decay: 0.0 scheduler: - _target_: torch.optim.lr_scheduler.ReduceLROnPlateau - _partial_: true - mode: min - factor: 0.1 - patience: 10 \ No newline at end of file + monitor: val/loss + interval: epoch + frequency: 1 # ${trainer.check_val_every_n_epoch} + + scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 \ No newline at end of file diff --git a/configs/model/dg/lisa.yaml b/configs/model/dg/lisa.yaml index 88bdb87..6ba3e36 100644 --- a/configs/model/dg/lisa.yaml +++ b/configs/model/dg/lisa.yaml @@ -2,11 +2,17 @@ defaults: - net.yaml -_target_: src.models.lisa_module.LISAMnist +_target_: src.models.lisa.LISA + +mixup_strategy: "mixup" ppred: 1.0 # probability of LISA-L mix_alpha: 0.4 # mixup weight +loss: + _target_: torch.nn.CrossEntropyLoss + n_classes: ${model.dg.net.n_classes} + optimizer: _target_: torch.optim.Adam _partial_: true @@ -14,8 +20,13 @@ optimizer: weight_decay: 0.0 scheduler: - _target_: torch.optim.lr_scheduler.ReduceLROnPlateau - _partial_: true - mode: min - factor: 0.1 - patience: 10 \ No newline at end of file + monitor: val/loss + interval: epoch + frequency: 1 # ${trainer.check_val_every_n_epoch} + + scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 \ No newline at end of file diff --git a/configs/train_dg_pa.yaml b/configs/test.yaml similarity index 63% rename from configs/train_dg_pa.yaml rename to configs/test.yaml index 909ce5e..b14e981 100644 --- a/configs/train_dg_pa.yaml +++ b/configs/test.yaml @@ -1,13 +1,9 @@ # @package _global_ -# specify here default configuration -# order of defaults determines the order in which configs override each other defaults: - _self_ - - data/dg: diagvib_multienv.yaml - - model/dg: optimize.yaml + - logger: none.yaml # TODO: see if passing null from CLI also works - callbacks: default.yaml - - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) - trainer: default.yaml - paths: default.yaml - extras: default.yaml @@ -27,27 +23,22 @@ defaults: # debugging config (enable through command line, e.g. `python train.py debug=default) - debug: null +data: ??? +model: ??? + +trainer: + limit_test_batches: 1.0 + # task name, determines output directory path -task_name: "train_pa_dg" +task_name: ??? # tags to help you identify your experiments # you can overwrite this in experiment configs # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` -tags: ["dev"] - -# set False to skip model training -train: True - -# evaluate on test set, using best model weights achieved during training -# lightning chooses best weights based on the metric specified in checkpoint callback -# test: True +tags: [test] # compile model for faster training with pytorch 2.0 compile: False -# simply provide checkpoint path to resume training -ckpt_path: null - # seed for random number generators in pytorch, numpy and python.random -seed: null - +seed: 123 \ No newline at end of file diff --git a/configs/test_pa.yaml b/configs/test_pa.yaml new file mode 100644 index 0000000..d679599 --- /dev/null +++ b/configs/test_pa.yaml @@ -0,0 +1,34 @@ +# @package _global_ + +defaults: + - _self_ + - paths: default.yaml + + # Tests on the DataModules + # - tests@data: /data_pipeline/test_sampler.yaml # passed cpu + # - tests@data: /data_pipeline/test_dataloaders_cifar10.yaml # passed cpu + # - tests@data: /data_pipeline/test_dataloaders_diagvib.yaml # passed cpu + + # Tests on the Trainers for DDP + # - tests@ddp: /ddp/test_ddp.yaml + + # Tests on the PA module + # - tests@pa_module: /pa_module/test_pa_module.yaml # passed cpu + + # Tests on the PA metric + # - tests@pa_metric: /pa_metric/test_basemetric.yaml # passed cpu + # - tests@pa_metric: /pa_metric/test_pametric_cpu.yaml # passed cpu + # - tests@pa_metric: /pa_metric/test_pametric_ddp.yaml + # - tests@pa_metric: /pa_metric/test_pametric_logits.yaml # passed cpu + # - tests@pa_metric: /pa_metric/test_accuracymetric.yaml # STILL NOT TESTED + + # Tests on the PA callback + - tests@pa_callback: /pa_callback/test_callback.yaml # passed cpu + + + +paths: + results_tests: ${paths.root_dir}/results/tests/ + +seed: 1234 + diff --git a/configs/tests/data_pipeline/test_dataloaders_cifar10.yaml b/configs/tests/data_pipeline/test_dataloaders_cifar10.yaml new file mode 100644 index 0000000..83adf8f --- /dev/null +++ b/configs/tests/data_pipeline/test_dataloaders_cifar10.yaml @@ -0,0 +1,54 @@ +# @package _global_ + +defaults: + # Data source + - /data/adv@datamodules.main: cifar10.yaml + - /data/adv@datamodules.pa: cifar10.yaml + - /data/adv@datamodules.pa_logits: cifar10.yaml + + # Attack + - override /data/adv/attack@datamodules.main.attack: FMN.yaml + - override /data/adv/attack@datamodules.pa.attack: FMN.yaml + - override /data/adv/attack@datamodules.pa_logits.attack: FMN.yaml + + # Defense + - override /model/adv/classifier@datamodules.main.classifier: bpda.yaml + - override /model/adv/classifier@datamodules.pa.classifier: bpda.yaml + - override /model/adv/classifier@datamodules.pa_logits.classifier: bpda.yaml # logits uses model twice + +expected_results: + main: + corresponding_labels: True # whether main dataset has corresponding labels + same_model_logits: True # wether main dataset is generated using the same classifier as the one used for logits + pa: + same_model_logits: True + +datamodules: + data_name: cifar10 + main: + cache: True # IMPORTANT + batch_size: 1000 + num_workers: 40 + pin_memory: True + + attack: + steps: 1000 + + pa: + _target_: src.data.cifar10_datamodules.CIFAR10DataModulePA + cache: ${data.datamodules.main.cache} + batch_size: ${data.datamodules.main.batch_size} + num_workers: ${data.datamodules.main.num_workers} + pin_memory: ${data.datamodules.main.pin_memory} + attack: + steps: ${data.datamodules.main.attack.steps} + + pa_logits: + _target_: src.data.cifar10_datamodules.CIFAR10DataModulePAlogits + cache: ${data.datamodules.main.cache} + batch_size: ${data.datamodules.main.batch_size} + num_workers: ${data.datamodules.main.num_workers} + pin_memory: ${data.datamodules.main.pin_memory} + attack: + steps: ${data.datamodules.main.attack.steps} + \ No newline at end of file diff --git a/configs/tests/data_pipeline/test_dataloaders_diagvib.yaml b/configs/tests/data_pipeline/test_dataloaders_diagvib.yaml new file mode 100644 index 0000000..b606c4e --- /dev/null +++ b/configs/tests/data_pipeline/test_dataloaders_diagvib.yaml @@ -0,0 +1,70 @@ +# @package _global_ + +defaults: + # Data source + - /data/dg@datamodules.main: diagvib_multienv.yaml + - /data/dg@datamodules.pa: diagvib_PA.yaml + - /data/dg@datamodules.pa_logits: diagvib_PA.yaml # !! needs classifier + + # Classifier for PA_logits. I write confis here to be easily accessible + - /model/dg/net@datamodules.pa_logits.classifier +exp_name: lisa_rebuttalL +shift_ratio: 1.0 + +expected_results: + main: + corresponding_labels: False # whether main dataset has corresponding labels + + +datamodules: + data_name: diagvib + main: + envs_index: [0,1] + envs_name: singlevar + datasets_dir: ${paths.data_dir}/dg/dg_datasets/test_data_pipeline/ + disjoint_envs: True + train_val_sequential: True + + collate_fn: + _target_: hydra.utils.get_method + path: src.data.components.collate_functions.MultiEnv_collate_fn + batch_size: 64 # Same as the one that was passed to the model. + num_workers: 2 + pin_memory: True + + + pa: + _target_: src.data.diagvib_datamodules.DiagVibDataModulePA + shift_ratio: ${data.shift_ratio} + + envs_index: ${data.datamodules.main.envs_index} + envs_name: train_${data.datamodules.main.envs_name} # remember it is necessary to write 'train', 'test' or 'val' + datasets_dir: ${data.datamodules.main.datasets_dir} + disjoint_envs: ${data.datamodules.main.disjoint_envs} + train_val_sequential: ${data.datamodules.main.train_val_sequential} + + collate_fn: ${data.datamodules.main.collate_fn} + batch_size: ${data.datamodules.main.batch_size} + num_workers: ${data.datamodules.main.num_workers} + pin_memory: ${data.datamodules.main.pin_memory} + + + pa_logits: + _target_: src.data.diagvib_datamodules.DiagVibDataModulePAlogits + shift_ratio: ${data.shift_ratio} + classifier: + _target_: src.models.components.dg_backbone.get_lm_model + exp_name: ${data.exp_name} + log_dir: ${paths.log_dir} + + envs_index: ${data.datamodules.main.envs_index} + envs_name: train_${data.datamodules.main.envs_name} # remember it is necessary to write 'train', 'test' or 'val' + datasets_dir: ${data.datamodules.main.datasets_dir} + disjoint_envs: ${data.datamodules.main.disjoint_envs} + train_val_sequential: ${data.datamodules.main.train_val_sequential} + + collate_fn: ${data.datamodules.main.collate_fn} + batch_size: ${data.datamodules.main.batch_size} + num_workers: ${data.datamodules.main.num_workers} + pin_memory: ${data.datamodules.main.pin_memory} + \ No newline at end of file diff --git a/configs/tests/data_pipeline/test_sampler.yaml b/configs/tests/data_pipeline/test_sampler.yaml new file mode 100644 index 0000000..b9c0692 --- /dev/null +++ b/configs/tests/data_pipeline/test_sampler.yaml @@ -0,0 +1,15 @@ +# @package _global_ + +defaults: + - /data/adv@datamodule: cifar10.yaml + - override /model/adv/classifier@datamodule.classifier: bpda.yaml + - override /data/adv/attack@datamodule.attack: FMN.yaml + +datamodule: + cache: True # IMPORTANT + batch_size: 1000 + num_workers: 40 + pin_memory: True + + attack: + steps: 1000 \ No newline at end of file diff --git a/configs/tests/ddp/test_ddp.yaml b/configs/tests/ddp/test_ddp.yaml new file mode 100644 index 0000000..3f5f670 --- /dev/null +++ b/configs/tests/ddp/test_ddp.yaml @@ -0,0 +1,79 @@ +# @package _global_ + +defaults: + # Data source + - /data/dg@datamodules.main: diagvib_multienv.yaml + - /data/dg@datamodules.pa: diagvib_PA.yaml + - /data/dg@datamodules.pa_logits: diagvib_PA.yaml # !! needs classifier + + # Trainers + - /trainer@trainer.cpu: cpu.yaml + - /trainer@trainer.ddp: ddp.yaml + + # Classifier for PA_logits. I write confis here to be easily accessible + - /model/dg/net@datamodules.pa_logits.classifier +exp_name: lisa_rebuttalL +shift_ratio: 1.0 + +expected_results: + main: + corresponding_labels: False # whether main dataset has corresponding labels + +trainer: + cpu: + max_epochs: 10 + ddp: + max_epochs: 10 + +datamodules: + data_name: diagvib + main: + envs_index: [0,1] + envs_name: singlevar + datasets_dir: ${paths.data_dir}/dg/dg_datasets/test_data_pipeline/ + disjoint_envs: True + train_val_sequential: True + + collate_fn: + _target_: hydra.utils.get_method + path: src.data.components.collate_functions.MultiEnv_collate_fn + batch_size: 64 # Same as the one that was passed to the model. + num_workers: 2 + pin_memory: True + + + pa: + _target_: src.data.diagvib_datamodules.DiagVibDataModulePA + shift_ratio: ${ddp.shift_ratio} + + envs_index: ${ddp.datamodules.main.envs_index} + envs_name: train_${ddp.datamodules.main.envs_name} # remember it is necessary to write 'train', 'test' or 'val' + datasets_dir: ${ddp.datamodules.main.datasets_dir} + disjoint_envs: ${ddp.datamodules.main.disjoint_envs} + train_val_sequential: ${ddp.datamodules.main.train_val_sequential} + + collate_fn: ${ddp.datamodules.main.collate_fn} + batch_size: ${ddp.datamodules.main.batch_size} + num_workers: ${ddp.datamodules.main.num_workers} + pin_memory: ${ddp.datamodules.main.pin_memory} + + + pa_logits: + _target_: src.data.diagvib_datamodules.DiagVibDataModulePAlogits + shift_ratio: ${ddp.shift_ratio} + classifier: + _target_: src.models.components.dg_backbone.get_lm_model + exp_name: ${ddp.exp_name} + log_dir: ${paths.log_dir} + + envs_index: ${ddp.datamodules.main.envs_index} + envs_name: train_${ddp.datamodules.main.envs_name} # remember it is necessary to write 'train', 'test' or 'val' + datasets_dir: ${ddp.datamodules.main.datasets_dir} + disjoint_envs: ${ddp.datamodules.main.disjoint_envs} + train_val_sequential: ${ddp.datamodules.main.train_val_sequential} + + collate_fn: ${ddp.datamodules.main.collate_fn} + batch_size: ${ddp.datamodules.main.batch_size} + num_workers: ${ddp.datamodules.main.num_workers} + pin_memory: ${ddp.datamodules.main.pin_memory} + \ No newline at end of file diff --git a/configs/tests/pa_callback/test_callback.yaml b/configs/tests/pa_callback/test_callback.yaml new file mode 100644 index 0000000..d9e8c83 --- /dev/null +++ b/configs/tests/pa_callback/test_callback.yaml @@ -0,0 +1,118 @@ +# @package _global_ + +defaults: + # Datamodules + - /data/dg@datamodules.main: diagvib_multienv.yaml + - /data/dg@datamodules.pa: diagvib_PA.yaml + + # Trainers + - /trainer@trainers.vanilla.cpu: cpu.yaml + - /trainer@trainers.vanilla.ddp: ddp.yaml + - /trainer@trainers.pa_module.cpu: cpu.yaml + - /trainer@trainers.pa_module.ddp: ddp.yaml + + # PA module + - /model/dg@pa_module: optimize.yaml + + # The net for the Vanilla ERM + - /model/dg@vanilla_model.classifier: net.yaml +exp_name: lisa_rebuttalL +shift_ratio: 1.0 + +trainers: + vanilla: + cpu: + _partial_: true # because it needs the callback + min_epochs: 3 + max_epochs: 3 + logger: false + ddp: + _partial_: true # because it needs the callback + min_epochs: 3 + max_epochs: 3 + logger: false + + # To compute the PA after the model has been trained + pa_module: + cpu: + min_epochs: 10 + max_epochs: 10 + logger: false + ddp: + min_epochs: 10 + max_epochs: 10 + logger: false + +vanilla_model: + _target_: tests.test_pa.pa_callback.Vanilla + # the metric will be sent from the script + _partial_: true + log_every_n_epochs: 2 + num_classes: 2 + classifier: + _target_: src.models.components.dg_backbone.get_lm_model + net: + net: resnet18 + pretrained: false + exp_name: ${pa_callback.exp_name} + log_dir: ${paths.log_dir} + +pa_module: + _target_: src.models.PA_module.PosteriorAgreementModule + # Classifier will be passed from the script + _partial_: true + optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.1 + classifier: null + beta0: 1.0 + num_classes: ${pa_callback.vanilla_model.num_classes} + +pa_metric: + _target_: src.pa_metric.metric.PosteriorAgreement + # dataset will be passed from the script + _partial_: true + pa_epochs: ${pa_callback.trainers.pa_module.cpu.max_epochs} # same as the PA module in all cases + beta0: ${pa_callback.pa_module.beta0} + processing_strategy: cuda + optimizer: null + +pa_callback: + _target_: src.pa_metric.callback.PA_Callback + # dataset will be passed from the script + _partial_: true + pa_epochs: ${pa_callback.trainers.pa_module.cpu.max_epochs} # same as PA module in all cases + log_every_n_epochs: ${pa_callback.vanilla_model.log_every_n_epochs} + beta0: ${pa_callback.pa_module.beta0} + +datamodules: + data_name: diagvib + + main: + envs_index: [0,1] + envs_name: singlevar # remember it is necessary to write 'train', 'test' or 'val' + datasets_dir: ${paths.data_dir}/dg/dg_datasets/test_data_pipeline/ + disjoint_envs: True + train_val_sequential: True + + collate_fn: + _target_: hydra.utils.get_method + path: src.data.components.collate_functions.MultiEnv_collate_fn + batch_size: 16 # same as in the metric + num_workers: 2 + pin_memory: True + + pa: + shift_ratio: ${pa_callback.shift_ratio} + + envs_index: ${pa_callback.datamodules.main.envs_index} + envs_name: train_${pa_callback.datamodules.main.envs_name} # remember it is necessary to write 'train', 'test' or 'val' + datasets_dir: ${pa_callback.datamodules.main.datasets_dir} + disjoint_envs: ${pa_callback.datamodules.main.disjoint_envs} + train_val_sequential: ${pa_callback.datamodules.main.train_val_sequential} + + collate_fn: ${pa_callback.datamodules.main.collate_fn} + batch_size: ${pa_callback.datamodules.main.batch_size} + num_workers: ${pa_callback.datamodules.main.num_workers} + pin_memory: ${pa_callback.datamodules.main.pin_memory} \ No newline at end of file diff --git a/configs/tests/pa_metric/test_accuracymetric.yaml b/configs/tests/pa_metric/test_accuracymetric.yaml new file mode 100644 index 0000000..01ebffd --- /dev/null +++ b/configs/tests/pa_metric/test_accuracymetric.yaml @@ -0,0 +1,40 @@ +# @package _global_ + +defaults: + # Datamodules + - /data/dg@datamodule: diagvib_multienv.yaml + + # Net for the logits PA datamodule + - /model/dg/net@classifier +exp_name: lisa_rebuttalL +shift_ratio: 1.0 + +metric: + _target_: src.pa_metric.metric.PosteriorAccuracy + # dataset, sharpness_factor, processing_strategy and cuda_devices will be passed from the script + _partial_: true + pa_epochs: 10 + beta0: 1.0 + optimizer: null + +classifier: + _target_: src.models.components.dg_backbone.get_lm_model + net: + net: resnet18 + pretrained: false + exp_name: ${pa_metric.exp_name} + log_dir: ${paths.log_dir} + +datamodule: + envs_index: [0,1] + envs_name: singlevar + datasets_dir: ${paths.data_dir}/dg/dg_datasets/test_data_pipeline/ + disjoint_envs: True + train_val_sequential: True + + collate_fn: + _target_: hydra.utils.get_method + path: src.data.components.collate_functions.MultiEnv_collate_fn + batch_size: 64 # Same as the one that was passed to the model. + num_workers: 2 + pin_memory: True \ No newline at end of file diff --git a/configs/tests/pa_metric/test_basemetric.yaml b/configs/tests/pa_metric/test_basemetric.yaml new file mode 100644 index 0000000..7d0d8f2 --- /dev/null +++ b/configs/tests/pa_metric/test_basemetric.yaml @@ -0,0 +1,82 @@ +# @package _global_ + +defaults: + # Datamodules + - /data/dg@logits_datamodule: diagvib_PA.yaml + + # Trainers + - /trainer@trainer.cpu: cpu.yaml + + # PA module: only logits, as it is faster + - /model/dg@pa_module: optimize.yaml + + # Classifier goes to the logits datamodule + - /model/dg/net@datamodules.logits.classifier +exp_name: lisa_rebuttalL +shift_ratio: 1.0 + +# We assume that the PA module passed the test, and therefore the results +# in the CPU are the same as in DDP. +trainer: + cpu: + min_epochs: 1 + max_epochs: 1 + +# We will compare the metric to the PA module +# Make sure we are using the same optimizer and the same LR. +pa_module: + optimizer: + _target_: torch.optim.Adam + classifier: null # because we are using logits + num_classes: 2 + +pa_basemetric: + _target_: src.pa_metric.basemetric.PosteriorAgreementBase + _partial_: true # dataset passed within the script + pa_epochs: ${pa_metric.trainer.cpu.max_epochs} + beta0: ${pa_metric.pa_module.beta0} # same as in the PA module + +datamodules: + main: + _target_: src.data.diagvib_datamodules.DiagVibDataModulePA + shift_ratio: ${pa_metric.shift_ratio} + + envs_index: [0,1] + envs_name: train_singlevar # remember it is necessary to write 'train', 'test' or 'val' + datasets_dir: ${paths.data_dir}/dg/dg_datasets/test_data_pipeline/ + disjoint_envs: True + train_val_sequential: True + + collate_fn: + _target_: hydra.utils.get_method + path: src.data.components.collate_functions.MultiEnv_collate_fn + batch_size: 16 # the same as in the basemetric + num_workers: 2 + pin_memory: True + + + logits: + _target_: src.data.diagvib_datamodules.DiagVibDataModulePAlogits + shift_ratio: ${pa_metric.shift_ratio} + + # This very same classifier will be used in basemetric call + classifier: + _target_: src.models.components.dg_backbone.get_lm_model + log_dir: ${paths.log_dir} + exp_name: ${pa_metric.exp_name} + net: + net: resnet18 + pretrained: false + + # Same as in the previous one + envs_index: ${pa_metric.datamodules.main.envs_index} + envs_name: ${pa_metric.datamodules.main.envs_name} + datasets_dir: ${pa_metric.datamodules.main.datasets_dir} + disjoint_envs: ${pa_metric.datamodules.main.disjoint_envs} + train_val_sequential: ${pa_metric.datamodules.main.train_val_sequential} + + collate_fn: ${pa_metric.datamodules.main.collate_fn} + batch_size: ${pa_metric.datamodules.main.batch_size} + num_workers: ${pa_metric.datamodules.main.num_workers} + pin_memory: ${pa_metric.datamodules.main.pin_memory} + \ No newline at end of file diff --git a/configs/tests/pa_metric/test_pametric_cpu.yaml b/configs/tests/pa_metric/test_pametric_cpu.yaml new file mode 100644 index 0000000..d0fa5dd --- /dev/null +++ b/configs/tests/pa_metric/test_pametric_cpu.yaml @@ -0,0 +1,56 @@ +# @package _global_ + +defaults: + # Datamodules + - /data/dg@datamodules.images: diagvib_multienv.yaml + - /data/dg@datamodules.logits: diagvib_PA.yaml + + # Net for the logits PA datamodule + - /model/dg/net@classifier +exp_name: lisa_rebuttalL +shift_ratio: 1.0 + +classifier: + _target_: src.models.components.dg_backbone.get_lm_model + log_dir: ${paths.log_dir} + exp_name: ${pa_metric.exp_name} + net: + net: resnet18 + pretrained: false + +metrics: + basemetric: + _target_: src.pa_metric.basemetric.PosteriorAgreementBase + _partial_: true # because the dataset will be passed from the script + pa_epochs: 10 + beta0: 1.0 + + fullmetric: + _target_: src.pa_metric.metric.PosteriorAgreement + _partial_: true # because the dataset will be passed from the script + pa_epochs: ${pa_metric.metrics.basemetric.pa_epochs} + beta0: ${pa_metric.metrics.basemetric.beta0} # same as in the basemetric + optimizer: null # use the default, which is the same as in the basemetric + processing_strategy: cpu + cuda_devices: null + +# The dataset for the metric will be initialized by the datamodule +datamodules: + images: + _target_: src.data.diagvib_datamodules.DiagVibDataModulePA + shift_ratio: ${pa_metric.shift_ratio} + + envs_index: [0,1] + envs_name: train_singlevar # remember it is necessary to write 'train', 'test' or 'val' + datasets_dir: ${paths.data_dir}/dg/dg_datasets/test_data_pipeline/ + disjoint_envs: True + train_val_sequential: True + + collate_fn: + _target_: hydra.utils.get_method + path: src.data.components.collate_functions.MultiEnv_collate_fn + batch_size: 16 # the same as in the basemetric + num_workers: 2 + pin_memory: True + + \ No newline at end of file diff --git a/configs/tests/pa_metric/test_pametric_ddp.yaml b/configs/tests/pa_metric/test_pametric_ddp.yaml new file mode 100644 index 0000000..623b178 --- /dev/null +++ b/configs/tests/pa_metric/test_pametric_ddp.yaml @@ -0,0 +1,44 @@ +# @package _global_ + +defaults: + # Datamodules + - /data/dg@logits_datamodule: diagvib_PA.yaml + + # Net for the logits PA datamodule + - /model/dg/net@classifier +exp_name: lisa_rebuttalL +shift_ratio: 1.0 + +metric: + _target_: src.pa_metric.metric.PosteriorAgreement + # dataset, strategy and cuda_devices will be passed from the script + _partial_: true + pa_epochs: 1 + beta0: 1.0 + optimizer: null + +classifier: + _target_: src.models.components.dg_backbone.get_lm_model + log_dir: ${paths.log_dir} + exp_name: ${pa_metric.exp_name} + net: + net: resnet18 + pretrained: false + +datamodules: + images: + _target_: src.data.diagvib_datamodules.DiagVibDataModulePA + shift_ratio: ${pa_metric.shift_ratio} + + envs_index: [0,1] + envs_name: train_singlevar # remember it is necessary to write 'train', 'test' or 'val' + datasets_dir: ${paths.data_dir}/dg/dg_datasets/test_data_pipeline/ + disjoint_envs: True + train_val_sequential: True + + collate_fn: + _target_: hydra.utils.get_method + path: src.data.components.collate_functions.MultiEnv_collate_fn + batch_size: 16 # the same as in the basemetric + num_workers: 2 + pin_memory: True \ No newline at end of file diff --git a/configs/tests/pa_metric/test_pametric_logits.yaml b/configs/tests/pa_metric/test_pametric_logits.yaml new file mode 100644 index 0000000..3cc7894 --- /dev/null +++ b/configs/tests/pa_metric/test_pametric_logits.yaml @@ -0,0 +1,51 @@ +defaults: + # Datamodules + - /data/dg@datamodules.logits: diagvib_PA.yaml + - /data/dg@datamodules.main: diagvib_multienv.yaml + + # Net for the logits PA datamodule + - /model/dg/net@datamodules.pa_logits.classifier +exp_name: lisa_rebuttalL +shift_ratio: 1.0 + +metric: + _target_: src.pa_metric.metric.PosteriorAgreement + # dataset, strategy and cuda_devices will be passed from the script + _partial_: true + pa_epochs: 10 + beta0: 1.0 + optimizer: null + +datamodules: + main: + envs_index: [0,1] + envs_name: singlevar + datasets_dir: ${paths.data_dir}/dg/dg_datasets/test_data_pipeline/ + disjoint_envs: True + train_val_sequential: True + + collate_fn: + _target_: hydra.utils.get_method + path: src.data.components.collate_functions.MultiEnv_collate_fn + batch_size: 64 # Same as the one that was passed to the model. + num_workers: 2 + pin_memory: True + + pa_logits: + _target_: src.data.diagvib_datamodules.DiagVibDataModulePAlogits + shift_ratio: ${pa_metric.shift_ratio} + classifier: + _target_: src.models.components.dg_backbone.get_lm_model + exp_name: ${pa_metric.exp_name} + log_dir: ${paths.log_dir} + + envs_index: ${pa_metric.datamodules.main.envs_index} + envs_name: train_${pa_metric.datamodules.main.envs_name} # remember it is necessary to write 'train', 'test' or 'val' + datasets_dir: ${pa_metric.datamodules.main.datasets_dir} + disjoint_envs: ${pa_metric.datamodules.main.disjoint_envs} + train_val_sequential: ${pa_metric.datamodules.main.train_val_sequential} + + collate_fn: ${pa_metric.datamodules.main.collate_fn} + batch_size: ${pa_metric.datamodules.main.batch_size} + num_workers: ${pa_metric.datamodules.main.num_workers} + pin_memory: ${pa_metric.datamodules.main.pin_memory} \ No newline at end of file diff --git a/configs/tests/pa_module/test_pa_module.yaml b/configs/tests/pa_module/test_pa_module.yaml new file mode 100644 index 0000000..fabbd67 --- /dev/null +++ b/configs/tests/pa_module/test_pa_module.yaml @@ -0,0 +1,85 @@ +# @package _global_ + +defaults: + # Datamodules + - /data/dg@datamodules.pa: diagvib_PA.yaml + - /data/dg@datamodules.pa_logits: diagvib_PA.yaml + + # Trainers + - /trainer@trainer.cpu: cpu.yaml + - /trainer@trainer.ddp: ddp.yaml + + # PA module + - /model/dg@pa_lightningmodule.pa: optimize.yaml + - /model/dg@pa_lightningmodule.pa_logits: optimize.yaml + + # Classifier for logits is the same going to the PA when needed. + - /model/dg/net@pa_lightningmodule.pa.classifier +exp_name: lisa_rebuttalL +shift_ratio: 1.0 + +trainer: + cpu: + min_epochs: 10 + max_epochs: 10 + # Use more epochs in DDP and compare with the truncated vector. + ddp: + min_epochs: 40 + max_epochs: 40 + +pa_lightningmodule: + # When using the PA datamodule + pa: + optimizer: + _target_: torch.optim.Adam + classifier: + exp_name: ${pa_module.exp_name} + net: + net: resnet18 + pretrained: false + num_classes: 2 + + # When using the PA_logits datamodule + pa_logits: + optimizer: ${pa_module.pa_lightningmodule.pa.optimizer} + classifier: null # because we are using logits + num_classes: ${pa_module.pa_lightningmodule.pa.num_classes} + +datamodules: + data_name: diagvib + + pa: + _target_: src.data.diagvib_datamodules.DiagVibDataModulePA + shift_ratio: ${pa_module.shift_ratio} + + envs_index: [0,1] + envs_name: train_singlevar # remember it is necessary to write 'train', 'test' or 'val' + datasets_dir: ${paths.data_dir}/dg/dg_datasets/test_data_pipeline/ + disjoint_envs: True + train_val_sequential: True + + collate_fn: + _target_: hydra.utils.get_method + path: src.data.components.collate_functions.MultiEnv_collate_fn + batch_size: 64 # Same as the one that was passed to the model. + num_workers: 2 + pin_memory: True + + + pa_logits: + _target_: src.data.diagvib_datamodules.DiagVibDataModulePAlogits + shift_ratio: ${pa_module.shift_ratio} + + # Same classifier as the one sent to the PA module (when using the PA datamodule) + classifier: ${pa_module.pa_lightningmodule.pa.classifier} + + envs_index: ${pa_module.datamodules.pa.envs_index} + envs_name: ${pa_module.datamodules.pa.envs_name} + datasets_dir: ${pa_module.datamodules.pa.datasets_dir} + disjoint_envs: ${pa_module.datamodules.pa.disjoint_envs} + train_val_sequential: ${pa_module.datamodules.pa.train_val_sequential} + + collate_fn: ${pa_module.datamodules.pa.collate_fn} + batch_size: ${pa_module.datamodules.pa.batch_size} + num_workers: ${pa_module.datamodules.pa.num_workers} + pin_memory: ${pa_module.datamodules.pa.pin_memory} \ No newline at end of file diff --git a/configs/train_dg.yaml b/configs/train.yaml similarity index 73% rename from configs/train_dg.yaml rename to configs/train.yaml index 2d951f2..d4c9a8f 100644 --- a/configs/train_dg.yaml +++ b/configs/train.yaml @@ -1,20 +1,14 @@ # @package _global_ -# to execute this experiment run: -# python train.py experiment=example - defaults: - _self_ - - data/dg: diagvib_multienv.yaml - - model/dg: erm.yaml + - logger: none.yaml # TODO: See if passing null from the CLI also works - callbacks: default.yaml - - logger: null - trainer: default.yaml - paths: default.yaml - extras: default.yaml - hydra: default.yaml - # experiment configs allow for version control of specific hyperparameters # e.g. best hyperparameters for given model and datamodule - experiment: null @@ -29,20 +23,20 @@ defaults: # debugging config (enable through command line, e.g. `python train.py debug=default) - debug: null +data: ??? +model: ??? + +trainer: + limit_train_batches: 1.0 + limit_val_batches: 1.0 + # task name, determines output directory path -task_name: train_dg +task_name: ??? # tags to help you identify your experiments # you can overwrite this in experiment configs # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` -tags: [dev] - -# set False to skip model training -train: True - -# evaluate on test set, using best model weights achieved during training -# lightning chooses best weights based on the metric specified in checkpoint callback -# test: True +tags: [train] # compile model for faster training with pytorch 2.0 compile: False @@ -51,4 +45,4 @@ compile: False ckpt_path: null # seed for random number generators in pytorch, numpy and python.random -seed: null \ No newline at end of file +seed: 123 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 7766a78..96652dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,50 +1,153 @@ -# --------- general --------- # -setuptools - -# ---------- numpy ---------- # +posterioragreement @ git+https://github.com/viictorjimenezzz/posterioragreement.git@main +diagvibsix @ git+https://github.com/viictorjimenezzz/diagvibsix.git@59cb3760c097a1068621c015e688543d2c854f98 +aiohttp==3.8.5 +aiosignal==1.3.1 +alembic==1.11.1 +annotated-types==0.6.0 +ansible-cmdb==1.31 +antlr4-python3-runtime==4.9.3 +appdirs==1.4.4 +asttokens==2.4.1 +async-timeout==4.0.2 +attrs==23.1.0 +autoattack==0.1 +autocommand==2.2.2 +autopage==0.5.1 +certifi==2023.5.7 +chardet==4.0.0 +charset-normalizer==3.2.0 +cheroot==10.0.0 +CherryPy==18.9.0 +click==8.1.6 +cliff==4.3.0 +cloudpickle==2.2.1 +cmaes==0.10.0 +cmd2==2.4.3 +colorlog==6.7.0 +contourpy==1.1.0 +cycler==0.11.0 +decorator==5.1.1 +docker-pycreds==0.4.0 +eagerpy==0.30.0 +exceptiongroup==1.2.0 +executing==2.0.1 +filelock==3.12.2 +fonttools==4.41.0 +foolbox==3.3.3 +frozenlist==1.4.0 +fsspec==2023.6.0 +geotorch==0.3.0 +gevent==23.7.0 +gitdb==4.0.10 +GitPython==3.1.32 +greenlet==2.0.2 +huggingface-hub==0.21.4 +hydra-colorlog==1.2.0 +hydra-core==1.3.1 +hydra-optuna-sweeper==1.2.0 +hydra-submitit-launcher==1.2.0 +idna==2.10 +imageio==2.31.1 +importlib-metadata==6.8.0 +importlib-resources==6.0.0 +inflect==7.0.0 +ipdb==0.13.13 +ipython==8.18.1 +jaraco.collections==5.0.0 +jaraco.context==4.3.0 +jaraco.functools==4.0.0 +jaraco.text==3.12.0 +jedi==0.19.1 +Jinja2==3.1.2 +joblib==1.3.1 +jsonxs==0.6 +kiwisolver==1.4.4 +lazy_loader==0.3 +lightning-fabric==1.9.1 +lightning-utilities==0.9.0 +littleutils==0.2.2 +Mako==1.2.4 +markdown-it-py==3.0.0 +MarkupSafe==2.1.3 +matplotlib==3.7.2 +matplotlib-inline==0.1.6 +mdurl==0.1.2 +more-itertools==10.2.0 +multidict==6.0.4 +networkx==3.1 numpy==1.22.4 - -# ---------- scikit --------- # +ogb==1.3.6 +omegaconf==2.3.0 +optuna==2.10.1 +outdated==0.2.2 +packaging==23.1 +pandas==1.3.5 +parso==0.8.3 +pathtools==0.1.2 +pbr==5.11.1 +pexpect==4.9.0 +Pillow==10.0.0 +portend==3.2.0 +prettytable==3.8.0 +prompt-toolkit==3.0.43 +protobuf==3.20.0 +psutil==5.9.5 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pydantic==2.6.3 +pydantic_core==2.16.3 +Pygments==2.15.1 +pyparsing==3.0.9 +pyperclip==1.8.2 +pyrootutils==1.0.4 +python-dateutil==2.8.2 +python-dotenv==1.0.0 +pytorch-lightning==1.9.1 +pytz==2023.3 +PyWavelets==1.4.1 +PyYAML==6.0.1 +regex==2023.12.25 +requests==2.25.1 +rich==13.4.2 +robustbench==1.1 +safetensors==0.4.2 +SciencePlots==2.1.0 +scikit-image==0.22.0 scikit-learn==1.1.3 -scikit-image - -# ---------- plots ---------- # -matplotlib -seaborn -# SciencePlots - -# --------- pytorch --------- # +scipy==1.11.1 +seaborn==0.12.2 +secml==0.15.6 +sentry-sdk==1.28.1 +setproctitle==1.3.2 +six==1.16.0 +smmap==5.0.0 +SQLAlchemy==2.0.19 +stack-data==0.6.3 +stevedore==5.1.0 +submitit==1.4.5 +tempora==5.5.1 +termcolor==2.4.0 +threadpoolctl==3.2.0 +tifffile==2023.8.12 +timm==0.6.13 +tokenizers==0.15.2 +tomli==2.0.1 torch==1.10.0 -torchvision>=0.11.0 -pytorch-lightning==1.9.1 +torchdiffeq==0.2.3 torchmetrics==0.11.0 - -# --------- hydra --------- # -hydra-core==1.3.2 -hydra-colorlog==1.2.0 -hydra-submitit-launcher==1.2.0 -# hydra-optuna-sweeper==1.2.0 - -# --------- loggers --------- # -wandb -# neptune-client -# mlflow -# comet-ml -# aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 - -# --------- others --------- # -pyrootutils # standardizing the project root setup -pre-commit # hooks for applying linters on commit -rich # beautiful text formatting in terminal -pytest # tests -# sh # for running bash commands in some tests (linux/macos only) - -# ------ adversarial ------- # -secml -foolbox -diagvibsix @ git+https://github.com/viictorjimenezzz/diagvibsix@librarization -robustbench @ git+https://github.com/RobustBench/robustbench@ec26a6cd0b0812135270c3659caabcab80701b15 - -# ---------- fix ---------- # -gevent # needed to fix a problem with wandb loading -protobuf==3.20.0 # fix pytorch_lightning imports \ No newline at end of file +torchvision==0.11.1 +tqdm==4.65.0 +traitlets==5.14.1 +transformers==4.38.2 +typing==3.7.4.3 +typing_extensions==4.7.1 +urllib3==1.26.16 +ushlex==0.99.1 +wandb==0.15.5 +wcwidth==0.2.6 +wilds==2.0.0 +yarl==1.9.2 +zc.lockfile==3.0.post1 +zipp==3.16.2 +zope.event==5.0 +zope.interface==6.0 diff --git a/scripts/debug_slurm.sh b/scripts/debug_slurm.sh new file mode 100755 index 0000000..6a10ef1 --- /dev/null +++ b/scripts/debug_slurm.sh @@ -0,0 +1,31 @@ +#!/bin/bash +#SBATCH --ntasks=4 +#SBATCH --ntasks-per-node=4 +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=4 +#SBATCH --mem-per-cpu=10G + +# activate conda env +source activate $1 + +srun python3 src/train_dg.py \ + --run \ + experiment=dg/train_dg_irm \ + test=true \ + +data/dg/wilds@data.dg=camelyon17_oracle \ + seed=123 \ + model.dg.net.net=densenet121 \ + data.dg.batch_size=16 \ + data.dg.num_workers=2 \ + data.dg.pin_memory=true \ + name=wilds_test_irm_oracle_full \ + trainer=ddp \ + +trainer.accumulate_grad_batches=2 \ + +trainer.replace_sampler_ddp=True \ + +trainer.fast_dev_run=False \ + trainer.min_epochs=1 \ + trainer.max_epochs=1 \ + logger=wandb \ + logger.wandb.entity=malvai \ + logger.wandb.project=cov_pa \ + logger.wandb.group=test_wilds \ \ No newline at end of file diff --git a/scripts/dg_test_wilds.sh b/scripts/dg_test_wilds.sh new file mode 100755 index 0000000..e9a5b64 --- /dev/null +++ b/scripts/dg_test_wilds.sh @@ -0,0 +1,22 @@ +#!/bin/bash +#SBATCH --ntasks=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=1 +#SBATCH --mem-per-cpu=10G + +# activate conda env +source activate $1 + +# python3 src/test.py \ +srun python3 src/train.py \ + --cfg=job \ + experiment=dg/camelyon17_erm \ + +data/dg/wilds@data=camelyon17_oracle \ + data.transform.is_training=false \ + seed=123 \ + name_logger=prova_configs_train_newconf \ + trainer=cpu \ + +trainer.fast_dev_run=true \ + logger=none \ + # logger.wandb.group=test_wilds \ \ No newline at end of file diff --git a/scripts/dg_train_wilds.sh b/scripts/dg_train_wilds.sh new file mode 100755 index 0000000..c122a9f --- /dev/null +++ b/scripts/dg_train_wilds.sh @@ -0,0 +1,25 @@ +#!/bin/bash +#SBATCH --ntasks=4 +#SBATCH --ntasks-per-node=4 +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=4 +#SBATCH --mem-per-cpu=10G +#SBATCH --time=4:00:00 + +# activate conda env +source activate $1 + +# python3 src/train.py \ +srun python3 src/train.py \ + --multirun \ + experiment=dg/wilds/camelyon17_erm \ + +data/dg/wilds@data=camelyon17_oracle \ + data.transform.is_training=true \ + seed=123 \ + name_logger=prova_configs_train_newconf_call \ + trainer=ddp \ + trainer.max_epochs=3 \ + +trainer.fast_dev_run=false \ + logger=wandb \ + logger.wandb.group=test_wilds \ + # +trainer.deterministic=true \ \ No newline at end of file diff --git a/scripts/train_adv_newrobust.sh b/scripts/eval_adv_pa.sh similarity index 100% rename from scripts/train_adv_newrobust.sh rename to scripts/eval_adv_pa.sh diff --git a/scripts/eval_adv_pa_logits.sh b/scripts/eval_adv_pa_logits.sh new file mode 100755 index 0000000..8cf3a92 --- /dev/null +++ b/scripts/eval_adv_pa_logits.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +python src/train_pa.py \ + --multirun \ + experiment=adv/optimize_beta_logits \ + model/adv/classifier@data.adv.classifier=weak \ + data/adv/attack=GAUSSIAN \ + data.adv.adversarial_ratio=1.0 \ + data.adv.attack.epsilons=1.0 \ + data.adv.batch_size=1000 \ + trainer=ddp \ + logger=wandb \ + logger.wandb.entity=malvai \ + logger.wandb.project=cov_pa \ + logger.wandb.group=foo \ + hydra/launcher=submitit_slurm \ + hydra.launcher.tasks_per_node=1 \ + hydra.launcher.mem_per_cpu=50000 \ + +hydra.launcher.time=4:00:00 \ + +hydra.launcher.num_gpus=4 + # data/adv/attack=PGD,GAUSSIAN,FMN \ + # model/adv/classifier@data.adv.classifier=weak,wong2020,addepalli2021,robust,peng2023,bpda \ + # data.adv.adversarial_ratio=0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0 \ + # data.adv.attack.epsilons=0.0314,0.0627,0.1255 + # data.adv.attack.epsilons=0.0,0.01960784,0.03921569,0.05882353,0.07843137,0.09803922,0.11764706,0.1372549,0.15686275,0.17647059,0.19607843,0.21568627,0.23529412,0.25490196,0.2745098,0.29411765,0.31372549,0.33333333,0.35294118,0.37254902,0.39215686,0.41176471,0.43137255,0.45098039,0.47058824,0.49019608,0.50980392,0.52941176,0.54901961,0.56862745,0.58823529,0.60784314,0.62745098,0.64705882,0.66666667,0.68627451,0.70588235,0.7254902,0.74509804,0.76470588,0.78431373,0.80392157,0.82352941,0.84313725,0.8627451,0.88235294,0.90196078,0.92156863,0.94117647,0.96078431,0.98039216,1.0 \ + # data.adv.attack.steps=1000 \ \ No newline at end of file diff --git a/scripts/train_dg_PA.sh b/scripts/eval_dg_pa.sh similarity index 50% rename from scripts/train_dg_PA.sh rename to scripts/eval_dg_pa.sh index 31214da..2139c49 100755 --- a/scripts/train_dg_PA.sh +++ b/scripts/eval_dg_pa.sh @@ -3,18 +3,17 @@ python3 src/train_dg_pa.py \ --multirun \ experiment=dg/erm_irm_lisa_pa_logits \ - exp_name=lisa_rebuttalD \ - data.dg.envs_index=[0,4] \ - data.dg.shift_ratio=0.6 \ - data.dg.envs_name=test_rebuttal \ - data.dg.disjoint_envs=False \ - data.dg.train_val_sequential=False \ + exp_name=erm_rebuttalTFFF \ + logger.wandb.group=dg_pa_diagvib_TFFF \ + data.dg.envs_index=[0,1],[0,2],[0,3],[0,4],[0,5] \ + data.dg.shift_ratio=0.2,0.4,0.6,0.8,1.0 \ + data.dg.envs_name=test_rebuttaltf \ trainer=ddp \ +trainer.fast_dev_run=False \ +trainer.limit_train_batches=1.0\ trainer.limit_val_batches=1.0 \ - trainer.min_epochs=200 \ - trainer.max_epochs=200 \ + trainer.min_epochs=100 \ + trainer.max_epochs=100 \ logger=wandb \ hydra/launcher=submitit_slurm \ hydra.launcher.tasks_per_node=1 \ @@ -23,5 +22,8 @@ python3 src/train_dg_pa.py \ +hydra.launcher.num_gpus=4 \ logger.wandb.entity=malvai \ logger.wandb.project=cov_pa \ - logger.wandb.group=pa_debugnew + # logger.wandb.group=dg_pa_diagvib_new + # exp_name=erm_rebuttal,irm_rebuttal,lisa_rebuttalD,lisa_rebuttalL \ + # data.dg.envs_index=[0,1],[0,2],[0,3],[0,4],[0,5] \ + # data.dg.shift_ratio=0.2,0.4,0.6,0.8,1.0 \ \ No newline at end of file diff --git a/scripts/pa_trainer.sh b/scripts/pa_trainer.sh new file mode 100755 index 0000000..45f7213 --- /dev/null +++ b/scripts/pa_trainer.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +python3 src/train_dg_pa.py \ + --multirun \ + experiment=dg/pa_training \ + name=erm_trainerdebug \ + data.dg.envs_index=[0,1] \ + data.dg.envs_name=rebuttal \ + data.dg.disjoint_envs=True \ + data.dg.train_val_sequential=True \ + trainer.pa_datamodule.envs_index=[0,4] \ + trainer.pa_datamodule.shift_ratio=0.6 \ + trainer.pa_datamodule.envs_name=test_rebuttal \ + trainer=cpu \ + +trainer.limit_train_batches=0.015 \ + +trainer.limit_val_batches=0.07 \ + trainer.min_epochs=5 \ + trainer.max_epochs=5 \ + logger=none \ + # hydra/launcher=submitit_slurm \ + # hydra.launcher.tasks_per_node=1 \ + # hydra.launcher.mem_per_cpu=100000 \ + # +hydra.launcher.time=4:00:00 \ + # +hydra.launcher.num_gpus=4 \ + # logger.wandb.entity=malvai \ + # logger.wandb.project=cov_pa \ + # logger.wandb.group=pa_debugnew \ No newline at end of file diff --git a/scripts/proves.sh b/scripts/proves.sh new file mode 100755 index 0000000..444075b --- /dev/null +++ b/scripts/proves.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +#SBATCH --nodes=1 +#SBATCH --time=4:00:00 +#SBATCH --mem-per-cpu=10G +#SBATCH --ntasks-per-node=4 +#SBATCH --gpus-per-node=4 + +srun python ./proves.py \ No newline at end of file diff --git a/scripts/run_create_dataset.sh b/scripts/run_create_dataset.sh deleted file mode 100755 index 3207966..0000000 --- a/scripts/run_create_dataset.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -# run this script from the root folder -set -euo pipefail - -cd "$(dirname "$0")/.." - -attack="PGD" -classifier="robust" -epsilons="0.0314" - -# if [ $attack = "PGD" ]; then -# epsilons=$(IFS=","; printf '%s' "0.0314 0.0627 0.1255") -# fi - -cmd="python3 src/generate_adv_data.py \ -experiment=adv/generate_adv_data \ -data/adv/attack=${attack} \ -model/adv/classifier=${classifier} \ -data.adv.attack.params.epsilons=${epsilons};" - -sbatch \ - -J adv_pa \ - -o outputs/generate_adv_data_att=${attack}_clf=${classifier}_eps=${epsilons} \ - --ntasks-per-node=1 \ - --time=120:00:00 \ - --mem-per-cpu=10000 \ - --gpus=1 \ - --wrap "$cmd" - -# for i in $epsilons; do -# exp_dir="exp_adv_${attack}_pr_20230512/exp_adv_${attack}_pr_esp${i}" -# exp_name="exp_${attack}_model${model}_eps${i}_pr0" - -# exp_path="$exp_dir/$exp_name" -# outputs_path="/cluster/project/jbuhmann/posterior_agreement/adv_pa/outputs/$exp_name" - -# # Run DG -# # cmd="python3 evaluate_dg.py experiment=$exp_path;" -# # # echo $cmd -# # sbatch -J dg_pa -o "$outputs_path" --ntasks-per-node=1 --time=4:00:00 --mem-per-cpu=10000 --gpus=1 --wrap "$cmd" -# # echo sbatch -J dg_pa -o "$outputs_path" --ntasks-per-node=1 --time=4:00:00 --mem-per-cpu=10000 --gpus=1 --wrap "$cmd" - -# # Run Adv -# cmd="python3 evaluate_adv.py experiment=$exp_path;" -# # echo $cmd -# sbatch -J adv_pa -o "$outputs_path" --ntasks-per-node=1 --time=120:00:00 --mem-per-cpu=10000 --gpus=1 --wrap "$cmd" -# echo sbatch -J adv_pa -o "$outputs_path" --ntasks-per-node=1 --time=120:00:00 --mem-per-cpu=10000 --gpus=1 --wrap "$cmd" -# done - diff --git a/scripts/schedule.sh b/scripts/schedule.sh deleted file mode 100755 index 44b3da1..0000000 --- a/scripts/schedule.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -# Schedule execution of many runs -# Run from root folder with: bash scripts/schedule.sh - -python src/train.py trainer.max_epochs=5 logger=csv - -python src/train.py trainer.max_epochs=10 logger=csv diff --git a/scripts/submit.sh b/scripts/submit.sh new file mode 100755 index 0000000..eb94c0f --- /dev/null +++ b/scripts/submit.sh @@ -0,0 +1,5 @@ +#!/bin/bash +#SBATCH --nodes=1 +#SBATCH --ntasks=4 +#SBATCH --gpus-per-node=4 +#SBATCH --gpus-per-task=1 \ No newline at end of file diff --git a/scripts/test_pa_cuda.sh b/scripts/test_pa_cuda.sh new file mode 100755 index 0000000..a230432 --- /dev/null +++ b/scripts/test_pa_cuda.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +#SBATCH --mem-per-cpu=50G +#SBATCH --gpus=4 + +python tests/test_pa.py\ + # pa_module.trainer.ddp.devices=4 \ + # ddp.trainer.ddp.devices=4 \ + + # must match with the number of gpus \ No newline at end of file diff --git a/scripts/test_pa_lightning.sh b/scripts/test_pa_lightning.sh new file mode 100755 index 0000000..dd60f5b --- /dev/null +++ b/scripts/test_pa_lightning.sh @@ -0,0 +1,11 @@ +#!/bin/bash +#SBATCH --ntasks=4 +#SBATCH --ntasks-per-node=4 +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=4 +#SBATCH --mem-per-cpu=10G + +# activate conda env +source activate $1 + +srun python3 tests/test_pa.py\ \ No newline at end of file diff --git a/scripts/train_adv_gaussian.sh b/scripts/train_adv_gaussian.sh deleted file mode 100755 index 019d3ea..0000000 --- a/scripts/train_adv_gaussian.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/bash - -python src/train_pa.py \ - --multirun \ - experiment=adv/optimize_beta \ - model/adv/classifier@data.classifier=weak \ - data/adv/attack@data.attack=GAUSSIAN \ - data.attack.epsilons=0.0314,0.0627,0.1255 \ - data.adversarial_ratio=1.0 \ - data.batch_size=1000 \ - logger=wandb \ - trainer=ddp \ - +trainer.fast_dev_run=False \ - +trainer.limit_train_batches=1.0 \ - +trainer.limit_val_batches=1.0 \ - hydra/launcher=submitit_slurm \ - hydra.launcher.tasks_per_node=1 \ - hydra.launcher.mem_per_cpu=10000 \ - +hydra.launcher.time=4:00:00 \ - +hydra.launcher.num_gpus=4 \ - logger.wandb.entity=malvai \ - logger.wandb.project=cov_pa \ - logger.wandb.group=gaussian_vs_PGD #\ - - # --cfg job \ - #data.attack.noise_std=0.01,0.025,0.05,0.075,0.1,0.2,0.3,0.4,0.5 \ \ No newline at end of file diff --git a/scripts/train_adv_newrobust_logits.sh b/scripts/train_adv_newrobust_logits.sh deleted file mode 100755 index 5e10c03..0000000 --- a/scripts/train_adv_newrobust_logits.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash - -python src/train_pa.py \ - --multirun \ - experiment=adv/optimize_beta_logits \ - model/adv/classifier@data.adv.classifier=bpda \ - data/adv/attack=FMN \ - data.adv.adversarial_ratio=0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0 \ - data.adv.attack.steps=1000 \ - data.adv.batch_size=1000 \ - trainer=ddp \ - logger=wandb \ - logger.wandb.entity=malvai \ - logger.wandb.project=cov_pa \ - logger.wandb.group=adversarial_pa \ - hydra/launcher=submitit_slurm \ - hydra.launcher.tasks_per_node=1 \ - hydra.launcher.mem_per_cpu=50000 \ - +hydra.launcher.time=4:00:00 \ - +hydra.launcher.num_gpus=4 - # data/adv/attack=PGD,GAUSSIAN,FMN \ - # model/adv/classifier@data.adv.classifier=weak,wong2020,addepalli2021,robust,wang2023,peng2023,bpda \ - # data.adv.adversarial_ratio=0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0 \ - # data.adv.attack.epsilons=0.0314,0.0627,0.1255 - # data.adv.attack.steps=1000 \ \ No newline at end of file diff --git a/scripts/train_dg_erm_2.sh b/scripts/train_dg_erm_2.sh new file mode 100755 index 0000000..910f2fc --- /dev/null +++ b/scripts/train_dg_erm_2.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +python src/train_dg.py \ + experiment=dg/train_dg_erm \ + test=true \ + +data/dg/wilds@data.dg=camelyon17_oracle \ + seed=123 \ + model.dg.net.net=densenet121 \ + data.dg.batch_size=16 \ + data.dg.num_workers=2 \ + data.dg.pin_memory=true \ + name=wilds_test_erm_oracle_conf \ + trainer=ddp \ + +trainer.accumulate_grad_batches=2 \ + +trainer.replace_sampler_ddp=True \ + +trainer.fast_dev_run=False \ + +trainer.limit_train_batches=0.001 \ + +trainer.limit_val_batches=0.01 \ + +trainer.limit_test_batches=0.01 \ + trainer.min_epochs=1 \ + trainer.max_epochs=1 \ + logger=wandb \ + logger.wandb.entity=malvai \ + logger.wandb.project=cov_pa \ + logger.wandb.group=test_wilds \ + hydra/launcher=submitit_slurm \ + hydra.launcher.tasks_per_node=4 \ + hydra.launcher.gpus_per_node=4 \ + hydra.launcher.mem_per_cpu=10G \ + +hydra.launcher.time=4:00:00 \ + logger.wandb.entity=malvai \ + logger.wandb.project=cov_pa \ + logger.wandb.group=test_wilds \ + \ No newline at end of file diff --git a/scripts/train_dg_erm.sh b/scripts/train_dg_erm_old.sh similarity index 50% rename from scripts/train_dg_erm.sh rename to scripts/train_dg_erm_old.sh index 1bf789d..f21ce36 100755 --- a/scripts/train_dg_erm.sh +++ b/scripts/train_dg_erm_old.sh @@ -2,24 +2,23 @@ python src/train_dg.py \ --multirun \ - experiment=dg/train_dg_erm \ - name=erm_rebuttal \ + experiment=dg/train_dg_erm_old \ data.dg.envs_index=[0,1] \ - data.dg.envs_name=rebuttal \ - data.dg.disjoint_envs=True \ - data.dg.train_val_sequential=True \ + data.dg.envs_name=singlevar \ + name=diagvib_pa \ trainer=ddp \ +trainer.fast_dev_run=False \ +trainer.limit_train_batches=1.0 \ +trainer.limit_val_batches=1.0 \ - trainer.min_epochs=100 \ - trainer.max_epochs=100 \ + trainer.min_epochs=10 \ + trainer.max_epochs=10 \ logger=wandb \ - hydra/launcher=submitit_slurm \ - hydra.launcher.tasks_per_node=1 \ - hydra.launcher.mem_per_cpu=100000 \ - +hydra.launcher.time=4:00:00 \ - +hydra.launcher.num_gpus=4 \ logger.wandb.entity=malvai \ logger.wandb.project=cov_pa \ - logger.wandb.group=pa_meeting \ \ No newline at end of file + logger.wandb.group=test_pametric \ + +hydra.launcher.additional_parameters.gpus=gtx_1080_ti:8 \ + hydra/launcher=submitit_slurm \ + hydra.launcher.tasks_per_node=4 \ + hydra.launcher.mem_per_cpu=10G \ + +hydra.launcher.time=4:00:00 \ + +hydra.launcher.num_gpus=4 \ \ No newline at end of file diff --git a/scripts/train_dg_irm.sh b/scripts/train_dg_irm.sh index c8eec91..f929bde 100755 --- a/scripts/train_dg_irm.sh +++ b/scripts/train_dg_irm.sh @@ -3,23 +3,24 @@ python src/train_dg.py \ --multirun \ experiment=dg/train_dg_irm \ - name=irm_rebuttal \ - data.dg.envs_index=[0,1] \ - data.dg.envs_name=rebuttal \ - data.dg.disjoint_envs=True \ - data.dg.train_val_sequential=True \ + test=true \ + +data/dg/wilds@data.dg=camelyon17_oracle \ + seed=123 \ + model.dg.net.net=densenet121 \ + data.dg.batch_size=16 \ + data.dg.num_workers=2 \ + data.dg.pin_memory=true \ + name=wilds_test_irm_oracle_full \ trainer=ddp \ + +trainer.accumulate_grad_batches=2 \ + +trainer.replace_sampler_ddp=True \ +trainer.fast_dev_run=False \ - +trainer.limit_train_batches=1.0 \ - +trainer.limit_val_batches=1.0 \ - trainer.min_epochs=100 \ - trainer.max_epochs=100 \ + +trainer.limit_train_batches=0.001 \ + +trainer.limit_val_batches=0.001 \ + +trainer.limit_test_batches=0.001 \ + trainer.min_epochs=1 \ + trainer.max_epochs=1 \ logger=wandb \ - hydra/launcher=submitit_slurm \ - hydra.launcher.tasks_per_node=1 \ - hydra.launcher.mem_per_cpu=100000 \ - +hydra.launcher.time=4:00:00 \ - +hydra.launcher.num_gpus=4 \ logger.wandb.entity=malvai \ logger.wandb.project=cov_pa \ - logger.wandb.group=pa_meeting \ No newline at end of file + logger.wandb.group=test_wilds \ \ No newline at end of file diff --git a/scripts/train_dg_lisa.sh b/scripts/train_dg_lisa.sh index 49232ff..d992932 100755 --- a/scripts/train_dg_lisa.sh +++ b/scripts/train_dg_lisa.sh @@ -3,24 +3,30 @@ python src/train_dg.py \ --multirun \ experiment=dg/train_dg_lisa \ - name=lisa_rebuttalL \ + test=true \ + +data/dg/wilds@data.dg=camelyon17_oodval \ + seed=123 \ + model.dg.net.net=densenet121 \ + data.dg.batch_size=16 \ + data.dg.num_workers=2 \ + data.dg.pin_memory=true \ + name=wilds_test_lisa_oodval \ model.dg.ppred=1.0 \ - data.dg.envs_index=[0,1] \ - data.dg.envs_name=rebuttal \ - data.dg.disjoint_envs=True \ - data.dg.train_val_sequential=True \ + model.dg.mix_alpha=2.0 \ trainer=ddp \ +trainer.fast_dev_run=False \ +trainer.limit_train_batches=1.0 \ +trainer.limit_val_batches=1.0 \ - trainer.min_epochs=100 \ - trainer.max_epochs=100 \ + +trainer.limit_test_batches=1.0 \ + trainer.min_epochs=1 \ + trainer.max_epochs=1 \ logger=wandb \ hydra/launcher=submitit_slurm \ hydra.launcher.tasks_per_node=1 \ - hydra.launcher.mem_per_cpu=100000 \ + hydra.launcher.mem_per_cpu=10G \ +hydra.launcher.time=4:00:00 \ +hydra.launcher.num_gpus=4 \ logger.wandb.entity=malvai \ logger.wandb.project=cov_pa \ - logger.wandb.group=pa_meeting \ \ No newline at end of file + logger.wandb.group=test_wilds \ + # +hydra.launcher.additional_parameters.gpus=gtx_1080_ti:8 \ diff --git a/scripts/victor_proves.sh b/scripts/victor_proves.sh new file mode 100755 index 0000000..0eff5c5 --- /dev/null +++ b/scripts/victor_proves.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +#SBATCH --mem-per-cpu=10G + +python ./victor_proves.py \ + --multirun \ + hydra/launcher=submitit_slurm \ + hydra.launcher.tasks_per_node=1 \ + hydra.launcher.mem_per_cpu=50000 \ + +hydra.launcher.time=4:00:00 \ No newline at end of file diff --git a/slurm-52217894.out b/slurm-52217894.out new file mode 100644 index 0000000..0d34bb5 --- /dev/null +++ b/slurm-52217894.out @@ -0,0 +1,986 @@ +[2024-03-22 11:19:54,864][HYDRA] Launching 1 jobs locally +[2024-03-22 11:19:54,864][HYDRA] #0 : experiment=dg/wilds/camelyon17_erm +data/dg/wilds@data=camelyon17_oracle data.transform.is_training=True seed=123 name_logger=prova_configs_train_newconf_call trainer=ddp trainer.max_epochs=1 +trainer.fast_dev_run=False logger=wandb logger.wandb.group=test_wilds +[2024-03-22 11:19:54,889][HYDRA] Launching 1 jobs locally +[2024-03-22 11:19:54,889][HYDRA] #0 : experiment=dg/wilds/camelyon17_erm +data/dg/wilds@data=camelyon17_oracle data.transform.is_training=True seed=123 name_logger=prova_configs_train_newconf_call trainer=ddp trainer.max_epochs=1 +trainer.fast_dev_run=False logger=wandb logger.wandb.group=test_wilds +[2024-03-22 11:19:54,909][HYDRA] Launching 1 jobs locally +[2024-03-22 11:19:54,909][HYDRA] #0 : experiment=dg/wilds/camelyon17_erm +data/dg/wilds@data=camelyon17_oracle data.transform.is_training=True seed=123 name_logger=prova_configs_train_newconf_call trainer=ddp trainer.max_epochs=1 +trainer.fast_dev_run=False logger=wandb logger.wandb.group=test_wilds +[2024-03-22 11:19:54,932][HYDRA] Launching 1 jobs locally +[2024-03-22 11:19:54,933][HYDRA] #0 : experiment=dg/wilds/camelyon17_erm +data/dg/wilds@data=camelyon17_oracle data.transform.is_training=True seed=123 name_logger=prova_configs_train_newconf_call trainer=ddp trainer.max_epochs=1 +trainer.fast_dev_run=False logger=wandb logger.wandb.group=test_wilds +[rank: 0] Global seed set to 123 +[rank: 3] Global seed set to 123 +[rank: 2] Global seed set to 123 +[rank: 1] Global seed set to 123 +[2024-03-22 11:19:55,401][__main__][INFO] - Instantiating datamodule  +[2024-03-22 11:19:55,655][__main__][INFO] - Instantiating model  +/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/utilities/parsing.py:263: UserWarning: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`. + rank_zero_warn( +[2024-03-22 11:19:56,175][__main__][INFO] - Instantiating callbacks... +[2024-03-22 11:19:56,176][src.utils.instantiators][INFO] - Instantiating callback  +[2024-03-22 11:19:56,415][src.utils.instantiators][INFO] - Instantiating callback  +[2024-03-22 11:19:56,428][src.utils.instantiators][INFO] - Instantiating callback  +[2024-03-22 11:19:56,475][src.utils.instantiators][INFO] - Instantiating callback  +[2024-03-22 11:20:06,129][src.utils.instantiators][INFO] - Instantiating callback  +[2024-03-22 11:20:06,136][src.utils.instantiators][INFO] - Instantiating callback  +[2024-03-22 11:20:06,137][src.utils.instantiators][INFO] - Instantiating callback  +[2024-03-22 11:20:06,137][src.utils.instantiators][INFO] - Instantiating callback  +[2024-03-22 11:20:06,138][__main__][INFO] - Instantiating loggers... +[2024-03-22 11:20:06,138][src.utils.instantiators][INFO] - Instantiating logger  +node rank 0, local rank 1, num processes 4 +Pero ara que ho hem canviat, tenim: 4 +node rank 0, local rank 3, num processes 4 +Pero ara que ho hem canviat, tenim: 4 +node rank 0, local rank 2, num processes 4 +Pero ara que ho hem canviat, tenim: 4 +[rank: 1] Global seed set to 123 +[rank: 2] Global seed set to 123 +[rank: 3] Global seed set to 123 +node rank 0, local rank 2, num processes 4 +Pero ara que ho hem canviat, tenim: 4 +node rank 0, local rank 1, num processes 4 +node rank 0, local rank 3, num processes 4 +Pero ara que ho hem canviat, tenim: 4 +Pero ara que ho hem canviat, tenim: 4 +Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4 +Now we have: eu-lo-s4-080, 22894 +AQUI ES DONDE SE PARA Y FALLA: torch_distributed_backend nccl, rank 2, world_size 4, kwargs {'timeout': datetime.timedelta(seconds=1800)} +Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4 +Now we have: eu-lo-s4-080, 22894 +AQUI ES DONDE SE PARA Y FALLA: torch_distributed_backend nccl, rank 3, world_size 4, kwargs {'timeout': datetime.timedelta(seconds=1800)} +Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4 +Now we have: eu-lo-s4-080, 22894 +AQUI ES DONDE SE PARA Y FALLA: torch_distributed_backend nccl, rank 1, world_size 4, kwargs {'timeout': datetime.timedelta(seconds=1800)} +wandb: Currently logged in as: victor-jimenez-rodriguez (malvai). Use `wandb login --relogin` to force relogin +wandb: WARNING Path /cluster/project/jbuhmann/posterior_agreement/adv_pa/logs/victor/dg_wilds/multiruns/2024-03-22_11-19-54/0/dg/wandb/ wasn't writable, using system temp directory. +wandb: WARNING Path /cluster/project/jbuhmann/posterior_agreement/adv_pa/logs/victor/dg_wilds/multiruns/2024-03-22_11-19-54/0/dg/wandb/ wasn't writable, using system temp directory +wandb: - Waiting for wandb.init()... wandb: \ Waiting for wandb.init()... wandb: wandb version 0.16.4 is available! To upgrade, please run: +wandb: $ pip install wandb --upgrade +wandb: Tracking run with wandb version 0.15.5 +wandb: Run data is saved locally in /scratch/tmp.52217894.vjimenez/wandb/run-20240322_112022-s3lg9025 +wandb: Run `wandb offline` to turn off syncing. +wandb: Syncing run prova_configs_train_newconf_call +wandb: ⭐️ View project at https://wandb.ai/malvai/cov_pa +wandb: 🚀 View run at https://wandb.ai/malvai/cov_pa/runs/s3lg9025 +[2024-03-22 11:20:25,124][__main__][INFO] - Instantiating trainer  +node rank 0, local rank 0, num processes 4 +Pero ara que ho hem canviat, tenim: 4 +Trainer already configured with model summary callbacks: []. Skipping setting a default `ModelSummary` callback. +GPU available: True (cuda), used: True +TPU available: False, using: 0 TPU cores +IPU available: False, using: 0 IPUs +HPU available: False, using: 0 HPUs +`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used.. +`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used.. +INSIDE ACCELERATOR CONNECTOR: parallel devices [device(type='cuda', index=0), device(type='cuda', index=1), device(type='cuda', index=2), device(type='cuda', index=3)] +[2024-03-22 11:20:25,386][__main__][INFO] - Logging hyperparameters! +[2024-03-22 11:20:25,398][__main__][INFO] - Starting training! +[rank: 0] Global seed set to 123 +node rank 0, local rank 0, num processes 4 +Pero ara que ho hem canviat, tenim: 4 +Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4 +Now we have: eu-lo-s4-080, 22894 +AQUI ES DONDE SE PARA Y FALLA: torch_distributed_backend nccl, rank 0, world_size 4, kwargs {'timeout': datetime.timedelta(seconds=1800)} +INSIDE ACCELERATOR CONNECTOR: parallel devices [device(type='cuda', index=0), device(type='cuda', index=1), device(type='cuda', index=2), device(type='cuda', index=3)] +[2024-03-22 11:20:25,681][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 1 +INSIDE ACCELERATOR CONNECTOR: parallel devices [device(type='cuda', index=0), device(type='cuda', index=1), device(type='cuda', index=2), device(type='cuda', index=3)] +[2024-03-22 11:20:25,681][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 2 +INSIDE ACCELERATOR CONNECTOR: parallel devices [device(type='cuda', index=0), device(type='cuda', index=1), device(type='cuda', index=2), device(type='cuda', index=3)] +[2024-03-22 11:20:25,681][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 3 +[2024-03-22 11:20:25,683][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 0 +[2024-03-22 11:20:25,683][torch.distributed.distributed_c10d][INFO] - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes. +---------------------------------------------------------------------------------------------------- +distributed_backend=nccl +All distributed processes registered. Starting with 4 processes +---------------------------------------------------------------------------------------------------- + +[2024-03-22 11:20:25,691][torch.distributed.distributed_c10d][INFO] - Rank 1: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes. +[2024-03-22 11:20:25,692][torch.distributed.distributed_c10d][INFO] - Rank 3: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes. +[2024-03-22 11:20:25,692][torch.distributed.distributed_c10d][INFO] - Rank 2: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes. +CUDA_VISIBLE_DEVICES: 0,1,2,3 +CUDA_VISIBLE_DEVICES: 0,1,2,3 +CUDA_VISIBLE_DEVICES: 0,1,2,3 +LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3] +LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3] +LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3] +CUDA_VISIBLE_DEVICES: 0,1,2,3 +LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3] +┏━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━┓ +┃ ┃ Name ┃ Type ┃ Params ┃ +┡━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━┩ +│ 0 │ model │ DGBackbo… │ 7.6 M │ +│ 1 │ model.net │ DenseNet │ 7.6 M │ +│ 2 │ model.net.features │ Sequenti… │ 7.0 M │ +│ 3 │ model.net.features.conv0 │ Conv2d │ 9.4 K │ +│ 4 │ model.net.features.norm0 │ SyncBatc… │ 128 │ +│ 5 │ model.net.features.relu0 │ ReLU │ 0 │ +│ 6 │ model.net.features.pool0 │ MaxPool2d │ 0 │ +│ 7 │ model.net.features.denseblock1 │ _DenseBl… │ 335 K │ +│ 8 │ model.net.features.denseblock1.denselayer1 │ _DenseLa… │ 45.4 K │ +│ 9 │ model.net.features.denseblock1.denselayer1.norm1 │ SyncBatc… │ 128 │ +│ 10 │ model.net.features.denseblock1.denselayer1.relu1 │ ReLU │ 0 │ +│ 11 │ model.net.features.denseblock1.denselayer1.conv1 │ Conv2d │ 8.2 K │ +│ 12 │ model.net.features.denseblock1.denselayer1.norm2 │ SyncBatc… │ 256 │ +│ 13 │ model.net.features.denseblock1.denselayer1.relu2 │ ReLU │ 0 │ +│ 14 │ model.net.features.denseblock1.denselayer1.conv2 │ Conv2d │ 36.9 K │ +│ 15 │ model.net.features.denseblock1.denselayer2 │ _DenseLa… │ 49.6 K │ +│ 16 │ model.net.features.denseblock1.denselayer2.norm1 │ SyncBatc… │ 192 │ +│ 17 │ model.net.features.denseblock1.denselayer2.relu1 │ ReLU │ 0 │ +│ 18 │ model.net.features.denseblock1.denselayer2.conv1 │ Conv2d │ 12.3 K │ +│ 19 │ model.net.features.denseblock1.denselayer2.norm2 │ SyncBatc… │ 256 │ +│ 20 │ model.net.features.denseblock1.denselayer2.relu2 │ ReLU │ 0 │ +│ 21 │ model.net.features.denseblock1.denselayer2.conv2 │ Conv2d │ 36.9 K │ +│ 22 │ model.net.features.denseblock1.denselayer3 │ _DenseLa… │ 53.8 K │ +│ 23 │ model.net.features.denseblock1.denselayer3.norm1 │ SyncBatc… │ 256 │ +│ 24 │ model.net.features.denseblock1.denselayer3.relu1 │ ReLU │ 0 │ +│ 25 │ model.net.features.denseblock1.denselayer3.conv1 │ Conv2d │ 16.4 K │ +│ 26 │ model.net.features.denseblock1.denselayer3.norm2 │ SyncBatc… │ 256 │ +│ 27 │ model.net.features.denseblock1.denselayer3.relu2 │ ReLU │ 0 │ +│ 28 │ model.net.features.denseblock1.denselayer3.conv2 │ Conv2d │ 36.9 K │ +│ 29 │ model.net.features.denseblock1.denselayer4 │ _DenseLa… │ 57.9 K │ +│ 30 │ model.net.features.denseblock1.denselayer4.norm1 │ SyncBatc… │ 320 │ +│ 31 │ model.net.features.denseblock1.denselayer4.relu1 │ ReLU │ 0 │ +│ 32 │ model.net.features.denseblock1.denselayer4.conv1 │ Conv2d │ 20.5 K │ +│ 33 │ model.net.features.denseblock1.denselayer4.norm2 │ SyncBatc… │ 256 │ +│ 34 │ model.net.features.denseblock1.denselayer4.relu2 │ ReLU │ 0 │ +│ 35 │ model.net.features.denseblock1.denselayer4.conv2 │ Conv2d │ 36.9 K │ +│ 36 │ model.net.features.denseblock1.denselayer5 │ _DenseLa… │ 62.1 K │ +│ 37 │ model.net.features.denseblock1.denselayer5.norm1 │ SyncBatc… │ 384 │ +│ 38 │ model.net.features.denseblock1.denselayer5.relu1 │ ReLU │ 0 │ +│ 39 │ model.net.features.denseblock1.denselayer5.conv1 │ Conv2d │ 24.6 K │ +│ 40 │ model.net.features.denseblock1.denselayer5.norm2 │ SyncBatc… │ 256 │ +│ 41 │ model.net.features.denseblock1.denselayer5.relu2 │ ReLU │ 0 │ +│ 42 │ model.net.features.denseblock1.denselayer5.conv2 │ Conv2d │ 36.9 K │ +│ 43 │ model.net.features.denseblock1.denselayer6 │ _DenseLa… │ 66.2 K │ +│ 44 │ model.net.features.denseblock1.denselayer6.norm1 │ SyncBatc… │ 448 │ +│ 45 │ model.net.features.denseblock1.denselayer6.relu1 │ ReLU │ 0 │ +│ 46 │ model.net.features.denseblock1.denselayer6.conv1 │ Conv2d │ 28.7 K │ +│ 47 │ model.net.features.denseblock1.denselayer6.norm2 │ SyncBatc… │ 256 │ +│ 48 │ model.net.features.denseblock1.denselayer6.relu2 │ ReLU │ 0 │ +│ 49 │ model.net.features.denseblock1.denselayer6.conv2 │ Conv2d │ 36.9 K │ +│ 50 │ model.net.features.transition1 │ _Transit… │ 33.3 K │ +│ 51 │ model.net.features.transition1.norm │ SyncBatc… │ 512 │ +│ 52 │ model.net.features.transition1.relu │ ReLU │ 0 │ +│ 53 │ model.net.features.transition1.conv │ Conv2d │ 32.8 K │ +│ 54 │ model.net.features.transition1.pool │ AvgPool2d │ 0 │ +│ 55 │ model.net.features.denseblock2 │ _DenseBl… │ 919 K │ +│ 56 │ model.net.features.denseblock2.denselayer1 │ _DenseLa… │ 53.8 K │ +│ 57 │ model.net.features.denseblock2.denselayer1.norm1 │ SyncBatc… │ 256 │ +│ 58 │ model.net.features.denseblock2.denselayer1.relu1 │ ReLU │ 0 │ +│ 59 │ model.net.features.denseblock2.denselayer1.conv1 │ Conv2d │ 16.4 K │ +│ 60 │ model.net.features.denseblock2.denselayer1.norm2 │ SyncBatc… │ 256 │ +│ 61 │ model.net.features.denseblock2.denselayer1.relu2 │ ReLU │ 0 │ +│ 62 │ model.net.features.denseblock2.denselayer1.conv2 │ Conv2d │ 36.9 K │ +│ 63 │ model.net.features.denseblock2.denselayer2 │ _DenseLa… │ 57.9 K │ +│ 64 │ model.net.features.denseblock2.denselayer2.norm1 │ SyncBatc… │ 320 │ +│ 65 │ model.net.features.denseblock2.denselayer2.relu1 │ ReLU │ 0 │ +│ 66 │ model.net.features.denseblock2.denselayer2.conv1 │ Conv2d │ 20.5 K │ +│ 67 │ model.net.features.denseblock2.denselayer2.norm2 │ SyncBatc… │ 256 │ +│ 68 │ model.net.features.denseblock2.denselayer2.relu2 │ ReLU │ 0 │ +│ 69 │ model.net.features.denseblock2.denselayer2.conv2 │ Conv2d │ 36.9 K │ +│ 70 │ model.net.features.denseblock2.denselayer3 │ _DenseLa… │ 62.1 K │ +│ 71 │ model.net.features.denseblock2.denselayer3.norm1 │ SyncBatc… │ 384 │ +│ 72 │ model.net.features.denseblock2.denselayer3.relu1 │ ReLU │ 0 │ +│ 73 │ model.net.features.denseblock2.denselayer3.conv1 │ Conv2d │ 24.6 K │ +│ 74 │ model.net.features.denseblock2.denselayer3.norm2 │ SyncBatc… │ 256 │ +│ 75 │ model.net.features.denseblock2.denselayer3.relu2 │ ReLU │ 0 │ +│ 76 │ model.net.features.denseblock2.denselayer3.conv2 │ Conv2d │ 36.9 K │ +│ 77 │ model.net.features.denseblock2.denselayer4 │ _DenseLa… │ 66.2 K │ +│ 78 │ model.net.features.denseblock2.denselayer4.norm1 │ SyncBatc… │ 448 │ +│ 79 │ model.net.features.denseblock2.denselayer4.relu1 │ ReLU │ 0 │ +│ 80 │ model.net.features.denseblock2.denselayer4.conv1 │ Conv2d │ 28.7 K │ +│ 81 │ model.net.features.denseblock2.denselayer4.norm2 │ SyncBatc… │ 256 │ +│ 82 │ model.net.features.denseblock2.denselayer4.relu2 │ ReLU │ 0 │ +│ 83 │ model.net.features.denseblock2.denselayer4.conv2 │ Conv2d │ 36.9 K │ +│ 84 │ model.net.features.denseblock2.denselayer5 │ _DenseLa… │ 70.4 K │ +│ 85 │ model.net.features.denseblock2.denselayer5.norm1 │ SyncBatc… │ 512 │ +│ 86 │ model.net.features.denseblock2.denselayer5.relu1 │ ReLU │ 0 │ +│ 87 │ model.net.features.denseblock2.denselayer5.conv1 │ Conv2d │ 32.8 K │ +│ 88 │ model.net.features.denseblock2.denselayer5.norm2 │ SyncBatc… │ 256 │ +│ 89 │ model.net.features.denseblock2.denselayer5.relu2 │ ReLU │ 0 │ +│ 90 │ model.net.features.denseblock2.denselayer5.conv2 │ Conv2d │ 36.9 K │ +│ 91 │ model.net.features.denseblock2.denselayer6 │ _DenseLa… │ 74.6 K │ +│ 92 │ model.net.features.denseblock2.denselayer6.norm1 │ SyncBatc… │ 576 │ +│ 93 │ model.net.features.denseblock2.denselayer6.relu1 │ ReLU │ 0 │ +│ 94 │ model.net.features.denseblock2.denselayer6.conv1 │ Conv2d │ 36.9 K │ +│ 95 │ model.net.features.denseblock2.denselayer6.norm2 │ SyncBatc… │ 256 │ +│ 96 │ model.net.features.denseblock2.denselayer6.relu2 │ ReLU │ 0 │ +│ 97 │ model.net.features.denseblock2.denselayer6.conv2 │ Conv2d │ 36.9 K │ +│ 98 │ model.net.features.denseblock2.denselayer7 │ _DenseLa… │ 78.7 K │ +│ 99 │ model.net.features.denseblock2.denselayer7.norm1 │ SyncBatc… │ 640 │ +│ 100 │ model.net.features.denseblock2.denselayer7.relu1 │ ReLU │ 0 │ +│ 101 │ model.net.features.denseblock2.denselayer7.conv1 │ Conv2d │ 41.0 K │ +│ 102 │ model.net.features.denseblock2.denselayer7.norm2 │ SyncBatc… │ 256 │ +│ 103 │ model.net.features.denseblock2.denselayer7.relu2 │ ReLU │ 0 │ +│ 104 │ model.net.features.denseblock2.denselayer7.conv2 │ Conv2d │ 36.9 K │ +│ 105 │ model.net.features.denseblock2.denselayer8 │ _DenseLa… │ 82.9 K │ +│ 106 │ model.net.features.denseblock2.denselayer8.norm1 │ SyncBatc… │ 704 │ +│ 107 │ model.net.features.denseblock2.denselayer8.relu1 │ ReLU │ 0 │ +│ 108 │ model.net.features.denseblock2.denselayer8.conv1 │ Conv2d │ 45.1 K │ +│ 109 │ model.net.features.denseblock2.denselayer8.norm2 │ SyncBatc… │ 256 │ +│ 110 │ model.net.features.denseblock2.denselayer8.relu2 │ ReLU │ 0 │ +│ 111 │ model.net.features.denseblock2.denselayer8.conv2 │ Conv2d │ 36.9 K │ +│ 112 │ model.net.features.denseblock2.denselayer9 │ _DenseLa… │ 87.0 K │ +│ 113 │ model.net.features.denseblock2.denselayer9.norm1 │ SyncBatc… │ 768 │ +│ 114 │ model.net.features.denseblock2.denselayer9.relu1 │ ReLU │ 0 │ +│ 115 │ model.net.features.denseblock2.denselayer9.conv1 │ Conv2d │ 49.2 K │ +│ 116 │ model.net.features.denseblock2.denselayer9.norm2 │ SyncBatc… │ 256 │ +│ 117 │ model.net.features.denseblock2.denselayer9.relu2 │ ReLU │ 0 │ +│ 118 │ model.net.features.denseblock2.denselayer9.conv2 │ Conv2d │ 36.9 K │ +│ 119 │ model.net.features.denseblock2.denselayer10 │ _DenseLa… │ 91.2 K │ +│ 120 │ model.net.features.denseblock2.denselayer10.norm1 │ SyncBatc… │ 832 │ +│ 121 │ model.net.features.denseblock2.denselayer10.relu1 │ ReLU │ 0 │ +│ 122 │ model.net.features.denseblock2.denselayer10.conv1 │ Conv2d │ 53.2 K │ +│ 123 │ model.net.features.denseblock2.denselayer10.norm2 │ SyncBatc… │ 256 │ +│ 124 │ model.net.features.denseblock2.denselayer10.relu2 │ ReLU │ 0 │ +│ 125 │ model.net.features.denseblock2.denselayer10.conv2 │ Conv2d │ 36.9 K │ +│ 126 │ model.net.features.denseblock2.denselayer11 │ _DenseLa… │ 95.4 K │ +│ 127 │ model.net.features.denseblock2.denselayer11.norm1 │ SyncBatc… │ 896 │ +│ 128 │ model.net.features.denseblock2.denselayer11.relu1 │ ReLU │ 0 │ +│ 129 │ model.net.features.denseblock2.denselayer11.conv1 │ Conv2d │ 57.3 K │ +│ 130 │ model.net.features.denseblock2.denselayer11.norm2 │ SyncBatc… │ 256 │ +│ 131 │ model.net.features.denseblock2.denselayer11.relu2 │ ReLU │ 0 │ +│ 132 │ model.net.features.denseblock2.denselayer11.conv2 │ Conv2d │ 36.9 K │ +│ 133 │ model.net.features.denseblock2.denselayer12 │ _DenseLa… │ 99.5 K │ +│ 134 │ model.net.features.denseblock2.denselayer12.norm1 │ SyncBatc… │ 960 │ +│ 135 │ model.net.features.denseblock2.denselayer12.relu1 │ ReLU │ 0 │ +│ 136 │ model.net.features.denseblock2.denselayer12.conv1 │ Conv2d │ 61.4 K │ +│ 137 │ model.net.features.denseblock2.denselayer12.norm2 │ SyncBatc… │ 256 │ +│ 138 │ model.net.features.denseblock2.denselayer12.relu2 │ ReLU │ 0 │ +│ 139 │ model.net.features.denseblock2.denselayer12.conv2 │ Conv2d │ 36.9 K │ +│ 140 │ model.net.features.transition2 │ _Transit… │ 132 K │ +│ 141 │ model.net.features.transition2.norm │ SyncBatc… │ 1.0 K │ +│ 142 │ model.net.features.transition2.relu │ ReLU │ 0 │ +│ 143 │ model.net.features.transition2.conv │ Conv2d │ 131 K │ +│ 144 │ model.net.features.transition2.pool │ AvgPool2d │ 0 │ +│ 145 │ model.net.features.denseblock3 │ _DenseBl… │ 2.8 M │ +│ 146 │ model.net.features.denseblock3.denselayer1 │ _DenseLa… │ 70.4 K │ +│ 147 │ model.net.features.denseblock3.denselayer1.norm1 │ SyncBatc… │ 512 │ +│ 148 │ model.net.features.denseblock3.denselayer1.relu1 │ ReLU │ 0 │ +│ 149 │ model.net.features.denseblock3.denselayer1.conv1 │ Conv2d │ 32.8 K │ +│ 150 │ model.net.features.denseblock3.denselayer1.norm2 │ SyncBatc… │ 256 │ +│ 151 │ model.net.features.denseblock3.denselayer1.relu2 │ ReLU │ 0 │ +│ 152 │ model.net.features.denseblock3.denselayer1.conv2 │ Conv2d │ 36.9 K │ +│ 153 │ model.net.features.denseblock3.denselayer2 │ _DenseLa… │ 74.6 K │ +│ 154 │ model.net.features.denseblock3.denselayer2.norm1 │ SyncBatc… │ 576 │ +│ 155 │ model.net.features.denseblock3.denselayer2.relu1 │ ReLU │ 0 │ +│ 156 │ model.net.features.denseblock3.denselayer2.conv1 │ Conv2d │ 36.9 K │ +│ 157 │ model.net.features.denseblock3.denselayer2.norm2 │ SyncBatc… │ 256 │ +│ 158 │ model.net.features.denseblock3.denselayer2.relu2 │ ReLU │ 0 │ +│ 159 │ model.net.features.denseblock3.denselayer2.conv2 │ Conv2d │ 36.9 K │ +│ 160 │ model.net.features.denseblock3.denselayer3 │ _DenseLa… │ 78.7 K │ +│ 161 │ model.net.features.denseblock3.denselayer3.norm1 │ SyncBatc… │ 640 │ +│ 162 │ model.net.features.denseblock3.denselayer3.relu1 │ ReLU │ 0 │ +│ 163 │ model.net.features.denseblock3.denselayer3.conv1 │ Conv2d │ 41.0 K │ +│ 164 │ model.net.features.denseblock3.denselayer3.norm2 │ SyncBatc… │ 256 │ +│ 165 │ model.net.features.denseblock3.denselayer3.relu2 │ ReLU │ 0 │ +│ 166 │ model.net.features.denseblock3.denselayer3.conv2 │ Conv2d │ 36.9 K │ +│ 167 │ model.net.features.denseblock3.denselayer4 │ _DenseLa… │ 82.9 K │ +│ 168 │ model.net.features.denseblock3.denselayer4.norm1 │ SyncBatc… │ 704 │ +│ 169 │ model.net.features.denseblock3.denselayer4.relu1 │ ReLU │ 0 │ +│ 170 │ model.net.features.denseblock3.denselayer4.conv1 │ Conv2d │ 45.1 K │ +│ 171 │ model.net.features.denseblock3.denselayer4.norm2 │ SyncBatc… │ 256 │ +│ 172 │ model.net.features.denseblock3.denselayer4.relu2 │ ReLU │ 0 │ +│ 173 │ model.net.features.denseblock3.denselayer4.conv2 │ Conv2d │ 36.9 K │ +│ 174 │ model.net.features.denseblock3.denselayer5 │ _DenseLa… │ 87.0 K │ +│ 175 │ model.net.features.denseblock3.denselayer5.norm1 │ SyncBatc… │ 768 │ +│ 176 │ model.net.features.denseblock3.denselayer5.relu1 │ ReLU │ 0 │ +│ 177 │ model.net.features.denseblock3.denselayer5.conv1 │ Conv2d │ 49.2 K │ +│ 178 │ model.net.features.denseblock3.denselayer5.norm2 │ SyncBatc… │ 256 │ +│ 179 │ model.net.features.denseblock3.denselayer5.relu2 │ ReLU │ 0 │ +│ 180 │ model.net.features.denseblock3.denselayer5.conv2 │ Conv2d │ 36.9 K │ +│ 181 │ model.net.features.denseblock3.denselayer6 │ _DenseLa… │ 91.2 K │ +│ 182 │ model.net.features.denseblock3.denselayer6.norm1 │ SyncBatc… │ 832 │ +│ 183 │ model.net.features.denseblock3.denselayer6.relu1 │ ReLU │ 0 │ +│ 184 │ model.net.features.denseblock3.denselayer6.conv1 │ Conv2d │ 53.2 K │ +│ 185 │ model.net.features.denseblock3.denselayer6.norm2 │ SyncBatc… │ 256 │ +│ 186 │ model.net.features.denseblock3.denselayer6.relu2 │ ReLU │ 0 │ +│ 187 │ model.net.features.denseblock3.denselayer6.conv2 │ Conv2d │ 36.9 K │ +│ 188 │ model.net.features.denseblock3.denselayer7 │ _DenseLa… │ 95.4 K │ +│ 189 │ model.net.features.denseblock3.denselayer7.norm1 │ SyncBatc… │ 896 │ +│ 190 │ model.net.features.denseblock3.denselayer7.relu1 │ ReLU │ 0 │ +│ 191 │ model.net.features.denseblock3.denselayer7.conv1 │ Conv2d │ 57.3 K │ +│ 192 │ model.net.features.denseblock3.denselayer7.norm2 │ SyncBatc… │ 256 │ +│ 193 │ model.net.features.denseblock3.denselayer7.relu2 │ ReLU │ 0 │ +│ 194 │ model.net.features.denseblock3.denselayer7.conv2 │ Conv2d │ 36.9 K │ +│ 195 │ model.net.features.denseblock3.denselayer8 │ _DenseLa… │ 99.5 K │ +│ 196 │ model.net.features.denseblock3.denselayer8.norm1 │ SyncBatc… │ 960 │ +│ 197 │ model.net.features.denseblock3.denselayer8.relu1 │ ReLU │ 0 │ +│ 198 │ model.net.features.denseblock3.denselayer8.conv1 │ Conv2d │ 61.4 K │ +│ 199 │ model.net.features.denseblock3.denselayer8.norm2 │ SyncBatc… │ 256 │ +│ 200 │ model.net.features.denseblock3.denselayer8.relu2 │ ReLU │ 0 │ +│ 201 │ model.net.features.denseblock3.denselayer8.conv2 │ Conv2d │ 36.9 K │ +│ 202 │ model.net.features.denseblock3.denselayer9 │ _DenseLa… │ 103 K │ +│ 203 │ model.net.features.denseblock3.denselayer9.norm1 │ SyncBatc… │ 1.0 K │ +│ 204 │ model.net.features.denseblock3.denselayer9.relu1 │ ReLU │ 0 │ +│ 205 │ model.net.features.denseblock3.denselayer9.conv1 │ Conv2d │ 65.5 K │ +│ 206 │ model.net.features.denseblock3.denselayer9.norm2 │ SyncBatc… │ 256 │ +│ 207 │ model.net.features.denseblock3.denselayer9.relu2 │ ReLU │ 0 │ +│ 208 │ model.net.features.denseblock3.denselayer9.conv2 │ Conv2d │ 36.9 K │ +│ 209 │ model.net.features.denseblock3.denselayer10 │ _DenseLa… │ 107 K │ +│ 210 │ model.net.features.denseblock3.denselayer10.norm1 │ SyncBatc… │ 1.1 K │ +│ 211 │ model.net.features.denseblock3.denselayer10.relu1 │ ReLU │ 0 │ +│ 212 │ model.net.features.denseblock3.denselayer10.conv1 │ Conv2d │ 69.6 K │ +│ 213 │ model.net.features.denseblock3.denselayer10.norm2 │ SyncBatc… │ 256 │ +│ 214 │ model.net.features.denseblock3.denselayer10.relu2 │ ReLU │ 0 │ +│ 215 │ model.net.features.denseblock3.denselayer10.conv2 │ Conv2d │ 36.9 K │ +│ 216 │ model.net.features.denseblock3.denselayer11 │ _DenseLa… │ 112 K │ +│ 217 │ model.net.features.denseblock3.denselayer11.norm1 │ SyncBatc… │ 1.2 K │ +│ 218 │ model.net.features.denseblock3.denselayer11.relu1 │ ReLU │ 0 │ +│ 219 │ model.net.features.denseblock3.denselayer11.conv1 │ Conv2d │ 73.7 K │ +│ 220 │ model.net.features.denseblock3.denselayer11.norm2 │ SyncBatc… │ 256 │ +│ 221 │ model.net.features.denseblock3.denselayer11.relu2 │ ReLU │ 0 │ +│ 222 │ model.net.features.denseblock3.denselayer11.conv2 │ Conv2d │ 36.9 K │ +│ 223 │ model.net.features.denseblock3.denselayer12 │ _DenseLa… │ 116 K │ +│ 224 │ model.net.features.denseblock3.denselayer12.norm1 │ SyncBatc… │ 1.2 K │ +│ 225 │ model.net.features.denseblock3.denselayer12.relu1 │ ReLU │ 0 │ +│ 226 │ model.net.features.denseblock3.denselayer12.conv1 │ Conv2d │ 77.8 K │ +│ 227 │ model.net.features.denseblock3.denselayer12.norm2 │ SyncBatc… │ 256 │ +│ 228 │ model.net.features.denseblock3.denselayer12.relu2 │ ReLU │ 0 │ +│ 229 │ model.net.features.denseblock3.denselayer12.conv2 │ Conv2d │ 36.9 K │ +│ 230 │ model.net.features.denseblock3.denselayer13 │ _DenseLa… │ 120 K │ +│ 231 │ model.net.features.denseblock3.denselayer13.norm1 │ SyncBatc… │ 1.3 K │ +│ 232 │ model.net.features.denseblock3.denselayer13.relu1 │ ReLU │ 0 │ +│ 233 │ model.net.features.denseblock3.denselayer13.conv1 │ Conv2d │ 81.9 K │ +│ 234 │ model.net.features.denseblock3.denselayer13.norm2 │ SyncBatc… │ 256 │ +│ 235 │ model.net.features.denseblock3.denselayer13.relu2 │ ReLU │ 0 │ +│ 236 │ model.net.features.denseblock3.denselayer13.conv2 │ Conv2d │ 36.9 K │ +│ 237 │ model.net.features.denseblock3.denselayer14 │ _DenseLa… │ 124 K │ +│ 238 │ model.net.features.denseblock3.denselayer14.norm1 │ SyncBatc… │ 1.3 K │ +│ 239 │ model.net.features.denseblock3.denselayer14.relu1 │ ReLU │ 0 │ +│ 240 │ model.net.features.denseblock3.denselayer14.conv1 │ Conv2d │ 86.0 K │ +│ 241 │ model.net.features.denseblock3.denselayer14.norm2 │ SyncBatc… │ 256 │ +│ 242 │ model.net.features.denseblock3.denselayer14.relu2 │ ReLU │ 0 │ +│ 243 │ model.net.features.denseblock3.denselayer14.conv2 │ Conv2d │ 36.9 K │ +│ 244 │ model.net.features.denseblock3.denselayer15 │ _DenseLa… │ 128 K │ +│ 245 │ model.net.features.denseblock3.denselayer15.norm1 │ SyncBatc… │ 1.4 K │ +│ 246 │ model.net.features.denseblock3.denselayer15.relu1 │ ReLU │ 0 │ +│ 247 │ model.net.features.denseblock3.denselayer15.conv1 │ Conv2d │ 90.1 K │ +│ 248 │ model.net.features.denseblock3.denselayer15.norm2 │ SyncBatc… │ 256 │ +│ 249 │ model.net.features.denseblock3.denselayer15.relu2 │ ReLU │ 0 │ +│ 250 │ model.net.features.denseblock3.denselayer15.conv2 │ Conv2d │ 36.9 K │ +│ 251 │ model.net.features.denseblock3.denselayer16 │ _DenseLa… │ 132 K │ +│ 252 │ model.net.features.denseblock3.denselayer16.norm1 │ SyncBatc… │ 1.5 K │ +│ 253 │ model.net.features.denseblock3.denselayer16.relu1 │ ReLU │ 0 │ +│ 254 │ model.net.features.denseblock3.denselayer16.conv1 │ Conv2d │ 94.2 K │ +│ 255 │ model.net.features.denseblock3.denselayer16.norm2 │ SyncBatc… │ 256 │ +│ 256 │ model.net.features.denseblock3.denselayer16.relu2 │ ReLU │ 0 │ +│ 257 │ model.net.features.denseblock3.denselayer16.conv2 │ Conv2d │ 36.9 K │ +│ 258 │ model.net.features.denseblock3.denselayer17 │ _DenseLa… │ 136 K │ +│ 259 │ model.net.features.denseblock3.denselayer17.norm1 │ SyncBatc… │ 1.5 K │ +│ 260 │ model.net.features.denseblock3.denselayer17.relu1 │ ReLU │ 0 │ +│ 261 │ model.net.features.denseblock3.denselayer17.conv1 │ Conv2d │ 98.3 K │ +│ 262 │ model.net.features.denseblock3.denselayer17.norm2 │ SyncBatc… │ 256 │ +│ 263 │ model.net.features.denseblock3.denselayer17.relu2 │ ReLU │ 0 │ +│ 264 │ model.net.features.denseblock3.denselayer17.conv2 │ Conv2d │ 36.9 K │ +│ 265 │ model.net.features.denseblock3.denselayer18 │ _DenseLa… │ 141 K │ +│ 266 │ model.net.features.denseblock3.denselayer18.norm1 │ SyncBatc… │ 1.6 K │ +│ 267 │ model.net.features.denseblock3.denselayer18.relu1 │ ReLU │ 0 │ +│ 268 │ model.net.features.denseblock3.denselayer18.conv1 │ Conv2d │ 102 K │ +│ 269 │ model.net.features.denseblock3.denselayer18.norm2 │ SyncBatc… │ 256 │ +│ 270 │ model.net.features.denseblock3.denselayer18.relu2 │ ReLU │ 0 │ +│ 271 │ model.net.features.denseblock3.denselayer18.conv2 │ Conv2d │ 36.9 K │ +│ 272 │ model.net.features.denseblock3.denselayer19 │ _DenseLa… │ 145 K │ +│ 273 │ model.net.features.denseblock3.denselayer19.norm1 │ SyncBatc… │ 1.7 K │ +│ 274 │ model.net.features.denseblock3.denselayer19.relu1 │ ReLU │ 0 │ +│ 275 │ model.net.features.denseblock3.denselayer19.conv1 │ Conv2d │ 106 K │ +│ 276 │ model.net.features.denseblock3.denselayer19.norm2 │ SyncBatc… │ 256 │ +│ 277 │ model.net.features.denseblock3.denselayer19.relu2 │ ReLU │ 0 │ +│ 278 │ model.net.features.denseblock3.denselayer19.conv2 │ Conv2d │ 36.9 K │ +│ 279 │ model.net.features.denseblock3.denselayer20 │ _DenseLa… │ 149 K │ +│ 280 │ model.net.features.denseblock3.denselayer20.norm1 │ SyncBatc… │ 1.7 K │ +│ 281 │ model.net.features.denseblock3.denselayer20.relu1 │ ReLU │ 0 │ +│ 282 │ model.net.features.denseblock3.denselayer20.conv1 │ Conv2d │ 110 K │ +│ 283 │ model.net.features.denseblock3.denselayer20.norm2 │ SyncBatc… │ 256 │ +│ 284 │ model.net.features.denseblock3.denselayer20.relu2 │ ReLU │ 0 │ +│ 285 │ model.net.features.denseblock3.denselayer20.conv2 │ Conv2d │ 36.9 K │ +│ 286 │ model.net.features.denseblock3.denselayer21 │ _DenseLa… │ 153 K │ +│ 287 │ model.net.features.denseblock3.denselayer21.norm1 │ SyncBatc… │ 1.8 K │ +│ 288 │ model.net.features.denseblock3.denselayer21.relu1 │ ReLU │ 0 │ +│ 289 │ model.net.features.denseblock3.denselayer21.conv1 │ Conv2d │ 114 K │ +│ 290 │ model.net.features.denseblock3.denselayer21.norm2 │ SyncBatc… │ 256 │ +│ 291 │ model.net.features.denseblock3.denselayer21.relu2 │ ReLU │ 0 │ +│ 292 │ model.net.features.denseblock3.denselayer21.conv2 │ Conv2d │ 36.9 K │ +│ 293 │ model.net.features.denseblock3.denselayer22 │ _DenseLa… │ 157 K │ +│ 294 │ model.net.features.denseblock3.denselayer22.norm1 │ SyncBatc… │ 1.9 K │ +│ 295 │ model.net.features.denseblock3.denselayer22.relu1 │ ReLU │ 0 │ +│ 296 │ model.net.features.denseblock3.denselayer22.conv1 │ Conv2d │ 118 K │ +│ 297 │ model.net.features.denseblock3.denselayer22.norm2 │ SyncBatc… │ 256 │ +│ 298 │ model.net.features.denseblock3.denselayer22.relu2 │ ReLU │ 0 │ +│ 299 │ model.net.features.denseblock3.denselayer22.conv2 │ Conv2d │ 36.9 K │ +│ 300 │ model.net.features.denseblock3.denselayer23 │ _DenseLa… │ 161 K │ +│ 301 │ model.net.features.denseblock3.denselayer23.norm1 │ SyncBatc… │ 1.9 K │ +│ 302 │ model.net.features.denseblock3.denselayer23.relu1 │ ReLU │ 0 │ +│ 303 │ model.net.features.denseblock3.denselayer23.conv1 │ Conv2d │ 122 K │ +│ 304 │ model.net.features.denseblock3.denselayer23.norm2 │ SyncBatc… │ 256 │ +│ 305 │ model.net.features.denseblock3.denselayer23.relu2 │ ReLU │ 0 │ +│ 306 │ model.net.features.denseblock3.denselayer23.conv2 │ Conv2d │ 36.9 K │ +│ 307 │ model.net.features.denseblock3.denselayer24 │ _DenseLa… │ 166 K │ +│ 308 │ model.net.features.denseblock3.denselayer24.norm1 │ SyncBatc… │ 2.0 K │ +│ 309 │ model.net.features.denseblock3.denselayer24.relu1 │ ReLU │ 0 │ +│ 310 │ model.net.features.denseblock3.denselayer24.conv1 │ Conv2d │ 126 K │ +│ 311 │ model.net.features.denseblock3.denselayer24.norm2 │ SyncBatc… │ 256 │ +│ 312 │ model.net.features.denseblock3.denselayer24.relu2 │ ReLU │ 0 │ +│ 313 │ model.net.features.denseblock3.denselayer24.conv2 │ Conv2d │ 36.9 K │ +│ 314 │ model.net.features.transition3 │ _Transit… │ 526 K │ +│ 315 │ model.net.features.transition3.norm │ SyncBatc… │ 2.0 K │ +│ 316 │ model.net.features.transition3.relu │ ReLU │ 0 │ +│ 317 │ model.net.features.transition3.conv │ Conv2d │ 524 K │ +│ 318 │ model.net.features.transition3.pool │ AvgPool2d │ 0 │ +│ 319 │ model.net.features.denseblock4 │ _DenseBl… │ 2.2 M │ +│ 320 │ model.net.features.denseblock4.denselayer1 │ _DenseLa… │ 103 K │ +│ 321 │ model.net.features.denseblock4.denselayer1.norm1 │ SyncBatc… │ 1.0 K │ +│ 322 │ model.net.features.denseblock4.denselayer1.relu1 │ ReLU │ 0 │ +│ 323 │ model.net.features.denseblock4.denselayer1.conv1 │ Conv2d │ 65.5 K │ +│ 324 │ model.net.features.denseblock4.denselayer1.norm2 │ SyncBatc… │ 256 │ +│ 325 │ model.net.features.denseblock4.denselayer1.relu2 │ ReLU │ 0 │ +│ 326 │ model.net.features.denseblock4.denselayer1.conv2 │ Conv2d │ 36.9 K │ +│ 327 │ model.net.features.denseblock4.denselayer2 │ _DenseLa… │ 107 K │ +│ 328 │ model.net.features.denseblock4.denselayer2.norm1 │ SyncBatc… │ 1.1 K │ +│ 329 │ model.net.features.denseblock4.denselayer2.relu1 │ ReLU │ 0 │ +│ 330 │ model.net.features.denseblock4.denselayer2.conv1 │ Conv2d │ 69.6 K │ +│ 331 │ model.net.features.denseblock4.denselayer2.norm2 │ SyncBatc… │ 256 │ +│ 332 │ model.net.features.denseblock4.denselayer2.relu2 │ ReLU │ 0 │ +│ 333 │ model.net.features.denseblock4.denselayer2.conv2 │ Conv2d │ 36.9 K │ +│ 334 │ model.net.features.denseblock4.denselayer3 │ _DenseLa… │ 112 K │ +│ 335 │ model.net.features.denseblock4.denselayer3.norm1 │ SyncBatc… │ 1.2 K │ +│ 336 │ model.net.features.denseblock4.denselayer3.relu1 │ ReLU │ 0 │ +│ 337 │ model.net.features.denseblock4.denselayer3.conv1 │ Conv2d │ 73.7 K │ +│ 338 │ model.net.features.denseblock4.denselayer3.norm2 │ SyncBatc… │ 256 │ +│ 339 │ model.net.features.denseblock4.denselayer3.relu2 │ ReLU │ 0 │ +│ 340 │ model.net.features.denseblock4.denselayer3.conv2 │ Conv2d │ 36.9 K │ +│ 341 │ model.net.features.denseblock4.denselayer4 │ _DenseLa… │ 116 K │ +│ 342 │ model.net.features.denseblock4.denselayer4.norm1 │ SyncBatc… │ 1.2 K │ +│ 343 │ model.net.features.denseblock4.denselayer4.relu1 │ ReLU │ 0 │ +│ 344 │ model.net.features.denseblock4.denselayer4.conv1 │ Conv2d │ 77.8 K │ +│ 345 │ model.net.features.denseblock4.denselayer4.norm2 │ SyncBatc… │ 256 │ +│ 346 │ model.net.features.denseblock4.denselayer4.relu2 │ ReLU │ 0 │ +│ 347 │ model.net.features.denseblock4.denselayer4.conv2 │ Conv2d │ 36.9 K │ +│ 348 │ model.net.features.denseblock4.denselayer5 │ _DenseLa… │ 120 K │ +│ 349 │ model.net.features.denseblock4.denselayer5.norm1 │ SyncBatc… │ 1.3 K │ +│ 350 │ model.net.features.denseblock4.denselayer5.relu1 │ ReLU │ 0 │ +│ 351 │ model.net.features.denseblock4.denselayer5.conv1 │ Conv2d │ 81.9 K │ +│ 352 │ model.net.features.denseblock4.denselayer5.norm2 │ SyncBatc… │ 256 │ +│ 353 │ model.net.features.denseblock4.denselayer5.relu2 │ ReLU │ 0 │ +│ 354 │ model.net.features.denseblock4.denselayer5.conv2 │ Conv2d │ 36.9 K │ +│ 355 │ model.net.features.denseblock4.denselayer6 │ _DenseLa… │ 124 K │ +│ 356 │ model.net.features.denseblock4.denselayer6.norm1 │ SyncBatc… │ 1.3 K │ +│ 357 │ model.net.features.denseblock4.denselayer6.relu1 │ ReLU │ 0 │ +│ 358 │ model.net.features.denseblock4.denselayer6.conv1 │ Conv2d │ 86.0 K │ +│ 359 │ model.net.features.denseblock4.denselayer6.norm2 │ SyncBatc… │ 256 │ +│ 360 │ model.net.features.denseblock4.denselayer6.relu2 │ ReLU │ 0 │ +│ 361 │ model.net.features.denseblock4.denselayer6.conv2 │ Conv2d │ 36.9 K │ +│ 362 │ model.net.features.denseblock4.denselayer7 │ _DenseLa… │ 128 K │ +│ 363 │ model.net.features.denseblock4.denselayer7.norm1 │ SyncBatc… │ 1.4 K │ +│ 364 │ model.net.features.denseblock4.denselayer7.relu1 │ ReLU │ 0 │ +│ 365 │ model.net.features.denseblock4.denselayer7.conv1 │ Conv2d │ 90.1 K │ +│ 366 │ model.net.features.denseblock4.denselayer7.norm2 │ SyncBatc… │ 256 │ +│ 367 │ model.net.features.denseblock4.denselayer7.relu2 │ ReLU │ 0 │ +│ 368 │ model.net.features.denseblock4.denselayer7.conv2 │ Conv2d │ 36.9 K │ +│ 369 │ model.net.features.denseblock4.denselayer8 │ _DenseLa… │ 132 K │ +│ 370 │ model.net.features.denseblock4.denselayer8.norm1 │ SyncBatc… │ 1.5 K │ +│ 371 │ model.net.features.denseblock4.denselayer8.relu1 │ ReLU │ 0 │ +│ 372 │ model.net.features.denseblock4.denselayer8.conv1 │ Conv2d │ 94.2 K │ +│ 373 │ model.net.features.denseblock4.denselayer8.norm2 │ SyncBatc… │ 256 │ +│ 374 │ model.net.features.denseblock4.denselayer8.relu2 │ ReLU │ 0 │ +│ 375 │ model.net.features.denseblock4.denselayer8.conv2 │ Conv2d │ 36.9 K │ +│ 376 │ model.net.features.denseblock4.denselayer9 │ _DenseLa… │ 136 K │ +│ 377 │ model.net.features.denseblock4.denselayer9.norm1 │ SyncBatc… │ 1.5 K │ +│ 378 │ model.net.features.denseblock4.denselayer9.relu1 │ ReLU │ 0 │ +│ 379 │ model.net.features.denseblock4.denselayer9.conv1 │ Conv2d │ 98.3 K │ +│ 380 │ model.net.features.denseblock4.denselayer9.norm2 │ SyncBatc… │ 256 │ +│ 381 │ model.net.features.denseblock4.denselayer9.relu2 │ ReLU │ 0 │ +│ 382 │ model.net.features.denseblock4.denselayer9.conv2 │ Conv2d │ 36.9 K │ +│ 383 │ model.net.features.denseblock4.denselayer10 │ _DenseLa… │ 141 K │ +│ 384 │ model.net.features.denseblock4.denselayer10.norm1 │ SyncBatc… │ 1.6 K │ +│ 385 │ model.net.features.denseblock4.denselayer10.relu1 │ ReLU │ 0 │ +│ 386 │ model.net.features.denseblock4.denselayer10.conv1 │ Conv2d │ 102 K │ +│ 387 │ model.net.features.denseblock4.denselayer10.norm2 │ SyncBatc… │ 256 │ +│ 388 │ model.net.features.denseblock4.denselayer10.relu2 │ ReLU │ 0 │ +│ 389 │ model.net.features.denseblock4.denselayer10.conv2 │ Conv2d │ 36.9 K │ +│ 390 │ model.net.features.denseblock4.denselayer11 │ _DenseLa… │ 145 K │ +│ 391 │ model.net.features.denseblock4.denselayer11.norm1 │ SyncBatc… │ 1.7 K │ +│ 392 │ model.net.features.denseblock4.denselayer11.relu1 │ ReLU │ 0 │ +│ 393 │ model.net.features.denseblock4.denselayer11.conv1 │ Conv2d │ 106 K │ +│ 394 │ model.net.features.denseblock4.denselayer11.norm2 │ SyncBatc… │ 256 │ +│ 395 │ model.net.features.denseblock4.denselayer11.relu2 │ ReLU │ 0 │ +│ 396 │ model.net.features.denseblock4.denselayer11.conv2 │ Conv2d │ 36.9 K │ +│ 397 │ model.net.features.denseblock4.denselayer12 │ _DenseLa… │ 149 K │ +│ 398 │ model.net.features.denseblock4.denselayer12.norm1 │ SyncBatc… │ 1.7 K │ +│ 399 │ model.net.features.denseblock4.denselayer12.relu1 │ ReLU │ 0 │ +│ 400 │ model.net.features.denseblock4.denselayer12.conv1 │ Conv2d │ 110 K │ +│ 401 │ model.net.features.denseblock4.denselayer12.norm2 │ SyncBatc… │ 256 │ +│ 402 │ model.net.features.denseblock4.denselayer12.relu2 │ ReLU │ 0 │ +│ 403 │ model.net.features.denseblock4.denselayer12.conv2 │ Conv2d │ 36.9 K │ +│ 404 │ model.net.features.denseblock4.denselayer13 │ _DenseLa… │ 153 K │ +│ 405 │ model.net.features.denseblock4.denselayer13.norm1 │ SyncBatc… │ 1.8 K │ +│ 406 │ model.net.features.denseblock4.denselayer13.relu1 │ ReLU │ 0 │ +│ 407 │ model.net.features.denseblock4.denselayer13.conv1 │ Conv2d │ 114 K │ +│ 408 │ model.net.features.denseblock4.denselayer13.norm2 │ SyncBatc… │ 256 │ +│ 409 │ model.net.features.denseblock4.denselayer13.relu2 │ ReLU │ 0 │ +│ 410 │ model.net.features.denseblock4.denselayer13.conv2 │ Conv2d │ 36.9 K │ +│ 411 │ model.net.features.denseblock4.denselayer14 │ _DenseLa… │ 157 K │ +│ 412 │ model.net.features.denseblock4.denselayer14.norm1 │ SyncBatc… │ 1.9 K │ +│ 413 │ model.net.features.denseblock4.denselayer14.relu1 │ ReLU │ 0 │ +│ 414 │ model.net.features.denseblock4.denselayer14.conv1 │ Conv2d │ 118 K │ +│ 415 │ model.net.features.denseblock4.denselayer14.norm2 │ SyncBatc… │ 256 │ +│ 416 │ model.net.features.denseblock4.denselayer14.relu2 │ ReLU │ 0 │ +│ 417 │ model.net.features.denseblock4.denselayer14.conv2 │ Conv2d │ 36.9 K │ +│ 418 │ model.net.features.denseblock4.denselayer15 │ _DenseLa… │ 161 K │ +│ 419 │ model.net.features.denseblock4.denselayer15.norm1 │ SyncBatc… │ 1.9 K │ +│ 420 │ model.net.features.denseblock4.denselayer15.relu1 │ ReLU │ 0 │ +│ 421 │ model.net.features.denseblock4.denselayer15.conv1 │ Conv2d │ 122 K │ +│ 422 │ model.net.features.denseblock4.denselayer15.norm2 │ SyncBatc… │ 256 │ +│ 423 │ model.net.features.denseblock4.denselayer15.relu2 │ ReLU │ 0 │ +│ 424 │ model.net.features.denseblock4.denselayer15.conv2 │ Conv2d │ 36.9 K │ +│ 425 │ model.net.features.denseblock4.denselayer16 │ _DenseLa… │ 166 K │ +│ 426 │ model.net.features.denseblock4.denselayer16.norm1 │ SyncBatc… │ 2.0 K │ +│ 427 │ model.net.features.denseblock4.denselayer16.relu1 │ ReLU │ 0 │ +│ 428 │ model.net.features.denseblock4.denselayer16.conv1 │ Conv2d │ 126 K │ +│ 429 │ model.net.features.denseblock4.denselayer16.norm2 │ SyncBatc… │ 256 │ +│ 430 │ model.net.features.denseblock4.denselayer16.relu2 │ ReLU │ 0 │ +│ 431 │ model.net.features.denseblock4.denselayer16.conv2 │ Conv2d │ 36.9 K │ +│ 432 │ model.net.features.norm5 │ SyncBatc… │ 2.0 K │ +│ 433 │ model.net.classifier │ Sequenti… │ 642 K │ +│ 434 │ model.net.classifier.0 │ Linear │ 512 K │ +│ 435 │ model.net.classifier.1 │ SyncBatc… │ 1.0 K │ +│ 436 │ model.net.classifier.2 │ Dropout │ 0 │ +│ 437 │ model.net.classifier.3 │ Linear │ 128 K │ +│ 438 │ model.net.classifier.4 │ Linear │ 514 │ +│ 439 │ loss │ CrossEnt… │ 0 │ +└─────┴───────────────────────────────────────────────────┴───────────┴────────┘ +Trainable params: 7.6 M +Non-trainable params: 0 +Total params: 7.6 M +Total estimated model params size (MB): 30 +SLURM auto-requeueing enabled. Setting signal handlers. +SLURM auto-requeueing enabled. Setting signal handlers. +SLURM auto-requeueing enabled. Setting signal handlers. +SLURM auto-requeueing enabled. Setting signal handlers. +/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:232: UserWarning: You called `self.log('debug/val_len', ...)` in your `on_validation_batch_end` but the value needs to be floating point. Converting it to torch.float32. + warning_cache.warn( +/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:232: UserWarning: You called `self.log('debug/val_right', ...)` in your `on_validation_batch_end` but the value needs to be floating point. Converting it to torch.float32. + warning_cache.warn( +/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 72 which is the number of cpus on this machine) in the `DataLoader` init to improve performance. + rank_zero_warn( +/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:232: UserWarning: You called `self.log('debug/train_len', ...)` in your `on_train_batch_end` but the value needs to be floating point. Converting it to torch.float32. + warning_cache.warn( +/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:232: UserWarning: You called `self.log('debug/train_right', ...)` in your `on_train_batch_end` but the value needs to be floating point. Converting it to torch.float32. + warning_cache.warn( +[W reducer.cpp:1303] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[W reducer.cpp:1303] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[W reducer.cpp:1303] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[W reducer.cpp:1303] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/utilities/data.py:84: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 16. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`. + warning_cache.warn( +/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/utilities/data.py:84: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 12. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`. + warning_cache.warn( +Epoch 0/0 ━━━━━━━━━━━━━━━━ 2610/2610 0:13:02 • 7.52it/s loss: 0.0348 + 0:00:00 v_num: 9025 +[2024-03-22 11:33:41,913][src.utils.utils][ERROR] -  +Traceback (most recent call last): + File "/cluster/home/vjimenez/adv_pa_new/src/utils/utils.py", line 65, in wrap + metric_dict, object_dict = task_func(cfg=cfg) + File "/cluster/home/vjimenez/adv_pa_new/src/train.py", line 84, in train + trainer.fit( + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit + call._call_and_handle_interrupt( + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt + return trainer_fn(*args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl + self._run(model, ckpt_path=self.ckpt_path) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1112, in _run + results = self._run_stage() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1191, in _run_stage + self._run_train() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1214, in _run_train + self.fit_loop.run() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run + self.on_advance_end() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 295, in on_advance_end + self.trainer._call_callback_hooks("on_train_epoch_end") + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1394, in _call_callback_hooks + fn(self, self.lightning_module, *args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/lightning_callback.py", line 52, in on_train_epoch_end + pa_dict = self.pa_metric( + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl + return forward_call(*input, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/basemetric.py", line 293, in forward + self.update(*args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torchmetrics/metric.py", line 388, in wrapped_func + update(*args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/metric.py", line 442, in update + self.pa_update(dist.get_rank() if self.processing_strategy == "cuda" else local_rank) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/metric.py", line 376, in pa_update + logits_dataset = self._compute_logits_dataset(rank) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/metric.py", line 350, in _compute_logits_dataset + for bidx, batch in enumerate(dataloader): + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/dataloader.py", line 521, in __next__ + data = self._next_data() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data + return self._process_data(data) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data + data.reraise() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/_utils.py", line 434, in reraise + raise exception +IndexError: Caught IndexError in DataLoader worker process 0. +Original Traceback (most recent call last): + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop + data = fetcher.fetch(index) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch + data = [self.dataset[idx] for idx in possibly_batched_index] + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in + data = [self.dataset[idx] for idx in possibly_batched_index] + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/datautils.py", line 50, in __getitem__ + return {str(i): dset[self.permutation[i][idx]] for i, dset in enumerate(self.dset_list)} + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/datautils.py", line 50, in + return {str(i): dset[self.permutation[i][idx]] for i, dset in enumerate(self.dset_list)} + File "/cluster/home/vjimenez/adv_pa_new/src/data/components/wilds_dataset.py", line 96, in __getitem__ + selected_idx = self.inds_to_select[idx] +IndexError: list index out of range + +[2024-03-22 11:33:41,930][src.utils.utils][INFO] - Output dir: /cluster/project/jbuhmann/posterior_agreement/adv_pa/logs/victor/dg_wilds/multiruns/2024-03-22_11-19-54/0 +[2024-03-22 11:33:41,931][src.utils.utils][INFO] - Closing wandb! +Error executing job with overrides: ['experiment=dg/wilds/camelyon17_erm', '+data/dg/wilds@data=camelyon17_oracle', 'data.transform.is_training=True', 'seed=123', 'name_logger=prova_configs_train_newconf_call', 'trainer=ddp', 'trainer.max_epochs=1', '+trainer.fast_dev_run=False', 'logger=wandb', 'logger.wandb.group=test_wilds'] +Error executing job with overrides: ['experiment=dg/wilds/camelyon17_erm', '+data/dg/wilds@data=camelyon17_oracle', 'data.transform.is_training=True', 'seed=123', 'name_logger=prova_configs_train_newconf_call', 'trainer=ddp', 'trainer.max_epochs=1', '+trainer.fast_dev_run=False', 'logger=wandb', 'logger.wandb.group=test_wilds'] +wandb: Waiting for W&B process to finish... (success). +Error executing job with overrides: ['experiment=dg/wilds/camelyon17_erm', '+data/dg/wilds@data=camelyon17_oracle', 'data.transform.is_training=True', 'seed=123', 'name_logger=prova_configs_train_newconf_call', 'trainer=ddp', 'trainer.max_epochs=1', '+trainer.fast_dev_run=False', 'logger=wandb', 'logger.wandb.group=test_wilds'] +wandb: +wandb: Run history: +wandb: debug/val_len ▁ +wandb: debug/val_right ▁ +wandb: epoch ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ +wandb: train/acc_0_step ▂▇▁▂▂▇▄▅▄█▄█▄▄▄▅▅▅▇█ +wandb: train/acc_1_step ▃▅▄▆▆▃▁▅▄▅▅▆▆▇▇▇▇▇▃█ +wandb: train/acc_2_step ▃▅▃▆▇▃▁▇▇▇▁▆▅▃▇██▇▇▆ +wandb: train/acc_3_step ▁▅▅▇▆█▇█▅▇██▇▇█████▇ +wandb: train/acc_average_step ▁▅▂▅▅▅▂▇▅▇▄█▅▅▇███▆█ +wandb: train/acc_step ▁▅▂▅▅▅▂▇▅▇▄█▅▅▇███▆█ +wandb: train/f1_step ▁▅▃▅▆▆▃▇▅▇▅█▆▅▇██▇▆█ +wandb: train/loss_step █▄▆▃▃▄▅▃▃▂▃▂▃▄▂▁▁▁▂▁ +wandb: train/precision_step ▄▅▄▅▃▆▁▅▅▇▆█▅▅█▇█▇▄▇ +wandb: train/sensitivity_step ▁▆▃▆▇▅▅█▅▇▅▇▆▆▇█▇▇██ +wandb: train/specificity_step ▅▅▄▅▄▆▁▅▅▇▆█▅▅█▇█▇▅▇ +wandb: trainer/global_step ▄█▁▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆ +wandb: val/acc_0_epoch ▁ +wandb: val/acc_0_step ▅▇██▃▇█▇█▇▇▆█▅▇▇█▇▅▄▅▂▃▅▃▅▂▅▅▁▅▆▂▅▃▂▂▂▄▄ +wandb: val/acc_average_epoch ▁ +wandb: val/acc_average_step ▅▇██▃▇█▇█▇▇▆█▅▇▇█▇▅▄▅▂▃▅▃▅▂▅▅▁▅▆▂▅▃▂▂▂▄▄ +wandb: val/acc_epoch ▁ +wandb: val/acc_step ▅▇██▃▇█▇█▇▇▆█▅▇▇█▇▅▄▅▂▃▅▃▅▂▅▅▁▅▆▂▅▃▂▂▂▄▄ +wandb: val/f1_epoch ▁ +wandb: val/f1_step ████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ +wandb: val/loss_epoch ▁ +wandb: val/loss_step ▅▂▁▂▇▄▃▁▂▂▃▄▂▄▃▃▁▂▆▅▆▆▇▆▆▄▅▄▅█▄▅▇▆▇▆█▆▇▇ +wandb: val/precision_epoch ▁ +wandb: val/precision_step ████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ +wandb: val/sensitivity_epoch ▁ +wandb: val/sensitivity_step ▇███▇█████████████▇▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ +wandb: val/specificity_epoch ▁ +wandb: val/specificity_step ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▇████▇██▇██▇██▇▇▇██ +wandb: +wandb: Run summary: +wandb: debug/val_len 4402.99658 +wandb: debug/val_right 4089.56323 +wandb: epoch 0 +wandb: train/acc_0_step 1.0 +wandb: train/acc_1_step 1.0 +wandb: train/acc_2_step 0.96875 +wandb: train/acc_3_step 0.98438 +wandb: train/acc_average_step 0.98828 +wandb: train/acc_step 0.98828 +wandb: train/f1_step 0.99071 +wandb: train/loss_step 0.02109 +wandb: train/precision_step 0.99419 +wandb: train/sensitivity_step 0.98747 +wandb: train/specificity_step 0.98864 +wandb: trainer/global_step 1031 +wandb: val/acc_0_epoch 0.91869 +wandb: val/acc_0_step 0.79167 +wandb: val/acc_average_epoch 0.91869 +wandb: val/acc_average_step 0.79167 +wandb: val/acc_epoch 0.91869 +wandb: val/acc_step 0.79167 +wandb: val/f1_epoch 0.48514 +wandb: val/f1_step 0.0 +wandb: val/loss_epoch 0.2321 +wandb: val/loss_step 0.38541 +wandb: val/precision_epoch 0.50005 +wandb: val/precision_step 0.0 +wandb: val/sensitivity_epoch 0.4744 +wandb: val/sensitivity_step 0.0 +wandb: val/specificity_epoch 0.44596 +wandb: val/specificity_step 0.79167 +wandb: +wandb: 🚀 View run prova_configs_train_newconf_call at: https://wandb.ai/malvai/cov_pa/runs/s3lg9025 +wandb: Synced 5 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s) +wandb: Find logs at: /scratch/tmp.52217894.vjimenez/wandb/run-20240322_112022-s3lg9025/logs +Error executing job with overrides: ['experiment=dg/wilds/camelyon17_erm', '+data/dg/wilds@data=camelyon17_oracle', 'data.transform.is_training=True', 'seed=123', 'name_logger=prova_configs_train_newconf_call', 'trainer=ddp', 'trainer.max_epochs=1', '+trainer.fast_dev_run=False', 'logger=wandb', 'logger.wandb.group=test_wilds'] +Traceback (most recent call last): + File "/cluster/home/vjimenez/adv_pa_new/src/train.py", line 120, in main + metric_dict, _ = train(cfg) + File "/cluster/home/vjimenez/adv_pa_new/src/utils/utils.py", line 75, in wrap + raise ex + File "/cluster/home/vjimenez/adv_pa_new/src/utils/utils.py", line 65, in wrap + metric_dict, object_dict = task_func(cfg=cfg) + File "/cluster/home/vjimenez/adv_pa_new/src/train.py", line 84, in train + trainer.fit( + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit + call._call_and_handle_interrupt( + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt + return trainer_fn(*args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl + self._run(model, ckpt_path=self.ckpt_path) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1112, in _run + results = self._run_stage() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1191, in _run_stage + self._run_train() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1214, in _run_train + self.fit_loop.run() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run + self.on_advance_end() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 295, in on_advance_end + self.trainer._call_callback_hooks("on_train_epoch_end") + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1394, in _call_callback_hooks + fn(self, self.lightning_module, *args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/lightning_callback.py", line 52, in on_train_epoch_end + pa_dict = self.pa_metric( + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl + return forward_call(*input, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/basemetric.py", line 293, in forward + self.update(*args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torchmetrics/metric.py", line 388, in wrapped_func + update(*args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/metric.py", line 442, in update + self.pa_update(dist.get_rank() if self.processing_strategy == "cuda" else local_rank) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/metric.py", line 376, in pa_update + logits_dataset = self._compute_logits_dataset(rank) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/metric.py", line 350, in _compute_logits_dataset + for bidx, batch in enumerate(dataloader): + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/dataloader.py", line 521, in __next__ + data = self._next_data() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data + return self._process_data(data) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data + data.reraise() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/_utils.py", line 434, in reraise + raise exception +IndexError: Caught IndexError in DataLoader worker process 0. +Original Traceback (most recent call last): + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop + data = fetcher.fetch(index) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch + data = [self.dataset[idx] for idx in possibly_batched_index] + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in + data = [self.dataset[idx] for idx in possibly_batched_index] + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/datautils.py", line 50, in __getitem__ + return {str(i): dset[self.permutation[i][idx]] for i, dset in enumerate(self.dset_list)} + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/datautils.py", line 50, in + return {str(i): dset[self.permutation[i][idx]] for i, dset in enumerate(self.dset_list)} + File "/cluster/home/vjimenez/adv_pa_new/src/data/components/wilds_dataset.py", line 96, in __getitem__ + selected_idx = self.inds_to_select[idx] +IndexError: list index out of range + + +Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace. +Traceback (most recent call last): + File "/cluster/home/vjimenez/adv_pa_new/src/train.py", line 120, in main + metric_dict, _ = train(cfg) + File "/cluster/home/vjimenez/adv_pa_new/src/utils/utils.py", line 75, in wrap + raise ex + File "/cluster/home/vjimenez/adv_pa_new/src/utils/utils.py", line 65, in wrap + metric_dict, object_dict = task_func(cfg=cfg) + File "/cluster/home/vjimenez/adv_pa_new/src/train.py", line 84, in train + trainer.fit( + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit + call._call_and_handle_interrupt( + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt + return trainer_fn(*args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl + self._run(model, ckpt_path=self.ckpt_path) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1112, in _run + results = self._run_stage() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1191, in _run_stage + self._run_train() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1214, in _run_train + self.fit_loop.run() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run + self.on_advance_end() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 295, in on_advance_end + self.trainer._call_callback_hooks("on_train_epoch_end") + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1394, in _call_callback_hooks + fn(self, self.lightning_module, *args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/lightning_callback.py", line 52, in on_train_epoch_end + pa_dict = self.pa_metric( + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl + return forward_call(*input, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/basemetric.py", line 293, in forward + self.update(*args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torchmetrics/metric.py", line 388, in wrapped_func + update(*args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/metric.py", line 442, in update + self.pa_update(dist.get_rank() if self.processing_strategy == "cuda" else local_rank) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/metric.py", line 376, in pa_update + logits_dataset = self._compute_logits_dataset(rank) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/metric.py", line 350, in _compute_logits_dataset + for bidx, batch in enumerate(dataloader): + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/dataloader.py", line 521, in __next__ + data = self._next_data() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data + return self._process_data(data) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data + data.reraise() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/_utils.py", line 434, in reraise + raise exception +IndexError: Caught IndexError in DataLoader worker process 0. +Original Traceback (most recent call last): + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop + data = fetcher.fetch(index) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch + data = [self.dataset[idx] for idx in possibly_batched_index] + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in + data = [self.dataset[idx] for idx in possibly_batched_index] + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/datautils.py", line 50, in __getitem__ + return {str(i): dset[self.permutation[i][idx]] for i, dset in enumerate(self.dset_list)} + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/datautils.py", line 50, in + return {str(i): dset[self.permutation[i][idx]] for i, dset in enumerate(self.dset_list)} + File "/cluster/home/vjimenez/adv_pa_new/src/data/components/wilds_dataset.py", line 96, in __getitem__ + selected_idx = self.inds_to_select[idx] +IndexError: list index out of range + + +Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace. +Traceback (most recent call last): + File "/cluster/home/vjimenez/adv_pa_new/src/train.py", line 120, in main + metric_dict, _ = train(cfg) + File "/cluster/home/vjimenez/adv_pa_new/src/utils/utils.py", line 75, in wrap + raise ex + File "/cluster/home/vjimenez/adv_pa_new/src/utils/utils.py", line 65, in wrap + metric_dict, object_dict = task_func(cfg=cfg) + File "/cluster/home/vjimenez/adv_pa_new/src/train.py", line 84, in train + trainer.fit( + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit + call._call_and_handle_interrupt( + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt + return trainer_fn(*args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl + self._run(model, ckpt_path=self.ckpt_path) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1112, in _run + results = self._run_stage() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1191, in _run_stage + self._run_train() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1214, in _run_train + self.fit_loop.run() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run + self.on_advance_end() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 295, in on_advance_end + self.trainer._call_callback_hooks("on_train_epoch_end") + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1394, in _call_callback_hooks + fn(self, self.lightning_module, *args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/lightning_callback.py", line 52, in on_train_epoch_end + pa_dict = self.pa_metric( + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl + return forward_call(*input, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/basemetric.py", line 293, in forward + self.update(*args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torchmetrics/metric.py", line 388, in wrapped_func + update(*args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/metric.py", line 442, in update + self.pa_update(dist.get_rank() if self.processing_strategy == "cuda" else local_rank) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/metric.py", line 376, in pa_update + logits_dataset = self._compute_logits_dataset(rank) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/metric.py", line 350, in _compute_logits_dataset + for bidx, batch in enumerate(dataloader): + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/dataloader.py", line 521, in __next__ + data = self._next_data() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data + return self._process_data(data) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data + data.reraise() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/_utils.py", line 434, in reraise + raise exception +IndexError: Caught IndexError in DataLoader worker process 0. +Original Traceback (most recent call last): + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop + data = fetcher.fetch(index) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch + data = [self.dataset[idx] for idx in possibly_batched_index] + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in + data = [self.dataset[idx] for idx in possibly_batched_index] + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/datautils.py", line 50, in __getitem__ + return {str(i): dset[self.permutation[i][idx]] for i, dset in enumerate(self.dset_list)} + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/datautils.py", line 50, in + return {str(i): dset[self.permutation[i][idx]] for i, dset in enumerate(self.dset_list)} + File "/cluster/home/vjimenez/adv_pa_new/src/data/components/wilds_dataset.py", line 96, in __getitem__ + selected_idx = self.inds_to_select[idx] +IndexError: list index out of range + + +Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace. +srun: error: eu-lo-s4-080: tasks 1-3: Exited with exit code 1 +Traceback (most recent call last): + File "/cluster/home/vjimenez/adv_pa_new/src/train.py", line 120, in main + metric_dict, _ = train(cfg) + File "/cluster/home/vjimenez/adv_pa_new/src/utils/utils.py", line 75, in wrap + raise ex + File "/cluster/home/vjimenez/adv_pa_new/src/utils/utils.py", line 65, in wrap + metric_dict, object_dict = task_func(cfg=cfg) + File "/cluster/home/vjimenez/adv_pa_new/src/train.py", line 84, in train + trainer.fit( + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit + call._call_and_handle_interrupt( + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt + return trainer_fn(*args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl + self._run(model, ckpt_path=self.ckpt_path) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1112, in _run + results = self._run_stage() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1191, in _run_stage + self._run_train() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1214, in _run_train + self.fit_loop.run() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run + self.on_advance_end() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 295, in on_advance_end + self.trainer._call_callback_hooks("on_train_epoch_end") + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1394, in _call_callback_hooks + fn(self, self.lightning_module, *args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/lightning_callback.py", line 52, in on_train_epoch_end + pa_dict = self.pa_metric( + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl + return forward_call(*input, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/basemetric.py", line 293, in forward + self.update(*args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torchmetrics/metric.py", line 388, in wrapped_func + update(*args, **kwargs) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/metric.py", line 442, in update + self.pa_update(dist.get_rank() if self.processing_strategy == "cuda" else local_rank) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/metric.py", line 376, in pa_update + logits_dataset = self._compute_logits_dataset(rank) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/metrics/metric.py", line 350, in _compute_logits_dataset + for bidx, batch in enumerate(dataloader): + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/dataloader.py", line 521, in __next__ + data = self._next_data() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data + return self._process_data(data) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data + data.reraise() + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/_utils.py", line 434, in reraise + raise exception +IndexError: Caught IndexError in DataLoader worker process 0. +Original Traceback (most recent call last): + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop + data = fetcher.fetch(index) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch + data = [self.dataset[idx] for idx in possibly_batched_index] + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in + data = [self.dataset[idx] for idx in possibly_batched_index] + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/datautils.py", line 50, in __getitem__ + return {str(i): dset[self.permutation[i][idx]] for i, dset in enumerate(self.dset_list)} + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/posterioragreement/datautils.py", line 50, in + return {str(i): dset[self.permutation[i][idx]] for i, dset in enumerate(self.dset_list)} + File "/cluster/home/vjimenez/adv_pa_new/src/data/components/wilds_dataset.py", line 96, in __getitem__ + selected_idx = self.inds_to_select[idx] +IndexError: list index out of range + + +Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace. +srun: error: eu-lo-s4-080: task 0: Exited with exit code 1 diff --git a/slurm-52225078.out b/slurm-52225078.out new file mode 100644 index 0000000..0b66777 --- /dev/null +++ b/slurm-52225078.out @@ -0,0 +1,86 @@ +[2024-03-22 12:01:38,521][HYDRA] Launching 1 jobs locally +[2024-03-22 12:01:38,521][HYDRA] #0 : experiment=dg/wilds/camelyon17_erm +data/dg/wilds@data=camelyon17_oracle data.transform.is_training=True seed=123 name_logger=prova_configs_train_newconf_call trainer=ddp trainer.max_epochs=3 +trainer.fast_dev_run=False logger=wandb logger.wandb.group=test_wilds +[2024-03-22 12:01:38,524][HYDRA] Launching 1 jobs locally +[2024-03-22 12:01:38,524][HYDRA] #0 : experiment=dg/wilds/camelyon17_erm +data/dg/wilds@data=camelyon17_oracle data.transform.is_training=True seed=123 name_logger=prova_configs_train_newconf_call trainer=ddp trainer.max_epochs=3 +trainer.fast_dev_run=False logger=wandb logger.wandb.group=test_wilds +[2024-03-22 12:01:38,526][HYDRA] Launching 1 jobs locally +[2024-03-22 12:01:38,526][HYDRA] #0 : experiment=dg/wilds/camelyon17_erm +data/dg/wilds@data=camelyon17_oracle data.transform.is_training=True seed=123 name_logger=prova_configs_train_newconf_call trainer=ddp trainer.max_epochs=3 +trainer.fast_dev_run=False logger=wandb logger.wandb.group=test_wilds +[2024-03-22 12:01:38,533][HYDRA] Launching 1 jobs locally +[2024-03-22 12:01:38,533][HYDRA] #0 : experiment=dg/wilds/camelyon17_erm +data/dg/wilds@data=camelyon17_oracle data.transform.is_training=True seed=123 name_logger=prova_configs_train_newconf_call trainer=ddp trainer.max_epochs=3 +trainer.fast_dev_run=False logger=wandb logger.wandb.group=test_wilds +[rank: 1] Global seed set to 123 +[rank: 3] Global seed set to 123 +[rank: 2] Global seed set to 123 +[rank: 0] Global seed set to 123 +[2024-03-22 12:01:38,953][__main__][INFO] - Instantiating datamodule  +[2024-03-22 12:01:39,076][__main__][INFO] - Instantiating model  +Error executing job with overrides: ['experiment=dg/wilds/camelyon17_erm', '+data/dg/wilds@data=camelyon17_oracle', 'data.transform.is_training=True', 'seed=123', 'name_logger=prova_configs_train_newconf_call', 'trainer=ddp', 'trainer.max_epochs=3', '+trainer.fast_dev_run=False', 'logger=wandb', 'logger.wandb.group=test_wilds'] +Error executing job with overrides: ['experiment=dg/wilds/camelyon17_erm', '+data/dg/wilds@data=camelyon17_oracle', 'data.transform.is_training=True', 'seed=123', 'name_logger=prova_configs_train_newconf_call', 'trainer=ddp', 'trainer.max_epochs=3', '+trainer.fast_dev_run=False', 'logger=wandb', 'logger.wandb.group=test_wilds'] +Error executing job with overrides: ['experiment=dg/wilds/camelyon17_erm', '+data/dg/wilds@data=camelyon17_oracle', 'data.transform.is_training=True', 'seed=123', 'name_logger=prova_configs_train_newconf_call', 'trainer=ddp', 'trainer.max_epochs=3', '+trainer.fast_dev_run=False', 'logger=wandb', 'logger.wandb.group=test_wilds'] +[2024-03-22 12:01:39,085][src.utils.utils][ERROR] -  +Traceback (most recent call last): + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/hydra/_internal/utils.py", line 644, in _locate + obj = getattr(obj, part) +AttributeError: module 'src.models' has no attribute 'erm' + +During handling of the above exception, another exception occurred: + +Traceback (most recent call last): + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/hydra/_internal/utils.py", line 650, in _locate + obj = import_module(mod) + File "/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/importlib/__init__.py", line 127, in import_module + return _bootstrap._gcd_import(name[level:], package, level) + File "", line 1030, in _gcd_import + File "", line 1007, in _find_and_load + File "", line 986, in _find_and_load_unlocked + File "", line 680, in _load_unlocked + File "", line 850, in exec_module + File "", line 228, in _call_with_frames_removed + File "/cluster/home/vjimenez/adv_pa_new/src/models/erm.py", line 8, in + class ERM(LightningModule): + File "/cluster/home/vjimenez/adv_pa_new/src/models/erm.py", line 25, in ERM + def _extract_batch(self, batch: Union[dict, tuple]): +NameError: name 'Union' is not defined + +The above exception was the direct cause of the following exception: + +Traceback (most recent call last): + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 134, in _resolve_target + target = _locate(target) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/hydra/_internal/utils.py", line 658, in _locate + raise ImportError( +ImportError: Error loading 'src.models.erm.ERM': +NameError("name 'Union' is not defined") + +The above exception was the direct cause of the following exception: + +Traceback (most recent call last): + File "/cluster/home/vjimenez/adv_pa_new/src/utils/utils.py", line 65, in wrap + metric_dict, object_dict = task_func(cfg=cfg) + File "/cluster/home/vjimenez/adv_pa_new/src/train.py", line 54, in train + model: LightningModule = hydra.utils.instantiate(cfg.model) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 226, in instantiate + return instantiate_node( + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 333, in instantiate_node + _target_ = _resolve_target(node.get(_Keys.TARGET), full_key) + File "/cluster/project/jbuhmann/posterior_agreement/.venvs/adv_pa/lib64/python3.9/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 139, in _resolve_target + raise InstantiationException(msg) from e +hydra.errors.InstantiationException: Error locating target 'src.models.erm.ERM', set env var HYDRA_FULL_ERROR=1 to see chained exception. +full_key: model +[2024-03-22 12:01:39,122][src.utils.utils][INFO] - Output dir: /cluster/project/jbuhmann/posterior_agreement/adv_pa/logs/victor/dg_wilds/multiruns/2024-03-22_12-01-38/0 +Error executing job with overrides: ['experiment=dg/wilds/camelyon17_erm', '+data/dg/wilds@data=camelyon17_oracle', 'data.transform.is_training=True', 'seed=123', 'name_logger=prova_configs_train_newconf_call', 'trainer=ddp', 'trainer.max_epochs=3', '+trainer.fast_dev_run=False', 'logger=wandb', 'logger.wandb.group=test_wilds'] +Error locating target 'src.models.erm.ERM', set env var HYDRA_FULL_ERROR=1 to see chained exception. +full_key: model + +Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace. +Error locating target 'src.models.erm.ERM', set env var HYDRA_FULL_ERROR=1 to see chained exception. +full_key: model + +Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace. +Error locating target 'src.models.erm.ERM', set env var HYDRA_FULL_ERROR=1 to see chained exception. +full_key: model + +Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace. +Error locating target 'src.models.erm.ERM', set env var HYDRA_FULL_ERROR=1 to see chained exception. +full_key: model + +Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace. +srun: error: eu-g3-044: tasks 0-3: Exited with exit code 1 diff --git a/src/callbacks/accuracy.py b/src/callbacks/accuracy.py new file mode 100644 index 0000000..f63d646 --- /dev/null +++ b/src/callbacks/accuracy.py @@ -0,0 +1,74 @@ +from pytorch_lightning.callbacks import Callback +from torchmetrics import Accuracy, F1Score, Recall, Specificity, Precision + +class Accuracy_Callback(Callback): + """ + Computes and logs general accuracy metrics during training and/or testing. + """ + def __init__(self, n_classes: int): + super().__init__() + + self.n_classes = n_classes + _task = "multiclass" if n_classes > 2 else "binary" + + # Training metrics + self.train_acc = Accuracy(task=_task, num_classes=self.n_classes, average="macro") + self.train_f1 = F1Score(task=_task, num_classes=self.n_classes, average="macro") + self.train_specificity = Specificity(task=_task, num_classes=self.n_classes, average="macro") + self.train_sensitivity = Recall(task=_task, num_classes=self.n_classes, average="macro") + self.train_precision = Precision(task=_task, num_classes=self.n_classes, average="macro") + + # Validation metrics + self.val_acc = Accuracy(task=_task, num_classes=self.n_classes, average="macro") + self.val_f1 = F1Score(task=_task, num_classes=self.n_classes, average="macro") + self.val_specificity = Specificity(task=_task, num_classes=self.n_classes, average="macro") + self.val_sensitivity = Recall(task=_task, num_classes=self.n_classes, average="macro") + self.val_precision = Precision(task=_task, num_classes=self.n_classes, average="macro") + + # Test metrics + self.test_acc = Accuracy(task=_task, num_classes=self.n_classes, average="macro") + self.test_f1 = F1Score(task=_task, num_classes=self.n_classes, average="macro") + self.test_specificity = Specificity(task=_task, num_classes=self.n_classes, average="macro") + self.test_sensitivity = Recall(task=_task, num_classes=self.n_classes, average="macro") + self.test_precision = Precision(task=_task, num_classes=self.n_classes, average="macro") + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + y, preds = outputs["targets"], outputs["preds"] + + metrics_dict = { + "train/loss": outputs["loss"], + "train/acc": self.train_acc.to(pl_module.device)(preds, y), + "train/f1": self.train_f1.to(pl_module.device)(preds, y), + "train/sensitivity": self.train_sensitivity.to(pl_module.device)(preds, y), + "train/specificity": self.train_specificity.to(pl_module.device)(preds, y), + "train/precision": self.train_precision.to(pl_module.device)(preds, y), + } + + pl_module.log_dict(metrics_dict, prog_bar=False, on_step=True, on_epoch=True, logger=True, sync_dist=True) + + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + y, preds = outputs["targets"], outputs["preds"] + + metrics_dict = { + "val/loss": outputs["loss"], + "val/acc": self.val_acc.to(pl_module.device)(preds, y), + "val/f1": self.val_f1.to(pl_module.device)(preds, y), + "val/sensitivity": self.val_sensitivity.to(pl_module.device)(preds, y), + "val/specificity": self.val_specificity.to(pl_module.device)(preds, y), + "val/precision": self.val_precision.to(pl_module.device)(preds, y), + } + pl_module.log_dict(metrics_dict, prog_bar=False, on_step=True, on_epoch=True, logger=True, sync_dist=True) + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + y, preds = outputs["targets"], outputs["preds"] + + metrics_dict = { + "test/loss": outputs["loss"], + "test/acc": self.test_acc.to(pl_module.device)(preds, y), + "test/f1": self.test_f1.to(pl_module.device)(preds, y), + "test/sensitivity": self.test_sensitivity.to(pl_module.device)(preds, y), + "test/specificity": self.test_specificity.to(pl_module.device)(preds, y), + "test/precision": self.test_precision.to(pl_module.device)(preds, y), + } + pl_module.log_dict(metrics_dict, prog_bar=False, on_step=True, on_epoch=True, logger=True, sync_dist=False) # SINGLE DEVICE \ No newline at end of file diff --git a/src/callbacks/accuracy_domains.py b/src/callbacks/accuracy_domains.py new file mode 100644 index 0000000..3ea8079 --- /dev/null +++ b/src/callbacks/accuracy_domains.py @@ -0,0 +1,95 @@ +from pytorch_lightning.callbacks import Callback +from torchmetrics import Accuracy +import torch + +class AccuracyDomains_Callback(Callback): + """ + Computes and logs general accuracy metrics specific for domain shifts (OOD and subpopulation shifts): + - Average accuracy across all domains. + - Worst domain accuracy. + """ + def __init__( + self, + n_classes: int, + n_domains_train: int, + n_domains_val: int, + n_domains_test: int + ): + super().__init__() + + self.n_classes = n_classes + self.n_domains_train, self.n_domains_val, self.n_domains_test = n_domains_train, n_domains_val, n_domains_test + + _task = "multiclass" if n_classes > 2 else "binary" + + self.train_acc_average = Accuracy(task=_task, num_classes=n_classes, average="macro") + self.train_acc = { + f'acc_{i}': Accuracy(task=_task, num_classes=n_classes, average="macro") + for i in range(n_domains_train) + } + + self.val_acc_average = Accuracy(task=_task, num_classes=n_classes, average="macro") + self.val_acc = { + f'acc_{i}': Accuracy(task=_task, num_classes=n_classes, average="macro") + for i in range(n_domains_val) + } + + self.test_acc_average = Accuracy(task=_task, num_classes=n_classes, average="macro") + self.test_acc = { + f'acc_{i}': Accuracy(task=_task, num_classes=n_classes, average="macro") + for i in range(n_domains_test) + } + + def _mask_each_domain(self, batch: dict, env_index: int): + """Returns a mask for the complete target and preds vectors corresponding to the current domain.""" + + intervals = torch.cat(( + torch.tensor([0]), + torch.cumsum( + torch.tensor([len(batch[env][1]) for env in batch.keys()]), + dim=0 + ) + )) + + len_total = intervals[-1] + return (torch.arange(len_total) >= intervals[env_index]) & (torch.arange(len_total) < intervals[env_index+1]) + + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + y, preds = outputs["targets"], outputs["preds"] + + metrics_dict = { + 'train/acc_average': self.train_acc_average.to(pl_module.device)(preds, y) + } + for env in batch.keys(): + assert int(env) in range(self.n_domains_train), f"Environment {env} not in range {self.n_domains_train}." + mask = self._mask_each_domain(batch, int(env)) + metrics_dict[f'train/acc_{env}'] = self.train_acc[f'acc_{env}'].to(pl_module.device)(preds[mask], y[mask]) + + pl_module.log_dict(metrics_dict, prog_bar=False, on_step=True, on_epoch=True, logger=True, sync_dist=True) + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + y, preds = outputs["targets"], outputs["preds"] + + metrics_dict = { + 'val/acc_average': self.val_acc_average.to(pl_module.device)(preds, y) + } + for env in batch.keys(): + assert int(env) in range(self.n_domains_val), f"Environment {env} not in range {self.n_domains_val}." + mask = self._mask_each_domain(batch, int(env)) + metrics_dict[f'val/acc_{env}'] = self.val_acc[f'acc_{env}'].to(pl_module.device)(preds[mask], y[mask]) + + pl_module.log_dict(metrics_dict, prog_bar=False, on_step=True, on_epoch=True, logger=True, sync_dist=True) + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + y, preds = outputs["targets"], outputs["preds"] + + metrics_dict = { + 'test/acc_average': self.test_acc_average.to(pl_module.device)(preds, y) + } + for env in batch.keys(): + assert int(env) in range(self.n_domains_test), f"Environment {env} not in range {self.n_domains_test}." + mask = self._mask_each_domain(batch, int(env)) + metrics_dict[f'test/acc_{env}'] = self.test_acc[f'acc_{env}'].to(pl_module.device)(preds[mask], y[mask]) + + pl_module.log_dict(metrics_dict, prog_bar=False, on_step=True, on_epoch=True, logger=True, sync_dist=False) # SINGLE DEVICE \ No newline at end of file diff --git a/src/callbacks/debugging.py b/src/callbacks/debugging.py new file mode 100644 index 0000000..92b1c9f --- /dev/null +++ b/src/callbacks/debugging.py @@ -0,0 +1,70 @@ +from pytorch_lightning.callbacks import Callback +from torchmetrics import Accuracy +import torch + +class Debugging_Callback(Callback): + """ + Callback used to accumulate and print metrics along the training pipeline for debugging purposes. + """ + def __init__(self): + super().__init__() + + self.len_train = torch.tensor(0) + self.right_train = torch.tensor(0) + + self.len_val = torch.tensor(0) + self.right_val = torch.tensor(0) + + self.len_test = torch.tensor(0) + self.right_test = torch.tensor(0) + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + y, preds = outputs["targets"], outputs["preds"] + + # Update the metrics + self.len_train = torch.tensor([self.len_train.item() + len(y)]).to(pl_module.device) + self.right_train = torch.tensor([self.right_train.item() + torch.sum(torch.eq(y, preds)).item()]).to(pl_module.device) + + metrics_dict = { + 'debug/train_len': self.len_train, + 'debug/train_right': self.right_train + } + pl_module.log_dict(metrics_dict, prog_bar=False, on_step=False, on_epoch=True, logger=True, sync_dist=True) + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + y, preds = outputs["targets"], outputs["preds"] + + # Update the metrics + self.len_val = torch.tensor([self.len_val.item() + len(y)]).to(pl_module.device) + self.right_val = torch.tensor([self.right_val.item() + torch.sum(torch.eq(y, preds)).item()]).to(pl_module.device) + + metrics_dict = { + 'debug/val_len': self.len_val, + 'debug/val_right': self.right_val + } + + pl_module.log_dict(metrics_dict, prog_bar=False, on_step=False, on_epoch=True, logger=True, sync_dist=True) + + def on_fit_end(self, trainer, pl_module): + print("\nTraining is concluded, showing debugging metrics:") + print("Length train, val: ", self.len_train.item(), self.len_val.item()) + print("Accuracy train, val: ", self.right_train.item() / self.len_train.item(), self.right_val.item() / self.len_val.item()) + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + y, preds = outputs["targets"], outputs["preds"] + + # Update the metrics + self.len_test = torch.tensor([self.len_test.item() + len(y)]).to(pl_module.device) + self.right_test = torch.tensor([self.right_test.item() + torch.sum(torch.eq(y, preds)).item()]).to(pl_module.device) + + metrics_dict = { + 'debug/train_len': self.len_test, + 'debug/train_right': self.right_test + } + + pl_module.log_dict(metrics_dict, prog_bar=False, on_step=False, on_epoch=True, logger=True, sync_dist=False) # SINGLE DEVICE + + def on_test_end(self, trainer, pl_module): + print("\nTesting is is concluded, showing debugging metrics:") + print("Length test: ", self.len_test.item()) + print("Accuracy test: ", self.right_test.item() / self.len_test.item()) \ No newline at end of file diff --git a/src/data/cifar10_datamodule.py b/src/data/cifar10_datamodules.py similarity index 83% rename from src/data/cifar10_datamodule.py rename to src/data/cifar10_datamodules.py index 4dea6fd..a85c7ac 100644 --- a/src/data/cifar10_datamodule.py +++ b/src/data/cifar10_datamodules.py @@ -12,15 +12,11 @@ from secml.adv.attacks import CAttack -from src.data.components import MultienvDataset, LogitsDataset -from src.data.components.adv import AdversarialCIFAR10Dataset +from src.data.components import MultienvDataset +from src.data.components.cifar10_dataset import AdversarialCIFAR10Dataset from src.data.utils import carray2tensor from src.data.components.collate_functions import MultiEnv_collate_fn -from torch.utils.data.distributed import DistributedSampler -from src.pa_metric_torch import PosteriorAgreementSampler - - class CIFAR10DataModule(LightningDataModule): """Example of LightningDataModule for MNIST dataset. @@ -131,13 +127,13 @@ def setup(self, stage: Optional[str] = None): self.hparams.cache, ) - self.paired_dset = MultienvDataset( + self.train_ds = MultienvDataset( [self.original_dset, self.adversarial_dset] ) def train_dataloader(self): return DataLoader( - dataset=self.paired_dset, + dataset=self.train_ds, batch_size=self.hparams.batch_size, collate_fn=MultiEnv_collate_fn, num_workers=self.hparams.num_workers, @@ -160,9 +156,30 @@ def load_state_dict(self, state_dict: Dict[str, Any]): """Things to do when loading checkpoint.""" pass + from src.data.components.logits_pa import LogitsPA +from src.pa_metric.pairing import PosteriorAgreementDatasetPairing + +class CIFAR10DataModulePA(CIFAR10DataModule): + """ + DataModule to use for PA optimization. + """ + def setup(self, stage: Optional[str] = None): + super().setup(stage) + self.train_ds = PosteriorAgreementDatasetPairing(self.train_ds) + + def train_dataloader(self): + return DataLoader( + dataset=self.train_ds, + batch_size=self.hparams.batch_size, + collate_fn=MultiEnv_collate_fn, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=True, # We change shuffle to True, because we intend to optimize the PA. + ) + -class CIFAR10DataModulelogits(LogitsPA, CIFAR10DataModule): +class CIFAR10DataModulePAlogits(LogitsPA, CIFAR10DataModulePA): def __init__(self, classifier: torch.nn.Module, *args, **kwargs): super().__init__(classifier) - CIFAR10DataModule.__init__(self, classifier, *args, **kwargs) \ No newline at end of file + CIFAR10DataModulePA.__init__(self, classifier, *args, **kwargs) \ No newline at end of file diff --git a/src/data/components/__init__.py b/src/data/components/__init__.py index 3f4190d..353957d 100644 --- a/src/data/components/__init__.py +++ b/src/data/components/__init__.py @@ -1,7 +1,7 @@ import torch from torch import Tensor -from torch.utils.data import Dataset, TensorDataset -from typing import List, Tuple +from torch.utils.data import Dataset, TensorDataset, Subset +from typing import List, Optional class MultienvDataset(Dataset): """ @@ -9,17 +9,20 @@ class MultienvDataset(Dataset): """ def __init__(self, dset_list: List[Dataset]): - len_ds = len(dset_list[0]) + len_ds = min([len(ds) for ds in dset_list]) + + same_size = True for ds in dset_list: if len(ds) != len_ds: - raise ValueError("All datasets must have the same size.") - + same_size = False + break + self.dset_list = dset_list self.num_envs = len(dset_list) - self.permutation = [torch.arange(len(self.dset_list[0])).tolist()]*self.num_envs + self.permutation = [torch.arange(len(ds)).tolist() for ds in dset_list] def __len__(self): - return len(self.permutation[0]) + return min([len(perm) for perm in self.permutation]) def __getitem__(self, idx: int): return {str(i): dset[self.permutation[i][idx]] for i, dset in enumerate(self.dset_list)} @@ -31,18 +34,11 @@ def __getitems__(self, indices: List[int]): # Is there a way to do it without multiplicating the calls to __getitem__? output_list = [None]*self.num_envs for i, dset in enumerate(self.dset_list): - output_list[i] = tuple([torch.stack([dset.__getitem__(self.permutation[i][idx])[0] for idx in indices]), - torch.tensor([dset.__getitem__(self.permutation[i][idx])[1] for idx in indices])]) - + output_list[i] = tuple([torch.stack([self.__getitem__(idx)[str(i)][0] for idx in indices]), + torch.tensor([self.__getitem__(idx)[str(i)][1] for idx in indices])]) + return output_list - def Subset(self, indices: List[int]): - """ - Returns a new MultienvDataset object with the subset of the original dataset. - """ - subset_items = self.__getitems__(indices) - return MultienvDataset([TensorDataset(*env_subset) for env_subset in subset_items]) - def __getlabels__(self, indices: List[int]): """ Useful method to retrieve only the labels associated with a specific index. This will help with the pairing of samples for the metric. @@ -50,9 +46,16 @@ def __getlabels__(self, indices: List[int]): output_list = [None]*self.num_envs for i, dset in enumerate(self.dset_list): - output_list[i] = torch.tensor([dset.__getitem__(self.permutation[i][idx])[1] for idx in indices]) - + output_list[i] = torch.tensor([self.__getitem__(idx)[str(i)][1] for idx in indices]) + return output_list + + def Subset(self, indices: List[int]): + """ + Returns a new MultienvDataset object with the subset of the original dataset. + """ + subset_items = self.__getitems__(indices) + return MultienvDataset([TensorDataset(*env_subset) for env_subset in subset_items]) class LogitsDataset(Dataset): """ @@ -69,6 +72,9 @@ def _check_input(self, logits: List[Tensor], y: Tensor) -> None: assert all(logits[0].size(0) == logit.size(0) for logit in logits), "Size mismatch between logits" assert all(y.size(0) == logit.size(0) for logit in logits), "Size mismatch between y and logits" + def __len__(self): + return self.logits[0].size(0) + def __additem__(self, logits: List[Tensor], y: Tensor) -> None: """ This method is slow, because it's concatenating tensors, so it should be avoided whenever possible. @@ -81,6 +87,9 @@ def __additem__(self, logits: List[Tensor], y: Tensor) -> None: def __getitem__(self, index: int): return {str(i): tuple([self.logits[i][index], self.y[index]]) for i in range(self.num_envs)} - - def __len__(self): - return self.logits[0].size(0) \ No newline at end of file + + def __getitems__(self, indices: List[int]): + """ + When I request several items, I prefer to get a tensor for each dataset. + """ + return [tuple([self.logits[i][indices], self.y[indices]]) for i in range(self.num_envs)] \ No newline at end of file diff --git a/src/data/components/adv.py b/src/data/components/adv.py index e67d2f9..a65ebda 100644 --- a/src/data/components/adv.py +++ b/src/data/components/adv.py @@ -1,14 +1,5 @@ -from abc import ABCMeta - -import os -import os.path as osp - -import torch -from torch.utils.data import TensorDataset, random_split - -from secml.array import CArray from secml.ml.classifiers import CClassifierPyTorch -from secml.adv.attacks.evasion import CAttackEvasion + from secml.adv.attacks.evasion.foolbox.fb_attacks.fb_ddn_attack import ( CFoolboxL2DDN, ) @@ -19,91 +10,8 @@ from secml.adv.attacks.evasion import CAttackEvasionFoolbox from foolbox.attacks import LInfFMNAttack from foolbox.attacks.basic_iterative_method import LinfBasicIterativeAttack -from secml.data.loader import CDataLoaderCIFAR10 -from secml.ml.peval.metrics import CMetricAccuracy - -from src.data.utils import carray2tensor - from src.data.components.gaussian_attack import GaussianAttack -class AdversarialCIFAR10Dataset(TensorDataset): - """Generate adversarially crafted CIFAR10 data, for an image - classification problem. - """ - - dset_name: str = "cifar10" - dset: ABCMeta = CDataLoaderCIFAR10 - dset_shape: tuple = (3, 32, 32) - - def __init__( - self, - attack: CAttackEvasion, - classifier: CClassifierPyTorch, - data_dir: str = osp.join(".", "data", "datasets"), - checkpoint_fname: str = "checkpoint.pt", - adversarial_ratio: float = 1.0, - verbose: bool = False, - cache: bool = False, - ): - _, ts = self.dset().load(val_size=0) - X, Y = ts.X / 255.0, ts.Y - - self.attacked_classifier = classifier - - fname = osp.join(data_dir, checkpoint_fname) - if cache and osp.exists(fname): - if verbose: - print( - f"Loaded found Adversarial {self.dset_name} dataset " - f"in {fname}" - ) - adv_X = torch.load(fname) - else: - if verbose: - print("Attack started...") - - adv_Y_pred, adv_scores, adv_ds, adv_f_obj = attack.run(X, Y) - - if verbose: - print( - f"Attack complete! Adversarial {self.dset_name} dataset " - "stored in ", - fname, - ) - adv_X = carray2tensor(adv_ds.X, torch.float32) - if cache: - os.makedirs(data_dir, exist_ok=True) - torch.save(adv_X.to("cpu"), fname) - - if adversarial_ratio != 1.0: # TODO: specify samples to be corrupted - X = carray2tensor(X, torch.float32) - if adversarial_ratio == 0.0: - adv_X = X - - dset_size = X.shape[0] - - split = int(adversarial_ratio * dset_size) - attack_norms = (adv_X - X).norm(p=float("inf"), dim=1) - - _, unpoison_ids = attack_norms.topk(dset_size - split) - - # remove poison for the largest 1 - adversarial_ratio attacked ones - adv_X[unpoison_ids] = X[unpoison_ids] - - adv_X = adv_X.reshape(-1, *self.dset_shape) - Y = carray2tensor(Y, torch.long) - - super().__init__(adv_X, Y) - - def performance_adversarial(self): - X = CArray(self.tensors[0].to(torch.float64).numpy()) - Y = CArray(self.tensors[1].to(torch.float64).numpy()) - metric = CMetricAccuracy() - y_pred = self.classifier.predict(X) - acc = metric.performance_score(y_true=Y, y_pred=y_pred) - return acc - - def get_attack(attack_name: str, classifier: CClassifierPyTorch, **kwargs): """Retrieve the attack and store its name.""" if attack_name == "PGD": @@ -148,4 +56,4 @@ def get_attack(attack_name: str, classifier: CClassifierPyTorch, **kwargs): if config else "" ) - return attack + return attack \ No newline at end of file diff --git a/src/data/components/cifar10_dataset.py b/src/data/components/cifar10_dataset.py new file mode 100644 index 0000000..cb9761c --- /dev/null +++ b/src/data/components/cifar10_dataset.py @@ -0,0 +1,93 @@ +from abc import ABCMeta + +import os +import os.path as osp + +import torch +from torch.utils.data import TensorDataset, random_split + +from secml.array import CArray +from secml.ml.classifiers import CClassifierPyTorch +from secml.adv.attacks.evasion import CAttackEvasion + +from secml.data.loader import CDataLoaderCIFAR10 +from secml.ml.peval.metrics import CMetricAccuracy + +from src.data.utils import carray2tensor + +class AdversarialCIFAR10Dataset(TensorDataset): + """Generate adversarially crafted CIFAR10 data, for an image + classification problem. + """ + + dset_name: str = "cifar10" + dset: ABCMeta = CDataLoaderCIFAR10 + dset_shape: tuple = (3, 32, 32) + + def __init__( + self, + attack: CAttackEvasion, + classifier: CClassifierPyTorch, + data_dir: str = osp.join(".", "data", "datasets"), + checkpoint_fname: str = "checkpoint.pt", + adversarial_ratio: float = 1.0, + verbose: bool = False, + cache: bool = False, + ): + _, ts = self.dset().load(val_size=0) + X, Y = ts.X / 255.0, ts.Y + + self.attacked_classifier = classifier + + fname = osp.join(data_dir, checkpoint_fname) + if cache and osp.exists(fname): + if verbose: + print( + f"Loaded found Adversarial {self.dset_name} dataset " + f"in {fname}" + ) + adv_X = torch.load(fname) + else: + if verbose: + print("Attack started...") + + adv_Y_pred, adv_scores, adv_ds, adv_f_obj = attack.run(X, Y) + + if verbose: + print( + f"Attack complete! Adversarial {self.dset_name} dataset " + "stored in ", + fname, + ) + adv_X = carray2tensor(adv_ds.X, torch.float32) + if cache: + os.makedirs(data_dir, exist_ok=True) + torch.save(adv_X.to("cpu"), fname) + + if adversarial_ratio != 1.0: # TODO: specify samples to be corrupted + X = carray2tensor(X, torch.float32) + if adversarial_ratio == 0.0: + adv_X = X + + dset_size = X.shape[0] + + split = int(adversarial_ratio * dset_size) + attack_norms = (adv_X - X).norm(p=float("inf"), dim=1) + + _, unpoison_ids = attack_norms.topk(dset_size - split) + + # remove poison for the largest 1 - adversarial_ratio attacked ones + adv_X[unpoison_ids] = X[unpoison_ids] + + adv_X = adv_X.reshape(-1, *self.dset_shape) + Y = carray2tensor(Y, torch.long) + + super().__init__(adv_X, Y) + + def performance_adversarial(self): + X = CArray(self.tensors[0].to(torch.float64).numpy()) + Y = CArray(self.tensors[1].to(torch.float64).numpy()) + metric = CMetricAccuracy() + y_pred = self.classifier.predict(X) + acc = metric.performance_score(y_true=Y, y_pred=y_pred) + return acc diff --git a/src/data/components/collate_functions.py b/src/data/components/collate_functions.py index 84905c9..c2d3535 100644 --- a/src/data/components/collate_functions.py +++ b/src/data/components/collate_functions.py @@ -10,7 +10,6 @@ def MultiEnv_collate_fn(batch: List): The output is of the form: batch_dict = { - "envs": [env1_name, env2_name, ...], "env1_name": [x1, y1], "env2_name": [x2, y2], ... @@ -23,7 +22,6 @@ def MultiEnv_collate_fn(batch: List): torch.tensor([b[env][1] for b in batch]), ] - batch_dict["envs"] = [env for env in batch[0].keys()] return batch_dict @@ -32,7 +30,7 @@ def SingleEnv_collate_fn(batch: List): The output is of the form: (Tensor[x_0_env1, ..., x_n_env1, x_0_env2, ..., x_n_env2, ...], - Tensor[y_0_env1, ..., y_n_env1, y_0_env2, ..., y_n_env2, ...]) + Tensor[y_0_env1, ..., y_n_env1, y_0_env2, ..., y_n_env2, ...]) """ x = torch.stack([b[env][0] for env in batch[0] for b in batch]) y = torch.tensor([b[env][1] for env in batch[0] for b in batch]) diff --git a/src/data/components/gaussian_attack.py b/src/data/components/gaussian_attack.py index 854b889..2b3ce5d 100644 --- a/src/data/components/gaussian_attack.py +++ b/src/data/components/gaussian_attack.py @@ -25,7 +25,7 @@ def __init__(self, super(CAttackEvasion, self).__init__(classifier) self.attacked_classifier = classifier - self.noise_std = epsilons/3 + self.epsilon = epsilons def _run(self): return @@ -44,8 +44,8 @@ def objective_function_gradient(self): def run(self, X, Y): X = CArray(X).atleast_2d() - noise = CArray(randn(X.shape[0], X.shape[1])*self.noise_std) - adv_X = X + noise + noise = CArray(randn(X.shape[0], X.shape[1])*self.epsilon/3).clip(-self.epsilon,self.epsilon) + adv_X = X + noise.abs() Y = CArray(Y).atleast_2d() adv_ds = CDataset(adv_X.deepcopy(), Y.deepcopy()) diff --git a/src/data/components/logits_pa.py b/src/data/components/logits_pa.py index b67442f..87b84be 100644 --- a/src/data/components/logits_pa.py +++ b/src/data/components/logits_pa.py @@ -1,8 +1,11 @@ import torch -from src.data.components import LogitsDataset +from src.data.components import Dataset, LogitsDataset from src.data.components.collate_functions import MultiEnv_collate_fn from torch.utils.data import DataLoader, DistributedSampler -from src.pa_metric_torch import PosteriorAgreementSampler +from src.pa_metric.pairing import PosteriorAgreementDatasetPairing +from typing import Optional + +import gc # garbage collector for the dataset class LogitsPA: """ @@ -11,15 +14,29 @@ class LogitsPA: def __init__(self, classifier: torch.nn.Module): self.classifier = classifier - self.dev = "cuda" if torch.cuda.is_available() else "cpu" # modify to set device - def _logits_dataset(self, image_dataloader: DataLoader): + def _logits_dataset(self, image_dataset: Dataset): self.classifier.eval() + self.dev = next(self.classifier.parameters()).device # wherever the model is + + # TODO: Check if this is the best option + # Batch size is for training, not for evaluating, so it should be higher. + # Num workers might give us problems when we are on lightning DDP. + image_dataloader = DataLoader( + dataset=image_dataset, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=MultiEnv_collate_fn, + shuffle=False, + drop_last = False + ) + with torch.no_grad(): logits = [] ys = [] for bidx, batch in enumerate(image_dataloader): - envs = batch["envs"] + envs = list(batch.keys()) X_list = [batch[envs[e]][0] for e in range(self.num_envs)] Y_list = [batch[envs[e]][1] for e in range(self.num_envs)] if not all([torch.equal(Y_list[0], Y_list[i]) for i in range(1, len(Y_list))]): # all labels must be equal @@ -28,25 +45,27 @@ def _logits_dataset(self, image_dataloader: DataLoader): logits.append([self.classifier(X.to(self.dev)) for X in X_list]) ys.append(Y_list[0]) - return LogitsDataset( - [torch.cat([logits[bidx][e] for bidx in range(len(ys))]) for e in range(self.num_envs)], + lds = LogitsDataset( + [torch.cat([logits[i][e] for i in range(len(ys))]) for e in range(self.num_envs)], torch.cat(ys) ) + return lds - def _set_sampler(self): - """For the super().train_dataloader()""" - # Because I operate within the GPU, but still want pairing - return PosteriorAgreementSampler(self.test_pairedds, shuffle=False, drop_last = True, num_replicas=1, rank=0) + def setup(self, stage: Optional[str] = None): + super().setup(stage) + self.logits_ds = self._logits_dataset(self.train_ds) + + # Free up memory asap + self.train_ds = None + gc.collect() def train_dataloader(self): - logits_dataset = self._logits_dataset(super().train_dataloader()) # unconventional here but still GPU return DataLoader( - dataset=logits_dataset, + dataset=self.logits_ds, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, collate_fn=MultiEnv_collate_fn, - sampler=DistributedSampler(logits_dataset, shuffle=False, drop_last = False) # samples already paired with PosteriorAgreementSampler - ) - - # val_dataloader should be the same as train_dataloader in the main class \ No newline at end of file + shuffle=True, + drop_last = False + ) \ No newline at end of file diff --git a/src/data/components/wilds_dataset.py b/src/data/components/wilds_dataset.py new file mode 100644 index 0000000..f2b11bc --- /dev/null +++ b/src/data/components/wilds_dataset.py @@ -0,0 +1,129 @@ +from typing import Optional, Callable, List +from wilds.datasets.wilds_dataset import WILDSDataset + +import torch +import torchvision.transforms as transforms +from torch.utils.data import Dataset + +from omegaconf import DictConfig + +def WILDS_multiple_to_single(multiple_env_config: DictConfig) -> dict: + """ + Converts a list of environment configurations to a single dictionary so that it can be interpreted by + the WILDSDatasetEnv as a single environment. This will allow us to perform configuration interpolations instead + of having to specify the PA datasets for every experiment. + """ + + combined_values = {} + all_group_by_fields = set() + + # Collect all possible group_by_fields and initialize combined_values + for _, envconf in multiple_env_config.items(): + for field in envconf['group_by_fields']: + all_group_by_fields.add(field) + if field not in combined_values: + combined_values[field] = [] + + # Aggregate values for each field from all environments + for _, envconf in multiple_env_config.items(): + for field in all_group_by_fields: + if field in envconf['group_by_fields']: # Only add if the field is used in this env + combined_values[field].extend(envconf['values'].get(field, [])) + + # Remove duplicates and sort + for field in combined_values: + combined_values[field] = sorted(list(set(combined_values[field]))) + + # Construct new dictionary assuming template of combined fields + env1_dict = { + 'split_name': 'train', + 'group_by_fields': list(all_group_by_fields), + 'values': combined_values + } + return env1_dict + +class WILDSDatasetEnv(Dataset): + """ + Provides a dataset for a specific environment. + """ + def __init__( + self, + dataset: WILDSDataset, + env_config: dict, + transform: Optional[Callable] = None + ): + + # Initial checks: + assert isinstance(dataset, WILDSDataset), "The dataset must be an instance of WILDSDataset." + assert list(env_config.keys()) == ["split_name", "group_by_fields", "values"], "The env_config must have the keys 'group_by_fields' and 'values'." + assert env_config["split_name"] in dataset.split_dict.keys(), f"The split_name must be one of the splits of the dataset: {list(dataset.split_dict.keys())}." + assert set(env_config["group_by_fields"]) <= set(dataset.metadata_fields), "The fields to be selected are not in the metadata of this dataset." + + # Mask for the split + split_index = dataset.split_dict[env_config["split_name"]] + split_mask = torch.tensor((dataset.split_array == split_index)) + + inds_to_select = [] + for field in env_config['group_by_fields']: + ind_field_in_metadata = dataset.metadata_fields.index(field) + unique_values = torch.unique(dataset.metadata_array[:, ind_field_in_metadata]).numpy() + # env_config["values"][field] is a list, it comes from the configuration dictionary. + assert set(env_config["values"][field]) <= set(unique_values), f"The values for the field {field} are not in the metadata of this dataset." + + # Mask for the values + value_mask = torch.zeros(len(dataset.metadata_array), dtype=torch.bool) + for value in env_config["values"][field]: + value_mask |= (dataset.metadata_array[:, ind_field_in_metadata] == value) + + # Combine masks and select index + combined_mask = value_mask & split_mask + inds_to_select.append(torch.where(combined_mask)[0]) + + self.inds_to_select = torch.sort(torch.cat(inds_to_select))[0].tolist() + # The WILDS dataset yields: (, tensor(1), tensor([0, 0, 1, 1])) + self.dataset = dataset + if transform is None: + self.transform = transforms.Compose( + [transforms.Resize((448, 448)), transforms.ToTensor()] + ) + else: + self.transform = transform + + def __len__(self): + return len(self.inds_to_select) + + def __getitem__(self, idx): + selected_idx = self.inds_to_select[idx] + image, label = self.dataset[selected_idx][0], self.dataset[selected_idx][1] + + if self.transform: + image = self.transform(image) + + return image, label + + +# import pyrootutils +# pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +# from wilds import get_dataset +# # from src.data.components.wilds_transforms import get_transform + +# # waterbirds missing +# dataset = get_dataset( +# dataset="waterbirds", +# download=False, +# unlabeled=False, +# root_dir="data/dg/dg_datasets/wilds" +# ) + +# import ipdb; ipdb.set_trace() + +# wilds_dataset = WILDSDatasetEnv( +# dataset=dataset, +# env_config={ +# "split_name": "val", +# "group_by_fields": ["hospital"], +# "values": {"hospital": [0]} +# }, +# transform=get_transform("camelyon17", 224) +# ) + diff --git a/src/data/components/wilds_transforms.py b/src/data/components/wilds_transforms.py new file mode 100644 index 0000000..2984aea --- /dev/null +++ b/src/data/components/wilds_transforms.py @@ -0,0 +1,451 @@ +# SOURCE: WILDS code +# https://github.com/p-lambda/wilds/blob/472677590de351857197a9bf24958838c39c272b/examples/transforms.py + +import copy +from typing import List + +import numpy as np +import torch +import torchvision.transforms as transforms +import torchvision.transforms.functional as TF +# from transformers import BertTokenizerFast, DistilBertTokenizerFast # won't be needed + +import torch +from PIL import Image, ImageOps, ImageEnhance, ImageDraw + +# -------------------------------------------------------------------------------------------------- +# Adapted from https://github.com/YBZh/Bridging_UDA_SSL + +def AutoContrast(img, _): + return ImageOps.autocontrast(img) + +def Brightness(img, v): + assert v >= 0.0 + return ImageEnhance.Brightness(img).enhance(v) + +def Color(img, v): + assert v >= 0.0 + return ImageEnhance.Color(img).enhance(v) + +def Contrast(img, v): + assert v >= 0.0 + return ImageEnhance.Contrast(img).enhance(v) + +def Equalize(img, _): + return ImageOps.equalize(img) + +def Invert(img, _): + return ImageOps.invert(img) + +def Identity(img, v): + return img + +def Posterize(img, v): # [4, 8] + v = int(v) + v = max(1, v) + return ImageOps.posterize(img, v) + +def Rotate(img, v): # [-30, 30] + return img.rotate(v) + +def Sharpness(img, v): # [0.1,1.9] + assert v >= 0.0 + return ImageEnhance.Sharpness(img).enhance(v) + +def ShearX(img, v): # [-0.3, 0.3] + return img.transform(img.size, Image.AFFINE, (1, v, 0, 0, 1, 0)) + +def ShearY(img, v): # [-0.3, 0.3] + return img.transform(img.size, Image.AFFINE, (1, 0, 0, v, 1, 0)) + +def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + v = v * img.size[0] + return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0)) + +def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0)) + +def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + v = v * img.size[1] + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v)) + +def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v)) + +def Solarize(img, v): # [0, 256] + assert 0 <= v <= 256 + return ImageOps.solarize(img, v) + +def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] => change to [0, 0.5] + assert 0.0 <= v <= 0.5 + + v = v * img.size[0] + return CutoutAbs(img, v) + +def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] + if v < 0: + return img + w, h = img.size + x_center = _sample_uniform(0, w) + y_center = _sample_uniform(0, h) + + x0 = int(max(0, x_center - v / 2.0)) + y0 = int(max(0, y_center - v / 2.0)) + x1 = min(w, x0 + v) + y1 = min(h, y0 + v) + + xy = (x0, y0, x1, y1) + color = (125, 123, 114) + img = img.copy() + ImageDraw.Draw(img).rectangle(xy, color) + return img + +FIX_MATCH_AUGMENTATION_POOL = [ + (AutoContrast, 0, 1), + (Brightness, 0.05, 0.95), + (Color, 0.05, 0.95), + (Contrast, 0.05, 0.95), + (Equalize, 0, 1), + (Identity, 0, 1), + (Posterize, 4, 8), + (Rotate, -30, 30), + (Sharpness, 0.05, 0.95), + (ShearX, -0.3, 0.3), + (ShearY, -0.3, 0.3), + (Solarize, 0, 256), + (TranslateX, -0.3, 0.3), + (TranslateY, -0.3, 0.3), +] + +def _sample_uniform(a, b): + return torch.empty(1).uniform_(a, b).item() + +class RandAugment: + def __init__(self, n, augmentation_pool): + assert n >= 1, "RandAugment N has to be a value greater than or equal to 1." + self.n = n + self.augmentation_pool = augmentation_pool + + def __call__(self, img): + ops = [ + self.augmentation_pool[torch.randint(len(self.augmentation_pool), (1,))] + for _ in range(self.n) + ] + for op, min_val, max_val in ops: + val = min_val + float(max_val - min_val) * _sample_uniform(0, 1) + img = op(img, val) + cutout_val = _sample_uniform(0, 1) * 0.5 + img = Cutout(img, cutout_val) + return img +# -------------------------------------------------------------------------------------------------- + +_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406] +_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225] + +def initialize_transform( + transform_name, config, dataset, is_training, additional_transform_name=None +): + """ + By default, transforms should take in `x` and return `transformed_x`. + For transforms that take in `(x, y)` and return `(transformed_x, transformed_y)`, + set `do_transform_y` to True when initializing the WILDSSubset. + """ + if transform_name is None: + return None + elif transform_name == "bert": + return initialize_bert_transform(config) + elif transform_name == 'rxrx1': + return initialize_rxrx1_transform(is_training) + + # For images + normalize = True + if transform_name == "image_base": + transform_steps = get_image_base_transform_steps(config, dataset) + elif transform_name == "image_resize": + transform_steps = get_image_resize_transform_steps( + config, dataset + ) + elif transform_name == "image_resize_and_center_crop": + transform_steps = get_image_resize_and_center_crop_transform_steps( + config, dataset + ) + elif transform_name == "poverty": + if not is_training: + return None + transform_steps = [] + normalize = False + else: + raise ValueError(f"{transform_name} not recognized") + + default_normalization = transforms.Normalize( + _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, + _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD, + ) + if additional_transform_name == "fixmatch": + if transform_name == 'poverty': + transformations = add_poverty_fixmatch_transform(config, dataset, transform_steps) + else: + transformations = add_fixmatch_transform( + config, dataset, transform_steps, default_normalization + ) + transform = MultipleTransforms(transformations) + elif additional_transform_name == "randaugment": + if transform_name == 'poverty': + transform = add_poverty_rand_augment_transform( + config, dataset, transform_steps + ) + else: + transform = add_rand_augment_transform( + config, dataset, transform_steps, default_normalization + ) + elif additional_transform_name == "weak": + transform = add_weak_transform( + config, dataset, transform_steps, normalize, default_normalization + ) + else: + if transform_name != "poverty": + # The poverty data is already a tensor at this point + transform_steps.append(transforms.ToTensor()) + if normalize: + transform_steps.append(default_normalization) + transform = transforms.Compose(transform_steps) + + return transform + + +def initialize_bert_transform(config): + def get_bert_tokenizer(model): + if model == "bert-base-uncased": + return BertTokenizerFast.from_pretrained(model) + elif model == "distilbert-base-uncased": + return DistilBertTokenizerFast.from_pretrained(model) + else: + raise ValueError(f"Model: {model} not recognized.") + + assert "bert" in config.model + assert config.max_token_length is not None + + tokenizer = get_bert_tokenizer(config.model) + + def transform(text): + tokens = tokenizer( + text, + padding="max_length", + truncation=True, + max_length=config.max_token_length, + return_tensors="pt", + ) + if config.model == "bert-base-uncased": + x = torch.stack( + ( + tokens["input_ids"], + tokens["attention_mask"], + tokens["token_type_ids"], + ), + dim=2, + ) + elif config.model == "distilbert-base-uncased": + x = torch.stack((tokens["input_ids"], tokens["attention_mask"]), dim=2) + x = torch.squeeze(x, dim=0) # First shape dim is always 1 + return x + + return transform + +def initialize_rxrx1_transform(is_training): + def standardize(x: torch.Tensor) -> torch.Tensor: + mean = x.mean(dim=(1, 2)) + std = x.std(dim=(1, 2)) + std[std == 0.] = 1. + return TF.normalize(x, mean, std) + t_standardize = transforms.Lambda(lambda x: standardize(x)) + + angles = [0, 90, 180, 270] + def random_rotation(x: torch.Tensor) -> torch.Tensor: + angle = angles[torch.randint(low=0, high=len(angles), size=(1,))] + if angle > 0: + x = TF.rotate(x, angle) + return x + t_random_rotation = transforms.Lambda(lambda x: random_rotation(x)) + + if is_training: + transforms_ls = [ + t_random_rotation, + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + t_standardize, + ] + else: + transforms_ls = [ + transforms.ToTensor(), + t_standardize, + ] + transform = transforms.Compose(transforms_ls) + return transform + +def get_image_base_transform_steps(config, dataset) -> List: + transform_steps = [] + + if dataset.original_resolution is not None and min( + dataset.original_resolution + ) != max(dataset.original_resolution): + crop_size = min(dataset.original_resolution) + transform_steps.append(transforms.CenterCrop(crop_size)) + + if config.target_resolution is not None: + transform_steps.append(transforms.Resize(config.target_resolution)) + + return transform_steps + + +def get_image_resize_and_center_crop_transform_steps(config, dataset) -> List: + """ + Resizes the image to a slightly larger square then crops the center. + """ + transform_steps = get_image_resize_transform_steps(config, dataset) + target_resolution = _get_target_resolution(config, dataset) + transform_steps.append( + transforms.CenterCrop(target_resolution), + ) + return transform_steps + + +def get_image_resize_transform_steps(config, dataset) -> List: + """ + Resizes the image to a slightly larger square. + """ + assert dataset.original_resolution is not None + assert config.resize_scale is not None + scaled_resolution = tuple( + int(res * config.resize_scale) for res in dataset.original_resolution + ) + return [ + transforms.Resize(scaled_resolution) + ] + +def add_fixmatch_transform(config, dataset, base_transform_steps, normalization): + return ( + add_weak_transform(config, dataset, base_transform_steps, True, normalization), + add_rand_augment_transform(config, dataset, base_transform_steps, normalization) + ) + +def add_poverty_fixmatch_transform(config, dataset, base_transform_steps): + return ( + add_weak_transform(config, dataset, base_transform_steps, False, None), + add_poverty_rand_augment_transform(config, dataset, base_transform_steps) + ) + +def add_weak_transform(config, dataset, base_transform_steps, should_normalize, normalization): + # Adapted from https://github.com/YBZh/Bridging_UDA_SSL + target_resolution = _get_target_resolution(config, dataset) + weak_transform_steps = copy.deepcopy(base_transform_steps) + weak_transform_steps.extend( + [ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop( + size=target_resolution, + ), + ] + ) + if should_normalize: + weak_transform_steps.append(transforms.ToTensor()) + weak_transform_steps.append(normalization) + return transforms.Compose(weak_transform_steps) + +def add_rand_augment_transform(config, dataset, base_transform_steps, normalization): + # Adapted from https://github.com/YBZh/Bridging_UDA_SSL + target_resolution = _get_target_resolution(config, dataset) + strong_transform_steps = copy.deepcopy(base_transform_steps) + strong_transform_steps.extend( + [ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop( + size=target_resolution + ), + RandAugment( + n=config.randaugment_n, + augmentation_pool=FIX_MATCH_AUGMENTATION_POOL, + ), + transforms.ToTensor(), + normalization, + ] + ) + return transforms.Compose(strong_transform_steps) + +def poverty_rgb_color_transform(ms_img, transform): + from wilds.datasets.poverty_dataset import _MEANS_2009_17, _STD_DEVS_2009_17 + poverty_rgb_means = np.array([_MEANS_2009_17[c] for c in ['RED', 'GREEN', 'BLUE']]).reshape((-1, 1, 1)) + poverty_rgb_stds = np.array([_STD_DEVS_2009_17[c] for c in ['RED', 'GREEN', 'BLUE']]).reshape((-1, 1, 1)) + + def unnormalize_rgb_in_poverty_ms_img(ms_img): + result = ms_img.detach().clone() + result[:3] = (result[:3] * poverty_rgb_stds) + poverty_rgb_means + return result + + def normalize_rgb_in_poverty_ms_img(ms_img): + result = ms_img.detach().clone() + result[:3] = (result[:3] - poverty_rgb_means) / poverty_rgb_stds + return ms_img + + color_transform = transforms.Compose([ + transforms.Lambda(lambda ms_img: unnormalize_rgb_in_poverty_ms_img(ms_img)), + transform, + transforms.Lambda(lambda ms_img: normalize_rgb_in_poverty_ms_img(ms_img)), + ]) + # The first three channels of the Poverty MS images are BGR + # So we shuffle them to the standard RGB to do the ColorJitter + # Before shuffling them back + ms_img[:3] = color_transform(ms_img[[2,1,0]])[[2,1,0]] # bgr to rgb to bgr + return ms_img + +def add_poverty_rand_augment_transform(config, dataset, base_transform_steps): + def poverty_color_jitter(ms_img): + return poverty_rgb_color_transform( + ms_img, + transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.1)) + + def ms_cutout(ms_img): + def _sample_uniform(a, b): + return torch.empty(1).uniform_(a, b).item() + + assert ms_img.shape[1] == ms_img.shape[2] + img_width = ms_img.shape[1] + cutout_width = _sample_uniform(0, img_width/2) + cutout_center_x = _sample_uniform(0, img_width) + cutout_center_y = _sample_uniform(0, img_width) + x0 = int(max(0, cutout_center_x - cutout_width/2)) + y0 = int(max(0, cutout_center_y - cutout_width/2)) + x1 = int(min(img_width, cutout_center_x + cutout_width/2)) + y1 = int(min(img_width, cutout_center_y + cutout_width/2)) + + # Fill with 0 because the data is already normalized to mean zero + ms_img[:, x0:x1, y0:y1] = 0 + return ms_img + + target_resolution = _get_target_resolution(config, dataset) + strong_transform_steps = copy.deepcopy(base_transform_steps) + strong_transform_steps.extend([ + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), shear=0.1, scale=(0.9, 1.1)), + transforms.Lambda(lambda ms_img: poverty_color_jitter(ms_img)), + transforms.Lambda(lambda ms_img: ms_cutout(ms_img)), + # transforms.Lambda(lambda ms_img: viz(ms_img)), + ]) + + return transforms.Compose(strong_transform_steps) + +def _get_target_resolution(config, dataset): + if config.target_resolution is not None: + return config.target_resolution + else: + return dataset.original_resolution + + +class MultipleTransforms(object): + """When multiple transformations of the same data need to be returned.""" + + def __init__(self, transformations): + self.transformations = transformations + + def __call__(self, x): + return tuple(transform(x) for transform in self.transformations) \ No newline at end of file diff --git a/src/data/diagvib_datamodules.py b/src/data/diagvib_datamodules.py index 501add9..12bae15 100644 --- a/src/data/diagvib_datamodules.py +++ b/src/data/diagvib_datamodules.py @@ -3,7 +3,7 @@ import os import os.path as osp -from torch.utils.data import DataLoader, Subset, ConcatDataset, SequentialSampler +from torch.utils.data import DataLoader, Subset, ConcatDataset from pytorch_lightning import LightningDataModule from diagvibsix.data.dataset.preprocess_mnist import get_processed_mnist @@ -11,10 +11,9 @@ from src.data.components import MultienvDataset from src.data.components.collate_functions import MultiEnv_collate_fn from src.data.components.diagvib_dataset import DiagVib6DatasetPA, select_dataset_spec -from src.pa_metric_torch import PosteriorAgreementSampler +from src.pa_metric.pairing import PosteriorAgreementDatasetPairing import torch.distributed -from torch.utils.data.distributed import DistributedSampler class DiagVibDataModuleMultienv(LightningDataModule): """ @@ -133,7 +132,7 @@ def train_dataloader(self): batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, - drop_last=True, + drop_last=True, # for LISA shuffle=True, collate_fn=self.collate_fn, ) @@ -145,7 +144,7 @@ def val_dataloader(self): batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, - drop_last=True, + drop_last=False, # for LISA shuffle=False, collate_fn=self.collate_fn ) @@ -165,46 +164,35 @@ def __init__(self, def setup(self, stage: Optional[str] = None): dataset_specs_path, cache_filepath = select_dataset_spec(dataset_dir=self.datasets_dir, dataset_name= self.envs_name + str(self.envs_index[0])) - self.test_ds1 = DiagVib6DatasetPA( + self.ds1 = DiagVib6DatasetPA( mnist_preprocessed_path = self.mnist_preprocessed_path, dataset_specs_path=dataset_specs_path, cache_filepath=cache_filepath, t='test') dataset_specs_path, cache_filepath = select_dataset_spec(dataset_dir=self.datasets_dir, dataset_name= self.envs_name + str(self.envs_index[1])) - self.test_ds2 = DiagVib6DatasetPA( + self.ds2 = DiagVib6DatasetPA( mnist_preprocessed_path = self.mnist_preprocessed_path, dataset_specs_path=dataset_specs_path, cache_filepath=cache_filepath, t='test') - self.test_ds2_shifted = self._apply_shift_ratio(self.test_ds1, self.test_ds2) - self.test_pairedds = MultienvDataset([self.test_ds1, self.test_ds2_shifted]) - - def _set_sampler(self): - """ - I don't need to disable the shuffling in the DistributedSampler to get corresponding observations X and X', as these are paired in the - collate function. Nevertheless, I want to control strictly the data that is used for the PA optimization so that I can compare with the metric. - """ - - ddp_init = torch.distributed.is_available() and torch.distributed.is_initialized() - if ddp_init: - return PosteriorAgreementSampler(self.test_pairedds, shuffle=False, drop_last = True) - else: - return PosteriorAgreementSampler(self.test_pairedds, shuffle=False, drop_last = True, num_replicas=1, rank=0) + self.ds2_shifted = self._apply_shift_ratio(self.ds1, self.ds2) + self.train_ds = PosteriorAgreementDatasetPairing(MultienvDataset([self.ds1, self.ds2_shifted])) def train_dataloader(self): return DataLoader( - dataset=self.test_pairedds, + dataset=self.train_ds, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, collate_fn=MultiEnv_collate_fn, - sampler=self._set_sampler() + shuffle=True, + drop_last=False ) def val_dataloader(self): - return self.train_dataloader() + return self.train_dataloader() # I don't need to shuffle but it's irrelevant if I do. def _apply_shift_ratio(self, ds1, ds2): """Generates the two-environment dataset adding (1-shift_ratio)*len(ds2) samples of ds1 to ds2. @@ -232,74 +220,8 @@ def _apply_shift_ratio(self, ds1, ds2): class DiagVibDataModulePAlogits(LogitsPA, DiagVibDataModulePA): def __init__(self, classifier: torch.nn.Module, *args, **kwargs): - super().__init__(*args, **kwargs) - self.classifier = classifier - + super().__init__(classifier) + DiagVibDataModulePA.__init__(self, *args, **kwargs) if __name__ == "__main__": - LightningDataModule() - - -# EXAMPLE OF USE: -""" -import numpy as np -import matplotlib.pyplot as plt - -dm = DiagVibDataModuleMultienv( - envs_index = [0, 1], - envs_name = 'bal', - datasets_dir = "/cluster/home/vjimenez/adv_pa_new/data/dg/dg_datasets/submission/toy_dataset/", - disjoint_envs = True, - train_val_sequential = True, - mnist_preprocessed_path="data/dg/mnist_processed.npz", - collate_fn = MultiEnv_collate_fn, - batch_size = 5 -) - -dm.prepare_data() -dm.setup() -train_dl = iter(dm.train_dataloader()) -output = train_dl.__next__() -print(output.keys()) - -for i in range(2): - images = output[str(i)][0] - targets = output[str(i)][1] - for j in range(2): - im = np.transpose(images[j], (1, 2, 0)) - target = targets[j] - plt.imshow(im) - plt.title(str(target)) - plt.savefig("/cluster/home/vjimenez/adv_pa_new/" + f"train_{i}_{j}.png") - - -dm = DiagVibDataModuleTestPA( - envs_index = [0, 1], - envs_name = 'bal', - datasets_dir = "/cluster/home/vjimenez/adv_pa_new/data/dg/dg_datasets/submission/toy_dataset/", - mnist_preprocessed_path="data/dg/mnist_processed.npz", - collate_fn = MultiEnv_collate_fn, - batch_size = 5 -) - -dm.prepare_data() -dm.setup() -test_dl = iter(dm.train_dataloader()) -output = test_dl.__next__() - -print(type(output)) -print(output.keys()) - - -import matplotlib.pyplot as plt - -for i in range(2): - images = output[str(i)][0] - targets = output[str(i)][1] - for j in range(2): - im = np.transpose(images[j], (1, 2, 0)) - target = targets[j] - plt.imshow(im) - plt.title(str(target)) - plt.savefig("/cluster/home/vjimenez/adv_pa_new/" + f"test_{i}_{j}.png") -""" \ No newline at end of file + LightningDataModule() \ No newline at end of file diff --git a/src/data/wilds_datamodules.py b/src/data/wilds_datamodules.py new file mode 100644 index 0000000..8f62fb6 --- /dev/null +++ b/src/data/wilds_datamodules.py @@ -0,0 +1,140 @@ +from typing import Callable, Optional, Union, List + +from omegaconf import DictConfig, OmegaConf + +import os.path as osp +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from pytorch_lightning import LightningDataModule +from pytorch_lightning.trainer.supporters import CombinedLoader + +from wilds import get_dataset +from src.data.components.wilds_dataset import WILDSDatasetEnv + +class WILDSDataModule(LightningDataModule): + def __init__( + self, + dataset_name: str, + n_classes: int, + train_config: Union[dict, DictConfig], + val_config: Optional[Union[dict, DictConfig]] = None, + test_config: Optional[Union[dict, DictConfig]] = None, + transform: Optional[Callable] = None, # TODO: Figure out the transform thing + dataset_dir: str = osp.join(".", "data", "datasets"), + cache: bool = True, + batch_size: int = 64, + pin_memory: bool = True, + num_workers: int = 0, + multiple_trainloader_mode='max_size_cycle', + ): + + super().__init__() + self.save_hyperparameters(logger=False) + + # Select train, val and test configurations + self.hparams.train_config = OmegaConf.to_container(train_config, resolve=True) if isinstance(train_config, DictConfig) else train_config + if val_config is not None: + self.hparams.val_config = OmegaConf.to_container(val_config, resolve=True) if isinstance(val_config, DictConfig) else val_config + if test_config is not None: + self.hparams.test_config = OmegaConf.to_container(test_config, resolve=True) if isinstance(test_config, DictConfig) else test_config + + self.train_dset_list, self.val_dset_list, self.test_dset_list = None, None, None + + @property + def num_classes(self): + return self.hparams.n_classes + + def prepare_data(self): + # If the dataset does not exist or cache is set to False, download data + if osp.exists(self.hparams.dataset_dir) == False or self.hparams.cache == False: + get_dataset( + dataset=self.hparams.dataset_name, + download=True, + unlabeled=False, + root_dir=self.hparams.dataset_dir + ) + + def setup(self, stage: Optional[str] = None): + self.dataset = get_dataset( + dataset=self.hparams.dataset_name, + download=False, + unlabeled=False, + root_dir=self.hparams.dataset_dir + ) + + if stage == "fit": + self.num_train_envs = len(self.hparams.train_config.keys()) + self.train_dset_list = [] + for env in self.hparams.train_config.keys(): + env_dset = WILDSDatasetEnv( + dataset=self.dataset, + env_config=self.hparams.train_config[env], + transform=self.hparams.transform + ) + self.train_dset_list.append(env_dset) + + if self.hparams.val_config is not None: + self.num_val_envs = len(self.hparams.val_config.keys()) + self.val_dset_list = [] + for env in self.hparams.val_config.keys(): + env_dset = WILDSDatasetEnv( + dataset=self.dataset, + env_config=self.hparams.val_config[env], + transform=self.hparams.transform + ) + self.val_dset_list.append(env_dset) + + if stage == "test": + if self.hparams.test_config is not None: + self.num_test_envs = len(self.hparams.test_config.keys()) + self.test_dset_list = [] + for env in self.hparams.test_config.keys(): + env_dset = WILDSDatasetEnv( + dataset=self.dataset, + env_config=self.hparams.test_config[env], + transform=self.hparams.transform + ) + self.test_dset_list.append(env_dset) + + def train_dataloader(self): + # Dictionary of dataloaders for the training. + return { + str(i): DataLoader( + dataset=ds, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + sampler = DistributedSampler(ds, drop_last=True, shuffle=True) if dist.is_initialized() else DistributedSampler(ds, drop_last=True, shuffle=True, num_replicas=1, rank=0), + ) for i, ds in enumerate(self.train_dset_list) + } + + def val_dataloader(self): + """ + We set `shuffle=True` because each dataset has different size, and we want that the probability of each sample to be selected + within each dataset is the same. + """ + if self.val_dset_list is not None: + return CombinedLoader({ + str(i): DataLoader( + dataset=ds, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + sampler = DistributedSampler(ds, drop_last=True, shuffle=False) if dist.is_initialized() else DistributedSampler(ds, drop_last=True, shuffle=False, num_replicas=1, rank=0), + ) for i, ds in enumerate(self.val_dset_list) + }, self.hparams.multiple_trainloader_mode) + + def test_dataloader(self): + """ + Here we can set `shuffle=False`, because only one dataset/environment is used. + """ + if self.test_dset_list is not None: + assert len(self.test_dset_list) == 1, "The test dataloader must only contain one environment." + return DataLoader( + dataset=self.test_dset_list[0], + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + sampler = DistributedSampler(self.test_dset_list[0], drop_last=False, shuffle=False, num_replicas=1, rank=0) + ) \ No newline at end of file diff --git a/src/generate_dg_data.py b/src/generate_dg_data.py index 771f6ff..00d0861 100644 --- a/src/generate_dg_data.py +++ b/src/generate_dg_data.py @@ -1,10 +1,11 @@ import hydra -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pytorch_lightning import seed_everything import os import numpy as np import pandas as pd +import json import pyrootutils pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) @@ -105,7 +106,7 @@ def main(cfg: DictConfig): os.makedirs(DATASETS_DIR, exist_ok=True) TO_BALANCE = cfg.get("BALANCE") - envs_name = cfg.get("envs_name") + file_name = cfg.get("file_name") SIZE_TRAIN = cfg.get("SIZE_TRAIN") SIZE_VAL = cfg.get("SIZE_VAL") @@ -117,7 +118,7 @@ def main(cfg: DictConfig): ## TRAINING & VALIDATION --------------------------------------------- # hue, lightness, texture, position, scale train_val_especs = cfg.get("train_val_especs") - train_val_envs = list(train_val_especs.keys()) + train_val_envs = list(train_val_especs.keys()) # number of environments tran_val_randperm = cfg.get("train_val_randperm") sizes = [SIZE_TRAIN, SIZE_VAL] @@ -152,7 +153,7 @@ def main(cfg: DictConfig): } ) - path = DATASETS_DIR + task + "_" + envs_name + str(t) + ".csv" + path = DATASETS_DIR + task + "_" + file_name + str(t) + ".csv" df.to_csv(path, index=False) if TO_BALANCE: balance_dataset(path, path, batch_size = cfg.get("BATCH_SIZE")) @@ -183,7 +184,7 @@ def main(cfg: DictConfig): } ) - path = DATASETS_DIR + "test_" + envs_name + str(t) + ".csv" + path = DATASETS_DIR + "test_" + file_name + str(t) + ".csv" df.to_csv(path, index=False) if t == 0 and TO_BALANCE: # we balance the first, and then copy the permutation to the rest balance_dataset(path, path, batch_size = cfg.get("BATCH_SIZE")) @@ -195,6 +196,11 @@ def main(cfg: DictConfig): df['permutation'] = master_permutation df.to_csv(path, index=False) + # Finally, we also store the configuration that generated such dataset: + dict_cfg = OmegaConf.to_container(cfg, resolve=True) + with open(DATASETS_DIR + "config.json", 'w') as json_file: + json.dump(dict_cfg, json_file, indent=4) + if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/models/PA_module.py b/src/models/PA_module.py index 52bac76..350bd3b 100644 --- a/src/models/PA_module.py +++ b/src/models/PA_module.py @@ -43,7 +43,7 @@ def model_step(self, batch: dict): self.kernel.beta.data.clamp_(min=0.0) self.kernel.reset() - env_names = batch["envs"] + env_names = list(batch.keys()) x1, x2 = batch[env_names[0]][0], batch[env_names[1]][0] with torch.no_grad(): @@ -56,12 +56,17 @@ def model_step(self, batch: dict): def training_step(self, train_batch: Any, batch_idx: int): o1, o2, loss = self.model_step(train_batch) + env_names = list(train_batch.keys()) + + import ipdb; ipdb.set_trace() + print("METRIC sums env0", [torch.sum(train_batch['0'][0][0]) for i in range(16)]) + print(f"\nTRAIN MODULE: batch idx {batch_idx}", torch.sum(train_batch['0'][0][0]), len(o1), o1, o2) if self.current_epoch == 0: # AFR does not change during the epochs y_pred = torch.argmax(o1.data, 1) y_pred_adv = torch.argmax(o2.data, 1) - y_true = train_batch[train_batch["envs"][0]][1] - assert torch.equal(y_true, train_batch[train_batch["envs"][1]][1]), "The true label tensors are not equal." + y_true = train_batch[env_names[0]][1] + assert torch.equal(y_true, train_batch[env_names[1]][1]), "The true label tensors are not equal." # Second, compute the AFR values = { @@ -78,13 +83,13 @@ def on_train_batch_end(self, out, batch, bidx): self.betas.append(self.kernel.beta.item()) def on_validation_start(self): - if self.trainer.is_last_batch: + if self.trainer.is_last_batch: self.model.eval() self.kernel.reset() def validation_step(self, batch: Any, bidx: int): if self.trainer.is_last_batch: # last batch for the trainer - env_names = batch["envs"] + env_names = list(batch.keys()) x1, x2 = batch[env_names[0]][0], batch[env_names[1]][0] o1, o2 = self.model(x1), self.model(x2) self.kernel.evaluate(self.betas[-1], o1, o2) diff --git a/src/models/erm.py b/src/models/erm.py new file mode 100644 index 0000000..f26f1af --- /dev/null +++ b/src/models/erm.py @@ -0,0 +1,86 @@ +import torch +from torch import nn, optim +from omegaconf import OmegaConf, DictConfig +from pytorch_lightning import LightningModule +from pytorch_lightning.core.optimizer import LightningOptimizer +from torchmetrics import Accuracy, F1Score, Recall, Specificity, Precision + +class ERM(LightningModule): + """Vanilla ERM traning scheme for fitting a NN to the training data""" + + def __init__( + self, + n_classes: int, + net: nn.Module, + loss: nn.Module, + optimizer: optim.Optimizer, + scheduler: DictConfig, + ): + super().__init__() + + self.model = net + self.loss = loss + self.save_hyperparameters(ignore=["net"]) + + def _extract_batch(self, batch: Union[dict, tuple]): + """ + The batch can come from either a multienvironment CombinedLoader or from a single environment DataLoader. That is + equivalent to using a `MultiEnv_collate_fn` or a `SingleEnv_collate_fn`. + """ + if isinstance(batch, dict): + x = torch.cat([batch[env][0] for env in batch.keys()], dim=0) + y = torch.cat([batch[env][1] for env in batch.keys()]) + return x, y + else: + return batch + + def training_step(self, batch, batch_idx): + x, y = self._extract_batch(batch) + + logits = self.model(x) + return { + "loss": self.loss(input=logits, target=y), + "logits": logits, + "targets": y, + "preds": torch.argmax(logits, dim=1) + } + + def validation_step(self, batch, batch_idx): + x, y = self._combined_loader_to_single(batch) if isinstance(batch, dict) else batch + + logits = self.model(x) + return { + "loss": self.loss(input=logits, target=y), + "logits": logits, + "targets": y, + "preds": torch.argmax(logits, dim=1) + } + + def test_step(self, batch, batch_idx): + x, y = batch + + logits = self.model(x) + return { + "loss": self.loss(input=logits, target=y), + "logits": logits, + "targets": y, + "preds": torch.argmax(logits, dim=1) + } + + def configure_optimizers(self): + optimizer = LightningOptimizer(self.hparams.optimizer(params=self.parameters())) + + if self.hparams.scheduler: + scheduler = self.hparams.scheduler.scheduler(optimizer=optimizer) + + scheduler_dict = OmegaConf.to_container(self.hparams.scheduler, resolve=True) # convert to normal dict + scheduler_dict.update({ + "scheduler": scheduler, + }) + + return { + "optimizer": optimizer, + "lr_scheduler": scheduler_dict + } + + return {"optimizer": optimizer} diff --git a/src/models/erm_module.py b/src/models/erm_module.py deleted file mode 100644 index d2912ac..0000000 --- a/src/models/erm_module.py +++ /dev/null @@ -1,170 +0,0 @@ -from pytorch_lightning import LightningModule -from torch import nn, argmax, optim -from torchmetrics import Accuracy, F1Score, Recall, Specificity, Precision - -# For the PA metric -from src.pa_metric_torch import PosteriorAgreement -from src.data.diagvib_datamodules import DiagVibDataModulePA -from src.data.components.collate_functions import MultiEnv_collate_fn -from copy import deepcopy - -class ERM(LightningModule): - """Vanilla ERM traning scheme for fitting a NN to the training data""" - - def __init__( - self, - n_classes: int, - net: nn.Module, - optimizer: optim.Optimizer, - scheduler: optim.lr_scheduler, - ): - super().__init__() - - self.model = None - self.loss = None - self.save_hyperparameters(ignore=["net"]) # for easier retrieval from w&b and sanity checks - - self.n_classes = int(n_classes) - _task = "multiclass" if self.n_classes > 2 else "binary" - - # Training metrics - self.train_acc = Accuracy(task=_task, num_classes=self.n_classes, average="macro") - self.train_f1 = F1Score(task=_task, num_classes=self.n_classes, average="macro") - - self.train_specificity = Specificity(task=_task, num_classes=self.n_classes, average="macro") - self.train_sensitivity = Recall(task=_task, num_classes=self.n_classes, average="macro") - self.train_precision = Precision(task=_task, num_classes=self.n_classes, average="macro") - - self.val_acc = Accuracy(task=_task, num_classes=self.n_classes, average="macro") - self.val_f1 = F1Score(task=_task, num_classes=self.n_classes, average="macro") - self.val_specificity = Specificity(task=_task, num_classes=self.n_classes, average="macro") - self.val_sensitivity = Recall(task=_task, num_classes=self.n_classes, average="macro") - self.val_precision = Precision(task=_task, num_classes=self.n_classes, average="macro") - - # TO DELETE - # dm = DiagVibDataModuleTestPA( - # envs_index = [0, 1], - # shift_ratio = 1.0, - # envs_name = "val_repbal", # here the full name not only the environment, as we may want to use test_ or val_ or even a custom name - # datasets_dir = "./data/dg/dg_datasets/replicate/", - # mnist_preprocessed_path = "./data/dg/mnist_processed.npz", - # batch_size = 64, - # num_workers = 2, - # pin_memory = True, - # collate_fn = MultiEnv_collate_fn) - - # dm.prepare_data() - # dm.setup() - - # # DEBUGGING PA - # self.PA = PosteriorAgreement( - # dataset = dm.test_pairedds, - # pa_epochs = 100, - # early_stopping=[0.001, 5, 10], - # strategy = "lightning") - - def training_step(self, batch, batch_idx): - x, y = batch - logits = self.model(x) - - return {"logits": logits, "targets": y} - - def training_step_end(self, outputs): - logits, y = outputs["logits"], outputs["targets"] - preds = argmax(logits, dim=1) - - # Log training metrics - metrics_dict = { - "train/loss": self.loss(input=logits, target=y), - "train/acc": self.train_acc(preds, y), - "train/f1": self.train_f1(preds, y), - "train/sensitivity": self.train_sensitivity(preds, y), - "train/specificity": self.train_specificity(preds, y), - "train/precision": self.train_precision(preds, y), - } - self.log_dict(metrics_dict, prog_bar=False, on_step=True, on_epoch=True, logger=True, sync_dist=True) - - # Return loss for optimization - return {"loss": metrics_dict["train/loss"]} - - def validation_step(self, batch, batch_idx): - x, y = batch - logits = self.model(x) - - return {"logits": logits, "targets": y} - - def validation_step_end(self, outputs): - logits, y = outputs["logits"], outputs["targets"] - preds = argmax(logits, dim=1) - - #Log PA in the last batch of the epoch. Log every n epochs. - # if self.trainer.is_last_batch and self.current_epoch % 2 == 0: - # self.PA.update(deepcopy(self.model)) - # pa_dict = self.PA.compute() - # metrics_dict = { - # "val/logPA": pa_dict["logPA"], - # "val/beta": pa_dict["beta"], - # "val/PA": pa_dict["PA"], - # "val/AFR pred": pa_dict["AFR pred"], - # "val/AFR true": pa_dict["AFR true"], - # "val/acc_pa": pa_dict["acc_pa"], - # } - # self.log_dict(metrics_dict, prog_bar=False, on_step=False, on_epoch=True, logger=True, sync_dist=True) - - # Log validation metrics - metrics_dict = { - "val/loss": self.loss(input=logits, target=y), - "val/acc": self.val_acc(preds, y), - "val/f1": self.val_f1(preds, y), - "val/sensitivity": self.val_sensitivity(preds, y), - "val/specificity": self.val_specificity(preds, y), - "val/precision": self.val_precision(preds, y), - } - self.log_dict(metrics_dict, prog_bar=False, on_step=True, on_epoch=True, logger=True, sync_dist=True) - - # Return loss for scheduler - return {"loss": metrics_dict["val/loss"]} - - def predict_step(self, predic_batch, batch_idx): - return self.model(predic_batch) - - def configure_optimizers(self): - """Choose what optimizers and learning-rate schedulers to use in your optimization. - Examples: - https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers - """ - optimizer = self.hparams.optimizer(params=self.parameters()) - if (self.hparams.scheduler is not None) and self.trainer.datamodule.val_dataloader(): - scheduler = self.hparams.scheduler(optimizer=optimizer) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": scheduler, - "monitor": "val/loss", - "interval": "epoch", - "frequency": 1, - }, - } - return {"optimizer": optimizer} - - -class ERMPerceptron(ERM): - def __init__(self, lr, weight_decay, n_classes, optimizer, momentum): - super().__init__(lr, weight_decay, n_classes, optimizer, momentum) - - self.model = nn.Linear(2, 2) - self.n_classes = n_classes - - -class ERMMnist(ERM): - def __init__( - self, - n_classes: int, - net: nn.Module, - optimizer: optim.Optimizer, - scheduler: optim.lr_scheduler, - ): - super().__init__(n_classes, net, optimizer, scheduler) - - self.model = net - self.loss = nn.CrossEntropyLoss() diff --git a/src/models/irm.py b/src/models/irm.py new file mode 100644 index 0000000..6462cbf --- /dev/null +++ b/src/models/irm.py @@ -0,0 +1,113 @@ +from pytorch_lightning import LightningModule, LightningDataModule +from pytorch_lightning.core.optimizer import LightningOptimizer +import torch +import os.path as osp +from torch import nn, argmax, optim +from torch.autograd import grad +from torchmetrics import Accuracy, F1Score, Recall, Specificity, Precision + +# For the PA metric +from src.pa_metric_torch import PosteriorAgreement +from src.data.diagvib_datamodules import DiagVibDataModulePA +from src.data.components.collate_functions import MultiEnv_collate_fn +from copy import deepcopy + +class IRM(LightningModule): + """Invariant Risk Minimization (IRM) module.""" + + def __init__( + self, + n_classes: int, + net: nn.Module, + loss: nn.Module, + optimizer: optim.Optimizer, + scheduler: DictConfig, + + lamb: float = 1.0 + ): + super().__init__() + + self.model = net + self.loss = loss + self.save_hyperparameters(ignore=["net"]) + + def compute_penalty(self, logits, y): + """ + Computes the additional penalty term to achieve invariant representation. + """ + dummy_w = torch.tensor(1.).to(self.device).requires_grad_() + with torch.enable_grad(): + loss = self.loss(logits*dummy_w, y).to(self.device) + gradient = grad(loss, [dummy_w], create_graph=True)[0] + return gradient**2 + + def training_step(self, batch, batch_idx): + loss = 0 + ys, preds = [], [] + for env in list(batch.keys()): + x, y = batch[env] + + logits = self.model(x) + penalty = self.compute_penalty(logits, y) + loss += self.loss(logits, y) + self.lamb*penalty + + ys.append(y) + preds.append(argmax(logits, dim=1)) + + return { + "loss": loss, + "logits": logits, + "targets": torch.cat(ys, dim=0), + "preds": torch.cat(preds, dim=0) + } + + def validation_step(self, batch, batch_idx): + loss = 0 + ys, preds = [], [] + for env in list(batch.keys()): + x, y = batch[env] + + logits = self.model(x) + penalty = self.compute_penalty(logits, y) + loss += self.loss(logits, y) + self.lamb*penalty + + ys.append(y) + preds.append(argmax(logits, dim=1)) + + return { + "loss": loss, + "logits": logits, + "targets": torch.cat(ys, dim=0), + "preds": torch.cat(preds, dim=0) + } + + def test_step(self, batch, batch_idx): + assert len(batch.keys()) == 1, "The test batch should have only one environment." + x, y = batch + + logits = self.model(x) + penalty = self.compute_penalty(logits, y) + return { + "loss": self.loss(logits, y) + self.lamb*penalty, + "logits": logits, + "targets": y, + "preds": argmax(logits, dim=1) + } + + def configure_optimizers(self): + optimizer = LightningOptimizer(self.hparams.optimizer(params=self.parameters())) + + if self.hparams.scheduler: + scheduler = self.hparams.scheduler.scheduler(optimizer=optimizer) + + scheduler_dict = OmegaConf.to_container(self.hparams.scheduler, resolve=True) # convert to normal dict + scheduler_dict.update({ + "scheduler": scheduler, + }) + + return { + "optimizer": optimizer, + "lr_scheduler": scheduler_dict + } + + return {"optimizer": optimizer} diff --git a/src/models/irm_module.py b/src/models/irm_module.py deleted file mode 100644 index 500dbb8..0000000 --- a/src/models/irm_module.py +++ /dev/null @@ -1,226 +0,0 @@ -from pytorch_lightning import LightningModule, LightningDataModule -import torch -import os.path as osp -from torch import nn, argmax, optim -from torch.autograd import grad -from torchmetrics import Accuracy, F1Score, Recall, Specificity, Precision - -# For the PA metric -from src.pa_metric_torch import PosteriorAgreement -from src.data.diagvib_datamodules import DiagVibDataModulePA -from src.data.components.collate_functions import MultiEnv_collate_fn -from copy import deepcopy - -class IRM(LightningModule): - """Invariant Risk Minimization (IRM) module.""" - - def __init__( - self, - n_classes: int, - net: nn.Module, - optimizer: optim.Optimizer, - scheduler: optim.lr_scheduler, - lamb: float = 1.0 - ): - super().__init__() - - self.model = None - self.loss = None - self.lamb = lamb - self.save_hyperparameters(ignore=["net"]) # for easier retrieval from w&b and sanity checks - - self.n_classes = int(n_classes) - _task = "multiclass" if self.n_classes > 2 else "binary" - - # Training metrics - self.train_acc = Accuracy(task=_task, num_classes=self.n_classes, average="macro") - self.train_f1 = F1Score(task=_task, num_classes=self.n_classes, average="macro") - self.train_specificity = Specificity(task=_task, num_classes=self.n_classes, average="macro") - self.train_sensitivity = Recall(task=_task, num_classes=self.n_classes, average="macro") - self.train_precision = Precision(task=_task, num_classes=self.n_classes, average="macro") - - # Validation metrics - self.val_acc = Accuracy(task=_task, num_classes=self.n_classes, average="macro") - self.val_f1 = F1Score(task=_task, num_classes=self.n_classes, average="macro") - self.val_specificity = Specificity(task=_task, num_classes=self.n_classes, average="macro") - self.val_sensitivity = Recall(task=_task, num_classes=self.n_classes, average="macro") - self.val_precision = Precision(task=_task, num_classes=self.n_classes, average="macro") - - # PA metric - # TO DELETE - # dm = DiagVibDataModuleTestPA( - # envs_index = [0, 1], - # shift_ratio = 1.0, - # envs_name = "val_randnobal", # here the full name not only the environment, as we may want to use test_ or val_ or even a custom name - # datasets_dir = "./data/dg/dg_datasets/randnobal/", - # mnist_preprocessed_path = "./data/dg/mnist_processed.npz", - # batch_size = 64, - # num_workers = 2, - # pin_memory = True, - # collate_fn = MultiEnv_collate_fn) - - # dm.prepare_data() - # dm.setup() - - # self.PA = PosteriorAgreement( - # dataset = dm.test_pairedds, - # pa_epochs = 10, - # strategy = "lightning") - - def compute_penalty(self, logits, y): - dummy_w = torch.tensor(1.).to(self.device).requires_grad_() - with torch.enable_grad(): - loss = self.loss(logits*dummy_w, y).to(self.device) - gradient = grad(loss, [dummy_w], create_graph=True)[0] - return gradient**2 - - def training_step(self, batch, batch_idx): - envs = batch["envs"] - outputs = {} - for env in envs: - x, y = batch[env] - logits = self.model(x) - - outputs[env] = { - "logits": logits, - "targets": y, - "penalty": self.compute_penalty(logits, y) - } - - return outputs - - def training_step_end(self, outputs): - # Separate the training_step bc we want to sum losses from different GPUs here. - loss = 0 - preds = [] - ys = [] - for env, output in outputs.items(): - logits = output["logits"] - y = output["targets"] - penalty = output["penalty"] - - env_loss = self.loss(logits, y) - loss += env_loss + self.lamb*penalty - - ys.append(y) - preds.append(argmax(logits, dim=1)) - y = torch.cat(ys, dim=0) - preds = torch.cat(preds, dim=0) - - # Log training metrics - metrics_dict = { - "train/loss": loss, - "train/acc": self.train_acc(preds, y), - "train/f1": self.train_f1(preds, y), - "train/sensitivity": self.train_sensitivity(preds, y), - "train/specificity": self.train_specificity(preds, y), - "train/precision": self.train_precision(preds, y), - } - self.log_dict(metrics_dict, prog_bar=False, on_step=True, on_epoch=True, logger=True, sync_dist=True) - - # Return loss for optimization - return {"loss": loss} - - - def validation_step(self, batch, batch_idx): - envs = batch["envs"] - - outputs = {} - for env in envs: - x, y = batch[env] - - with torch.enable_grad(): - logits = self.model(x) - - outputs[env] = { - "logits": logits, - "targets": y, - "penalty": self.compute_penalty(logits, y) - } - - return outputs - - - def validation_step_end(self, outputs): - loss = 0 - preds = [] - ys = [] - for env, output in outputs.items(): - logits = output["logits"] - y = output["targets"] - penalty = output["penalty"] - - env_loss = self.loss(logits, y) - loss += env_loss + self.lamb*penalty - - ys.append(y) - preds.append(argmax(logits, dim=1)) - y = torch.cat(ys, dim=0) - preds = torch.cat(preds, dim=0) - - # Log PA in the last batch of the epoch. Log every n epochs. - # if self.trainer.is_last_batch and self.current_epoch % 2 == 0: - # self.PA.update(deepcopy(self.model)) - # pa_dict = self.PA.compute() - # metrics_dict = { - # "val/logPA": pa_dict["logPA"], - # "val/beta": pa_dict["beta"], - # "val/PA": pa_dict["PA"], - # "val/AFR pred": pa_dict["AFR pred"], - # "val/AFR true": pa_dict["AFR true"], - # "val/acc_pa": pa_dict["acc_pa"], - # } - # self.log_dict(metrics_dict, prog_bar=False, on_step=False, on_epoch=True, logger=True, sync_dist=True) - - # Log validation metrics - metrics_dict = { - "val/loss": loss, - "val/acc": self.val_acc(preds, y), - "val/f1": self.val_f1(preds, y), - "val/sensitivity": self.val_sensitivity(preds, y), - "val/specificity": self.val_specificity(preds, y), - "val/precision": self.val_precision(preds, y), - } - self.log_dict(metrics_dict, prog_bar=False, on_step=True, on_epoch=True, logger=True, sync_dist=True) - - # Return loss for scheduler - return {"loss": loss} - - def configure_optimizers(self): - optimizer = self.hparams.optimizer(params=self.parameters()) - if (self.hparams.scheduler is not None) and self.trainer.datamodule.val_dataloader(): - scheduler = self.hparams.scheduler(optimizer=optimizer) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": scheduler, - "monitor": "val/loss", - "interval": "epoch", - "frequency": 1, - }, - } - - return {"optimizer": optimizer} - - -class IRMMnist(IRM): - def __init__( - self, - n_classes: int, - net: nn.Module, - optimizer: optim.Optimizer, - scheduler: optim.lr_scheduler, - lamb: float = 1.0 # Penalty weight for IRM - ): - super().__init__(n_classes, net, optimizer, scheduler, lamb) - - self.model = net - self.loss = nn.CrossEntropyLoss() - -class IRMPerceptron(IRM): - def __init__(self, lr, weight_decay, n_classes, optimizer, momentum, lamb: float = 1.0): - super().__init__(lr, weight_decay, n_classes, optimizer, momentum, lamb) - - self.model = nn.Linear(2, 2) - self.n_classes = n_classes - diff --git a/src/models/lisa_module.py b/src/models/lisa.py similarity index 64% rename from src/models/lisa_module.py rename to src/models/lisa.py index 4cac6ef..87f81d0 100644 --- a/src/models/lisa_module.py +++ b/src/models/lisa.py @@ -1,4 +1,4 @@ -from src.models.erm_module import ERM +from src.models.erm import ERM import torch from torch import nn, Tensor, optim @@ -7,6 +7,8 @@ from copy import deepcopy +from src.models.utils import + class LISA(ERM): """ Implements selective augmentation on top of ERM. @@ -29,12 +31,17 @@ def __init__( self, n_classes: int, net: nn.Module, + loss: nn.Module, optimizer: optim.Optimizer, - scheduler: optim.lr_scheduler, + scheduler: DictConfig, + + mixup_strategy: str = "mixup", ppred: float = 0.5, # probability of LISA-L mix_alpha: float = 0.5 # mixup weight ): - super().__init__(n_classes, net, optimizer, scheduler) + super().__init__(n_classes, net, loss, optimizer, scheduler) + assert mixup_strategy in ["mixup", "cutmix"], "The mixup strategy must be either 'mixup' or 'cutmix'." + self.save_hyperparameters(ignore=['net']) def to_one_hot(self, target, C): @@ -46,7 +53,7 @@ def to_one_hot(self, target, C): def from_one_hot(self, one_hot): """Converts one-hot-encoded tensor to a tensor of labels.""" - + _, indices = torch.max(one_hot, 1) return indices.to(self.device) @@ -80,6 +87,35 @@ def mix_up(self, mix_alpha: float, x: Tensor, y: Tensor, x2: Tensor = None, y2: mixed_y = l_y * y1 + (1 - l_y) * y2 return mixed_x, mixed_y + def cut_mix(self, mix_alpha: float, x, y): + def _rand_bbox(size, lam): + W = size[2] + H = size[3] + cut_rat = torch.sqrt(1. - lam).to(self.device) + cut_w = (W * cut_rat).to(torch.int32).to(self.device) + cut_h = (H * cut_rat).to(torch.int32).to(self.device) + + # uniform + cx = torch.randint(0, W, (1,)).item() + cy = torch.randint(0, H, (1,)).item() + + bbx1 = torch.clamp(cx - torch.div(cut_w, 2, rounding_mode='trunc'), 0, W).to(self.device) + bby1 = torch.clamp(cy - torch.div(cut_h, 2, rounding_mode='trunc'), 0, H).to(self.device) + bbx2 = torch.clamp(cx + torch.div(cut_w, 2, rounding_mode='trunc'), 0, W).to(self.device) + bby2 = torch.clamp(cy + torch.div(cut_h, 2, rounding_mode='trunc'), 0, H).to(self.device) + return bbx1, bby1, bbx2, bby2 + + rand_index = torch.randperm(len(y)).to(self.device) + lam = Beta(mix_alpha, mix_alpha).sample().to(self.device) + target_a = y + target_b = y[rand_index] + bbx1, bby1, bbx2, bby2 = _rand_bbox(x.size(), lam) + x[:, :, bbx1:bbx2, bby1:bby2] = x[rand_index, :, bbx1:bbx2, bby1:bby2] + # adjust lambda to exactly match pixel ratio + lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2])) + + return x, lam*target_a + (1-lam)*target_b + def pair_lisa(self, cat: Tensor): """ It pairs observations with different attributes. It is important to note that the @@ -116,15 +152,18 @@ def pair_lisa(self, cat: Tensor): return B_1, B_2 # indexes - def model_step(self, batch: dict): + def selective_augmentation(self, batch: dict): """ Implements mixup and selective augmentation. """ - # get data and convert env to tensor - x = torch.cat([batch[env][0] for env in batch["envs"]]) - y = torch.cat([batch[env][1] for env in batch["envs"]]) - envs = [env for env in batch["envs"] for _ in range(len(batch[batch["envs"][0]][1]))] + x = torch.cat([batch[env][0] for env in batch.keys()]) + y = torch.cat([batch[env][1] for env in batch.keys()]) + envs = [ + env + for env in batch.keys() + for _ in range(len(batch[env][1])) + ] all_inds = torch.arange(len(envs)).to(self.device) env_to_int = {item: i for i, item in enumerate(set(envs))} @@ -157,44 +196,30 @@ def model_step(self, batch: dict): B1_lab, B2_lab = self.pair_lisa(y[mask]) # indexes wrt mask #print("LISA-D, samples with the same domain:", len(y[mask])) - # accumulate indexes wrt all observations + # accumulate indexes wrt all observations B1 = torch.cat((B1, torch.index_select(all_inds[mask], 0, B1_lab))) B2 = torch.cat((B2, torch.index_select(all_inds[mask], 0, B2_lab))) - # mixup - mixed_x, mixed_y = self.mix_up(self.hparams.mix_alpha, - torch.index_select(input=x, dim=0, index=B1.sort()[0]).to(self.device), - self.to_one_hot(torch.index_select(input=y, dim=0, index=B1.sort()[0]), self.hparams.n_classes), - torch.index_select(input=x, dim=0, index=B2.sort()[0]).to(self.device), - self.to_one_hot(torch.index_select(input=y, dim=0, index=B2.sort()[0]), self.hparams.n_classes)) - return mixed_x, self.from_one_hot(mixed_y) + if self.hparams.mixup_strategy == "cutmix": + joined_indexes = torch.cat([B1, B2]).sort()[0] + mixed_x, mixed_y = self.cut_mix( + self.hparams.mix_alpha, + torch.index_select(input=x, dim=0, index=joined_indexes).to(self.device), + torch.index_select(input=y, dim=0, index=joined_indexes).to(self.device) + ) + mixed_y = mixed_y.long() + else: # mixup + mixed_x, mixed_y = self.mix_up( + self.hparams.mix_alpha, + torch.index_select(input=x, dim=0, index=B1.sort()[0]).to(self.device), + self.to_one_hot(torch.index_select(input=y, dim=0, index=B1.sort()[0]), self.hparams.n_classes), + torch.index_select(input=x, dim=0, index=B2.sort()[0]).to(self.device), + self.to_one_hot(torch.index_select(input=y, dim=0, index=B2.sort()[0]), self.hparams.n_classes) + ) + mixed_y = self.from_one_hot(mixed_y) + return mixed_x, mixed_y - def training_step(self, batch, batch_idx): - mixed_x, mixed_y = self.model_step(batch) - logits = self.model(mixed_x) - - return {"logits": logits, "targets": mixed_y} - - def validation_step(self, batch, batch_idx): - mixed_x, mixed_y = self.model_step(batch) - logits = self.model(mixed_x) - - return {"logits": logits, "targets": mixed_y} - - -class LISAMnist(LISA): - def __init__( - self, - n_classes: int, - net: nn.Module, - optimizer: optim.Optimizer, - scheduler: optim.lr_scheduler, - ppred: float = 0.5, # probability of LISA-L - mix_alpha: float = 0.5 - ): - super().__init__(n_classes, net, optimizer, scheduler, ppred, mix_alpha) - - self.model = net - self.loss = nn.CrossEntropyLoss() + def _extract_batch(self, batch: dict): + return self.selective_augmentation(batch) \ No newline at end of file diff --git a/src/models/utils.py b/src/models/utils.py index b3eaeb8..72253db 100644 --- a/src/models/utils.py +++ b/src/models/utils.py @@ -1,5 +1,4 @@ from torch import Tensor - def AFR(y_pred_adv: Tensor, y_pred_clean: Tensor): return 1.0 - (y_pred_adv != y_pred_clean).sum() / len(y_pred_adv) diff --git a/src/pa_metric/callback.py b/src/pa_metric/callback.py deleted file mode 100644 index b1a25d3..0000000 --- a/src/pa_metric/callback.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Optional, List -import torch.nn.functional as F - -# To implement it in a LightningModule without receiving errors -from pytorch_lightning.callbacks import Callback -from pytorch_lightning import Trainer, LightningModule, LightningDataModule -from copy import deepcopy - -from .metric import PosteriorAgreement - -class PA_Callback(Callback): - def __init__(self, - log_every_n_epochs: int, - pa_epochs: int, - datamodule: LightningDataModule, - beta0: Optional[float], - early_stopping: Optional[List] = None): - - """ - Incorporation of the PA Metric to the Lightning training procedure. A LightningDataModule is required to generate the logits, and this - is either provided during initialization or else the validation DataLoader is used. - """ - super().__init__() - - self.beta0 = beta0 - self.pa_epochs = pa_epochs - self.log_every_n_epochs = log_every_n_epochs - self.early_stopping = early_stopping - - # TODO: Check that the dataset is not stored twice if its the same as the validation dataset. - datamodule.prepare_data() - datamodule.setup() - self.PosteriorAgreement = PosteriorAgreement( - dataset = datamodule.test_pairedds, - beta0 = self.beta0, - pa_epochs = self.pa_epochs, - early_stopping = self.early_stopping, - strategy = "lightning" - ) - - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule): - if (pl_module.current_epoch + 1) % self.log_every_n_epochs == 0: - self.PosteriorAgreement.update(classifier=deepcopy(pl_module.model)) - pa_dict = self.PosteriorAgreement.compute() - metrics_dict = { - "val/logPA": pa_dict["logPA"], - "val/beta": pa_dict["beta"], - "val/PA": pa_dict["PA"], - "val/AFR pred": pa_dict["AFR pred"], - "val/AFR true": pa_dict["AFR true"], - "val/acc_pa": pa_dict["acc_pa"], - } - self.log_dict(metrics_dict, prog_bar=False, on_step=False, on_epoch=True, logger=True, sync_dist=True) \ No newline at end of file diff --git a/src/pa_metric/crossvalidator.py b/src/pa_metric/crossvalidator.py deleted file mode 100644 index 77a1030..0000000 --- a/src/pa_metric/crossvalidator.py +++ /dev/null @@ -1,212 +0,0 @@ -import torch -from torch.utils.data import Dataset, Subset, ConcatDataset -from torchvision import datasets -from torchvision.transforms import ToTensor - -from torch.utils.data import DataLoader - - -#-------------------------------------------------- -#print(training_data.targets) -from abc import ABCMeta, abstractmethod -from sklearn.model_selection._split import BaseCrossValidator - - -from typing import Optional, List, Literal -import csv - -class Custom_CV(BaseCrossValidator): - """Custom CV implementation using pytorch and designed to substitute sklearn method. - - Args: - train_ds (Dataset): Dataset with the whole training data. - config_csv (str): Path to csv file with the splitting configuration. - test_ds (Optional[Dataset]): Dataset with the whole test data. If provided, all splits will be use for training. - shuffle (Optional[Literal["random", "sequential", "paired"]]): Shuffle strategy for the test split, either "random", "sequential" (i.e. not shuffled) and "paired", by which samples of different splits are label-correspondent. - random_state (Optional[int], optional): Random state. Defaults to 123. - - Returns: - fold_counter: Informs of the fold index. Every fold can contain multiple training groups. - group_info: A tuple indicating (group index, total number of groups). - train_ind_tensor: A tensor with the indexes for the training dataset, that will constitute the train split. - test_ind_list: A list containing two tensors of indexes for the test dataset (if provided) or the train dataset, that will constitute the two environments of the test split. If finally only test dataset wants to be used, use ConcatDataset. - - Example of use: No test dataset is provided, and the same split is used for testing each fold. - ds0,ds1,ds2,ds3,ds4 - 0,1,1,2,2 - 0,1,2,1,2 - 0,2,1,1,2 - 0,1,2,2,1 - 0,2,1,2,1 - 0,2,2,1,1 - - cv = Custom_CV(train_ds=BigDS, config_csv=..., shuffle="paired") - for i, (fold, group_info, train_idx, test_idx) in enumerate(cv.split()): - print(f"\nFold {fold}") - group_idx, num_groups = group_info # useful to average the metrics for example. - - train_ds = Subset(BigDS, train_idx) - test_ds1 = Subset(BigDS, test_idx[0]) - test_ds2 = Subset(BigDS, test_idx[1]) - - # train and evaluate (metric) - - It is important to notice that the training data is always the same size, but the size of the test data varies slightly in the "paired" setup due to the non-uniformity of the target distribution. - The test indexes for every element of the group are the same, only the training varies. - - """ - - def __init__(self, - train_ds: Dataset, - config_csv: str, - test_ds: Optional[Dataset] = None, - shuffle: Optional[Literal["random", "sequential", "paired"]] = "random", - random_state: Optional[int] = 123): - - self.train_dataset = train_ds - if test_ds: - self.test_dataset = test_ds - else: - self.test_dataset = None - - # Turn CSV into a list of lists - self.configlist = [] - with open(config_csv, newline='') as csvfile: - reader = csv.reader(csvfile) - next(reader) # Skip the header row - for row in reader: - self.configlist.append([int(item) for item in row]) - - self.n_folds = self.get_n_folds() - self.n_splits = self.get_n_splits() - - if shuffle not in ["random", "sequential", "paired"]: - raise ValueError("shuffle must be 'random', 'sequential' or 'paired'; got {0}".format(shuffle)) - - if shuffle != "sequential" and type(random_state) is not int: - raise ValueError("random_state must be an integer when shuffle is not 'sequential'; got {0}".format(random_state)) - - self.shuffle = shuffle - - def get_n_folds(self): - """Returns the number of folds in the cross-validator.""" - if len(self.configlist) < 1: - raise ValueError("Configuration list is empty.") - - return len(self.configlist) - - def get_n_splits(self): - """Returns the number of splits in the cross-validator.""" - n_splits_0 = len(self.configlist[0]) - for f in range(self.n_folds): # loop to check before training - n_splits = len(self.configlist[f]) - if n_splits != n_splits_0: - raise ValueError("Configuration must specify the same number of splits for each fold.") - if n_splits < 1: - raise ValueError("Configuration must include at least one split for each fold.") - - return n_splits_0 - - def _next_split_config(self, bool_test_ds: bool = False): - """Generates the configuration dictionary for the next split.""" - - for f in range(self.n_folds): # loop to check before training - if not bool_test_ds and self.n_splits < 2: - raise ValueError("Since no test set is provided, configuration must include at least two splits for each fold.") - - for f in range(self.n_folds): - # Get to same numbers for consistency - if not bool_test_ds: # If we dont specify a test dataset, take smaller group index as test - ind_zero = 0 - else: - ind_zero = 1 - group_inds = sorted(list(set(self.configlist[f]))) - map_to_inds = {el: i + ind_zero for i, el in enumerate(group_inds)} - group_inds = [map_to_inds[el] for el in self.configlist[f]] - - # Get group dictionary - group_dict = {} - for i, el in enumerate(group_inds): - if el not in group_dict: - group_dict[el] = [] - group_dict[el].append(i) - if bool_test_ds: - group_dict[0] = None - - yield group_dict - - def get_n_samples(self, dataset: Dataset): - """Returns the number of samples in the dataset.""" - return len(dataset) - - def paired_indexes(self, dataset:Dataset, pair_splits: int): - """Computes a list of indexes for every split in a way that targets match.""" - - n_samples = self.get_n_samples(dataset) - inds = torch.arange(n_samples) - - try: - labs = dataset.targets - except AttributeError: # when its not a full Dataset, but a Subset or ConcatDataset - labs = torch.tensor([dataset.__getitem__(i)[1] for i in range(len(dataset))]) - - unique_labs = labs.unique() - inds_mask = [inds[labs.eq(unique_lab.item())] for unique_lab in unique_labs] # indexes for every label - inds_mask = [mask[torch.randperm(mask.size(0))] for mask in inds_mask] # randomly permute the indexes for every label (unnecessary) - n_split_lab = min([mask.size(0) // pair_splits for mask in inds_mask]) - split_permutation = torch.randperm(n_split_lab*len(unique_labs)) - - indexes = [ - torch.cat([mask[n*n_split_lab:(n+1)*n_split_lab] for mask in inds_mask] - )[split_permutation] # same permutation to all the splits (to mix labels but keep correspondence) - for n in range(pair_splits)] - - return indexes - - - def split(self): - """Generate indices to split data.""" - - # Get train indexes list. These are only training if no test_dataset is provided. - n_train_samples = self.get_n_samples(self.train_dataset) - n_samples_split = n_train_samples // self.n_splits - train_inds = torch.arange(n_train_samples) - if self.shuffle != "sequential": # if sequential, both training and test are produced in a sequential way - train_inds = train_inds[torch.randperm(n_train_samples)] - train_indexes = [train_inds[n*n_samples_split:(n+1)*n_samples_split] for n in range(self.n_splits)] - - # Get test indexes list. These are only generated if there is a test_dataset. - if self.test_dataset: # the specified splits only concern the training dataset - n_test_samples = self.get_n_samples(self.test_dataset) - n_samples_test_split = n_test_samples // 2 # because we want the test to be divived in two - test_inds = torch.arange(n_test_samples) - - if self.shuffle == "paired": # we pair the test indexes, train indexes are randomly permuted - test_ind_list = self.paired_indexes(self.test_dataset, pair_splits = 2) - - else: # either random or sequential - if self.shuffle == "random": # test is random, train is random - test_inds = test_inds[torch.randperm(n_test_samples)] - test_ind_list = [test_inds[n*n_samples_test_split:(n+1)*n_samples_test_split] for n in range(2)] - - fold_counter = -1 - for sconfig in self._next_split_config(True if self.test_dataset else False): - fold_counter += 1 - group_names = sorted(list(sconfig.keys()))[1:] # all but 0, which is for the test split - num_groups = len(group_names) - - if not self.test_dataset: # the test_ind_list with 2 elements must be generated - test_inds = torch.cat([train_indexes[i] for i in sconfig[0]]) # join train splits specified for testing - if self.shuffle == "paired": - test_ind_list = self.paired_indexes(Subset(self.train_dataset, test_inds), pair_splits = 2) # indexes wrt Subset - test_ind_list = [test_inds[test_ind_list[i]] for i in range(2)] # indexes wrt self.train_dataset - else: - if self.shuffle == "random": - test_inds = test_inds[torch.randperm(len(test_inds))] - - n_samples_test_split = len(test_inds) // 2 # impose two groups - test_ind_list = [test_inds[n*n_samples_test_split:(n+1)*n_samples_test_split] for n in range(2)] - - for group in group_names: - train = torch.cat([train_indexes[i] for i in sconfig[group]]) - yield fold_counter, (group-1, num_groups), train, test_ind_list \ No newline at end of file diff --git a/src/pa_metric/kernel.py b/src/pa_metric/kernel.py deleted file mode 100644 index 0cd5741..0000000 --- a/src/pa_metric/kernel.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -from torch import nn -from typing import Optional -import torch.nn.functional as F - -class PosteriorAgreementKernel(nn.Module): - def __init__(self, beta0: Optional[float] = None, device: str = "cpu"): - super().__init__() - beta0 = beta0 if beta0 else 1.0 - if beta0 < 0.0: - raise ValueError("'beta' must be non-negative.") - - self.dev = device - self.beta = torch.nn.Parameter(torch.tensor([beta0], dtype=torch.float), requires_grad=True).to(self.dev) - self.log_post = torch.tensor([0.0], requires_grad=True).to(self.dev) - - def forward(self, preds1, preds2): - self.beta.requires_grad_(True) - self.beta.data.clamp_(min=0.0) - self.reset() - - with torch.set_grad_enabled(True): - probs1 = F.softmax(self.beta * preds1, dim=1).to(self.dev) - probs2 = F.softmax(self.beta * preds2, dim=1).to(self.dev) - - probs_sum = (probs1 * probs2).sum(dim=1).to(self.dev) - - # log correction for numerical stability: replace values less than eps - # with eps, in a gradient compliant way. Replace nans in gradients - # deriving from 0 * inf - probs_sum = probs_sum + (probs_sum < 1e-44) * (1e-44 - probs_sum) - if probs_sum.requires_grad: - probs_sum.register_hook(torch.nan_to_num) - - #self.log_post += torch.log(probs_sum).sum(dim=0) - self.log_post = self.log_post + torch.log(probs_sum).sum(dim=0).to(self.dev) - return -self.log_post - - def evaluate(self, beta_opt, preds1, preds2): - with torch.set_grad_enabled(False): - probs1 = F.softmax(beta_opt * preds1, dim=1).to(self.dev) - probs2 = F.softmax(beta_opt * preds2, dim=1).to(self.dev) - probs_sum = (probs1 * probs2).sum(dim=1).to(self.dev) - self.log_post = self.log_post + torch.log(probs_sum).sum(dim=0).to(self.dev) - - def reset(self): - self.log_post = torch.tensor([0.0], requires_grad=True).to(self.dev) - - def log_posterior(self): - return self.log_post.clone().to(self.dev) - - def posterior(self): - return torch.exp(self.log_post).to(self.dev) - - @property - def module(self): - """Returns the kernel itself. It helps the kernel be accessed in both DDP and non-DDP mode.""" - return self \ No newline at end of file diff --git a/src/pa_metric/metric.py b/src/pa_metric/metric.py deleted file mode 100644 index 276140b..0000000 --- a/src/pa_metric/metric.py +++ /dev/null @@ -1,378 +0,0 @@ -import os -import warnings - -import torch -from torchmetrics import Metric -from typing import Optional, List, Union -from torch.utils.data import DataLoader - -import torch.distributed as dist -import torch.multiprocessing as mp -from torch.utils.data import SequentialSampler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.distributed import init_process_group -from src.data.components import MultienvDataset, LogitsDataset -from src.data.components.collate_functions import MultiEnv_collate_fn - -from .sampler import PosteriorAgreementSampler -from .kernel import PosteriorAgreementKernel - -#TODO: Check this out - # # Set to True if the metric is differentiable else set to False - # is_differentiable: Optional[bool] = None - - # # Set to True if the metric reaches it optimal value when the metric is maximized. - # # Set to False if it when the metric is minimized. - # higher_is_better: Optional[bool] = True - - # # Set to True if the metric during 'update' requires access to the global metric - # # state for its calculations. If not, setting this to False indicates that all - # # batch states are independent and we will optimize the runtime of 'forward' - # full_state_update: bool = True - -class PosteriorAgreementSimple(Metric): - def __init__(self, - pa_epochs: int, - beta0: Optional[float] = None): - super().__init__() - - self.dev = "cuda" if torch.cuda.is_available() else "cpu" - self.beta0 = beta0 - self.pa_epochs = pa_epochs - - # Preallocate metrics to track - self.afr_true = torch.zeros(self.pa_epochs).to(self.dev) # metrics live in master process device - self.afr_pred = torch.zeros_like(self.afr_true) - self.accuracy = torch.zeros_like(self.afr_true) - - # Kernel and optimizer are initialized right away - self.kernel = PosteriorAgreementKernel(beta0=self.beta0).to(self.dev) - self.optimizer = torch.optim.Adam([self.kernel.module.beta], lr=0.01) - - def update(self, logits_dataloader: DataLoader): - """ - For this simple version, the logits will be passed and the kernel optimized. - """ - - self.betas = torch.zeros_like(self.afr_true) - self.logPAs = torch.full_like(self.afr_true, -float('inf')) - for epoch in range(self.pa_epochs): - for bidx, batch in enumerate(logits_dataloader): - self.kernel.module.beta.data.clamp_(min=0.0) - self.kernel.module.reset() - - envs = batch["envs"] - logits0, logits1 = batch[envs[0]][0], batch[envs[1]][0] - with torch.set_grad_enabled(True): - loss = self.kernel.module.forward(logits0, logits1) - loss.backward() - self.optimizer.step() - self.optimizer.zero_grad() - - self.kernel.module.beta.data.clamp_(min=0.0) # project to >=0 one last time - beta_last = self.kernel.module.beta.item() - self.betas[epoch] = beta_last - - # Compute logPA with the last beta per epoch - self.kernel.module.reset() - correct, correct_pred, correct_true = 0, 0, 0 - for bidx, batch in enumerate(logits_dataloader): - envs = batch["envs"] - logits0, logits1, y = batch[envs[0]][0], batch[envs[1]][0], batch[envs[0]][1] - - # Compute accuracy metrics if desired - if y and epoch == 0: - y_pred = torch.argmax(logits0.to(self.dev), 1) # env 1 - y_pred_adv = torch.argmax(logits1.to(self.dev), 1) # env 2 - correct_pred += (y_pred_adv == y_pred).sum().item() - correct_true += (y_pred_adv == y).sum().item() - correct += (torch.cat([y_pred, y_pred_adv]).to(self.dev) == torch.cat([y, y]).to(self.dev)).sum().item() - - # Update logPA - self.kernel.module.evaluate(beta_last, logits0, logits1) - - # Retrieve final logPA - self.logPAs[epoch] = self.kernel.module.log_posterior().item() - - if y and epoch == 0: # retrieve accuracy metrics - self.afr_pred[epoch] = correct_pred/len(logits0) - self.afr_true[epoch] = correct_true/len(logits0) - self.accuracy[epoch] = correct/(2*len(logits0)) - - # Locate the highest PA. - self.selected_index = torch.argmax(self.logPAs).item() - - def compute(self): - """ - Only meant to be used at the end of the PA optimization. - """ - return { - "beta": self.betas[self.selected_index], - "logPA": self.logPAs[self.selected_index], - "PA": torch.exp(self.logPAs[self.selected_index]), # TODO: Fix this small error - "AFR pred": self.afr_pred[self.selected_index], - "AFR true": self.afr_true[self.selected_index], - "acc_pa": self.accuracy[self.selected_index] - } - - -class PosteriorAgreement(PosteriorAgreementSimple): - def __init__(self, - dataset: MultienvDataset, - early_stopping: Optional[List] = None, - strategy: Optional[str] = "cuda", - cuda_devices: Optional[Union[List[str], int]] = None, - *args, **kwargs): - - if strategy not in ["cuda", "cpu", "lightning"]: - raise ValueError("The strategy must be either 'cuda', 'cpu' or 'lightning'.") - - super().__init__(*args, **kwargs) - - self.strategy = strategy - - # Initialize multiprocessing configuration - self.ddp_init = None - if self.strategy != "lightning": # cuda or cpu - if dist.is_initialized(): # ongoing cuda - self.device_list = [f"cuda:{i}" for i in range(dist.get_world_size())] - else: # non initialized cuda or cpu - if cuda_devices: - if isinstance(cuda_devices, int): - cuda_devices = [f"cuda:{i}" for i in range(cuda_devices)] - self.device_list = cuda_devices if (torch.cuda.is_available() and self.strategy == "cuda") else ["cpu"] - else: - self.device_list = ["cuda"] if (torch.cuda.is_available() and self.strategy == "cuda") else ["cpu"] - self.ddp_init = [False]*len(self.device_list) if "cuda" in self.device_list[0] else None - self.dev = self.device_list[0] - else: # Depending where the metric is initialized we will have to update it or not, but the accuracy tensors are already here - self.dev = "cuda" if dist.is_initialized() else "cpu" - self.device_list = [self.dev] - - print("IS CUDA AVAILABLE: ", torch.cuda.is_available()) - print("device list: ", self.device_list) - print("dev: ", self.dev) - print("ddp_init: ", self.ddp_init) - - # Check dataloader conditions - if not isinstance(dataset, MultienvDataset): - raise ValueError("The dataloader must be wrapped using a MultienvDataset.") - - self.dataloader = DataLoader( - dataset=dataset, - sampler=PosteriorAgreementSampler(dataset, shuffle=False, drop_last=True, num_replicas=len(self.device_list), rank=0), - collate_fn=MultiEnv_collate_fn, - shuffle=False, # we use custom sampler - - # Decide whether this has to be set as config input or not - batch_size = 64, - num_workers = 0, # 4*len(self.device_list) if "cuda" in self.dev else max(2, min(8, os.cpu_count())), - pin_memory = ("cuda" in self.dev), - ) - self.num_envs = self.dataloader.sampler.dataset.num_envs # get the modified version - self.batch_size = self.dataloader.batch_size - self.num_batches = self.dataloader.sampler.num_samples // self.batch_size - - # Define early stopping parameters - self.tol, self.itertol, self.patience = None, None, 0 - if early_stopping: - self.tol = early_stopping[0]*torch.ones(early_stopping[1]).to(self.dev) # tensor([tol, tol, tol, ...]) - self.itertol = float('inf')*torch.ones(early_stopping[1]).to(self.dev) # tensor([inf, inf, inf, ...]) to be filled with relative variations of beta - self.patience = early_stopping[2] - - # Checking inputs - if self.patience > self.pa_epochs: - warnings.warn("The patience is greater than the number of epochs. Early stopping will not be applied.") - self.tol, self.itertol, self.patience = None, None, 0 - if early_stopping[1] > self.pa_epochs: - warnings.warn("The number of iterations to consider for early stopping is greater than the number of epochs. Early stopping will not be applied.") - self.tol, self.itertol, self.patience = None, None, 0 - - def _logits_dataset(self, dev, classifier: torch.nn.Module, classifier2: Optional[torch.nn.Module] = None): - classifier.to(dev) - if classifier2: - classifier2.to(dev) - - y_totensor = [None]*len(self.dataloader) - X_totensor = [None]*len(self.dataloader) - for bidx, batch in enumerate(self.dataloader): - if bidx == 0: # initialize logits dataset - envs = batch["envs"] - if len(envs) != self.num_envs: - raise ValueError("There is a problem with the configuration of the Dataset and/or the DataLoader collate function.") - - X_list = [batch[envs[e]][0].to(dev) for e in range(self.num_envs)] - Y_list = [batch[envs[e]][1].to(dev) for e in range(self.num_envs)] - if not all([torch.equal(Y_list[0], Y_list[i]) for i in range(1, len(Y_list))]): # all labels must be equal - raise ValueError("The labels of the two environments must be the same.") - - y_totensor[bidx] = Y_list[0] - if classifier2: # then the validation with additional datasets uses the second classifier - X_totensor[bidx] = [classifier(X_list[0])] + [classifier2(X_list[i]) for i in range(1, len(X_list))] - else: # subset has two elements, each with the same labels - X_totensor[bidx] = [classifier(X) for X in X_list] - - logits_list = [torch.cat([X_totensor[j][i] for j in range(len(self.dataloader))]) for i in range(len(X_list))] - y = torch.cat(y_totensor) - - return LogitsDataset(logits_list, y) - - def _pa_validation(self, dev, kernel, fixed_beta, env, logits_dataloader: DataLoader): - kernel.module.reset() - total_samples = 0 - correct, correct_pred, correct_true = 0, 0, 0 - with torch.no_grad(): - for bidx, batch in enumerate(logits_dataloader): - # This is the return of the __getitem__ method. No need of a collate function bc this will not generalize. - #{str(i): tuple([self.logits[i][index], self.y[index]]) for i in range(self.num_envs)} - - logit0, y = batch['0'][0], batch['0'][1] - logit1 = batch[str(env)][0] - - # Compute accuracy metrics - y_pred = torch.argmax(logit0.to(dev), 1) # env 1 - y_pred_adv = torch.argmax(logit1.to(dev), 1) # env 2 - correct_pred += (y_pred_adv == y_pred).sum().item() - correct_true += (y_pred_adv == y).sum().item() - correct += (torch.cat([y_pred, y_pred_adv]).to(dev) == torch.cat([y, y]).to(dev)).sum().item() - total_samples += len(y) - - # Update logPA - kernel.module.evaluate(fixed_beta, logit0, logit1) - - # Retrieve final logPA for the (subset) batches - logPA = kernel.module.log_posterior().to(dev) - - # Retrieve logPA and accuracy metrics for the epoch and log - if "cuda" in dev and self.strategy == "cuda": - dist.all_reduce(logPA, op=dist.ReduceOp.SUM) - dist.all_reduce(torch.tensor(total_samples).to(dev), op=dist.ReduceOp.SUM) - dist.all_reduce(torch.tensor(correct_pred).to(dev), op=dist.ReduceOp.SUM) - dist.all_reduce(torch.tensor(correct_true).to(dev), op=dist.ReduceOp.SUM) - dist.all_reduce(torch.tensor(correct).to(dev), op=dist.ReduceOp.SUM) - - return { - "logPA": logPA.item(), - "AFR pred": correct_pred/total_samples, - "AFR true": correct_true/total_samples, - "accuracy": correct/(2*total_samples) - } - - def _optimize_beta(self, rank: int, classifier: torch.nn.Module, classifier2: Optional[torch.nn.Module] = None): - if self.strategy == "lightning": - dev = "cuda" if torch.cuda.is_available() else "cpu" - else: - dev = str(self.device_list[rank]) - - self.dataloader.sampler.rank = rank # adjust to device to be used - logits_dataset = self._logits_dataset(dev, classifier, classifier2) - logits_dataloader = DataLoader(dataset=logits_dataset, - batch_size=self.batch_size, # same as the data - num_workers=0, # we won't create subprocesses inside a subprocess, and data is very light - pin_memory=False, # only dense CPU tensors can be pinned - - # Important so that it matches with the input data. - shuffle=False, - drop_last = False, - sampler=SequentialSampler(logits_dataset)) - - # load training objects every time - kernel = PosteriorAgreementKernel(beta0=self.beta0).to(dev) - if "cuda" in dev and self.strategy == "cuda": - kernel = DDP(kernel, device_ids=[dev]) - optimizer = torch.optim.Adam([kernel.module.beta], lr=0.01) - - # Optimize beta for every batch within an epoch, for every epoch - for epoch in range(self.pa_epochs): - beta_e = 0.0 - for bidx, batch in enumerate(logits_dataloader): - kernel.module.beta.data.clamp_(min=0.0) - kernel.module.reset() - - logits, _ = batch - with torch.set_grad_enabled(True): - loss = kernel.module.forward(logits[0].to(dev), logits[1].to(dev)) - beta_e += kernel.module.beta.item() - - loss.backward() - optimizer.step() - optimizer.zero_grad() - - # Retrieve betas and compute the mean over the epoch - if "cuda" in dev and self.strategy == "cuda": - dist.all_reduce(torch.tensor(beta_e).to(dev), op=dist.ReduceOp.SUM) # sum betas from all processes for the same epoch - beta_mean = beta_e / self.num_batches - self.betas[epoch] = beta_mean - - # Compute logPA with the mean beta for the epoch and validate - for i in range(1, self.num_envs): - metric_dict = self._pa_validation(dev, kernel, beta_mean, i, logits_dataloader) - if i == 1: # the ones for the first environment must be stored - self.logPAs[epoch] = metric_dict["logPA"] - self.afr_pred[epoch] = metric_dict["AFR pred"] - self.afr_true[epoch] = metric_dict["AFR true"] - self.accuracy[epoch] = metric_dict["accuracy"] - - else: - if epoch == self.pa_epochs-1: # TODO: Decide what to do with this - print(metric_dict) - with open('logs_pa_metric.txt', 'a') as log_file: - log_file.writelines([f"metric dict for 0-{i}" + str(metric_dict) + "\n"]) - - # Check for beta relative variation and implement early stopping - if self.tol != None and epoch > self.patience: - relvar = torch.tensor([abs(beta_mean - self.betas[epoch-1])/beta_mean]).to(self.dev) - self.itertol = torch.cat([self.itertol[1:], relvar]).to(self.dev) - if torch.le(self.itertol, self.tol).all().item(): - print(f"PA optimization stopped at epoch {epoch}.") - break - - def _init_DDP_wrapper(self, rank: int, classifier: torch.nn.Module, classifier2: Optional[torch.nn.Module] = None): - """ - Implements optimization after initializing the corresponding subprocesses. - """ - if self.ddp_init[rank] == False: # Initialize the process only once, even if the .update() is called several times during a training procedure. - os.environ['MASTER_ADDR'] = 'localhost' - os.environ["MASTER_PORT"] = "50000" - init_process_group(backend="nccl", rank=rank, world_size=len(self.device_list)) - torch.cuda.set_device(rank) - self.ddp_init[rank] = True - - self._optimize_beta(rank, classifier, classifier2) - - def update(self, classifier: torch.nn.Module, classifier2: Optional[torch.nn.Module] = None, destroy_process_group: Optional[bool] = False): - """ - The goal is to make the Metric as versatile as possible. The Metric can be called in two ways: - - During a training procedure. In such case, it will use the training strategy already in place (e.g DDP). - - With a trained model, for evaluation. In such case, the training strategy can be selected: CPU or (multi)-GPU with DDP. - - Important: If used during training, pass a copy.deepcopy() of the classifier(s) to avoid errors. - """ - - # Set to eval mode and freeze parameters - classifier.eval() - for param in classifier.parameters(): - param.requires_grad = False - if classifier2: - classifier2.eval() - for param in classifier2.parameters(): - param.requires_grad = False - - # Optimize beta depending on the strategy and the devices available - if dist.is_initialized(): # ongoing cuda or ddp lightning - self._optimize_beta(dist.get_rank(), classifier, classifier2) - else: - if "cuda" in self.dev: # cuda for the metric - mp.spawn(self._init_DDP_wrapper, - args=(classifier, classifier2,), - nprocs=len(self.device_list), - join=True) # this gave error - - # Set to True when it's the last call to .update() - if destroy_process_group and dist.is_initialized(): - dist.destroy_process_group() - else: # "cpu", either lightning or not - self._optimize_beta(0, classifier, classifier2) - - # Get the epoch achieving maximum logPA - self.selected_index = torch.argmax(self.logPAs).item() \ No newline at end of file diff --git a/src/pa_metric/sampler.py b/src/pa_metric/sampler.py deleted file mode 100644 index 1515317..0000000 --- a/src/pa_metric/sampler.py +++ /dev/null @@ -1,105 +0,0 @@ -import torch -from typing import Union -import warnings -from torch.utils.data.distributed import DistributedSampler -from src.data.components import MultienvDataset, LogitsDataset - -class PosteriorAgreementSampler(DistributedSampler): - def __init__(self, dataset: Union[MultienvDataset, LogitsDataset], *args, **kwargs): - """ - - If the dataset contains only one environment, the metric will expect two classifiers in the .update() method. - - If the dataset contains more than one environment, those not used (at least one if two classifiers are provided) will be used for validation. - """ - - if not (isinstance(dataset, MultienvDataset) or isinstance(dataset, LogitsDataset)): - warnings.warn("The dataset must be a MultienvDataset to work with the PA metric.") - - self.num_envs = dataset.num_envs - original_dset_list = dataset.dset_list - if dataset.num_envs == 1: - warnings.warn("Only one environment was found in the dataset. The PA metric will expect two classifiers in the .update() method.") - - # Build two first environments to be PA pairs. - if self.num_envs >= 2: - dataset.num_envs = 2 - dataset.dset_list = dataset.dset_list[:2] - dataset.permutation = self._pair_optimize(dataset) - - # Add additional environments adjusted to the first two. - if self.num_envs > 2: - new_permutations = [None]*(self.num_envs-2) - new_nsamples = dataset.__len__() # samples after pairing (new permutation has been applied) - new_labels = dataset.__getlabels__(list(range(new_nsamples)))[0] # labels of first environment (idem second) - add_dataset = MultienvDataset(original_dset_list[2:]) - - add_labels = add_dataset.__getlabels__(list(range(add_dataset.__len__()))) # labels of the rest of environments - for i in range(self.num_envs-2): - new_permutations[i] = self._pair_validate(new_labels, add_labels[i]) - - filtered = torch.tensor([new_permutations[i] != None for i in range(self.num_envs-2)]) - dataset.num_envs = 2 + filtered.sum().item() - dataset.dset_list = original_dset_list[:2] + [original_dset_list[2+i] for i in range(len(filtered)) if filtered[i].item()] - dataset.permutation = dataset.permutation + [newperm for newperm, flag in zip(new_permutations, filtered) if flag] - - super().__init__(dataset, *args, **kwargs) - - def _pair_optimize(self, dataset: MultienvDataset): - """ - Generates permutations for the first pair of environments so that their labels are correspondent. - """ - n_samples = dataset.__len__() - inds = torch.arange(n_samples) - labels_list = dataset.__getlabels__(inds.tolist())[:2] # only the first two environments - - # IMPORTANT: If the data is already paired, it could mean that not only the labels are paired but also the samples. - # In such case, we don't want to touch it. - if torch.equal(labels_list[0], labels_list[1]): - return [inds, inds] - - unique_labs = [labels.unique() for labels in labels_list] - common_labs = unique_labs[0][torch.isin(unique_labs[0], unique_labs[1])] # labels that are common to both environments - - final_inds = [[], []] - for lab in list(common_labs): - inds_mask = [inds[labels_list[i].eq(lab)] for i in range(2)] # indexes for every label - if len(inds_mask[0]) >= len(inds_mask[1]): - final_inds[0].append(inds_mask[0][:len(inds_mask[1])]) - final_inds[1].append(inds_mask[1]) - else: - final_inds[0].append(inds_mask[0]) - final_inds[1].append(inds_mask[1][:len(inds_mask[0])]) - - return [torch.cat(final_inds[i]).tolist() for i in range(2)] - - def _pair_validate(self, labels: torch.Tensor, labels_add: torch.Tensor): - """ - Generates permutations for additional validation environments so that their labels are correspondent to the PA pair. - If the number of observations for certain labels is not enough, the samples are repeated. - If there are not observations associated with specific reference labels, the environment will be discarded. - """ - if torch.equal(labels, labels_add): - return torch.arange(len(labels)).tolist() # do not rearrange if labels are already equal - - unique, counts = labels.unique(return_counts=True) - sorted_values, sorted_indices = torch.sort(labels) - unique_add, counts_add = labels_add.unique(return_counts=True) - sorted_values_add, sorted_indices_add = torch.sort(labels_add) - - permuted = [] - for i in range(len(unique)): - pos_add = (unique_add==unique[i].item()).nonzero(as_tuple=True)[0] - if len(pos_add) == 0: # it means that that the label is not present in the second tensor - warnings.warn("The label " + str(unique[i].item()) + " is not present in the tensor. Pairig is impossible, so the environment will not be used.") - return None - else: - num = counts[i] # elements in the reference - num_add = counts_add[pos_add.item()] # elements in the second tensor - diff = num_add - num - vals_add = sorted_indices_add[counts_add[:pos_add].sum(): counts_add[:pos_add+1].sum()] # indexes of the second tensor - if diff >= 0: # if there are enough in the second tensor, we sample without replacement - permuted.append(vals_add[torch.randperm(num_add)[:num]]) - else: # if there are not enough, we sample with replacement (some samples will be repeated) - permuted.append(vals_add[torch.randint(0, num_add, (num,))]) - - perm = torch.cat(permuted) - return perm[torch.argsort(sorted_indices)].tolist() # b => sorted_b' = sorted_a <= a \ No newline at end of file diff --git a/src/plot/adv/adv_plot.py b/src/plot/adv/adv_plot.py index 5ed3108..df6e9d1 100644 --- a/src/plot/adv/adv_plot.py +++ b/src/plot/adv/adv_plot.py @@ -4,7 +4,7 @@ from src.plot.adv.utils import create_dataframe_from_wandb_runs attack = "GAUSSIAN" # "PGD", "FMN -date = "2024-01-12" +date = "2024-01-15" tags = [ "cifar10", attack, @@ -21,7 +21,7 @@ date=date, filters={ "state": "finished", - "group": "adversarial_pa", + "group": "adv_pa_cifar", # "tags": {"$all": ["cifar10", attack]}, # for some reason this does not work "$and": [{"tags": tag} for tag in tags], "created_at": {"$gte": date}, diff --git a/src/plot/adv/fmn.py b/src/plot/adv/fmn.py index 403f247..c649c39 100644 --- a/src/plot/adv/fmn.py +++ b/src/plot/adv/fmn.py @@ -51,7 +51,7 @@ def curves(df: pd.DataFrame, metric: str = "logPA") -> None: sns.set_style("ticks") # Divide by the cardinality of cifar10 - subset["logPA"] = subset["logPA"]/10000.0 + #subset["logPA"] = subset["logPA"]/10000.0 sns.lineplot( data=subset, @@ -77,16 +77,17 @@ def curves(df: pd.DataFrame, metric: str = "logPA") -> None: ax.grid(linestyle="--") ax.set_xlabel(x_label, fontname=fontname) - ax.set_ylabel(r"$10^{4} \cdot $ PA" if metric == "logPA" else "AFR", fontname=fontname) + #ax.set_ylabel(r"$10^{4} \cdot $ PA" if metric == "logPA" else "AFR", fontname=fontname) + ax.set_ylabel("PA" if metric == "logPA" else "AFR", fontname=fontname) # Legend handles, labels = ax.get_legend_handles_labels() labels = [LABEL_DICT[label] for label in labels] # sort labels and handles - #ids = sorted(range(len(labels)), key=YEARS.__getitem__) + ids = sorted(range(len(labels)), key=YEARS.__getitem__) - ids = [5, 0, 2, 4, 3, 1] + #ids = [5, 0, 2, 4, 3, 1] labels = [labels[i] for i in ids] handles = [handles[i] for i in ids] diff --git a/src/plot/adv/medianeps.py b/src/plot/adv/medianeps.py new file mode 100644 index 0000000..5a885b5 --- /dev/null +++ b/src/plot/adv/medianeps.py @@ -0,0 +1,50 @@ +import pyrootutils +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +import torch +from secml.data.loader import CDataLoaderCIFAR10 +from src.data.utils import carray2tensor + +model_var = ["Standard", "BPDA", "Engstrom2019Robustness", "Wong2020Fast", "Addepalli2021Towards_RN18", "Wang2023Better_WRN-28-10"] +file_path = '/cluster/home/vjimenez/adv_pa_new/results/plots/adv/model_linfs.txt' +f = open(file_path, 'w') +for model in model_var: + datapath = f"/cluster/home/vjimenez/adv_pa_new/data/adv/adv_datasets/model={model}_attack=FMN_steps=1000.pt" + + dset = CDataLoaderCIFAR10 + _, ts = dset().load(val_size=0) + X, Y = ts.X / 255.0, ts.Y + X = carray2tensor(X, torch.float32) + dset_size = X.shape[0] + + linf_data = [] + for adversarial_ratio in torch.arange(0, 1.1, 0.1): + adv_X = torch.load(datapath) + + if adversarial_ratio == 0.0: + adv_X = X + + split = int(adversarial_ratio * dset_size) + attack_norms = (adv_X - X).norm(p=float("inf"), dim=1) + + _, unpoison_ids = attack_norms.topk(dset_size - split) + + # remove poison for the largest 1 - adversarial_ratio attacked ones + adv_X[unpoison_ids] = X[unpoison_ids] + + linf = torch.norm(adv_X - X, p=float("inf"), dim=1) + #import ipdb; ipdb.set_trace() + #print(linf.median().item()*255, linf.max().item()*255) + linf_data.append((linf.median().item()*255, linf.max().item()*255)) + + table_header = "\nAR\tMedian linf\tMax linf\n" + table_rows = [f"{i/10:.1f}\t{median:.2f}\t{max_linf:.2f}" for i, (median, max_linf) in enumerate(linf_data)] + table = table_header + "\n".join(table_rows) + f.write(f"\n\nModel: {model}") + f.write(table) + + + + + + diff --git a/src/plot/adv/pgd.py b/src/plot/adv/pgd.py index 6ca1130..2b12f46 100644 --- a/src/plot/adv/pgd.py +++ b/src/plot/adv/pgd.py @@ -5,6 +5,7 @@ import matplotlib.pyplot as plt import matplotlib.font_manager as fm import seaborn as sns +from numpy import asarray from src.plot.adv import DASHES_DICT, COLORS_DICT, LABEL_DICT, YEARS from src.plot.adv.utils import create_dataframe_from_wandb_runs @@ -58,8 +59,11 @@ def curves(df: pd.DataFrame, metric: str = "logPA", attack_name = None) -> None: plt.rcParams["font.serif"] = fontname sns.set_style("ticks") - # Divide by the cardinality of cifar10 - subset["logPA"] = subset["logPA"]/10000.0 + + + df.loc[df["linf"].eq(0.0), "logPA"] = 0.0 + subset = df.loc[df["linf"].lt(50.01/255.0)] # because no error yields PA very very low + #subset["logPA"] = subset["logPA"]/10000.0 # Divide by the cardinality of cifar10 sns.lineplot( data=subset, @@ -72,13 +76,21 @@ def curves(df: pd.DataFrame, metric: str = "logPA", attack_name = None) -> None: dashes=False, marker="o", linewidth=3, + legend=False ) ax.minorticks_on() if level == "linf": ax.set_xticks([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]) else: # level == "adversarial_ratio" - ax.set_xticks([0.0314, 0.0627, 0.1255]) + # ax.set_xticks([0.0314, 0.0627, 0.1255]) + # linfs = asarray(range(5,256, 5)) + # ax.set_xticks(list(linfs/255.0)[::5]) + # ax.set_xticklabels(list(linfs)[::5]) + + linfs = asarray(range(0,51, 5)) + ax.set_xticks(list(linfs/255.0)) + ax.set_xticklabels(list(linfs)) ax.tick_params(axis="both", which="both", direction="in") plt.xticks(rotation=45) @@ -88,15 +100,19 @@ def curves(df: pd.DataFrame, metric: str = "logPA", attack_name = None) -> None: ax.grid(linestyle="--") + #x_label = r"$255 \times$ " + x_label ax.set_xlabel(x_label, fontname=fontname) - ax.set_ylabel(r"$10^{4} \cdot $ PA" if metric == "logPA" else "AFR", fontname=fontname) + #ax.set_ylabel(r"$10^{-4} \times $ PA" if metric == "logPA" else "AFR", fontname=fontname) + ax.set_ylabel("PA" if metric == "logPA" else "AFR", fontname=fontname) # Modified y axis - # if level == "linf": - # ax.set_ylim(min(subset[metric])*10, 10 if attack_name == 'PGD' else 100) - # else: - # ax.set_ylim(min(subset[metric])*10, None) - # ax.set_yscale('symlog') + if level == "linf": + #ax.set_ylim(min(subset[metric])*10, 10 if attack_name == 'PGD' else 100) + ax.set_ylim(None, None) + else: + ax.set_ylim(min(subset[metric])*2, 0.5) + #ax.set_ylim(min(subset[metric])*10, None) + ax.set_yscale('symlog') # Legend handles, labels = ax.get_legend_handles_labels() @@ -106,17 +122,17 @@ def curves(df: pd.DataFrame, metric: str = "logPA", attack_name = None) -> None: ids = sorted(range(len(labels)), key=YEARS.__getitem__) # ids[0], ids[1] = ids[1], ids[0] # ids = [2,1,0] - ids = [5, 0, 2, 4, 3, 1] + # ids = [5, 0, 2, 4, 3, 1] labels = [labels[i] for i in ids] handles = [handles[i] for i in ids] - ax.legend( - handles, - labels, - handlelength=0.5, - prop={"family": fontname, "size": 16}, - ) - # sns.move_legend(ax2, "upper right") + # ax.legend( + # handles, + # labels, + # handlelength=0.5, + # prop={"family": fontname, "size": 16}, + # ) + #sns.move_legend(ax2, "upper right") if attack_name == 'GAUSSIAN': title = "Gaussian noise" diff --git a/src/plot/adv/utils.py b/src/plot/adv/utils.py index 5e9203d..e09bb5a 100644 --- a/src/plot/adv/utils.py +++ b/src/plot/adv/utils.py @@ -78,11 +78,17 @@ def create_dataframe_from_wandb_runs( ) ) - if "data/attack/epsilons" in config or "data/adv/attack/epsilons" in config: + # if attack == "PGD": + # varx = "epsilons" + # else: # gaussian + # varx = "noise_std" + varx = "epsilons" + + if "data/attack/" + varx in config or "data/adv/attack/" + varx in config: data["linf"].append( config.get( - "data/attack/epsilons", - config.get("data/adv/attack/epsilons"), + "data/attack/" + varx, + config.get("data/adv/attack/" + varx), ) ) # pause because wandb sometimes is not able to retrieve the results @@ -91,14 +97,15 @@ def create_dataframe_from_wandb_runs( data["AFR"].append( max( [ - row["val/AFR pred"] + row["val/AFR true"] for row in run.scan_history() - if row["val/AFR pred"] is not None + if row["val/AFR true"] is not None ] ) ) except: continue + data["logPA"].append(history["val/logPA"].max()) df = pd.DataFrame(data) diff --git a/src/plot/dg/dg.py b/src/plot/dg/dg.py index e9cea8c..79d36c1 100644 --- a/src/plot/dg/dg.py +++ b/src/plot/dg/dg.py @@ -8,6 +8,40 @@ import numpy as np +def _check_correlations(subset: pd.DataFrame): + sr = subset["shift_ratio"].values[0] # all the same + model_names = subset["model_name"].unique() + shift_factors = np.sort(subset["shift_factor"].unique()) + + def _compute_differences(v1, v2 = None): + if type(v2) == type(None): + v2 = v1.copy() + return np.asarray(v1[:-1]) - np.asarray(v2[1:]) + + cor_sf_true_pred = np.zeros(len(shift_factors)) + cor_sf_true_true, cor_sf_pred = cor_sf_true_pred.copy(), cor_sf_true_pred.copy() + for i in range(len(shift_factors)): + condition = subset["shift_factor"] == shift_factors[i] + columns = subset.loc[condition, ["logPA", "AFR_true", "AFR_pred"]] + cor_sf_true_pred[i] = np.corrcoef(_compute_differences(columns["logPA"]), + _compute_differences(columns["AFR_true"], columns["AFR_pred"]))[0, 1] + cor_sf_true_true[i] = np.corrcoef(_compute_differences(columns["logPA"]), + _compute_differences(columns["AFR_true"]))[0, 1] + cor_sf_pred[i] = np.corrcoef(_compute_differences(columns["logPA"]), + _compute_differences(columns["AFR_pred"]))[0, 1] + + cor_mod_true = np.zeros(len(model_names)) + cor_mod_pred = np.zeros(len(model_names)) + for i in range(len(model_names)): + condition = subset["model_name"] == model_names[i] + columns = subset.loc[condition, ["logPA", "AFR_true", "AFR_pred"]] + cor_mod_true[i] = np.corrcoef(_compute_differences(columns["logPA"]), + _compute_differences(columns["AFR_true"]))[0, 1] + cor_mod_pred[i] = np.corrcoef(_compute_differences(columns["logPA"]), + _compute_differences(columns["AFR_pred"]))[0, 1] + + return np.mean(cor_sf_true_true), np.mean(cor_sf_true_pred), np.mean(cor_sf_pred), np.mean(cor_mod_true), np.mean(cor_mod_pred) + def logpa(df: pd.DataFrame, dirname: str, @@ -19,6 +53,11 @@ def logpa(df: pd.DataFrame, dirname = osp.join(dirname, "PA" if metric == "logPA" else "AFR") os.makedirs(dirname, exist_ok=True) + # To compute the correlations. + list_shift_ratios = np.sort(df["shift_ratio"].unique()) + corr_sf_pred = np.zeros(len(list_shift_ratios)) + corr_sf_true_true, corr_sf_true_pred, corr_mod_true, corr_mod_pred = corr_sf_pred.copy(), corr_sf_pred.copy(), corr_sf_pred.copy(), corr_sf_pred.copy() + pairs = [("shift_ratio", "shift_factor")] for levels in tqdm(pairs, total=len(pairs)): level, x_level = levels @@ -53,7 +92,8 @@ def logpa(df: pd.DataFrame, "shift_ratio", "shift_factor", "logPA", - "AFR", + "AFR_true", + "AFR_pred", "acc_pa" ], ].sort_values(by='model_name') @@ -73,6 +113,7 @@ def logpa(df: pd.DataFrame, plt.rcParams["font.serif"] = fontname sns.set_style("ticks") + #subset["logPA"] = subset["logPA"]/1000.0 sns.lineplot( data=subset, ax=ax1, @@ -91,13 +132,17 @@ def logpa(df: pd.DataFrame, ax1.text( point['shift_factor'], point['logPA'], - f"{point['acc_pa']:.2f}", + f"{point['AFR_true']:.2f}", color='black', ha='center', va='bottom', fontsize=9 ) + if level == "shift_ratio": + ind = int(np.where(list_shift_ratios == value)[0]) + corr_sf_true_true[ind], corr_sf_true_pred[ind], corr_sf_pred[ind], corr_mod_true[ind], corr_mod_pred[ind] = _check_correlations(subset) + ax1.minorticks_on() ax1.tick_params(axis="both", which="both", direction="in") xticks_font = fm.FontProperties(family=fontname) @@ -109,6 +154,9 @@ def logpa(df: pd.DataFrame, ax1.set_xlabel(x_name, fontname=fontname) ax1.set_ylabel("PA", fontname=fontname) + #ax1.set_ylabel(r"$2 \times 10^3 \cdot$" + " PA", fontname=fontname) + # ax1.set_ylim(min(subset["logPA"])*2, -5) + # ax1.set_yscale('symlog') # Legend handles, labels = ax1.get_legend_handles_labels() @@ -149,6 +197,8 @@ def logpa(df: pd.DataFrame, plt.savefig(fname) plt.clf() plt.close() + + import ipdb; ipdb.set_trace() def afr_vs_logpa(df: pd.DataFrame, comparison_metric: str = "ASR"): diff --git a/src/plot/dg/dg_plot.py b/src/plot/dg/dg_plot.py index ce29545..366dec5 100644 --- a/src/plot/dg/dg_plot.py +++ b/src/plot/dg/dg_plot.py @@ -8,7 +8,7 @@ entity='malvai' project='cov_pa' -group = 'pa_lightningdebug' +group = 'dg_pa_diagvib' df_dir = 'results/dataframes/dg/' + group + '/' pic_dir = 'results/plots/dg/' + group + '/' @@ -27,7 +27,8 @@ df_list.append(df) df = pd.concat(df_list, axis=0) -logpa(df, pic_dir, show_acc=False, picformat="pdf") +df["model_name"] = df["name"].apply(lambda x: "_".join(x.split("=")[1].split("_")[:2])) # for logits, as I get "$model_name" from wandb +logpa(df, pic_dir, show_acc=True, picformat="png") diff --git a/src/plot/dg/utils.py b/src/plot/dg/utils.py index cc8a375..ff910fb 100644 --- a/src/plot/dg/utils.py +++ b/src/plot/dg/utils.py @@ -42,9 +42,11 @@ def dg_pa_dataframe( data["name"].append(run.name) data["shift_ratio"].append(config["data/dg/shift_ratio"]) - data["model_name"].append(config["model/dg/classifier/exp_name"]) + #data["model_name"].append(config["model/dg/classifier/exp_name"]) # no logits + data["model_name"].append(config["data/dg/classifier/exp_name"]) # logits data["shift_factor"].append(config["data/dg/envs_index"][1]) # for original - data["AFR"].append(max(retrieve_from_history(run, f"val/AFR {afr}"))) + data["AFR_true"].append(max(retrieve_from_history(run, f"val/AFR true"))) + data["AFR_pred"].append(max(retrieve_from_history(run, f"val/AFR pred"))) data["acc_pa"].append(max(retrieve_from_history(run, "val/acc_pa"))) logpa_epoch = retrieve_from_history(run, "val/logPA") data["logPA"].append(max(logpa_epoch)) diff --git a/src/test.py b/src/test.py new file mode 100644 index 0000000..fa7bca3 --- /dev/null +++ b/src/test.py @@ -0,0 +1,120 @@ +from typing import List, Optional, Tuple + +import hydra +import os +import pandas as pd +import csv +import pytorch_lightning as pl +from omegaconf import OmegaConf, DictConfig +from pytorch_lightning import ( + Callback, + LightningDataModule, + LightningModule, + Trainer, +) +from pytorch_lightning.loggers import Logger + +# Add resolvers to evaluate operations in the .yaml configuration files +OmegaConf.register_new_resolver("eval", eval) +OmegaConf.register_new_resolver("len", len) +OmegaConf.register_new_resolver("classname", lambda classpath: classpath.split(".")[-1]) + +import pyrootutils +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + + +from src import utils + +log = utils.get_pylogger(__name__) + +# TO DELETE AFTER IT'S BEEN DEBUGGED +import torch +import os + +@utils.task_wrapper +def test(cfg: DictConfig) -> Tuple[dict, dict]: + """Tests the model. + + This method is wrapped in optional @task_wrapper decorator which applies extra utilities + before and after the call. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + + Returns: + Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. + """ + + # set seed for random number generators in pytorch, numpy and python.random + pl.seed_everything(cfg.seed, workers=True) + + log.info(f"Instantiating datamodule <{cfg.data._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule) + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating callbacks...") + callbacks: List[Callback] = utils.instantiate_callbacks( + cfg.get("callbacks") + ) + + log.info("Instantiating loggers...") + logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate( + cfg.trainer, callbacks=callbacks, logger=logger + ) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + utils.log_hyperparameters(object_dict) + + # Test model + log.info("\nStarting test!") + trainer = hydra.utils.instantiate( + cfg.trainer, logger=logger, strategy='gpu', devices = 1 + ) + + model: LightningModule = hydra.utils.instantiate(cfg.model) + trainer.test( + model = model, # model.load_from_checkpoint(cfg.ckpt_path, net=hydra.utils.instantiate(cfg.model.net)), # because 'net' is not stored in the checkpoint + datamodule=datamodule, + ) + + print("TESTING IS DONE, SHOW LENGTH TO DEBUG:") + print("lengths: ", torch.sum(model.len_test)) + + test_metrics = trainer.callback_metrics + metric_dict = {**test_metrics} + return metric_dict, object_dict + + +@hydra.main( + version_base="1.3", config_path="../configs", config_name="test.yaml" +) +def main(cfg: DictConfig) -> Optional[float]: + # train the model + metric_dict, _ = test(cfg) + + # safely retrieve metric value for hydra-based hyperparameter optimization + metric_value = utils.get_metric_value( + metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") + ) + + # return optimized metric + return metric_value + + +if __name__ == "__main__": + main() diff --git a/src/train_dg.py b/src/train.py similarity index 60% rename from src/train_dg.py rename to src/train.py index 6ce26cf..3f53ede 100644 --- a/src/train_dg.py +++ b/src/train.py @@ -5,8 +5,7 @@ import pandas as pd import csv import pytorch_lightning as pl -from omegaconf import DictConfig -import pyrootutils +from omegaconf import OmegaConf, DictConfig from pytorch_lightning import ( Callback, LightningDataModule, @@ -15,34 +14,25 @@ ) from pytorch_lightning.loggers import Logger +# Add resolvers to evaluate operations in the .yaml configuration files +OmegaConf.register_new_resolver("eval", eval) +OmegaConf.register_new_resolver("len", len) +OmegaConf.register_new_resolver("classname", lambda classpath: classpath.split(".")[-1]) + +import pyrootutils pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -# ------------------------------------------------------------------------------------ # -# the setup_root above is equivalent to: -# - adding project root dir to PYTHONPATH -# (so you don't need to force user to install project as a package) -# (necessary before importing any local modules e.g. `from src import utils`) -# - setting up PROJECT_ROOT environment variable -# (which is used as a base for paths in "configs/paths/default.yaml") -# (this way all filepaths are the same no matter where you run the code) -# - loading environment variables from ".env" in root dir -# -# you can remove it if you: -# 1. either install project as a package or move entry files to project root dir -# 2. set `root_dir` to "." in "configs/paths/default.yaml" -# -# more info: https://github.com/ashleve/pyrootutils -# ------------------------------------------------------------------------------------ # from src import utils - log = utils.get_pylogger(__name__) +# TODO: Remove after debugging +import torch +import os @utils.task_wrapper def train(cfg: DictConfig) -> Tuple[dict, dict]: - """Trains the model. Can additionally evaluate on a testset, using best weights obtained during - training. + """Trains and optionally evaluates the model. This method is wrapped in optional @task_wrapper decorator which applies extra utilities before and after the call. @@ -51,18 +41,17 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: cfg (DictConfig): Configuration composed by Hydra. Returns: - Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. + Tuple[str, dict, dict]: Best model path, dict with metrics and dict with all instantiated objects. """ # set seed for random number generators in pytorch, numpy and python.random - if cfg.get("seed"): - pl.seed_everything(cfg.seed, workers=True) + pl.seed_everything(cfg.seed, workers=True) - log.info(f"Instantiating datamodule <{cfg.data.dg._target_}>") - datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data.dg) + log.info(f"Instantiating datamodule <{cfg.data._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) - log.info(f"Instantiating model <{cfg.model.dg._target_}>") - model: LightningModule = hydra.utils.instantiate(cfg.model.dg) + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model) log.info("Instantiating callbacks...") callbacks: List[Callback] = utils.instantiate_callbacks( @@ -91,11 +80,10 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: utils.log_hyperparameters(object_dict) # Train model - if cfg.get("train"): - log.info("Starting training!") - trainer.fit( - model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path") - ) + log.info("Starting training!") + trainer.fit( + model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path") + ) train_metrics = trainer.callback_metrics metric_dict = {**train_metrics} @@ -107,21 +95,25 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: if os.path.exists(path_ckpt_csv) == False: with open(path_ckpt_csv, "w", newline="") as file: writer = csv.writer(file) - writer.writerow(["experiment_name", "ckpt_path"]) - writer.writerow(["place_holder", "place_holder"]) + writer.writerow(["experiment_name", "experiment_id", "seed", "ckpt_path"]) + writer.writerow(["place_holder", "place_holder", "place_holder", "place_holder"]) pd_ckpt = pd.read_csv(path_ckpt_csv) if logger: if logger[0].experiment.name not in pd_ckpt["experiment_name"].tolist(): with open(path_ckpt_csv, "a+", newline="") as file: writer = csv.writer(file) - writer.writerow([logger[0].experiment.name, ckpt_path]) - + writer.writerow([logger[0].experiment.name, logger[0].experiment.id, cfg.seed, ckpt_path]) + + # Print model checkpoint and experiment to resume for testing. + print("\nBest model checkpoint path: ", ckpt_path) + if logger: + print(f"\nExperiment id: {logger[0].experiment.id}") return metric_dict, object_dict @hydra.main( - version_base="1.3", config_path="../configs", config_name="train_dg.yaml" + version_base="1.3", config_path="../configs", config_name="train.yaml" ) def main(cfg: DictConfig) -> Optional[float]: # train the model diff --git a/src/train_dg_pa.py b/src/train_dg_pa.py deleted file mode 100644 index dcb7aaf..0000000 --- a/src/train_dg_pa.py +++ /dev/null @@ -1,143 +0,0 @@ -from typing import List, Optional, Tuple - -import hydra -import pytorch_lightning as pl -import pyrootutils -import torch -from pytorch_lightning import ( - Callback, - LightningDataModule, - LightningModule, - Trainer, -) -from pytorch_lightning.loggers import Logger -from omegaconf import DictConfig - -pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -# ------------------------------------------------------------------------------------ # -# the setup_root above is equivalent to: -# - adding project root dir to PYTHONPATH -# (so you don't need to force user to install project as a package) -# (necessary before importing any local modules e.g. `from src import utils`) -# - setting up PROJECT_ROOT environment variable -# (which is used as a base for paths in "configs/paths/default.yaml") -# (this way all filepaths are the same no matter where you run the code) -# - loading environment variables from ".env" in root dir -# -# you can remove it if you: -# 1. either install project as a package or move entry files to project root dir -# 2. set `root_dir` to "." in "configs/paths/default.yaml" -# -# more info: https://github.com/ashleve/pyrootutils -# ------------------------------------------------------------------------------------ # - -from src import utils - -log = utils.get_pylogger(__name__) - - -@utils.task_wrapper -def train(cfg: DictConfig) -> Tuple[dict, dict]: - """Trains the model. Can additionally evaluate on a testset, using best weights obtained during - training. - - This method is wrapped in optional @task_wrapper decorator, that controls the behavior during - failure. Useful for multiruns, saving info about the crash, etc. - - Args: - cfg (DictConfig): Configuration composed by Hydra. - - Returns: - Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. - """ - - # set seed for random number generators in pytorch, numpy and python.random - - if cfg.get("seed"): - pl.seed_everything(cfg.seed, workers=True) - - log.info(f"Instantiating datamodule <{cfg.data.dg._target_}>") - datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data.dg) - - log.info(f"Instantiating model <{cfg.model.dg._target_}>") - model: LightningModule = hydra.utils.instantiate(cfg.model.dg) - - log.info("Instantiating callbacks...") - callbacks: List[Callback] = utils.instantiate_callbacks( - cfg.get("callbacks") - ) - - log.info("Instantiating loggers...") - logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) - - log.info(f"Instantiating trainer <{cfg.trainer._target_}>") - trainer: Trainer = hydra.utils.instantiate( - cfg.trainer, callbacks=callbacks, logger=logger - ) - - object_dict = { - "cfg": cfg, - "datamodule": datamodule, - "model": model, - "callbacks": callbacks, - "logger": logger, - "trainer": trainer, - } - - if logger: - log.info("Logging hyperparameters!") - utils.log_hyperparameters(object_dict) - - if cfg.get("compile"): - log.info("Compiling model!") - model = torch.compile(model) - - if cfg.get("train"): - log.info("Starting training!") - trainer.fit( - model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path") - ) - - train_metrics = trainer.callback_metrics - - if cfg.get("test"): - log.info("Starting testing!") - ckpt_path = trainer.checkpoint_callback.best_model_path - if ckpt_path == "": - log.warning( - "Best ckpt not found! Using current weights for testing..." - ) - ckpt_path = None - trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) - log.info(f"Best ckpt path: {ckpt_path}") - - test_metrics = trainer.callback_metrics - - # merge train and test metrics - metric_dict = {**train_metrics, **test_metrics} - - return metric_dict, object_dict - - -@hydra.main( - version_base="1.3", config_path="../configs", config_name="train_dg_pa.yaml" -) -def main(cfg: DictConfig) -> Optional[float]: - # apply extra utilities - # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) - utils.extras(cfg) - - # train the model - metric_dict, _ = train(cfg) - - # safely retrieve metric value for hydra-based hyperparameter optimization - metric_value = utils.get_metric_value( - metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") - ) - - # return optimized metric - return metric_value - - -if __name__ == "__main__": - main() diff --git a/src/train_pa.py b/src/train_pa.py deleted file mode 100644 index cc9364c..0000000 --- a/src/train_pa.py +++ /dev/null @@ -1,143 +0,0 @@ -from typing import List, Optional, Tuple - -import hydra -import pytorch_lightning as pl -import pyrootutils -import torch -from pytorch_lightning import ( - Callback, - LightningDataModule, - LightningModule, - Trainer, -) -from pytorch_lightning.loggers import Logger -from omegaconf import DictConfig - -pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -# ------------------------------------------------------------------------------------ # -# the setup_root above is equivalent to: -# - adding project root dir to PYTHONPATH -# (so you don't need to force user to install project as a package) -# (necessary before importing any local modules e.g. `from src import utils`) -# - setting up PROJECT_ROOT environment variable -# (which is used as a base for paths in "configs/paths/default.yaml") -# (this way all filepaths are the same no matter where you run the code) -# - loading environment variables from ".env" in root dir -# -# you can remove it if you: -# 1. either install project as a package or move entry files to project root dir -# 2. set `root_dir` to "." in "configs/paths/default.yaml" -# -# more info: https://github.com/ashleve/pyrootutils -# ------------------------------------------------------------------------------------ # - -from src import utils - -log = utils.get_pylogger(__name__) - - -@utils.task_wrapper -def train(cfg: DictConfig) -> Tuple[dict, dict]: - """Trains the model. Can additionally evaluate on a testset, using best weights obtained during - training. - - This method is wrapped in optional @task_wrapper decorator, that controls the behavior during - failure. Useful for multiruns, saving info about the crash, etc. - - Args: - cfg (DictConfig): Configuration composed by Hydra. - - Returns: - Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. - """ - - # set seed for random number generators in pytorch, numpy and python.random - - if cfg.get("seed"): - pl.seed_everything(cfg.seed, workers=True) - - log.info(f"Instantiating datamodule <{cfg.data.adv._target_}>") - datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data.adv) - - log.info(f"Instantiating model <{cfg.model.adv._target_}>") - model: LightningModule = hydra.utils.instantiate(cfg.model.adv) - - log.info("Instantiating callbacks...") - callbacks: List[Callback] = utils.instantiate_callbacks( - cfg.get("callbacks") - ) - - log.info("Instantiating loggers...") - logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) - - log.info(f"Instantiating trainer <{cfg.trainer._target_}>") - trainer: Trainer = hydra.utils.instantiate( - cfg.trainer, callbacks=callbacks, logger=logger - ) - - object_dict = { - "cfg": cfg, - "datamodule": datamodule, - "model": model, - "callbacks": callbacks, - "logger": logger, - "trainer": trainer, - } - - if logger: - log.info("Logging hyperparameters!") - utils.log_hyperparameters(object_dict) - - if cfg.get("compile"): - log.info("Compiling model!") - model = torch.compile(model) - - if cfg.get("train"): - log.info("Starting training!") - trainer.fit( - model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path") - ) - - train_metrics = trainer.callback_metrics - - if cfg.get("test"): - log.info("Starting testing!") - ckpt_path = trainer.checkpoint_callback.best_model_path - if ckpt_path == "": - log.warning( - "Best ckpt not found! Using current weights for testing..." - ) - ckpt_path = None - trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) - log.info(f"Best ckpt path: {ckpt_path}") - - test_metrics = trainer.callback_metrics - - # merge train and test metrics - metric_dict = {**train_metrics, **test_metrics} - - return metric_dict, object_dict - - -@hydra.main( - version_base="1.3", config_path="../configs", config_name="train_adv.yaml" -) -def main(cfg: DictConfig) -> Optional[float]: - # apply extra utilities - # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) - utils.extras(cfg) - - # train the model - metric_dict, _ = train(cfg) - - # safely retrieve metric value for hydra-based hyperparameter optimization - metric_value = utils.get_metric_value( - metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") - ) - - # return optimized metric - return metric_value - - -if __name__ == "__main__": - main() diff --git a/tests/test_pa.py b/tests/test_pa.py new file mode 100644 index 0000000..a98a44b --- /dev/null +++ b/tests/test_pa.py @@ -0,0 +1,55 @@ +import pyrootutils +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +import hydra +from omegaconf import DictConfig +from typing import Optional +import warnings +warnings.simplefilter("ignore") + +# Tests to perform +from tests.test_pa.data_pipeline import * +from tests.test_pa.ddp import * +from tests.test_pa.pa_module import * +from tests.test_pa.pa_metric import * +from tests.test_pa.pa_callback import * + +@hydra.main( + version_base="1.3", config_path="../configs", config_name="test_pa.yaml" +) +def main(cfg: DictConfig) -> Optional[float]: + """ + Tests of the data pipeline. + """ + # test_sampler(cfg) + # test_dataloaders(cfg) + + """ + Tests of the parallelization strategy. + """ + # test_ddp(cfg) + + """ + Tests of the PA module. + """ + # test_pa_module(cfg) + + """ + Tests of the PA metric. + """ + # test_basemetric(cfg) + # test_pametric_cpu(cfg) + # test_pametric_ddp(cfg) + # test_pametric_logits(cfg) + # test_accuracymetric(cfg) + + print("-----------------------------------------------") + """ + Tests of the PA callback. + """ + test_pa_callback(cfg) + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/pa_metric/__init__.py b/tests/test_pa/__init__.py similarity index 100% rename from src/pa_metric/__init__.py rename to tests/test_pa/__init__.py diff --git a/tests/test_pa/data_pipeline.py b/tests/test_pa/data_pipeline.py new file mode 100644 index 0000000..0671def --- /dev/null +++ b/tests/test_pa/data_pipeline.py @@ -0,0 +1,258 @@ +""" +This test will check that the data generated in the MultiEnv, PA and PA_logits datamodules is consistent with the +requirements of the different optimization procedures. +""" + +import hydra +from omegaconf import DictConfig + +import os +import torch +from torch.utils.data import TensorDataset, DataLoader +import numpy as np +import warnings + +from src.pa_metric.pairing import PosteriorAgreementDatasetPairing +from pytorch_lightning import LightningDataModule + +from src.data.components import MultienvDataset, LogitsDataset +from src.data.components.collate_functions import MultiEnv_collate_fn + +from .utils import plot_multienv, plot_images_multienv + +__all__ = ["test_dataloaders", "test_sampler"] +PLOT_INDS = [0, 9] # must be smaller than batch_size +CHECK_LOGITS = 10 # will be passed to the model, must be smaller than the length of the paired dataset + +""" +- These tests are designed to run locally, as subprocess tests are performed separately. Reduce CHECK_LOGITS and batch_size if you run out of memory. +- It is recommended to keep batch_size and CHECK_LOGITS as big as possible, otherwise randomness could lead to some false negatives. +- Important to keeep batch_size <= len(dataset)/2, as I run two epochs here. +""" + +def _dl_plot(dataloader: DataLoader, batch_size: int, img_dir: str, expected: str = "equal"): + assert expected in ["equal", "not_equal"], "The expected behaviour must be either 'equal' or 'not_equal'." + + for epoch in range(2): + for batch_ind, batch in enumerate(dataloader): + env_names = list(batch.keys()) + x1, x2, y1, y2 = batch[env_names[0]][0], batch[env_names[1]][0], batch[env_names[0]][1], batch[env_names[1]][1] + + """ + Labels will be the same when the dataset is correspondent. PA and PA_logits must be, but main can be different. + """ + if expected == "equal": + assert torch.equal(y1, y2), "The labels are not the same, and they should." + else: + assert not torch.equal(y1, y2), "The labels are the same, and they shouldn't." + + assert torch.equal(torch.tensor(x1.size()), torch.tensor(x2.size())), "The images have different sizes." + assert x1.size(0) == batch_size, "The batch size is not the same as the one specified in the config file." + + # Visual check: + for ind in PLOT_INDS: + plot_images_multienv([x1[ind], x2[ind]], [y1[ind], y2[ind]], + os.path.join(img_dir, f"epoch_{epoch}_batch_{batch_ind}_ind_{ind}")) + + break # only one batch per epoch + +def _dl_logits(dataloader: DataLoader, classifier: torch.nn.Module, logits_dataloader: DataLoader, expected="equal"): + assert expected in ["equal", "not_equal"], "The expected behaviour must be either 'equal' or 'not_equal'." + + assert len(dataloader.dataset) == len(logits_dataloader.dataset), "The number of samples in the dataset and the logits dataset is not the same." + #inds_to_compare = torch.randint(0, len(dataloader.dataset), (CHECK_LOGITS,)).tolist() + inds_to_compare = list(range(5)) + subset = dataloader.dataset.__getitems__(inds_to_compare) + subset_logits = logits_dataloader.dataset.__getitems__(inds_to_compare) + + classifier.eval() + with torch.no_grad(): + if expected == "equal": + for i in range(2): # two environments + assert torch.allclose(classifier(subset[i][0]), subset_logits[i][0]), "The logits generated and stored are not the same, and they should be." + else: + for i in range(2): + assert not torch.allclose(classifier(subset[i][0]), subset_logits[i][0]), "The logits generated and stored are the same, and they shouldn't be." + +def test_dataloaders(cfg: DictConfig): + """ + EXPLANATION + """ + torch.manual_seed(cfg.seed) # so that the dataloader shuffle yields the same results + + # Main_dataloader + dm_main: LightningDataModule = hydra.utils.instantiate(cfg.data.datamodules.main) + dm_main.prepare_data() + dm_main.setup("fit") + dl_main = dm_main.train_dataloader() + assert type(dl_main.dataset) in [MultienvDataset, LogitsDataset], "All datasets must belong to class MultienvDataset or LogitsDataset." + + _dl_plot(dl_main, + cfg.data.datamodules.main.batch_size, + os.path.join(cfg.paths.results_tests, cfg.data.datamodules.data_name + "_main"), + expected = "equal" if cfg.data.expected_results.main.corresponding_labels else "not_equal") + + """ + We must check whether the main dataset contains corresponding labels, which usually is not the case. + """ + if cfg.data.expected_results.main.corresponding_labels: + for e in range(1, dm_main.train_ds.num_envs): # Main dataset can have more than two environments + assert torch.equal(dm_main.train_ds.__getlabels__(list(range(len(dm_main.train_ds))))[0], + dm_main.train_ds.__getlabels__(list(range(len(dm_main.train_ds))))[e]), "The labels in the main dataset are not corresponding, and they should be." + else: + for e in range(1, dm_main.train_ds.num_envs): + assert not torch.equal(dm_main.train_ds.__getlabels__(list(range(len(dm_main.train_ds))))[0], + dm_main.train_ds.__getlabels__(list(range(len(dm_main.train_ds))))[e]), "The labels in the main dataset are corresponding, and they shouldn't be." + + + # PA_dataloader + dm_pa: LightningDataModule = hydra.utils.instantiate(cfg.data.datamodules.pa) + dm_pa.prepare_data() + dm_pa.setup("fit") + dl_pa = dm_pa.train_dataloader() + assert type(dl_pa.dataset) in [MultienvDataset, LogitsDataset], "All datasets must belong to class MultienvDataset or LogitsDataset." + + _dl_plot(dl_pa, + cfg.data.datamodules.pa.batch_size, + os.path.join(cfg.paths.results_tests, cfg.data.datamodules.data_name + "_pa"), + expected = "equal") # this must be always the case for PA + + """ + We must check that the PA dataset contains corresponding labels, as it has been passed through PosteriorAgreementDatasetPairing + """ + assert torch.equal(dm_pa.train_ds.__getlabels__(list(range(len(dm_pa.train_ds))))[0], + dm_pa.train_ds.__getlabels__(list(range(len(dm_pa.train_ds))))[1]), "The labels in the PA dataset are not corresponding." + + # PAlogits_dataloader + dm_palogits: LightningDataModule = hydra.utils.instantiate(cfg.data.datamodules.pa_logits) + dm_palogits.prepare_data() + dm_palogits.setup("fit") + dl_palogits = dm_palogits.train_dataloader() # Has corresponding labels by definition + assert type(dl_palogits.dataset) in [MultienvDataset, LogitsDataset], "All datasets must belong to class MultienvDataset or LogitsDataset." + + """ + We must check that the labels and logits are the same when `shuffle=False`, for CHECK_LOGITS_INDS. + Set `expected="equal"` when the model passed to the PA_logits dataloader is the same as the one passed to the main and PA dataloaders (if they require so). + """ + if "classifier" in list(cfg.data.datamodules.main.keys()): # If the main DL has a classifier (only in adversarial case) + model_main: torch.nn.Module = hydra.utils.instantiate(cfg.data.datamodules.main.classifier) + _dl_logits(dl_main, model_main, dl_palogits, + # Expected equal when labels are correspondent (so PAPairing won't affect it) and they have the same input model (so the logits are the same) + expected="equal" if cfg.data.expected_results.main.corresponding_labels and cfg.data.expected_results.main.same_model_logits else "not_equal") + + model_pa: torch.nn.Module = hydra.utils.instantiate(cfg.data.datamodules.pa.classifier) + _dl_logits(dl_pa, model_pa, dl_palogits, + expected="equal" if cfg.data.expected_results.pa.same_model_logits else "not_equal") # Must have corresponding labels + + model_palogits: torch.nn.Module = hydra.utils.instantiate(cfg.data.datamodules.pa_logits.classifier) + # _dl_logits(dl_main, model_palogits, dl_palogits, + # expected="equal" if cfg.data.expected_results.main.corresponding_labels else "not_equal") # so pairing won't affect it + _dl_logits(dl_pa, model_palogits, dl_palogits, expected="equal") # dl_pa paired, and pairing is maintained in dl_palogits + + """ + Check length of the datasets. + """ + assert len(dm_main.train_ds) == len(dm_pa.train_ds) == len(dm_palogits.logits_ds), "The length of the datasets is not the same." + + """ + Finally, the PA and PA_logits datasets should have the same labels. Checking for the first environment is enough. + """ + assert torch.equal(dm_pa.train_ds.__getlabels__(list(range(len(dm_pa.train_ds))))[0], + dm_palogits.logits_ds.y), "The labels in the PA and PA_logits datasets are not the same." + + # print("\n Dataloader retrieval (first 5 samples or first environment): ") + """ + To compare the dataloader retrieval, we should see that the samples of the main and PA dataloaders are not the same. + Only the first batch is enough + """ + for bidx, (b_main, b_pa, b_palog) in enumerate(zip(dl_main, dl_pa, dl_palogits)): + env_names = list(b_main.keys()) + for env in env_names: + Xe_main, Xe_pa, Xe_palog = b_main[env][0], b_pa[env][0], b_palog[env][0] + ye_main, ye_pa, ye_palog = b_main[env][1], b_pa[env][1], b_palog[env][1] + + sum_main = torch.tensor([torch.sum(X).item() for X in Xe_main]) + sum_pa = torch.tensor([torch.sum(X).item() for X in Xe_pa]) + assert not torch.equal(sum_main, sum_pa), "The PA dataloader doesn't shuffle observations properly." + break # only the first batch + + print("\n\nTest passed.") + + +def test_sampler(cfg: DictConfig): + """ + The goal is to see whether the PosteriorAgreementDatasetPairing function works as expected. + 1. Compare observations from the dataset given by the original permutations and the ones given by the sampler. + 2. Analyze observations provided by the train_dataloader for different epochs. + """ + + np.random.seed(cfg.seed); torch.manual_seed(cfg.seed) + #warnings.simplefilter("ignore") # generation of images will yield warning + + dm: LightningDataModule = hydra.utils.instantiate(cfg.data.datamodule) + dm.prepare_data() + dm.setup("fit") + + # In this way we can keep both datasets in memory: + dataset = MultienvDataset(dm.train_ds.dset_list) + dataset_sampler = PosteriorAgreementDatasetPairing(dm.train_ds) + + print("Class of dataset: ", type(dataset), "\n") + print("Length original vs sampled: {} vs {}".format(len(dataset), len(dataset_sampler))) + + envs, envs_sampled = list(dataset.__getitem__(0).keys()), list(dataset.__getitem__(0).keys()) + print("Number of environments original vs sampled: {} vs {}".format(len(dataset), len(dataset_sampler))) + assert len(envs) == len(envs_sampled), "The number of environments is not the same." + + # Since datasets are MiltienvDataset objects, we can get their labels straight away: + print("\nOriginal labels: ", ) + for e in range(len(envs)): + print("Environment " + str(envs[e]) + ": ", list(dataset.__getlabels__(list(range(len(dataset)))))[e].tolist()[:10]) + print("\nSampled labels: ", ) + for e in range(len(envs_sampled)): + print("Environment " + str(envs_sampled[e]) + ": ", list(dataset_sampler.__getlabels__(list(range(len(dataset_sampler)))))[e].tolist()[:10]) + + #inds_to_plot = [0, 11, 10, 3] # 0-0, 0-1, 1-0, 1-1 + inds_to_plot = [1, 2, 4, 5] # 8, 8, 6, 6 + plot_multienv(dataset, cfg.paths.results_tests + "dataset/test_dataset", random=inds_to_plot) + plot_multienv(dataset_sampler, cfg.paths.results_tests + "dataset/test_sampler", random=inds_to_plot) + + # Now I will generate a mismatch of 100 samples between the two environments in the original dataset: + subset_inds = torch.randint(0, len(dataset), (20,)) + envs_subset = dataset.__getitems__(subset_inds) + + print("\nBefore modification: ") + print("Subset env 0: ", envs_subset[0][1]) + print("Subset env 1: ", envs_subset[1][1]) + + inds_1 = torch.where(envs_subset[1][1].eq(envs_subset[1][1][0]))[0] # position of observations in env2 of label equal to first label + inds_0 = torch.where(envs_subset[0][1].ne(envs_subset[1][1][0]))[0] # position of observations in env1 of label different to such label + envs_subset[0][0][inds_0[0:len(inds_1)]] = envs_subset[1][0][inds_1] # change observations + envs_subset[0][1][inds_0[0:len(inds_1)]] = envs_subset[1][1][inds_1] # change labels associated + + print("\nAfter modification: ") + print("Subset env 0: ", envs_subset[0][1]) + print("Subset env 1: ", envs_subset[1][1]) + + subset = MultienvDataset([TensorDataset(*subs_env) for subs_env in envs_subset]) # create subset + + print("\nSubset labels for modified indexes: ", inds_0[0:len(inds_1)].tolist()) + envs = list(subset.__getitem__(0).keys()) + for e in range(len(envs)): + print("Environment " + str(envs[e]) + ": ", list(subset.__getlabels__(list(range(len(subset)))))[e].tolist()) + + subset_sampler = PosteriorAgreementDatasetPairing(MultienvDataset(subset.dset_list)) + + print("\nSampled labels for modified: ", ) + envs_sampled = list(subset_sampler.__getitem__(0).keys()) + for e in range(len(envs_sampled)): + print("Environment " + str(envs_sampled[e]) + ": ", list(subset_sampler.__getlabels__(list(range(len(subset_sampler)))))[e].tolist()) + + print("\nLength original vs sampled: {} vs {}".format(len(subset), len(subset_sampler))) + print("Number of environments original vs sampled: {} vs {}".format(len(envs), len(envs_sampled))) + + inds_to_plot = [1, 2, 4, 5] # 8, 8, 6, 6 + plot_multienv(subset, cfg.paths.results_tests + "subset/test_subset", random=inds_to_plot) + plot_multienv(subset_sampler, cfg.paths.results_tests + "subset/test_subsetsampler", random=inds_to_plot) + + print("\n\nTest passed.") \ No newline at end of file diff --git a/tests/test_pa/ddp.py b/tests/test_pa/ddp.py new file mode 100644 index 0000000..ad2acb1 --- /dev/null +++ b/tests/test_pa/ddp.py @@ -0,0 +1,243 @@ +""" +This test will check that the data passed to the model is the same in the CPU and DDP configurations. +""" + +import hydra +from omegaconf import DictConfig +from typing import Optional + +import torch +from torch.utils.data.sampler import RandomSampler +from torch.utils.data import DataLoader, DistributedSampler +from pytorch_lightning import LightningDataModule, LightningModule, Trainer +from pytorch_lightning.trainer.supporters import CombinedLoader + +from .utils import plot_images_multienv + +class TestingModule(LightningModule): + def __init__( + self, + classifier: torch.nn.Module, + expected_results: DictConfig, + ): + super().__init__() + + # Retrieve classifier to deduce the logits + self.model = classifier + first_param = True + for param in self.model.parameters(): + param.requires_grad = False + if first_param: + param.requires_grad = True + first_param = False + + self.loss = torch.nn.CrossEntropyLoss() + + # Save expected results as a DictConfig to be used in the testing step + self.expected_results = expected_results + self.size_main, self.size_pa, self.size_palogits = 0, 0, 0 + self.plot_images = None + + def _testing_step(self, batch_dict: dict, bidx: int): + b_main, b_pa, b_palogits = batch_dict["dl_main"], batch_dict["dl_pa"], batch_dict["dl_palogits"] + env_names = list(b_main.keys()) + + if bidx == 0 and self.trainer.local_rank == 0: + # Save some images to plot later and have a visual inspection. + self.plot_images = [b_main[env][0][0] for env in env_names] + + """ + We will use the model only for one epoch, so we are interested in checking the size of the data being passed + wrt the same procedure in DDP. + """ + self.size_main += len(b_main[env_names[0]][1]) + self.size_pa += len(b_pa[env_names[0]][1]) + self.size_palogits += len(b_palogits[env_names[0]][1]) + # In the configuration of the test, all dataloaders must have the same data source. + # Eventually some samples will be discarded for PA and PA_logits wrt the main dataloader, but this shoudln't be + # reflected here, as mode="min_size" is used in the CombinedLoader. + assert self.size_main == self.size_pa == self.size_palogits, "The batches must have the same length." + + """ + Observations must be shuffled by the dataloaders in different ways. + """ + + labels_main = torch.cat([b_main[env][1] for env in env_names]) + labels_pa = torch.cat([b_pa[env][1] for env in env_names]) + labels_palogits = torch.cat([b_palogits[env][1] for env in env_names]) + assert not torch.equal(labels_main, labels_pa), "Observations in the PA dataloader are not being shuffled properly." + assert not torch.equal(labels_pa, labels_palogits), "Observations in the PA_logits dataloader are not being shuffled properly." + + plot_images_multienv + + for env in env_names: + if env != env_names[0]: + """ + Observations across environments must be different when shift_factor > 0. + Use large shift_factor to ensure that the samples are not repeated and it's a false positive. + """ + assert not torch.equal(Xe_main, b_main[env][0]), "The samples across environments could be repeated in the main dataset." + assert not torch.equal(Xe_pa, b_pa[env][0]), "The samples across environments could be repeated in the PA dataset." + assert not torch.equal(Xe_palogits, b_palogits[env][0]), "The samples across environments could be repeated in the PA_logits dataset." + + """ + Labels across environments must be the same in any case, at least for the PA and PA_logits dataloaders. + """ + if self.expected_results.main.corresponding_labels: + assert torch.equal(ye_main, b_main[env][1]), "Labels in the main dataloader are not corresponding, and they should." + else: + assert not torch.equal(ye_main, b_main[env][1]), "Labels in the main dataloader are corresponding, and they shouldn't be." + assert torch.equal(ye_pa, b_pa[env][1]), "Labels in the PA dataloader are not corresponding, and they should be." + assert torch.equal(ye_palogits, b_palogits[env][1]), "Labels in the PA_logits dataloader are not corresponding, and they should be." + + Xe_main, Xe_pa, Xe_palogits = b_main[env][0], b_pa[env][0], b_palogits[env][0] + ye_main, ye_pa, ye_palogits = b_main[env][1], b_pa[env][1], b_palogits[env][1] + + + def _model_step(self, batch_dict: dict, grad_enabled: bool = True): + b_main = batch_dict["dl_main"] + env_names = list(b_main.keys()) + x, y = b_main[env_names[0]] + with torch.set_grad_enabled(grad_enabled): + logits = self.model(x).to(self.device) + loss = self.loss(input=logits, target=y).to(self.device) + + return loss + + def training_step(self, batch_dict: dict, bidx: int): + if self.trainer.accelerator == "gpu": + print(f"This is happening: {bidx}") + print(self.trainer.local_rank) + self._testing_step(batch_dict, bidx) + + loss = self._model_step(batch_dict) + return {"loss": loss} + + def validation_step(self, batch_dict: dict, bidx: int): + self._testing_step(batch_dict, bidx) + loss = self._model_step(batch_dict, False) + return {"loss": loss} + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=0.1) + return {"optimizer": optimizer} + + +class TestingDataModule(LightningDataModule): + def __init__( + self, + cfg: DictConfig, + ): + super().__init__() + + self.dm_main = hydra.utils.instantiate(cfg.ddp.datamodules.main) + self.dm_pa = hydra.utils.instantiate(cfg.ddp.datamodules.pa) + self.dm_palogits = hydra.utils.instantiate(cfg.ddp.datamodules.pa_logits) + + + def prepare_data(self): + self.dm_main.prepare_data() + self.dm_pa.prepare_data() + self.dm_palogits.prepare_data() + + def setup(self, stage: Optional[str] = None): + self.dm_main.setup("fit") + self.dm_pa.setup("fit") + self.dm_palogits.setup("fit") + + self.ds_pa = self.dm_pa.train_ds + self.ds_palogits = self.dm_palogits.logits_ds + + def train_dataloader(self): + return CombinedLoader( + { + "dl_main": self.dm_main.train_dataloader(), + "dl_pa": self.dm_pa.train_dataloader(), + "dl_palogits": self.dm_palogits.train_dataloader() + }, + mode="min_size" + ) + + def val_dataloader(self): + return CombinedLoader( + { + "dl_main": self.dm_main.val_dataloader(), + "dl_pa": self.dm_pa.val_dataloader(), + "dl_palogits": self.dm_palogits.val_dataloader() + }, + mode="min_size" + ) + + +def test_ddp(cfg: DictConfig): + """ + The goal is to evaluate the data retrieved by the Trainer when using DDP. + """ + + """ + We run the TestingModule with both CPU and DDP trainers. + """ + # If the DataModule requires a classifier, we will assume it's the same as the one used in the logits + # otherwise the results wouldn't make any sense. + model = TestingModule( + classifier=hydra.utils.instantiate(cfg.ddp.datamodules.pa_logits.classifier), + expected_results=cfg.ddp.expected_results + ) + + # We initialize the datamodule yielding a CombinedLoader + dm = TestingDataModule(cfg) + dm.prepare_data() + dm.setup() + + trainer = hydra.utils.instantiate(cfg.ddp.trainer.cpu) + trainer.fit(model, datamodule=dm) + size_main, size_pa, size_palogits = model.size_main, model.size_pa, model.size_palogits # store the sizes + devices_cpu = trainer.device_ids + + plot_images_multienv( + model.plot_images, + [str(i) for i in range(len(model.plot_images))], + cfg.paths.results_tests + "/cpu" + ) + + model = TestingModule( + classifier=hydra.utils.instantiate(cfg.ddp.datamodules.pa_logits.classifier), + expected_results=cfg.ddp.expected_results + ) + + trainer = hydra.utils.instantiate(cfg.ddp.trainer.ddp) + trainer.fit(model, datamodule=dm) + devices_ddp = trainer.device_ids + + print("\nDevices used by CPU: ", devices_cpu) + print("Devices used by DDP: ", devices_ddp) + + + plot_images_multienv( + model.plot_images, + [str(i) for i in range(len(model.plot_images))], + cfg.paths.results_tests + "/ddp" + ) + + + """ + The size of the data passed through the model should be the same in CPU and DDP configurations. + """ + print("HERE IS WHERE THE CONFLICT ARISES", size_main, model.size_main) + assert size_main == model.size_main, "The size of the main dataloader is different when using DDP." + assert size_pa == model.size_pa, "The size of the PA dataloader is different when using DDP." + assert size_palogits == model.size_palogits, "The size of the PA_logits dataloader is different when using DDP." + print("\nSize CPU vs DDP: ") + print("Main: ", size_main, model.size_main) + print("PA: ", size_pa, model.size_pa) + print("PA_logits: ", size_palogits, model.size_palogits) + + """ + The size of the data passed through the model in the PA and PA_logits case should be equal than the size of the dataset. + No possible drop_last=True because the adjustment has already been made. + """ + assert size_pa == len(dm.ds_pa), "The size of the PA dataloader is different than the size of the dataset. Check drop_last=False." + assert size_palogits == len(dm.ds_palogits), "The size of the PA_logits dataloader is different than the size of its dataset. Check drop_last=False." + assert size_pa == size_palogits, "The size of the PA dataset is different than the size of the PA_logits dataset." + + print("\n\nTest passed.") \ No newline at end of file diff --git a/tests/test_pa/pa_callback.py b/tests/test_pa/pa_callback.py new file mode 100644 index 0000000..3ce3f5d --- /dev/null +++ b/tests/test_pa/pa_callback.py @@ -0,0 +1,383 @@ +""" +This test will check if the PosteriorAgreement metric works properly when processing_strategy == "lightning"; that is, when +the metric is called within an ongoing parallelized process. + +The results provided the metric in the Callback should be the same as the ones provided by the metric within the +LightningModule, and the results for the last epoch should as well coincide with those obtained with the PA module. +""" + +import hydra +from omegaconf import DictConfig + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader, SequentialSampler +from pytorch_lightning import LightningModule, LightningDataModule + +from .utils import get_acc_metrics, get_pa_metrics +from copy import deepcopy + +from torch import nn, argmax, optim +from src.pa_metric.callback import PA_Callback +from src.pa_metric.pairing import PosteriorAgreementDatasetPairing +from src.pa_metric.metric import PosteriorAgreement + +class Vanilla(LightningModule): + def __init__( + self, + num_classes: int, + classifier: nn.Module, + metric, + log_every_n_epochs:int + ): + super().__init__() + + self.model = deepcopy(classifier).eval() + self.model_to_train = classifier.train() + self.loss = nn.CrossEntropyLoss() + self.n_classes = int(num_classes) + + self.metric = metric + self.log_every_n_epochs = log_every_n_epochs + + def training_step(self, batch, batch_idx): + # Adapt to Multienv_collate_fn so that I dont have to instantiate main datamodule twice + env_names = list(batch.keys()) + x = torch.cat([batch[env][0] for env in env_names]) + y = torch.cat([batch[env][1] for env in env_names]) + + with torch.set_grad_enabled(True): + logits = self.model_to_train(x) + loss = self.loss(input=logits, target=y) + assert loss.requires_grad + return {"loss": loss} + + def on_train_epoch_start(self): + # If we already computed PA in the previous iteration, then we can compare the results: + if self.current_epoch + 1 > self.log_every_n_epochs: + print(f"We are checking the results at epoch {self.current_epoch}") + assert torch.allclose(self.metric.betas, self.trainer.callbacks[0].pa_metric.betas) + assert torch.allclose(self.metric.logPAs[0], self.trainer.callbacks[0].pa_metric.logPAs[0]) + assert torch.allclose(self.metric.afr_pred, self.trainer.callbacks[0].pa_metric.afr_pred) + assert torch.allclose(self.metric.afr_true, self.trainer.callbacks[0].pa_metric.afr_true) + assert torch.allclose(self.metric.accuracy, self.trainer.callbacks[0].pa_metric.accuracy) + + def training_epoch_end(self, outputs): + if (self.current_epoch + 1) % self.log_every_n_epochs == 0: + print(f"\nWE ARE COMPUTING THE METRIC at epoch {self.current_epoch}\n") + self.metric.update( + deepcopy(self.model).eval(), + local_rank=self.trainer.local_rank, + ) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=0.1) + return {"optimizer": optimizer} + + def optimizer_step( + self, + epoch=None, + batch_idx=None, + optimizer=None, + optimizer_idx=None, + optimizer_closure=None, + on_tpu=None, + using_native_amp=None, + using_lbfgs=None, + **kwargs, + ): + """ + We will set gradients to zero after the optimizer step, so that the model is not optimized. + This is done so that I can pass an already trained classifier and obtain the same results for the + PA over and over again. + """ + + # call the 'backward' function on the current loss + optimizer_closure() + + # Set gradients to zero to avoid the model from being optimized + for param in self.parameters(): + if param.grad is not None: + param.grad *= 0 # This zeros the gradients + optimizer.step() + +""" +- Only in DDP as it's where it makes sense to test a model. +- Train Vanilla with the callback in the Trainer and with the metric inside. The result should be the same. +- The metric outside and the PA_module outside should also give the same result. +""" + +def test_pa_callback(cfg: DictConfig): + + # Main DataModule ----------------------------------------------------------------------- + # This is the datamodule that will train the model + datamodule_main: LightningDataModule = hydra.utils.instantiate(cfg.pa_callback.datamodules.main) + datamodule_main.prepare_data() + datamodule_main.setup("fit") + + # PA DataModule ------------------------------------------------------------------------- + datamodule_pa: LightningDataModule = hydra.utils.instantiate(cfg.pa_callback.datamodules.pa) + datamodule_pa.prepare_data() + datamodule_pa.setup("fit") + + # I disable shuffling so that the results can be compared + pa_traindl = datamodule_pa.train_dataloader() + pa_traindl = DataLoader( + pa_traindl.dataset, + batch_size=pa_traindl.batch_size, + sampler=SequentialSampler(pa_traindl.dataset), + num_workers=0,#pa_traindl.num_workers, + collate_fn=pa_traindl.collate_fn, + drop_last=False + ) + + #______________________________________ CPU ______________________________________ + + # Instantiation of the callback + pa_callback_partial = hydra.utils.instantiate(cfg.pa_callback.pa_callback) + pa_callback = pa_callback_partial( + dataset = pa_traindl.dataset, + cuda_devices = 0 + ) + + # Instantiation of the metric + pa_metric_partial = hydra.utils.instantiate(cfg.pa_callback.pa_metric) + pa_metric = pa_metric_partial( + dataset = pa_traindl.dataset, + processing_strategy = "lightning", + cuda_devices=0 + ) + + # We train the model with the callback and the metric inside: + trainer_vanilla_partial = hydra.utils.instantiate(cfg.pa_callback.trainers.vanilla.cpu) + trainer_vanilla = trainer_vanilla_partial( + callbacks=[pa_callback] + ) + vanilla_model_partial = hydra.utils.instantiate(cfg.pa_callback.vanilla_model) + vanilla_model = vanilla_model_partial( + metric = pa_metric + ) + trainer_vanilla.fit(vanilla_model, datamodule_main) + + """ + The overall results of the callback and the metric should be exactly the same. + """ + + # Check PA results from the last epoch it was computed. For other epochs it has been checked from within the LightnignModule + assert torch.allclose(vanilla_model.metric.betas, vanilla_model.trainer.callbacks[0].pa_metric.betas) + assert torch.allclose(vanilla_model.metric.logPAs, vanilla_model.trainer.callbacks[0].pa_metric.logPAs) + assert torch.allclose(vanilla_model.metric.afr_pred, vanilla_model.trainer.callbacks[0].pa_metric.afr_pred) + assert torch.allclose(vanilla_model.metric.afr_true, vanilla_model.trainer.callbacks[0].pa_metric.afr_true) + assert torch.allclose(vanilla_model.metric.accuracy, vanilla_model.trainer.callbacks[0].pa_metric.accuracy) + + # Check that the values are different + assert not torch.equal(vanilla_model.metric.betas, vanilla_model.metric.betas[torch.randperm(len(vanilla_model.metric.betas))]) + + # We also check the final results of the whole PA optimization (i.e. the PA selected for each call) + assert torch.allclose(torch.tensor(vanilla_model.metric.log_beta), torch.tensor(vanilla_model.trainer.callbacks[0].pa_metric.log_beta)) + assert torch.allclose(torch.tensor(vanilla_model.metric.log_logPA), torch.tensor(vanilla_model.trainer.callbacks[0].pa_metric.log_logPA)) + assert torch.allclose(torch.tensor(vanilla_model.metric.log_AFR_pred), torch.tensor(vanilla_model.trainer.callbacks[0].pa_metric.log_AFR_pred)) + assert torch.allclose(torch.tensor(vanilla_model.metric.log_AFR_pred), torch.tensor(vanilla_model.trainer.callbacks[0].pa_metric.log_AFR_pred)) + assert torch.allclose(torch.tensor(vanilla_model.metric.log_accuracy), torch.tensor(vanilla_model.trainer.callbacks[0].pa_metric.log_accuracy)) + + # """ + # Now we will check that the same results are obtained when using the MAIN datamodule also for instantiating the callback, etc. + # """ + # main_ds = PosteriorAgreementDatasetPairing(datamodule_main.train_ds) + + # # Instantiation of the callback + # pa_callback_partial = hydra.utils.instantiate(cfg.pa_callback.pa_callback) + # pa_callback = pa_callback_partial( + # dataset = main_ds, + # cuda_devices = 0 + # ) + + # # Instantiation of the metric + # pa_metric_partial = hydra.utils.instantiate(cfg.pa_callback.pa_metric) + # pa_metric = pa_metric_partial( + # dataset = main_ds, + # processing_strategy = "lightning", + # cuda_devices=0 + # ) + + # # We train the model with the callback and the metric inside: + # trainer_vanilla_partial = hydra.utils.instantiate(cfg.pa_callback.trainers.vanilla.cpu) + # trainer_vanilla = trainer_vanilla_partial( + # callbacks=[pa_callback] + # ) + # vanilla_model_partial = hydra.utils.instantiate(cfg.pa_callback.vanilla_model) + # vanilla_model2 = vanilla_model_partial( + # metric = pa_metric + # ) + # trainer_vanilla.fit(vanilla_model2, datamodule_main) + + # # First we check that the results are the same wrt the previous implementation. + # assert vanilla_model.metric.afr_pred.item() == vanilla_model2.metric.afr_pred.item(), "The AFR_pred is not the same." + # assert vanilla_model.metric.afr_true.item() == vanilla_model2.metric.afr_true.item(), "The AFR_true is not the same." + # assert vanilla_model.metric.accuracy.item() == vanilla_model2.metric.accuracy.item(), "The accuracy is not the same." + + # # Check PA results from the last epoch it was computed. For other epochs it has been checked from within the LightnignModule + # assert torch.allclose(vanilla_model2.metric.betas, vanilla_model2.trainer.callbacks[0].pa_metric.betas) + # assert torch.allclose(vanilla_model2.metric.logPAs, vanilla_model2.trainer.callbacks[0].pa_metric.logPAs) + # assert torch.allclose(vanilla_model2.metric.afr_pred, vanilla_model2.trainer.callbacks[0].pa_metric.afr_pred) + # assert torch.allclose(vanilla_model2.metric.afr_true, vanilla_model2.trainer.callbacks[0].pa_metric.afr_true) + # assert torch.allclose(vanilla_model2.metric.accuracy, vanilla_model2.trainer.callbacks[0].pa_metric.accuracy) + + # # Check that the values are different + # assert not torch.equal(vanilla_model2.metric.betas, vanilla_model2.metric.betas[torch.randperm(len(vanilla_model2.metric.betas))]) + + # # We also check the final results of the whole PA optimization (i.e. the PA selected for each call) + # assert torch.allclose(torch.tensor(vanilla_model2.metric.log_beta), torch.tensor(vanilla_model2.trainer.callbacks[0].pa_metric.log_beta)) + # assert torch.allclose(torch.tensor(vanilla_model2.metric.log_logPA), torch.tensor(vanilla_model2.trainer.callbacks[0].pa_metric.log_logPA)) + # assert torch.allclose(torch.tensor(vanilla_model2.metric.log_AFR_pred), torch.tensor(vanilla_model2.trainer.callbacks[0].pa_metric.log_AFR_pred)) + # assert torch.allclose(torch.tensor(vanilla_model2.metric.log_AFR_pred), torch.tensor(vanilla_model2.trainer.callbacks[0].pa_metric.log_AFR_pred)) + # assert torch.allclose(torch.tensor(vanilla_model2.metric.log_accuracy), torch.tensor(vanilla_model2.trainer.callbacks[0].pa_metric.log_accuracy)) + + # """ + # Now we will call the metric from outside again, and the PA module. + # """ + + # # Results should be the same (comparison within the model) + # pa_metric = pa_metric_partial( + # dataset = pa_traindl.dataset, + # processing_strategy = "cpu" + # ) + # pa_metric.update(deepcopy(vanilla_model.model).eval()) + + # pa_module_partial: LightningModule = hydra.utils.instantiate(cfg.pa_callback.pa_module) + # pa_module = pa_module_partial(classifier=deepcopy(vanilla_model.model).eval()) + # trainer_pa = hydra.utils.instantiate(cfg.pa_callback.trainers.pa_module.cpu) + # trainer_pa.fit( + # model=pa_module, + # train_dataloaders=pa_traindl, + # val_dataloaders=pa_traindl + # ) + + # assert torch.equal(pa_metric.betas, torch.tensor(pa_module.betas, dtype=float)) + # assert torch.allclose(pa_metric.logPAs, torch.tensor(pa_module.logPAs, dtype=float)) + + + # print("\nCheck that the PA and beta values are meaningful in the CPU:") + # print("beta: ", pa_metric.betas) + # print("logPA: ", pa_metric.logPAs) + + # print("\nCPU test passed.\n") + # exit() + + # # ______________________________________ LIGHTNING ______________________________________ + # print("1") + # pa_callback_partial = hydra.utils.instantiate(cfg.pa_callback.pa_callback) + # pa_callback = pa_callback_partial( + # dataset = pa_traindl.dataset, + # cuda_devices = cfg.pa_callback.trainers.vanilla.ddp.devices + # ) + # print("2") + # pa_metric_partial = hydra.utils.instantiate(cfg.pa_callback.pa_metric) + # pa_metric = pa_metric_partial( + # dataset = pa_traindl.dataset, + # processing_strategy = "lightning", + # cuda_devices = cfg.pa_callback.trainers.vanilla.ddp.devices + # ) + # print("3") + # # We train the model with the callback and the metric inside: + # trainer_vanilla_partial = hydra.utils.instantiate(cfg.pa_callback.trainers.vanilla.ddp) + # trainer_vanilla = trainer_vanilla_partial( + # callbacks=[pa_callback] + # ) + # print("4") + # vanilla_model_partial = hydra.utils.instantiate(cfg.pa_callback.vanilla_model) + # vanilla_model = vanilla_model_partial( + # metric = pa_metric + # ) + # print("5") + # trainer_vanilla.fit( + # vanilla_model, + # # I initialize the datamodule again because it's in DDP now + # datamodule=hydra.utils.instantiate(cfg.pa_callback.datamodules.main) + # ) + # print("6") + + # """ + # The overall results of the callback and the metric should be exactly the same. + # """ + # # Check PA results from the last epoch it was computed. For other epochs it has been checked from within the LightnignModule + # assert torch.allclose(vanilla_model.metric.betas, vanilla_model.trainer.callbacks[0].pa_metric.betas) + # assert torch.allclose(vanilla_model.metric.logPAs, vanilla_model.trainer.callbacks[0].pa_metric.logPAs) + # assert torch.allclose(vanilla_model.metric.afr_pred, vanilla_model.trainer.callbacks[0].pa_metric.afr_pred) + # assert torch.allclose(vanilla_model.metric.afr_true, vanilla_model.trainer.callbacks[0].pa_metric.afr_true) + # assert torch.allclose(vanilla_model.metric.accuracy, vanilla_model.trainer.callbacks[0].pa_metric.accuracy) + + # # Check that the values are different + # assert not torch.equal(vanilla_model.metric.betas, vanilla_model.metric.betas[torch.randperm(len(vanilla_model.metric.betas))]) + + # # We also check the final results of the whole PA optimization (i.e. the PA selected for each call) + # assert torch.allclose(torch.tensor(vanilla_model.metric.log_beta), torch.tensor(vanilla_model.trainer.callbacks[0].pa_metric.log_beta)) + # assert torch.allclose(torch.tensor(vanilla_model.metric.log_logPA), torch.tensor(vanilla_model.trainer.callbacks[0].pa_metric.log_logPA)) + # assert torch.allclose(torch.tensor(vanilla_model.metric.log_AFR_pred), torch.tensor(vanilla_model.trainer.callbacks[0].pa_metric.log_AFR_pred)) + # assert torch.allclose(torch.tensor(vanilla_model.metric.log_AFR_pred), torch.tensor(vanilla_model.trainer.callbacks[0].pa_metric.log_AFR_pred)) + # assert torch.allclose(torch.tensor(vanilla_model.metric.log_accuracy), torch.tensor(vanilla_model.trainer.callbacks[0].pa_metric.log_accuracy)) + + # print("\nCheck that the PA and beta values are meaningful for the lightning strategy:") + # print("beta: ", vanilla_model.metric.betas) + # print("logPA: ", vanilla_model.metric.logPAs) + # exit() + + # # ______________________________________ CUDA ______________________________________ + # """ + # Now we will call the metric from outside again, and the PA module. + # """ + # print("7") + # if dist.is_initialized(): + # dist.destroy_process_group() + + # # Results should be the same (comparison within the model) + # vanilla_model_partial = hydra.utils.instantiate(cfg.pa_callback.vanilla_model) + # vanilla_model = vanilla_model_partial(metric=None) + # model_to_eval = deepcopy(vanilla_model.model).eval() + # model_to_eval2 = deepcopy(vanilla_model.model).eval() + + # # Initialize from here just in case + # pa_metric_cuda = PosteriorAgreement( + # dataset = pa_traindl.dataset, + # beta0 = cfg.pa_callback.pa_module.beta0, + # pa_epochs = cfg.pa_callback.trainers.pa_module.ddp.max_epochs, + # processing_strategy = "cuda", + # cuda_devices = cfg.pa_callback.trainers.vanilla.ddp.devices + # ) + # pa_metric_cuda.update(model_to_eval, destroy_process_group=True) + # print("8") + + # # Now compare with the PA module implementation + # pa_module_partial: LightningModule = hydra.utils.instantiate(cfg.pa_callback.pa_module) + # pa_module = pa_module_partial(classifier=model_to_eval2) + # trainer_pa = hydra.utils.instantiate( + # cfg.pa_callback.trainers.pa_module.ddp + # ) + # trainer_pa.fit( + # model=pa_module, + # train_dataloaders=pa_traindl, + # val_dataloaders=pa_traindl + # ) + # print("9") + + # print("\nCheck that the PA and beta values are meaningful in DDP:") + # print("betas metric: ", pa_metric_cuda.betas) + # print("betas module: ", torch.tensor(pa_module.betas, dtype=float)) + # print("logPA metric: ", pa_metric_cuda.logPAs[0, :]) + # print("logPA module: ", torch.tensor(pa_module.logPAs, dtype=float)) + + # try: + # assert torch.equal(pa_metric_cuda.betas, torch.tensor(pa_module.betas, dtype=float)), "The betas are not the same between metric and module." + # assert torch.equal(pa_metric_cuda.logPAs[0, :], torch.tensor(pa_module.logPAs, dtype=float)) + # assert pa_metric_cuda.afr_pred == pa_module.afr_pred.item(), "The AFR_pred is not the same between metric and module." + # assert pa_metric_cuda.afr_true == pa_module.afr_true.item(), "The AFR_true is not the same between metric and module." + # assert pa_metric_cuda.accuracy == pa_module.acc_pa.item(), "The accuracy is not the same between metric and module." + # except: + # print("THEY ARE NOT EQUAL!!") + + # # exit() + # print("\nTest passed.") + + + + diff --git a/tests/test_pa/pa_metric.py b/tests/test_pa/pa_metric.py new file mode 100644 index 0000000..840a31e --- /dev/null +++ b/tests/test_pa/pa_metric.py @@ -0,0 +1,405 @@ +""" +This test will check that the PA metric provides the same results that the existing PA module when the data is controlled. +""" + +import hydra +from omegaconf import DictConfig + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader, SequentialSampler +from pytorch_lightning import seed_everything, LightningModule, LightningDataModule + +from .utils import get_pa_metrics, get_acc_metrics +from src.pa_metric.metric import PosteriorAgreement +from src.pa_metric.datautils import LogitsDataset, MultienvDataset +from src.pa_metric.metric import PosteriorAccuracy + +from copy import deepcopy + +def test_basemetric(cfg: DictConfig): + """ + This test checks that the optimization over beta yields the same results in the PA module + and the basemetric. + """ + seed_everything(42, workers=True) + + ## DATA INITIALIZATION: + # Main datamodule: + datamodule_main: LightningDataModule = hydra.utils.instantiate(cfg.pa_metric.datamodules.main) + datamodule_main.prepare_data() + datamodule_main.setup("fit") + + # Impose sequential sampler + main_traindl = datamodule_main.train_dataloader() + main_traindl = DataLoader( + main_traindl.dataset, + batch_size=main_traindl.batch_size, + sampler=SequentialSampler(main_traindl.dataset), + num_workers=main_traindl.num_workers, + collate_fn=main_traindl.collate_fn + ) + + # Logits datamodule + datamodule_palogs: LightningDataModule = hydra.utils.instantiate(cfg.pa_metric.datamodules.logits) + datamodule_palogs.prepare_data() + datamodule_palogs.setup("fit") + + # Impose sequential sampler + palogs_traindl = datamodule_palogs.train_dataloader() + palogs_traindl = DataLoader( + palogs_traindl.dataset, + batch_size=palogs_traindl.batch_size, + sampler=SequentialSampler(palogs_traindl.dataset), + num_workers=palogs_traindl.num_workers, + collate_fn=palogs_traindl.collate_fn + ) + + # USING THE PA MODULE + pamodule_palogs: LightningModule = hydra.utils.instantiate(cfg.pa_metric.pa_module) + trainer = hydra.utils.instantiate(cfg.pa_metric.trainer.cpu) + trainer.fit( + model=pamodule_palogs, + # Same data for training and validation, as in the metric. + train_dataloaders=palogs_traindl, + val_dataloaders=palogs_traindl + ) + acc_metrics = get_acc_metrics(pamodule_palogs) + AFR_pred = acc_metrics[0] + AFR_true = acc_metrics[1] + acc_pa = acc_metrics[2] + + beta_epoch_palogs, logPA_epoch_palogs = get_pa_metrics(pamodule_palogs) + assert len(beta_epoch_palogs) == cfg.pa_metric.trainer.cpu.max_epochs, "Some beta values are not being stored properly." + assert len(logPA_epoch_palogs) == cfg.pa_metric.trainer.cpu.max_epochs, "Some logPA values are not being stored properly." + + ## USING THE PA BASEMETRIC + partial_basemetric = hydra.utils.instantiate(cfg.pa_metric.pa_basemetric) + pabasemetric = partial_basemetric(main_traindl.dataset) + pabasemetric.update( + hydra.utils.instantiate(cfg.pa_metric.datamodules.logits.classifier) + ) + + """ + Since palogs_traindl and palogs_valdl are the same, the base metric should give the same results as the PA module. + """ + assert torch.equal(beta_epoch_palogs, pabasemetric.betas), "The beta values do not coincide." + assert torch.equal(logPA_epoch_palogs, pabasemetric.logPAs), "The logPA values do not coincide." + assert AFR_true == pabasemetric.afr_true, "The AFR_true values do not coincide." + assert AFR_pred == pabasemetric.afr_pred, "The AFR_pred values do not coincide." + assert acc_pa == pabasemetric.accuracy, "The accuracy values do not coincide." + + """ + This would be the dictionary we would be getting if we called simply basemetric(logits), instead of basemetric.update(logits). + This is how the metric will be called in any implementation. + """ + results_dict = pabasemetric.compute() + print("\nResults from the base metric:") + print(results_dict) + + assert results_dict["logPA"] == max(pabasemetric.logPAs), "The logPA value is not the maximum of the logPAs." + + print("\nTest passed.") + + +def test_pametric_cpu(cfg: DictConfig): + """ + Check that the results of the PA metric are the same as the PA base metric when using the CPU. The basemetric has already + been tested against the PA module, and passing the test would mean that the PA metric is also working properly in the CPU. + """ + + # INITIALIZE IMAGES DATALOADER + datamodule_images: LightningDataModule = hydra.utils.instantiate(cfg.pa_metric.datamodules.images) + datamodule_images.prepare_data() + datamodule_images.setup("fit") + + # I disable shuffling so that the results can be compared + images_traindl = datamodule_images.train_dataloader() + images_traindl = DataLoader( + images_traindl.dataset, + batch_size=images_traindl.batch_size, + sampler=SequentialSampler(images_traindl.dataset), + num_workers=images_traindl.num_workers, + collate_fn=images_traindl.collate_fn, + drop_last=False + ) + + # Using the PA basemetric --------------------------------------------------------------- + pabasemetric_partial = hydra.utils.instantiate(cfg.pa_metric.metrics.basemetric) + pabasemetric = pabasemetric_partial(dataset = images_traindl.dataset) + pabasemetric.update(hydra.utils.instantiate(cfg.pa_metric.classifier)) + + # Using the PosteriorAgreement metric ---------------------------------------------------- + pa_metric_partial = hydra.utils.instantiate(cfg.pa_metric.metrics.fullmetric) # initialize the metric for CPU + pa_metric = pa_metric_partial(dataset = images_traindl.dataset) + pa_metric.update(hydra.utils.instantiate(cfg.pa_metric.classifier)) + + assert torch.equal(pa_metric.betas, pabasemetric.betas), "The beta values do not coincide with the logits results." + assert torch.equal(pa_metric.logPAs[0, :], pabasemetric.logPAs), "The logPA values do not coincide with the logits results." + assert pa_metric.afr_true[0].item() == pabasemetric.afr_true, "The AFR_true values do not coincide with the logits results." + assert pa_metric.afr_pred[0].item() == pabasemetric.afr_pred, "The AFR_pred values do not coincide with the logits results." + assert pa_metric.accuracy[0].item() == pabasemetric.accuracy, "The accuracy values do not coincide with the logits results." + + print("\nTest passed.") + +import time +def test_pametric_ddp(cfg: DictConfig): + """ + Check that the results of the PA metric are the same regardless of the data partitioning strategy and the + number of GPUs used. This is important as the PA metric bears a custom DDP implementation. + """ + + # We will only work with the logits dataloader, as it makes it easier. + datamodule_main: LightningDataModule = hydra.utils.instantiate(cfg.pa_metric.datamodules.images) + datamodule_main.prepare_data() + datamodule_main.setup("fit") + + # I disable shuffling so that the results can be compared + main_traindl = datamodule_main.train_dataloader() + main_traindl = DataLoader( + main_traindl.dataset, + batch_size=main_traindl.batch_size, + sampler=SequentialSampler(main_traindl.dataset), + num_workers=main_traindl.num_workers, + collate_fn=main_traindl.collate_fn, + drop_last=False + ) + + # Initialization of the metric + pa_ds = main_traindl.dataset # get the dataset + pa_metric_partial = hydra.utils.instantiate(cfg.pa_metric.metric) + + # Metric running in the CPU vs CUDA --------------------------------------------------------------- + pa_metric_cpu = pa_metric_partial( + dataset = pa_ds, + processing_strategy = "cpu" + ) + pa_metric_cpu.update(hydra.utils.instantiate(cfg.pa_metric.classifier)) + + pa_metric_cuda = pa_metric_partial( + dataset = pa_ds, + processing_strategy = "cuda", + cuda_devices = 4 + ) + pa_metric_cuda.update(hydra.utils.instantiate(cfg.pa_metric.classifier), destroy_process_group=True) # because it's the last time it will be called + + assert torch.allclose(pa_metric_cpu.betas, pa_metric_cuda.betas), "The beta values do not coincide." + assert torch.allclose(pa_metric_cpu.logPAs, pa_metric_cuda.logPAs), "The logPA values do not coincide." + assert torch.allclose(pa_metric_cpu.afr_true, pa_metric_cuda.afr_true), "The AFR_true values do not coincide." + assert torch.allclose(pa_metric_cpu.afr_pred, pa_metric_cuda.afr_pred), "The AFR_pred values do not coincide." + assert torch.allclose(pa_metric_cpu.accuracy, pa_metric_cuda.accuracy), "The accuracy values do not coincide." + + # Metric running in less number of GPUs than available ----------------------------------------- + if dist.is_initialized(): + dist.destroy_process_group() + pa_metric_cuda2 = pa_metric_partial( + dataset = pa_ds, + processing_strategy = "cuda", + cuda_devices = 2 # change number of devices + ) + pa_metric_cuda2.update(hydra.utils.instantiate(cfg.pa_metric.classifier), destroy_process_group=True) + + assert torch.allclose(pa_metric_cuda.betas, pa_metric_cuda2.betas), "The beta values do not coincide with different number of GPUs." + assert torch.allclose(pa_metric_cuda.logPAs[0], pa_metric_cuda2.logPAs[0]), "The logPA values do not coincide with different number of GPUs." + assert torch.allclose(pa_metric_cuda.afr_true, pa_metric_cuda2.afr_true), "The AFR_true values do not coincide with different number of GPUs." + assert torch.allclose(pa_metric_cuda.afr_pred, pa_metric_cuda2.afr_pred), "The AFR_pred values do not coincide with different number of GPUs." + assert torch.allclose(pa_metric_cuda.accuracy, pa_metric_cuda2.accuracy), "The accuracy values do not coincide with different number of GPUs." + + # Metric performing several calls in a row with the same data ----------------------------------- + if dist.is_initialized(): + dist.destroy_process_group() + pa_metric_cuda = pa_metric_partial( + dataset = pa_ds, + cuda_devices = 4, + processing_strategy = "cuda") + pa_metric_cuda(hydra.utils.instantiate(cfg.pa_metric.classifier), destroy_process_group=False) + pa_metric_cuda(hydra.utils.instantiate(cfg.pa_metric.classifier), destroy_process_group=False) + pa_metric_cuda(hydra.utils.instantiate(cfg.pa_metric.classifier), destroy_process_group=True) + + print("\nThe log of results for multiple calls (3) is working: ") + print("Logged logPA: ", pa_metric_cuda.log_logPA) + print("Logged beta: ", pa_metric_cuda.log_beta) + print("Logged AFR_true: ", pa_metric_cuda.log_AFR_true) + print("Logged AFR_pred: ", pa_metric_cuda.log_AFR_pred) + print("Logged accuracy: ", pa_metric_cuda.log_accuracy) + + assert torch.allclose(pa_metric_cpu.betas, pa_metric_cuda.betas), "The beta values do not coincide." + assert torch.allclose(pa_metric_cpu.logPAs[0], pa_metric_cuda.logPAs[0]), "The logPA values do not coincide." + assert torch.allclose(pa_metric_cpu.afr_true, pa_metric_cuda.afr_true), "The AFR_true values do not coincide." + assert torch.allclose(pa_metric_cpu.afr_pred, pa_metric_cuda.afr_pred), "The AFR_pred values do not coincide." + assert torch.allclose(pa_metric_cpu.accuracy, pa_metric_cuda.accuracy), "The accuracy values do not coincide." + + # Performing many epochs --------------------------------------------------------------------- + start = time.time() + if dist.is_initialized(): + dist.destroy_process_group() + pa_metric_cuda_long = pa_metric_partial( + dataset = pa_ds, + pa_epochs = 1000, + cuda_devices = 4, + processing_strategy = "cuda") + metric_dict_long = pa_metric_cuda_long(destroy_process_group=True) + print("\nTime for 1000 epochs: ", time.time() - start) + + print("\nTest passed.") + + +def test_pametric_logits(cfg: DictConfig): + """ + We will test if the logits generated within the PosteriorAgreement metric in DDP mode are the same + as the ones provided by the LogitsDatamodule. + """ + + # Non-paired dataset of images: + datamodule_main: LightningDataModule = hydra.utils.instantiate(cfg.pa_metric.datamodules.main) + datamodule_main.prepare_data() + datamodule_main.setup("fit") + main_traindl = datamodule_main.train_dataloader() + main_dataset = main_traindl.dataset + + # Paired dataset of logits: + datamodule_palogs: LightningDataModule = hydra.utils.instantiate(cfg.pa_metric.datamodules.pa_logits) + datamodule_palogs.prepare_data() + datamodule_palogs.setup("fit") + palogs_traindl = datamodule_palogs.train_dataloader() + palogs_dataset = palogs_traindl.dataset + assert palogs_dataset.__class__.__name__ == "LogitsDataset", "The PA_logits dataset is not being properly loaded." + + pa_metric_partial = hydra.utils.instantiate(cfg.pa_metric.metric) + pa_metric_cpu = pa_metric_partial( + dataset = main_dataset, + processing_strategy = "cpu" + ) + + classifier = hydra.utils.instantiate(cfg.pa_metric.datamodules.pa_logits.classifier) + pa_metric_cpu.classifier, pa_metric_cpu.classifier_val = pa_metric_cpu._initialize_classifiers(classifier) + + # Initialize classifier: Same as the one used in the PA logits. + cpu_dataset = pa_metric_cpu._compute_logits_dataset(0) + assert cpu_dataset.__class__.__name__ == "LogitsDataset", "The logits dataset is not being properly loaded from the metric." + + """ + The LogitsDataset provided by the method in the PosteriorAgreement metric should be the same as the one + generated from the same dataset in the PA_logits datamodule. + """ + assert palogs_dataset.num_envs == cpu_dataset.num_envs , "The number of environments does not coincide." + assert torch.equal(palogs_dataset.y, cpu_dataset.y), "The labels are not the same." + for e in range(palogs_dataset.num_envs): + assert torch.allclose(palogs_dataset.logits[e], cpu_dataset.logits[e]), f"The logits in environment {e} are not the same." + + print("\nCPU test passed.\n") + exit() + + """ + Now we must check that the PA optimization results are the same in the CPU and in DDP mode using the same dataset. + + IMPORTANT: We must instantiate the dataset again, as it has already been paired. A second pairing would mess it up. + """ + pa_metric_cpu.update(classifier) + pa_metric_cuda = pa_metric_partial( + dataset = main_dataset, + cuda_devices = 4, + processing_strategy = "cuda" + ) + pa_metric_cuda.update(classifier, destroy_process_group=True) + + assert torch.allclose(pa_metric_cpu.betas, pa_metric_cuda.betas), "The beta values do not coincide." + assert torch.allclose(pa_metric_cpu.logPAs, pa_metric_cuda.logPAs), "The logPA values do not coincide." + assert torch.allclose(pa_metric_cpu.afr_true, pa_metric_cuda.afr_true), "The AFR_true values do not coincide." + assert torch.allclose(pa_metric_cpu.afr_pred, pa_metric_cuda.afr_pred), "The AFR_pred values do not coincide." + assert torch.allclose(pa_metric_cpu.accuracy, pa_metric_cuda.accuracy), "The accuracy values do not coincide." + + print("\nTest passed.") + + +def test_accuracymetric(cfg): + """ + Test the subclass of posterior accuracy, where one of the distributions is fixed. + """ + + # We will only work with the logits dataloader, as it makes it easier. + datamodule_main: LightningDataModule = hydra.utils.instantiate(cfg.pa_metric.datamodule) + datamodule_main.prepare_data() + datamodule_main.setup("fit") + + # I select the first dataset from the MultienvDataset, as the Accuracy metric only requires one. + ds = MultienvDataset([datamodule_main.train_ds.dset_list[0]]) + + # Initialization of the metric + pa_metric_partial = hydra.utils.instantiate(cfg.pa_metric.metric) + + # Initialization of the fixed classifier used to update the metric + classifier = hydra.utils.instantiate(cfg.pa_metric.classifier) + + # # Metric running in the CPU vs CUDA --------------------------------------------------------------- + pa_metric_cpu = pa_metric_partial( + dataset = ds, + sharpness_factor=1.5, + processing_strategy = "cpu") + pa_metric_cpu.update(classifier) + + pa_metric_cuda = pa_metric_partial( + dataset = ds, + sharpness_factor=1.5, + cuda_devices = 4, + processing_strategy = "cuda") + pa_metric_cuda.update(classifier, destroy_process_group=True) # because it's the last time it will be called + + assert torch.allclose(pa_metric_cpu.betas, pa_metric_cuda.betas), "The beta values do not coincide." + assert torch.allclose(pa_metric_cpu.logPAs, pa_metric_cuda.logPAs), "The logPA values do not coincide." + assert torch.allclose(pa_metric_cpu.afr_true, pa_metric_cuda.afr_true), "The AFR_true values do not coincide." + assert torch.allclose(pa_metric_cpu.afr_pred, pa_metric_cuda.afr_pred), "The AFR_pred values do not coincide." + assert torch.allclose(pa_metric_cpu.accuracy, pa_metric_cuda.accuracy), "The accuracy values do not coincide." + + # Metric running in less number of GPUs than available ----------------------------------------- + pa_metric_cuda2 = pa_metric_partial( + dataset = ds, + sharpness_factor=1.5, + cuda_devices = 2, # change number of devices + processing_strategy = "cuda") + pa_metric_cuda2.update(classifier, destroy_process_group=True) + + assert torch.allclose(pa_metric_cuda.betas, pa_metric_cuda2.betas), "The beta values do not coincide with different number of GPUs." + assert torch.allclose(pa_metric_cuda.logPAs, pa_metric_cuda2.logPAs), "The logPA values do not coincide with different number of GPUs." + assert torch.allclose(pa_metric_cuda.afr_true, pa_metric_cuda2.afr_true), "The AFR_true values do not coincide with different number of GPUs." + assert torch.allclose(pa_metric_cuda.afr_pred, pa_metric_cuda2.afr_pred), "The AFR_pred values do not coincide with different number of GPUs." + assert torch.allclose(pa_metric_cuda.accuracy, pa_metric_cuda2.accuracy), "The accuracy values do not coincide with different number of GPUs." + + # Metric performing several calls woth different sharpness ----------------------------------- + for sharpness in [1.001, 1.01, 1.1, 1.5, 2, 10, 50]: + pa_accuracymetric = pa_metric_partial( + dataset = ds, + sharpness_factor=sharpness, + cuda_devices = 0, + processing_strategy = "cpu") + pa_accuracymetric(classifier) + + pa_accuracymetric2 = pa_metric_partial( + dataset = ds, + sharpness_factor=sharpness, + cuda_devices = 4, + processing_strategy = "cuda") + pa_accuracymetric2(classifier) + + print("\nFor sharpness factor: ", sharpness) + print("beta: ", pa_accuracymetric.betas, pa_accuracymetric2.betas) + print("logPA: ", pa_accuracymetric.logPAs, pa_accuracymetric2.logPAs) + + # Performing many epochs --------------------------------------------------------------------- + start = time.time() + pa_metric_cuda_long = pa_metric_partial( + dataset = ds, + sharpness_factor=1.5, + pa_epochs = 1000, + cuda_devices = 4, + processing_strategy = "cuda") + metric_dict_long = pa_metric_cuda_long(classifier, destroy_process_group=True) + print("\nTime for 1000 epochs: ", time.time() - start) + + print("\nTest passed.") + + + + + + diff --git a/tests/test_pa/pa_module.py b/tests/test_pa/pa_module.py new file mode 100644 index 0000000..59f2ea9 --- /dev/null +++ b/tests/test_pa/pa_module.py @@ -0,0 +1,174 @@ +""" +This test will check that the PA module provides the same results using PA and PA_logits, for both CPU and DDP configurations. +""" + +import hydra +from omegaconf import DictConfig + +import torch +from torch.utils.data import DataLoader, SequentialSampler +from pytorch_lightning import LightningModule, LightningDataModule + +import os +import pandas as pd + +from .utils import get_acc_metrics, get_pa_metrics + +def test_pa_module(cfg: DictConfig): + AFR_pred, AFR_true, acc_pa = [], [], [] # store for comparison at the end + + # ________________________________________ CPU ________________________________________ + # PA DataModule ----------------------------------------------------------------------- + datamodule_pa: LightningDataModule = hydra.utils.instantiate(cfg.pa_module.datamodules.pa) + datamodule_pa.prepare_data() + datamodule_pa.setup("fit") + + # We will use the same datalaoder but with a SequentialSampler, so that results are the same in all cases + pa_traindl = datamodule_pa.train_dataloader() + pa_traindl = DataLoader( + pa_traindl.dataset, + batch_size=pa_traindl.batch_size, + sampler=SequentialSampler(pa_traindl.dataset), + num_workers=pa_traindl.num_workers, + collate_fn=pa_traindl.collate_fn + ) + + # They are the same dataloader, but still wanna check that everything is fine + pa_valdl = datamodule_pa.val_dataloader() + pa_valdl = DataLoader( + pa_valdl.dataset, + batch_size=pa_valdl.batch_size, + sampler=SequentialSampler(pa_valdl.dataset), + num_workers=pa_valdl.num_workers, + collate_fn=pa_valdl.collate_fn + ) + + # Training the model: + pamodule_pa: LightningModule = hydra.utils.instantiate(cfg.pa_module.pa_lightningmodule.pa) + trainer = hydra.utils.instantiate(cfg.pa_module.trainer.cpu) + trainer.fit( + model=pamodule_pa, + train_dataloaders=pa_traindl, + val_dataloaders=pa_valdl + ) + acc_metrics = get_acc_metrics(pamodule_pa) + AFR_pred.append(acc_metrics[0]) + AFR_true.append(acc_metrics[1]) + acc_pa.append(acc_metrics[2]) + + beta_epoch_pa, logPA_epoch_pa = get_pa_metrics(pamodule_pa) + assert len(beta_epoch_pa) == cfg.pa_module.trainer.cpu.max_epochs, "Some beta values are not being stored properly." + assert len(logPA_epoch_pa) == cfg.pa_module.trainer.cpu.max_epochs, "Some logPA values are not being stored properly." + + # PA_logits DataModule ------------------------------------------------------------------ + datamodule_palogs: LightningDataModule = hydra.utils.instantiate(cfg.pa_module.datamodules.pa_logits) + datamodule_palogs.prepare_data() + datamodule_palogs.setup("fit") + + palogs_traindl = datamodule_palogs.train_dataloader() + palogs_traindl = DataLoader( + palogs_traindl.dataset, + batch_size=palogs_traindl.batch_size, + sampler=SequentialSampler(palogs_traindl.dataset), + num_workers=palogs_traindl.num_workers, + collate_fn=palogs_traindl.collate_fn + ) + + # They are the same dataloader, but still wanna check that everything is fine + palogs_valdl = datamodule_palogs.val_dataloader() + palogs_valdl = DataLoader( + palogs_valdl.dataset, + batch_size=palogs_valdl.batch_size, + sampler=SequentialSampler(palogs_valdl.dataset), + num_workers=palogs_valdl.num_workers, + collate_fn=palogs_valdl.collate_fn + ) + + # Training the model: + pamodule_palogs: LightningModule = hydra.utils.instantiate(cfg.pa_module.pa_lightningmodule.pa_logits) + trainer = hydra.utils.instantiate(cfg.pa_module.trainer.cpu) + trainer.fit( + model=pamodule_palogs, + train_dataloaders=palogs_traindl, + val_dataloaders=palogs_valdl + ) + acc_metrics = get_acc_metrics(pamodule_palogs) + AFR_pred.append(acc_metrics[0]) + AFR_true.append(acc_metrics[1]) + acc_pa.append(acc_metrics[2]) + + beta_epoch_palogs, logPA_epoch_palogs = get_pa_metrics(pamodule_palogs) + assert len(beta_epoch_palogs) == cfg.pa_module.trainer.cpu.max_epochs, "Some beta values are not being stored properly." + assert len(logPA_epoch_palogs) == cfg.pa_module.trainer.cpu.max_epochs, "Some logPA values are not being stored properly." + + + """ + The results from PA and PA_logits implementations should be the same. + """ + assert torch.equal(beta_epoch_pa, beta_epoch_palogs), "The beta values are not equal between PA and PA_logits methods." + assert torch.equal(logPA_epoch_pa, logPA_epoch_palogs), "The logPA values are not equal between PA and PA_logits methods." + + + print("\nCPU tests passed.\n") + # ________________________________________ DDP ________________________________________ + # Now I don't need to initialize so many things. + + # PA model ---------------------------------------------------------------------------- + pamodule_pa: LightningModule = hydra.utils.instantiate(cfg.pa_module.pa_lightningmodule.pa) + trainer = hydra.utils.instantiate(cfg.pa_module.trainer.ddp) # DDP trainer + trainer.fit( + model=pamodule_pa, + train_dataloaders=pa_traindl, + val_dataloaders=pa_valdl + ) + acc_metrics = get_acc_metrics(pamodule_pa) + AFR_pred.append(acc_metrics[0]) + AFR_true.append(acc_metrics[1]) + acc_pa.append(acc_metrics[2]) + + beta_epoch_pa, logPA_epoch_pa = get_pa_metrics(pamodule_pa) + assert len(beta_epoch_pa) == cfg.pa_module.trainer.ddp.max_epochs, "Some beta values are not being stored properly." + assert len(logPA_epoch_pa) == cfg.pa_module.trainer.ddp.max_epochs, "Some logPA values are not being stored properly." + + """ + The results in DDP should be the same as in CPU. We compare with the last ones saved (PA_logits). + """ + # We use more epochs on DDP, and thus we compare the first cpu.max_epochs epochs. + + assert torch.equal(beta_epoch_pa[:cfg.pa_module.trainer.cpu.max_epochs], beta_epoch_palogs), "The beta values are not equal between CPU and DDP implementations." + assert torch.equal(logPA_epoch_pa[:cfg.pa_module.trainer.cpu.max_epochs], logPA_epoch_palogs), "The logPA values are not equal between CPU and DDP implementations." + + + # PA_logits model ------------------------------------------------------------------------- + pamodule_palogs: LightningModule = hydra.utils.instantiate(cfg.pa_module.pa_lightningmodule.pa_logits) + trainer = hydra.utils.instantiate(cfg.pa_module.trainer.ddp) # DDP trainer + trainer.fit( + model=pamodule_palogs, + train_dataloaders=palogs_traindl, + val_dataloaders=palogs_valdl + ) + acc_metrics = get_acc_metrics(pamodule_palogs) + AFR_pred.append(acc_metrics[0]) + AFR_true.append(acc_metrics[1]) + acc_pa.append(acc_metrics[2]) + + beta_epoch_palogs, logPA_epoch_palogs = get_pa_metrics(pamodule_palogs) + assert len(beta_epoch_palogs) == cfg.pa_module.trainer.ddp.max_epochs, "Some beta values are not being stored properly." + assert len(logPA_epoch_palogs) == cfg.pa_module.trainer.ddp.max_epochs, "Some logPA values are not being stored properly." + + """ + The results for PA and PA_logits should be the same also in DDP setting. + The comparison with CPU has already been done. + """ + # We use more epochs on DDP, and thus we compare the first cpu.max_epochs epochs. + assert torch.equal(beta_epoch_pa, beta_epoch_palogs), "The beta values are not equal between CPU and DDP implementations." + assert torch.equal(logPA_epoch_pa, logPA_epoch_palogs), "The logPA values are not equal between CPU and DDP implementations." + + # ______________________________________________________________________________________ + + """ + Finally we compare the results of the AFR and acc_pa metrics in the 4 settings. + """ + assert AFR_pred[0] == AFR_pred[1] == AFR_pred[2] == AFR_pred[3], "The AFR_pred values are not equal." + assert AFR_true[0] == AFR_true[1] == AFR_true[2] == AFR_true[3], "The AFR_true values are not equal." + assert acc_pa[0] == acc_pa[1] == acc_pa[2] == acc_pa[3], "The acc_pa values are not equal." \ No newline at end of file diff --git a/tests/test_pa/utils.py b/tests/test_pa/utils.py new file mode 100644 index 0000000..1edb2c3 --- /dev/null +++ b/tests/test_pa/utils.py @@ -0,0 +1,90 @@ + + +import numpy as np +import matplotlib.pyplot as plt +from typing import List +import os + +__all__ = ["plot_multienv", "plot_images_multienv"] + +def _adjust_image(im): + """ + Adjusts the image to be plotted. + """ + if im.dtype == np.float32 or im.dtype == np.float64: + # Normalize to [0, 1] range + im_min, im_max = im.min(), im.max() + if im_max > 1 or im_min < 0: + im = (im - im_min) / (im_max - im_min) + else: + # Clip and convert to integers if needed + im = np.clip(im, 0, 255).astype(np.uint8) + return im + +def plot_multienv(dataset, filename, random=True, extension = ".png"): + if type(random) == bool: + iterate = np.random.randint(0, len(dataset), 2) if random else range(2) + else: # random is an iterable + iterate = random + + for i in iterate: + multienv_dict = dataset.__getitem__(i) + + envs = list(multienv_dict.keys()) + fig, axs = plt.subplots(1, len(envs), figsize=(len(envs)*5, 5)) + for e in range(len(envs)): + env = envs[e] + + # The normalization is with respect all the data, and it's not necessary. + #im_tens = multienv_dict[env][0] + #rescaled_im = (im_tens - im_tens.min()) / (im_tens.max() - im_tens) + #im = np.transpose(rescaled_im.numpy(), (1, 2, 0)) + + im = np.transpose(multienv_dict[env][0].numpy(), (1, 2, 0)) + lab = multienv_dict[env][1] + axs[e].imshow(_adjust_image(im)) + axs[e].set_title(str(lab)) + plt.suptitle(f"Comparison: {envs}") + + savepath = filename + "_" + str(i) + extension + if not os.path.exists(os.path.dirname(savepath)): + os.makedirs(os.path.dirname(savepath)) + plt.savefig(savepath) + + +def plot_images_multienv(images: List, subtitles: List, filename, extension = ".png"): + assert len(images) == len(subtitles), "The number of images and labels is not the same." + + envs = len(images) + fig, axs = plt.subplots(1, envs, figsize=(envs*5, 5)) + for e in range(envs): + im = np.transpose(images[e].numpy(), (1, 2, 0)) + axs[e].imshow(_adjust_image(im)) + axs[e].set_title(str(subtitles[e])) + + savepath = filename + extension + if not os.path.exists(os.path.dirname(savepath)): + os.makedirs(os.path.dirname(savepath)) + plt.savefig(savepath) + +import pandas as pd +import torch +from pytorch_lightning import LightningModule + +def get_pa_metrics(module: LightningModule): + df = pd.read_csv(os.path.join(module.trainer.logger.log_dir, "metrics.csv")) + beta_epoch = df["val/beta"].dropna().values + logPA_epoch = df["val/logPA"].dropna().values + return torch.tensor(beta_epoch), torch.tensor(logPA_epoch) + + +def get_acc_metrics(module: LightningModule): + df = pd.read_csv(os.path.join(module.trainer.logger.log_dir, "metrics.csv")) + AFR_pred = df["val/AFR pred"].dropna().values + AFR_true = df["val/AFR true"].dropna().values + acc_pa = df["val/acc_pa"].dropna().values + + assert len(AFR_pred) == len(AFR_true) == len(acc_pa), "Some AFR or acc_pa values are not being stored properly." + assert len(AFR_pred) == 1, "There should be only one value for AFR and acc_pa." + + return AFR_pred[0], AFR_true[0], acc_pa[0] \ No newline at end of file