Skip to content

Commit

Permalink
Merge pull request #460 from yoshitomo-matsubara/dev
Browse files Browse the repository at this point in the history
Add KD with logits standardization
  • Loading branch information
yoshitomo-matsubara authored Apr 22, 2024
2 parents d960c30 + 81cc224 commit 182b11d
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 0 deletions.
153 changes: 153 additions & 0 deletions configs/sample/ilsvrc2012/kd_w_ls/resnet18_from_resnet34.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
datasets:
&imagenet_train ilsvrc2012/train: !import_call
_name: &dataset_name 'ilsvrc2012'
_root: &root_dir !join ['~/datasets/', *dataset_name]
key: 'torchvision.datasets.ImageFolder'
init:
kwargs:
root: !join [*root_dir, '/train']
transform: !import_call
key: 'torchvision.transforms.Compose'
init:
kwargs:
transforms:
- !import_call
key: 'torchvision.transforms.RandomResizedCrop'
init:
kwargs:
size: &input_size [224, 224]
- !import_call
key: 'torchvision.transforms.RandomHorizontalFlip'
init:
kwargs:
p: 0.5
- !import_call
key: 'torchvision.transforms.ToTensor'
init:
- !import_call
key: 'torchvision.transforms.Normalize'
init:
kwargs: &normalize_kwargs
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
&imagenet_val ilsvrc2012/val: !import_call
key: 'torchvision.datasets.ImageFolder'
init:
kwargs:
root: !join [*root_dir, '/val']
transform: !import_call
key: 'torchvision.transforms.Compose'
init:
kwargs:
transforms:
- !import_call
key: 'torchvision.transforms.Resize'
init:
kwargs:
size: 256
- !import_call
key: 'torchvision.transforms.CenterCrop'
init:
kwargs:
size: *input_size
- !import_call
key: 'torchvision.transforms.ToTensor'
init:
- !import_call
key: 'torchvision.transforms.Normalize'
init:
kwargs: *normalize_kwargs

models:
teacher_model:
key: &teacher_model_key 'resnet34'
_weights: &teacher_weights !import_get
key: 'torchvision.models.resnet.ResNet34_Weights'
kwargs:
num_classes: 1000
weights: !getattr [*teacher_weights, 'IMAGENET1K_V1']
src_ckpt:
student_model:
key: &student_model_key 'resnet18'
kwargs:
num_classes: 1000
_experiment: &student_experiment !join [*dataset_name, '-', *student_model_key, '_from_', *teacher_model_key]
src_ckpt:
dst_ckpt: !join ['./resource/ckpt/ilsvrc2012/kd_w_ls/', *student_experiment, '.pt']

train:
log_freq: 1000
num_epochs: 100
train_data_loader:
dataset_id: *imagenet_train
sampler:
class_or_func: !import_get
key: 'torch.utils.data.RandomSampler'
kwargs:
kwargs:
batch_size: 512
num_workers: 16
pin_memory: True
drop_last: False
cache_output:
val_data_loader:
dataset_id: *imagenet_val
sampler: &val_sampler
class_or_func: !import_get
key: 'torch.utils.data.SequentialSampler'
kwargs:
kwargs:
batch_size: 32
num_workers: 16
pin_memory: True
drop_last: False
teacher:
forward_proc: 'forward_batch_only'
sequential: []
wrapper: 'DataParallel'
requires_grad: False
student:
forward_proc: 'forward_batch_only'
adaptations:
sequential: []
wrapper: 'DistributedDataParallel'
requires_grad: True
frozen_modules: []
optimizer:
key: 'SGD'
kwargs:
lr: 0.2
momentum: 0.9
weight_decay: 0.0001
scheduler:
key: 'MultiStepLR'
kwargs:
milestones: [30, 60, 90]
gamma: 0.1
criterion:
key: 'WeightedSumLoss'
kwargs:
sub_terms:
kd:
criterion:
key: 'LogitStdKDLoss'
kwargs:
student_module_path: '.'
student_module_io: 'output'
teacher_module_path: '.'
teacher_module_io: 'output'
temperature: 2.0
alpha: 0.5
beta: 9
reduction: 'batchmean'
weight: 1.0

test:
test_data_loader:
dataset_id: *imagenet_val
sampler: *val_sampler
kwargs:
batch_size: 1
num_workers: 16
pin_memory: True
drop_last: False
55 changes: 55 additions & 0 deletions torchdistill/losses/mid_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,3 +1677,58 @@ def forward(self, student_io_dict, teacher_io_dict, *args, **kwargs):
torch.softmax(teacher_logits / self.temperature, dim=1))
loss = 2 * feat_distill_loss + kl_loss
return loss


@register_mid_level_loss
class LogitStdKDLoss(nn.KLDivLoss):
"""
A standard knowledge distillation (KD) loss module with logits standardization.
Shangquan Sun, Wenqi Ren, Jingzhi Li, Rui Wang, Xiaochun Cao: `"Logit Standardization in Knowledge Distillation" <https://arxiv.org/abs/2403.01427>`_ @ CVPR 2024 (2024)
:param student_module_path: student model's logit module path.
:type student_module_path: str
:param student_module_io: 'input' or 'output' of the module in the student model.
:type student_module_io: str
:param teacher_module_path: teacher model's logit module path.
:type teacher_module_path: str
:param teacher_module_io: 'input' or 'output' of the module in the teacher model.
:type teacher_module_io: str
:param temperature: hyperparameter :math:`\\tau` to soften class-probability distributions.
:type temperature: float
:param eps: value added to the denominator for numerical stability.
:type eps: float
:param alpha: balancing factor for :math:`L_{CE}`, cross-entropy.
:type alpha: float
:param beta: balancing factor (default: :math:`1 - \\alpha`) for :math:`L_{KL}`, KL divergence between class-probability distributions softened by :math:`\\tau`.
:type beta: float or None
:param reduction: ``reduction`` for KLDivLoss. If ``reduction`` = 'batchmean', CrossEntropyLoss's ``reduction`` will be 'mean'.
:type reduction: str or None
"""
def __init__(self, student_module_path, student_module_io, teacher_module_path, teacher_module_io,
temperature, eps=1e-7, alpha=None, beta=None, reduction='batchmean', **kwargs):
super().__init__(reduction=reduction)
self.student_module_path = student_module_path
self.student_module_io = student_module_io
self.teacher_module_path = teacher_module_path
self.teacher_module_io = teacher_module_io
self.temperature = temperature
self.eps = eps
self.alpha = alpha
self.beta = 1 - alpha if beta is None else beta
cel_reduction = 'mean' if reduction == 'batchmean' else reduction
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction=cel_reduction, **kwargs)

def standardize(self, logits):
return (logits - logits.mean(dim=-1, keepdims=True)) / (self.eps + logits.std(dim=-1, keepdims=True))

def forward(self, student_io_dict, teacher_io_dict, targets=None, *args, **kwargs):
student_logits = student_io_dict[self.student_module_path][self.student_module_io]
teacher_logits = teacher_io_dict[self.teacher_module_path][self.teacher_module_io]
soft_loss = super().forward(torch.log_softmax(self.standardize(student_logits) / self.temperature, dim=1),
torch.softmax(self.standardize(teacher_logits) / self.temperature, dim=1))
if self.alpha is None or self.alpha == 0 or targets is None:
return soft_loss

hard_loss = self.cross_entropy_loss(student_logits, targets)
return self.alpha * hard_loss + self.beta * (self.temperature ** 2) * soft_loss

0 comments on commit 182b11d

Please sign in to comment.