Skip to content

Commit

Permalink
Merge pull request #59 from microsoft/folder_rename
Browse files Browse the repository at this point in the history
rename folder and move expert_containers one level up
  • Loading branch information
sordonia authored Jul 9, 2024
2 parents 929fad6 + 21d28bb commit 1f45a0f
Show file tree
Hide file tree
Showing 62 changed files with 113 additions and 116 deletions.
2 changes: 1 addition & 1 deletion mttl/cli/convert_library_to_hf_phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))

from mttl.models.modifiers.expert_containers.expert_library import (
from mttl.models.library.expert_library import (
ExpertLibrary,
)

Expand Down
2 changes: 1 addition & 1 deletion mttl/cli/show_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import click
from mttl.models.modifiers.expert_containers.expert_library import ExpertLibrary
from mttl.models.library.expert_library import ExpertLibrary
from rich.console import Console
from rich.table import Table

Expand Down
2 changes: 1 addition & 1 deletion mttl/dataloader/alpaca_dataset_readers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary


class AlpacaTemplateForHash(
Expand Down
2 changes: 1 addition & 1 deletion mttl/dataloader/oasst1_readers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary
from mttl.utils import hash_example, logger


Expand Down
2 changes: 1 addition & 1 deletion mttl/dataloader/platypus_dataset_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
)
import numpy as np

from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary


class PlatypusTemplate:
Expand Down
2 changes: 1 addition & 1 deletion mttl/datamodule/arc_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import dataclass
import os

from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion mttl/datamodule/bbh_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os

from mttl.datamodule.mt_seq_to_seq_module import augment_few_shot_task
from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion mttl/datamodule/codex_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import numpy
from mttl.datamodule.base import DefaultDataModule, DatasetConfig
from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary


class CodexDataConfig(DatasetConfig):
Expand Down
2 changes: 1 addition & 1 deletion mttl/datamodule/facts_lm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mttl.datamodule.platypus_module import PlatypusConfig
from mttl.datamodule.base import DefaultDataModule

from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary
from mttl.utils import logger


Expand Down
2 changes: 1 addition & 1 deletion mttl/datamodule/hellaswag_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
MultiChoiceDataModule,
DatasetConfig,
)
from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion mttl/datamodule/humaneval_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from mttl.datamodule.base import DefaultDataModule, DatasetConfig
from mttl.datamodule.utils import maybe_filter_hf_dataset_by_task
from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion mttl/datamodule/mbpp_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import numpy
from mttl.datamodule.utils import maybe_filter_hf_dataset_by_task
from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion mttl/datamodule/mmlu_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
DefaultCollator,
DefaultDataModule,
)
from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary


#################################################
Expand Down
2 changes: 1 addition & 1 deletion mttl/datamodule/mt_seq_to_seq_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mttl.datamodule.utils import maybe_filter_hf_dataset_by_task, logger
from dataclasses import dataclass

from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary


def is_phi2_eval_task(task):
Expand Down
2 changes: 1 addition & 1 deletion mttl/datamodule/ni_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from mttl.datamodule.utils import maybe_filter_hf_dataset_by_task
from mttl.datamodule.base import DefaultCollator, DefaultDataModule, DatasetConfig
from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary
from mttl.utils import logger


Expand Down
2 changes: 1 addition & 1 deletion mttl/datamodule/openbookqa_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
import os

from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion mttl/datamodule/piqa_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
import os

from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion mttl/datamodule/retrieval_lm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass

from mttl.datamodule.utils import get_tokenizer
from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary
from mttl.utils import logger


Expand Down
2 changes: 1 addition & 1 deletion mttl/datamodule/superglue_data_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from dataclasses import dataclass
from mttl.datamodule.base import DatasetConfig, MultiChoiceDataModule
from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion mttl/datamodule/winogrande_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
import os

from mttl.models.modifiers.expert_containers.expert_library import DatasetLibrary
from mttl.models.library.expert_library import DatasetLibrary


