Skip to content
/ CDL Public

Continual Distillation Learning: Knowledge Distillation in Prompt-based Continual Learning

Notifications You must be signed in to change notification settings

IRVLUTD/CDL

Repository files navigation

Continual Distillation Learning

PyTorch code for the paper:
Continual Distillation Learning: Knowledge Distillation in Prompt-based Continual Learning
Qifan Zhang, Yunhui Guo, Yu Xiang

arXiv, Project

Abstract

We introduce the problem of continual distillation learning (CDL) in order to use knowledge distillation (KD) to improve prompt-based continual learning (CL) models. The CDL problem is valuable to study since the use of a larger vision transformer (ViT) leads to better performance in prompt-based continual learning. The distillation of knowledge from a large ViT to a small ViT can improve the inference efficiency for prompt-based CL models. We empirically found that existing KD methods such as logit distillation and feature distillation cannot effectively improve the student model in the CDL setup. To this end, we introduce a novel method named Knowledge Distillation based on Prompts (KDP), in which globally accessible prompts specifically designed for knowledge distillation are inserted into the frozen ViT backbone of the student model. We demonstrate that our KDP method effectively enhances the distillation performance in comparison to existing KD methods in the CDL setup.

Setup

  • set up conda environment w/ python 3.8, ex: conda create --name CDL python=3.8
  • conda activate CDL
  • sh install_requirements.sh

Datasets

  ./data  
  ├── cifar-100-python  
  ├── imagenet-r  
  │   ├── n01443537  
  │   │   ├── art_0.jpg  
  │   │   ├── cartoon_0.jpg  
  │   │   ├── graffiti_0.jpg
  │   │   └── ...
  │   ├── n01833805  
  │   │   ├── art_0.jpg  
  │   │   ├── cartoon_0.jpg  
  │   │   ├── graffiti_0.jpg
          └── ... 

Training

The scripts are set up for 2 GPUs but can be modified for your hardware. You can directly run the run.py and test on ImageNet-R dataset:

python -u run.py --config configs/imnet-r_prompt.yaml --gpuid 0 1 \
    --learner_type prompt --learner_name CODAPrompt \
    --prompt_param 100 8 0.0 \
    --log_dir ImageNet_R/coda-p \
    --t_model 'vit_base_patch16_224' \
    --s_model 'vit_small_patch16_224' \
    --KD_method 'KD_Token' \
    --kd_prompt_param 12 6
  • Check the experiments/imagenet-r.sh and experiments/cifar-100.sh to see the details.
  • You can change the learner_name for DualPrompt or L2P.
  • Change the prompt_param for different learner(CODA, DualPrompt or L2P)
  • You can adjust the teacher and student's model with --t_model and --s_model.
  • Change the --KD_method for different knowledage distillation methods -> ['KD_Token', 'KD', 'DKD', 'FitNets', 'ReviewKD']. Use the 'KD_Token' for our KDP model.
  • Change the --kd_prompt_param for our KDP model (kd_layers size, kd_prompt_length).

Results

The results will be saved in the created --log_dir folder, including the models for the teacher and student as well as the final average accuracy for both the teacher and student.

Acknowledgments

This project is based on the following repositories:

About

Continual Distillation Learning: Knowledge Distillation in Prompt-based Continual Learning

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published