Skip to content

Latest commit

 

History

History
183 lines (136 loc) · 9.98 KB

README.md

File metadata and controls

183 lines (136 loc) · 9.98 KB

Selective Mixup Fine-Tuning for Optimizing Non-Decomposable Objectives [ICLR 2024 Spotlight]

Authors: Shrinivas Ramasubramanian* , Harsh Rangwani * , Sho Takemori * , Kunal Samanta, Yuhei Umeda, Venkatesh Babu Radhakrishnan

This repo. contains code for our ICLR'24 spotlight [paper] "Selective Mixup Fine-Tuning for Optimizing Non-Decomposable Objectives"

Image Description
Fig. 1: We demonstrate the effect of the variants of mixup on feature representations (a). With Mixup, the feature representation gets equal contribution in all directions of other classes (b). Unlike this, in SelMix (c), certain class mixups are selected at a timestep t such that they optimize the desired metric. Above is an overview of how the SelMix distribution is obtained at timestep t

@inproceedings{
ramasubramanian2024selective,
title={Selective Mixup Fine-Tuning for Optimizing Non-Decomposable Metrics},
author={Shrinivas Ramasubramanian and Harsh Rangwani and Sho Takemori and Kunal Samanta and Yuhei Umeda and Venkatesh Babu Radhakrishnan},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
}

TLDR

SelMix is a novel method for optimizing non-decomposable objectives in long-tail semi-supervised and supervised learning tasks. It formulates the problem as a multi-armed bandit, where each arm represents a pair of classes to perform mixup on. SelMix introduces a selection policy that assigns probabilities to class pairs based on their estimated gain in the objective, updated using validation feedback.

Installation

To install PyTorch's latest version, please refer to official docs. After installing the latest version of PyTorch for your GPU, create and activate your conda environment:

conda create --name your_env_name --file requirements.txt
conda activate your_env_name

Install the necessary libraries using

pip install -r requirements.txt

Weights & Biases (wandb)

Since we extensively log various detailed performance measures for the model, we strongly recommend that you install wandb before proceeding further. You can find the instruction here

Usage

The usage guidelines using custom-models is provided in Model Readme

Training loop

Overview of the training pipeling

## Define your datasets for mixup; for the supervised case, they are assumed to be the same dataset
dataset1 = None
dataset2 = None 

# Placeholder for optimizer
optimizer = None

# Placeholder for Lagrange multipliers
lagrange_multipliers = None

# Loop through epochs
for epoch in range(num_epochs):
    # Perform validation and obtain confusion matrix and prototypes
    confusion_matrix, prototypes = validation(valset, model)
    
    # Calculate MinRecall objective and update Lagrange multipliers
    objective = MinRecall(confusion_matrix, prototypes, lagrange_multipliers)
    lagrange_multipliers = objective.lambdas
    
    # Obtain P_selmix and create FastJointSampler using the objective's P
    P_selmix = objective.P
    SelMix_dataloader = FastJointSampler(dataset1, dataset2, model, P_selmix)

    # Loop through steps in each epoch
    for step in range(num_steps_per_epoch):
        # Get batches from SelMix dataloader
        (x1, y1), (x2, y2) = SelMix_dataloader.get_batch()
        
        # Forward pass through the model with mixed inputs
        logits = model(x1, x2)
        
        # Calculate cross-entropy loss using labels from the first batch
        loss = F.cross_entropy(logits, y1)
        
        # Backward pass and optimization step
        loss.backward()
        optimizer.step()
        
        # Reset gradients for the next iteration
        optimizer.zero_grad()

How to run

Pre-training

We provide the code for pre-training your model using FixMatch and FixMatch w/ LA on various dataset configurations. The complete set of config files can be found in the ./configs directory with the following structure. Each dataset has configurations corresponding to its dataset distribution's pre-training for a given pre-training method and the corresponding config file for fine-tuning for a given objective.

.
├── cifar10/
│   ├── DataDistribution/        
│       ├── GeometricMean.yaml
│       ├── ...
│       ├── OtherObjectives.yaml
│       ├── ...
│       └── MinRecall.yaml
│   ├── ...
│   └── ...
│   ├── pretraining/
│       ├── fixmatchLA.yaml
│       ├── fixmatchOriginal.yaml
├── cifar100/
├── imagenet1k/
└── stl10/

