Skip to content

Commit

Permalink
Merge pull request #79 from microsoft/fix-circular-import
Browse files Browse the repository at this point in the history
Fix circular import
  • Loading branch information
sordonia authored Aug 8, 2024
2 parents 50ae650 + e42b4f1 commit 8c9a02c
Show file tree
Hide file tree
Showing 27 changed files with 781 additions and 671 deletions.
4 changes: 2 additions & 2 deletions mttl/models/containers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
CoalescedLoRAExpertContainer,
LoRAExpertContainer,
)
from mttl.models.containers.selectors import (
from mttl.models.containers.selectors.base import (
Selector,
SelectorConfig,
SelectorsCache,
SelectorView,
get_selector,
)
from mttl.models.containers.selectors.base import SelectorsCache
from mttl.models.library.expert import Expert
from mttl.models.library.expert_library import ExpertLibrary
from mttl.models.modifiers.base import Modifier
Expand Down
8 changes: 3 additions & 5 deletions mttl/models/containers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

from mttl.config import Config
from mttl.logging import warn_once
from mttl.models.containers.selectors import (
from mttl.models.containers.selectors.base import Selector, TaskNameSelector
from mttl.models.containers.selectors.kv_selector import KVTaskNameSelector
from mttl.models.containers.selectors.selector_output import (
BatchExpertsAndWeightsSelectorOutput,
BatchExpertsSelectorOutput,
BatchSequenceExpertsAndWeightsSelectorOutput,
ExpertsAndWeightsSelectorOutput,
KVTaskNameSelector,
Selector,
SelectorOutput,
)
from mttl.models.library.expert import Expert
Expand All @@ -39,8 +39,6 @@ class ExpertContainer(nn.Module, Container):
def __init__(self, config, layer, selector=None):
super().__init__()

from mttl.models.containers.selectors import TaskNameSelector

self.config = config
self.layer = layer
self.selector = selector or TaskNameSelector()
Expand Down
7 changes: 2 additions & 5 deletions mttl/models/containers/hard_prompts_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
from torch import nn

from mttl.models.containers import ExpertContainer
from mttl.models.containers.selectors import (
BatchExpertsSelectorOutput,
Selector,
TaskNameSelector,
)
from mttl.models.containers.selectors.base import Selector, TaskNameSelector
from mttl.models.containers.selectors.selector_output import BatchExpertsSelectorOutput
from mttl.models.library.expert import Expert
from mttl.models.modifiers.hard_prompts import HardPrompt, HardPromptConfig

Expand Down
2 changes: 1 addition & 1 deletion mttl/models/containers/kv_containers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from torch import nn

from mttl.models.containers.lora_containers import ExpertContainer
from mttl.models.containers.selectors.base import KVTaskNameSelector
from mttl.models.containers.selectors.kv_selector import KVTaskNameSelector
from mttl.models.library.expert import Expert
from mttl.models.modifiers.kv_adapter import KVAdapter, KVAdapterConfig

Expand Down
2 changes: 1 addition & 1 deletion mttl/models/containers/lora_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from mttl.logging import warn_once
from mttl.models.containers.base import ExpertContainer
from mttl.models.containers.selectors.base import (
from mttl.models.containers.selectors.selector_output import (
BatchExpertsAndWeightsSelectorOutput,
BatchExpertsSelectorOutput,
BatchSequenceExpertsAndWeightsSelectorOutput,
Expand Down
47 changes: 41 additions & 6 deletions mttl/models/containers/selectors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,41 @@
from mttl.models.containers.selectors.arrow_selector import *
from mttl.models.containers.selectors.average_activation_selector import *
from mttl.models.containers.selectors.base import *
from mttl.models.containers.selectors.moe_selector import *
from mttl.models.containers.selectors.phatgoose_selector import *
from mttl.models.containers.selectors.poly_selector import *
# import everything from the selectors module
from mttl.models.containers.selectors.arrow_selector import (
ArrowSelector,
ArrowSelectorConfig,
)
from mttl.models.containers.selectors.average_activation_selector import (
AverageActivationSelector,
AverageActivationSelectorConfig,
)
from mttl.models.containers.selectors.base import (
Selector,
SelectorConfig,
TaskNameSelector,
TaskNameSelectorConfig,
)
from mttl.models.containers.selectors.moe_selector import (
MOERKHSSelector,
MOERKHSSelectorConfig,
)
from mttl.models.containers.selectors.per_token_selector import (
PerTokenSelector,
PerTokenSelectorConfig,
)
from mttl.models.containers.selectors.phatgoose_selector import (
PhatgooseSelector,
PhatgooseSelectorConfig,
)
from mttl.models.containers.selectors.poly_selector import (
PolySelector,
PolySelectorConfig,
)
from mttl.models.containers.selectors.selector_output import (
BatchExpertsAndWeightsSelectorOutput,
BatchExpertsSelectorOutput,
BatchExpertsSplitsAndWeightsSelectorOutput,
BatchSequenceExpertsAndWeightsSelectorOutput,
BatchSequenceExpertsSplitsAndWeightsSelectorOutput,
ExpertsAndWeightsSelectorOutput,
ExpertsSplitsAndWeightsSelectorOutput,
SelectorOutput,
)
5 changes: 2 additions & 3 deletions mttl/models/containers/selectors/arrow_selector.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from dataclasses import dataclass

from mttl.models.containers.selectors.base import (
from mttl.models.containers.selectors.base import Selector, artifacts_cache
from mttl.models.containers.selectors.per_token_selector import (
PerTokenSelector,
PerTokenSelectorConfig,
Selector,
artifacts_cache,
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from dataclasses import dataclass

from mttl.models.containers.selectors.base import (
from mttl.models.containers.selectors.base import Selector, artifacts_cache
from mttl.models.containers.selectors.per_token_selector import (
PerTokenSelector,
PerTokenSelectorConfig,
Selector,
artifacts_cache,
)


Expand Down
Loading

0 comments on commit 8c9a02c

Please sign in to comment.