Skip to content

Commit

Permalink
Merge pull request #73 from microsoft/fix_merge_tests
Browse files Browse the repository at this point in the history
fix tests
  • Loading branch information
sordonia authored Jul 29, 2024
2 parents a4a8340 + b83f290 commit e1f3f1c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
4 changes: 2 additions & 2 deletions mttl/models/containers/selectors/moe_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
SelectorOutput,
forward_with_cache,
)
from mttl.models.expert_context import InfoContainer
from mttl.models.library.expert import ExpertInfo


Expand Down Expand Up @@ -70,9 +71,8 @@ def forward(self, input, **kwargs) -> BatchSequenceExpertsAndWeightsSelectorOutp
# soft routing
selected_experts = SelectorOutput.ALL_EXPERTS

g = getattr(self.info_container, "routing_gates", [])
g = InfoContainer.get().routing_gates
g.append(router_logits)
self.info_container.routing_gates = g

return BatchSequenceExpertsAndWeightsSelectorOutput(
experts=selected_experts, weights=routing_weights
Expand Down
15 changes: 7 additions & 8 deletions mttl/models/containers/selectors/poly_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,13 @@ def get_routing_weights(self, task_name, **selector_kwargs) -> Dict:
def on_add_expert(
self, expert_name: str, expert_info: ExpertInfo, is_default=False
):
if self.n_experts == self.module_logits.shape[-1]:
# we need additional space in the routing to accomodate the incoming expert
self.module_logits.data = torch.empty(
self.n_tasks + 1, self.config.n_splits * (self.n_experts + 1)
).uniform_(-1e-3, 1e-3)

# Last expert is exactly uniform
self.module_logits.data[-1] = 0.0
# we need additional space in the routing to accomodate the incoming expert
self.module_logits.data = torch.empty(
self.n_tasks + 1, self.config.n_splits * (self.n_experts + 1)
).uniform_(-1e-3, 1e-3)

# Last expert is exactly uniform
self.module_logits.data[-1] = 0.0


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions mttl/models/modifiers/kv_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers import Cache
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

from mttl.models.modifiers.base import Modifier, ModifierConfig
from mttl.models.modifiers.base import Modifier, ModifierConfig, ModifyMixin


@dataclass
Expand All @@ -25,7 +25,7 @@ class KVAdapterConfig(ModifierConfig):


@Modifier.register("kv_adapter", config_cls=KVAdapterConfig)
class KVAdapter(Modifier):
class KVAdapter(Modifier, ModifyMixin):
"""
Modifier augmenting the self-attention with additional learnable KV pairs.
Expand Down

0 comments on commit e1f3f1c

Please sign in to comment.