This repository contains the implementation of RankSim on AgeDB-DIR dataset.
The imbalanced regression framework and LDS+FDS are based on the public repository of Yang et al., ICML 2021.
The blackbox combinatorial solver is based on the public repository of Vlastelica et al., ICLR 2020.
-
Download AgeDB dataset from here and extract the zip file (you may need to contact the authors of AgeDB dataset for the zip password) to folder
./data
-
We use the standard train/val/test split file (
agedb.csv
in folder./data
) provided by Yang et al.(ICML 2021), which is used to set up balanced val/test set. To reproduce the results in the paper, please directly use this file. You can also generate it using
python data/create_agedb.py
python data/preprocess_agedb.py
- PyTorch (>= 1.2, tested on 1.6)
- tensorboard_logger
- numpy, pandas, scipy, tqdm, matplotlib, PIL, wget
train.py
: main training and evaluation scriptcreate_agedb.py
: create AgeDB raw meta datapreprocess_agedb.py
: create AgeDB-DIR meta fileagedb.csv
with balanced val/test set
--data_dir
: data directory to place data and meta file--reweight
: cost-sensitive re-weighting scheme to use--loss
: training loss type--regularization_weight
: gamma, weight of the regularization term (default 100.0)--interpolation_lambda
: lambda, interpolation strength parameter(default 2.0)
To use Vanilla model
python train.py --batch_size 256 --lr 1e-3
To use square-root frequency inverse
python train.py --batch_size 256 --lr 1e-3 --reweight sqrt_inv
To use LDS (Yang et al., ICML 2021) with originally reported hyperparameters
python train.py --batch_size 256 --lr 1e-3 --reweight sqrt_inv --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2
To use FDS (Yang et al., ICML 2021) with originally reported hyperparameters
python train.py --batch_size 256 --lr 1e-3 --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2
python train.py --batch_size 64 --lr 2.5e-4 --regularization_weight=100.0 --interpolation_lambda=2.0
python train.py --batch_size 64 --lr 2.5e-4 --reweight sqrt_inv --regularization_weight=100.0 --interpolation_lambda=2.0
To use RankSim with Focal-R loss
python train.py --loss focal_l1 --batch_size 64 --lr 2.5e-4 --regularization_weight=100.0 --interpolation_lambda=2.0
To use RankSim (gamma: 100.0, lambda: 2.0) with Gaussian kernel (kernel size: 5, sigma: 2)
python train.py --batch_size 64 --lr 2.5e-4 --reweight sqrt_inv --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2 --regularization_weight=100.0 --interpolation_lambda=2.0
To use RankSim (gamma: 100.0, lambda: 2.0) with Gaussian kernel (kernel size: 5, sigma: 2)
python train.py --batch_size 64 --lr 2.5e-4 --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2 --regularization_weight=100.0 --interpolation_lambda=2.0
To use RankSim (gamma: 100.0, lambda: 2.0) with LDS (Gaussian kernel, kernel size: 5, sigma: 2) and FDS (Gaussian kernel, kernel size: 5, sigma: 2)
python train.py --batch_size 64 --lr 2.5e-4 --reweight sqrt_inv --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2 --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2 --regularization_weight=100.0 --interpolation_lambda=2.0
NOTE: We report the results with batch size of 64 & learn rate of 2.5e-4. You can try the batch size reported by Yang et al., ICML 2021 by changing the arguments, e.g. run SQINV + RankSim with batch size 256, learning rate 1e-3
python train.py --batch_size 256 --lr 1e-3 --reweight sqrt_inv --regularization_weight=100.0 --interpolation_lambda=2.0
If you do not train the model, you can evaluate the model and reproduce our results directly using the pretrained weights from the anonymous links below.
python train.py --evaluate [...evaluation model arguments...] --resume <path_to_evaluation_ckpt>
SQINV + RankSim, MAE All 6.91 (best MAE All-shot)
(weights)
SQINV + FDS + RankSim, MAE Few-shot 9.68 (best MAE Few-shot)
(weights)
Focal-R + LDS + FDS + RankSim, MAE Many-shot 6.17
(weights)
Focal-R + FDS + RankSim, GM Med-shot 4.84
(weights)
RRT + LDS + RankSim, MAE Med-shot 7.54
(weights)