From 81cc2246e8b04b6b929a2d67d50b67343e23b787 Mon Sep 17 00:00:00 2001 From: Yoshitomo Matsubara Date: Sun, 21 Apr 2024 17:31:46 -0700 Subject: [PATCH] Add KD with logits standardization --- .../kd_w_ls/resnet18_from_resnet34.yaml | 153 ++++++++++++++++++ torchdistill/losses/mid_level.py | 55 +++++++ 2 files changed, 208 insertions(+) create mode 100644 configs/sample/ilsvrc2012/kd_w_ls/resnet18_from_resnet34.yaml diff --git a/configs/sample/ilsvrc2012/kd_w_ls/resnet18_from_resnet34.yaml b/configs/sample/ilsvrc2012/kd_w_ls/resnet18_from_resnet34.yaml new file mode 100644 index 00000000..96fb97dc --- /dev/null +++ b/configs/sample/ilsvrc2012/kd_w_ls/resnet18_from_resnet34.yaml @@ -0,0 +1,153 @@ +datasets: + &imagenet_train ilsvrc2012/train: !import_call + _name: &dataset_name 'ilsvrc2012' + _root: &root_dir !join ['~/datasets/', *dataset_name] + key: 'torchvision.datasets.ImageFolder' + init: + kwargs: + root: !join [*root_dir, '/train'] + transform: !import_call + key: 'torchvision.transforms.Compose' + init: + kwargs: + transforms: + - !import_call + key: 'torchvision.transforms.RandomResizedCrop' + init: + kwargs: + size: &input_size [224, 224] + - !import_call + key: 'torchvision.transforms.RandomHorizontalFlip' + init: + kwargs: + p: 0.5 + - !import_call + key: 'torchvision.transforms.ToTensor' + init: + - !import_call + key: 'torchvision.transforms.Normalize' + init: + kwargs: &normalize_kwargs + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + &imagenet_val ilsvrc2012/val: !import_call + key: 'torchvision.datasets.ImageFolder' + init: + kwargs: + root: !join [*root_dir, '/val'] + transform: !import_call + key: 'torchvision.transforms.Compose' + init: + kwargs: + transforms: + - !import_call + key: 'torchvision.transforms.Resize' + init: + kwargs: + size: 256 + - !import_call + key: 'torchvision.transforms.CenterCrop' + init: + kwargs: + size: *input_size + - !import_call + key: 'torchvision.transforms.ToTensor' + init: + - !import_call + key: 'torchvision.transforms.Normalize' + init: + kwargs: *normalize_kwargs + +models: + teacher_model: + key: &teacher_model_key 'resnet34' + _weights: &teacher_weights !import_get + key: 'torchvision.models.resnet.ResNet34_Weights' + kwargs: + num_classes: 1000 + weights: !getattr [*teacher_weights, 'IMAGENET1K_V1'] + src_ckpt: + student_model: + key: &student_model_key 'resnet18' + kwargs: + num_classes: 1000 + _experiment: &student_experiment !join [*dataset_name, '-', *student_model_key, '_from_', *teacher_model_key] + src_ckpt: + dst_ckpt: !join ['./resource/ckpt/ilsvrc2012/kd_w_ls/', *student_experiment, '.pt'] + +train: + log_freq: 1000 + num_epochs: 100 + train_data_loader: + dataset_id: *imagenet_train + sampler: + class_or_func: !import_get + key: 'torch.utils.data.RandomSampler' + kwargs: + kwargs: + batch_size: 512 + num_workers: 16 + pin_memory: True + drop_last: False + cache_output: + val_data_loader: + dataset_id: *imagenet_val + sampler: &val_sampler + class_or_func: !import_get + key: 'torch.utils.data.SequentialSampler' + kwargs: + kwargs: + batch_size: 32 + num_workers: 16 + pin_memory: True + drop_last: False + teacher: + forward_proc: 'forward_batch_only' + sequential: [] + wrapper: 'DataParallel' + requires_grad: False + student: + forward_proc: 'forward_batch_only' + adaptations: + sequential: [] + wrapper: 'DistributedDataParallel' + requires_grad: True + frozen_modules: [] + optimizer: + key: 'SGD' + kwargs: + lr: 0.2 + momentum: 0.9 + weight_decay: 0.0001 + scheduler: + key: 'MultiStepLR' + kwargs: + milestones: [30, 60, 90] + gamma: 0.1 + criterion: + key: 'WeightedSumLoss' + kwargs: + sub_terms: + kd: + criterion: + key: 'LogitStdKDLoss' + kwargs: + student_module_path: '.' + student_module_io: 'output' + teacher_module_path: '.' + teacher_module_io: 'output' + temperature: 2.0 + alpha: 0.5 + beta: 9 + reduction: 'batchmean' + weight: 1.0 + +test: + test_data_loader: + dataset_id: *imagenet_val + sampler: *val_sampler + kwargs: + batch_size: 1 + num_workers: 16 + pin_memory: True + drop_last: False diff --git a/torchdistill/losses/mid_level.py b/torchdistill/losses/mid_level.py index 46565d6b..466b23ab 100644 --- a/torchdistill/losses/mid_level.py +++ b/torchdistill/losses/mid_level.py @@ -1677,3 +1677,58 @@ def forward(self, student_io_dict, teacher_io_dict, *args, **kwargs): torch.softmax(teacher_logits / self.temperature, dim=1)) loss = 2 * feat_distill_loss + kl_loss return loss + + +@register_mid_level_loss +class LogitStdKDLoss(nn.KLDivLoss): + """ + A standard knowledge distillation (KD) loss module with logits standardization. + + Shangquan Sun, Wenqi Ren, Jingzhi Li, Rui Wang, Xiaochun Cao: `"Logit Standardization in Knowledge Distillation" `_ @ CVPR 2024 (2024) + + :param student_module_path: student model's logit module path. + :type student_module_path: str + :param student_module_io: 'input' or 'output' of the module in the student model. + :type student_module_io: str + :param teacher_module_path: teacher model's logit module path. + :type teacher_module_path: str + :param teacher_module_io: 'input' or 'output' of the module in the teacher model. + :type teacher_module_io: str + :param temperature: hyperparameter :math:`\\tau` to soften class-probability distributions. + :type temperature: float + :param eps: value added to the denominator for numerical stability. + :type eps: float + :param alpha: balancing factor for :math:`L_{CE}`, cross-entropy. + :type alpha: float + :param beta: balancing factor (default: :math:`1 - \\alpha`) for :math:`L_{KL}`, KL divergence between class-probability distributions softened by :math:`\\tau`. + :type beta: float or None + :param reduction: ``reduction`` for KLDivLoss. If ``reduction`` = 'batchmean', CrossEntropyLoss's ``reduction`` will be 'mean'. + :type reduction: str or None + """ + def __init__(self, student_module_path, student_module_io, teacher_module_path, teacher_module_io, + temperature, eps=1e-7, alpha=None, beta=None, reduction='batchmean', **kwargs): + super().__init__(reduction=reduction) + self.student_module_path = student_module_path + self.student_module_io = student_module_io + self.teacher_module_path = teacher_module_path + self.teacher_module_io = teacher_module_io + self.temperature = temperature + self.eps = eps + self.alpha = alpha + self.beta = 1 - alpha if beta is None else beta + cel_reduction = 'mean' if reduction == 'batchmean' else reduction + self.cross_entropy_loss = nn.CrossEntropyLoss(reduction=cel_reduction, **kwargs) + + def standardize(self, logits): + return (logits - logits.mean(dim=-1, keepdims=True)) / (self.eps + logits.std(dim=-1, keepdims=True)) + + def forward(self, student_io_dict, teacher_io_dict, targets=None, *args, **kwargs): + student_logits = student_io_dict[self.student_module_path][self.student_module_io] + teacher_logits = teacher_io_dict[self.teacher_module_path][self.teacher_module_io] + soft_loss = super().forward(torch.log_softmax(self.standardize(student_logits) / self.temperature, dim=1), + torch.softmax(self.standardize(teacher_logits) / self.temperature, dim=1)) + if self.alpha is None or self.alpha == 0 or targets is None: + return soft_loss + + hard_loss = self.cross_entropy_loss(student_logits, targets) + return self.alpha * hard_loss + self.beta * (self.temperature ** 2) * soft_loss