Skip to content

pranavjadhav001/DML-for-Classification

Repository files navigation

DML-for-Classification [🚧 WIP]

This series of experiments is to curb my curiosity on:

  1. How Deep Metric Learning or distance based methods perform on classification tasks?
  2. Does generalization to unseen classes holds true?
  3. Why is Deep Metric Learning that performs exceedingly well for biometrics (Face recognition), isn't used for any other image recognition problems?
  4. What are the trade offs and are they worth choosing a distance based approch over a classical classification approach?

Motivation

It all started with a company research project, when my Manager asked me to take up classification problem but approach in a more of a class agnostic way, in general sense try embedding / DML / distance based methods. Developing a model in this way has certain advantages like, we can add more data and classes without retraining everytime.This approach seemed promising, considering that established Face recognition technology demonstrates its feasibility, yet its widespread adoption remains limited.With the advent of Vector Databases, I believe we will be seeing more shift towards this.

Experiment Comparison & Tracking

I'm using Weights & Biases MLops tool to track and compare all experiments. The goal is to stick to predefined metrics and not change over the course of experimentation to give apples to apples comparison.
Link to W&B experimentation page

Dataset

Decent amount of time was spent on choosing the right dataset. I was looking at datasets particularly used for fine grained image classification problems with decent intra and inter class variance. I also wanted dataset with decent number of classes since unseen class validation split would reduce training classes. It has been known that having large no. of classes during training improves model generalization capabilities.
CUB-200-2011 Dataset ticked all the above requirements. It has around 200 classes with approx 60 images for each class. It is commonly used to quantify DML tasks. Other datasets considered were : Cars196, Standford Online Products, In-shop Clothes retrieval , Hotel-50k.
Dataset Link

Download Dataset using CLI

wget https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz?download=1
tar -xvzf <zip-filename>

How to Quantify

There are various metrics to track and compare experiments involving Deep Metric Learning:

  • NMI
  • Recall @ k
  • Mean average Precision(MAP)

Since the goal of this project is to check whether DML can replace traditional classification methods. I've chosen precision @ 1 which literally translates to classification accuracy.

Common Experiment Details

All experiments followed these details, unless explicitly mentioned:

  • Epochs 100
  • Batch size 128
  • fixed seed 42
  • image shape 224x224x3
  • image train test split 0.2
  • unseen class train test split, last 20 classes were kept for unseen class metrics
  • metric for test accuracy precision@1
  • Augmentation is consistent wherever mentioned:
    • Train Transform
      • Resize 224x224x3
      • Random Crop 224x224x3
      • Random Horizontal Flip
    • Test Transform
      • Resize 224x224x3
      • Center Crop 224x224x3

Experiment Table

Experiment Name Experiment Details Deductions
EPSHN_euclidean_resnet18 - used Euclidean Distance metric
- used EPSHN as loss
- Architecture chosen as Resnet 18 with Imagenet1k Pretrained weights
- Used Adam optimizer
- Embedding Dimension of 128

- Test Accuracy of 50 % was achieved
EPSHN_cosine_resnet18 - same as EPSHN_euclidean_resnet18
- used cosine as distance metric
- Loss and accuray curves followed EPSHN_euclidean_resnet18
- slight test accuracy gain of 2% w.r.t EPSHN_euclidean_resnet18
EPSHN_cosine_resnet18_sgd - same as EPSHN_cosine_resnet18
- used SGD optimizer instead of adam

