Skip to content

Commit

Permalink
BEiT model refactoring (speechbrain#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
lenglaender authored and calpt committed Aug 25, 2023
1 parent b077e26 commit 6d937be
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 1,228 deletions.
4 changes: 4 additions & 0 deletions src/adapter_transformers/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .albert import AlbertModelAdaptersMixin
from .bart import BartDecoderAdaptersMixin, BartEncoderAdaptersMixin, BartModelAdaptersMixin
from .beit import BeitIntermediateAdaptersMixin, BeitModelAdaptersMixin, BeitOutputAdaptersMixin
from .bert import BertLayerAdaptersMixin, BertModelAdaptersMixin
from .distilbert import DistilBertModelAdaptersMixin, DistilBertTransformerAdaptersMixin
from .t5 import T5BlockAdaptersMixin, T5ModelAdaptersMixin, T5ModelAdaptersWithHeadsMixin
Expand All @@ -11,6 +12,9 @@
"BartEncoder": BartEncoderAdaptersMixin,
"BartDecoder": BartDecoderAdaptersMixin,
"BartModel": BartModelAdaptersMixin,
"BeitIntermediate": BeitIntermediateAdaptersMixin,
"BeitOutput": BeitOutputAdaptersMixin,
"BeitModel": BeitModelAdaptersMixin,
"BertLayer": BertLayerAdaptersMixin,
"BertModel": BertModelAdaptersMixin,
"Transformer": DistilBertTransformerAdaptersMixin,
Expand Down
47 changes: 34 additions & 13 deletions src/adapter_transformers/mixins/beit.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,54 @@
import logging
from typing import Iterable, Tuple

import torch.nn as nn

from ..layer import AdapterLayer
from ..model_mixin import ModelAdaptersMixin, ModelWithHeadsAdaptersMixin
from ..lora import Linear as LoRALinear
from ..model_mixin import ModelBaseAdaptersMixin
from ..prefix_tuning import PrefixTuningShim


logger = logging.getLogger(__name__)
class BeitSelfAttentionAdaptersMixin:
def init_adapters(self, config):
self.location_key = "self"

# Wrap layers for LoRA
self.query = LoRALinear.wrap(self.query, "selfattn", config, attn_key="q")
self.key = LoRALinear.wrap(self.key, "selfattn", config, attn_key="k")
self.value = LoRALinear.wrap(self.value, "selfattn", config, attn_key="v")

self.prefix_tuning = PrefixTuningShim(self.location_key + "_prefix" if self.location_key else None, config)


class BeitIntermediateAdaptersMixin:
def init_adapters(self, config):
# Wrap layers for LoRA
self.dense = LoRALinear.wrap(self.dense, "intermediate", config)


class BeitOutputAdaptersMixin:
def init_adapters(self, config):
# Wrap layers for LoRA
self.dense = LoRALinear.wrap(self.dense, "output", config)


class BeitLayerAdaptersMixin:
"""Adds adapters to the BeitLayer module."""

def _init_adapter_modules(self):
self.attention_adapters = AdapterLayer("mh_adapter", self.config)
self.attention_adapters._init_adapter_modules()

self.output_adapters = AdapterLayer("output_adapter", self.config)
self.output_adapters._init_adapter_modules()
def init_adapters(self, config):
self.attention_adapters = AdapterLayer("mh_adapter")
self.output_adapters = AdapterLayer("output_adapter")


class BeitModelAdaptersMixin(ModelAdaptersMixin):
class BeitModelAdaptersMixin(ModelBaseAdaptersMixin):
"""Adds adapters to the BeitModel module."""

def init_adapters(self, config):
super().init_adapters(config)

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.encoder.layer):
yield i, layer


class BeitModelWithHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin):
pass
def set_input_embeddings(self, value):
self.embeddings.patch_embeddings = value
16 changes: 0 additions & 16 deletions src/adapter_transformers/models/beit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,11 @@

_import_structure = {
"adapter_model": ["BeitAdapterModel"],
"modeling_bert": [
"BEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"BeitForImageClassification",
"BeitForMaskedImageModeling",
"BeitForSemanticSegmentation",
"BeitModel",
"BeitPreTrainedModel",
],
}


if TYPE_CHECKING:
from .adapter_model import BeitAdapterModel
from .modeling_beit import (
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
BeitForImageClassification,
BeitForMaskedImageModeling,
BeitForSemanticSegmentation,
BeitModel,
BeitPreTrainedModel,
)

else:
import sys
Expand Down
10 changes: 8 additions & 2 deletions src/adapter_transformers/models/beit/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@

import torch

from transformers.models.beit.modeling_beit import (
BEIT_INPUTS_DOCSTRING,
BEIT_START_DOCSTRING,
BeitModel,
BeitPreTrainedModel,
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import AdapterSetup
from ...heads import ImageClassificationHead, ModelWithFlexibleHeadsAdaptersMixin
from .modeling_beit import BEIT_INPUTS_DOCSTRING, BEIT_START_DOCSTRING, BeitModel, BeitPreTrainedModel
from ...wrappers import wrap_model


@add_start_docstrings(
Expand All @@ -17,7 +23,7 @@ class BeitAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, BeitPreTrainedModel)
def __init__(self, config):
super().__init__(config)

self.beit = BeitModel(config)
self.beit = wrap_model(BeitModel(config))

self._init_head_modules()

Expand Down
Loading

0 comments on commit 6d937be

Please sign in to comment.