This series of experiments is to curb my curiosity on:
- How Deep Metric Learning or distance based methods perform on classification tasks?
- Does generalization to unseen classes holds true?
- Why is Deep Metric Learning that performs exceedingly well for biometrics (Face recognition), isn't used for any other image recognition problems?
- What are the trade offs and are they worth choosing a distance based approch over a classical classification approach?
It all started with a company research project, when my Manager asked me to take up classification problem but approach in a more of a class agnostic way, in general sense try embedding / DML / distance based methods. Developing a model in this way has certain advantages like, we can add more data and classes without retraining everytime.This approach seemed promising, considering that established Face recognition technology demonstrates its feasibility, yet its widespread adoption remains limited.With the advent of Vector Databases, I believe we will be seeing more shift towards this.
I'm using Weights & Biases MLops tool to track and compare all experiments. The goal is to stick to predefined metrics and not change over the course of experimentation to give apples to apples comparison.
Link to W&B experimentation page
Decent amount of time was spent on choosing the right dataset. I was looking at datasets particularly used for fine grained image classification problems with decent intra and inter class variance. I also wanted dataset with decent number of classes since unseen class validation split would reduce training classes. It has been known that having large no. of classes during training improves model generalization capabilities.
CUB-200-2011 Dataset ticked all the above requirements. It has around 200 classes with approx 60 images for each class. It is commonly used to quantify DML tasks. Other datasets considered were : Cars196, Standford Online Products, In-shop Clothes retrieval , Hotel-50k.
Dataset Link
wget https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz?download=1
tar -xvzf <zip-filename>
There are various metrics to track and compare experiments involving Deep Metric Learning:
- NMI
- Recall @ k
- Mean average Precision(MAP)
Since the goal of this project is to check whether DML can replace traditional classification methods. I've chosen precision @ 1 which literally translates to classification accuracy.
All experiments followed these details, unless explicitly mentioned:
- Epochs 100
- Batch size 128
- fixed seed 42
- image shape 224x224x3
- image train test split 0.2
- unseen class train test split, last 20 classes were kept for unseen class metrics
- metric for test accuracy precision@1
- Augmentation is consistent wherever mentioned:
- Train Transform
- Resize 224x224x3
- Random Crop 224x224x3
- Random Horizontal Flip
- Test Transform
- Resize 224x224x3
- Center Crop 224x224x3
- Train Transform
Experiment Name | Experiment Details | Deductions |
---|---|---|
EPSHN_euclidean_resnet18 | - used Euclidean Distance metric - used EPSHN as loss - Architecture chosen as Resnet 18 with Imagenet1k Pretrained weights - Used Adam optimizer - Embedding Dimension of 128 |
- Test Accuracy of 50 % was achieved |
EPSHN_cosine_resnet18 | - same as EPSHN_euclidean_resnet18 - used cosine as distance metric |
- Loss and accuray curves followed EPSHN_euclidean_resnet18 - slight test accuracy gain of 2% w.r.t EPSHN_euclidean_resnet18 |
EPSHN_cosine_resnet18_sgd | - same as EPSHN_cosine_resnet18 - used SGD optimizer instead of adam |
- around 25% gain in test accuracy w.r.t EPSHN_cosine_resnet18 - SGD generalizes really well, making model more robust towards unknown data distribution and classes |
EPSHN_cosine_resnet18_sgd_aug | - same as EPSHN_cosine_resnet18_sgd - added train and test augmentations for generalization |
- around 3% test accuracy gains - This will be treated as Baseline - a drop of 20% accuracy was noted b/w |
EPSHN_cosine_frozenBN_resnet18_sgd_aug | - same as EPSHN_cosine_resnet18_sgd_aug - all batch normalization layers are frozen , and pretrained imagenet1k parameters are retained |
- Frozen BN is usually to reduce overfitting and generalize better to unknown data and classes - this step led to decrease in test accuracy by around 14% w.r.t to baseline |
EPSHN_cosine_resnet18_scratch_sgd_aug | - same as EPSHN_cosine_resnet18_sgd_aug - instead of pretrained weights, model was trained from scratch |
- model couldn’t even reach test accuracy of 20% - This tells weight initialization plays a crucial role in convergence of the model |
EPSHN_cosine_resnet18_classifier_sgd_aug | - same as EPSHN_cosine_resnet18_sgd_aug - instead of pretrained weights, model was first trained using classifier and later finetuned using Metric learning |
- model oscillates over the same point in loss and accuracy curve which it inherits from pretrained weights from the classifier - test accuracy will depend how well classifier is trained , and how well it generalizes |
EPSHN_cosine_resnet18x512_sgd_aug | - same as EPSHN_cosine_resnet18_sgd_aug - Embedding Dimension of 512 was chosen instead of 128 |
- Increasing embedding dim doesn’t have effect on test accuracy - only a drop of 10% accuracy was noted b/w models trained with and without unseen classes - for unseen class performs better than baseline |
EPSHN_cosine_resnet18+skipConnHead_sgd_aug | - same as EPSHN_cosine_resnet18_sgd_aug - Adding Skip connection Head on top the model |
- Adding skip connection head doesn’t have effect on test accuracy |
EPSHN_cosine_resnet18+1dBN_sgd_aug | - same as EPSHN_cosine_resnet18_sgd_aug - Adding a 1d Batch Normalization on top of model |
- Adding 1D Batch Norm layer doesn’t have effect on test accuracy - only a drop of 10% accuracy was noted b/w models trained with and without unseen classes - for unseen class performs better than baseline |
EPSHN_cosine_resnet50_sgd_aug | - same as EPSHN_cosine_resnet18_sgd_aug - instead of Resnet 18 , Resnet 50 was chosen with pretrained imagenet1k weights |
- Adding more layers / parameters has a direct correlation on performance - an increase of 11% test accuracy was noted w.r.t baseline - only a drop of 10% accuracy was noted b/w models trained with and without unseen classes |
ProxyNCA_cosine_resnet18_sgd_aug | - same as EPSHN_cosine_resnet18_sgd_aug - instead of EPSHN triplet loss, a proxy based loss was used - better part of this method, it doesnt require any mining methods. - but you are also training additional parameters in the form of proxies |
- a huge drop in test accuracy around 16% w.r.t baseline |
ArcFace_cosine_resnet18_sgd_aug | - same as EPSHN_cosine_resnet18_sgd_aug - instead of EPSHN triplet loss, arcface loss was used - better part of this method, it doesnt require any mining methods. - but you are also training additional parameters in the form of class weights |
- a increase of 11 % test accuracy was noted from baseline - Arcface shows the highest drop (32%) in precision for model performance b/w with and without classes included - Model doesn’t generalize well for unseen classes |
SubCenterArcFace_cosine_resnet18_sgd_aug | - same as ArcFace_cosine_resnet18_sgd_aug - a variation of arcface loss used to help with datasets with high intra class variance |
- a increase of 1 % test accuracy as compared to baseline - Exp takes up after ArcFace_cosine_resnet18_sgd_aug on generalization for unseen classes |
EPSHN_cosine_simclr_resnet18_sgd_aug | - same as EPSHN_cosine_resnet18_sgd_aug - instead of pretrained imagenet1k weights, model used weights learned from simclr model trained on CUB-200-2011 Dataste |
- model couldn’t even reach test accuracy of 20% - weight initialization plays a crucial role in convergence of the model |
Each Experiment has been divided into two phase training, One with all classes included, the other excluding unseen classes. The steps are as follows:
- Create experiment config , initialize wandb run
- Initialize random seed to a fixed value which is common for all experiments
- Load image dataset for all 200 classes , split into train test datasets with 0.2 ratio stratified at label level
- Initialize model, dataloaders, optimizer, distance metric, loss function, validation function
- Train for 100 epochs while monitoring train loss and test accuracy curves, save the weights at the end after training
- For test dataset ,get precision @ 1 scores for all classes,for each test sample it looks for closest sample in training dataset; "all_classes_metrics" Table
- Using Faiss library, For each test sample , get predicted class by searching for closest sample in training dataset, Generate a classification report ; "all_classes_classification_report" Table
- Train a new model, excluding last 20 classes from dataset.
- All the steps(4,5,6,7) remain the same as first phase training, even during evaluation, entire dataset was loaded including the last 20 classes
- Get precision drop for last 20 unseen classes for both the models; call it "comparison_unseen_classes_metrics" Table, average leads to "precision_drop_unseen_classes" scalar.
- Get precision drop for all classes for both the models; call it "comparison_seen_classes_metrics" Table, average leads to "precision_drop_seen_classes" scalar.
- To understand and see where the model is looking at, implemented Grad Cam equivalent for embedding networks from this paper. You can see the results here at this repo. Model from EPSHN_cosine_resnet50_sgd_aug experiment has been used.
Grad-CAM |
Vanilla BackProp |
Guided BackProp |
Guided Grad-CAM |
---|---|---|---|
- To visualize embeddings for all classes for entire dataset, I've used FiftyOne to plot using umap and see corresponding images associated. It shows how visually similar classes cluster around one another. This also tells us about the spread(variance) of test embeddings vs train embeddings.
- Used FAISS to find best and closest matches for unseen classes. This offers insights that closest one may not be always the best and hence always determine the best class by majority voting for closest n samples.
true label: 005.Crested_Auklet
predicted label: 024.Red_faced_Cormorant
Classes of Top 10 closest samples : 24, 24, 5, 24, 5, 5, 5, 5, 24, 5
True Class(005.Crested_Auklet) |
Predicted Class(024.Red_faced_Cormorant) |
---|---|
More comprehensive plots and numbers were created for each experiment to better compare against each other.These plots, tables and scalars are only present for experiments that showed potential(decent test accuracy)
- "train_loss" curve: Tracks the loss curve for 200 class training Dataset
- "test_accuracy" curve : Tracks precision @ 1 curve for the 200 known classes for test dataset
- "all_classes_metrics" table : Table which contains rows of class name, test sample no. , training sample no. , precision @ 1 for that class. Determined by looping for all classes, for each test sample find the class of the closest sample in training dataset.
- "all_classes_classification_report" table : classification report table for test dataset.
- "train_loss2" curve: Tracks the loss curve for 180 class training Dataset
- "test_accuracy2" curve: Tracks precision @ 1 curve for the 180 known classes for test dataset
- "limited_classes_metrics" table : Table which contains rows of class name, test sample no. , training sample no. , precision @ 1 for that class. Determined by looping for all 200 classes, for each test sample find the class of the closest sample in training dataset.
- "limited_classes_classification_report" table : classification report table for test dataset for all 200 classes
- "comparison_unseen_classes_metrics" table : table which contains rows of class name, precision @ 1 for 200 class model, precision @ 1 for 180 class model
- "precision_drop_unseen_classes" scalar : average value of unseen 20 classes(precision @ 1 for 200 class model - precision @ 1 for 180 class model)
- "precision_drop_all_classes" scalar : average value of 200 classes(precision @ 1 for 200 class model - precision @ 1 for 180 class model)
- Resnet-18 1D batchnorm Arcface architecture
- ArcFace Paper
- EPSHN Paper
- Evaluation Measures information retreival
- Pytorch Metric Learning Library
- what-is-the-mean-average-precision-in-information-retrieval
- precision-and-recall-in-information-retrieval
- Pytorch Metric Learning Library
- Quarterion Similarity Learning Tips & Tricks
- ProxyNCA Loss Paper
- SimCLR pytorch implementation by AiSummer
- A Metric Learning Reality Check Paper
- Deep Metric Learning Survey Blog Post