diff --git a/converter/src/torchconverter/tracer.py b/converter/src/torchconverter/tracer.py index b970c8d..340ebbf 100644 --- a/converter/src/torchconverter/tracer.py +++ b/converter/src/torchconverter/tracer.py @@ -131,7 +131,7 @@ def print_graph(self): """ self.gm.graph.print_tabular() - def get_modules_from_sequential(self, module: torch.nn.Sequential, indicies: List[int]) -> torch.nn.Module: + def _get_inner_module(self, module: torch.nn.Module, target_hierarchy: List[str]) -> torch.nn.Module: """ Get a module in a nn.Sequential layer. This function will recursively unpack the nn.Sequential layers to get the innermost module. @@ -143,9 +143,14 @@ def get_modules_from_sequential(self, module: torch.nn.Sequential, indicies: Lis Returns: The innermost module. """ - if len(indicies) == 0: - return module - return self.get_modules_from_sequential(module[indicies[0]], indicies[1:]) + module_name = target_hierarchy[0] + target_hierarchy = target_hierarchy[1:] + submodule = getattr(module, module_name) + + if len(target_hierarchy) == 0: + return submodule + + return self._get_inner_module(submodule, target_hierarchy) def get_module(self, module_name: str) -> torch.nn.Module: """ @@ -161,13 +166,7 @@ def get_module(self, module_name: str) -> torch.nn.Module: if "." in module_name: # if we have nn.Sequential layers target_hierarchy = module_name.split(".") - sequential_name = target_hierarchy[0] - - # indicies = target_hierarchy[1:] - indicies = [int(x) for x in target_hierarchy[1:]] - - module = getattr(self.model, sequential_name) - return self.get_modules_from_sequential(module, indicies) + return self._get_inner_module(self.model, target_hierarchy) return getattr(self.model, module_name)