To start the pre-training for vanilla FixMatch for CIFAR-10 with $\rho_l = 100, \rho_u = 100$ and $N_1$ = 150, $M_1$ = 300 run the following

python pretrain.py --config_file configs/cifar10/$N_1$-1500_$M_1$-3000_IBRL-100_IBRU-100/pretraining/fixmatchOriginal.yaml

We also provide the config file for pre-training with FixMatch w/ LA on the same datasets as

python pretrain.py --config_file configs/cifar10/$N_1$-1500_$M_1$-3000_IBRL-100_IBRU-100/pretraining/fixmatchLA.yaml

Pretraining Checkpoints

We provide the seed 0 pre-training checkpoints for Fixmatch and Fixmatch w/ LA loss

Dataset CIFAR-10 CIFAR-10 CIFAR-10
$N_1 = 1500, M_1 = 3000$ $N_1 = 1500, M_1 = 3000$ $N_1 = 1500, M_1 = 30$
$\rho_l = 100, \rho_u = 100$ $\rho_l = 100, \rho_u = 1$ $\rho_l = 100, \rho_u = 0.01$
Fixmatch Google Drive Link Google Drive Link Google Drive Link
w/ LA Google Drive Link Google Drive Link Google Drive Link
Dataset STL-10 Imagenet-100 CIFAR-100
$N_1 = 450, M_1 = Unk$ $N_1 = 433, M_1 = 866$ $N_1 = 150, M_1 = 300$
$\rho_l = 10, \rho_u = Unk$ $\rho_l = 10, \rho_u = 10$ $\rho_l = 10, \rho_u = 10$
FixMatch Google Drive Link Google Drive Link Google Drive Link
w/ LA Google Drive Link Google Drive Link Google Drive Link

Fine-tuning

To start the fine-tuning, set the hyperparameters for the specific objective and dataset in the config file parameters and run the following

python trainMetricOpt.py --config_file configs/cifar10/$N_1$-1500_$M_1$-3000_IBRL-100_IBRU-100/MinRecall.yaml

Results

We obtain the following results when we run the fine-tuning process for the code. The metric reported in each table corresponds to the mean recall and min-recall for the fine-tuned FixMatch w/LA checkpoints.

Mean Recall

Dataset CIFAR-10 CIFAR-10 CIFAR-10 CIFAR-100 STL-10 Imagenet-100
$N_1 = 1500, M_1 = 3000$ $N_1 = 1500, M_1 = 3000$ $N_1 = 1500, M_1 = 30$ $N_1 = 150, M_1 = 300$ $N_1 = 450, M_1 = Unk$ $N_1 = 433, M_1 = 866$
$\rho_l = 100, \rho_u = 100$ $\rho_l = 100, \rho_u = 1$ $\rho_l = 100, \rho_u = 0.01$ $\rho_l = 10, \rho_u = 10$ $\rho_l = 10, \rho_u = Unk$ $\rho_l = 10, \rho_u = 10$
Fixmatch w/ LA 80.1 93.5 80.6 55.7
w/ SelMix 85.3 93.8 81.4 56.1

Min/Min HT Recall

Dataset CIFAR-10 CIFAR-10 CIFAR-10 CIFAR-100 (Min HT) STL-10 Imagenet-100 (Min HT)
$N_1 = 1500, M_1 = 3000$ $N_1 = 1500, M_1 = 3000$ $N_1 = 1500, M_1 = 30$ $N_1 = 150, M_1 = 300$ $N_1 = 450, M_1 = Unk$ $N_1 = 433, M_1 = 866$
$\rho_l = 100, \rho_u = 100$ $\rho_l = 100, \rho_u = 1$ $\rho_l = 100, \rho_u = 0.01$ $\rho_l = 10, \rho_u = 10$ $\rho_l = 10, \rho_u = Unk$ $\rho_l = 10, \rho_u = 10$
Fixmatch w/ LA 69.3 83.3 63.1 32.2
w/ SelMix 79.7 88.4 72.7 56.4

Acknowledgements

Our pre-training code is based on the implementation FixMatch-pytorch.