Skip to content

Commit

Permalink
enable lora_merge_after in Coalesced container (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
pclucas14 authored Jul 11, 2024
1 parent 04a51d6 commit 35fcaa4
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions mttl/models/containers/expert_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,17 @@ class CoalescedLoRAExpertContainer(LoRAExpertContainer):

__supports_configs__ = [SkilledLoRAConfig, LoRAConfig]

def __init__(self, config, info_container, layer, selector=None, **kwargs):
def __init__(
self,
config,
info_container,
layer,
selector=None,
lora_merge_after=False,
**kwargs,
):
MergeableAdapter.__init__(self)
super().__init__(config, info_container, layer, selector)
super().__init__(config, info_container, layer, selector, lora_merge_after)

if not isinstance(self.layer, nn.Linear):
raise ValueError(
Expand Down Expand Up @@ -466,7 +474,11 @@ def route(self, input, selection, **kwargs):
)

module_output = SkilledLoRA.parallel_linear_weighted_forward(
input, [self.experts], weights, dim_names=["batch", "experts"]
input,
[self.experts],
weights,
dim_names=["batch", "experts"],
merge_after=self.lora_merge_after,
)
return module_output
elif (
Expand Down Expand Up @@ -497,7 +509,11 @@ def route(self, input, selection, **kwargs):
assert weights.shape[-1] == self.experts.n_skills

module_output = SkilledLoRA.parallel_linear_weighted_forward(
input, [self.experts], weights, dim_names=selection.dim_names
input,
[self.experts],
weights,
dim_names=selection.dim_names,
merge_after=self.lora_merge_after,
)
return module_output
else:
Expand Down

0 comments on commit 35fcaa4

Please sign in to comment.