Skip to content

Commit

Permalink
Merge pull request #75 from microsoft/selectors-cache
Browse files Browse the repository at this point in the history
Selectors cache
  • Loading branch information
sordonia authored Aug 1, 2024
2 parents 4d6dcc9 + e24ae31 commit e004593
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 230 deletions.
55 changes: 26 additions & 29 deletions mttl/models/containers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SelectorView,
get_selector,
)
from mttl.models.containers.selectors.base import SelectorsCache
from mttl.models.library.expert import Expert
from mttl.models.library.expert_library import ExpertLibrary
from mttl.models.modifiers.base import Modifier
Expand Down Expand Up @@ -88,8 +89,7 @@ def add_expert_library_to_transformer(
expert_library: ExpertLibrary,
action: str = "route",
default_expert: str = None,
routing_config: SelectorConfig = None,
training_config: Config = None,
selector_config: SelectorConfig = None,
):
for expert_name, expert_dump in expert_library.items():
add_expert_to_transformer(
Expand All @@ -99,8 +99,7 @@ def add_expert_library_to_transformer(
expert_dump.expert_weights,
action=action,
is_default=expert_name == default_expert,
routing_config=routing_config,
training_config=training_config,
selector_config=selector_config,
)


Expand All @@ -109,7 +108,7 @@ def create_selector_for_container(
container,
modifier_type: str,
selector_config: SelectorConfig,
training_config: Config = None,
selector_cache: SelectorsCache,
) -> Selector:
if container.selector is not None and container.selector.config == selector_config:
# selector already exists and has the same config
Expand All @@ -122,23 +121,21 @@ def create_selector_for_container(
# we create a new selector if it doesn't exist for this identifier, or
# if we are replacing a previous one of a different type
create_new_selector = (
identifier not in transformer.selectors[modifier_type]
or transformer.selectors[modifier_type].get(identifier).config
!= selector_config
not selector_cache.get(modifier_type, identifier)
or selector_cache.get(modifier_type, identifier).config != selector_config
)
if create_new_selector:
# Special case when you have a decoder layer in an enc-dec model
selector = get_selector(
selector_config,
layer=container.layer,
training_config=training_config,
)
transformer.selectors[modifier_type][identifier] = selector
selector_cache.insert(modifier_type, identifier, selector)

# selector needs to know how many times it will be called per forward pass in order to be able to reset the cache
selector.total_calls_per_forward += 1
else:
selector: Selector = transformer.selectors[modifier_type][identifier]
selector: Selector = selector_cache.get(modifier_type, identifier)
# selector needs to know how many times it will be called per forward pass in order to be able to reset the cache
selector.total_calls_per_forward += 1
selector = selector.create_view()
Expand All @@ -152,7 +149,7 @@ def replace_selector_for_container(
transformer,
modifier_type: str,
selector_config: SelectorConfig,
training_config: Config = None,
selector_cache: SelectorsCache,
force_replace: bool = False,
):
"""
Expand All @@ -179,13 +176,10 @@ def replace_selector_for_container(
f"No expert containers found for modifier type: {modifier_type}. Cannot assign a selector! Load some experts beforehand."
)

if not modifier_type in transformer.selectors:
transformer.selectors[modifier_type] = {}

if force_replace:
for container in expert_containers:
container.selector = None
transformer.selectors[modifier_type] = {}
selector_cache.clear(modifier_type)

n_selectors = 0
n_selectors_views = 0
Expand All @@ -196,7 +190,7 @@ def replace_selector_for_container(
container,
modifier_type,
selector_config,
training_config,
selector_cache,
)
n_selectors += isinstance(selector, Selector)
n_selectors_views += isinstance(selector, SelectorView)
Expand Down Expand Up @@ -285,15 +279,19 @@ def add_expert_to_transformer(
expert: Expert,
action: str = "route",
is_default: bool = False,
routing_config: SelectorConfig = None,
training_config: Config = None,
selector_config: SelectorConfig = None,
selector_cache: SelectorsCache = None,
) -> None:
"""
Routine to add an expert to the transformer architecture.
Params:
transformer: the transformer model to modify
Config: the config of the model to which the expert is added
expert: expert instance that needs to be added
action: whether to route or merge this expert, default is `route`
is_default: whether the expert should be set as default
selector_config: selector configuration to use for the model
selector_cache: cache to store the selectors for the model
"""
expert_config = expert.expert_config

Expand Down Expand Up @@ -332,8 +330,8 @@ def add_expert_to_transformer(
expert_config,
layer,
lora_merge_after=(
routing_config.lora_merge_after
if routing_config
selector_config.lora_merge_after
if selector_config
else False
),
)
Expand Down Expand Up @@ -367,31 +365,30 @@ def add_expert_to_transformer(
transformer.named_parameters(), expert_config.tie_params
)
tie_params(transformer, expert_config, target_2_source_param)
####################

if not added_layers:
raise ValueError(
"You were trying to add an expert but no expert containers were created, this is likely due to a misconfiguration of the expert config."
" `modify_layers` and `modify_modules` did not return a match for the current model."
)

if routing_config is not None:
if selector_config is not None:
replace_selector_for_container(
transformer,
model_modifier,
routing_config,
training_config,
selector_config,
selector_cache,
)

if not transformer.selectors[model_modifier]:
if not selector_cache.get(model_modifier):
raise ValueError(
"No selectors were created but a routing config was specified. Check your routing_config and model architecture."
"No selectors were created but a routing config was specified. Check your selector_config and model architecture."
)

logger.debug(
"Added expert %s, with %s selectors",
expert.name,
len(transformer.selectors[model_modifier]),
len(selector_cache.get(model_modifier)),
)

logger.debug("Patched layers: %s", added_layers)
2 changes: 1 addition & 1 deletion mttl/models/containers/selectors/poly_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)

self.module_logits_dict = nn.ParameterDict()
self.training_config = kwargs["training_config"]
self.init_gap = [-1e-3, 1e-3]
self.device = kwargs["layer"].weight.device
self.finetune_task_name = self.config.finetune_task_name

def _get_weights(self):
weights = torch.cat(
Expand Down
54 changes: 18 additions & 36 deletions mttl/models/expert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mttl.models.containers.selectors.base import (
LoadableLibraryMixin,
LoadableSelectorConfig,
SelectorsCache,
)
from mttl.models.expert_config import ExpertConfig
from mttl.models.expert_context import InfoContainer
Expand Down Expand Up @@ -271,6 +272,7 @@ class MultiExpertModel(ExpertModel):

def __init__(self, **config_kwargs):
config_kwargs["model_modifier"] = None

super().__init__(**config_kwargs)

# config about the routing
Expand All @@ -282,8 +284,12 @@ def __init__(self, **config_kwargs):
)

# inject memory for adding selectors
self.model.selectors = {}
self.experts_names = []
self.selector_cache = SelectorsCache()
self.experts_infos = {}

@property
def experts_names(self):
return list(self.experts_infos.keys())

@classmethod
def from_pretrained_library(
Expand Down Expand Up @@ -350,13 +356,8 @@ def experts_containers(self) -> List[ExpertContainer]:
return containers

@property
def selectors(self) -> Dict[str, List[Selector]]:
selectors = defaultdict(list)
for modifier, selectors_dict in self.model.selectors.items():
for selector in selectors_dict.values():
if isinstance(selector, Selector):
selectors[modifier].append(selector)
return selectors
def selectors(self) -> Dict[str, Dict[str, Selector]]:
return self.selector_cache.cache

def delete_expert_container(self):
"""
Expand All @@ -366,7 +367,9 @@ def delete_expert_container(self):
for c_name, child in dict(module.named_children()).items():
if isinstance(child, ExpertContainer) and len(child.experts) > 0:
setattr(module, c_name, child.layer)
self.experts_names.clear()

self.selector_cache.clear()
self.experts_infos.clear()

def add_experts_from_library(self, library):
import concurrent.futures
Expand Down Expand Up @@ -482,12 +485,12 @@ def add_expert_instance(
expert_instance,
action=action,
is_default=expert_instance.name == "default" or is_default,
routing_config=self.selector_config,
training_config=self.training_config,
selector_config=self.selector_config,
selector_cache=self.selector_cache,
)

if action != "merge":
self.experts_names.append(expert_instance.name)
self.experts_infos[expert_instance.name] = expert_instance.expert_info
# reload the expert instance to fill the weights properly if this was an empty expert
expert_instance = self.get_expert_instance(expert_instance.name)
return expert_instance
Expand All @@ -503,9 +506,10 @@ def set_selector(
self.model,
modifier_type,
selector_config,
self.selector_cache,
force_replace=True,
)
assert self.model.selectors[modifier_type]
assert self.selector_cache.get(modifier_type)
logger.info(
"Created {} selectors and {} views.".format(n_selectors, n_selectors_views)
)
Expand All @@ -527,28 +531,6 @@ def extract_parameters(self, p_name_pattern=".*lora.*"):
para_list.append(param.reshape(-1))
return torch.cat(para_list)

def get_task_embeddings(self):
"""
Retrieves the task embeddings for the loaded experts.
This method assumes that the names of the loaded experts correspond to the tasks they are made for.
Returns:
embeddings (dict): A dictionary containing the task embeddings for each expert.
The keys are the expert names and the values are the corresponding embeddings.
"""
if len(self.experts_names) == 0:
return self.extract_parameters()

embeddings = {}
for exp_name in self.experts_names:
embeddings[exp_name] = (
self.extract_parameters(p_name_pattern=rf".*{exp_name}\..*lora.*")
.detach()
.cpu()
)
return embeddings

def get_expert_instance(self, expert_name):
"""
Retrieves an instance of the specified expert from the model.
Expand Down
Loading

0 comments on commit e004593

Please sign in to comment.