diff --git a/fvcore/nn/jit_analysis.py b/fvcore/nn/jit_analysis.py index dafe996..3f560af 100644 --- a/fvcore/nn/jit_analysis.py +++ b/fvcore/nn/jit_analysis.py @@ -111,6 +111,22 @@ def _named_modules_without_dup(model: nn.Module) -> Iterator[Tuple[str, nn.Modul yield name, mod +def _maybe_flatten(object) -> List[torch.Tensor]: + # Try its best to find all tensors within the object and put them + # into a flattened list. Custom stuctures cannot be recognized. + # TODO: improve coverage of other structures, e.g. by using __dict__ + ret = [] + if isinstance(object, torch.Tensor): + ret.append(object) + if isinstance(object, (list, tuple)): + for x in object: + ret.extend(_maybe_flatten(x)) + if isinstance(object, dict): + for x in object.values(): + ret.extend(_maybe_flatten(x)) + return ret + + def _get_scoped_trace_graph( module: nn.Module, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -149,8 +165,11 @@ def __call__(self, module: nn.Module, inputs: Any, outputs: Any) -> Any: tracing_state = torch._C._get_tracing_state() if tracing_state: tracing_state.pop_scope() + # Don't save all intermediate tensors on GPU. There could be a lot. + all_output_tensors.extend([x.cpu() for x in _maybe_flatten(outputs)]) return outputs + all_output_tensors: List[torch.Tensor] = [] hook_handles: List[Any] = [] def register_hooks(mod: nn.Module, name: str) -> None: @@ -173,6 +192,27 @@ def register_hooks(mod: nn.Module, name: str) -> None: name = aliases[mod] register_hooks(mod, name) + class WrapperModule(nn.Module): + def __init__(self, module): + super().__init__() + self._wrapped = module + + def forward(self, *args): + # Some intermediate tensors may not be directly connected to the final model + # output, for example due to: + # * control flow not observed by tracing + # * tensor -> numpy/int conversion + # Operations that produce such tensors will get pruned by pytorch's DCE, + # but we want to include them in the graph. + # There is currently no way to disable DCE. So we capture all tensors we can + # and return them here, to reduce missing flops. + outputs = self._wrapped(*args) + return outputs, all_output_tensors + + # Hooks are registered before wrapping with their original scope names, so + # adding a wrapper here won't affect scopes. + module = WrapperModule(module) + graph, _ = _get_trace_graph(module, inputs) for handle in hook_handles: diff --git a/tests/test_jit_model_analysis.py b/tests/test_jit_model_analysis.py index 1b212fc..369815c 100644 --- a/tests/test_jit_model_analysis.py +++ b/tests/test_jit_model_analysis.py @@ -6,13 +6,14 @@ import unittest import warnings from collections import Counter -from typing import Any, Dict, List +from typing import Any, Dict, List, Union import torch import torch.nn as nn from fvcore.nn.flop_count import FlopCountAnalysis from fvcore.nn.jit_analysis import JitModelAnalysis from fvcore.nn.jit_handles import addmm_flop_jit, conv_flop_jit, Handle, linear_flop_jit +from torch.nn import functional as F class NestedNetInnerModule(nn.Module): @@ -283,20 +284,28 @@ class TraceWarningNet(nn.Module): will be skipped and raise a warning. """ + class IntLinear(nn.Linear): + """ + A linear that outputs int, therefore cannot be traced. + """ + + def forward(self, x) -> Union[float, int]: + return F.linear(x, self.weight, self.bias).item() + def __init__(self) -> None: super().__init__() self.input_size = (10,) fc1_in, fc1_out = 10, 1 fc2_in, fc2_out = 10, 10 - self.fc1 = nn.Linear(in_features=fc1_in, out_features=fc1_out) + self.fc1 = TraceWarningNet.IntLinear(in_features=fc1_in, out_features=fc1_out) self.fc2 = nn.Linear(in_features=fc2_in, out_features=fc2_out) self.fc1_flops: int = fc1_in * fc1_out self.fc2_flops: int = fc2_in * fc2_out def forward(self, x: torch.Tensor) -> torch.Tensor: - y = self.fc1(x).item() + y = self.fc1(x) warnings.warn("Dummy RuntimeWarning.", RuntimeWarning) if y < 0.0: x = self.fc2(x) @@ -806,6 +815,24 @@ def test_disable_warnings(self) -> None: self.assertTrue(any(uncalled_msg in s for s in cm.output)) self.assertTrue(any(uncalled_modules in s for s in cm.output)) + def test_capture_intermediate_outputs(self) -> None: + class TestCaptureNet(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = nn.Linear(10, 1) + self.fc2 = nn.Linear(10, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = self.fc1(x) + del y # unused by output + return self.fc2(x) + 2 + + model = TestCaptureNet() + inputs = (torch.randn((1, 10)),) + analyzer = FlopCountAnalysis(model=model, inputs=inputs) + _ = analyzer.total() + self.assertEqual(analyzer.uncalled_modules(), set()) + def test_skip_uncalled_containers_warnings(self) -> None: # uncalled containers should not warn