Skip to content

Latest commit

 

History

History
70 lines (50 loc) · 3.79 KB

README.md

File metadata and controls

70 lines (50 loc) · 3.79 KB

Invariant-feature Subspace Recovery (ISR)

Code for the paper Provable Domain Generalization via Invariant-Feature Subspace Recovery (ICML 2022) by Haoxiang Wang, Haozhe Si, Bo Li, and Han Zhao from UIUC.

If you find this repo useful for your research, please consider citing our paper

@inproceedings{ISR,
  title = {Provable Domain Generalization via Invariant-Feature Subspace Recovery},
  author = {Wang, Haoxiang and Si, Haozhe and Li, Bo and Zhao, Han},
  booktitle = {International Conference on Machine Learning},
  pages = {23018--23033},
  year = {2022},
  publisher = {PMLR},
  url = {https://proceedings.mlr.press/v162/wang22x.html},
}

ISR-Mean illustration

Installation

This repo was tested with Ubuntu 20.04, Python 3.8/3.9 (Anaconda version), Pytorch 1.9/1.10/1.11 with CUDA 11. The experiments on the real datasets are tested on a single GPU of 16GB memory, but 11GB may also suffice.

Our code is devided into two parts, linear_unit_tests/ and real_datasets/.

Datasets

The synthetic datasets in linear_unit_tests/ are generated by the code, and the three real datasets (Waterbirds, CelebA and MultiNLI) used in real_datasets/ should be downloaded in advance following this instruction.

Two image datasets and one text dataset.

Code

linear_unit_tests/

Run python launch_exp.py to reproduce experiments in the paper. Use python plot_results.py to plot the results. The experiments on run on CPU (parallel on all CPU cores by default).

real_datasets/

Run experiments on the three real-world datasets:

Please see real_datasets/README.md for detailed instructions on running experiments.

Notably, we implement our ISR algorithms in a sklearn-style classifier class, which can be easily used as follows

from isr import ISRClassifier
classifier = ISRClassifier(version="mean", # "mean": ISR-Mean. "cov": ISR-Cov.
                           d_spu=1, # the number of spurious features to remove
                          )
# xs: training samples
# ys: class labels
# es: environment labels
classifier.fit(xs,ys,es,
              chosen_class=0, # need to condition on a class
              )
predictions = classifier.predict(test_xs) # test_xs: test samples

Acknowledgement

In this repo, we adopt some code from the following codebases, and we sincerely thank their authors: