diff --git a/mttl/models/containers/expert_containers.py b/mttl/models/containers/expert_containers.py index 59d464bb1..9d656a2ac 100644 --- a/mttl/models/containers/expert_containers.py +++ b/mttl/models/containers/expert_containers.py @@ -3,7 +3,7 @@ import torch from pyparsing import abstractmethod -from torch import nn +from torch import Tensor, nn from mttl.config import Config from mttl.logging import warn_once @@ -44,7 +44,6 @@ def __init__(self, config, layer, selector=None): self.selector = selector or TaskNameSelector() self.expert_infos = {} - self.expert_names = [] self.default_expert_name = None def assign_selector(self, selector: Selector) -> None: @@ -80,7 +79,6 @@ def add_expert(self, expert: Expert, action="merge", is_default=False) -> None: self.on_add_expert(expert, action=action, is_default=is_default) self.expert_infos[expert.name] = expert_info - self.expert_names.append(expert.name) self.default_expert_name: str | None = ( expert.name if is_default else self.default_expert_name ) @@ -88,6 +86,10 @@ def add_expert(self, expert: Expert, action="merge", is_default=False) -> None: expert.name, expert_info=expert_info, is_default=is_default ) + @property + def expert_names(self) -> list: + return list(self.expert_infos.keys()) + def _check_config(self, expert_config: Union[Config, ModifierConfig]): """Checks if the config is supported and converts it to the supported config type if needed.""" if isinstance(expert_config, Config): @@ -423,8 +425,8 @@ def __getitem__(self, name) -> Union[LoRA, SkilledLoRA]: Arrow adds lora modules to the container, while MHR adds skilled lora modules to the container. """ - index_of = self.expert_names.index(name) - weights = self.experts.get_skill_weights(index_of) + index_of: int = self.expert_names.index(name) + weights: dict[str, Tensor] = self.experts.get_skill_weights(index_of) config = self.expert_infos[name].expert_config modifier_type = get_modifier_type(config) diff --git a/mttl/models/containers/selectors/base_selectors.py b/mttl/models/containers/selectors/base_selectors.py index 20f16b9ba..94255baba 100644 --- a/mttl/models/containers/selectors/base_selectors.py +++ b/mttl/models/containers/selectors/base_selectors.py @@ -270,7 +270,6 @@ def __init__(self, config=None, **kwargs): self.config = config self.expert_infos = {} - self.expert_names = [] self.selector_views = [] self.forward_cache = None self.default_expert_name = None @@ -279,6 +278,10 @@ def __init__(self, config=None, **kwargs): # dependency injection filled from ExpertContainer self.__layer_name__ = None + @property + def expert_names(self) -> list: + return list(self.expert_infos.keys()) + @property def clear_cache(self): reset_cache = self._calls_counter >= self.total_calls_per_forward @@ -345,7 +348,6 @@ def add_expert( self.default_expert_name = expert_name self.expert_infos[expert_name] = expert_info - self.expert_names.append(expert_name) class SelectorView: