Skip to content

Commit

Permalink
FIX: support non-sequential recursive module
Browse files Browse the repository at this point in the history
  • Loading branch information
T-K-233 committed Oct 16, 2024
1 parent bf31616 commit b199ebd
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions converter/src/torchconverter/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def print_graph(self):
"""
self.gm.graph.print_tabular()

def get_modules_from_sequential(self, module: torch.nn.Sequential, indicies: List[int]) -> torch.nn.Module:
def _get_inner_module(self, module: torch.nn.Module, target_hierarchy: List[str]) -> torch.nn.Module:
"""
Get a module in a nn.Sequential layer.
This function will recursively unpack the nn.Sequential layers to get the innermost module.
Expand All @@ -143,9 +143,14 @@ def get_modules_from_sequential(self, module: torch.nn.Sequential, indicies: Lis
Returns:
The innermost module.
"""
if len(indicies) == 0:
return module
return self.get_modules_from_sequential(module[indicies[0]], indicies[1:])
module_name = target_hierarchy[0]
target_hierarchy = target_hierarchy[1:]
submodule = getattr(module, module_name)

if len(target_hierarchy) == 0:
return submodule

return self._get_inner_module(submodule, target_hierarchy)

def get_module(self, module_name: str) -> torch.nn.Module:
"""
Expand All @@ -161,13 +166,7 @@ def get_module(self, module_name: str) -> torch.nn.Module:
if "." in module_name:
# if we have nn.Sequential layers
target_hierarchy = module_name.split(".")
sequential_name = target_hierarchy[0]

# indicies = target_hierarchy[1:]
indicies = [int(x) for x in target_hierarchy[1:]]

module = getattr(self.model, sequential_name)
return self.get_modules_from_sequential(module, indicies)
return self._get_inner_module(self.model, target_hierarchy)

return getattr(self.model, module_name)

Expand Down

0 comments on commit b199ebd

Please sign in to comment.