If you found this code/work useful in your own research, please consider citing the following:
@inproceedings{
wu23disc,
title={Discover and Cure: Concept-aware Mitigation of Spurious Correlation},
author={Shirley Wu and Mert Yuksekgonul and Linjun Zhang and James Zou},
booktitle={ICML},
year={2023},
}
DISC is an algorithm on image classification tasks which adaptively discovers and removes spurious correlations during model training, using a concept bank generated by Stable Diffusion.
- 🔑 Effectively remove strong spurious correlation and make models generalize better! Go for the green decision boundary!
- 🔎 No more ambiguous interpretations! DISC tells you exactly what attributes contribute to the spurious correlation and how significant their contributions are.
- 🌱 Monitor how models learn spurious correlations!
- Build a concept bank with multiple categories.
- In each iteration, discover spurious concepts by computing concept sensitivity.
- In each iteration, mix up concept images with the training dataset guided by the concept sensitivity, and update model parameters on the balanced dataset.
See our paper for details!
See requirements.txt
or install the environment via
conda create -n disc python=3.9
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
pip install scikit-learn transformers wilds umap-learn diffusers nltk
pip install tarfile zipfile gdown # Used for data download
(Recommended) Download all the datasets via the commands below!
cd disc
python download_datasets.py
Manual download (If auto download fails)
-
MetaShift: Download the dataset from here. Unzipping this should result in a folder
metashifts
, which should be moved as$ROOT/data/metashifts
depending on your root directory. -
Waterbirds: Download the dataset from here. Unzipping this should result in a folder
waterbird_complete95_forest2water2
. Place this folder under$ROOT/data/cub/
. -
FMoW: Dataset download is automatic and can be found in
$ROOT/data/fmow/fmow_v1.1
. We recommend following the setup instructions provided by the official WILDS website. -
ISIC: Download the dataset from here. Unzipping this should result in a folder
isic
, which should be moved as$ROOT/data/isic
depending on your root directory.
(Recommended) Download the concept bank we have already generated via the commands below!
cd concept_bank
python download.py
Manual generation. Can be used for customizing your own concept bank!
- Define the concept bank in
synthetic_concepts/metadata.json
- Run the generation using Stable Diffusion v1-4:
cd concept_bank python generate_concept_bank.py --n_samples 200
We provide commands under scripts
folder. For example, train an ERM model on MetaShift:
SEED=0
ROOT=./DISC # Set your code root here
python run_expt.py \
-s confounder -d MetaDatasetCatDog -t cat -c background --lr 0.001 --batch_size 16 \
--weight_decay 0.0001 --model resnet50 --n_epochs 100 --log_dir $ROOT/output/ \
--root_dir $ROOT/data/metashifts/MetaDatasetCatDog --save_best --save_last --seed $SEED
We provide commands under scripts
folder. For example, with a trained ERM model on MetaShift, you can train the DISC model via:
SEED=0
N_CLUSTERS=2
ROOT=./DISC # Set your code root here
python run_expt.py \
-s confounder -d MetaDatasetCatDog -t cat -c background --lr 0.0005 --batch_size 16 \
--weight_decay 0.0001 --model resnet50 --n_epochs 100 --log_dir $ROOT/output/ \
--root_dir $ROOT/data/metashifts/MetaDatasetCatDog \
--erm_path <the erm model path ends with .pth> \
--concept_img_folder $ROOT/synthetic_concepts --concept_categories everything \
--n_clusters $N_CLUSTERS --augment_data --save_last --save_best --seed $SEED --disc
Free feel to create an issue under this repo or contact [email protected]
if you have any questions!