diff --git a/.buildinfo b/.buildinfo
new file mode 100644
index 00000000..2ea8b5c4
--- /dev/null
+++ b/.buildinfo
@@ -0,0 +1,4 @@
+# Sphinx build info version 1
+# This file records the configuration used when building these files. When it is not found, a full rebuild will be done.
+config: 304214780642431526784af45404aef0
+tags: 645f666f9bcd5a90fca523b33c5a78b7
diff --git a/.doctrees/benchmarks.doctree b/.doctrees/benchmarks.doctree
new file mode 100644
index 00000000..2b70ef41
Binary files /dev/null and b/.doctrees/benchmarks.doctree differ
diff --git a/.doctrees/environment.pickle b/.doctrees/environment.pickle
new file mode 100644
index 00000000..87b3a7bf
Binary files /dev/null and b/.doctrees/environment.pickle differ
diff --git a/.doctrees/index.doctree b/.doctrees/index.doctree
new file mode 100644
index 00000000..42fa79ed
Binary files /dev/null and b/.doctrees/index.doctree differ
diff --git a/.doctrees/package.doctree b/.doctrees/package.doctree
new file mode 100644
index 00000000..8249a179
Binary files /dev/null and b/.doctrees/package.doctree differ
diff --git a/.doctrees/projects.doctree b/.doctrees/projects.doctree
new file mode 100644
index 00000000..ac62e026
Binary files /dev/null and b/.doctrees/projects.doctree differ
diff --git a/.doctrees/subpkgs/common.doctree b/.doctrees/subpkgs/common.doctree
new file mode 100644
index 00000000..b6eb1f2b
Binary files /dev/null and b/.doctrees/subpkgs/common.doctree differ
diff --git a/.doctrees/subpkgs/core.doctree b/.doctrees/subpkgs/core.doctree
new file mode 100644
index 00000000..6e239d80
Binary files /dev/null and b/.doctrees/subpkgs/core.doctree differ
diff --git a/.doctrees/subpkgs/datasets.doctree b/.doctrees/subpkgs/datasets.doctree
new file mode 100644
index 00000000..92eb93e6
Binary files /dev/null and b/.doctrees/subpkgs/datasets.doctree differ
diff --git a/.doctrees/subpkgs/losses.doctree b/.doctrees/subpkgs/losses.doctree
new file mode 100644
index 00000000..cb866d7b
Binary files /dev/null and b/.doctrees/subpkgs/losses.doctree differ
diff --git a/.doctrees/subpkgs/misc.doctree b/.doctrees/subpkgs/misc.doctree
new file mode 100644
index 00000000..99f3e822
Binary files /dev/null and b/.doctrees/subpkgs/misc.doctree differ
diff --git a/.doctrees/subpkgs/models.doctree b/.doctrees/subpkgs/models.doctree
new file mode 100644
index 00000000..0a25165a
Binary files /dev/null and b/.doctrees/subpkgs/models.doctree differ
diff --git a/.doctrees/subpkgs/optim.doctree b/.doctrees/subpkgs/optim.doctree
new file mode 100644
index 00000000..53134710
Binary files /dev/null and b/.doctrees/subpkgs/optim.doctree differ
diff --git a/.doctrees/usage.doctree b/.doctrees/usage.doctree
new file mode 100644
index 00000000..69152e90
Binary files /dev/null and b/.doctrees/usage.doctree differ
diff --git a/.nojekyll b/.nojekyll
new file mode 100644
index 00000000..e69de29b
diff --git a/_images/logo-color.png b/_images/logo-color.png
new file mode 100644
index 00000000..79cd852d
Binary files /dev/null and b/_images/logo-color.png differ
diff --git a/_modules/index.html b/_modules/index.html
new file mode 100644
index 00000000..aa0daf75
--- /dev/null
+++ b/_modules/index.html
@@ -0,0 +1,152 @@
+
+
+
+
+
+[docs]
+defcheck_if_exists(file_path):
+"""
+ Checks if a file/dir exists.
+
+ :param file_path: file/dir path
+ :type file_path: str
+ :return: True if the given file exists
+ :rtype: bool
+ """
+ returnfile_pathisnotNoneandos.path.exists(file_path)
+
+
+
+
+[docs]
+defget_file_path_list(dir_path,is_recursive=False,is_sorted=False):
+"""
+ Gets file paths for a given dir path.
+
+ :param dir_path: dir path
+ :type dir_path: str
+ :param is_recursive: if True, get file paths recursively
+ :type is_recursive: bool
+ :param is_sorted: if True, sort file paths in ascending order
+ :type is_sorted: bool
+ :return: list of file paths
+ :rtype: list[str]
+ """
+ file_list=list()
+ forfileinos.listdir(dir_path):
+ path=os.path.join(dir_path,file)
+ ifos.path.isfile(path):
+ file_list.append(path)
+ elifis_recursive:
+ file_list.extend(get_file_path_list(path,is_recursive))
+ returnsorted(file_list)ifis_sortedelsefile_list
+
+
+
+
+[docs]
+defget_dir_path_list(dir_path,is_recursive=False,is_sorted=False):
+"""
+ Gets dir paths for a given dir path.
+
+ :param dir_path: dir path
+ :type dir_path: str
+ :param is_recursive: if True, get dir paths recursively
+ :type is_recursive: bool
+ :param is_sorted: if True, sort dir paths in ascending order
+ :type is_sorted: bool
+ :return: list of dir paths
+ :rtype: list[str]
+ """
+ dir_list=list()
+ forfileinos.listdir(dir_path):
+ path=os.path.join(dir_path,file)
+ ifos.path.isdir(path):
+ dir_list.append(path)
+ elifis_recursive:
+ dir_list.extend(get_dir_path_list(path,is_recursive))
+ returnsorted(dir_list)ifis_sortedelsedir_list
+
+
+
+
+[docs]
+defmake_dirs(dir_path):
+"""
+ Makes a directory and its parent directories.
+
+ :param dir_path: dir path
+ :type dir_path: str
+ """
+ Path(dir_path).mkdir(parents=True,exist_ok=True)
+[docs]
+defsave_pickle(obj,file_path):
+"""
+ Saves a serialized object as a file.
+
+ :param obj: object to be serialized
+ :type obj: Any
+ :param file_path: output file path
+ :type file_path: str
+ """
+ make_parent_dirs(file_path)
+ withopen(file_path,'wb')asfp:
+ pickle.dump(obj,fp)
+
+
+
+
+[docs]
+defload_pickle(file_path):
+"""
+ Loads a deserialized object from a file.
+
+ :param file_path: serialized file path
+ :type file_path: str
+ :return: deserialized object
+ :rtype: Any
+ """
+ withopen(file_path,'rb')asfp:
+ returnpickle.load(fp)
+
+
+
+
+[docs]
+defget_binary_object_size(obj,unit_size=1024):
+"""
+ Computes the size of object in bytes after serialization.
+
+ :param obj: object
+ :type obj: Any
+ :param unit_size: unit file size
+ :type unit_size: int or float
+ :return: size of object in bytes, divided by the ``unit_size``
+ :rtype: float
+ """
+ returnsys.getsizeof(pickle.dumps(obj))/unit_size
+[docs]
+defimport_get(key,package=None,**kwargs):
+"""
+ Imports module and get its attribute.
+
+ :param key: attribute name or package path separated by period(.).
+ :type key: str
+ :param package: package path if ``key`` is just an attribute name.
+ :type package: str or None
+ :return: attribute of the imported module.
+ :rtype: Any
+ """
+ ifpackageisNone:
+ names=key.split('.')
+ key=names[-1]
+ package='.'.join(names[:-1])
+
+ logger.info(f'Getting `{key}` from `{package}`')
+ module=import_module(package)
+ returngetattr(module,key)
+
+
+
+
+[docs]
+defimport_call(key,package=None,init=None,**kwargs):
+"""
+ Imports module and call the module/function e.g., instantiation.
+
+ :param key: module name or package path separated by period(.).
+ :type key: str
+ :param package: package path if ``key`` is just an attribute name.
+ :type package: str or None
+ :param init: dict of arguments and/or keyword arguments to instantiate the imported module.
+ :type init: dict
+ :return: object imported and called.
+ :rtype: Any
+ """
+ ifpackageisNone:
+ names=key.split('.')
+ key=names[-1]
+ package='.'.join(names[:-1])
+
+ obj=import_get(key,package)
+ ifinitisNone:
+ init=dict()
+
+ logger.info(f'Calling `{key}` from `{package}` with {init}')
+ args=init.get('args',list())
+ kwargs=init.get('kwargs',dict())
+ returnobj(*args,**kwargs)
+
+
+
+
+[docs]
+defimport_call_method(package,class_name=None,method_name=None,init=None,**kwargs):
+"""
+ Imports module and call its method.
+
+ :param package: package path.
+ :type package: str
+ :param class_name: class name under ``package``.
+ :type class_name: str
+ :param method_name: method name of ``class_name`` class under ``package``.
+ :type method_name: str
+ :param init: dict of arguments and/or keyword arguments to instantiate the imported module.
+ :type init: dict
+ :return: object imported and called.
+ :rtype: Any
+ """
+ ifclass_nameisNoneormethod_nameisNone:
+ names=package.split('.')
+ class_name=names[-2]
+ method_name=names[-1]
+ package='.'.join(names[:-2])
+
+ cls=import_get(class_name,package)
+ ifinitisNone:
+ init=dict()
+
+ logger.info(f'Calling `{class_name}.{method_name}` from `{package}` with {init}')
+ args=init.get('args',list())
+ kwargs=init.get('kwargs',dict())
+ method=getattr(cls,method_name)
+ returnmethod(*args,**kwargs)
+
+
+
+
+[docs]
+defsetup_for_distributed(is_master):
+"""
+ Disables logging when not in master process.
+
+ :param is_master: True if it is the master process.
+ :type is_master: bool
+ """
+ def_logger.setLevel(logging.INFOifis_masterelselogging.WARN)
+ builtin_print=__builtin__.print
+
+ defprint(*args,**kwargs):
+ force=kwargs.pop('force',False)
+ ifis_masterorforce:
+ builtin_print(*args,**kwargs)
+
+ __builtin__.print=print
+
+
+
+
+[docs]
+defset_seed(seed):
+"""
+ Sets a random seed for `random`, `numpy`, and `torch` (torch.manual_seed, torch.cuda.manual_seed_all).
+
+ :param seed: random seed.
+ :type seed: int
+ """
+ ifnotisinstance(seed,int):
+ return
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+
+
+[docs]
+defis_dist_avail_and_initialized():
+"""
+ Checks if distributed model is available and initialized.
+
+ :return: True if distributed mode is available and initialized.
+ :rtype: bool
+ """
+ ifnotdist.is_available():
+ returnFalse
+ ifnotdist.is_initialized():
+ returnFalse
+ returnTrue
+
+
+
+
+[docs]
+defget_world_size():
+"""
+ Gets world size.
+
+ :return: world size.
+ :rtype: int
+ """
+ ifnotis_dist_avail_and_initialized():
+ return1
+ returndist.get_world_size()
+
+
+
+
+[docs]
+defget_rank():
+"""
+ Gets the rank of the current process in the provided ``group`` or the default group if none was provided.
+
+ :return: rank of the current process in the provided ``group`` or the default group if none was provided.
+ :rtype: int
+ """
+ ifnotis_dist_avail_and_initialized():
+ return0
+ returndist.get_rank()
+
+
+
+
+[docs]
+defis_main_process():
+"""
+ Checks if this is the main process.
+
+ :return: True if this is the main process.
+ :rtype: bool
+ """
+ returnget_rank()==0
+
+
+
+
+[docs]
+defsave_on_master(*args,**kwargs):
+"""
+ Use `torch.save` for `args` if this is the main process.
+
+ :return: True if this is the main process.
+ :rtype: bool
+ """
+ ifis_main_process():
+ torch.save(*args,**kwargs)
+
+
+
+
+[docs]
+definit_distributed_mode(world_size=1,dist_url='env://'):
+"""
+ Initialize the distributed mode.
+
+ :param world_size: world size.
+ :type world_size: int
+ :param dist_url: URL specifying how to initialize the process group.
+ :type dist_url: str
+ :return: tuple of 1) whether distributed mode is initialized, 2) world size, and 3) list of device IDs.
+ :rtype: (bool, int, list[int] or None)
+ """
+ if'RANK'inos.environand'WORLD_SIZE'inos.environ:
+ rank=int(os.environ['RANK'])
+ world_size=int(os.environ['WORLD_SIZE'])
+ device_id=int(os.environ['LOCAL_RANK'])
+ elif'SLURM_PROCID'inos.environ:
+ rank=int(os.environ['SLURM_PROCID'])
+ device_id=rank%torch.cuda.device_count()
+ else:
+ logger.info('Not using distributed mode')
+ returnFalse,world_size,None
+
+ torch.cuda.set_device(device_id)
+ dist_backend='nccl'
+ logger.info('| distributed init (rank {}): {}'.format(rank,dist_url))
+ torch.distributed.init_process_group(backend=dist_backend,init_method=dist_url,
+ world_size=world_size,rank=rank)
+ torch.distributed.barrier()
+ setup_for_distributed(rank==0)
+ returnTrue,world_size,[device_id]
+[docs]
+defcheck_if_plottable():
+"""
+ Checks if DISPLAY environmental variable is valid.
+
+ :return: True if DISPLAY variable is valid.
+ :rtype: bool
+ """
+ returnos.environ.get('DISPLAY','')!=''
+
+
+
+
+[docs]
+defget_classes(package_name,require_names=False):
+"""
+ Gets classes in a given package.
+
+ :param package_name: package name.
+ :type package_name: str
+ :param require_names: whether to preserve member names.
+ :type require_names: bool
+ :return: collection of classes defined in the given package.
+ :rtype: list[(str, class)] or list[class]
+ """
+ members=inspect.getmembers(sys.modules[package_name],inspect.isclass)
+ ifrequire_names:
+ returnmembers
+ return[objfor_,objinmembers]
+
+
+
+
+[docs]
+defget_classes_as_dict(package_name,is_lower=False):
+"""
+ Gets classes in a given package as dict.
+
+ :param package_name: package name.
+ :type package_name: str
+ :param is_lower: if True, use lowercase module names.
+ :type is_lower: bool
+ :return: dict of classes defined in the given package.
+ :rtype: dict
+ """
+ members=get_classes(package_name,require_names=True)
+ class_dict=dict()
+ forname,objinmembers:
+ class_dict[name.lower()ifis_lowerelsename]=obj
+ returnclass_dict
+
+
+
+
+[docs]
+defget_functions(package_name,require_names=False):
+"""
+ Gets functions in a given package.
+
+ :param package_name: package name.
+ :type package_name: str
+ :param require_names: whether to preserve function names.
+ :type require_names: bool
+ :return: collection of functions defined in the given package.
+ :rtype: list[(str, typing.Callable)] or list[typing.Callable]
+ """
+ members=inspect.getmembers(sys.modules[package_name],inspect.isfunction)
+ ifrequire_names:
+ returnmembers
+ return[objfor_,objinmembers]
+
+
+
+
+[docs]
+defget_functions_as_dict(package_name,is_lower=False):
+"""
+ Gets function in a given package as dict.
+
+ :param package_name: package name.
+ :type package_name: str
+ :param is_lower: if True, use lowercase module names.
+ :type is_lower: bool
+ :return: dict of classes defined in the given package.
+ :rtype: dict
+ """
+ members=get_functions(package_name,require_names=True)
+ func_dict=dict()
+ forname,objinmembers:
+ func_dict[name.lower()ifis_lowerelsename]=obj
+ returnfunc_dict
+[docs]
+defcheck_if_wrapped(model):
+"""
+ Checks if a given model is wrapped by DataParallel or DistributedDataParallel.
+
+ :param model: model.
+ :type model: nn.Module
+ :return: True if `model` is wrapped by either DataParallel or DistributedDataParallel.
+ :rtype: bool
+ """
+ returnisinstance(model,(DataParallel,DistributedDataParallel))
+
+
+
+
+[docs]
+defcount_params(module):
+"""
+ Returns the number of module parameters.
+
+ :param module: module.
+ :type module: nn.Module
+ :return: number of model parameters.
+ :rtype: int
+ """
+ returnsum(param.numel()forparaminmodule.parameters())
+
+
+
+
+[docs]
+deffreeze_module_params(module):
+"""
+ Freezes parameters by setting requires_grad=False for all the parameters.
+
+ :param module: module.
+ :type module: nn.Module
+ """
+ ifisinstance(module,Module):
+ forparaminmodule.parameters():
+ param.requires_grad=False
+ elifisinstance(module,Parameter):
+ module.requires_grad=False
+
+
+
+
+[docs]
+defunfreeze_module_params(module):
+"""
+ Unfreezes parameters by setting requires_grad=True for all the parameters.
+
+ :param module: module.
+ :type module: nn.Module
+ """
+ ifisinstance(module,Module):
+ forparaminmodule.parameters():
+ param.requires_grad=True
+ elifisinstance(module,Parameter):
+ module.requires_grad=True
+[docs]
+defdecompose(ordered_dict):
+"""
+ Converts an ordered dict into a list of key-value pairs.
+
+ :param ordered_dict: ordered dict.
+ :type ordered_dict: collections.OrderedDict
+ :return: list of key-value pairs.
+ :rtype: list[(str, Any)]
+ """
+ component_list=list()
+ forkey,valueinordered_dict.items():
+ ifisinstance(value,OrderedDict):
+ component_list.append((key,decompose(value)))
+ elifisinstance(value,list):
+ component_list.append((key,value))
+ else:
+ component_list.append(key)
+ returncomponent_list
+
+
+
+
+[docs]
+defget_components(module_paths):
+"""
+ Converts module paths into a list of pairs of parent module and child module names.
+
+ :param module_paths: module paths.
+ :type module_paths: list[str]
+ :return: list of pairs of parent module and child module names.
+ :rtype: list[(str, str)]
+ """
+ ordered_dict=get_hierarchized_dict(module_paths)
+ returndecompose(ordered_dict)
+
+
+
+
+[docs]
+defextract_target_modules(parent_module,target_class,module_list):
+"""
+ Extracts modules that are instance of ``target_class`` and update ``module_list`` with the extracted modules.
+
+ :param parent_module: parent module.
+ :type parent_module: nn.Module
+ :param target_class: target class.
+ :type target_class: class
+ :param module_list: (empty) list to be filled with modules that are instances of ``target_class``.
+ :type module_list: list[nn.Module]
+ """
+ ifisinstance(parent_module,target_class):
+ module_list.append(parent_module)
+
+ child_modules=list(parent_module.children())
+ forchild_moduleinchild_modules:
+ extract_target_modules(child_module,target_class,module_list)
+
+
+
+
+[docs]
+defextract_all_child_modules(parent_module,module_list):
+"""
+ Extracts all the child modules and update ``module_list`` with the extracted modules.
+
+ :param parent_module: parent module.
+ :type parent_module: nn.Module
+ :param module_list: (empty) list to be filled with child modules.
+ :type module_list: list[nn.Module]
+ """
+ child_modules=list(parent_module.children())
+ ifnotchild_modules:
+ module_list.append(parent_module)
+ return
+
+ forchild_moduleinchild_modules:
+ extract_all_child_modules(child_module,module_list)
+fromcollectionsimportnamedtuple
+
+QuantizedTensor=namedtuple('QuantizedTensor',['tensor','scale','zero_point'])
+
+
+# Referred to https://github.com/eladhoffer/utils.pytorch/blob/master/quantize.py
+# and http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
+
+[docs]
+defquantize_tensor(x,num_bits=8):
+"""
+ Quantizes a tensor using `num_bits` int and float.
+
+ Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew Howard, Hartwig Adam, Dmitry Kalenichenko: `"Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference" <https://openaccess.thecvf.com/content_cvpr_2018/html/Jacob_Quantization_and_Training_CVPR_2018_paper.html>`_ @ CVPR 2018 (2018)
+
+ :param x: tensor to be quantized.
+ :type x: torch.Tensor
+ :param num_bits: the number of bits for quantization.
+ :type num_bits: int
+ :return: quantized tensor.
+ :rtype: QuantizedTensor
+ """
+ qmin=0.0
+ qmax=2.0**num_bits-1.0
+ min_val,max_val=x.min(),x.max()
+ scale=(max_val-min_val)/(qmax-qmin)
+ initial_zero_point=qmin-min_val/scale
+ zero_point=qminifinitial_zero_point<qminelseqmaxifinitial_zero_point>qmaxelseinitial_zero_point
+ zero_point=int(zero_point)
+ qx=zero_point+x/scale
+ qx=qx.clamp(qmin,qmax).round().byte()
+ returnQuantizedTensor(tensor=qx,scale=scale,zero_point=zero_point)
+
+
+
+
+[docs]
+defdequantize_tensor(q_x):
+"""
+ Dequantizes a quantized tensor.
+
+ Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew Howard, Hartwig Adam, Dmitry Kalenichenko: `"Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference" <https://openaccess.thecvf.com/content_cvpr_2018/html/Jacob_Quantization_and_Training_CVPR_2018_paper.html>`_ @ CVPR 2018 (2018)
+
+ :param q_x: quantized tensor to be dequantized.
+ :type q_x: QuantizedTensor
+ :return: dequantized tensor.
+ :rtype: torch.Tensor
+ """
+ returnq_x.scale*(q_x.tensor.float()-q_x.zero_point)
+[docs]
+classDistillationBox(object):
+"""
+ A single-stage knowledge distillation framework.
+
+ :param teacher_model: teacher model.
+ :type teacher_model: nn.Module
+ :param student_model: student model.
+ :type student_model: nn.Module
+ :param dataset_dict: dict that contains datasets with IDs of your choice.
+ :type dataset_dict: dict
+ :param train_config: training configuration.
+ :type train_config: dict
+ :param device: target device.
+ :type device: torch.device
+ :param device_ids: target device IDs.
+ :type device_ids: list[int]
+ :param distributed: whether to be in distributed training mode.
+ :type distributed: bool
+ :param lr_factor: multiplier for learning rate.
+ :type lr_factor: float or int
+ :param accelerator: Hugging Face accelerator.
+ :type accelerator: accelerate.Accelerator or None
+ """
+
+[docs]
+ defsetup_data_loaders(self,train_config):
+"""
+ Sets up training and validation data loaders for the current training stage.
+ This method will be internally called when instantiating this class and when calling
+ :meth:`MultiStagesDistillationBox.advance_to_next_stage`.
+
+ :param train_config: training configuration.
+ :type train_config: dict
+ """
+ train_data_loader_config=train_config.get('train_data_loader',dict())
+ if'requires_supp'notintrain_data_loader_config:
+ train_data_loader_config['requires_supp']=True
+
+ val_data_loader_config=train_config.get('val_data_loader',dict())
+ train_data_loader,val_data_loader=\
+ build_data_loaders(self.dataset_dict,[train_data_loader_config,val_data_loader_config],
+ self.distributed,self.accelerator)
+ iftrain_data_loaderisnotNone:
+ self.train_data_loader=train_data_loader
+ ifval_data_loaderisnotNone:
+ self.val_data_loader=val_data_loader
+[docs]
+ defsetup_loss(self,train_config):
+"""
+ Sets up a training loss module for the current training stage.
+ This method will be internally called when instantiating this class and when calling
+ :meth:`MultiStagesDistillationBox.advance_to_next_stage`.
+
+ :param train_config: training configuration.
+ :type train_config: dict
+ """
+ criterion_config=train_config['criterion']
+ self.criterion=get_high_level_loss(criterion_config)
+ logger.info(self.criterion)
+ self.extract_model_loss=get_func2extract_model_output(criterion_config.get('func2extract_model_loss',None))
+
+
+
+[docs]
+ defsetup_pre_post_processes(self,train_config):
+"""
+ Sets up pre/post-epoch/forward processes for the current training stage.
+ This method will be internally called when instantiating this class and when calling
+ :meth:`MultiStagesDistillationBox.advance_to_next_stage`.
+
+ :param train_config: training configuration.
+ :type train_config: dict
+ """
+ pre_epoch_process=default_pre_epoch_process_with_teacher
+ if'pre_epoch_process'intrain_config:
+ pre_epoch_process=get_pre_epoch_proc_func(train_config['pre_epoch_process'])
+ setattr(DistillationBox,'pre_epoch_process',pre_epoch_process)
+ pre_forward_process=default_pre_forward_process
+ if'pre_forward_process'intrain_config:
+ pre_forward_process=get_pre_forward_proc_func(train_config['pre_forward_process'])
+ setattr(DistillationBox,'pre_forward_process',pre_forward_process)
+ post_forward_process=default_post_forward_process
+ if'post_forward_process'intrain_config:
+ post_forward_process=get_post_forward_proc_func(train_config['post_forward_process'])
+
+ setattr(DistillationBox,'post_forward_process',post_forward_process)
+ post_epoch_process=default_post_epoch_process_with_teacher
+ if'post_epoch_process'intrain_config:
+ post_epoch_process=get_post_epoch_proc_func(train_config['post_epoch_process'])
+ setattr(DistillationBox,'post_epoch_process',post_epoch_process)
+
+
+
+[docs]
+ defsetup(self,train_config):
+"""
+ Configures a :class:`DistillationBox`/:class:`MultiStagesDistillationBox` for the current training stage.
+ This method will be internally called when instantiating this class and when calling
+ :meth:`MultiStagesDistillationBox.advance_to_next_stage`.
+
+ :param train_config: training configuration.
+ :type train_config: dict
+ """
+ # Set up train and val data loaders
+ self.setup_data_loaders(train_config)
+
+ # Define teacher and student models used in this stage
+ teacher_config=train_config.get('teacher',dict())
+ student_config=train_config.get('student',dict())
+ self.setup_teacher_student_models(teacher_config,student_config)
+
+ # Define loss function used in this stage
+ self.setup_loss(train_config)
+
+ # Freeze parameters if specified
+ self.teacher_updatable=True
+ ifnotteacher_config.get('requires_grad',True):
+ logger.info('Freezing the whole teacher model')
+ freeze_module_params(self.teacher_model)
+ self.teacher_updatable=False
+
+ ifnotstudent_config.get('requires_grad',True):
+ logger.info('Freezing the whole student model')
+ freeze_module_params(self.student_model)
+
+ # Wrap models if necessary
+ teacher_any_updatable=len(get_updatable_param_names(self.teacher_model))>0
+ self.teacher_model=\
+ wrap_model(self.teacher_model,teacher_config,self.device,self.device_ids,self.distributed,
+ self.teacher_any_frozen,teacher_any_updatable)
+ student_any_updatable=len(get_updatable_param_names(self.student_model))>0
+ self.student_model=\
+ wrap_model(self.student_model,student_config,self.device,self.device_ids,self.distributed,
+ self.student_any_frozen,student_any_updatable)
+
+ # Set up optimizer and scheduler
+ optim_config=train_config.get('optimizer',dict())
+ optimizer_reset=False
+ iflen(optim_config)>0:
+ optim_kwargs=optim_config['kwargs']
+ if'lr'inoptim_kwargs:
+ optim_kwargs['lr']*=self.lr_factor
+
+ module_wise_configs=optim_config.get('module_wise_configs',list())
+ iflen(module_wise_configs)>0:
+ trainable_module_list=list()
+ formodule_wise_configinmodule_wise_configs:
+ module_wise_kwargs=dict()
+ ifisinstance(module_wise_config.get('kwargs',None),dict):
+ module_wise_kwargs.update(module_wise_config['kwargs'])
+
+ if'lr'inmodule_wise_kwargs:
+ module_wise_kwargs['lr']*=self.lr_factor
+
+ target_model= \
+ self.teacher_modelifmodule_wise_config.get('is_teacher',False)elseself.student_model
+ module=get_module(target_model,module_wise_config['module'])
+ module_wise_kwargs['params']=module.parameters()ifisinstance(module,nn.Module)else[module]
+ trainable_module_list.append(module_wise_kwargs)
+ else:
+ trainable_module_list=nn.ModuleList([self.student_model])
+ ifself.teacher_updatable:
+ logger.info('Note that you are training some/all of the modules in the teacher model')
+ trainable_module_list.append(self.teacher_model)
+
+ filters_params=optim_config.get('filters_params',True)
+ self.optimizer= \
+ get_optimizer(trainable_module_list,optim_config['key'],
+ **optim_kwargs,filters_params=filters_params)
+
+ self.optimizer.zero_grad()
+ self.max_grad_norm=optim_config.get('max_grad_norm',None)
+ self.grad_accum_step=optim_config.get('grad_accum_step',1)
+ optimizer_reset=True
+
+ scheduler_config=train_config.get('scheduler',None)
+ ifscheduler_configisnotNoneandlen(scheduler_config)>0:
+ self.lr_scheduler=get_scheduler(self.optimizer,scheduler_config['key'],**scheduler_config['kwargs'])
+ self.scheduling_step=scheduler_config.get('scheduling_step',0)
+ elifoptimizer_reset:
+ self.lr_scheduler=None
+ self.scheduling_step=None
+
+ # Set up accelerator if necessary
+ ifself.acceleratorisnotNone:
+ ifself.teacher_updatable:
+ self.teacher_model,self.student_model,self.optimizer,self.train_data_loader,self.val_data_loader= \
+ self.accelerator.prepare(self.teacher_model,self.student_model,self.optimizer,
+ self.train_data_loader,self.val_data_loader)
+ else:
+ self.teacher_model=self.teacher_model.to(self.accelerator.device)
+ ifself.accelerator.state.use_fp16:
+ self.teacher_model=self.teacher_model.half()
+
+ self.student_model,self.optimizer,self.train_data_loader,self.val_data_loader= \
+ self.accelerator.prepare(self.student_model,self.optimizer,
+ self.train_data_loader,self.val_data_loader)
+
+ # Set up {pre,post}-{epoch,forward} processes
+ self.setup_pre_post_processes(train_config)
+[docs]
+ defpre_epoch_process(self,*args,**kwargs):
+"""
+ Performs a pre-epoch process Shows the summary of results.
+
+ This should be overridden by all subclasses or defined through :meth:`setup_pre_post_processes`.
+ """
+ raiseNotImplementedError()
+
+
+
+[docs]
+ defpre_forward_process(self,*args,**kwargs):
+"""
+ Performs a pre-forward process Shows the summary of results.
+
+ This should be overridden by all subclasses or defined through :meth:`setup_pre_post_processes`.
+ """
+ raiseNotImplementedError()
+[docs]
+ defpost_forward_process(self,*args,**kwargs):
+"""
+ Performs a post-forward process.
+
+ This should be overridden by all subclasses or defined through :meth:`setup_pre_post_processes`.
+ """
+ raiseNotImplementedError()
+
+
+
+[docs]
+ defpost_epoch_process(self,*args,**kwargs):
+"""
+ Performs a post-epoch process.
+
+ This should be overridden by all subclasses or defined through :meth:`setup_pre_post_processes`.
+ """
+ raiseNotImplementedError()
+
+
+
+[docs]
+ defclean_modules(self):
+"""
+ Unfreezes all the teacher and student modules, clears I/O dicts, unregisters forward hook handles,
+ and clears the handle lists.
+ """
+ unfreeze_module_params(self.org_teacher_model)
+ unfreeze_module_params(self.org_student_model)
+ self.teacher_io_dict.clear()
+ self.student_io_dict.clear()
+ for_,module_handleinself.target_teacher_pairs+self.target_student_pairs:
+ module_handle.remove()
+
+ self.target_teacher_pairs.clear()
+ self.target_student_pairs.clear()
+
+
+
+
+
+[docs]
+classMultiStagesDistillationBox(DistillationBox):
+"""
+ A multi-stage knowledge distillation framework. This is a subclass of :class:`DistillationBox`.
+
+ :param teacher_model: teacher model.
+ :type teacher_model: nn.Module
+ :param student_model: student model.
+ :type student_model: nn.Module
+ :param dataset_dict: dict that contains datasets with IDs of your choice.
+ :type dataset_dict: dict
+ :param train_config: training configuration.
+ :type train_config: dict
+ :param device: target device.
+ :type device: torch.device
+ :param device_ids: target device IDs.
+ :type device_ids: list[int]
+ :param distributed: whether to be in distributed training mode.
+ :type distributed: bool
+ :param lr_factor: multiplier for learning rate.
+ :type lr_factor: float or int
+ :param accelerator: Hugging Face accelerator.
+ :type accelerator: accelerate.Accelerator or None
+ """
+ def__init__(self,teacher_model,student_model,dataset_dict,
+ train_config,device,device_ids,distributed,lr_factor,accelerator=None):
+ stage1_config=train_config['stage1']
+ super().__init__(teacher_model,student_model,dataset_dict,
+ stage1_config,device,device_ids,distributed,lr_factor,accelerator)
+ self.train_config=train_config
+ self.stage_number=1
+ self.stage_end_epoch=stage1_config['num_epochs']
+ self.num_epochs=sum(train_config[key]['num_epochs']forkeyintrain_config.keys()ifkey.startswith('stage'))
+ self.current_epoch=0
+ logger.info('Started stage {}'.format(self.stage_number))
+
+
+[docs]
+ defsave_stage_ckpt(self,model,local_model_config):
+"""
+ Saves the checkpoint of ``model`` for the current training stage.
+
+ :param model: model to be saved.
+ :type model: nn.Module
+ :param local_model_config: model configuration at the current training stage.
+ :type local_model_config: dict
+ """
+ dst_ckpt_file_path=local_model_config.get('dst_ckpt',None)
+ ifdst_ckpt_file_pathisnotNone:
+ model_state_dict=model.module.state_dict()ifcheck_if_wrapped(model)elsemodel.state_dict()
+ make_parent_dirs(dst_ckpt_file_path)
+ save_on_master(model_state_dict,dst_ckpt_file_path)
+
+
+
+[docs]
+ defadvance_to_next_stage(self):
+"""
+ Reads the next training stage's configuration in ``train_config`` and advances to the next training stage.
+ """
+ self.save_stage_ckpt(self.teacher_model,self.train_config.get('teacher',dict()))
+ self.save_stage_ckpt(self.student_model,self.train_config.get('student',dict()))
+ self.clean_modules()
+ self.stage_grad_count=0
+ self.stage_number+=1
+ next_stage_config=self.train_config['stage{}'.format(self.stage_number)]
+ self.setup(next_stage_config)
+ self.stage_end_epoch+=next_stage_config['num_epochs']
+ logger.info('Advanced to stage {}'.format(self.stage_number))
+
+
+
+[docs]
+ defpost_epoch_process(self,*args,**kwargs):
+"""
+ Performs a post-epoch process.
+
+ The superclass's post_epoch_process should be overridden by all subclasses or
+ defined through :meth:`DistillationBox.setup_pre_post_processes`.
+ """
+ super().post_epoch_process(*args,**kwargs)
+ self.current_epoch+=1
+ ifself.current_epoch==self.stage_end_epochandself.current_epoch<self.num_epochs:
+ self.advance_to_next_stage()
+
+
+
+
+
+[docs]
+defget_distillation_box(teacher_model,student_model,dataset_dict,
+ train_config,device,device_ids,distributed,lr_factor,accelerator=None):
+"""
+ Gets a distillation box.
+
+ :param teacher_model: teacher model.
+ :type teacher_model: nn.Module
+ :param student_model: student model.
+ :type student_model: nn.Module
+ :param dataset_dict: dict that contains datasets with IDs of your choice.
+ :type dataset_dict: dict
+ :param train_config: training configuration.
+ :type train_config: dict
+ :param device: target device.
+ :type device: torch.device
+ :param device_ids: target device IDs.
+ :type device_ids: list[int]
+ :param distributed: whether to be in distributed training mode.
+ :type distributed: bool
+ :param lr_factor: multiplier for learning rate.
+ :type lr_factor: float or int
+ :param accelerator: Hugging Face accelerator.
+ :type accelerator: accelerate.Accelerator or None
+ :return: distillation box.
+ :rtype: DistillationBox or MultiStagesDistillationBox
+ """
+ if'stage1'intrain_config:
+ returnMultiStagesDistillationBox(teacher_model,student_model,dataset_dict,
+ train_config,device,device_ids,distributed,lr_factor,accelerator)
+ returnDistillationBox(teacher_model,student_model,dataset_dict,train_config,
+ device,device_ids,distributed,lr_factor,accelerator)
+[docs]
+defregister_pre_epoch_proc_func(arg=None,**kwargs):
+"""
+ Registers a pre-epoch process function for :class:`torchdistill.core.distillation.DistillationBox` and
+ :class:`torchdistill.core.training.TrainingBox`.
+
+ :param arg: function to be registered as a pre-epoch process function.
+ :type arg: typing.Callable or None
+ :return: registered pre-epoch process function.
+ :rtype: typing.Callable
+
+ .. note::
+ The function will be registered as an option of the pre-epoch process function.
+ You can choose the registered function by specifying the name of the function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the function with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torchdistill.core.interfaces.registry import register_pre_epoch_proc_func
+ >>> @register_pre_epoch_proc_func(key='my_custom_pre_epoch_proc_func')
+ >>> def new_pre_epoch_proc(self, epoch=None, **kwargs):
+ >>> print('This is my custom pre-epoch process function')
+
+ In the example, ``new_pre_epoch_proc`` function is registered with a key "my_custom_pre_epoch_proc_func".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``new_pre_epoch_proc`` function by
+ "my_custom_pre_epoch_proc_func".
+ """
+ def_register_pre_epoch_proc_func(func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=func.__name__
+
+ PRE_EPOCH_PROC_FUNC_DICT[key]=func
+ returnfunc
+
+ ifcallable(arg):
+ return_register_pre_epoch_proc_func(arg)
+ return_register_pre_epoch_proc_func
+
+
+
+
+[docs]
+defregister_pre_forward_proc_func(arg=None,**kwargs):
+"""
+ Registers a pre-forward process function for :class:`torchdistill.core.distillation.DistillationBox` and
+ :class:`torchdistill.core.training.TrainingBox`.
+
+ :param arg: function to be registered as a pre-forward process function.
+ :type arg: typing.Callable or None
+ :return: registered pre-forward process function.
+ :rtype: typing.Callable
+
+ .. note::
+ The function will be registered as an option of the pre-forward process function.
+ You can choose the registered function by specifying the name of the function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the function with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torchdistill.core.interfaces.registry import register_pre_forward_proc_func
+ >>> @register_pre_forward_proc_func(key='my_custom_pre_forward_proc_func')
+ >>> def new_pre_forward_proc(self, *args, **kwargs):
+ >>> print('This is my custom pre-forward process function')
+
+ In the example, ``new_pre_forward_proc`` function is registered with a key "my_custom_pre_forward_proc_func".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``new_pre_forward_proc`` function by
+ "my_custom_pre_forward_proc_func".
+ """
+ def_register_pre_forward_proc_func(func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=func.__name__
+
+ PRE_FORWARD_PROC_FUNC_DICT[key]=func
+ returnfunc
+
+ ifcallable(arg):
+ return_register_pre_forward_proc_func(arg)
+ return_register_pre_forward_proc_func
+
+
+
+
+[docs]
+defregister_forward_proc_func(arg=None,**kwargs):
+"""
+ Registers a forward process function for :class:`torchdistill.core.distillation.DistillationBox` and
+ :class:`torchdistill.core.training.TrainingBox`.
+
+ :param arg: function to be registered as a forward process function.
+ :type arg: typing.Callable or None
+ :return: registered forward process function.
+ :rtype: typing.Callable
+
+ .. note::
+ The function will be registered as an option of the forward process function.
+ You can choose the registered function by specifying the name of the function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the function with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torchdistill.core.interfaces.registry import register_forward_proc_func
+ >>> @register_forward_proc_func(key='my_custom_forward_proc_func')
+ >>> def new_forward_proc(model, sample_batch, targets=None, supp_dict=None, **kwargs):
+ >>> print('This is my custom forward process function')
+
+ In the example, ``new_forward_proc`` function is registered with a key "my_custom_forward_proc_func".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``new_forward_proc`` function by
+ "my_custom_forward_proc_func".
+ """
+ def_register_forward_proc_func(func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=func.__name__
+
+ FORWARD_PROC_FUNC_DICT[key]=func
+ returnfunc
+
+ ifcallable(arg):
+ return_register_forward_proc_func(arg)
+ return_register_forward_proc_func
+
+
+
+
+[docs]
+defregister_post_forward_proc_func(arg=None,**kwargs):
+"""
+ Registers a post-forward process function for :class:`torchdistill.core.distillation.DistillationBox` and
+ :class:`torchdistill.core.training.TrainingBox`.
+
+ :param arg: function to be registered as a post-forward process function.
+ :type arg: typing.Callable or None
+ :return: registered post-forward process function.
+ :rtype: typing.Callable
+
+ .. note::
+ The function will be registered as an option of the post-forward process function.
+ You can choose the registered function by specifying the name of the function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the function with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torchdistill.core.interfaces.registry import register_post_forward_proc_func
+ >>> @register_post_forward_proc_func(key='my_custom_post_forward_proc_func')
+ >>> def new_post_forward_proc(self, loss, metrics=None, **kwargs):
+ >>> print('This is my custom post-forward process function')
+
+ In the example, ``new_post_forward_proc`` function is registered with a key "my_custom_post_forward_proc_func".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``new_post_forward_proc`` function by
+ "my_custom_post_forward_proc_func".
+ """
+ def_register_post_forward_proc_func(func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=func.__name__
+
+ POST_FORWARD_PROC_FUNC_DICT[key]=func
+ returnfunc
+
+ ifcallable(arg):
+ return_register_post_forward_proc_func(arg)
+ return_register_post_forward_proc_func
+
+
+
+
+[docs]
+defregister_post_epoch_proc_func(arg=None,**kwargs):
+"""
+ Registers a post-epoch process function for :class:`torchdistill.core.distillation.DistillationBox` and
+ :class:`torchdistill.core.training.TrainingBox`.
+
+ :param arg: function to be registered as a post-epoch process function.
+ :type arg: typing.Callable or None
+ :return: registered post-epoch process function.
+ :rtype: typing.Callable
+
+ .. note::
+ The function will be registered as an option of the post-epoch process function.
+ You can choose the registered function by specifying the name of the function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the function with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torchdistill.core.interfaces.registry import register_post_epoch_proc_func
+ >>> @register_post_epoch_proc_func(key='my_custom_post_epoch_proc_func')
+ >>> def new_post_epoch_proc(self, metrics=None, **kwargs):
+ >>> print('This is my custom post-epoch process function')
+
+ In the example, ``new_post_epoch_proc`` function is registered with a key "my_custom_post_epoch_proc_func".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``new_post_epoch_proc`` function by
+ "my_custom_post_epoch_proc_func".
+ """
+ def_register_post_epoch_proc_func(func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=func.__name__
+
+ POST_EPOCH_PROC_FUNC_DICT[key]=func
+ returnfunc
+
+ ifcallable(arg):
+ return_register_post_epoch_proc_func(arg)
+ return_register_post_epoch_proc_func
+
+
+
+
+[docs]
+defget_pre_epoch_proc_func(key):
+"""
+ Gets a registered pre-epoch process function.
+
+ :param key: unique key to identify the registered pre-epoch process function.
+ :type key: str
+ :return: registered pre-epoch process function.
+ :rtype: typing.Callable
+ """
+ ifkeyinPRE_EPOCH_PROC_FUNC_DICT:
+ returnPRE_EPOCH_PROC_FUNC_DICT[key]
+ raiseValueError('No pre-epoch process function `{}` registered'.format(key))
+
+
+
+
+[docs]
+defget_pre_forward_proc_func(key):
+"""
+ Gets a registered pre-forward process function.
+
+ :param key: unique key to identify the registered pre-forward process function.
+ :type key: str
+ :return: registered pre-forward process function.
+ :rtype: typing.Callable
+ """
+ ifkeyinPRE_FORWARD_PROC_FUNC_DICT:
+ returnPRE_FORWARD_PROC_FUNC_DICT[key]
+ raiseValueError('No pre-forward process function `{}` registered'.format(key))
+
+
+
+
+[docs]
+defget_forward_proc_func(key):
+"""
+ Gets a registered forward process function.
+
+ :param key: unique key to identify the registered forward process function.
+ :type key: str
+ :return: registered forward process function.
+ :rtype: typing.Callable
+ """
+ ifkeyinFORWARD_PROC_FUNC_DICT:
+ returnFORWARD_PROC_FUNC_DICT[key]
+ raiseValueError('No forward process function `{}` registered'.format(key))
+
+
+
+
+[docs]
+defget_post_forward_proc_func(key):
+"""
+ Gets a registered post-forward process function.
+
+ :param key: unique key to identify the registered post-forward process function.
+ :type key: str
+ :return: registered post-forward process function.
+ :rtype: typing.Callable
+ """
+ ifkeyinPOST_FORWARD_PROC_FUNC_DICT:
+ returnPOST_FORWARD_PROC_FUNC_DICT[key]
+ raiseValueError('No post-forward process function `{}` registered'.format(key))
+
+
+
+
+[docs]
+defget_post_epoch_proc_func(key):
+"""
+ Gets a registered post-epoch process function.
+
+ :param key: unique key to identify the registered post-epoch process function.
+ :type key: str
+ :return: registered post-epoch process function.
+ :rtype: typing.Callable
+ """
+ ifkeyinPOST_EPOCH_PROC_FUNC_DICT:
+ returnPOST_EPOCH_PROC_FUNC_DICT[key]
+ raiseValueError('No post-epoch process function `{}` registered'.format(key))
+[docs]
+classTrainingBox(object):
+"""
+ A single-stage training framework.
+
+ :param model: model.
+ :type model: nn.Module
+ :param dataset_dict: dict that contains datasets with IDs of your choice.
+ :type dataset_dict: dict
+ :param train_config: training configuration.
+ :type train_config: dict
+ :param device: target device.
+ :type device: torch.device
+ :param device_ids: target device IDs.
+ :type device_ids: list[int]
+ :param distributed: whether to be in distributed training mode.
+ :type distributed: bool
+ :param lr_factor: multiplier for learning rate.
+ :type lr_factor: float or int
+ :param accelerator: Hugging Face accelerator.
+ :type accelerator: accelerate.Accelerator or None
+ """
+
+[docs]
+ defsetup_data_loaders(self,train_config):
+"""
+ Sets up training and validation data loaders for the current training stage.
+ This method will be internally called when instantiating this class and when calling
+ :meth:`MultiStagesTrainingBox.advance_to_next_stage`.
+
+ :param train_config: training configuration.
+ :type train_config: dict
+ """
+ train_data_loader_config=train_config.get('train_data_loader',dict())
+ if'requires_supp'notintrain_data_loader_config:
+ train_data_loader_config['requires_supp']=True
+
+ val_data_loader_config=train_config.get('val_data_loader',dict())
+ train_data_loader,val_data_loader=\
+ build_data_loaders(self.dataset_dict,[train_data_loader_config,val_data_loader_config],
+ self.distributed,self.accelerator)
+ iftrain_data_loaderisnotNone:
+ self.train_data_loader=train_data_loader
+ ifval_data_loaderisnotNone:
+ self.val_data_loader=val_data_loader
+
+
+
+[docs]
+ defsetup_model(self,model_config):
+"""
+ Sets up a model for the current training stage.
+ This method will be internally called when instantiating this class and when calling
+ :meth:`MultiStagesTrainingBox.advance_to_next_stage`.
+
+ :param model_config: model configuration.
+ :type model_config: dict
+ """
+ unwrapped_org_model= \
+ self.org_model.moduleifcheck_if_wrapped(self.org_model)elseself.org_model
+ self.target_model_pairs.clear()
+ ref_model=unwrapped_org_model
+ iflen(model_config)>0or(len(model_config)==0andself.modelisNone):
+ logger.info('[student model]')
+ model_type='original'
+ auxiliary_model_wrapper= \
+ build_auxiliary_model_wrapper(model_config,student_model=unwrapped_org_model,device=self.device,
+ device_ids=self.device_ids,distributed=self.distributed)
+ ifauxiliary_model_wrapperisnotNone:
+ ref_model=auxiliary_model_wrapper
+ model_type=type(ref_model).__name__
+
+ self.model=redesign_model(ref_model,model_config,'student',model_type)
+ src_ckpt_file_path=model_config.get('src_ckpt',None)
+ ifsrc_ckpt_file_pathisnotNone:
+ load_ckpt(src_ckpt_file_path,self.model)
+
+ self.model_any_frozen= \
+ len(model_config.get('frozen_modules',list()))>0ornotmodel_config.get('requires_grad',True)
+ self.target_model_pairs.extend(set_hooks(self.model,ref_model,model_config,self.model_io_dict))
+ self.model_forward_proc=get_forward_proc_func(model_config.get('forward_proc',None))
+
+
+
+[docs]
+ defsetup_loss(self,train_config):
+"""
+ Sets up a training loss module for the current training stage.
+ This method will be internally called when instantiating this class and when calling
+ :meth:`MultiStagesTrainingBox.advance_to_next_stage`.
+
+ :param train_config: training configuration.
+ :type train_config: dict
+ """
+ criterion_config=train_config['criterion']
+ self.criterion=get_high_level_loss(criterion_config)
+ logger.info(self.criterion)
+ self.extract_model_loss=get_func2extract_model_output(criterion_config.get('func2extract_model_loss',None))
+
+
+
+[docs]
+ defsetup_pre_post_processes(self,train_config):
+"""
+ Sets up pre/post-epoch/forward processes for the current training stage.
+ This method will be internally called when instantiating this class and when calling
+ :meth:`MultiStagesTrainingBox.advance_to_next_stage`.
+
+ :param train_config: training configuration.
+ :type train_config: dict
+ """
+ pre_epoch_process=default_pre_epoch_process_without_teacher
+ if'pre_epoch_process'intrain_config:
+ pre_epoch_process=get_pre_epoch_proc_func(train_config['pre_epoch_process'])
+ setattr(TrainingBox,'pre_epoch_process',pre_epoch_process)
+ pre_forward_process=default_pre_forward_process
+ if'pre_forward_process'intrain_config:
+ pre_forward_process=get_pre_forward_proc_func(train_config['pre_forward_process'])
+ setattr(TrainingBox,'pre_forward_process',pre_forward_process)
+ post_forward_process=default_post_forward_process
+ if'post_forward_process'intrain_config:
+ post_forward_process=get_post_forward_proc_func(train_config['post_forward_process'])
+
+ setattr(TrainingBox,'post_forward_process',post_forward_process)
+ post_epoch_process=default_post_epoch_process_without_teacher
+ if'post_epoch_process'intrain_config:
+ post_epoch_process=get_post_epoch_proc_func(train_config['post_epoch_process'])
+ setattr(TrainingBox,'post_epoch_process',post_epoch_process)
+
+
+
+[docs]
+ defsetup(self,train_config):
+"""
+ Configures a :class:`TrainingBox`/:class:`MultiStagesTrainingBox` for the current training stage.
+ This method will be internally called when instantiating this class and when calling
+ :meth:`MultiStagesTrainingBox.advance_to_next_stage`.
+
+ :param train_config: training configuration.
+ :type train_config: dict
+ """
+ # Set up train and val data loaders
+ self.setup_data_loaders(train_config)
+
+ # Define model used in this stage
+ model_config=train_config.get('model',dict())
+ self.setup_model(model_config)
+
+ # Define loss function used in this stage
+ self.setup_loss(train_config)
+
+ # Freeze parameters if specified
+ ifnotmodel_config.get('requires_grad',True):
+ logger.info('Freezing the whole model')
+ freeze_module_params(self.model)
+
+ # Wrap models if necessary
+ any_updatable=len(get_updatable_param_names(self.model))>0
+ self.model=\
+ wrap_model(self.model,model_config,self.device,self.device_ids,self.distributed,
+ self.model_any_frozen,any_updatable)
+
+ # Set up optimizer and scheduler
+ optim_config=train_config.get('optimizer',dict())
+ optimizer_reset=False
+ iflen(optim_config)>0:
+ optim_kwargs=optim_config['kwargs']
+ if'lr'inoptim_kwargs:
+ optim_kwargs['lr']*=self.lr_factor
+
+ module_wise_configs=optim_config.get('module_wise_kwargs',list())
+ iflen(module_wise_configs)>0:
+ trainable_module_list=list()
+ formodule_wise_configinmodule_wise_configs:
+ module_wise_kwargs=dict()
+ ifisinstance(module_wise_config.get('kwargs',None),dict):
+ module_wise_kwargs.update(module_wise_config['kwargs'])
+
+ if'lr'inmodule_wise_kwargs:
+ module_wise_kwargs['lr']*=self.lr_factor
+
+ module=get_module(self.model,module_wise_config['module'])
+ module_wise_kwargs['params']=module.parameters()ifisinstance(module,nn.Module)else[module]
+ trainable_module_list.append(module_wise_kwargs)
+ else:
+ trainable_module_list=nn.ModuleList([self.model])
+
+ filters_params=optim_config.get('filters_params',True)
+ self.optimizer= \
+ get_optimizer(trainable_module_list,optim_config['key'],
+ **optim_kwargs,filters_params=filters_params)
+ self.optimizer.zero_grad()
+ self.max_grad_norm=optim_config.get('max_grad_norm',None)
+ self.grad_accum_step=optim_config.get('grad_accum_step',1)
+ optimizer_reset=True
+
+ scheduler_config=train_config.get('scheduler',None)
+ ifscheduler_configisnotNoneandlen(scheduler_config)>0:
+ self.lr_scheduler=get_scheduler(self.optimizer,scheduler_config['key'],**scheduler_config['kwargs'])
+ self.scheduling_step=scheduler_config.get('scheduling_step',0)
+ elifoptimizer_reset:
+ self.lr_scheduler=None
+ self.scheduling_step=None
+
+ # Set up accelerator if necessary
+ ifself.acceleratorisnotNone:
+ self.model,self.optimizer,self.train_data_loader,self.val_data_loader= \
+ self.accelerator.prepare(self.model,self.optimizer,self.train_data_loader,self.val_data_loader)
+
+ # Set up {pre,post}-{epoch,forward} processes
+ self.setup_pre_post_processes(train_config)
+[docs]
+ defpre_epoch_process(self,*args,**kwargs):
+"""
+ Performs a pre-epoch process Shows the summary of results.
+
+ This should be overridden by all subclasses or defined through :meth:`setup_pre_post_processes`.
+ """
+ raiseNotImplementedError()
+
+
+
+[docs]
+ defpre_forward_process(self,*args,**kwargs):
+"""
+ Performs a pre-forward process Shows the summary of results.
+
+ This should be overridden by all subclasses or defined through :meth:`setup_pre_post_processes`.
+ """
+ raiseNotImplementedError()
+[docs]
+ defpost_forward_process(self,*args,**kwargs):
+"""
+ Performs a post-forward process.
+
+ This should be overridden by all subclasses or defined through :meth:`setup_pre_post_processes`.
+ """
+ raiseNotImplementedError()
+
+
+
+[docs]
+ defpost_epoch_process(self,*args,**kwargs):
+"""
+ Performs a post-epoch process.
+
+ This should be overridden by all subclasses or defined through :meth:`setup_pre_post_processes`.
+ """
+ raiseNotImplementedError()
+
+
+
+[docs]
+ defclean_modules(self):
+"""
+ Unfreezes all the modules, clears an I/O dict, unregisters forward hook handles,
+ and clears the handle lists.
+ """
+ unfreeze_module_params(self.org_model)
+ self.model_io_dict.clear()
+ for_,module_handleinself.target_model_pairs:
+ module_handle.remove()
+ self.target_model_pairs.clear()
+
+
+
+
+
+[docs]
+classMultiStagesTrainingBox(TrainingBox):
+"""
+ A multi-stage training framework. This is a subclass of :class:`TrainingBox`.
+
+ :param model: model.
+ :type model: nn.Module
+ :param dataset_dict: dict that contains datasets with IDs of your choice.
+ :type dataset_dict: dict
+ :param train_config: training configuration.
+ :type train_config: dict
+ :param device: target device.
+ :type device: torch.device
+ :param device_ids: target device IDs.
+ :type device_ids: list[int]
+ :param distributed: whether to be in distributed training mode.
+ :type distributed: bool
+ :param lr_factor: multiplier for learning rate.
+ :type lr_factor: float or int
+ :param accelerator: Hugging Face accelerator.
+ :type accelerator: accelerate.Accelerator or None
+ """
+ def__init__(self,model,dataset_dict,train_config,
+ device,device_ids,distributed,lr_factor,accelerator=None):
+ stage1_config=train_config['stage1']
+ super().__init__(model,dataset_dict,
+ stage1_config,device,device_ids,distributed,lr_factor,accelerator)
+ self.train_config=train_config
+ self.stage_number=1
+ self.stage_end_epoch=stage1_config['num_epochs']
+ self.num_epochs=sum(train_config[key]['num_epochs']forkeyintrain_config.keys()ifkey.startswith('stage'))
+ self.current_epoch=0
+ logger.info('Started stage {}'.format(self.stage_number))
+
+
+[docs]
+ defsave_stage_ckpt(self,model,local_model_config):
+"""
+ Saves the checkpoint of ``model`` for the current training stage.
+
+ :param model: model to be saved.
+ :type model: nn.Module
+ :param local_model_config: model configuration at the current training stage.
+ :type local_model_config: dict
+ """
+ dst_ckpt_file_path=local_model_config.get('dst_ckpt',None)
+ ifdst_ckpt_file_pathisnotNone:
+ model_state_dict=model.module.state_dict()ifcheck_if_wrapped(model)elsemodel.state_dict()
+ make_parent_dirs(dst_ckpt_file_path)
+ save_on_master(model_state_dict,dst_ckpt_file_path)
+
+
+
+[docs]
+ defadvance_to_next_stage(self):
+"""
+ Reads the next training stage's configuration in ``train_config`` and advances to the next training stage.
+ """
+ self.save_stage_ckpt(self.model,self.train_config.get('model',dict()))
+ self.clean_modules()
+ self.stage_grad_count=0
+ self.stage_number+=1
+ next_stage_config=self.train_config['stage{}'.format(self.stage_number)]
+ self.setup(next_stage_config)
+ self.stage_end_epoch+=next_stage_config['num_epochs']
+ logger.info('Advanced to stage {}'.format(self.stage_number))
+
+
+
+[docs]
+ defpost_epoch_process(self,*args,**kwargs):
+"""
+ Performs a post-epoch process.
+
+ The superclass's post_epoch_process should be overridden by all subclasses or
+ defined through :meth:`TrainingBox.setup_pre_post_processes`.
+ """
+ super().post_epoch_process(*args,**kwargs)
+ self.current_epoch+=1
+ ifself.current_epoch==self.stage_end_epochandself.current_epoch<self.num_epochs:
+ self.advance_to_next_stage()
+
+
+
+
+
+[docs]
+defget_training_box(model,dataset_dict,train_config,device,device_ids,distributed,
+ lr_factor,accelerator=None):
+"""
+ Gets a training box.
+
+ :param model: model.
+ :type model: nn.Module
+ :param dataset_dict: dict that contains datasets with IDs of your choice.
+ :type dataset_dict: dict
+ :param train_config: training configuration.
+ :type train_config: dict
+ :param device: target device.
+ :type device: torch.device
+ :param device_ids: target device IDs.
+ :type device_ids: list[int]
+ :param distributed: whether to be in distributed training mode.
+ :type distributed: bool
+ :param lr_factor: multiplier for learning rate.
+ :type lr_factor: float or int
+ :param accelerator: Hugging Face accelerator.
+ :type accelerator: accelerate.Accelerator or None
+ :return: training box.
+ :rtype: TrainingBox or MultiStagesTrainingBox
+ """
+ if'stage1'intrain_config:
+ returnMultiStagesTrainingBox(model,dataset_dict,
+ train_config,device,device_ids,distributed,lr_factor,accelerator)
+ returnTrainingBox(model,dataset_dict,train_config,device,device_ids,distributed,lr_factor,accelerator)
+[docs]
+defregister_dataset(arg=None,**kwargs):
+"""
+ Registers a dataset class or function to instantiate it.
+
+ :param arg: class or function to be registered as a dataset.
+ :type arg: class or typing.Callable or None
+ :return: registered dataset class or function to instantiate it.
+ :rtype: class or typing.Callable
+
+ .. note::
+ The dataset will be registered as an option.
+ You can choose the registered class/function by specifying the name of the class/function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the class/function with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torch.utils.data import Dataset
+ >>> from torchdistill.datasets.registry import register_dataset
+ >>> @register_dataset(key='my_custom_dataset')
+ >>> class CustomDataset(Dataset):
+ >>> def __init__(self, **kwargs):
+ >>> print('This is my custom dataset class')
+
+ In the example, ``CustomDataset`` class is registered with a key "my_custom_dataset".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``CustomDataset`` class by
+ "my_custom_dataset".
+ """
+ def_register_dataset(cls_or_func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=cls_or_func.__name__
+
+ DATASET_DICT[key]=cls_or_func
+ returncls_or_func
+
+ ifcallable(arg):
+ return_register_dataset(arg)
+ return_register_dataset
+
+
+
+
+[docs]
+defregister_collate_func(arg=None,**kwargs):
+"""
+ Registers a collate function.
+
+ :param arg: function to be registered as a collate function.
+ :type arg: typing.Callable or None
+ :return: registered function.
+ :rtype: typing.Callable
+
+ .. note::
+ The collate function will be registered as an option.
+ You can choose the registered function by specifying the name of the function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the function with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torchdistill.datasets.registry import register_collate_func
+ >>>
+ >>> @register_collate_func(key='my_custom_collate')
+ >>> def custom_collate(batch, label):
+ >>> print('This is my custom collate function')
+ >>> return batch, label
+
+ In the example, ``custom_collate`` function is registered with a key "my_custom_collate".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``custom_collate`` function by
+ "my_custom_collate".
+ """
+ def_register_collate_func(func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=func.__name__ifisinstance(func,(BuiltinMethodType,BuiltinFunctionType,FunctionType)) \
+ elsetype(func).__name__
+
+ COLLATE_FUNC_DICT[key]=func
+ returnfunc
+
+ ifcallable(arg):
+ return_register_collate_func(arg)
+ return_register_collate_func
+
+
+
+
+[docs]
+defregister_sample_loader(arg=None,**kwargs):
+"""
+ Registers a sample loader class or function to instantiate it.
+
+ :param arg: class or function to be registered as a sample loader.
+ :type arg: class or typing.Callable or None
+ :return: registered sample loader class or function to instantiate it.
+ :rtype: class
+
+ .. note::
+ The sample loader will be registered as an option.
+ You can choose the registered class/function by specifying the name of the class/function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the class with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torch.utils.data import Sampler
+ >>> from torchdistill.datasets.registry import register_sample_loader
+ >>> @register_sample_loader(key='my_custom_sample_loader')
+ >>> class CustomSampleLoader(Sampler):
+ >>> def __init__(self, **kwargs):
+ >>> print('This is my custom dataset class')
+
+ In the example, ``CustomSampleLoader`` class is registered with a key "my_custom_sample_loader".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``CustomSampleLoader`` class by
+ "my_custom_sample_loader".
+ """
+ def_register_sample_loader_class(cls):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=cls.__name__
+
+ SAMPLE_LOADER_DICT[key]=cls
+ returncls
+
+ ifcallable(arg):
+ return_register_sample_loader_class(arg)
+ return_register_sample_loader_class
+
+
+
+
+[docs]
+defregister_batch_sampler(arg=None,**kwargs):
+"""
+ Registers a batch sampler or function to instantiate it.
+
+ :param arg: function to be registered as a batch sample loader.
+ :type arg: typing.Callable or None
+ :return: registered batch sample loader function.
+ :rtype: typing.Callable
+
+ .. note::
+ The batch sampler will be registered as an option.
+ You can choose the registered class/function by specifying the name of the class/function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the class with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torch.utils.data import Sampler
+ >>> from torchdistill.datasets.registry import register_batch_sampler
+ >>> @register_batch_sampler(key='my_custom_batch_sampler')
+ >>> class CustomSampleLoader(Sampler):
+ >>> def __init__(self, **kwargs):
+ >>> print('This is my custom dataset class')
+
+ In the example, ``CustomSampleLoader`` class is registered with a key "my_custom_batch_sampler".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``CustomSampleLoader`` class by
+ "my_custom_batch_sampler".
+ """
+ def_register_batch_sampler(cls_or_func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=cls_or_func.__name__
+
+ BATCH_SAMPLER_DICT[key]=cls_or_func
+ returncls_or_func
+
+ ifcallable(arg):
+ return_register_batch_sampler(arg)
+ return_register_batch_sampler
+
+
+
+
+[docs]
+defregister_transform(arg=None,**kwargs):
+"""
+ Registers a transform class or function to instantiate it.
+
+ :param arg: class/function to be registered as a transform.
+ :type arg: class or typing.Callable or None
+ :return: registered transform class/function.
+ :rtype: typing.Callable
+
+ .. note::
+ The transform will be registered as an option.
+ You can choose the registered class/function by specifying the name of the class/function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the class with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torch import nn
+ >>> from torchdistill.datasets.registry import register_transform
+ >>> @register_transform(key='my_custom_transform')
+ >>> class CustomTransform(nn.Module):
+ >>> def __init__(self, **kwargs):
+ >>> print('This is my custom transform class')
+
+ In the example, ``CustomTransform`` class is registered with a key "my_custom_transform".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``CustomTransform`` class by
+ "my_custom_transform".
+ """
+ def_register_transform(cls_or_func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=cls_or_func.__name__
+
+ TRANSFORM_DICT[key]=cls_or_func
+ returncls_or_func
+
+ ifcallable(arg):
+ return_register_transform(arg)
+ return_register_transform
+
+
+
+
+[docs]
+defregister_dataset_wrapper(arg=None,**kwargs):
+"""
+ Registers a dataset wrapper class or function to instantiate it.
+
+ :param arg: class/function to be registered as a dataset wrapper.
+ :type arg: class or typing.Callable or None
+ :return: registered dataset wrapper class/function.
+ :rtype: typing.Callable
+
+ .. note::
+ The dataset wrapper will be registered as an option.
+ You can choose the registered class/function by specifying the name of the class/function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the class with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torch.utils.data import Dataset
+ >>> from torchdistill.datasets.registry import register_dataset_wrapper
+ >>> @register_transform(key='my_custom_dataset_wrapper')
+ >>> class CustomDatasetWrapper(Dataset):
+ >>> def __init__(self, **kwargs):
+ >>> print('This is my custom dataset wrapper class')
+
+ In the example, ``CustomDatasetWrapper`` class is registered with a key "my_custom_dataset_wrapper".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``CustomDatasetWrapper`` class by
+ "my_custom_dataset_wrapper".
+ """
+ def_register_dataset_wrapper(cls_or_func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=cls_or_func.__name__
+
+ DATASET_WRAPPER_DICT[key]=cls_or_func
+ returncls_or_func
+
+ ifcallable(arg):
+ return_register_dataset_wrapper(arg)
+ return_register_dataset_wrapper
+
+
+
+
+[docs]
+defget_dataset(key):
+"""
+ Gets a registered dataset class or function to instantiate it.
+
+ :param key: unique key to identify the registered dataset class/function.
+ :type key: str
+ :return: registered dataset class or function to instantiate it.
+ :rtype: class or typing.Callable
+ """
+ ifkeyisNone:
+ returnNone
+ elifkeyinDATASET_DICT:
+ returnDATASET_DICT[key]
+ raiseValueError('No dataset `{}` registered'.format(key))
+[docs]
+defget_sample_loader(key):
+"""
+ Gets a registered sample loader class or function to instantiate it.
+
+ :param key: unique key to identify the registered sample loader class or function to instantiate it.
+ :type key: str
+ :return: registered sample loader class or function to instantiate it.
+ :rtype: class or typing.Callable
+ """
+ ifkeyisNone:
+ returnNone
+ elifkeyinSAMPLE_LOADER_DICT:
+ returnSAMPLE_LOADER_DICT[key]
+ raiseValueError('No sample loader `{}` registered.'.format(key))
+
+
+
+
+[docs]
+defget_batch_sampler(key):
+"""
+ Gets a registered batch sampler class or function to instantiate it.
+
+ :param key: unique key to identify the registered batch sampler class or function to instantiate it.
+ :type key: str
+ :return: registered batch sampler class or function to instantiate it.
+ :rtype: class or typing.Callable
+ """
+ ifkeyisNone:
+ returnNone
+
+ ifkeynotinBATCH_SAMPLER_DICTandkey!='BatchSampler':
+ raiseValueError('No batch sampler `{}` registered.'.format(key))
+ returnBATCH_SAMPLER_DICT[key]
+
+
+
+
+[docs]
+defget_transform(key):
+"""
+ Gets a registered transform class or function to instantiate it.
+
+ :param key: unique key to identify the registered transform class or function to instantiate it.
+ :type key: str
+ :return: registered transform class or function to instantiate it.
+ :rtype: class or typing.Callable
+ """
+ ifkeyinTRANSFORM_DICT:
+ returnTRANSFORM_DICT[key]
+ raiseValueError('No transform `{}` registered.'.format(key))
+
+
+
+
+[docs]
+defget_dataset_wrapper(key):
+"""
+ Gets a registered dataset wrapper class or function to instantiate it.
+
+ :param key: unique key to identify the registered dataset wrapper class or function to instantiate it.
+ :type key: str
+ :return: registered dataset wrapper class or function to instantiate it.
+ :rtype: class or typing.Callable
+ """
+ ifkeyinDATASET_WRAPPER_DICT:
+ returnDATASET_WRAPPER_DICT[key]
+ raiseValueError('No dataset wrapper `{}` registered.'.format(key))
+[docs]
+defdefault_idx2subpath(index):
+"""
+ Converts index to a file path including a parent dir name, which consists of the last four digits of the index.
+
+ :param index: index.
+ :type index: int
+ :return: file path with a parent directory.
+ :rtype: str
+ """
+ digits_str='{:04d}'.format(index)
+ returnos.path.join(digits_str[-4:],digits_str)
+
+
+
+
+[docs]
+classBaseDatasetWrapper(Dataset):
+"""
+ A base dataset wrapper. This is a subclass of :class:`torch.utils.data.Dataset`.
+
+ :param org_dataset: original dataset to be wrapped.
+ :type org_dataset: torch.utils.data.Dataset
+ """
+ def__init__(self,org_dataset):
+ self.org_dataset=org_dataset
+
+ def__getitem__(self,index):
+ sample,target=self.org_dataset.__getitem__(index)
+ returnsample,target,dict()
+
+ def__len__(self):
+ returnlen(self.org_dataset)
+
+
+
+
+[docs]
+classCacheableDataset(BaseDatasetWrapper):
+"""
+ A dataset wrapper that additionally loads cached files in ``cache_dir_path`` if exists.
+
+ :param org_dataset: original dataset to be wrapped.
+ :type org_dataset: torch.utils.data.Dataset
+ :param cache_dir_path: cache directory path.
+ :type cache_dir_path: str
+ :param idx2subpath_func: function to convert a sample index to a file path.
+ :type idx2subpath_func: typing.Callable or None
+ :param ext: cache file extension.
+ :type ext: str
+ """
+ def__init__(self,org_dataset,cache_dir_path,idx2subpath_func=None,ext='.pt'):
+ super().__init__(org_dataset)
+ self.cache_dir_path=cache_dir_path
+ self.idx2subath_func=strifidx2subpath_funcisNoneelseidx2subpath_func
+ self.ext=ext
+
+ def__getitem__(self,index):
+ sample,target,supp_dict=super().__getitem__(index)
+ cache_file_path=os.path.join(self.cache_dir_path,self.idx2subath_func(index)+self.ext)
+ iffile_util.check_if_exists(cache_file_path):
+ cached_data=torch.load(cache_file_path)
+ supp_dict['cached_data']=cached_data
+
+ supp_dict['cache_file_path']=cache_file_path
+ returnsample,target,supp_dict
+[docs]
+defregister_low_level_loss(arg=None,**kwargs):
+"""
+ Registers a low-level loss class or function to instantiate it.
+
+ :param arg: class or function to be registered as a low-level loss.
+ :type arg: class or typing.Callable or None
+ :return: registered low-level loss class or function to instantiate it.
+ :rtype: class or typing.Callable
+
+ .. note::
+ The low-level loss will be registered as an option.
+ You can choose the registered class/function by specifying the name of the class/function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the class/function with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torch import nn
+ >>> from torchdistill.losses.registry import register_low_level_loss
+ >>>
+ >>> @register_low_level_loss(key='my_custom_low_level_loss')
+ >>> class CustomLowLevelLoss(nn.Module):
+ >>> def __init__(self, **kwargs):
+ >>> print('This is my custom low-level loss class')
+
+ In the example, ``CustomLowLevelLoss`` class is registered with a key "my_custom_low_level_loss".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``CustomLowLevelLoss`` class by
+ "my_custom_low_level_loss".
+ """
+ def_register_low_level_loss(cls_or_func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=cls_or_func.__name__
+
+ LOW_LEVEL_LOSS_DICT[key]=cls_or_func
+ returncls_or_func
+
+ ifcallable(arg):
+ return_register_low_level_loss(arg)
+ return_register_low_level_loss
+
+
+
+
+[docs]
+defregister_mid_level_loss(arg=None,**kwargs):
+"""
+ Registers a middle-level loss class or function to instantiate it.
+
+ :param arg: class or function to be registered as a middle-level loss.
+ :type arg: class or typing.Callable or None
+ :return: registered middle-level loss class or function to instantiate it.
+ :rtype: class or typing.Callable
+
+ .. note::
+ The middle-level loss will be registered as an option.
+ You can choose the registered class/function by specifying the name of the class/function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the class/function with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torch import nn
+ >>> from torchdistill.losses.registry import register_mid_level_loss
+ >>>
+ >>> @register_mid_level_loss(key='my_custom_mid_level_loss')
+ >>> class CustomMidLevelLoss(nn.Module):
+ >>> def __init__(self, **kwargs):
+ >>> print('This is my custom middle-level loss class')
+
+ In the example, ``CustomMidLevelLoss`` class is registered with a key "my_custom_mid_level_loss".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``CustomMidLevelLoss`` class by
+ "my_custom_mid_level_loss".
+ """
+ def_register_mid_level_loss(cls_or_func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=cls_or_func.__name__
+
+ MIDDLE_LEVEL_LOSS_DICT[key]=cls_or_func
+ returncls_or_func
+
+ ifcallable(arg):
+ return_register_mid_level_loss(arg)
+ return_register_mid_level_loss
+
+
+
+
+[docs]
+defregister_high_level_loss(arg=None,**kwargs):
+"""
+ Registers a high-level loss class or function to instantiate it.
+
+ :param arg: class or function to be registered as a high-level loss.
+ :type arg: class or typing.Callable or None
+ :return: registered high-level loss class or function to instantiate it.
+ :rtype: class or typing.Callable
+
+ .. note::
+ The high-level loss will be registered as an option.
+ You can choose the registered class/function by specifying the name of the class/function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the class/function with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torch import nn
+ >>> from torchdistill.losses.registry import register_high_level_loss
+ >>>
+ >>> @register_high_level_loss(key='my_custom_high_level_loss')
+ >>> class CustomHighLevelLoss(nn.Module):
+ >>> def __init__(self, **kwargs):
+ >>> print('This is my custom high-level loss class')
+
+ In the example, ``CustomHighLevelLoss`` class is registered with a key "my_custom_high_level_loss".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``CustomHighLevelLoss`` class by
+ "my_custom_high_level_loss".
+ """
+ def_register_high_level_loss(cls_or_func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=cls_or_func.__name__
+
+ HIGH_LEVEL_LOSS_DICT[key]=cls_or_func
+ returncls_or_func
+
+ ifcallable(arg):
+ return_register_high_level_loss(arg)
+ return_register_high_level_loss
+
+
+
+
+[docs]
+defregister_loss_wrapper(arg=None,**kwargs):
+"""
+ Registers a loss wrapper class or function to instantiate it.
+
+ :param arg: class or function to be registered as a loss wrapper.
+ :type arg: class or typing.Callable or None
+ :return: registered loss wrapper class or function to instantiate it.
+ :rtype: class or typing.Callable
+
+ .. note::
+ The loss wrapper will be registered as an option.
+ You can choose the registered class/function by specifying the name of the class/function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the class/function with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torch import nn
+ >>> from torchdistill.losses.registry import register_loss_wrapper
+ >>>
+ >>> @register_loss_wrapper(key='my_custom_loss_wrapper')
+ >>> class CustomLossWrapper(nn.Module):
+ >>> def __init__(self, **kwargs):
+ >>> print('This is my custom loss wrapper class')
+
+ In the example, ``CustomHighLevelLoss`` class is registered with a key "my_custom_loss_wrapper".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``CustomLossWrapper`` class by
+ "my_custom_loss_wrapper".
+ """
+ def_register_loss_wrapper(cls_or_func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=cls_or_func.__name__
+
+ LOSS_WRAPPER_DICT[key]=cls_or_func
+ returncls_or_func
+
+ ifcallable(arg):
+ return_register_loss_wrapper(arg)
+ return_register_loss_wrapper
+
+
+
+
+[docs]
+defregister_func2extract_model_output(arg=None,**kwargs):
+"""
+ Registers a function to extract model output.
+
+ :param arg: function to be registered for extracting model output.
+ :type arg: typing.Callable or None
+ :return: registered function.
+ :rtype: typing.Callable
+
+ .. note::
+ The function to extract model output will be registered as an option.
+ You can choose the registered function by specifying the name of the function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the function with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torchdistill.losses.registry import register_func2extract_model_output
+ >>>
+ >>> @register_func2extract_model_output(key='my_custom_function2extract_model_output')
+ >>> def custom_func2extract_model_output(batch, label):
+ >>> print('This is my custom collate function')
+ >>> return batch, label
+
+ In the example, ``custom_func2extract_model_output`` function is registered with a key "my_custom_function2extract_model_output".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``custom_func2extract_model_output`` function by
+ "my_custom_function2extract_model_output".
+ """
+ def_register_func2extract_model_output(func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=func.__name__
+
+ FUNC2EXTRACT_MODEL_OUTPUT_DICT[key]=func
+ returnfunc
+
+ ifcallable(arg):
+ return_register_func2extract_model_output(arg)
+ return_register_func2extract_model_output
+
+
+
+
+[docs]
+defget_low_level_loss(key,**kwargs):
+"""
+ Gets a registered (low-level) loss module.
+
+ :param key: unique key to identify the registered loss class/function.
+ :type key: str
+ :return: registered loss class or function to instantiate it.
+ :rtype: nn.Module
+ """
+ ifkeyinLOSS_DICT:
+ returnLOSS_DICT[key](**kwargs)
+ elifkeyinLOW_LEVEL_LOSS_DICT:
+ returnLOW_LEVEL_LOSS_DICT[key](**kwargs)
+ raiseValueError('No loss `{}` registered'.format(key))
+
+
+
+
+[docs]
+defget_mid_level_loss(mid_level_criterion_config,criterion_wrapper_config=None):
+"""
+ Gets a registered middle-level loss module.
+
+ :param mid_level_criterion_config: middle-level loss configuration to identify and instantiate the registered middle-level loss class.
+ :type mid_level_criterion_config: dict
+ :param criterion_wrapper_config: middle-level loss configuration to identify and instantiate the registered middle-level loss class.
+ :type criterion_wrapper_config: dict
+ :return: registered middle-level loss class or function to instantiate it.
+ :rtype: nn.Module
+ """
+ loss_key=mid_level_criterion_config['key']
+ mid_level_loss=MIDDLE_LEVEL_LOSS_DICT[loss_key](**mid_level_criterion_config['kwargs']) \
+ ifloss_keyinMIDDLE_LEVEL_LOSS_DICTelseget_low_level_loss(loss_key,**mid_level_criterion_config['kwargs'])
+ ifcriterion_wrapper_configisNoneorlen(criterion_wrapper_config)==0:
+ returnmid_level_loss
+ returnget_loss_wrapper(mid_level_loss,criterion_wrapper_config)
+
+
+
+
+[docs]
+defget_high_level_loss(criterion_config):
+"""
+ Gets a registered high-level loss module.
+
+ :param criterion_config: high-level loss configuration to identify and instantiate the registered high-level loss class.
+ :type criterion_config: dict
+ :return: registered high-level loss class or function to instantiate it.
+ :rtype: nn.Module
+ """
+ criterion_key=criterion_config['key']
+ args=criterion_config.get('args',None)
+ kwargs=criterion_config.get('kwargs',None)
+ ifargsisNone:
+ args=list()
+ ifkwargsisNone:
+ kwargs=dict()
+ ifcriterion_keyinHIGH_LEVEL_LOSS_DICT:
+ returnHIGH_LEVEL_LOSS_DICT[criterion_key](*args,**kwargs)
+ raiseValueError('No high-level loss `{}` registered'.format(criterion_key))
+
+
+
+
+[docs]
+defget_loss_wrapper(mid_level_loss,criterion_wrapper_config):
+"""
+ Gets a registered loss wrapper module.
+
+ :param mid_level_loss: middle-level loss module.
+ :type mid_level_loss: nn.Module
+ :param criterion_wrapper_config: loss wrapper configuration to identify and instantiate the registered loss wrapper class.
+ :type criterion_wrapper_config: dict
+ :return: registered loss wrapper class or function to instantiate it.
+ :rtype: nn.Module
+ """
+ wrapper_key=criterion_wrapper_config['key']
+ args=criterion_wrapper_config.get('args',None)
+ kwargs=criterion_wrapper_config.get('kwargs',None)
+ ifargsisNone:
+ args=list()
+ ifkwargsisNone:
+ kwargs=dict()
+ ifwrapper_keyinLOSS_WRAPPER_DICT:
+ returnLOSS_WRAPPER_DICT[wrapper_key](mid_level_loss,*args,**kwargs)
+ raiseValueError('No loss wrapper `{}` registered'.format(wrapper_key))
+
+
+
+
+[docs]
+defget_func2extract_model_output(key):
+"""
+ Gets a registered function to extract model output.
+
+ :param key: unique key to identify the registered function to extract model output.
+ :type key: str
+ :return: registered function to extract model output.
+ :rtype: typing.Callable
+ """
+ ifkeyisNone:
+ key='extract_model_loss_dict'
+ ifkeyinFUNC2EXTRACT_MODEL_OUTPUT_DICT:
+ returnFUNC2EXTRACT_MODEL_OUTPUT_DICT[key]
+ raiseValueError('No function to extract original output `{}` registered'.format(key))
+[docs]
+defsetup_log_file(log_file_path):
+"""
+ Sets a file handler with ``log_file_path`` to write a log file.
+
+ :param log_file_path: log file path.
+ :type log_file_path: str
+ """
+ make_parent_dirs(log_file_path)
+ fh=FileHandler(filename=log_file_path,mode='w')
+ fh.setFormatter(Formatter(LOGGING_FORMAT))
+ def_logger.addHandler(fh)
+
+
+
+
+[docs]
+classSmoothedValue(object):
+"""
+ A deque-based value object tracks a series of values and provides access to smoothed values
+ over a window or the global series average. The original implementation is https://github.com/pytorch/vision/blob/main/references/classification/utils.py
+
+ :param window_size: window size.
+ :type window_size: int
+ :param fmt: text format.
+ :type fmt: str or None
+ """
+
+ def__init__(self,window_size=20,fmt=None):
+ iffmtisNone:
+ fmt="{median:.4f} ({global_avg:.4f})"
+ self.deque=deque(maxlen=window_size)
+ self.total=0.0
+ self.count=0
+ self.fmt=fmt
+
+
+[docs]
+ defupdate(self,value,n=1):
+"""
+ Appends ``value``.
+
+ :param value: value to be added.
+ :type value: float or int
+ :param n: sample count.
+ :type n: int
+ """
+ self.deque.append(value)
+ self.count+=n
+ self.total+=value*n
+
+
+
+[docs]
+ defsynchronize_between_processes(self):
+"""
+ Synchronizes between processes.
+
+ .. warning::
+ It does not synchronize the deque.
+ """
+ ifnotis_dist_avail_and_initialized():
+ return
+
+ t=torch.tensor([self.count,self.total],dtype=torch.float64,device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t=t.tolist()
+ self.count=int(t[0])
+ self.total=t[1]
+[docs]
+defregister_model(arg=None,**kwargs):
+"""
+ Registers a model class or function to instantiate it.
+
+ :param arg: class or function to be registered as a model.
+ :type arg: class or typing.Callable or None
+ :return: registered model class or function to instantiate it.
+ :rtype: class or typing.Callable
+
+ .. note::
+ The model will be registered as an option.
+ You can choose the registered class/function by specifying the name of the class/function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the class/function with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torch import nn
+ >>> from torchdistill.models.registry import register_model
+ >>>
+ >>> @register_model(key='my_custom_model')
+ >>> class CustomModel(nn.Module):
+ >>> def __init__(self, **kwargs):
+ >>> print('This is my custom model class')
+
+ In the example, ``CustomModel`` class is registered with a key "my_custom_model".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``CustomModel`` class by
+ "my_custom_model".
+ """
+ def_register_model(cls):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=cls.__name__
+
+ MODEL_DICT[key]=cls
+ returncls
+
+ ifcallable(arg):
+ return_register_model(arg)
+ return_register_model
+
+
+
+
+[docs]
+defregister_adaptation_module(arg=None,**kwargs):
+"""
+ Registers an adaptation module class or function to instantiate it.
+
+ :param arg: class or function to be registered as an adaptation module.
+ :type arg: class or typing.Callable or None
+ :return: registered adaptation module class or function to instantiate it.
+ :rtype: class or typing.Callable
+
+ .. note::
+ The adaptation module will be registered as an option.
+ You can choose the registered class/function by specifying the name of the class/function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the class/function with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torch import nn
+ >>> from torchdistill.models.registry import register_adaptation_module
+ >>>
+ >>> @register_adaptation_module(key='my_custom_adaptation_module')
+ >>> class CustomAdaptationModule(nn.Module):
+ >>> def __init__(self, **kwargs):
+ >>> print('This is my custom adaptation module class')
+
+ In the example, ``CustomAdaptationModule`` class is registered with a key "my_custom_adaptation_module".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``CustomAdaptationModule`` class by
+ "my_custom_adaptation_module".
+ """
+ def_register_adaptation_module(cls_or_func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=cls_or_func.__name__
+
+ ADAPTATION_MODULE_DICT[key]=cls_or_func
+ returncls_or_func
+
+ ifcallable(arg):
+ return_register_adaptation_module(arg)
+ return_register_adaptation_module
+
+
+
+
+[docs]
+defregister_auxiliary_model_wrapper(arg=None,**kwargs):
+"""
+ Registers an auxiliary model wrapper class or function to instantiate it.
+
+ :param arg: class or function to be registered as an auxiliary model wrapper.
+ :type arg: class or typing.Callable or None
+ :return: registered auxiliary model wrapper class or function to instantiate it.
+ :rtype: class or typing.Callable
+
+ .. note::
+ The auxiliary model wrapper will be registered as an option.
+ You can choose the registered class/function by specifying the name of the class/function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the class/function with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torch import nn
+ >>> from torchdistill.models.registry import register_auxiliary_model_wrapper
+ >>>
+ >>> @register_auxiliary_model_wrapper(key='my_custom_auxiliary_model_wrapper')
+ >>> class CustomAuxiliaryModelWrapper(nn.Module):
+ >>> def __init__(self, **kwargs):
+ >>> print('This is my custom auxiliary model wrapper class')
+
+ In the example, ``CustomAuxiliaryModelWrapper`` class is registered with a key "my_custom_auxiliary_model_wrapper".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``CustomAuxiliaryModelWrapper`` class by
+ "my_custom_auxiliary_model_wrapper".
+ """
+ def_register_auxiliary_model_wrapper(cls_or_func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=cls_or_func.__name__
+
+ AUXILIARY_MODEL_WRAPPER_DICT[key]=cls_or_func
+ returncls_or_func
+
+ ifcallable(arg):
+ return_register_auxiliary_model_wrapper(arg)
+ return_register_auxiliary_model_wrapper
+
+
+
+
+[docs]
+defget_model(key,repo_or_dir=None,*args,**kwargs):
+"""
+ Gets a model from the model registry.
+
+ :param key: model key.
+ :type key: str
+ :param repo_or_dir: ``repo_or_dir`` for torch.hub.load.
+ :type repo_or_dir: str or None
+ :return: model.
+ :rtype: nn.Module
+ """
+ ifkeyinMODEL_DICTandrepo_or_dirisNone:
+ returnMODEL_DICT[key](*args,**kwargs)
+ elifrepo_or_dirisnotNone:
+ returntorch.hub.load(repo_or_dir,key,*args,**kwargs)
+ raiseValueError('model_name `{}` is not expected'.format(key))
+[docs]
+defbuild_sequential_container(module_dict):
+"""
+ Builds sequential container (nn.Sequential) from ``module_dict``.
+
+ :param module_dict: module dict to build sequential to build a sequential container.
+ :type module_dict: nn.ModuleDict or collections.OrderedDict
+ :return: sequential container.
+ :rtype: nn.Sequential
+ """
+ forkeyinmodule_dict.keys():
+ value=module_dict[key]
+ ifisinstance(value,OrderedDict):
+ value=build_sequential_container(value)
+ module_dict[key]=value
+ elifnotisinstance(value,Module):
+ raiseValueError('module type `{}` is not expected'.format(type(value)))
+ returnSequential(module_dict)
+
+
+
+
+[docs]
+defredesign_model(org_model,model_config,model_label,model_type='original'):
+"""
+ Redesigns ``org_model`` and returns a new separate model e.g.,
+
+ * prunes some modules from ``org_model``,
+ * freezes parameters of some modules in ``org_model``, and
+ * adds adaptation module(s) to ``org_model`` as a new separate model.
+
+ .. note::
+ The parameters and states of modules in ``org_model`` will be kept in a new redesigned model.
+
+ :param org_model: original model to be redesigned.
+ :type org_model: nn.Module
+ :param model_config: configuration to redesign ``org_model``.
+ :type model_config: dict
+ :param model_label: model label (e.g., 'teacher', 'student') to be printed just for debugging purpose.
+ :type model_label: str
+ :param model_type: model type (e.g., 'original', name of model class, etc) to be printed just for debugging purpose.
+ :type model_type: str
+ :return: redesigned model.
+ :rtype: nn.Module
+ """
+ frozen_module_path_set=set(model_config.get('frozen_modules',list()))
+ module_paths=model_config.get('sequential',list())
+ ifnotisinstance(module_paths,list)orlen(module_paths)==0:
+ logger.info('Using the {} model'.format(model_type))
+ iflen(frozen_module_path_set)>0:
+ logger.info('Frozen module(s): {}'.format(frozen_module_path_set))
+
+ isinstance_str='instance('
+ forfrozen_module_pathinfrozen_module_path_set:
+ iffrozen_module_path.startswith(isinstance_str)andfrozen_module_path.endswith(')'):
+ target_cls=nn.__dict__[frozen_module_path[len(isinstance_str):-1]]
+ forminorg_model.modules():
+ ifisinstance(m,target_cls):
+ freeze_module_params(m)
+ else:
+ module=get_module(org_model,frozen_module_path)
+ freeze_module_params(module)
+ returnorg_model
+
+ logger.info('Redesigning the {} model with {}'.format(model_label,module_paths))
+ iflen(frozen_module_path_set)>0:
+ logger.info('Frozen module(s): {}'.format(frozen_module_path_set))
+
+ module_dict=OrderedDict()
+ adaptation_dict=model_config.get('adaptations',dict())
+
+ forfrozen_module_pathinfrozen_module_path_set:
+ module=get_module(org_model,frozen_module_path)
+ freeze_module_params(module)
+
+ formodule_pathinmodule_paths:
+ ifmodule_path.startswith('+'):
+ module_path=module_path[1:]
+ adaptation_config=adaptation_dict[module_path]
+ module=get_adaptation_module(adaptation_config['key'],**adaptation_config['kwargs'])
+ else:
+ module=get_module(org_model,module_path)
+
+ ifmodule_pathinfrozen_module_path_set:
+ freeze_module_params(module)
+
+ add_submodule(module,module_path,module_dict)
+ returnbuild_sequential_container(module_dict)
+[docs]
+classAuxiliaryModelWrapper(nn.Module):
+"""
+ An abstract auxiliary model wrapper.
+
+ :meth:`forward`, :meth:`secondary_forward`, and :meth:`post_epoch_process` should be overridden by all subclasses.
+ """
+ def__init__(self):
+ super().__init__()
+
+ defsecondary_forward(self,*args,**kwargs):
+ pass
+
+ defpost_epoch_process(self,*args,**kwargs):
+ pass
+
+
+
+
+[docs]
+@register_auxiliary_model_wrapper
+classEmptyModule(AuxiliaryModelWrapper):
+"""
+ An empty auxiliary model wrapper. This module returns input as output and is useful when you want to replace
+ your teacher/student model with an empty model for saving inference time.
+ e.g., Multi-stage knowledge distillation may have some stages that do not require either teacher or student models.
+ """
+ def__init__(self,**kwargs):
+ super().__init__()
+
+ defforward(self,*args,**kwargs):
+ returnargs[0]ifisinstance(args,tuple)andlen(args)==1elseargs
+
+
+
+
+[docs]
+classParaphraser4FactorTransfer(nn.Module):
+"""
+ Paraphraser for factor transfer (FT). This module is used at the 1st and 2nd stages of FT method.
+
+ Jangho Kim, Seonguk Park, Nojun Kwak: `"Paraphrasing Complex Network: Network Compression via Factor Transfer" <https://papers.neurips.cc/paper_files/paper/2018/hash/6d9cb7de5e8ac30bd5e8734bc96a35c1-Abstract.html>`_ @ NeurIPS 2018 (2018)
+
+ :param k: paraphrase rate.
+ :type k: float
+ :param num_input_channels: number of input channels.
+ :type num_input_channels: int
+ :param kernel_size: ``kernel_size`` for Conv2d.
+ :type kernel_size: int
+ :param stride: ``stride`` for Conv2d.
+ :type stride: int
+ :param padding: ``padding`` for Conv2d.
+ :type padding: int
+ :param uses_bn: if True, uses BatchNorm2d.
+ :type uses_bn: bool
+ :param uses_decoder: if True, uses decoder in :meth:`forward`.
+ :type uses_decoder: bool
+ """
+ @staticmethod
+ defmake_tail_modules(num_output_channels,uses_bn):
+ leaky_relu=nn.LeakyReLU(0.1)
+ ifuses_bn:
+ return[nn.BatchNorm2d(num_output_channels),leaky_relu]
+ return[leaky_relu]
+
+ @classmethod
+ defmake_enc_modules(cls,num_input_channels,num_output_channels,kernel_size,stride,padding,uses_bn):
+ return[
+ nn.Conv2d(num_input_channels,num_output_channels,kernel_size,stride=stride,padding=padding),
+ *cls.make_tail_modules(num_output_channels,uses_bn)
+ ]
+
+ @classmethod
+ defmake_dec_modules(cls,num_input_channels,num_output_channels,kernel_size,stride,padding,uses_bn):
+ return[
+ nn.ConvTranspose2d(num_input_channels,num_output_channels,kernel_size,stride=stride,padding=padding),
+ *cls.make_tail_modules(num_output_channels,uses_bn)
+ ]
+
+ def__init__(self,k,num_input_channels,kernel_size=3,stride=1,padding=1,uses_bn=True,uses_decoder=True):
+ super().__init__()
+ self.paraphrase_rate=k
+ num_enc_output_channels=int(num_input_channels*k)
+ self.encoder=nn.Sequential(
+ *self.make_enc_modules(num_input_channels,num_input_channels,
+ kernel_size,stride,padding,uses_bn),
+ *self.make_enc_modules(num_input_channels,num_enc_output_channels,
+ kernel_size,stride,padding,uses_bn),
+ *self.make_enc_modules(num_enc_output_channels,num_enc_output_channels,
+ kernel_size,stride,padding,uses_bn)
+ )
+ self.decoder=nn.Sequential(
+ *self.make_dec_modules(num_enc_output_channels,num_enc_output_channels,
+ kernel_size,stride,padding,uses_bn),
+ *self.make_dec_modules(num_enc_output_channels,num_input_channels,
+ kernel_size,stride,padding,uses_bn),
+ *self.make_dec_modules(num_input_channels,num_input_channels,
+ kernel_size,stride,padding,uses_bn)
+ )
+ self.uses_decoder=uses_decoder
+
+ defforward(self,z):
+ ifself.uses_decoder:
+ returnself.decoder(self.encoder(z))
+ returnself.encoder(z)
+
+
+
+
+[docs]
+classTranslator4FactorTransfer(nn.Sequential):
+"""
+ Translator for factor transfer (FT). This module is used at the 2nd stage of FT method.
+ Note that "the student translator has the same three convolution layers as the paraphraser".
+
+ Jangho Kim, Seonguk Park, Nojun Kwak: `"Paraphrasing Complex Network: Network Compression via Factor Transfer" <https://papers.neurips.cc/paper_files/paper/2018/hash/6d9cb7de5e8ac30bd5e8734bc96a35c1-Abstract.html>`_ @ NeurIPS 2018 (2018)
+
+ :param num_input_channels: number of input channels.
+ :type num_input_channels: int
+ :param kernel_size: ``kernel_size`` for Conv2d.
+ :type kernel_size: int
+ :param stride: ``stride`` for Conv2d.
+ :type stride: int
+ :param padding: ``padding`` for Conv2d.
+ :type padding: int
+ :param uses_bn: if True, uses BatchNorm2d.
+ :type uses_bn: bool
+ """
+ def__init__(self,num_input_channels,num_output_channels,kernel_size=3,stride=1,padding=1,uses_bn=True):
+ super().__init__(
+ *Paraphraser4FactorTransfer.make_enc_modules(num_input_channels,num_input_channels,
+ kernel_size,stride,padding,uses_bn),
+ *Paraphraser4FactorTransfer.make_enc_modules(num_input_channels,num_output_channels,
+ kernel_size,stride,padding,uses_bn),
+ *Paraphraser4FactorTransfer.make_enc_modules(num_output_channels,num_output_channels,
+ kernel_size,stride,padding,uses_bn)
+ )
+
+
+
+
+[docs]
+@register_auxiliary_model_wrapper
+classTeacher4FactorTransfer(AuxiliaryModelWrapper):
+"""
+ An auxiliary teacher model wrapper for factor transfer (FT), including paraphraser :class:`Paraphraser4FactorTransfer`.
+
+ Jangho Kim, Seonguk Park, Nojun Kwak: `"Paraphrasing Complex Network: Network Compression via Factor Transfer" <https://papers.neurips.cc/paper_files/paper/2018/hash/6d9cb7de5e8ac30bd5e8734bc96a35c1-Abstract.html>`_ @ NeurIPS 2018 (2018)
+
+ :param teacher_model: teacher model.
+ :type teacher_model: nn.Module
+ :param minimal: ``model_config`` for :meth:`build_auxiliary_model_wrapper` if you want to.
+ :type minimal: dict or None
+ :param input_module_path: path of module whose output is used as input to paraphraser.
+ :type input_module_path: str
+ :param paraphraser_kwargs: kwargs to instantiate :class:`Paraphraser4FactorTransfer`.
+ :type paraphraser_kwargs: dict
+ :param uses_decoder: ``uses_decoder`` for :class:`Paraphraser4FactorTransfer`.
+ :type uses_decoder: bool
+ :param device: target device.
+ :type device: torch.device
+ :param device_ids: target device IDs.
+ :type device_ids: list[int]
+ :param distributed: whether to be in distributed training mode.
+ :type distributed: bool
+ :param find_unused_parameters: ``find_unused_parameters`` for DistributedDataParallel.
+ :type find_unused_parameters: bool or None
+ """
+ def__init__(self,teacher_model,minimal,input_module_path,
+ paraphraser_kwargs,paraphraser_ckpt,uses_decoder,device,device_ids,distributed,
+ find_unused_parameters=None,**kwargs):
+ super().__init__()
+ ifminimalisNone:
+ minimal=dict()
+
+ auxiliary_teacher_model_wrapper=build_auxiliary_model_wrapper(minimal,teacher_model=teacher_model)
+ model_type='original'
+ teacher_ref_model=teacher_model
+ ifauxiliary_teacher_model_wrapperisnotNone:
+ teacher_ref_model=auxiliary_teacher_model_wrapper
+ model_type=type(teacher_ref_model).__name__
+
+ self.teacher_model=redesign_model(teacher_ref_model,minimal,'teacher',model_type)
+ self.input_module_path=input_module_path
+ paraphraser=Paraphraser4FactorTransfer(uses_decoder=uses_decoder,**paraphraser_kwargs)
+ self.paraphraser=wrap_if_distributed(paraphraser,device,device_ids,distributed,
+ find_unused_parameters=find_unused_parameters)
+ self.ckpt_file_path=paraphraser_ckpt
+ ifos.path.isfile(self.ckpt_file_path):
+ map_location={'cuda:0':'cuda:{}'.format(device_ids[0])}ifdistributedelsedevice
+ load_module_ckpt(self.paraphraser,map_location,self.ckpt_file_path)
+ self.uses_decoder=uses_decoder
+
+ defforward(self,*args):
+ withtorch.no_grad():
+ returnself.teacher_model(*args)
+
+ defsecondary_forward(self,io_dict):
+ ifself.uses_decoderandnotself.paraphraser.training:
+ self.paraphraser.train()
+ self.paraphraser(io_dict[self.input_module_path]['output'])
+
+ defpost_epoch_process(self,*args,**kwargs):
+ save_module_ckpt(self.paraphraser,self.ckpt_file_path)
+
+
+
+
+[docs]
+@register_auxiliary_model_wrapper
+classStudent4FactorTransfer(AuxiliaryModelWrapper):
+"""
+ An auxiliary student model wrapper for factor transfer (FT), including translator :class:`Translator4FactorTransfer`.
+
+ Jangho Kim, Seonguk Park, Nojun Kwak: `"Paraphrasing Complex Network: Network Compression via Factor Transfer" <https://papers.neurips.cc/paper_files/paper/2018/hash/6d9cb7de5e8ac30bd5e8734bc96a35c1-Abstract.html>`_ @ NeurIPS 2018 (2018)
+
+ :param student_model: student model.
+ :type student_model: nn.Module
+ :param input_module_path: path of module whose output is used as input to paraphraser.
+ :type input_module_path: str
+ :param translator_kwargs: kwargs to instantiate :class:`Translator4FactorTransfer`.
+ :type translator_kwargs: dict
+ :param device: target device.
+ :type device: torch.device
+ :param device_ids: target device IDs.
+ :type device_ids: list[int]
+ :param distributed: whether to be in distributed training mode.
+ :type distributed: bool
+ :param find_unused_parameters: ``find_unused_parameters`` for DistributedDataParallel.
+ :type find_unused_parameters: bool or None
+ """
+ def__init__(self,student_model,input_module_path,translator_kwargs,device,device_ids,distributed,
+ find_unused_parameters=None,**kwargs):
+ super().__init__()
+ self.student_model=wrap_if_distributed(student_model,device,device_ids,distributed,
+ find_unused_parameters=find_unused_parameters)
+ self.input_module_path=input_module_path
+ self.translator= \
+ wrap_if_distributed(Translator4FactorTransfer(**translator_kwargs),device,device_ids,distributed,
+ find_unused_parameters=find_unused_parameters)
+
+ defforward(self,*args):
+ returnself.student_model(*args)
+
+ defsecondary_forward(self,io_dict):
+ self.translator(io_dict[self.input_module_path]['output'])
+
+
+
+
+[docs]
+@register_auxiliary_model_wrapper
+classConnector4DAB(AuxiliaryModelWrapper):
+"""
+ An auxiliary student model wrapper with connector for distillation of activation boundaries (DAB).
+
+ Byeongho Heo, Minsik Lee, Sangdoo Yun, Jin Young Choi: `"Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons" <https://ojs.aaai.org/index.php/AAAI/article/view/4264>`_ @ AAAI 2019 (2019)
+
+ :param student_model: student model.
+ :type student_model: nn.Module
+ :param connectors: connector keys and configurations.
+ :type connectors: dict
+ :param device: target device.
+ :type device: torch.device
+ :param device_ids: target device IDs.
+ :type device_ids: list[int]
+ :param distributed: whether to be in distributed training mode.
+ :type distributed: bool
+ :param find_unused_parameters: ``find_unused_parameters`` for DistributedDataParallel.
+ :type find_unused_parameters: bool or None
+ """
+ @staticmethod
+ defbuild_connector(conv2d_kwargs,bn2d_kwargs=None):
+ module_list=[nn.Conv2d(**conv2d_kwargs)]
+ ifbn2d_kwargsisnotNoneandlen(bn2d_kwargs)>0:
+ module_list.append(nn.BatchNorm2d(**bn2d_kwargs))
+ returnnn.Sequential(*module_list)
+
+ def__init__(self,student_model,connectors,device,device_ids,distributed,find_unused_parameters=None,
+ **kwargs):
+ super().__init__()
+ self.student_model=wrap_if_distributed(student_model,device,device_ids,distributed,find_unused_parameters)
+ io_path_pairs=list()
+ self.connector_dict=nn.ModuleDict()
+ forconnector_key,connector_configinconnectors.items():
+ connector= \
+ self.build_connector(connector_config['conv2d_kwargs'],connector_config.get('bn2d_kwargs',None))
+ self.connector_dict[connector_key]= \
+ wrap_if_distributed(connector,device,device_ids,distributed,find_unused_parameters)
+ io_path_pairs.append((connector_key,connector_config['io'],connector_config['path']))
+ self.io_path_pairs=io_path_pairs
+
+ defforward(self,x):
+ returnself.student_model(x)
+
+ defsecondary_forward(self,io_dict):
+ forconnector_key,io_type,module_pathinself.io_path_pairs:
+ self.connector_dict[connector_key](io_dict[module_path][io_type])
+
+
+
+
+[docs]
+classRegressor4VID(nn.Module):
+"""
+ An auxiliary module for variational information distillation (VID).
+
+ Sungsoo Ahn, Shell Xu Hu, Andreas Damianou, Neil D. Lawrence, Zhenwen Dai: `"Variational Information Distillation for Knowledge Transfer" <https://openaccess.thecvf.com/content_CVPR_2019/html/Ahn_Variational_Information_Distillation_for_Knowledge_Transfer_CVPR_2019_paper.html>`_ @ CVPR 2019 (2019)
+
+ :param in_channels: number of input channels for the first convolution layer.
+ :type in_channels: int
+ :param mid_channels: number of output/input channels for the first/second convolution layer.
+ :type mid_channels: int
+ :param out_channels: number of output channels for the third convolution layer.
+ :type out_channels: int
+ :param eps: eps.
+ :type eps: float
+ :param init_pred_var: minimum variance introduced for numerical stability.
+ :type init_pred_var: float
+ """
+ def__init__(self,in_channels,middle_channels,out_channels,eps,init_pred_var,**kwargs):
+ super().__init__()
+ self.regressor=nn.Sequential(
+ nn.Conv2d(in_channels,middle_channels,kernel_size=1,stride=1,padding=0,bias=False),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(middle_channels,middle_channels,kernel_size=1,stride=1,padding=0,bias=False),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(middle_channels,out_channels,kernel_size=1,stride=1,padding=0,bias=False),
+ )
+ self.soft_plus_param= \
+ nn.Parameter(np.log(np.exp(init_pred_var-eps)-1.0)*torch.ones(out_channels))
+ self.eps=eps
+ self.init_pred_var=init_pred_var
+
+ defforward(self,student_feature_map):
+ pred_mean=self.regressor(student_feature_map)
+ pred_var=torch.log(1.0+torch.exp(self.soft_plus_param))+self.eps
+ pred_var=pred_var.view(1,-1,1,1)
+ returnpred_mean,pred_var
+
+
+
+
+[docs]
+@register_auxiliary_model_wrapper
+classVariationalDistributor4VID(AuxiliaryModelWrapper):
+"""
+ An auxiliary student model wrapper for variational information distillation (VID), including translator :class:`Regressor4VID`.
+
+ Sungsoo Ahn, Shell Xu Hu, Andreas Damianou, Neil D. Lawrence, Zhenwen Dai: `"Variational Information Distillation for Knowledge Transfer" <https://openaccess.thecvf.com/content_CVPR_2019/html/Ahn_Variational_Information_Distillation_for_Knowledge_Transfer_CVPR_2019_paper.html>`_ @ CVPR 2019 (2019)
+
+ :param student_model: student model.
+ :type student_model: nn.Module
+ :param in_channels: number of input channels for the first convolution layer.
+ :type in_channels: int
+ :param regressors: regressor keys and configurations.
+ :type regressors: dict
+ :param device: target device.
+ :type device: torch.device
+ :param device_ids: target device IDs.
+ :type device_ids: list[int]
+ :param distributed: whether to be in distributed training mode.
+ :type distributed: bool
+ :param find_unused_parameters: ``find_unused_parameters`` for DistributedDataParallel.
+ :type find_unused_parameters: bool or None
+ """
+ def__init__(self,student_model,regressors,device,device_ids,distributed,find_unused_parameters=None,
+ **kwargs):
+ super().__init__()
+ self.student_model=wrap_if_distributed(student_model,device,device_ids,distributed,find_unused_parameters)
+ io_path_pairs=list()
+ self.regressor_dict=nn.ModuleDict()
+ forregressor_key,regressor_configinregressors.items():
+ regressor=Regressor4VID(**regressor_config['kwargs'])
+ self.regressor_dict[regressor_key]= \
+ wrap_if_distributed(regressor,device,device_ids,distributed,find_unused_parameters)
+ io_path_pairs.append((regressor_key,regressor_config['io'],regressor_config['path']))
+ self.io_path_pairs=io_path_pairs
+
+ defforward(self,x):
+ returnself.student_model(x)
+
+ defsecondary_forward(self,io_dict):
+ forregressor_key,io_type,module_pathinself.io_path_pairs:
+ self.regressor_dict[regressor_key](io_dict[module_path][io_type])
+[docs]
+@register_auxiliary_model_wrapper
+classLinear4CRD(AuxiliaryModelWrapper):
+"""
+ An auxiliary teacher/student model wrapper for contrastive representation distillation (CRD), including translator :class:`Normalizer4CRD`.
+ Refactored https://github.com/HobbitLong/RepDistiller/blob/master/crd/memory.py
+
+ Yonglong Tian, Dilip Krishnan, Phillip Isola: `"Contrastive Representation Distillation" <https://openreview.net/forum?id=SkgpBJrtvS>`_ @ ICLR 2020 (2020)
+
+ :param input_module_path: path of module whose output will be flattened and then used as input to normalizer.
+ :type input_module_path: str
+ :param linear_kwargs: kwargs for Linear.
+ :type linear_kwargs: dict
+ :param device: target device.
+ :type device: torch.device
+ :param device_ids: target device IDs.
+ :type device_ids: list[int]
+ :param distributed: whether to be in distributed training mode.
+ :type distributed: bool
+ :param power: ``power`` for :class:`Normalizer4CRD`.
+ :type power: int
+ :param teacher_model: teacher model.
+ :type teacher_model: nn.Module or None
+ :param student_model: student model.
+ :type student_model: nn.Module or None
+ :param find_unused_parameters: ``find_unused_parameters`` for DistributedDataParallel.
+ :type find_unused_parameters: bool or None
+ """
+ def__init__(self,input_module_path,linear_kwargs,device,device_ids,distributed,power=2,
+ teacher_model=None,student_model=None,find_unused_parameters=None,**kwargs):
+ super().__init__()
+ is_teacher=teacher_modelisnotNone
+ ifnotis_teacher:
+ student_model=wrap_if_distributed(student_model,device,device_ids,distributed,find_unused_parameters)
+
+ self.model=teacher_modelifis_teacherelsestudent_model
+ self.is_teacher=is_teacher
+ self.empty=nn.Sequential()
+ self.input_module_path=input_module_path
+ linear=nn.Linear(**linear_kwargs)
+ self.normalizer=wrap_if_distributed(Normalizer4CRD(linear,power=power),device,device_ids,distributed,
+ find_unused_parameters)
+
+ defforward(self,x,supp_dict):
+ # supp_dict is given to be hooked and stored in io_dict
+ self.empty(supp_dict)
+ ifself.is_teacher:
+ withtorch.no_grad():
+ returnself.model(x)
+ returnself.model(x)
+
+ defsecondary_forward(self,io_dict):
+ flat_outputs=torch.flatten(io_dict[self.input_module_path]['output'],1)
+ self.normalizer(flat_outputs)
+
+
+
+
+[docs]
+@register_auxiliary_model_wrapper
+classHeadRCNN(AuxiliaryModelWrapper):
+"""
+ An auxiliary teacher/student model wrapper for head network distillation (HND) and generalized head network distillation (GHND).
+
+ * Yoshitomo Matsubara, Sabur Baidya, Davide Callegaro, Marco Levorato, Sameer Singh: `"Distilled Split Deep Neural Networks for Edge-Assisted Real-Time Systems" <https://dl.acm.org/doi/10.1145/3349614.3356022>`_ @ MobiCom 2019 Workshop on Hot Topics in Video Analytics and Intelligent Edges (2019)
+ * Yoshitomo Matsubara, Marco Levorato: `"Neural Compression and Filtering for Edge-assisted Real-time Object Detection in Challenged Networks" <https://arxiv.org/abs/2007.15818>`_ @ ICPR 2020 (2021)
+
+ :param head_rcnn: head R-CNN configuration as ``model_config`` in :meth:`torchdistill.models.util.redesign_model`.
+ :type head_rcnn: dict
+ :param kwargs: ``teacher_model`` or ``student_model`` keys must be included. If both ``teacher_model`` and ``student_model`` are provided, ``student_model`` will be prioritized.
+ :type kwargs: dict
+ """
+ def__init__(self,head_rcnn,**kwargs):
+ super().__init__()
+ tmp_ref_model=kwargs.get('teacher_model',None)
+ ref_model=kwargs.get('student_model',tmp_ref_model)
+ ifref_modelisNone:
+ raiseValueError('Either student_model or teacher_model has to be given.')
+
+ self.transform=ref_model.transform
+ self.seq=redesign_model(ref_model,head_rcnn,'R-CNN','HeadRCNN')
+
+ defforward(self,images,targets=None):
+ original_image_sizes=torch.jit.annotate(List[Tuple[int,int]],[])
+ forimginimages:
+ val=img.shape[-2:]
+ assertlen(val)==2
+ original_image_sizes.append((val[0],val[1]))
+
+ images,targets=self.transform(images,targets)
+ returnself.seq(images.tensors)
+[docs]
+defregister_optimizer(arg=None,**kwargs):
+"""
+ Registers an optimizer class or function to instantiate it.
+
+ :param arg: class or function to be registered as an optimizer.
+ :type arg: class or typing.Callable or None
+ :return: registered optimizer class or function to instantiate it.
+ :rtype: class or typing.Callable
+
+ .. note::
+ The optimizer will be registered as an option.
+ You can choose the registered class/function by specifying the name of the class/function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the class/function with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torch.optim import Optimizer
+ >>> from torchdistill.optim.registry import register_optimizer
+ >>>
+ >>> @register_optimizer(key='my_custom_optimizer')
+ >>> class CustomOptimizer(Optimizer):
+ >>> def __init__(self, **kwargs):
+ >>> print('This is my custom optimizer class')
+
+ In the example, ``CustomOptimizer`` class is registered with a key "my_custom_optimizer".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``CustomOptimizer`` class by
+ "my_custom_optimizer".
+ """
+ def_register_optimizer(cls_or_func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=cls_or_func.__name__
+
+ OPTIM_DICT[key]=cls_or_func
+ returncls_or_func
+
+ ifcallable(arg):
+ return_register_optimizer(arg)
+ return_register_optimizer
+
+
+
+
+[docs]
+defregister_scheduler(arg=None,**kwargs):
+"""
+ Registers a scheduler class or function to instantiate it.
+
+ :param arg: class or function to be registered as a scheduler.
+ :type arg: class or typing.Callable or None
+ :return: registered scheduler class or function to instantiate it.
+ :rtype: class or typing.Callable
+
+ .. note::
+ The scheduler will be registered as an option.
+ You can choose the registered class/function by specifying the name of the class/function or ``key``
+ you used for the registration, in a training configuration used for
+ :class:`torchdistill.core.distillation.DistillationBox` or :class:`torchdistill.core.training.TrainingBox`.
+
+ If you want to register the class/function with a key of your choice, add ``key`` to the decorator as below:
+
+ >>> from torch.optim.lr_scheduler import LRScheduler
+ >>> from torchdistill.optim.registry import register_scheduler
+ >>>
+ >>> @register_scheduler(key='my_custom_scheduler')
+ >>> class CustomScheduler(LRScheduler):
+ >>> def __init__(self, **kwargs):
+ >>> print('This is my custom scheduler class')
+
+ In the example, ``CustomScheduler`` class is registered with a key "my_custom_scheduler".
+ When you configure :class:`torchdistill.core.distillation.DistillationBox` or
+ :class:`torchdistill.core.training.TrainingBox`, you can choose the ``CustomScheduler`` class by
+ "my_custom_scheduler".
+ """
+ def_register_scheduler(cls_or_func):
+ key=kwargs.get('key')
+ ifkeyisNone:
+ key=cls_or_func.__name__
+
+ SCHEDULER_DICT[key]=cls_or_func
+ returncls_or_func
+
+ ifcallable(arg):
+ return_register_scheduler(arg)
+ return_register_scheduler
+
+
+
+
+[docs]
+defget_optimizer(module,key,filters_params=True,*args,**kwargs):
+"""
+ Gets an optimizer from the optimizer registry.
+
+ :param module: module to be added to optimizer.
+ :type module: nn.Module
+ :param key: optimizer key.
+ :type key: str
+ :param filters_params: if True, filers out parameters whose `required_grad = False`.
+ :type filters_params: bool
+ :return: optimizer.
+ :rtype: Optimizer
+ """
+ is_module=isinstance(module,nn.Module)
+ ifkeyinOPTIM_DICT:
+ optim_cls_or_func=OPTIM_DICT[key]
+ ifis_moduleandfilters_params:
+ params=module.parameters()ifis_moduleelsemodule
+ updatable_params=[pforpinparamsifp.requires_grad]
+ returnoptim_cls_or_func(updatable_params,*args,**kwargs)
+ returnoptim_cls_or_func(module,*args,**kwargs)
+ raiseValueError('No optimizer `{}` registered'.format(key))
+
+
+
+
+[docs]
+defget_scheduler(optimizer,key,*args,**kwargs):
+"""
+ Gets a scheduler from the scheduler registry.
+
+ :param optimizer: optimizer to be added to scheduler.
+ :type optimizer: Optimizer
+ :param key: scheduler key.
+ :type key: str
+ :return: scheduler.
+ :rtype: LRScheduler
+ """
+ ifkeyinSCHEDULER_DICT:
+ returnSCHEDULER_DICT[key](optimizer,*args,**kwargs)
+ raiseValueError('No scheduler `{}` registered'.format(key))
torchdistill (formerly kdkit) offers various state-of-the-art knowledge distillation methods
+and enables you to design (new) experiments simply by editing a declarative yaml config file instead of Python code.
+Even when you need to extract intermediate representations in teacher/student models,
+you will NOT need to reimplement the models, that often change the interface of the forward, but instead
+specify the module path(s) in the yaml file.
+
In addition to knowledge distillation, this framework helps you design and perform general deep learning experiments
+(WITHOUT coding) for reproducible deep learning studies. i.e., it enables you to train models without teachers
+simply by excluding teacher entries from a declarative yaml config file.
+You can find such examples in configs/sample/ of the official repository.
+
When you refer to torchdistill in your paper, please cite these papers
+instead of this GitHub repository.
+If you use torchdistill as part of your work, your citation is appreciated and motivates me to maintain and upgrade this framework!
@inproceedings{matsubara2021torchdistill,
+title={{torchdistill: A Modular, Configuration-Driven Framework for Knowledge Distillation}},
+author={Matsubara, Yoshitomo},
+booktitle={International Workshop on Reproducible Research in Pattern Recognition},
+pages={24--44},
+year={2021},
+organization={Springer}
+}
+
+@inproceedings{matsubara2023torchdistill,
+title={{torchdistill Meets Hugging Face Libraries for Reproducible, Coding-Free Deep Learning Studies: A Case Study on NLP}},
+author={Matsubara, Yoshitomo},
+booktitle={Proceedings of the 3rd Workshop for Natural Language Processing Open Source Software (NLP-OSS 2023)},
+publisher={Empirical Methods in Natural Language Processing},
+pages={153--164},
+year={2023}
+}
+
If you have either a question or feature request, start a new “Q&A” discussion at GitHub instead of a GitHub issue.
+Please make sure the issue/question/request has not been addressed yet by searching through the open/closed issues and discussions.
This page is a showcase of OSS (open source software) and papers which have used torchdistill in the projects.
+If your work is built on torchdistill, start a “Show and tell” discussion at GitHub.
This framework was built on PyTorch and designed to benchmark SC2 methods, Supervised Compression for Split Computing.
+It is pip-installable and published as a PyPI package i.e., you can install it by pip3installsc2bench
Abstract: In this paper we revisit the efficacy of knowledge distillation as a function matching and metric learning problem. In doing so we verify three important design decisions, namely the normalisation, soft maximum function, and projection layers as key ingredients. We theoretically show that the projector implicitly encodes information on past examples, enabling relational gradients for the student. We then show that the normalisation of representations is tightly coupled with the training dynamics of this projector, which can have a large impact on the students performance. Finally, we show that a simple soft maximum function can be used to address any significant capacity gap problems. Experimental results on various benchmark datasets demonstrate that using these insights can lead to superior or comparable performance to state-of-the-art knowledge distillation techniques, despite being much more computationally efficient. In particular, we obtain these results across image classification (CIFAR100 and ImageNet), object detection (COCO2017), and on more difficult distillation objectives, such as training data efficient transformers, whereby we attain a 77.2% top-1 accuracy with DeiT-Ti on ImageNet. Code and models are publicly available.
+
+
+
FrankenSplit: Efficient Neural Feature Compression With Shallow Variational Bottleneck Injection for Mobile Edge Computing
+
+
Author(s): Alireza Furutanpey, Philipp Raith, Schahram Dustdar
Abstract: The rise of mobile AI accelerators allows latency-sensitive applications to execute lightweight Deep Neural Networks (DNNs) on the client side. However, critical applications require powerful models that edge devices cannot host and must therefore offload requests, where the high-dimensional data will compete for limited bandwidth. Split Computing (SC) alleviates resource inefficiency by partitioning DNN layers across devices, but current methods are overly specific and only marginally reduce bandwidth consumption. This work proposes shifting away from focusing on executing shallow layers of partitioned DNNs. Instead, it advocates concentrating the local resources on variational compression optimized for machine interpretability. We introduce a novel framework for resource-conscious compression models and extensively evaluate our method in an environment reflecting the asymmetric resource distribution between edge devices and servers. Our method achieves 60% lower bitrate than a state-of-the-art SC method without decreasing accuracy and is up to 16x faster than offloading with existing codec standards.
+
+
+
torchdistill Meets Hugging Face Libraries for Reproducible, Coding-Free Deep Learning Studies: A Case Study on NLP
+
+
Author(s): Yoshitomo Matsubara
+
Venue: EMNLP 2023 Workshop for Natural Language Processing Open Source Software (NLP-OSS)
Abstract: Reproducibility in scientific work has been becoming increasingly important in research communities
+such as machine learning, natural language processing, and computer vision communities due to the rapid development of
+the research domains supported by recent advances in deep learning. In this work, we present a significantly upgraded
+version of torchdistill, a modular-driven coding-free deep learning framework significantly upgraded from the initial
+release, which supports only image classification and object detection tasks for reproducible knowledge distillation
+experiments. To demonstrate that the upgraded framework can support more tasks with third-party libraries, we reproduce
+the GLUE benchmark results of BERT models using a script based on the upgraded torchdistill, harmonizing with various
+Hugging Face libraries. All the 27 fine-tuned BERT models and configurations to reproduce the results are published at
+Hugging Face, and the model weights have already been widely used in research communities. We also reimplement popular
+small-sized models and new knowledge distillation methods and perform additional experiments for computer vision tasks.
+
+
+
SC2 Benchmark: Supervised Compression for Split Computing
+
+
Author(s): Yoshitomo Matsubara, Ruihan Yang, Marco Levorato, Stephan Mandt
Abstract: With the increasing demand for deep learning models on mobile devices, splitting neural network
+computation between the device and a more powerful edge server has become an attractive solution. However, existing
+split computing approaches often underperform compared to a naive baseline of remote computation on compressed data.
+Recent studies propose learning compressed representations that contain more relevant information for supervised
+downstream tasks, showing improved tradeoffs between compressed data size and supervised performance. However, existing
+evaluation metrics only provide an incomplete picture of split computing. This study introduces supervised compression
+for split computing (SC2) and proposes new evaluation criteria: minimizing computation on the mobile device, minimizing
+transmitted data size, and maximizing model accuracy. We conduct a comprehensive benchmark study using 10 baseline
+methods, three computer vision tasks, and over 180 trained models, and discuss various aspects of SC2. We also release
+our code and sc2bench, a Python package for future research on SC2. Our proposed metrics and package will help
+researchers better understand the tradeoffs of supervised compression in split computing.
+
+
+
Supervised Compression for Resource-Constrained Edge Computing Systems
+
+
Author(s): Yoshitomo Matsubara, Ruihan Yang, Marco Levorato, Stephan Mandt
Abstract: There has been much interest in deploying deep learning algorithms on low-powered devices, including
+smartphones, drones, and medical sensors. However, full-scale deep neural networks are often too resource-intensive
+in terms of energy and storage. As a result, the bulk part of the machine learning operation is therefore often
+carried out on an edge server, where the data is compressed and transmitted. However, compressing data (such as images)
+leads to transmitting information irrelevant to the supervised task. Another popular approach is to split the deep
+network between the device and the server while compressing intermediate features. To date, however, such split
+computing strategies have barely outperformed the aforementioned naive data compression baselines due to their
+inefficient approaches to feature compression. This paper adopts ideas from knowledge distillation and neural image
+compression to compress intermediate feature representations more efficiently. Our supervised compression approach
+uses a teacher model and a student model with a stochastic bottleneck and learnable prior for entropy coding
+(Entropic Student). We compare our approach to various neural image and feature compression baselines in three vision
+tasks and found that it achieves better supervised rate-distortion performance while maintaining smaller end-to-end
+latency. We furthermore show that the learned feature representations can be tuned to serve multiple downstream tasks.
+
+
+
torchdistill: A Modular, Configuration-Driven Framework for Knowledge Distillation
+
+
Author(s): Yoshitomo Matsubara
+
Venue: ICPR 2020 International Workshop on Reproducible Research in Pattern Recognition
Abstract: While knowledge distillation (transfer) has been attracting attentions from the research community,
+the recent development in the fields has heightened the need for reproducible studies and highly generalized frameworks
+to lower barriers to such high-quality, reproducible deep learning research. Several researchers voluntarily published
+frameworks used in their knowledge distillation studies to help other interested researchers reproduce their original
+work. Such frameworks, however, are usually neither well generalized nor maintained, thus researchers are still
+required to write a lot of code to refactor/build on the frameworks for introducing new methods, models, datasets and
+designing experiments. In this paper, we present our developed open-source framework built on PyTorch and dedicated for
+knowledge distillation studies. The framework is designed to enable users to design experiments by declarative PyYAML
+configuration files, and helps researchers complete the recently proposed ML Code Completeness Checklist. Using the
+developed framework, we demonstrate its various efficient training strategies, and implement a variety of knowledge
+distillation methods. We also reproduce some of their original experimental results on the ImageNet and COCO datasets
+presented at major machine learning conferences such as ICLR, NeurIPS, CVPR and ECCV, including recent state-of-the-art
+methods. All the source code, configurations, log files and trained model weights are publicly available at
+https://github.com/yoshitomo-matsubara/torchdistill.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/py-modindex.html b/py-modindex.html
new file mode 100644
index 00000000..9823d6c3
--- /dev/null
+++ b/py-modindex.html
@@ -0,0 +1,318 @@
+
+
+
+
+
+
+
+ Python Module Index — torchdistill v1.1.2 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
If you want to register the function with a key of your choice, add key to the decorator as below:
+
>>> fromtorchdistill.core.interfaces.registryimportregister_pre_epoch_proc_func
+>>> @register_pre_epoch_proc_func(key='my_custom_pre_epoch_proc_func')
+>>> defnew_pre_epoch_proc(self,epoch=None,**kwargs):
+>>> print('This is my custom pre-epoch process function')
+
If you want to register the function with a key of your choice, add key to the decorator as below:
+
>>> fromtorchdistill.core.interfaces.registryimportregister_pre_forward_proc_func
+>>> @register_pre_forward_proc_func(key='my_custom_pre_forward_proc_func')
+>>> defnew_pre_forward_proc(self,*args,**kwargs):
+>>> print('This is my custom pre-forward process function')
+
If you want to register the function with a key of your choice, add key to the decorator as below:
+
>>> fromtorchdistill.core.interfaces.registryimportregister_forward_proc_func
+>>> @register_forward_proc_func(key='my_custom_forward_proc_func')
+>>> defnew_forward_proc(model,sample_batch,targets=None,supp_dict=None,**kwargs):
+>>> print('This is my custom forward process function')
+
If you want to register the function with a key of your choice, add key to the decorator as below:
+
>>> fromtorchdistill.core.interfaces.registryimportregister_post_forward_proc_func
+>>> @register_post_forward_proc_func(key='my_custom_post_forward_proc_func')
+>>> defnew_post_forward_proc(self,loss,metrics=None,**kwargs):
+>>> print('This is my custom post-forward process function')
+
If you want to register the function with a key of your choice, add key to the decorator as below:
+
>>> fromtorchdistill.core.interfaces.registryimportregister_post_epoch_proc_func
+>>> @register_post_epoch_proc_func(key='my_custom_post_epoch_proc_func')
+>>> defnew_post_epoch_proc(self,metrics=None,**kwargs):
+>>> print('This is my custom post-epoch process function')
+
Sets up training and validation data loaders for the current training stage.
+This method will be internally called when instantiating this class and when calling
+MultiStagesTrainingBox.advance_to_next_stage().
Sets up a training loss module for the current training stage.
+This method will be internally called when instantiating this class and when calling
+MultiStagesTrainingBox.advance_to_next_stage().
Sets up pre/post-epoch/forward processes for the current training stage.
+This method will be internally called when instantiating this class and when calling
+MultiStagesTrainingBox.advance_to_next_stage().
Sets up training and validation data loaders for the current training stage.
+This method will be internally called when instantiating this class and when calling
+MultiStagesDistillationBox.advance_to_next_stage().
Sets up teacher and student models for the current training stage.
+This method will be internally called when instantiating this class and when calling
+MultiStagesDistillationBox.advance_to_next_stage().
Sets up a training loss module for the current training stage.
+This method will be internally called when instantiating this class and when calling
+MultiStagesDistillationBox.advance_to_next_stage().
Sets up pre/post-epoch/forward processes for the current training stage.
+This method will be internally called when instantiating this class and when calling
+MultiStagesDistillationBox.advance_to_next_stage().
If you want to register the function with a key of your choice, add key to the decorator as below:
+
>>> fromtorchdistill.losses.registryimportregister_func2extract_model_output
+>>>
+>>> @register_func2extract_model_output(key='my_custom_function2extract_model_output')
+>>> defcustom_func2extract_model_output(batch,label):
+>>> print('This is my custom collate function')
+>>> returnbatch,label
+
+
+
In the example, custom_func2extract_model_output function is registered with a key “my_custom_function2extract_model_output”.
+When you configure torchdistill.core.distillation.DistillationBox or
+torchdistill.core.training.TrainingBox, you can choose the custom_func2extract_model_output function by
+“my_custom_function2extract_model_output”.
A weighted sum (linear combination) of mid-/low-level loss modules.
+
If model_term contains a numerical value with weight key, it will be a multiplier \(W_{model}\)
+for the sum of model-driven loss values \(\sum_{i} L_{model, i}\).
A dict-based wrapper module designed to use low-level loss modules (e.g., loss modules in PyTorch)
+in torchdistill’s pipelines. This is a subclass of SimpleLossWrapper and useful for models whose forward
+output is dict.
+
+
Parameters:
+
+
low_level_loss (nn.Module) – low-level loss module e.g., torch.nn.CrossEntropyLoss.
+
weights (dict) – dict contains keys that match the model’s output dict keys and corresponding loss weights.
+
kwargs (dict or None) – kwargs to configure what the wrapper passes low_level_loss.
+
+
+
+
+
An example YAML to instantiate DictLossWrapper for deeplabv3_resnet50 in torchvision, whose default output is a dict of outputs from its main and auxiliary branches with keys ‘out’ and ‘aux’ respectively.
kernel_config (dict) – kernel (‘gaussian’ or ‘bilinear’) configuration.
+
reduction (str) – loss reduction type.
+
+
+
+
+
An example YAML to instantiate CCKDLoss for a teacher-student pair of ResNet-50 and ResNet-18 in torchvision, using an auxiliary module torchdistill.models.wrapper.Linear4CCKD for the teacher and student models.
num_negative_samples (int) – number of negative samples.
+
num_samples (int) – number of samples.
+
temperature (float) – temperature to adjust concentration level (not the temperature for KDLoss).
+
momentum (float) – momentum.
+
eps (float) – eps.
+
+
+
+
+
An example YAML to instantiate CRDLoss for a teacher-student pair of ResNet-34 and ResNet-18 in torchvision, using an auxiliary module torchdistill.models.wrapper.Linear4CRD for the teacher and student models.
A loss module for self-supervision knowledge distillation (SSKD) that treats contrastive prediction as
+a self-supervision task (auxiliary task). This loss module is used at the 1st stage of SSKD method.
+Refactored https://github.com/xuguodong03/SSKD/blob/master/student.py
kl_temp (float) – temperature to soften teacher and student’s class-probability distributions for KL divergence given original data.
+
ss_temp (float) – temperature to soften teacher and student’s self-supervision cosine similarities for KL divergence.
+
tf_temp (float) – temperature to soften teacher and student’s class-probability distributions for KL divergence given augmented data by transform.
+
ss_ratio (float) – ratio of samples with the smallest error levels used for self-supervision.
+
tf_ratio (float) – ratio of samples with the smallest error levels used for transform.
+
student_linear_module_io (str) – ‘input’ or ‘output’ of the linear module in the student model.
+
teacher_linear_module_io (str) – ‘input’ or ‘output’ of the linear module in the teacher model.
+
student_ss_module_io (str) – ‘input’ or ‘output’ of the self-supervision module in the student model.
+
teacher_ss_module_io (str) – ‘input’ or ‘output’ of the self-supervision module in the teacher model.
+
loss_weights (list[float] or None) – weights for 1) cross-entropy, 2) KL divergence for the original data, 3) KL divergence for self-supervision cosine similarities, and 4) KL divergence for the augmented data by transform.
+
reduction (str or None) – reduction for KLDivLoss. If reduction = ‘batchmean’, CrossEntropyLoss’s reduction will be ‘mean’.
+
+
+
+
+
An example YAML to instantiate SSKDLoss for a teacher-student pair of ResNet-34 and ResNet-18 in torchvision, using an auxiliary module torchdistill.models.wrapper.SSWrapper4SSKD for the teacher and student models.
feature_pairs (dict) – configuration of teacher-student module pairs to compute the L2 distance between the inter-channel correlation matrices of the student and the teacher.
An empty auxiliary model wrapper. This module returns input as output and is useful when you want to replace
+your teacher/student model with an empty model for saving inference time.
+e.g., Multi-stage knowledge distillation may have some stages that do not require either teacher or student models.
Translator for factor transfer (FT). This module is used at the 2nd stage of FT method.
+Note that “the student translator has the same three convolution layers as the paraphraser”.
An auxiliary teacher/student model wrapper for correlation congruence for knowledge distillation (CCKD).
+Fully-connected layers cope with a mismatch of feature representations of teacher and student models.
kwargs (dict) – teacher_model or student_model keys must be included. If both teacher_model and student_model are provided, student_model will be prioritized.
An auxiliary teacher/student model wrapper for self-supervision knowledge distillation (SSKD).
+If both teacher_model and student_model are provided, student_model will be prioritized
Step 2: Register the module e.g., add a registry function to the module as a Python decorator
+
Step 3: Run your script with a yaml file containing the module name (key) and parameters, call the Python decorator, and then your module is available in the registry
+
+
Steps 1 and 2: Create a Python file (e.g., my_module.py) containing your own module (e.g., “MyNewCoolModel”) with a Python decorator “register_model”
Step 3: Run your script (e.g., example/torchvision/image_classification.py) with a yaml containing the registered module name (“MyNewCoolModel”) and parameters (“some_value”, “some_list”, “some_dict”)