Skip to content

Commit

Permalink
flop count: capture output tensors from all layers so that unused lay…
Browse files Browse the repository at this point in the history
…ers are correctly counted.

Reviewed By: ericmintun

Differential Revision: D32242109

fbshipit-source-id: 362b68a2b7c50b1ec2efd0d415d7ec3e6b2ba1c8
  • Loading branch information
ppwwyyxx committed May 21, 2022
1 parent e4f0b3d commit a5bc776
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 3 deletions.
40 changes: 40 additions & 0 deletions fvcore/nn/jit_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,22 @@ def _named_modules_without_dup(model: nn.Module) -> Iterator[Tuple[str, nn.Modul
yield name, mod


def _maybe_flatten(object) -> List[torch.Tensor]:
# Try its best to find all tensors within the object and put them
# into a flattened list. Custom stuctures cannot be recognized.
# TODO: improve coverage of other structures, e.g. by using __dict__
ret = []
if isinstance(object, torch.Tensor):
ret.append(object)
if isinstance(object, (list, tuple)):
for x in object:
ret.extend(_maybe_flatten(x))
if isinstance(object, dict):
for x in object.values():
ret.extend(_maybe_flatten(x))
return ret


def _get_scoped_trace_graph(
module: nn.Module,
inputs: Union[Tensor, Tuple[Tensor, ...]],
Expand Down Expand Up @@ -149,8 +165,11 @@ def __call__(self, module: nn.Module, inputs: Any, outputs: Any) -> Any:
tracing_state = torch._C._get_tracing_state()
if tracing_state:
tracing_state.pop_scope()
# Don't save all intermediate tensors on GPU. There could be a lot.
all_output_tensors.extend([x.cpu() for x in _maybe_flatten(outputs)])
return outputs

all_output_tensors: List[torch.Tensor] = []
hook_handles: List[Any] = []

def register_hooks(mod: nn.Module, name: str) -> None:
Expand All @@ -173,6 +192,27 @@ def register_hooks(mod: nn.Module, name: str) -> None:
name = aliases[mod]
register_hooks(mod, name)

class WrapperModule(nn.Module):
def __init__(self, module):
super().__init__()
self._wrapped = module

def forward(self, *args):
# Some intermediate tensors may not be directly connected to the final model
# output, for example due to:
# * control flow not observed by tracing
# * tensor -> numpy/int conversion
# Operations that produce such tensors will get pruned by pytorch's DCE,
# but we want to include them in the graph.
# There is currently no way to disable DCE. So we capture all tensors we can
# and return them here, to reduce missing flops.
outputs = self._wrapped(*args)
return outputs, all_output_tensors

# Hooks are registered before wrapping with their original scope names, so
# adding a wrapper here won't affect scopes.
module = WrapperModule(module)

graph, _ = _get_trace_graph(module, inputs)

for handle in hook_handles:
Expand Down
33 changes: 30 additions & 3 deletions tests/test_jit_model_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import unittest
import warnings
from collections import Counter
from typing import Any, Dict, List
from typing import Any, Dict, List, Union

import torch
import torch.nn as nn
from fvcore.nn.flop_count import FlopCountAnalysis
from fvcore.nn.jit_analysis import JitModelAnalysis
from fvcore.nn.jit_handles import addmm_flop_jit, conv_flop_jit, Handle, linear_flop_jit
from torch.nn import functional as F


class NestedNetInnerModule(nn.Module):
Expand Down Expand Up @@ -283,20 +284,28 @@ class TraceWarningNet(nn.Module):
will be skipped and raise a warning.
"""

class IntLinear(nn.Linear):
"""
A linear that outputs int, therefore cannot be traced.
"""

def forward(self, x) -> Union[float, int]:
return F.linear(x, self.weight, self.bias).item()

def __init__(self) -> None:
super().__init__()
self.input_size = (10,)
fc1_in, fc1_out = 10, 1
fc2_in, fc2_out = 10, 10

self.fc1 = nn.Linear(in_features=fc1_in, out_features=fc1_out)
self.fc1 = TraceWarningNet.IntLinear(in_features=fc1_in, out_features=fc1_out)
self.fc2 = nn.Linear(in_features=fc2_in, out_features=fc2_out)

self.fc1_flops: int = fc1_in * fc1_out
self.fc2_flops: int = fc2_in * fc2_out

def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.fc1(x).item()
y = self.fc1(x)
warnings.warn("Dummy RuntimeWarning.", RuntimeWarning)
if y < 0.0:
x = self.fc2(x)
Expand Down Expand Up @@ -806,6 +815,24 @@ def test_disable_warnings(self) -> None:
self.assertTrue(any(uncalled_msg in s for s in cm.output))
self.assertTrue(any(uncalled_modules in s for s in cm.output))

def test_capture_intermediate_outputs(self) -> None:
class TestCaptureNet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(10, 1)
self.fc2 = nn.Linear(10, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.fc1(x)
del y # unused by output
return self.fc2(x) + 2

model = TestCaptureNet()
inputs = (torch.randn((1, 10)),)
analyzer = FlopCountAnalysis(model=model, inputs=inputs)
_ = analyzer.total()
self.assertEqual(analyzer.uncalled_modules(), set())

def test_skip_uncalled_containers_warnings(self) -> None:
# uncalled containers should not warn

Expand Down

0 comments on commit a5bc776

Please sign in to comment.