Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

expose dilation argument in Kinetics400Dataset class #557

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions classy_vision/dataset/classy_kinetics400.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
audio_channels: int,
step_between_clips: int,
frame_rate: Optional[int],
dilation: int,
clips_per_video: int,
video_dir: str,
extensions: List[str],
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(
step_between_clips: Number of frames between each clip.
frame_rate: desired video frame rate. If None, keep
orignal video frame rate.
dilation: the spacing between adjacent sampled frames
clips_per_video: Number of clips to sample from each video
video_dir: path to video folder
extensions: A list of file extensions, such as "avi" and "mp4". Only
Expand Down Expand Up @@ -104,6 +106,7 @@ def __init__(
_video_min_dimension=video_min_dimension,
_audio_samples=audio_samples,
_audio_channels=audio_channels,
dilation=dilation,
)
metadata = dataset.metadata
if metadata and not os.path.exists(metadata_filepath):
Expand Down Expand Up @@ -149,6 +152,7 @@ def from_config(cls, config: Dict[str, Any]) -> "Kinetics400Dataset":
audio_samples,
step_between_clips,
frame_rate,
dilation,
clips_per_video,
) = cls.parse_config(config)
extensions = config.get("extensions", ("mp4"))
Expand All @@ -169,6 +173,7 @@ def from_config(cls, config: Dict[str, Any]) -> "Kinetics400Dataset":
audio_channels,
step_between_clips,
frame_rate,
dilation,
clips_per_video,
config["video_dir"],
extensions,
Expand Down
2 changes: 2 additions & 0 deletions classy_vision/dataset/classy_video_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def parse_config(cls, config: Dict[str, Any]):
audio_samples = config.get("audio_samples", 0)
step_between_clips = config.get("step_between_clips", 1)
frame_rate = config.get("frame_rate", None)
dilation = config.get("dilation", 1)
clips_per_video = config.get("clips_per_video", 1)

(
Expand Down Expand Up @@ -157,6 +158,7 @@ def parse_config(cls, config: Dict[str, Any]):
audio_samples,
step_between_clips,
frame_rate,
dilation,
clips_per_video,
)

Expand Down
211 changes: 77 additions & 134 deletions classy_vision/generic/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _get_batchsize_per_replica(x: Union[Tuple, List, Dict]) -> int:
return x.size()[0]


