Code for generating class representations using pretrained ArcFace Model for explainablity
- pull latest pytorch docker image from docker hub
- docker run -p 8080:8080 --rm -v arcface_inversion/:/arcface_inversion/ -it --gpus=all pytorch/pytorch:latest bash
- pip3 install -r requirements.txt
- apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
- jupyter notebook --port=8080 --ip=0.0.0.0 --allow-root --no-browser
- Train embedding model using ArcFace loss
- Here I'm using pytorch-metric-learning library(https://kevinmusgrave.github.io/pytorch-metric-learning/) for using high level api
- I'm using Resnet-18 SE block model architecture(https://github.com/ronghuaiyang/arcface-pytorch/blob/master/models/resnet.py)
- Train the deep metric learning task on any dataset, using MNIST here in this example
- Save the embedding model weights , along with ArcFace loss weights
- Get pretrained Batch Normalization priors, store them as variables
- Initialize random gaussian centered images once
- Run training loop for 20k iterations, and update the input image every iteration using statistic loss ,Arcface loss and regularization loss
- Statistic loss is calculated using running mean and variance of all batch norm layers in the model
- You can tune weight decay , alpha , learning rate, epochs, batch size for different/better results.
- Complex dataset can show better generation of class representations
- Pretrained Embedding model is not enough to reproduce, you will need arcface class weights as well
- Resnet-50 is used in paper but here a flavour of resnet-18 is used
- Image Background matters :
"data-free method employing the BN priors to restore ImageNet images for distillation, quantization and pruning. Their model inversion results contain obvious artifact in the background due to the translation augmentation during training. By contrast, our ArcFace model is trained on normalized face crops without back- ground, thus the restored faces exhibit less artifact."
- https://github.com/ronghuaiyang/arcface-pytorch
- MNIST training example using PML : https://github.com/KevinMusgrave/pytorch-metric-learning/tree/master/examples
- Arcface paper : https://arxiv.org/abs/1801.07698