diff --git a/mttl/models/containers/selectors/moe_selector.py b/mttl/models/containers/selectors/moe_selector.py index 0bd7ad6de..4c2652cba 100644 --- a/mttl/models/containers/selectors/moe_selector.py +++ b/mttl/models/containers/selectors/moe_selector.py @@ -12,6 +12,7 @@ SelectorOutput, forward_with_cache, ) +from mttl.models.expert_context import InfoContainer from mttl.models.library.expert import ExpertInfo @@ -70,9 +71,8 @@ def forward(self, input, **kwargs) -> BatchSequenceExpertsAndWeightsSelectorOutp # soft routing selected_experts = SelectorOutput.ALL_EXPERTS - g = getattr(self.info_container, "routing_gates", []) + g = InfoContainer.get().routing_gates g.append(router_logits) - self.info_container.routing_gates = g return BatchSequenceExpertsAndWeightsSelectorOutput( experts=selected_experts, weights=routing_weights diff --git a/mttl/models/containers/selectors/poly_selector.py b/mttl/models/containers/selectors/poly_selector.py index 3607e070b..d69a9821f 100644 --- a/mttl/models/containers/selectors/poly_selector.py +++ b/mttl/models/containers/selectors/poly_selector.py @@ -146,14 +146,13 @@ def get_routing_weights(self, task_name, **selector_kwargs) -> Dict: def on_add_expert( self, expert_name: str, expert_info: ExpertInfo, is_default=False ): - if self.n_experts == self.module_logits.shape[-1]: - # we need additional space in the routing to accomodate the incoming expert - self.module_logits.data = torch.empty( - self.n_tasks + 1, self.config.n_splits * (self.n_experts + 1) - ).uniform_(-1e-3, 1e-3) - - # Last expert is exactly uniform - self.module_logits.data[-1] = 0.0 + # we need additional space in the routing to accomodate the incoming expert + self.module_logits.data = torch.empty( + self.n_tasks + 1, self.config.n_splits * (self.n_experts + 1) + ).uniform_(-1e-3, 1e-3) + + # Last expert is exactly uniform + self.module_logits.data[-1] = 0.0 @dataclass diff --git a/mttl/models/modifiers/kv_adapter.py b/mttl/models/modifiers/kv_adapter.py index 900a17f60..419a9f474 100644 --- a/mttl/models/modifiers/kv_adapter.py +++ b/mttl/models/modifiers/kv_adapter.py @@ -11,7 +11,7 @@ from transformers import Cache from transformers.models.llama.modeling_llama import apply_rotary_pos_emb -from mttl.models.modifiers.base import Modifier, ModifierConfig +from mttl.models.modifiers.base import Modifier, ModifierConfig, ModifyMixin @dataclass @@ -25,7 +25,7 @@ class KVAdapterConfig(ModifierConfig): @Modifier.register("kv_adapter", config_cls=KVAdapterConfig) -class KVAdapter(Modifier): +class KVAdapter(Modifier, ModifyMixin): """ Modifier augmenting the self-attention with additional learnable KV pairs.