def _layer_flops(layer: nn.Module, x: Any, y: Any) -> int:
def _layer_flops(layer: nn.Module, x: Any, y: Any, verbose: bool=False) -> int:
"""
Computes the number of FLOPs required for a single layer.

Expand Down Expand Up @@ -162,6 +162,36 @@ def flops(self, x):
/ layer.groups
)

# 3D convolution
elif layer_type in ["Conv3d"]:
out_t = int(
(x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0])
// layer.stride[0]
+ 1
)
out_h = int(
(x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1])
// layer.stride[1]
+ 1
)
out_w = int(
(x.size()[4] + 2 * layer.padding[2] - layer.kernel_size[2])
// layer.stride[2]
+ 1
)
flops = (
batchsize_per_replica
* layer.in_channels
* layer.out_channels
* layer.kernel_size[0]
* layer.kernel_size[1]
* layer.kernel_size[2]
* out_t
* out_h
* out_w
/ layer.groups
)

# learned group convolution:
elif layer_type in ["LearnedGroupConv"]:
conv = layer.conv
Expand All @@ -186,51 +216,36 @@ def flops(self, x):
)
flops = count1 + count2

# non-linearities:
# non-linearities are not considered in MAC counting
elif layer_type in ["ReLU", "ReLU6", "Tanh", "Sigmoid", "Softmax"]:
flops = x.numel()

# 2D pooling layers:
elif layer_type in ["AvgPool2d", "MaxPool2d"]:
in_h = x.size()[2]
in_w = x.size()[3]
if isinstance(layer.kernel_size, int):
layer.kernel_size = (layer.kernel_size, layer.kernel_size)
kernel_ops = layer.kernel_size[0] * layer.kernel_size[1]
out_h = 1 + int(
(in_h + 2 * layer.padding - layer.kernel_size[0]) / layer.stride
)
out_w = 1 + int(
(in_w + 2 * layer.padding - layer.kernel_size[1]) / layer.stride
flops = 0

elif layer_type in [
"MaxPool1d",
"MaxPool2d",
"MaxPool3d",
"AdaptiveMaxPool1d",
"AdaptiveMaxPool2d",
"AdaptiveMaxPool3d",
]:
flops = 0

elif layer_type in ["AvgPool1d", "AvgPool2d", "AvgPool3d"]:
kernel_ops = 1
flops = kernel_ops * y.numel()

elif layer_type in ["AdaptiveAvgPool1d", "AdaptiveAvgPool2d", "AdaptiveAvgPool3d"]:
assert isinstance(layer.output_size, (list, tuple))
kernel = torch.Tensor(list(x.shape[2:])) // torch.Tensor(
[list(layer.output_size)]
)
flops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops

# adaptive avg pool2d
# This is approximate and works only for downsampling without padding
# based on aten/src/ATen/native/AdaptiveAveragePooling.cpp
elif layer_type in ["AdaptiveAvgPool2d"]:
in_h = x.size()[2]
in_w = x.size()[3]
if isinstance(layer.output_size, int):
out_h, out_w = layer.output_size, layer.output_size
elif len(layer.output_size) == 1:
out_h, out_w = layer.output_size[0], layer.output_size[0]
else:
out_h, out_w = layer.output_size
if out_h > in_h or out_w > in_w:
raise ClassyProfilerNotImplementedError(layer)
batchsize_per_replica = x.size()[0]
num_channels = x.size()[1]
kh = in_h - out_h + 1
kw = in_w - out_w + 1
kernel_ops = kh * kw
flops = batchsize_per_replica * num_channels * out_h * out_w * kernel_ops
kernel_ops = torch.prod(kernel)
flops = kernel_ops * y.numel()

# linear layer:
elif layer_type in ["Linear"]:
weight_ops = layer.weight.numel()
bias_ops = layer.bias.numel() if layer.bias is not None else 0
flops = x.size()[0] * (weight_ops + bias_ops)
flops = x.size()[0] * weight_ops

# batch normalization / layer normalization:
elif layer_type in [
Expand All @@ -240,94 +255,12 @@ def flops(self, x):
"SyncBatchNorm",
"LayerNorm",
]:
flops = 2 * x.numel()

# 3D convolution
elif layer_type in ["Conv3d"]:
out_t = int(
(x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0])
// layer.stride[0]
+ 1
)
out_h = int(
(x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1])
// layer.stride[1]
+ 1
)
out_w = int(
(x.size()[4] + 2 * layer.padding[2] - layer.kernel_size[2])
// layer.stride[2]
+ 1
)
flops = (
batchsize_per_replica
* layer.in_channels
* layer.out_channels
* layer.kernel_size[0]
* layer.kernel_size[1]
* layer.kernel_size[2]
* out_t
* out_h
* out_w
/ layer.groups
)

# 3D pooling layers
elif layer_type in ["AvgPool3d", "MaxPool3d"]:
in_t = x.size()[2]
in_h = x.size()[3]
in_w = x.size()[4]
if isinstance(layer.kernel_size, int):
layer.kernel_size = (
layer.kernel_size,
layer.kernel_size,
layer.kernel_size,
)
if isinstance(layer.padding, int):
layer.padding = (layer.padding, layer.padding, layer.padding)
if isinstance(layer.stride, int):
layer.stride = (layer.stride, layer.stride, layer.stride)
kernel_ops = layer.kernel_size[0] * layer.kernel_size[1] * layer.kernel_size[2]
out_t = 1 + int(
(in_t + 2 * layer.padding[0] - layer.kernel_size[0]) / layer.stride[0]
)
out_h = 1 + int(
(in_h + 2 * layer.padding[1] - layer.kernel_size[1]) / layer.stride[1]
)
out_w = 1 + int(
(in_w + 2 * layer.padding[2] - layer.kernel_size[2]) / layer.stride[2]
)
flops = batchsize_per_replica * x.size()[1] * out_t * out_h * out_w * kernel_ops

# adaptive avg pool3d
# This is approximate and works only for downsampling without padding
# based on aten/src/ATen/native/AdaptiveAveragePooling3d.cpp
elif layer_type in ["AdaptiveAvgPool3d"]:
in_t = x.size()[2]
in_h = x.size()[3]
in_w = x.size()[4]
out_t = layer.output_size[0]
out_h = layer.output_size[1]
out_w = layer.output_size[2]
if out_t > in_t or out_h > in_h or out_w > in_w:
raise ClassyProfilerNotImplementedError(layer)
batchsize_per_replica = x.size()[0]
num_channels = x.size()[1]
kt = in_t - out_t + 1
kh = in_h - out_h + 1
kw = in_w - out_w + 1
kernel_ops = kt * kh * kw
flops = (
batchsize_per_replica * num_channels * out_t * out_w * out_h * kernel_ops
)
# batchnorm can be merged into conv op. Thus, count 0 FLOPS
flops = 0

# dropout layer
elif layer_type in ["Dropout"]:
# At test time, we do not drop values but scale the feature map by the
# dropout ratio
flops = 1
for dim_size in x.size():
flops *= dim_size
flops = 0

elif layer_type == "Identity":
flops = 0
Expand All @@ -351,11 +284,12 @@ def flops(self, x):
f"params(M): {count_params(layer) / 1e6}",
f"flops(M): {int(flops) / 1e6}",
]
logging.debug("\t".join(message))
if verbose:
logging.info("\t".join(message))
return flops


def _layer_activations(layer: nn.Module, x: Any, out: Any) -> int:
def _layer_activations(layer: nn.Module, x: Any, out: Any, verbose: bool=False) -> int:
"""
Computes the number of activations produced by a single layer.

