diff --git a/config.py b/config.py index b78e5a9b8..3c38d2644 100644 --- a/config.py +++ b/config.py @@ -256,6 +256,22 @@ def create_parser(): group.add_argument('--drop_overflow_update', type=bool, default=False, help='Whether to execute optimizer if there is an overflow (default=False)') + # distillation + group = parser.add_argument_group('Distillation parameters') + group.add_argument('--distillation_type', type=str, default=None, + choices=['hard', 'soft'], + help='The type of distillation (default=None)') + group.add_argument('--teacher_model', type=str, default=None, + help='Name of teacher model (default=None)') + group.add_argument('--teacher_ckpt_path', type=str, default='', + help='Initialize teacher model from this checkpoint. ' + 'If use distillation, specify the checkpoint path (default="").') + group.add_argument('--teacher_ema', type=str2bool, nargs='?', const=True, default=False, + help='Whether teacher model training with ema (default=False)') + group.add_argument('--distillation_alpha', type=float, default=0.5, + help='The coefficient to balance the distillation loss and base loss. ' + '(default=0.5)') + # modelarts group = parser.add_argument_group('modelarts') group.add_argument('--enable_modelarts', type=str2bool, nargs='?', const=True, default=False, diff --git a/mindcv/utils/__init__.py b/mindcv/utils/__init__.py index 39b346e04..12d174dc1 100644 --- a/mindcv/utils/__init__.py +++ b/mindcv/utils/__init__.py @@ -2,6 +2,7 @@ from .amp import * from .callbacks import * from .checkpoint_manager import * +from .distillation import * from .download import * from .logger import * from .path import * diff --git a/mindcv/utils/distillation.py b/mindcv/utils/distillation.py new file mode 100644 index 000000000..4ac24c139 --- /dev/null +++ b/mindcv/utils/distillation.py @@ -0,0 +1,87 @@ +""" distillation related functions """ +from types import MethodType + +import mindspore as ms +from mindspore import nn +from mindspore.ops import functional as F + + +class DistillLossCell(nn.WithLossCell): + """ + Wraps the network with hard distillation loss function. + + Get the loss of student network and an extra knowledge distillation loss by + taking a teacher model prediction and using it as additional supervision. + + Args: + backbone (Cell): The student network to train and calculate base loss. + loss_fn (Cell): The loss function used to compute loss of student network. + distillation_type (str): The type of distillation. + teacher_model (Cell): The teacher network to calculate distillation loss. + alpha (float): The coefficient to balance the distillation loss and base loss. Default: 0.5. + tau (float): Distillation temperature. The higher the temperature, the lower the + dispersion of the loss calculated by Kullback-Leibler divergence loss. Default: 1.0. + """ + + def __init__(self, backbone, loss_fn, distillation_type, teacher_model, alpha=0.5, tau=1.0): + super().__init__(backbone, loss_fn) + if distillation_type == "hard": + self.hard_type = True + elif distillation_type == "soft": + self.hard_type = False + else: + raise ValueError(f"Distillation type only support ['hard', 'soft'], but got {distillation_type}.") + self.teacher_model = teacher_model + self.alpha = alpha + self.tau = tau + + def construct(self, data, label): + out = self._backbone(data) + + out, out_kd = out + base_loss = self._loss_fn(out, label) + + teacher_out = F.stop_gradient(self.teacher_model(data)) + + if self.hard_type: + distillation_loss = F.cross_entropy(out_kd, teacher_out.argmax(axis=1)) + else: + T = self.tau + out_kd = F.cast(out_kd, ms.float32) + distillation_loss = ( + F.kl_div( + F.log_softmax(out_kd / T, axis=1), + F.log_softmax(teacher_out / T, axis=1), + reduction="sum", + ) + * (T * T) + / F.size(out_kd) + ) + + loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha + + return loss + + +def bn_infer_only(self, x): + return self.bn_infer(x, self.gamma, self.beta, self.moving_mean, self.moving_variance)[0] + + +def dropout_infer_only(self, x): + return x + + +def set_validation(network): + """ + Since MindSpore cannot automatically set some cells to validation mode + during training in the teacher network, we need to manually set these + cells to validation mode in this function. + """ + + for _, cell in network.cells_and_names(): + if isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d, nn.BatchNorm3d)): + cell.construct = MethodType(bn_infer_only, cell) + elif isinstance(cell, (nn.Dropout, nn.Dropout2d, nn.Dropout3d, nn.InstanceNorm2d)): + cell.construct = MethodType(dropout_infer_only, cell) + else: + cell.set_train(False) diff --git a/mindcv/utils/trainer_factory.py b/mindcv/utils/trainer_factory.py index db47a48e6..c471488cc 100644 --- a/mindcv/utils/trainer_factory.py +++ b/mindcv/utils/trainer_factory.py @@ -9,6 +9,7 @@ from mindspore.train import DynamicLossScaleManager, FixedLossScaleManager, Model from .amp import auto_mixed_precision +from .distillation import DistillLossCell from .train_step import TrainStep __all__ = [ @@ -38,6 +39,7 @@ def require_customized_train_step( clip_grad: bool = False, gradient_accumulation_steps: int = 1, amp_cast_list: Optional[str] = None, + distillation_type: Optional[str] = None, ): if ema: return True @@ -47,6 +49,8 @@ def require_customized_train_step( return True if amp_cast_list: return True + if distillation_type: + return True return False @@ -88,6 +92,9 @@ def create_trainer( clip_grad: bool = False, clip_value: float = 15.0, gradient_accumulation_steps: int = 1, + distillation_type: Optional[str] = None, + teacher_network: Optional[nn.Cell] = None, + distillation_alpha: float = 0.5, ): """Create Trainer. @@ -106,6 +113,9 @@ def create_trainer( clip_grad: whether to gradient clip. clip_value: The value at which to clip gradients. gradient_accumulation_steps: Accumulate the gradients of n batches before update. + distillation_type: The type of distillation. + teacher_network: The teacher network for distillation. + distillation_alpha: The coefficient to balance the distillation loss and base loss. Returns: mindspore.Model @@ -120,7 +130,7 @@ def create_trainer( if gradient_accumulation_steps < 1: raise ValueError("`gradient_accumulation_steps` must be >= 1!") - if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list): + if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list, distillation_type): mindspore_kwargs = dict( network=network, loss_fn=loss, @@ -149,7 +159,10 @@ def create_trainer( else: # require customized train step eval_network = nn.WithEvalCell(network, loss, amp_level in ["O2", "O3", "auto"]) auto_mixed_precision(network, amp_level, amp_cast_list) - net_with_loss = add_loss_network(network, loss, amp_level) + if distillation_type: + net_with_loss = DistillLossCell(network, loss, distillation_type, teacher_network, distillation_alpha) + else: + net_with_loss = add_loss_network(network, loss, amp_level) train_step_kwargs = dict( network=net_with_loss, optimizer=optimizer, diff --git a/train.py b/train.py index 644948a54..08fa8dd44 100644 --- a/train.py +++ b/train.py @@ -19,6 +19,7 @@ require_customized_train_step, set_logger, set_seed, + set_validation, ) from config import parse_args, save_args # isort: skip @@ -180,6 +181,19 @@ def train(args): aux_factor=args.aux_factor, ) + # create teacher model + teacher_network = None + if args.distillation_type: + if not args.teacher_ckpt_path: + logger.warning("You are using distillation, but your teacher model has not loaded weights.") + teacher_network = create_model( + model_name=args.teacher_model, + num_classes=num_classes, + checkpoint_path=args.teacher_ckpt_path, + ema=args.teacher_ema, + ) + set_validation(teacher_network) + # create learning rate schedule lr_scheduler = create_scheduler( num_batches, @@ -213,6 +227,7 @@ def train(args): args.clip_grad, args.gradient_accumulation_steps, args.amp_cast_list, + args.distillation_type, ) ): optimizer_loss_scale = args.loss_scale @@ -250,6 +265,9 @@ def train(args): clip_grad=args.clip_grad, clip_value=args.clip_value, gradient_accumulation_steps=args.gradient_accumulation_steps, + distillation_type=args.distillation_type, + teacher_network=teacher_network, + distillation_alpha=args.distillation_alpha, ) # callback