This Repository mainly introduces a method of channel pruning using discrete cosine transform. Our paper Discrete Cosine Transform for Filter Pruning has been accepted by Applied Intelligence. Part of the code is referenced from: HRank: Filter Pruning using High-Rank Feature Map. Thanks for their great work before.
The following is the framework of this method:
The first step is to obtain the output feature maps of each filter of the model, and use the discrete cosine transform to transform the feature maps from the spatial domain to the frequency domain. The second step is to obtain the corresponding importance score according to the DCT coefficient of the feature maps. This importance score represents the relative importance of each channel in the filter. The third step is to remove some unimportant channels as needed based on the importance score. Finally, fine-tune the pruned model to restore the lost accuracy.
The relevant software versions used in use are as follows:
- Ubuntu 16.04
- Python 3.5
- Pytorch 1.4.0
- CUDA 10.0
The pre-trained models used in our paper has been uploaded here:https://pan.baidu.com/s/1a1vUhsmdK_5GvGdfo5M1QQ code:mddg.
Dateset | Models |
---|---|
Cifar10 | VGG16 |
Cifar10 | ResNet56 |
Cifar10 | ResNet110 |
Cifar10 | DensNet40 |
Cifar10 | GoogleNet |
ImageNet | ResNet50 |
DUTS | U2Net |
Datasets such as Cifar10, ImageNet, and DUTS are used in the experiment. Please store them in the following directory structure:
DCT_Pruning/
├─data/
├─cifar-10-batches-py
├─ImageNet
│ ├─ILSVRC2012_img_train
│ └─val
└─DUTS
├─DUTS-TE
└─DUTS-TR
The use of this code is divided into two steps. First, the importance score of the model needs to be generated and saved locally, and then the importance score is used for pruning and fine-tuning.
python importance_generation.py \
--dataset [dataset name] \
--data_dir [dataset dir] \
--batch_size [batch size] \
--pretrain_dir [pretrain_model dir] \
--limit [batch numbers] \
--net [model name]
Among them, ‘--limit’ refers to the number of batches used to generate the importance score. A slightly larger setting will make the importance score more accurate. (We take 5 in the experiment)
According to the corresponding dataset, select the following code to prune the model. Note that different compression rates will significantly affect the final results of pruning. If the final accuracy after pruning cannot meet the requirements, please lower the compression rate appropriately.
The following are the compression ratios used in our experiment and the corresponding final results:
Parameters,M(PR,%) | FLOPs,M(PR,%) | Top-1,% |
---|---|---|
0.87(94.2) | 49.86(84.1) | 92.81 |
python prune_cifar10.py \
--dataset 'cifar10' \
--data_dir './data' \
--job_dir './save_models' \
--batch_size 256 \
--epochs 150 \
--snapshot 20 \
--learning_rate 0.01 \
--lr_decay_step '50,100' \
--momentum 0.9 \
--weight_decay 0.005 \
--pretrain_dir './checkpoints/vgg_16_bn.pt' \
--imp_score './importance_score/vgg_16_bn_limit5' \
--compress_rate '[0.50]*7+[0.95]*5' \
--net 'vgg_16_bn'
Parameters,M(PR,%) | FLOPs,M(PR,%) | Top-1,% |
---|---|---|
0.66M(22.3%) | 90.35M(28.0%) | 93.96 |
python prune_cifar10.py \
--dataset 'cifar10' \
--data_dir './data' \
--job_dir './save_models' \
--batch_size 256 \
--epochs 300 \
--snapshot 20 \
--learning_rate 0.01 \
--lr_decay_step '150,225' \
--momentum 0.9 \
--weight_decay 0.005 \
--pretrain_dir './checkpoints/resnet_56.pt' \
--imp_score './importance_score/resnet_56_limit5' \
--compress_rate '[0.]+[0.18]*29' \
--net 'resnet_56'
Parameters,M(PR,%) | FLOPs,M(PR,%) | Top-1,% |
---|---|---|
1.00M(41.9%) | 135.88M(46.3%) | 94.26 |
python prune_cifar10.py \
--dataset 'cifar10' \
--data_dir './data' \
--job_dir './save_models' \
--batch_size 256 \
--epochs 300 \
--snapshot 20 \
--learning_rate 0.01 \
--lr_decay_step '150,225' \
--momentum 0.9 \
--weight_decay 0.005 \
--pretrain_dir './checkpoints/resnet_110.pt' \
--imp_score './importance_score/resnet_110_limit5' \
--compress_rate '[0.]+[0.2]*2+[0.3]*18+[0.40]*18+[0.39]*19' \
--net 'resnet_110'
Parameters,M(PR,%) | FLOPs,M(PR,%) | Top-1,% |
---|---|---|
0.62M(40.4%) | 173.39M(38.5%) | 94.32 |
python prune_cifar10.py \
--dataset 'cifar10' \
--data_dir './data' \
--job_dir './save_models' \
--batch_size 256 \
--epochs 300 \
--snapshot 20 \
--learning_rate 0.01 \
--lr_decay_step '150,225' \
--momentum 0.9 \
--weight_decay 0.002 \
--pretrain_dir './checkpoints/densenet_40.pt' \
--imp_score './importance_score/densenet_40_limit5' \
--compress_rate '[0.]+[0.2]*12+[0.]+[0.2]*12+[0.]+[0.2]*12' \
--net 'densenet_40'
Parameters,M(PR,%) | FLOPs,B(PR,%) | Top-1,% |
---|---|---|
2.10M(66.0%) | 0.40B(74.1%) | 94.67 |
python prune_cifar10.py \
--dataset 'cifar10' \
--data_dir './data' \
--job_dir './save_models' \
--batch_size 128 \
--epochs 300 \
--snapshot 20 \
--learning_rate 0.01 \
--lr_decay_step '150,225' \
--momentum 0.9 \
--weight_decay 0.005 \
--pretrain_dir './checkpoints/googlenet.pt' \
--imp_score './importance_score/googlenet_limit5' \
--compress_rate '[0.4]+[0.85]*2+[0.9]*5+[0.9]*2' \
--net 'googlenet'
Parameters,M(PR,%) | FLOPs,B(PR,%) | Top-1,% |
---|---|---|
7.45M(70.8%) | 1.06B(74.1%) | 72.32 |
python prune_imagenet.py \
--dataset 'imagenet' \
--data_dir './data/ImageNet' \
--job_dir './save_models' \
--batch_size 256 \
--epochs 180 \
--snapshot 20 \
--learning_rate 5e-06 \
--lr_type 'cos' \
--momentum 0.99 \
--weight_decay 0.0001 \
--label_smooth 0.1 \
--pretrain_dir './checkpoints/resnet_50.pth' \
--imp_score './importance_score/resnet_50_limit5' \
--compress_rate '[0.]+[0.1]*3+[0.4]*7+[0.4]*9' \
--net 'resnet_50'
python prune_u2netp.py \
--dataset 'DUTS' \
--data_dir './data/DUTS' \
--job_dir './save_models' \
--batch_size 12 \
--epochs 1000 \
--learning_rate 0.001 \
--eps 1e-08 \
--weight_decay 0 \
--pretrain_dir './checkpoints/u2netp.pt' \
--imp_score './importance_score/u2netp_limit5' \
--compress_rate '[0.4]+[0.85]*2+[0.9]*5+[0.9]*2' \
--net 'u2netp'
Use the following code to verify the accuracy after pruning:
python test.py \
--dataset 'cifar10' \
--data_dir './data' \
--batch_size [batch size] \
--test_model_dir [test model dir] \
--compress_rate [compression ratio of the test model] \
--net [model name]
python test.py \
--dataset 'imagenet' \
--data_dir './data/ImageNet' \
--batch_size [batch size] \
--test_model_dir [test model dir] \
--compress_rate [compression ratio of the test model] \
--net 'resnet_50'
python test.py \
--dataset 'DUTS' \
--data_dir './data/DUTS' \
--batch_size [batch size] \
--test_model_dir [test model dir] \
--compress_rate [compression ratio of the test model] \
--net 'u2netp'
The results of the test will be saved in '/data/DUTS/u2netp_DUTS-TE_results'. Then use Binary-Segmentation-Evaluation-Tool to evaluate model accuracy:
python Binary-Segmentation-Evaluation-Tool/quan_eval_demo.py
In order to quickly verify the experimental results, the pruned model is provided below for quick verification:https://pan.baidu.com/s/1Pr1xWPPBXS4cNJk5X1mn4g code:icbf
Architecture | Params | Flops | Compress_rate | Accuracy | Link |
---|---|---|---|---|---|
vgg16 | link | ||||
resnet56 | link | ||||
resnet110 | link | ||||
resnet50 | link | ||||
densnet40 | link | ||||
googlenet | link | ||||
u2net | link |
@inproceedings{DCTPruning,
title={Discrete Cosine Transform for Filter Pruning},
author={Yaosen Chen, Renshuang Zhou, Bing Guo, Yan Shen, Wei Wang, Xuming Wen & Xinhua Suo },
journal={Applied Intelligence},
year={2022},
publisher={Springer}
}