Expand All @@ -376,7 +310,8 @@ def activations(self, x, out):
return 0

message = [f"module: {typestr}", f"activations: {activations}"]
logging.debug("\t".join(message))
if verbose:
logging.info("\t".join(message))
return activations


Expand All @@ -402,17 +337,19 @@ def summarize_profiler_info(prof: torch.autograd.profiler.profile) -> str:


class ComplexityComputer:
def __init__(self, compute_fn: Callable, count_unique: bool):
def __init__(self, compute_fn: Callable, count_unique: bool, verbose: bool = False):
self.compute_fn = compute_fn
self.count_unique = count_unique
self.count = 0
self.verbose = verbose
self.seen_modules = set()

def compute(self, layer: nn.Module, x: Any, out: Any, module_name: str):
if self.count_unique and module_name in self.seen_modules:
return
logging.debug(f"module name: {module_name}")
self.count += self.compute_fn(layer, x, out)
if self.verbose:
logging.info(f"module name: {module_name}")
self.count += self.compute_fn(layer, x, out, self.verbose)
self.seen_modules.add(module_name)

def reset(self):
Expand Down Expand Up @@ -498,6 +435,7 @@ def compute_complexity(
input_key: Optional[Union[str, List[str]]] = None,
patch_attr: str = None,
compute_unique: bool = False,
verbose: bool = False,
) -> int:
"""
Compute the complexity of a forward pass.
Expand All @@ -517,7 +455,7 @@ def compute_complexity(
else:
input = get_model_dummy_input(model, input_shape, input_key)

complexity_computer = ComplexityComputer(compute_fn, compute_unique)
complexity_computer = ComplexityComputer(compute_fn, compute_unique, verbose)

# measure FLOPs:
modify_forward(model, complexity_computer, patch_attr=patch_attr)
Expand All @@ -540,7 +478,7 @@ def compute_flops(
Compute the number of FLOPs needed for a forward pass.
"""
return compute_complexity(
model, _layer_flops, input_shape, input_key, patch_attr="flops"
model, _layer_flops, input_shape, input_key, patch_attr="flops", verbose=True
)


Expand All @@ -553,7 +491,12 @@ def compute_activations(
Compute the number of activations created in a forward pass.
"""
return compute_complexity(
model, _layer_activations, input_shape, input_key, patch_attr="activations"
model,
_layer_activations,
input_shape,
input_key,
patch_attr="activations",
verbose=True,
)


Expand Down