Skip to content

Commit

Permalink
Merge pull request #25 from experimaestro/manipulate-modulelist
Browse files Browse the repository at this point in the history
Manipulate modulelist
  • Loading branch information
bpiwowar authored Dec 1, 2023
2 parents de68d02 + 99842dc commit 59538f3
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 7 deletions.
26 changes: 19 additions & 7 deletions src/xpmir/learning/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,22 @@ def to(self, *args, **kwargs):
return torch.nn.Module.to(self, *args, **kwargs)


class ModuleList(Config, Initializable, torch.nn.Module):
class ModuleList(Module, Initializable):
"""Groups different models together, to be used within the Learner"""

modules: Param[List[Module]]
sub_modules: Param[List[Module]]

def __init__(self):
Initializable.__init__(self)
torch.nn.Module.__init__(self)
def __post_init__(self):
# Register sub-modules
for ix, sub_module in enumerate(self.sub_modules):
self.add_module(str(ix), sub_module)

def __initialize__(self, *args, **kwargs):
for module in self.modules:
for module in self.sub_modules:
module.initialize(*args, **kwargs)

def __call__(self, *args, **kwargs):
raise AssertionError("This module cannot be used as such: it is just a ")
raise AssertionError("This module cannot be used as such")

def to(self, *args, **kwargs):
return torch.nn.Module.to(self, *args, **kwargs)
Expand Down Expand Up @@ -136,6 +137,9 @@ def __init__(self):
def __validate__(self):
return self.includes or self.excludes

def __repr__(self) -> str:
return f"RegexParameterFilter({self.includes}, {self.excludes})"

def __call__(self, name, params) -> bool:
# Look first at included
if self.includes:
Expand Down Expand Up @@ -179,6 +183,9 @@ def create_optimizer(
for name, param in module.named_parameters()
if (self.filter is None or self.filter(name, param)) and filter(name, param)
]
if not params:
raise RuntimeError(f"Parameter list is empty with {self.filter}")

optimizer = self.optimizer(params)
return optimizer

Expand Down Expand Up @@ -210,6 +217,11 @@ def initialize(
self.scheduler_steps = -1 # Number of scheduler steps
self.num_training_steps = num_training_steps

try:
next(module.parameters())
except StopIteration:
raise RuntimeError("No parameters to optimize in the module")

filter = DuplicateParameterFilter()
for param_optimizer in param_optimizers:
optimizer = param_optimizer.create_optimizer(module, filter)
Expand Down
21 changes: 21 additions & 0 deletions src/xpmir/test/learning/test_optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from experimaestro import ObjectStore
import torch.nn as nn
from itertools import chain
from xpmir.learning.optim import Module, ModuleList


class MyModule(Module):
def __post_init__(self) -> None:
self.linear = nn.Linear(2, 3)


def test_module_list():
a = MyModule()
b = MyModule()
container = ModuleList(sub_modules=[a, b])

store = ObjectStore()
container = container.instance(objects=store)
a = a.instance(objects=store)
b = b.instance(objects=store)
assert set(container.parameters()) == set(chain(a.parameters(), b.parameters()))

0 comments on commit 59538f3

Please sign in to comment.