diff --git a/classy_vision/dataset/classy_kinetics400.py b/classy_vision/dataset/classy_kinetics400.py index 9ce1d042d0..5484c81cc1 100644 --- a/classy_vision/dataset/classy_kinetics400.py +++ b/classy_vision/dataset/classy_kinetics400.py @@ -45,6 +45,7 @@ def __init__( audio_channels: int, step_between_clips: int, frame_rate: Optional[int], + dilation: int, clips_per_video: int, video_dir: str, extensions: List[str], @@ -72,6 +73,7 @@ def __init__( step_between_clips: Number of frames between each clip. frame_rate: desired video frame rate. If None, keep orignal video frame rate. + dilation: the spacing between adjacent sampled frames clips_per_video: Number of clips to sample from each video video_dir: path to video folder extensions: A list of file extensions, such as "avi" and "mp4". Only @@ -104,6 +106,7 @@ def __init__( _video_min_dimension=video_min_dimension, _audio_samples=audio_samples, _audio_channels=audio_channels, + dilation=dilation, ) metadata = dataset.metadata if metadata and not os.path.exists(metadata_filepath): @@ -149,6 +152,7 @@ def from_config(cls, config: Dict[str, Any]) -> "Kinetics400Dataset": audio_samples, step_between_clips, frame_rate, + dilation, clips_per_video, ) = cls.parse_config(config) extensions = config.get("extensions", ("mp4")) @@ -169,6 +173,7 @@ def from_config(cls, config: Dict[str, Any]) -> "Kinetics400Dataset": audio_channels, step_between_clips, frame_rate, + dilation, clips_per_video, config["video_dir"], extensions, diff --git a/classy_vision/dataset/classy_video_dataset.py b/classy_vision/dataset/classy_video_dataset.py index 9e8385696d..a1e3e02c09 100644 --- a/classy_vision/dataset/classy_video_dataset.py +++ b/classy_vision/dataset/classy_video_dataset.py @@ -125,6 +125,7 @@ def parse_config(cls, config: Dict[str, Any]): audio_samples = config.get("audio_samples", 0) step_between_clips = config.get("step_between_clips", 1) frame_rate = config.get("frame_rate", None) + dilation = config.get("dilation", 1) clips_per_video = config.get("clips_per_video", 1) ( @@ -157,6 +158,7 @@ def parse_config(cls, config: Dict[str, Any]): audio_samples, step_between_clips, frame_rate, + dilation, clips_per_video, ) diff --git a/classy_vision/generic/profiler.py b/classy_vision/generic/profiler.py index 4b1044e54a..d6b19802fc 100644 --- a/classy_vision/generic/profiler.py +++ b/classy_vision/generic/profiler.py @@ -102,7 +102,7 @@ def _get_batchsize_per_replica(x: Union[Tuple, List, Dict]) -> int: return x.size()[0] -def _layer_flops(layer: nn.Module, x: Any, y: Any) -> int: +def _layer_flops(layer: nn.Module, x: Any, y: Any, verbose: bool = False) -> int: """ Computes the number of FLOPs required for a single layer. @@ -162,6 +162,36 @@ def flops(self, x): / layer.groups ) + # 3D convolution + elif layer_type in ["Conv3d"]: + out_t = int( + (x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) + // layer.stride[0] + + 1 + ) + out_h = int( + (x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) + // layer.stride[1] + + 1 + ) + out_w = int( + (x.size()[4] + 2 * layer.padding[2] - layer.kernel_size[2]) + // layer.stride[2] + + 1 + ) + flops = ( + batchsize_per_replica + * layer.in_channels + * layer.out_channels + * layer.kernel_size[0] + * layer.kernel_size[1] + * layer.kernel_size[2] + * out_t + * out_h + * out_w + / layer.groups + ) + # learned group convolution: elif layer_type in ["LearnedGroupConv"]: conv = layer.conv @@ -186,51 +216,36 @@ def flops(self, x): ) flops = count1 + count2 - # non-linearities: + # non-linearities are not considered in MAC counting elif layer_type in ["ReLU", "ReLU6", "Tanh", "Sigmoid", "Softmax"]: - flops = x.numel() - - # 2D pooling layers: - elif layer_type in ["AvgPool2d", "MaxPool2d"]: - in_h = x.size()[2] - in_w = x.size()[3] - if isinstance(layer.kernel_size, int): - layer.kernel_size = (layer.kernel_size, layer.kernel_size) - kernel_ops = layer.kernel_size[0] * layer.kernel_size[1] - out_h = 1 + int( - (in_h + 2 * layer.padding - layer.kernel_size[0]) / layer.stride - ) - out_w = 1 + int( - (in_w + 2 * layer.padding - layer.kernel_size[1]) / layer.stride + flops = 0 + + elif layer_type in [ + "MaxPool1d", + "MaxPool2d", + "MaxPool3d", + "AdaptiveMaxPool1d", + "AdaptiveMaxPool2d", + "AdaptiveMaxPool3d", + ]: + flops = 0 + + elif layer_type in ["AvgPool1d", "AvgPool2d", "AvgPool3d"]: + kernel_ops = 1 + flops = kernel_ops * y.numel() + + elif layer_type in ["AdaptiveAvgPool1d", "AdaptiveAvgPool2d", "AdaptiveAvgPool3d"]: + assert isinstance(layer.output_size, (list, tuple)) + kernel = torch.Tensor(list(x.shape[2:])) // torch.Tensor( + [list(layer.output_size)] ) - flops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops - - # adaptive avg pool2d - # This is approximate and works only for downsampling without padding - # based on aten/src/ATen/native/AdaptiveAveragePooling.cpp - elif layer_type in ["AdaptiveAvgPool2d"]: - in_h = x.size()[2] - in_w = x.size()[3] - if isinstance(layer.output_size, int): - out_h, out_w = layer.output_size, layer.output_size - elif len(layer.output_size) == 1: - out_h, out_w = layer.output_size[0], layer.output_size[0] - else: - out_h, out_w = layer.output_size - if out_h > in_h or out_w > in_w: - raise ClassyProfilerNotImplementedError(layer) - batchsize_per_replica = x.size()[0] - num_channels = x.size()[1] - kh = in_h - out_h + 1 - kw = in_w - out_w + 1 - kernel_ops = kh * kw - flops = batchsize_per_replica * num_channels * out_h * out_w * kernel_ops + kernel_ops = torch.prod(kernel) + flops = kernel_ops * y.numel() # linear layer: elif layer_type in ["Linear"]: weight_ops = layer.weight.numel() - bias_ops = layer.bias.numel() if layer.bias is not None else 0 - flops = x.size()[0] * (weight_ops + bias_ops) + flops = x.size()[0] * weight_ops # batch normalization / layer normalization: elif layer_type in [ @@ -240,94 +255,12 @@ def flops(self, x): "SyncBatchNorm", "LayerNorm", ]: - flops = 2 * x.numel() - - # 3D convolution - elif layer_type in ["Conv3d"]: - out_t = int( - (x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) - // layer.stride[0] - + 1 - ) - out_h = int( - (x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) - // layer.stride[1] - + 1 - ) - out_w = int( - (x.size()[4] + 2 * layer.padding[2] - layer.kernel_size[2]) - // layer.stride[2] - + 1 - ) - flops = ( - batchsize_per_replica - * layer.in_channels - * layer.out_channels - * layer.kernel_size[0] - * layer.kernel_size[1] - * layer.kernel_size[2] - * out_t - * out_h - * out_w - / layer.groups - ) - - # 3D pooling layers - elif layer_type in ["AvgPool3d", "MaxPool3d"]: - in_t = x.size()[2] - in_h = x.size()[3] - in_w = x.size()[4] - if isinstance(layer.kernel_size, int): - layer.kernel_size = ( - layer.kernel_size, - layer.kernel_size, - layer.kernel_size, - ) - if isinstance(layer.padding, int): - layer.padding = (layer.padding, layer.padding, layer.padding) - if isinstance(layer.stride, int): - layer.stride = (layer.stride, layer.stride, layer.stride) - kernel_ops = layer.kernel_size[0] * layer.kernel_size[1] * layer.kernel_size[2] - out_t = 1 + int( - (in_t + 2 * layer.padding[0] - layer.kernel_size[0]) / layer.stride[0] - ) - out_h = 1 + int( - (in_h + 2 * layer.padding[1] - layer.kernel_size[1]) / layer.stride[1] - ) - out_w = 1 + int( - (in_w + 2 * layer.padding[2] - layer.kernel_size[2]) / layer.stride[2] - ) - flops = batchsize_per_replica * x.size()[1] * out_t * out_h * out_w * kernel_ops - - # adaptive avg pool3d - # This is approximate and works only for downsampling without padding - # based on aten/src/ATen/native/AdaptiveAveragePooling3d.cpp - elif layer_type in ["AdaptiveAvgPool3d"]: - in_t = x.size()[2] - in_h = x.size()[3] - in_w = x.size()[4] - out_t = layer.output_size[0] - out_h = layer.output_size[1] - out_w = layer.output_size[2] - if out_t > in_t or out_h > in_h or out_w > in_w: - raise ClassyProfilerNotImplementedError(layer) - batchsize_per_replica = x.size()[0] - num_channels = x.size()[1] - kt = in_t - out_t + 1 - kh = in_h - out_h + 1 - kw = in_w - out_w + 1 - kernel_ops = kt * kh * kw - flops = ( - batchsize_per_replica * num_channels * out_t * out_w * out_h * kernel_ops - ) + # batchnorm can be merged into conv op. Thus, count 0 FLOPS + flops = 0 # dropout layer elif layer_type in ["Dropout"]: - # At test time, we do not drop values but scale the feature map by the - # dropout ratio - flops = 1 - for dim_size in x.size(): - flops *= dim_size + flops = 0 elif layer_type == "Identity": flops = 0 @@ -351,11 +284,14 @@ def flops(self, x): f"params(M): {count_params(layer) / 1e6}", f"flops(M): {int(flops) / 1e6}", ] - logging.debug("\t".join(message)) + if verbose: + logging.info("\t".join(message)) return flops -def _layer_activations(layer: nn.Module, x: Any, out: Any) -> int: +def _layer_activations( + layer: nn.Module, x: Any, out: Any, verbose: bool = False +) -> int: """ Computes the number of activations produced by a single layer. @@ -376,7 +312,8 @@ def activations(self, x, out): return 0 message = [f"module: {typestr}", f"activations: {activations}"] - logging.debug("\t".join(message)) + if verbose: + logging.info("\t".join(message)) return activations @@ -402,17 +339,19 @@ def summarize_profiler_info(prof: torch.autograd.profiler.profile) -> str: class ComplexityComputer: - def __init__(self, compute_fn: Callable, count_unique: bool): + def __init__(self, compute_fn: Callable, count_unique: bool, verbose: bool = False): self.compute_fn = compute_fn self.count_unique = count_unique self.count = 0 + self.verbose = verbose self.seen_modules = set() def compute(self, layer: nn.Module, x: Any, out: Any, module_name: str): if self.count_unique and module_name in self.seen_modules: return - logging.debug(f"module name: {module_name}") - self.count += self.compute_fn(layer, x, out) + if self.verbose: + logging.info(f"module name: {module_name}") + self.count += self.compute_fn(layer, x, out, self.verbose) self.seen_modules.add(module_name) def reset(self): @@ -498,6 +437,7 @@ def compute_complexity( input_key: Optional[Union[str, List[str]]] = None, patch_attr: str = None, compute_unique: bool = False, + verbose: bool = False, ) -> int: """ Compute the complexity of a forward pass. @@ -517,7 +457,7 @@ def compute_complexity( else: input = get_model_dummy_input(model, input_shape, input_key) - complexity_computer = ComplexityComputer(compute_fn, compute_unique) + complexity_computer = ComplexityComputer(compute_fn, compute_unique, verbose) # measure FLOPs: modify_forward(model, complexity_computer, patch_attr=patch_attr) @@ -540,7 +480,7 @@ def compute_flops( Compute the number of FLOPs needed for a forward pass. """ return compute_complexity( - model, _layer_flops, input_shape, input_key, patch_attr="flops" + model, _layer_flops, input_shape, input_key, patch_attr="flops", verbose=True ) @@ -553,7 +493,12 @@ def compute_activations( Compute the number of activations created in a forward pass. """ return compute_complexity( - model, _layer_activations, input_shape, input_key, patch_attr="activations" + model, + _layer_activations, + input_shape, + input_key, + patch_attr="activations", + verbose=True, ) diff --git a/classy_vision/optim/classy_optimizer.py b/classy_vision/optim/classy_optimizer.py index 7be454bff7..9b61190c86 100644 --- a/classy_vision/optim/classy_optimizer.py +++ b/classy_vision/optim/classy_optimizer.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import copy from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Optional, Union @@ -221,7 +222,10 @@ def update_schedule_on_step(self, where: float) -> None: def _update_schedule(self) -> None: """Update the optimizer's parameters based on self.parameters.""" for group in self.optimizer.param_groups: - group.update(self.parameters) + lr_scale = group.get("lr_scale", 1.0) + parameters = copy.deepcopy(self.parameters) + parameters["lr"] *= lr_scale + group.update(parameters) # Here there's an assumption that pytorch optimizer maintain the order # of param_groups and that frozen_param_groups were added before the diff --git a/classy_vision/tasks/fine_tuning_task.py b/classy_vision/tasks/fine_tuning_task.py index 04b15b5223..f97a4e18b4 100644 --- a/classy_vision/tasks/fine_tuning_task.py +++ b/classy_vision/tasks/fine_tuning_task.py @@ -8,11 +8,18 @@ from classy_vision.generic.util import ( load_and_broadcast_checkpoint, + split_batchnorm_params, update_classy_model, ) from classy_vision.tasks import ClassificationTask, register_task +def _exclude_parameters(params, params_to_exclude): + params_to_exclude_ids = [id(p) for p in params_to_exclude] + params = [p for p in params if id(p) not in params_to_exclude_ids] + return params + + @register_task("fine_tuning") class FineTuningTask(ClassificationTask): def __init__(self, *args, **kwargs): @@ -41,6 +48,7 @@ def from_config(cls, config: Dict[str, Any]) -> "FineTuningTask": task.set_reset_heads(config.get("reset_heads", False)) task.set_freeze_trunk(config.get("freeze_trunk", False)) + task.set_head_lr_scale(config.get("head_lr_scale", 1.0)) return task def set_pretrained_checkpoint(self, checkpoint_path: str) -> "FineTuningTask": @@ -61,6 +69,10 @@ def set_freeze_trunk(self, freeze_trunk: bool) -> "FineTuningTask": self.freeze_trunk = freeze_trunk return self + def set_head_lr_scale(self, head_lr_scale: float) -> "FineTuningTask": + self.head_lr_scale = head_lr_scale + return self + def _set_model_train_mode(self): phase = self.phases[self.phase_idx] self.loss.train(phase["train"]) @@ -74,6 +86,61 @@ def _set_model_train_mode(self): else: self.base_model.train(phase["train"]) + def prepare_optimizer(self, optimizer, model, loss=None): + if not self.bn_weight_decay: + model_bn_params, model_params = split_batchnorm_params(model) + heads = model.get_heads() + heads_bn_parmas, heads_params = [], [] + for _, block_heads in heads.items(): + for head in block_heads: + head_bn_params, head_params = split_batchnorm_params(head) + heads_bn_parmas += head_bn_params + heads_params += head_params + + model_bn_params = _exclude_parameters(model_bn_params, heads_bn_parmas) + model_params = _exclude_parameters(model_params, heads_params) + if loss is not None: + bn_params_loss, params_loss = split_batchnorm_params(loss) + heads_bn_parmas += bn_params_loss + heads_params += params_loss + frozen_param_groups = [] + if len(model_bn_params) > 0: + frozen_param_groups.append( + {"params": model_bn_params, "weight_decay": 0} + ) + if len(heads_bn_parmas) > 0: + frozen_param_groups.append( + { + "params": heads_bn_parmas, + "weight_decay": 0, + "lr_scale": self.head_lr_scale, + } + ) + param_groups = [ + {"params": model_params}, + {"params": heads_params, "lr_scale": self.head_lr_scale}, + ] + else: + frozen_param_groups = None + model_params = model.parameters() + heads_params = [] + heads = model.get_heads() + for _, block_heads in heads.items(): + for head in block_heads: + heads_params += head.parameters() + model_params = _exclude_parameters(model_params, heads_params) + if loss is not None: + heads_params += loss.parameters() + + param_groups = [ + {"params": model_params}, + {"params": heads_params, "lr_scale": self.head_lr_scale}, + ] + + self.optimizer.set_param_groups( + param_groups=param_groups, frozen_param_groups=frozen_param_groups + ) + def prepare(self) -> None: super().prepare() if self.checkpoint_dict is None: