Skip to content

Commit

Permalink
Merge pull request #71 from microsoft/registrable
Browse files Browse the repository at this point in the history
Registrable
  • Loading branch information
sordonia authored Jul 26, 2024
2 parents caad815 + ee86648 commit a4a8340
Show file tree
Hide file tree
Showing 26 changed files with 266 additions and 265 deletions.
21 changes: 6 additions & 15 deletions mttl/models/containers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from mttl.models.library.expert import Expert
from mttl.models.library.expert_library import ExpertLibrary
from mttl.models.modifiers.base import Modifier
from mttl.utils import logger


Expand Down Expand Up @@ -152,21 +153,20 @@ def replace_selector_for_container(
modifier_type: str,
selector_config: SelectorConfig,
training_config: Config = None,
selector_weights: dict = None,
force_replace: bool = False,
):
"""
Assigns a selector to the expert containers in the transformer model.
"""
from mttl.models.modifiers.modify_model import get_modifier_type

expert_containers = []
for _, module in dict(transformer.named_modules()).items():
for _, layer in dict(module.named_children()).items():
if isinstance(layer, ExpertContainer):
# check if the container holds the same modifier type, e.g. LoRAConfig --> "lora"
for supports_config in layer.__supports_configs__:
container_modifier = get_modifier_type(supports_config)
container_modifier = Modifier.get_name_by_config_class(
supports_config
)
# selector does not apply to this container
if not container_modifier == modifier_type:
continue
Expand All @@ -179,10 +179,6 @@ def replace_selector_for_container(
f"No expert containers found for modifier type: {modifier_type}. Cannot assign a selector! Load some experts beforehand."
)

# stores the selectors per container type
if not hasattr(transformer, "selectors"):
transformer.selectors = {}

if not modifier_type in transformer.selectors:
transformer.selectors[modifier_type] = {}

