Skip to content

Commit

Permalink
expert_names as property
Browse files Browse the repository at this point in the history
  • Loading branch information
Alessandro Sordoni committed Jul 25, 2024
1 parent 2fc9f28 commit f3133a6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
12 changes: 7 additions & 5 deletions mttl/models/containers/expert_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from pyparsing import abstractmethod
from torch import nn
from torch import Tensor, nn

from mttl.config import Config
from mttl.logging import warn_once
Expand Down Expand Up @@ -44,7 +44,6 @@ def __init__(self, config, layer, selector=None):
self.selector = selector or TaskNameSelector()

self.expert_infos = {}
self.expert_names = []
self.default_expert_name = None

def assign_selector(self, selector: Selector) -> None:
Expand Down Expand Up @@ -80,14 +79,17 @@ def add_expert(self, expert: Expert, action="merge", is_default=False) -> None:
self.on_add_expert(expert, action=action, is_default=is_default)

self.expert_infos[expert.name] = expert_info
self.expert_names.append(expert.name)
self.default_expert_name: str | None = (
expert.name if is_default else self.default_expert_name
)
self.selector.add_expert(
expert.name, expert_info=expert_info, is_default=is_default
)

@property
def expert_names(self) -> list:
return list(self.expert_infos.keys())

def _check_config(self, expert_config: Union[Config, ModifierConfig]):
"""Checks if the config is supported and converts it to the supported config type if needed."""
if isinstance(expert_config, Config):
Expand Down Expand Up @@ -423,8 +425,8 @@ def __getitem__(self, name) -> Union[LoRA, SkilledLoRA]:
Arrow adds lora modules to the container, while MHR adds
skilled lora modules to the container.
"""
index_of = self.expert_names.index(name)
weights = self.experts.get_skill_weights(index_of)
index_of: int = self.expert_names.index(name)
weights: dict[str, Tensor] = self.experts.get_skill_weights(index_of)

config = self.expert_infos[name].expert_config
modifier_type = get_modifier_type(config)
Expand Down
6 changes: 4 additions & 2 deletions mttl/models/containers/selectors/base_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ def __init__(self, config=None, **kwargs):

self.config = config
self.expert_infos = {}
self.expert_names = []
self.selector_views = []
self.forward_cache = None
self.default_expert_name = None
Expand All @@ -279,6 +278,10 @@ def __init__(self, config=None, **kwargs):
# dependency injection filled from ExpertContainer
self.__layer_name__ = None

@property
def expert_names(self) -> list:
return list(self.expert_infos.keys())

@property
def clear_cache(self):
reset_cache = self._calls_counter >= self.total_calls_per_forward
Expand Down Expand Up @@ -345,7 +348,6 @@ def add_expert(
self.default_expert_name = expert_name

self.expert_infos[expert_name] = expert_info
self.expert_names.append(expert_name)


class SelectorView:
Expand Down

0 comments on commit f3133a6

Please sign in to comment.