- around 25% gain in test accuracy w.r.t EPSHN_cosine_resnet18
- SGD generalizes really well, making model more robust towards unknown data distribution and classes
EPSHN_cosine_resnet18_sgd_aug - same as EPSHN_cosine_resnet18_sgd
- added train and test augmentations for generalization
- around 3% test accuracy gains
- This will be treated as Baseline
- a drop of 20% accuracy was noted b/w
EPSHN_cosine_frozenBN_resnet18_sgd_aug - same as EPSHN_cosine_resnet18_sgd_aug
- all batch normalization layers are frozen , and pretrained imagenet1k parameters are retained
- Frozen BN is usually to reduce overfitting and generalize better to unknown data and classes
- this step led to decrease in test accuracy by around 14% w.r.t to baseline
EPSHN_cosine_resnet18_scratch_sgd_aug - same as EPSHN_cosine_resnet18_sgd_aug
- instead of pretrained weights, model was trained from scratch
- model couldn’t even reach test accuracy of 20%
- This tells weight initialization plays a crucial role in convergence of the model
EPSHN_cosine_resnet18_classifier_sgd_aug - same as EPSHN_cosine_resnet18_sgd_aug
- instead of pretrained weights, model was first trained using classifier and later finetuned using Metric learning
- model oscillates over the same point in loss and accuracy curve which it inherits from pretrained weights from the classifier
- test accuracy will depend how well classifier is trained , and how well it generalizes
EPSHN_cosine_resnet18x512_sgd_aug - same as EPSHN_cosine_resnet18_sgd_aug
- Embedding Dimension of 512 was chosen instead of 128
- Increasing embedding dim doesn’t have effect on test accuracy
- only a drop of 10% accuracy was noted b/w models trained with and without unseen classes
- for unseen class performs better than baseline
EPSHN_cosine_resnet18+skipConnHead_sgd_aug - same as EPSHN_cosine_resnet18_sgd_aug
- Adding Skip connection Head on top the model
- Adding skip connection head doesn’t have effect on test accuracy
EPSHN_cosine_resnet18+1dBN_sgd_aug - same as EPSHN_cosine_resnet18_sgd_aug
- Adding a 1d Batch Normalization on top of model
- Adding 1D Batch Norm layer doesn’t have effect on test accuracy
- only a drop of 10% accuracy was noted b/w models trained with and without unseen classes
- for unseen class performs better than baseline
EPSHN_cosine_resnet50_sgd_aug - same as EPSHN_cosine_resnet18_sgd_aug
- instead of Resnet 18 , Resnet 50 was chosen with pretrained imagenet1k weights
- Adding more layers / parameters has a direct correlation on performance
- an increase of 11% test accuracy was noted w.r.t baseline
- only a drop of 10% accuracy was noted b/w models trained with and without unseen classes
ProxyNCA_cosine_resnet18_sgd_aug - same as EPSHN_cosine_resnet18_sgd_aug
- instead of EPSHN triplet loss, a proxy based loss was used
- better part of this method, it doesnt require any mining methods.
- but you are also training additional parameters in the form of proxies
- a huge drop in test accuracy around 16% w.r.t baseline
ArcFace_cosine_resnet18_sgd_aug - same as EPSHN_cosine_resnet18_sgd_aug
- instead of EPSHN triplet loss, arcface loss was used
- better part of this method, it doesnt require any mining methods.
- but you are also training additional parameters in the form of class weights
- a increase of 11 % test accuracy was noted from baseline
- Arcface shows the highest drop (32%) in precision for model performance b/w with and without classes included
- Model doesn’t generalize well for unseen classes
SubCenterArcFace_cosine_resnet18_sgd_aug - same as ArcFace_cosine_resnet18_sgd_aug
- a variation of arcface loss used to help with datasets with high intra class variance
- a increase of 1 % test accuracy as compared to baseline
- Exp takes up after ArcFace_cosine_resnet18_sgd_aug on generalization for unseen classes
EPSHN_cosine_simclr_resnet18_sgd_aug - same as EPSHN_cosine_resnet18_sgd_aug
- instead of pretrained imagenet1k weights, model used weights learned from simclr model trained on CUB-200-2011 Dataste
- model couldn’t even reach test accuracy of 20%
- weight initialization plays a crucial role in convergence of the model

Plots & Curves

Test accuracy (only 10 runs can be shown simultaneously)

Test Accuracy Curves

Precison Drop for models trained with & without unseen classes

Precision Drop Bar Plot

How an experiment has been conducted

Each Experiment has been divided into two phase training, One with all classes included, the other excluding unseen classes. The steps are as follows:

  1. Create experiment config , initialize wandb run
  2. Initialize random seed to a fixed value which is common for all experiments
  3. Load image dataset for all 200 classes , split into train test datasets with 0.2 ratio stratified at label level
  4. Initialize model, dataloaders, optimizer, distance metric, loss function, validation function
  5. Train for 100 epochs while monitoring train loss and test accuracy curves, save the weights at the end after training
  6. For test dataset ,get precision @ 1 scores for all classes,for each test sample it looks for closest sample in training dataset; "all_classes_metrics" Table
  7. Using Faiss library, For each test sample , get predicted class by searching for closest sample in training dataset, Generate a classification report ; "all_classes_classification_report" Table
  8. Train a new model, excluding last 20 classes from dataset.
  9. All the steps(4,5,6,7) remain the same as first phase training, even during evaluation, entire dataset was loaded including the last 20 classes
  10. Get precision drop for last 20 unseen classes for both the models; call it "comparison_unseen_classes_metrics" Table, average leads to "precision_drop_unseen_classes" scalar.
  11. Get precision drop for all classes for both the models; call it "comparison_seen_classes_metrics" Table, average leads to "precision_drop_seen_classes" scalar.

Explainability

  • To understand and see where the model is looking at, implemented Grad Cam equivalent for embedding networks from this paper. You can see the results here at this repo. Model from EPSHN_cosine_resnet50_sgd_aug experiment has been used.
Grad-CAM Vanilla BackProp Guided BackProp Guided Grad-CAM
Grad-CAM Vanilla backpropagation Guided backpropagation Guided Grad-CAM
  • To visualize embeddings for all classes for entire dataset, I've used FiftyOne to plot using umap and see corresponding images associated. It shows how visually similar classes cluster around one another. This also tells us about the spread(variance) of test embeddings vs train embeddings. Embedding Visualization, along with samples
  • Used FAISS to find best and closest matches for unseen classes. This offers insights that closest one may not be always the best and hence always determine the best class by majority voting for closest n samples.
true label: 005.Crested_Auklet
predicted label: 024.Red_faced_Cormorant
Classes of Top 10 closest samples : 24, 24, 5, 24, 5, 5, 5, 5, 24, 5
True Class(005.Crested_Auklet) Predicted Class(024.Red_faced_Cormorant)
Closest Sample=1 Closest Sample=1

What to look for

More comprehensive plots and numbers were created for each experiment to better compare against each other.These plots, tables and scalars are only present for experiments that showed potential(decent test accuracy)

  • "train_loss" curve: Tracks the loss curve for 200 class training Dataset
  • "test_accuracy" curve : Tracks precision @ 1 curve for the 200 known classes for test dataset
  • "all_classes_metrics" table : Table which contains rows of class name, test sample no. , training sample no. , precision @ 1 for that class. Determined by looping for all classes, for each test sample find the class of the closest sample in training dataset.
  • "all_classes_classification_report" table : classification report table for test dataset.
  • "train_loss2" curve: Tracks the loss curve for 180 class training Dataset
  • "test_accuracy2" curve: Tracks precision @ 1 curve for the 180 known classes for test dataset
  • "limited_classes_metrics" table : Table which contains rows of class name, test sample no. , training sample no. , precision @ 1 for that class. Determined by looping for all 200 classes, for each test sample find the class of the closest sample in training dataset.
  • "limited_classes_classification_report" table : classification report table for test dataset for all 200 classes
  • "comparison_unseen_classes_metrics" table : table which contains rows of class name, precision @ 1 for 200 class model, precision @ 1 for 180 class model
  • "precision_drop_unseen_classes" scalar : average value of unseen 20 classes(precision @ 1 for 200 class model - precision @ 1 for 180 class model)
  • "precision_drop_all_classes" scalar : average value of 200 classes(precision @ 1 for 200 class model - precision @ 1 for 180 class model)

References

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published