Skip to content

Commit

Permalink
Merge pull request #74 from microsoft/fix-task-to-expert
Browse files Browse the repository at this point in the history
Fix task to expert
  • Loading branch information
sordonia authored Jul 29, 2024
2 parents e1f3f1c + 330b75e commit 4d6dcc9
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 61 deletions.
16 changes: 9 additions & 7 deletions mttl/models/containers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,20 @@ def add_expert(self, expert: Expert, action="route", is_default=False) -> None:
"Cannot set is_default if this expert is merged, change to 'route'."
)

self.on_add_expert(expert, action=action, is_default=is_default)

# if a new expert was added, we update the selector and information meta-data
if action != "merge":
update = action != "merge"
if update:
self.expert_infos[expert.name] = expert_info
self.selector.add_expert(
expert.name, expert_info=expert_info, is_default=is_default
)
self.default_expert_name: str | None = (
expert.name if is_default else self.default_expert_name
)

self.on_add_expert(expert, action=action, is_default=is_default)
if update:
# if a new expert was added, we update the selector and information meta-data
self.selector.add_expert(
expert.name, expert_info=expert_info, is_default=is_default
)

@property
def expert_names(self) -> list:
return list(self.expert_infos.keys())
Expand Down
130 changes: 83 additions & 47 deletions mttl/models/containers/selectors/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import threading
from abc import ABC
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

Expand All @@ -14,6 +15,7 @@
from mttl.logging import logger, warn_once
from mttl.models.expert_context import InfoContainer
from mttl.models.library.expert import ExpertInfo
from mttl.models.modifiers.base import Modifier
from mttl.models.ranker.adapter_ranker import AdapterRankerHelper
from mttl.models.ranker.classifier_ranker import ClusterPredictor
from mttl.models.utils import MetricLogger
Expand All @@ -35,6 +37,7 @@ class SelectorConfig:
router_granularity: str = "*"
lora_merge_after: bool = False
selector_logging: bool = True
num_experts: int = 0

def __eq__(self, other):
# compare all the attributes
Expand All @@ -60,7 +63,9 @@ def fromdict(cls, dumped: Dict) -> "SelectorConfig":

