近年来,深度神经网络在计算机视觉、自然语言处理等领域被验证是一种极其有效的解决问题的方法。通过构建合适的神经网络,加以训练,最终网络模型的性能指标基本上都会超过传统算法。
在数据量足够大的情况下,通过合理构建网络模型的方式增加其参数量,可以显著改善模型性能,但是这又带来了模型复杂度急剧提升的问题。大模型在实际场景中使用的成本较高。
深度神经网络一般有较多的参数冗余,目前有几种主要的方法对模型进行压缩,减小其参数量。如裁剪、量化、知识蒸馏等,其中知识蒸馏是指使用教师模型(teacher model)去指导学生模型(student model)学习特定任务,保证小模型在参数量不变的情况下,得到比较大的性能提升,甚至获得与大模型相似的精度指标 [1]。
根据蒸馏方式的不同,可以将知识蒸馏方法分为3个不同的类别:Response based distillation、Feature based distillation、Relation based distillation。下面进行详细介绍。
最早的知识蒸馏算法 KD,由 Hinton 提出,训练的损失函数中除了 gt loss 之外,还引入了学生模型与教师模型输出的 KL 散度,最终精度超过单纯使用 gt loss 训练的精度。这里需要注意的是,在训练的时候,需要首先训练得到一个更大的教师模型,来指导学生模型的训练过程。
PaddleClas 中提出了一种简单使用的 SSLD 知识蒸馏算法 [6],在训练的时候去除了对 gt label 的依赖,结合大量无标注数据,最终蒸馏训练得到的预训练模型在 15 个模型上的精度提升平均高达 3%。
上述标准的蒸馏方法是通过一个大模型作为教师模型来指导学生模型提升效果,而后来又发展出 DML(Deep Mutual Learning)互学习蒸馏方法 [7],即通过两个结构相同的模型互相学习。具体的。相比于 KD 等依赖于大的教师模型的知识蒸馏算法,DML 脱离了对大的教师模型的依赖,蒸馏训练的流程更加简单,模型产出效率也要更高一些。
Heo 等人提出了 OverHaul [8], 计算学生模型与教师模型的 feature map distance,作为蒸馏的 loss,在这里使用了学生模型、教师模型的转移,来保证二者的 feature map 可以正常地进行 distance 的计算。
基于 feature map distance 的知识蒸馏方法也能够和 3.1 章节
中的基于 response 的知识蒸馏算法融合在一起,同时对学生模型的输出结果和中间层 feature map 进行监督。而对于 DML 方法来说,这种融合过程更为简单,因为不需要对学生和教师模型的 feature map 进行转换,便可以完成对齐(alignment)过程。PP-OCRv2 系统中便使用了这种方法,最终大幅提升了 OCR 文字识别模型的精度。
1.1.1 和 1.1.2 章节中的论文中主要是考虑到学生模型与教师模型的输出或者中间层 feature map,这些知识蒸馏算法只关注个体的输出结果,没有考虑到个体之间的输出关系。
Park 等人提出了 RKD [10],基于关系的知识蒸馏算法,RKD 中进一步考虑个体输出之间的关系,使用 2 种损失函数,二阶的距离损失(distance-wise)和三阶的角度损失(angle-wise)
本论文提出的算法关系知识蒸馏(RKD)迁移教师模型得到的输出结果间的结构化关系给学生模型,不同于之前的只关注个体输出结果,RKD 算法使用两种损失函数:二阶的距离损失(distance-wise)和三阶的角度损失(angle-wise)。在最终计算蒸馏损失函数的时候,同时考虑 KD loss 和 RKD loss。最终精度优于单独使用 KD loss 蒸馏得到的模型精度。
更多关于知识蒸馏的算法简介以及应用介绍,请参考:知识蒸馏算法简介。
论文信息:
Cheng Cui, Ruoyu Guo, Yuning Du, Dongliang He, Fu Li, Zewu Wu, Qiwen Liu, Shilei Wen, Jizhou Huang, Xiaoguang Hu, Dianhai Yu, Errui Ding, Yanjun Ma
arxiv, 2021
SSLD是百度于2021年提出的一种简单的半监督知识蒸馏方案,通过设计一种改进的JS散度作为损失函数,结合基于ImageNet22k数据集的数据挖掘策略,最终帮助15个骨干网络模型的精度平均提升超过3%。
更多关于SSLD的原理、模型库与使用介绍,请参考:SSLD知识蒸馏算法介绍。
SSLD配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定梯度,并且加载预训练参数。在损失函数Loss字段中,需要定义DistillationDMLLoss
,作为训练的损失函数。
# model architecture
Arch:
name: "DistillationModel" # 模型名称,这里使用的是蒸馏模型,
class_num: &class_num 1000 # 类别数量,对于ImageNet1k数据集来说,类别数为1000
pretrained_list: # 预训练模型列表,因为在下面的子网络中指定了预训练模型,这里无需指定
freeze_params_list: # 固定网络参数列表,为True时,表示固定该index对应的网络
- True
- False
infer_model_name: "Student" # 在模型导出的时候,会导出Student子网络
models: # 子网络列表
- Teacher: # 教师模型
name: ResNet50_vd # 模型名称
class_num: *class_num # 类别数
pretrained: True # 预训练模型路径,如果为True,则会从官网下载默认的预训练模型
use_ssld: True # 是否使用SSLD蒸馏得到的预训练模型,精度会更高一些
- Student: # 学生模型
name: PPLCNet_x2_5 # 模型名称
class_num: *class_num # 类别数
pretrained: False # 预训练模型路径,可以指定为bool值或者字符串,这里为False,表示学生模型默认不加载预训练模型
# loss function config for traing/eval process
Loss: # 定义损失函数
Train: # 定义训练的损失函数,为列表形式
- DistillationDMLLoss: # 蒸馏的DMLLoss,对DMLLoss进行封装,支持蒸馏结果(dict形式)的损失函数计算
weight: 1.0 # loss权重
model_name_pairs: # 用于计算的模型对,这里表示计算Student和Teacher输出的损失函数
- ["Student", "Teacher"]
Eval: # 定义评估时的损失函数
- CELoss:
weight: 1.0
论文信息:
Ying Zhang, Tao Xiang, Timothy M. Hospedales, Huchuan Lu
CVPR, 2018
DML论文中,在蒸馏的过程中,不依赖于教师模型,两个结构相同的模型互相学习,计算彼此输出(logits)的KL散度,最终完成训练过程。
在ImageNet1k公开数据集上,效果如下所示。
策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
---|---|---|---|---|
baseline | PPLCNet_x2_5 | PPLCNet_x2_5.yaml | 74.93% | - |
DML | PPLCNet_x2_5 | PPLCNet_x2_5_dml.yaml | 76.68%(+1.75%) | - |
- 注:完整的PPLCNet_x2_5模型训练了360epoch,这里为了方便对比,baseline和DML均训练了100epoch,因此指标比官网最终开源出来的模型精度(76.60%)低一些。
DML配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型与学生模型均保持梯度更新状态。在损失函数Loss字段中,需要定义DistillationDMLLoss
(学生与教师之间的JS-Div loss)以及DistillationGTCELoss
(学生与教师关于真值标签的CE loss),作为训练的损失函数。
Arch:
name: "DistillationModel"
class_num: &class_num 1000
pretrained_list:
freeze_params_list: # 两个模型互相学习,因此这里两个子网络的参数均不能固定
- False
- False
models:
- Teacher:
name: PPLCNet_x2_5 # 两个模型互学习,因此均没有加载预训练模型
class_num: *class_num
pretrained: False
- Student:
name: PPLCNet_x2_5
class_num: *class_num
pretrained: False
Loss:
Train:
- DistillationGTCELoss: # 因为2个子网络均没有加载预训练模型,这里需要同时计算不同子网络的输出与真值标签之间的CE loss
weight: 1.0
model_names: ["Student", "Teacher"]
- DistillationDMLLoss:
weight: 1.0
model_name_pairs:
- ["Student", "Teacher"]
Eval:
- CELoss:
weight: 1.0
论文信息:
UDML 是百度飞桨视觉团队提出的无需依赖教师模型的知识蒸馏算法,它基于DML进行改进,在蒸馏的过程中,除了考虑两个模型的输出信息,也考虑两个模型的中间层特征信息,从而进一步提升知识蒸馏的精度。更多关于UDML的说明与应用,请参考PP-ShiTu论文以及PP-OCRv3论文。
在ImageNet1k公开数据集上,效果如下所示。
策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
---|---|---|---|---|
baseline | PPLCNet_x2_5 | PPLCNet_x2_5.yaml | 74.93% | - |
UDML | PPLCNet_x2_5 | PPLCNet_x2_5_dml.yaml | 76.74%(+1.81%) | - |
Arch:
name: "DistillationModel"
class_num: &class_num 1000
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- False
- False
models:
- Teacher:
name: PPLCNet_x2_5
class_num: *class_num
pretrained: False
# return_patterns表示除了返回输出的logits,也会返回对应名称的中间层feature map
return_patterns: ["blocks3", "blocks4", "blocks5", "blocks6"]
- Student:
name: PPLCNet_x2_5
class_num: *class_num
pretrained: False
return_patterns: ["blocks3", "blocks4", "blocks5", "blocks6"]
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
key: logits
model_names: ["Student", "Teacher"]
- DistillationDMLLoss:
weight: 1.0
key: logits
model_name_pairs:
- ["Student", "Teacher"]
- DistillationDistanceLoss: # 基于蒸馏结果的距离loss,这里默认使用l2 loss计算block5之间的损失函数
weight: 1.0
key: "blocks5"
model_name_pairs:
- ["Student", "Teacher"]
Eval:
- CELoss:
weight: 1.0
注意(: 上述在网络中指定return_patterns
,返回中间层特征的功能是基于TheseusLayer,更多关于TheseusLayer的使用说明,请参考:TheseusLayer 使用说明。
论文信息:
Show, attend and distill: Knowledge distillation via attention-based feature matching
Mingi Ji, Byeongho Heo, Sungrae Park
AAAI, 2018
AFD提出在蒸馏的过程中,利用基于注意力的元网络学习特征之间的相对相似性,并应用识别的相似关系来控制所有可能的特征图pair的蒸馏强度。
在ImageNet1k公开数据集上,效果如下所示。
策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
---|---|---|---|---|
baseline | ResNet18 | ResNet18.yaml | 70.8% | - |
AFD | ResNet18 | resnet34_distill_resnet18_afd.yaml | 71.68%(+0.88%) | - |
注意:这里为了与论文的训练配置保持对齐,设置训练的迭代轮数为100epoch,因此baseline精度低于PaddleClas中开源出的模型精度(71.0%)
AFD配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,固定教师模型的权重。这里需要对从教师模型获取的特征进行变换,进而与学生模型进行损失函数的计算。在损失函数Loss字段中,需要定义DistillationKLDivLoss
(学生与教师之间的KL-Div loss)、AFDLoss
(学生与教师之间的AFD loss)以及DistillationGTCELoss
(学生与教师关于真值标签的CE loss),作为训练的损失函数。
Arch:
name: "DistillationModel"
pretrained_list:
freeze_params_list:
models:
- Teacher:
name: AttentionModel # 包含若干个串行的网络,后面的网络会将前面的网络输出作为输入并进行处理
pretrained_list:
freeze_params_list:
- True
- False
models:
# AttentionModel 的基础网络
- ResNet34:
name: ResNet34
pretrained: True
# return_patterns表示除了返回输出的logits,也会返回对应名称的中间层feature map
return_patterns: &t_keys ["blocks[0]", "blocks[1]", "blocks[2]", "blocks[3]",
"blocks[4]", "blocks[5]", "blocks[6]", "blocks[7]",
"blocks[8]", "blocks[9]", "blocks[10]", "blocks[11]",
"blocks[12]", "blocks[13]", "blocks[14]", "blocks[15]"]
# AttentionModel的变换网络,会对基础子网络的特征进行变换
- LinearTransformTeacher:
name: LinearTransformTeacher
qk_dim: 128
keys: *t_keys
t_shapes: &t_shapes [[64, 56, 56], [64, 56, 56], [64, 56, 56], [128, 28, 28],
[128, 28, 28], [128, 28, 28], [128, 28, 28], [256, 14, 14],
[256, 14, 14], [256, 14, 14], [256, 14, 14], [256, 14, 14],
[256, 14, 14], [512, 7, 7], [512, 7, 7], [512, 7, 7]]
- Student:
name: AttentionModel
pretrained_list:
freeze_params_list:
- False
- False
models:
- ResNet18:
name: ResNet18
pretrained: False
return_patterns: &s_keys ["blocks[0]", "blocks[1]", "blocks[2]", "blocks[3]",
"blocks[4]", "blocks[5]", "blocks[6]", "blocks[7]"]
- LinearTransformStudent:
name: LinearTransformStudent
qk_dim: 128
keys: *s_keys
s_shapes: &s_shapes [[64, 56, 56], [64, 56, 56], [128, 28, 28], [128, 28, 28],
[256, 14, 14], [256, 14, 14], [512, 7, 7], [512, 7, 7]]
t_shapes: *t_shapes
infer_model_name: "Student"
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
key: logits
- DistillationKLDivLoss: # 蒸馏的KL-Div loss,会根据model_name_pairs中的模型名称去提取对应模型的输出特征,计算loss
weight: 0.9 # 该loss的权重
model_name_pairs: [["Student", "Teacher"]]
temperature: 4
key: logits
- AFDLoss: # AFD loss
weight: 50.0
model_name_pair: ["Student", "Teacher"]
student_keys: ["bilinear_key", "value"]
teacher_keys: ["query", "value"]
s_shapes: *s_shapes
t_shapes: *t_shapes
Eval:
- CELoss:
weight: 1.0
注意(: 上述在网络中指定return_patterns
,返回中间层特征的功能是基于TheseusLayer,更多关于TheseusLayer的使用说明,请参考:TheseusLayer 使用说明。
论文信息:
Decoupled Knowledge Distillation
Borui Zhao, Quan Cui, Renjie Song, Yiyu Qiu, Jiajun Liang
CVPR, 2022
DKD将蒸馏中常用的 KD Loss 进行了解耦成为Target Class Knowledge Distillation(TCKD,目标类知识蒸馏)以及Non-target Class Knowledge Distillation(NCKD,非目标类知识蒸馏)两个部分,对两个部分的作用分别研究,并使它们各自的权重可以独立调节,提升了蒸馏的精度和灵活性。
在ImageNet1k公开数据集上,效果如下所示。
策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
---|---|---|---|---|
baseline | ResNet18 | ResNet18.yaml | 70.8% | - |
DKD | ResNet18 | resnet34_distill_resnet18_dkd.yaml | 72.59%(+1.79%) | - |
DKD 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义DistillationDKDLoss
(学生与教师之间的DKD loss)以及DistillationGTCELoss
(学生与教师关于真值标签的CE loss),作为训练的损失函数。
Arch:
name: "DistillationModel"
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
models:
- Teacher:
name: ResNet34
pretrained: True
- Student:
name: ResNet18
pretrained: False
infer_model_name: "Student"
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
- DistillationDKDLoss:
weight: 1.0
model_name_pairs: [["Student", "Teacher"]]
temperature: 1
alpha: 1.0
beta: 1.0
Eval:
- CELoss:
weight: 1.0
论文信息:
Knowledge Distillation from A Stronger Teacher
Tao Huang, Shan You, Fei Wang, Chen Qian, Chang Xu
2022, under review
使用KD方法进行模型蒸馏时,教师模型精度提升时,蒸馏的效果往往难以同步提升。本文提出DIST方法,使用皮尔逊相关系数(Pearson correlation coefficient)去表征学生模型与教师模型之间的差异,替代蒸馏过程中默认的KL散度,从而保证模型可以学到更加准确的相关性信息。
在ImageNet1k公开数据集上,效果如下所示。
策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
---|---|---|---|---|
baseline | ResNet18 | ResNet18.yaml | 70.8% | - |
DIST | ResNet18 | resnet34_distill_resnet18_dist.yaml | 71.99%(+1.19%) | - |
DIST 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义DistillationDISTLoss
(学生与教师之间的DIST loss)以及DistillationGTCELoss
(学生与教师关于真值标签的CE loss),作为训练的损失函数。
Arch:
name: "DistillationModel"
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
models:
- Teacher:
name: ResNet34
pretrained: True
- Student:
name: ResNet18
pretrained: False
infer_model_name: "Student"
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
- DistillationDISTLoss:
weight: 2.0
model_name_pairs:
- ["Student", "Teacher"]
Eval:
- CELoss:
weight: 1.0
论文信息:
Masked Generative Distillation
Zhendong Yang, Zhe Li, Mingqi Shao, Dachuan Shi, Zehuan Yuan, Chun Yuan
ECCV 2022
该方法针对特征图展开蒸馏,在蒸馏的过程中,对特征进行随机mask,强制学生用部分特征去生成教师模型的所有特征,以提升学生模型的表征能力,最终在特征蒸馏任务上达到了SOTA,并在检测、分割等任务中广泛验证有效。
在ImageNet1k公开数据集上,效果如下所示。
策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
---|---|---|---|---|
baseline | ResNet18 | ResNet18.yaml | 70.8% | - |
MGD | ResNet18 | resnet34_distill_resnet18_mgd.yaml | 71.86%(+1.06%) | - |
MGD 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义DistillationPairLoss
(学生与教师模型之间的MGDLoss)以及DistillationGTCELoss
(学生与教师关于真值标签的CE loss),作为训练的损失函数。
Arch:
name: "DistillationModel"
class_num: &class_num 1000
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
infer_model_name: "Student"
models:
- Teacher:
name: ResNet34
class_num: *class_num
pretrained: True
return_patterns: &t_stages ["blocks[2]", "blocks[6]", "blocks[12]", "blocks[15]"]
- Student:
name: ResNet18
class_num: *class_num
pretrained: False
return_patterns: &s_stages ["blocks[1]", "blocks[3]", "blocks[5]", "blocks[7]"]
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
- DistillationPairLoss:
weight: 1.0
model_name_pairs: [["Student", "Teacher"]] # calculate mgdloss for Student and Teacher
name: "loss_mgd"
base_loss_name: MGDLoss # MGD loss,the following are parameters of 'MGD loss'
s_key: "blocks[7]" # feature map used to calculate MGD loss in student model
t_key: "blocks[15]" # feature map used to calculate MGD loss in teacher model
student_channels: 512 # channel num for stduent feature map
teacher_channels: 512 # channel num for teacher feature map
Eval:
- CELoss:
weight: 1.0
论文信息:
Rethinking Soft Labels For Knowledge Distillation: A Bias-variance Tradeoff Perspective
Helong Zhou, Liangchen Song, Jiajie Chen, Ye Zhou, Guoli Wang, Junsong Yuan, Qian Zhang
ICLR, 2021
WSL (Weighted Soft Labels) 损失函数根据教师模型与学生模型关于真值标签的 CE Loss 比值,对每个样本的 KD Loss 分别赋予权重。若学生模型相对教师模型在某个样本上预测结果更好,则对该样本赋予较小的权重。该方法简单、有效,使各个样本的权重可自适应调节,提升了蒸馏精度。
在ImageNet1k公开数据集上,效果如下所示。
策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
---|---|---|---|---|
baseline | ResNet18 | ResNet18.yaml | 70.8% | - |
WSL | ResNet18 | resnet34_distill_resnet18_wsl.yaml | 72.23%(+1.43%) | - |
WSL 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义DistillationGTCELoss
(学生与真值标签之间的CE loss)以及DistillationWSLLoss
(学生与教师之间的WSL loss),作为训练的损失函数。
# model architecture
Arch:
name: "DistillationModel"
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
models:
- Teacher:
name: ResNet34
pretrained: True
- Student:
name: ResNet18
pretrained: False
infer_model_name: "Student"
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
- DistillationWSLLoss:
weight: 2.5
model_name_pairs: [["Student", "Teacher"]]
temperature: 2
Eval:
- CELoss:
weight: 1.0
论文信息:
Reducing the Teacher-Student Gap via Spherical Knowledge Disitllation
Jia Guo, Minghao Chen, Yao Hu, Chen Zhu, Xiaofei He, Deng Cai
2022, under review
使用更大、精度更高的教师模型蒸馏学生模型,学生模型的精度往往反而降低。SKD (Spherical Knowledge Disitllation) 方法显式地消除了教师与学生之间的置信度差距,缓解了教师与学生之间的容量差距问题。SKD在ImageNet1k上蒸馏ResNet18的任务上显著超越了SOTA。
在ImageNet1k公开数据集上,效果如下所示。
策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
---|---|---|---|---|
baseline | ResNet18 | ResNet18.yaml | 70.8% | - |
SKD | ResNet18 | resnet34_distill_resnet18_skd.yaml | 72.84%(+2.04%) | - |
SKD 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义DistillationSKDLoss
(学生与教师之间的SKD loss),作为训练的损失函数。需要注意的是,SKD loss包含了学生与教师模型之间的KL div loss和学生模型与真值标签之间的CE loss,因此无需定义DistillationGTCELoss
。
# model architecture
Arch:
name: "DistillationModel"
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
models:
- Teacher:
name: ResNet34
pretrained: True
- Student:
name: ResNet18
pretrained: False
infer_model_name: "Student"
# loss function config for traing/eval process
Loss:
Train:
- DistillationSKDLoss:
weight: 1.0
model_name_pairs: [["Student", "Teacher"]]
temperature: 1.0
multiplier: 2.0
alpha: 0.9
Eval:
- CELoss:
weight: 1.0
论文信息:
Improved Feature Distillation via Projector Ensemble
Yudong Chen, Sen Wang, Jiajun Liu, Xuwei Xu, Frank de Hoog, Zi Huang
NeurIPS 2022
PEFD使用多个projector对学生特征图进行投影并ensemble,来拟合教师的特征图。与不使用projector或使用单个projector相比,该方法可以避免学生模型对教师特征的过拟合,进一步提高特征蒸馏的性能。
在ImageNet1k公开数据集上,效果如下所示。
策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
---|---|---|---|---|
baseline | ResNet18 | ResNet18.yaml | 70.8% | - |
PEFD | ResNet18 | resnet34_distill_resnet18_pefd.yaml | 72.23%(+1.43%) | - |
PEFD 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义DistillationPairLoss
(学生与教师模型之间的PEFDLoss)以及DistillationGTCELoss
(学生与教师关于真值标签的CE loss),作为训练的损失函数。
# model architecture
Arch:
name: "DistillationModel"
class_num: &class_num 1000
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
infer_model_name: "Student"
models:
- Teacher:
name: ResNet34
class_num: *class_num
pretrained: True
return_patterns: &t_stages ["avg_pool"]
- Student:
name: ResNet18
class_num: *class_num
pretrained: False
return_patterns: &s_stages ["avg_pool"]
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
- DistillationPairLoss:
weight: 25.0
base_loss_name: PEFDLoss
model_name_pairs: [["Student", "Teacher"]]
s_key: "avg_pool"
t_key: "avg_pool"
name: "loss_pefd"
student_channel: 512
teacher_channel: 512
Eval:
- CELoss:
weight: 1.0
- 安装:请先参考 Paddle 安装教程 以及 PaddleClas 安装教程 配置 PaddleClas 运行环境。
请在ImageNet 官网准备 ImageNet-1k 相关的数据。
进入 PaddleClas 目录。
cd path_to_PaddleClas
进入 dataset/
目录,将下载好的数据命名为 ILSVRC2012
,存放于此。 ILSVRC2012
目录中具有以下数据:
├── train
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
├── train_list.txt
...
├── val
│ ├── ILSVRC2012_val_00000001.JPEG
│ ├── ILSVRC2012_val_00000002.JPEG
├── val_list.txt
其中 train/
和 val/
分别为训练集和验证集。train_list.txt
和 val_list.txt
分别为训练集和验证集的标签文件。
如果包含与训练集场景相似的无标注数据,则也可以按照与训练集标注完全相同的方式进行整理,将文件与当前有标注的数据集放在相同目录下,将其标签值记为0,假设整理的标签文件名为train_list_unlabel.txt
,则可以通过下面的命令生成用于SSLD训练的标签文件。
cat train_list.txt train_list_unlabel.txt > train_list_all.txt
备注:
- 关于
train_list.txt
、val_list.txt
的格式说明,可以参考PaddleClas分类数据集格式说明 。
以SSLD知识蒸馏算法为例,介绍知识蒸馏算法的模型训练、评估、预测等过程。配置文件为 PPLCNet_x2_5_ssld.yaml ,使用下面的命令可以完成模型训练。
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3 -m paddle.distributed.launch \
--gpus="0,1,2,3" \
tools/train.py \
-c ppcls/configs/ImageNet/Distillation/PPLCNet_x2_5_ssld.yaml
训练好模型之后,可以通过以下命令实现对模型指标的评估。
python3 tools/eval.py \
-c ppcls/configs/ImageNet/Distillation/PPLCNet_x2_5_ssld.yaml \
-o Global.pretrained_model=output/DistillationModel/best_model
其中 -o Global.pretrained_model="output/DistillationModel/best_model"
指定了当前最佳权重所在的路径,如果指定其他权重,只需替换对应的路径即可。
模型训练完成之后,可以加载训练得到的预训练模型,进行模型预测。在模型库的 tools/infer.py
中提供了完整的示例,只需执行下述命令即可完成模型预测:
python3 tools/infer.py \
-c ppcls/configs/ImageNet/Distillation/PPLCNet_x2_5_ssld.yaml \
-o Global.pretrained_model=output/DistillationModel/best_model
输出结果如下:
[{'class_ids': [8, 7, 86, 82, 21], 'scores': [0.87908, 0.12091, 0.0, 0.0, 0.0], 'file_name': 'docs/images/inference_deployment/whl_demo.jpg', 'label_names': ['hen', 'cock', 'partridge', 'ruffed grouse, partridge, Bonasa umbellus', 'kite']}]
备注:
-
这里
-o Global.pretrained_model="output/ResNet50/best_model"
指定了当前最佳权重所在的路径,如果指定其他权重,只需替换对应的路径即可。 -
默认是对
docs/images/inference_deployment/whl_demo.jpg
进行预测,此处也可以通过增加字段-o Infer.infer_imgs=xxx
对其他图片预测。
Paddle Inference 是飞桨的原生推理库, 作用于服务器端和云端,提供高性能的推理能力。相比于直接基于预训练模型进行预测,Paddle Inference可使用MKLDNN、CUDNN、TensorRT 进行预测加速,从而实现更优的推理性能。更多关于Paddle Inference推理引擎的介绍,可以参考Paddle Inference官网教程。
在模型推理之前需要先导出模型。对于知识蒸馏训练得到的模型,在导出时需要指定-o Global.infer_model_name=Student
,来表示导出的模型为学生模型。具体命令如下所示。
python3 tools/export_model.py \
-c ppcls/configs/ImageNet/Distillation/PPLCNet_x2_5_ssld.yaml \
-o Global.pretrained_model=./output/DistillationModel/best_model \
-o Arch.infer_model_name=Student
最终在inference
目录下会产生inference.pdiparams
、inference.pdiparams.info
、inference.pdmodel
3个文件。
关于更多模型推理相关的教程,请参考:Python 预测推理。
[1] Hinton G, Vinyals O, Dean J. Distilling the knowledge in a neural network[J]. arXiv preprint arXiv:1503.02531, 2015.
[2] Bagherinezhad H, Horton M, Rastegari M, et al. Label refinery: Improving imagenet classification through label progression[J]. arXiv preprint arXiv:1805.02641, 2018.
[3] Yalniz I Z, Jégou H, Chen K, et al. Billion-scale semi-supervised learning for image classification[J]. arXiv preprint arXiv:1905.00546, 2019.
[4] Cubuk E D, Zoph B, Mane D, et al. Autoaugment: Learning augmentation strategies from data[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2019: 113-123.
[5] Touvron H, Vedaldi A, Douze M, et al. Fixing the train-test resolution discrepancy[C]//Advances in Neural Information Processing Systems. 2019: 8250-8260.
[6] Cui C, Guo R, Du Y, et al. Beyond Self-Supervision: A Simple Yet Effective Network Distillation Alternative to Improve Backbones[J]. arXiv preprint arXiv:2103.05959, 2021.
[7] Zhang Y, Xiang T, Hospedales T M, et al. Deep mutual learning[C]//Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018: 4320-4328.
[8] Heo B, Kim J, Yun S, et al. A comprehensive overhaul of feature distillation[C]//Proceedings of the IEEE/CVF International Conference on Computer Vision. 2019: 1921-1930.
[9] Du Y, Li C, Guo R, et al. PP-OCRv2: Bag of Tricks for Ultra Lightweight OCR System[J]. arXiv preprint arXiv:2109.03144, 2021.
[10] Park W, Kim D, Lu Y, et al. Relational knowledge distillation[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019: 3967-3976.
[11] Zhao B, Cui Q, Song R, et al. Decoupled Knowledge Distillation[J]. arXiv preprint arXiv:2203.08679, 2022.
[12] Ji M, Heo B, Park S. Show, attend and distill: Knowledge distillation via attention-based feature matching[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2021, 35(9): 7945-7952.
[13] Huang T, You S, Wang F, et al. Knowledge Distillation from A Stronger Teacher[J]. arXiv preprint arXiv:2205.10536, 2022.