Skip to content

Commit

Permalink
Allow containers of homogenous superclass in state dict (#96)
Browse files Browse the repository at this point in the history
Updates homogenous container check to allow subclasses that all derive
from the same superclass to be "mixed" in the same container. This
effectively allows lists/dicts of modules or params to be registered in
the state dict.

L0 test: 1959 passed, 47 skipped, 1 deselected, 1 warning

---------

Signed-off-by: Akhil Goel <[email protected]>
  • Loading branch information
akhilg-nv authored Aug 15, 2024
1 parent 0bb2b0b commit cef7968
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 6 deletions.
16 changes: 15 additions & 1 deletion tripy/tests/frontend/module/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self):

def __call__(self):
return self.param + self.dummy1() + self.dummy2()


class ListNetwork(tp.Module):
def __init__(self):
Expand Down Expand Up @@ -80,6 +80,16 @@ def __call__(self):
return out


class MixedNetwork(tp.Module):
def __init__(self):
super().__init__()
self.mixed_list = [DummyOp(tp.zeros((2,), dtype=tp.float32)), DummyNestedOp(tp.ones((2,), dtype=tp.float32))]
self.mixed_dict = {"dummy": DummyOp(tp.zeros((2,), dtype=tp.float32)), "dummy_nested": DummyNestedOp(tp.ones((2,), dtype=tp.float32))}

def __call__(self):
return self.mixed_list[0]() + self.mixed_list[1]() + self.mixed_dict["dummy"]() + self.mixed_dict["dummy_nested"]()


class ComplexNetwork(tp.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -115,6 +125,10 @@ def list_network():
def dict_network():
yield DictNetwork()

@pytest.fixture
def mixed_network():
yield MixedNetwork()


@pytest.fixture
def complex_network():
Expand Down
26 changes: 26 additions & 0 deletions tripy/tests/frontend/module/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,32 @@ def test_add_dict_params(self, dict_network):
assert tripy_params["new_params.param0"] is param0
assert tripy_params["new_params.param1"] is param1

class TestMixedModule:
def test_basic_structure(self, mixed_network):
module = mixed_network
assert hasattr(module, "mixed_list")
assert hasattr(module, "mixed_dict")
assert isinstance(module.mixed_list, list)
assert isinstance(module.mixed_dict, dict)
assert all(isinstance(list_module, tp.Module) for list_module in module.mixed_list)
assert all(isinstance(dict_module, tp.Module) for dict_module in module.mixed_dict.values())

def test_state_dict(self, mixed_network):
module = mixed_network
print(module.state_dict())
tensor = tp.ones((2,))
external_state_dict = {
"mixed_list.0.nested.param": tensor,
"mixed_list.1.param": tensor,
"mixed_dict.dummy.nested.param": tensor,
"mixed_dict.dummy_nested.param": tensor,
}
module.load_from_state_dict(external_state_dict)
assert module.mixed_list[0].nested.param is tensor
assert module.mixed_list[1].param is tensor
assert module.mixed_dict["dummy"].nested.param is tensor
assert module.mixed_dict["dummy_nested"].param is tensor


class TestComplexModule:
def test_basic_structure(self, complex_network):
Expand Down
10 changes: 5 additions & 5 deletions tripy/tripy/frontend/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def _check_param_compatible(original_param, new_param, param_name):
)


def _is_homogeneous_container(container: Sequence):
return len(set(map(type, container))) == 1
def _is_homogeneous_container(container: Sequence, typ: T):
return all(isinstance(op, typ) for op in container)


def _contains_types(container: Sequence, types: type):
Expand Down Expand Up @@ -106,7 +106,7 @@ def __setattr__(self, name: str, value: Any) -> None:

if isinstance(value, List) or isinstance(value, Dict):
container = value if isinstance(value, List) else value.values()
if _contains_types(container, [Parameter, Module]) and not _is_homogeneous_container(container):
if _contains_types(container, [Parameter, Module]) and not _is_homogeneous_container(container, Parameter):
logger.warning("A container of mixed types will not be registered with this module's state_dict().")

def state_dict(self) -> Dict[str, Parameter]:
Expand Down Expand Up @@ -276,13 +276,13 @@ def _iterate_members_of_type(self, typ: T) -> Iterator[Tuple[str, T]]:
for name, value in vars(self).items():
if isinstance(value, typ):
yield name, value
elif isinstance(value, List) and _contains_types(value, [typ]) and _is_homogeneous_container(value):
elif isinstance(value, List) and _contains_types(value, [typ]) and _is_homogeneous_container(value, typ):
for i, obj in enumerate(value):
yield f"{name}.{i}", obj
elif (
isinstance(value, Dict)
and _contains_types(value.values(), [typ])
and _is_homogeneous_container(value.values())
and _is_homogeneous_container(value.values(), typ)
):
for key, obj in value.items():
yield f"{name}.{key}", obj
Expand Down

0 comments on commit cef7968

Please sign in to comment.