@classmethod
def from_training_config(
cls, training_config: Union["Config", "SelectorConfig"]
cls,
training_config: Union["Config", "SelectorConfig"],
ignore_prefix: str = None,
) -> Union["SelectorConfig", None]:
"""Build modifier config from the training config.
Expand All @@ -70,28 +75,71 @@ def from_training_config(
# nothing to do here
return training_config

# if called on the base class, we need to find the correct subclass
if training_config.router_selector is None:
return None
if cls == SelectorConfig:
# if called on the base class, we need to find the correct subclass
if training_config.router_selector is None:
return None

if training_config.router_selector not in Selector.registered_names():
raise ValueError(
f"Selector '{training_config.router_selector}' not found, has it been registered?"
)
if training_config.router_selector not in Selector.registered_names():
raise ValueError(
f"Selector '{training_config.router_selector}' not found, has it been registered?"
)

config_klass = Selector.get_config_class_by_name(
training_config.router_selector
)
config_klass = Selector.get_config_class_by_name(
training_config.router_selector
)
else:
config_klass = cls

kwargs = {}
for key in config_klass.__dataclass_fields__.keys():
# only overwrite default if value exists and is not None
train_cfg_value = getattr(training_config, key, None)
if ignore_prefix:
default_value = getattr(training_config, ignore_prefix + key, None)
else:
default_value = None

train_cfg_value = getattr(training_config, key, default_value)
if train_cfg_value is not None:
kwargs[key] = getattr(training_config, key)
kwargs[key] = train_cfg_value
return config_klass(**kwargs)


class SelectorsCache:
"""Keep a cache of all added selectors indexed by both modifier and selector name."""

def __init__(self):
self.cache = defaultdict(dict)
self.clear()

def clear(self, modifier_name: str = None):
# initialize cache for all registered modifiers
if modifier_name is None:
for modifier_name in Modifier.registered_names():
self.cache[modifier_name] = {}
else:
self.cache[modifier_name] = {}

def insert(self, modifier_name: str, selector_name: str, selector: "Selector"):
self.cache[modifier_name][selector_name] = selector

def get(
self, modifier_name: str, selector_name: str = None
) -> Union["Selector", Dict]:
if selector_name is None:
return self.cache[modifier_name]
return self.cache[modifier_name].get(selector_name, None)

def items(self):
return iter(self.cache.items())

def __setitem__(self, key, value):
if key not in Modifier.registered_names():
raise ValueError(f"Modifier '{key}' not found, has it been registered?")

self.cache[key] = value


@dataclass
class LoadableSelectorConfig(SelectorConfig):
"""Adds support for library_id and data_id, which specifies the unique identifier to load."""
Expand Down Expand Up @@ -258,6 +306,7 @@ def __init__(self, config=None, **kwargs):
self.default_expert_name = None
self.total_calls_per_forward = 0
self._calls_counter = 0
self._task_to_expert_name = {}
# dependency injection filled from ExpertContainer
self.__layer_name__ = None

Expand Down Expand Up @@ -304,6 +353,10 @@ def layer_name(self):

return self.__layer_name__

@property
def task_to_expert_name(self):
return getattr(self, "_task_to_expert_name", {})

@property
def n_experts(self):
return len(self.expert_names)
Expand All @@ -324,14 +377,28 @@ def on_add_expert(
def add_expert(
self, expert_name: str, expert_info: ExpertInfo = None, is_default=False
):
self.on_add_expert(expert_name, expert_info, is_default)
if expert_info is None or expert_info.expert_task_name is None:
logger.warning(
"Expert's task_name not set, assume task name corresponds to expert name!"
)
self._task_to_expert_name[expert_name] = expert_name
else:
for task_name in expert_info.expert_task_name.split(","):
if task_name in self._task_to_expert_name:
logger.warning(
f"Task name {task_name} already assigned to expert {self._task_to_expert_name[task_name]}"
)
self._task_to_expert_name[task_name] = expert_name

# standard bookkeeping for all selectors
if is_default:
self.default_expert_name = expert_name

self.expert_infos[expert_name] = expert_info

# call custom logic for add expert
self.on_add_expert(expert_name, expert_info, is_default)


class SelectorView:
"""A view on a selector that allows it to call forward but doesn't act on add_expert.
Expand Down Expand Up @@ -420,37 +487,6 @@ def on_add_expert(
pass


class TaskToExpertMixin:
"""
Builds `task_to_expert_name` mapping on add_expert, useful for
routing (as in TaskNameSelector) or for logging in-distribution stats (PerTokenSelector)
"""

@property
def task_to_expert_name(self):
return getattr(self, "_task_to_expert_name", {})

def on_add_expert(
self, expert_name: str, expert_info: ExpertInfo = None, is_default=False
):
_task_to_expert_name = self.task_to_expert_name

if expert_info is None or expert_info.expert_task_name is None:
logger.warning(
"Expert's task_name not set, assume task name corresponds to expert name!"
)
_task_to_expert_name[expert_name] = expert_name
else:
for task_name in expert_info.expert_task_name.split(","):
if task_name in _task_to_expert_name:
logger.warning(
f"Task name {task_name} already assigned to expert {_task_to_expert_name[task_name]}"
)
_task_to_expert_name[task_name] = expert_name

self._task_to_expert_name = _task_to_expert_name


@dataclass
class PerTokenSelectorConfig(LoadableSelectorConfig):
router_temp: float = None
Expand Down Expand Up @@ -523,7 +559,7 @@ def get_expert_prototype_from_library_artifacts(


@Selector.register("per_token_router", PerTokenSelectorConfig)
class PerTokenSelector(Selector, TaskToExpertMixin, LoadableLibraryMixin):
class PerTokenSelector(Selector, LoadableLibraryMixin):
def __init__(self, config, **kwargs) -> None:
super().__init__(config, **kwargs)

Expand Down Expand Up @@ -734,7 +770,7 @@ class TaskNameSelectorConfig(SelectorConfig):


@Selector.register("task_selector", TaskNameSelectorConfig)
class TaskNameSelector(Selector, TaskToExpertMixin):
class TaskNameSelector(Selector):
def __init__(self, **kwargs) -> None:
super().__init__()

Expand Down
14 changes: 9 additions & 5 deletions mttl/models/containers/selectors/poly_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def on_add_expert(
):
# we need additional space in the routing to accomodate the incoming expert
self.module_logits.data = torch.empty(
self.n_tasks + 1, self.config.n_splits * (self.n_experts + 1)
self.n_tasks + 1, self.config.n_splits * self.n_experts
).uniform_(-1e-3, 1e-3)

# Last expert is exactly uniform
Expand All @@ -157,7 +157,8 @@ def on_add_expert(

@dataclass
class PolySelectorDirectConfig(PolySelectorConfig):
pass
# Used to initialize the logits of the expert for the current task
finetune_task_name: str = None


@Selector.register("poly_router_dir", PolySelectorDirectConfig)
Expand Down Expand Up @@ -187,16 +188,19 @@ def on_add_expert(
):
"""
Assume:
expert_task_name -- task name expert is pecialized at
expert_task_name -- task name expert is specialized in
self.config.finetune_task_name -- name of the task the model is currently trained on
If we encounter a module for the current task, we init it with one hot, otherwise with uniform.
"""
main_m = 1

expert_task_name = expert_info.expert_task_name
if expert_name not in self.module_logits_dict:
if self.training_config.finetune_task_name == expert_task_name:
if (
self.config.finetune_task_name
and self.task_to_expert_name[self.config.finetune_task_name]
== expert_name
):
self.init_gap = [
0,
0,
Expand Down
11 changes: 9 additions & 2 deletions tests/test_routed_multi_expert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def test_expert_selector_with_poly_routing(self, tmp_exp_config):

module = MultiExpertModel(**vars(config))
module.load_from_module_dict(module_dict, action="route")
assert module.selectors["lora"][0].init_gap == [-1e-3, 1e-3]

assert isinstance(
module.model.transformer.h[0].attn.attention.k_proj, LoRAExpertContainer
Expand All @@ -259,7 +260,7 @@ def test_expert_selector_with_poly_routing(self, tmp_exp_config):

# Test Base Llama model
output = module(batch)
assert np.allclose(output.item(), 10.15, atol=0.1)
assert np.allclose(output.item(), 9.68, atol=0.1)

# check the get_router_weights function
weights = {}
Expand All @@ -286,12 +287,18 @@ def test_expert_selector_with_poly_routing(self, tmp_exp_config):

# change router_granularity to finegrained
config.router_granularity = "finegrained"
config.finetune_task_name = "mod1"

module = MultiExpertModel(
**vars(config),
)
module.load_from_module_dict(module_dict)
assert module.selectors["lora"][0].init_gap == [0, 0]
assert module.selectors["lora"][0].module_logits_dict["mod1"].item() == 1.0
assert module.selectors["lora"][0].module_logits_dict["mod2"].item() == 0.0

output = module(batch)
assert np.allclose(output.item(), 10.15, atol=0.1)
assert np.allclose(output.item(), 9.68, atol=0.1)

weights = {}
for _, selector_list in module.selectors.items():
Expand Down

0 comments on commit 4d6dcc9

Please sign in to comment.