diff --git a/models/base.py b/models/base.py index 464ba7e..de4e929 100644 --- a/models/base.py +++ b/models/base.py @@ -1,8 +1,8 @@ -from abc import ABC, abstractmethod - import torch import torch.nn as nn +from abc import ABC, abstractmethod +from fvcore.nn import FlopCountAnalysis, parameter_count class TorchModelWrapper(nn.Module, ABC): def __init__(self, model_name): @@ -37,4 +37,17 @@ def replace_modules(self, replace_dict): from models.utils import replace_modules replace_modules(self.model, replace_dict) + def profile(self): + random_input = torch.randn(self.input_size) + flop_counter = FlopCountAnalysis(self.model, random_input) + # ignore batch norm + flop_counter._ignored_ops.add("aten::batch_norm") + del flop_counter._op_handles["aten::batch_norm"] + + param_counter = parameter_count(self.model) + macs = flop_counter.total() + params = param_counter[''] + + print(f"MACs: {macs}, Params: {params}") + from models.utils import generate_onnx_files \ No newline at end of file diff --git a/models/detection/coco.py b/models/detection/coco.py index 0dd4a2f..15e27b6 100644 --- a/models/detection/coco.py +++ b/models/detection/coco.py @@ -18,6 +18,7 @@ def load_model(self, eval=True): self.yolo = YOLO(self.model_name) self.model = self.yolo.model self.model_fixer() + self.input_size = (1, 3, self.yolo.overrides['imgsz'], self.yolo.overrides['imgsz']) # utlralytics conv bn fusion not working after compression, disable it def _fuse(verbose=True):