Expand All @@ -205,11 +201,6 @@ def replace_selector_for_container(
n_selectors += isinstance(selector, Selector)
n_selectors_views += isinstance(selector, SelectorView)

if selector_weights is not None:
raise NotImplementedError(
"Support for `selector_weights` is not implemented yet."
)

return n_selectors, n_selectors_views


Expand Down Expand Up @@ -312,9 +303,9 @@ def add_expert_to_transformer(
from mttl.models.containers.hard_prompts_container import (
add_hard_prompt_to_transformer,
)
from mttl.models.modifiers.modify_model import get_modifier_type
from mttl.models.modifiers.modify_model import get_modifier_name

model_modifier = get_modifier_type(expert_config)
model_modifier = get_modifier_name(expert_config)

if model_modifier == "hard_prompt":
return add_hard_prompt_to_transformer(
Expand Down
2 changes: 1 addition & 1 deletion mttl/models/containers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from mttl.models.modifiers.base import ModifierConfig, ModifyMixin
from mttl.models.modifiers.kv_adapter import KVAdapter, KVAdapterConfig
from mttl.models.modifiers.lora import LoRA, LoRAConfig, SkilledLoRA, SkilledLoRAConfig
from mttl.models.modifiers.modify_model import get_modifier_type
from mttl.models.modifiers.modify_model import get_modifier_name


class Container(abc.ABC):
Expand Down
6 changes: 2 additions & 4 deletions mttl/models/containers/hard_prompts_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
TaskNameSelector,
)
from mttl.models.library.expert import Expert
from mttl.models.modifiers.base import ModifyMixin
from mttl.models.modifiers.hard_prompts import HardPrompt, HardPromptConfig
from mttl.models.modifiers.modify_model import register_modifier


class HardPromptDecoderWrapper:
Expand Down Expand Up @@ -115,12 +113,12 @@ def on_add_expert(
action="route",
is_default=False,
) -> None:
from mttl.models.modifiers.modify_model import get_modifier_type
from mttl.models.modifiers.modify_model import get_modifier_name

if action == "merge":
raise ValueError("Merging is not supported for hard prompts.")

if get_modifier_type(expert.expert_config) == "hard_prompt":
if get_modifier_name(expert.expert_config) == "hard_prompt":
expert_module = HardPrompt(
expert.expert_config, prompt_init=expert.expert_weights
)
Expand Down
2 changes: 1 addition & 1 deletion mttl/models/containers/kv_containers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from torch import nn

from mttl.models.containers.lora_containers import ExpertContainer
from mttl.models.containers.selectors.base_selectors import KVTaskNameSelector
from mttl.models.containers.selectors.base import KVTaskNameSelector
from mttl.models.library.expert import Expert
from mttl.models.modifiers.kv_adapter import KVAdapter, KVAdapterConfig

Expand Down
10 changes: 5 additions & 5 deletions mttl/models/containers/lora_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from mttl.logging import warn_once
from mttl.models.containers.base import ExpertContainer
from mttl.models.containers.selectors.base_selectors import (
from mttl.models.containers.selectors.base import (
BatchExpertsAndWeightsSelectorOutput,
BatchExpertsSelectorOutput,
BatchSequenceExpertsAndWeightsSelectorOutput,
Expand All @@ -13,7 +13,7 @@
)
from mttl.models.library.expert import Expert
from mttl.models.modifiers.lora import LoRA, LoRAConfig, SkilledLoRA, SkilledLoRAConfig
from mttl.models.modifiers.modify_model import get_modifier_type
from mttl.models.modifiers.modify_model import get_modifier_name


class LoRAExpertContainer(ExpertContainer):
Expand Down Expand Up @@ -53,7 +53,7 @@ def on_add_expert(

# We may want to add a SkilledLoRA directly, if we are loading an MHR model for example
LoRA_cls = {"lora": LoRA, "skilled_lora": SkilledLoRA}[
get_modifier_type(expert.expert_config)
get_modifier_name(expert.expert_config)
]

modifier_module = LoRA_cls(
Expand Down Expand Up @@ -274,7 +274,7 @@ def __getitem__(self, name) -> Union[LoRA, SkilledLoRA]:
weights: dict[str, Tensor] = self.experts.get_skill_weights(index_of)

config = self.expert_infos[name].expert_config
modifier_type = get_modifier_type(config)
modifier_type = get_modifier_name(config)

if modifier_type == "lora":
assert self.dummy_config.n_splits == 1
Expand Down Expand Up @@ -302,7 +302,7 @@ def on_add_expert(self, expert: Expert, action="route", is_default=False) -> Non
self._check_config(expert.expert_config)

# We may want to add a SkilledLoRA directly, if we are loading an MHR model for example
lora_type = get_modifier_type(expert.expert_config)
lora_type = get_modifier_name(expert.expert_config)
LoRA_cls = {"lora": LoRA, "skilled_lora": SkilledLoRA}[lora_type]

modifier_module = LoRA_cls(
Expand Down
3 changes: 2 additions & 1 deletion mttl/models/containers/selectors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from mttl.models.containers.selectors.arrow_selector import *
from mttl.models.containers.selectors.average_activation_selector import *
from mttl.models.containers.selectors.base_selectors import *
from mttl.models.containers.selectors.base import *
from mttl.models.containers.selectors.moe_selector import *
from mttl.models.containers.selectors.phatgoose_selector import *
from mttl.models.containers.selectors.poly_selector import *
6 changes: 3 additions & 3 deletions mttl/models/containers/selectors/arrow_selector.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dataclasses import dataclass

from mttl.models.containers.selectors.base_selectors import (
from mttl.models.containers.selectors.base import (
PerTokenSelector,
PerTokenSelectorConfig,
register_multi_expert_selector,
Selector,
)


Expand Down Expand Up @@ -42,7 +42,7 @@ class ArrowSelectorConfig(PerTokenSelectorConfig):
proto_norm_fn: str = "id"


@register_multi_expert_selector("arrow_router", ArrowSelectorConfig)
@Selector.register("arrow_router", ArrowSelectorConfig)
class ArrowSelector(PerTokenSelector):
def _load_from_library(self):
"""Fetches prototypes from the library."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dataclasses import dataclass

from mttl.models.containers.selectors.base_selectors import (
from mttl.models.containers.selectors.base import (
PerTokenSelector,
PerTokenSelectorConfig,
register_multi_expert_selector,
Selector,
)


Expand Down Expand Up @@ -47,7 +47,7 @@ class AverageActivationSelectorConfig(PerTokenSelectorConfig):
proto_norm_fn: str = "id"


@register_multi_expert_selector("avg_act_router", AverageActivationSelectorConfig)
@Selector.register("avg_act_router", AverageActivationSelectorConfig)
class AverageActivationSelector(PerTokenSelector):
def _load_from_library(self):
"""Fetches prototypes from the library."""
Expand Down
Loading

0 comments on commit a4a8340

Please sign in to comment.