def doc_to_text(doc):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import re
from mttl.config import Config
from mttl.models.modifiers.expert_containers.expert_library import ExpertLibrary
from mttl.models.modifiers.expert_containers.selectors import (
from mttl.models.library.expert_library import ExpertLibrary
from mttl.models.containers.selectors import (
Selector,
SelectorConfig,
SelectorView,
get_selector,
)
from mttl.models.modifiers.expert_containers.expert_containers import *
from mttl.models.containers.expert_containers import *
from mttl.models.modifiers.modify_model import CONFIGS_TO_MODIFIERS
from mttl.utils import logger
from mttl.models.modifiers.expert_containers.expert import Expert
from mttl.models.library.expert import Expert


def _extract_identifier(string, match_on="finegrained"):
Expand Down Expand Up @@ -307,7 +307,7 @@ def add_expert_to_transformer(
raise ValueError("Expert name cannot be empty!")

from mttl.models.modifiers.modify_model import get_modifier_type
from mttl.models.modifiers.expert_containers.hard_prompts_container import (
from mttl.models.containers.hard_prompts_container import (
add_hard_prompt_to_transformer,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ModifierConfig,
ModifyMixin,
)
from mttl.models.modifiers.expert_containers.selectors import (
from mttl.models.containers.selectors import (
BatchExpertsAndWeightsSelectorOutput,
BatchExpertsSelectorOutput,
BatchSequenceExpertsAndWeightsSelectorOutput,
Expand All @@ -27,15 +27,15 @@
SkilledLoRAConfig,
)
from mttl.models.modifiers.kv_adapter import KVAdapter, KVAdapterConfig
from mttl.models.modifiers.expert_containers.expert import Expert
from mttl.models.library.expert import Expert
from mttl.models.modifiers.modify_model import get_modifier_type


class ExpertContainer:
__supports_configs__ = []

def __init__(self, config, info_container, layer, selector=None):
from mttl.models.modifiers.expert_containers.selectors import TaskNameSelector
from mttl.models.containers.selectors import TaskNameSelector

self.config = config
self.layer = layer
Expand Down Expand Up @@ -184,7 +184,7 @@ def add_expert(
action="merge",
is_default=False,
) -> None:
from mttl.models.modifiers.expert_containers import filter_expert_weights
from mttl.models.containers import filter_expert_weights

if expert.name in self.expert_infos:
raise ValueError(
Expand Down Expand Up @@ -577,7 +577,7 @@ def add_expert(
is_default=False,
**kwargs,
) -> None:
from mttl.models.modifiers.expert_containers import filter_expert_weights
from mttl.models.containers import filter_expert_weights

expert_weights = filter_expert_weights(
self.__layer_name__, expert.expert_weights
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from typing import Any, Dict
from mttl.models.modifiers.base import ModifyMixin

from mttl.models.modifiers.expert_containers.selectors import (
from mttl.models.containers.selectors import (
Selector,
TaskNameSelector,
BatchExpertsSelectorOutput,
)
from mttl.models.modifiers.expert_containers import ExpertContainer
from mttl.models.containers import ExpertContainer
from mttl.models.modifiers.hard_prompts import HardPrompt, HardPromptConfig
from mttl.models.modifiers.modify_model import register_modifier
from mttl.models.modifiers.expert_containers.expert import Expert
from mttl.models.library.expert import Expert


class HardPromptDecoderWrapper(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from torch import nn
import torch.nn.functional as F
from mttl.models.modifiers.expert_containers.expert import ExpertInfo
from mttl.models.library.expert import ExpertInfo
from mttl.models.modifiers.routing import RoutingInfo
from torch.distributions import Bernoulli, Categorical
from mttl.models.utils import MetricLogger
Expand Down
22 changes: 10 additions & 12 deletions mttl/models/expert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@
import torch
from transformers import PreTrainedModel

from mttl.models.modifiers.expert_containers.library_transforms import (
from mttl.models.library.library_transforms import (
ArrowConfig,
HiddenStateComputerConfig,
)
from mttl.models.modifiers.lora import SkilledLoRAConfig

from mttl.models.modifiers.expert_containers import add_expert_to_transformer
from mttl.models.modifiers.expert_containers.expert_library import ExpertLibrary
from mttl.models.containers import add_expert_to_transformer
from mttl.models.library.expert_library import ExpertLibrary
from mttl.models.modifiers.routing import RoutingInfo
from mttl.utils import logger
from mttl.models.modifiers.expert_containers.expert import Expert, ExpertInfo
from mttl.models.modifiers.expert_containers.expert_containers import (
from mttl.models.library.expert import Expert, ExpertInfo
from mttl.models.containers.expert_containers import (
ExpertContainer,
)
from mttl.models.modifiers.expert_containers.selectors import Selector, SelectorConfig
from mttl.models.containers.selectors import Selector, SelectorConfig


import torch
Expand All @@ -31,16 +31,14 @@
from mttl.models.modifiers import modify_transformer
from mttl.models.modifiers.base import ModifierConfig

from mttl.models.modifiers.expert_containers.expert import ExpertInfo
from mttl.models.modifiers.expert_containers.selectors import SelectorConfig
from mttl.models.library.expert import ExpertInfo
from mttl.models.containers.selectors import SelectorConfig
from mttl.models.modifiers.routing import RoutingInfo
from mttl.models.utils import (
EfficientCheckpointModule,
model_loader_helper,
prepare_model_for_kbit_training,
)
from mttl.models.expert_config import ExpertConfig
from mttl.models.ranker.adapter_ranker import AdapterRankerHelper


torch.set_float32_matmul_precision("high")
Expand Down Expand Up @@ -397,7 +395,7 @@ def load_expert(
is_default: bool = False,
expert_library: ExpertLibrary = None,
):
from mttl.models.modifiers.expert_containers.expert import load_expert
from mttl.models.library.expert import load_expert

expert = load_expert(
expert_path,
Expand Down Expand Up @@ -457,7 +455,7 @@ def set_selector(
selector_config: SelectorConfig,
selector_weights: dict = None,
):
from mttl.models.modifiers.expert_containers import (
from mttl.models.containers import (
replace_selector_for_container,
)

Expand Down
Empty file added mttl/models/library/__init__.py
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)

from mttl.utils import logger, remote_login
from mttl.models.modifiers.expert_containers.expert import (
from mttl.models.library.expert import (
Expert,
load_expert,
ExpertInfo,
Expand Down
Loading

0 comments on commit 1f45a0f

Please sign in to comment.