Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Profiler bug-fixes and improvements (#482)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #482

Made the following changes to the profiler code -
- The parameter calculation skipped parameters defined in modules, and had unnecessary complexity - updated the logic
- `AdaptiveAvgPool2d` handles a single number `output_size` as well
- Added FLOPs for the `Identity` module
- Added support to specify an `activations` function to fetch activations from (similar to the `flops` function)
- Replaced the hacky list append logic with a class, `_ComplexityComputer`
- Implemented test cases which verify that the fixes work

Reviewed By: vreis

Differential Revision: D21009734

fbshipit-source-id: 926d93164c13c6c98eb88f9131d295b61d6acda7
  • Loading branch information
mannatsingh authored and facebook-github-bot committed Apr 21, 2020
1 parent 5643849 commit e5b9873
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 35 deletions.
113 changes: 79 additions & 34 deletions classy_vision/generic/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import collections.abc as abc
import logging
import operator
from typing import Callable

import torch
import torch.nn as nn
Expand Down Expand Up @@ -183,8 +184,12 @@ def flops(self, x):
elif layer_type in ["AdaptiveAvgPool2d"]:
in_h = x.size()[2]
in_w = x.size()[3]
out_h = layer.output_size[0]
out_w = layer.output_size[1]
if isinstance(layer.output_size, int):
out_h, out_w = layer.output_size, layer.output_size
elif len(layer.output_size) == 1:
out_h, out_w = layer.output_size[0], layer.output_size[0]
else:
out_h, out_w = layer.output_size
if out_h > in_h or out_w > in_w:
raise NotImplementedError()
batchsize_per_replica = x.size()[0]
Expand Down Expand Up @@ -295,6 +300,10 @@ def flops(self, x):
for dim_size in x.size():
flops *= dim_size
return flops

elif layer_type == "Identity":
return 0

elif hasattr(layer, "flops"):
# If the module already defines a method to compute flops with the signature
# below, we use it to compute flops
Expand All @@ -312,8 +321,16 @@ def _layer_activations(layer, x, out):
"""
Computes the number of activations produced by a single layer.
Activations are counted only for convolutional layers.
Activations are counted only for convolutional layers. To override this behavior, a
layer can define a method to compute activations with the signature below, which
will be used to compute the activations instead.
Class MyModule(nn.Module):
def activations(self, x, out):
...
"""
if hasattr(layer, "activations"):
return layer.activations(x, out)
return out.numel() if isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Conv3d)) else 0


Expand All @@ -338,11 +355,25 @@ def summarize_profiler_info(prof):
return str


def _patched_computation_module(module, compute_list, compute_fn):
class _ComplexityComputer:
def __init__(self, compute_fn: Callable, count_unique: bool):
self.compute_fn = compute_fn
self.count_unique = count_unique
self.count = 0
self.seen_modules = set()

def compute(self, layer, x, out, module_name):
if self.count_unique and module_name in self.seen_modules:
return
self.count += self.compute_fn(layer, x, out)
self.seen_modules.add(module_name)


def _patched_computation_module(module, complexity_computer, module_name):
"""
Patch the module to compute a module's parameters, like FLOPs.
Calls compute_fn and appends the results to compute_list.
Calls compute_fn and passes the results to the complexity computer.
"""
ty = type(module)
typestring = module.__repr__()
Expand All @@ -355,7 +386,7 @@ def _original_forward(self, *args, **kwargs):

def forward(self, *args, **kwargs):
out = self._original_forward(*args, **kwargs)
compute_list.append(compute_fn(self, args[0], out))
complexity_computer.compute(self, args[0], out, module_name)
return out

def __repr__(self):
Expand All @@ -364,37 +395,58 @@ def __repr__(self):
return ComputeModule


def modify_forward(model, compute_list, compute_fn):
def modify_forward(model, complexity_computer, prefix="", patch_attr=None):
"""
Modify forward pass to measure a module's parameters, like FLOPs.
"""
if is_leaf(model) or hasattr(model, "flops"):
model.__class__ = _patched_computation_module(model, compute_list, compute_fn)
if is_leaf(model) or (patch_attr is not None and hasattr(model, patch_attr)):
model.__class__ = _patched_computation_module(
model, complexity_computer, prefix
)

else:
for child in model.children():
modify_forward(child, compute_list, compute_fn)
for name, child in model.named_children():
modify_forward(
child,
complexity_computer,
prefix=f"{prefix}.{name}",
patch_attr=patch_attr,
)

return model


def restore_forward(model):
def restore_forward(model, patch_attr=None):
"""
Restore original forward in model:
Restore original forward in model.
"""
if is_leaf(model) or hasattr(model, "flops"):
if is_leaf(model) or (patch_attr is not None and hasattr(model, patch_attr)):
model.__class__ = model.orig_type

else:
for child in model.children():
restore_forward(child)
restore_forward(child, patch_attr=patch_attr)

return model


def compute_complexity(model, compute_fn, input_shape, input_key=None):
def compute_complexity(
model,
compute_fn,
input_shape,
input_key=None,
patch_attr=None,
compute_unique=False,
):
"""
Compute the complexity of a forward pass.
Args:
compute_unique: If True, the compexity for a given module is only calculated
once. Otherwise, it is counted every time the module is called.
TODO(@mannatsingh): We have some assumptions about only modules which are leaves
or have patch_attr defined. This should be fixed and generalized if possible.
"""
# assertions, input, and upvalue in which we will perform the count:
assert isinstance(model, nn.Module)
Expand All @@ -404,50 +456,43 @@ def compute_complexity(model, compute_fn, input_shape, input_key=None):
else:
input = get_model_dummy_input(model, input_shape, input_key)

compute_list = []
complexity_computer = _ComplexityComputer(compute_fn, compute_unique)

# measure FLOPs:
modify_forward(model, compute_list, compute_fn)
modify_forward(model, complexity_computer, patch_attr=patch_attr)
try:
# compute complexity in eval mode
with eval_model(model), torch.no_grad():
model.forward(input)
except NotImplementedError as err:
raise err
finally:
restore_forward(model)
restore_forward(model, patch_attr=patch_attr)

return sum(compute_list)
return complexity_computer.count


def compute_flops(model, input_shape=(3, 224, 224), input_key=None):
"""
Compute the number of FLOPs needed for a forward pass.
"""
return compute_complexity(model, _layer_flops, input_shape, input_key)
return compute_complexity(
model, _layer_flops, input_shape, input_key, patch_attr="flops"
)


def compute_activations(model, input_shape=(3, 224, 224), input_key=None):
"""
Compute the number of activations created in a forward pass.
"""
return compute_complexity(model, _layer_activations, input_shape, input_key)
return compute_complexity(
model, _layer_activations, input_shape, input_key, patch_attr="activations"
)


def count_params(model):
"""
Count the number of parameters in a model.
"""
assert isinstance(model, nn.Module)
count = 0
for child in model.children():
if is_leaf(child):
if hasattr(child, "_mask"): # for masked modules (like LGC)
count += child._mask.long().sum().item()
# FIXME: BatchNorm parameters in LGC are not counted.
else: # for regular modules
for p in child.parameters():
count += p.nelement()
else:
count += count_params(child)
return count
return sum((parameter.nelement() for parameter in model.parameters()))
75 changes: 74 additions & 1 deletion test/generic_profiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import unittest
from test.generic.config_utils import get_test_model_configs

import torch
import torch.nn as nn
from classy_vision.generic.profiler import (
compute_activations,
compute_flops,
Expand All @@ -15,8 +17,61 @@
from classy_vision.models import build_model


class TestModule(nn.Module):
def __init__(self):
super().__init__()
# add parameters to the module to affect the parameter count
self.linear = nn.Linear(2, 3, bias=False)

def forward(self, x):
return x + 1

def flops(self, x):
# TODO: this should raise an exception if this function is not defined
# since the FLOPs are indeterminable

# need to define flops since this is an unknown class
return x.numel()


class TestConvModule(nn.Conv2d):
def __init__(self):
super().__init__(2, 3, (4, 4), bias=False)
# add another (unused) layer for added complexity and to test parameters
self.linear = nn.Linear(4, 5, bias=False)

def forward(self, x):
return x

def activations(self, x, out):
# TODO: this should ideally work without this function being defined
return out.numel()

def flops(self, x):
# need to define flops since this is an unknown class
return 0


class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(300, 300, bias=False)
self.mod = TestModule()
self.conv = TestConvModule()
# we should be able to pick up user defined parameters as well
self.extra_params = nn.Parameter(torch.randn(10, 10))
# we shouldn't count flops for an unused layer
self.unused_linear = nn.Linear(2, 2, bias=False)

def forward(self, x):
out = self.conv(x)
out = out.view(out.shape[0], -1)
out = self.mod(out)
return self.linear(out)


class TestProfilerFunctions(unittest.TestCase):
def test_complexity_calculation(self) -> None:
def test_complexity_calculation_resnext(self) -> None:
model_configs = get_test_model_configs()
# make sure there are three configs returned
self.assertEqual(len(model_configs), 3)
Expand All @@ -34,3 +89,21 @@ def test_complexity_calculation(self) -> None:
self.assertEqual(compute_activations(model) // 10 ** 6, m_activations)
self.assertEqual(compute_flops(model) // 10 ** 6, m_flops)
self.assertEqual(count_params(model) // 10 ** 6, m_params)

def test_complexity_calculation(self) -> None:
model = TestModel()
input_shape = (3, 10, 10)
num_elems = 3 * 10 * 10
self.assertEqual(compute_activations(model, input_shape=input_shape), num_elems)
self.assertEqual(
compute_flops(model, input_shape=input_shape),
num_elems
+ 0
+ (300 * 300), # TestModule + TestConvModule + TestModel.linear;
# TestModel.unused_linear is unused and shouldn't be counted
)
self.assertEqual(
count_params(model),
(2 * 3) + (2 * 3 * 4 * 4) + (4 * 5) + (300 * 300) + (10 * 10) + (2 * 2),
) # TestModule.linear + TestConvModule + TestConvModule.linear +
# TestModel.linear + TestModel.extra_params + TestModel.unused_linear

0 comments on commit e5b9873

Please sign in to comment.