Skip to content

Commit

Permalink
Refactor DistilBert modeling classes (speechbrain#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Aug 25, 2023
1 parent 717de0d commit b077e26
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 1,122 deletions.
3 changes: 3 additions & 0 deletions src/adapter_transformers/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .albert import AlbertModelAdaptersMixin
from .bart import BartDecoderAdaptersMixin, BartEncoderAdaptersMixin, BartModelAdaptersMixin
from .bert import BertLayerAdaptersMixin, BertModelAdaptersMixin
from .distilbert import DistilBertModelAdaptersMixin, DistilBertTransformerAdaptersMixin
from .t5 import T5BlockAdaptersMixin, T5ModelAdaptersMixin, T5ModelAdaptersWithHeadsMixin


Expand All @@ -12,6 +13,8 @@
"BartModel": BartModelAdaptersMixin,
"BertLayer": BertLayerAdaptersMixin,
"BertModel": BertModelAdaptersMixin,
"Transformer": DistilBertTransformerAdaptersMixin,
"DistilBertModel": DistilBertModelAdaptersMixin,
"RobertaLayer": BertLayerAdaptersMixin,
"RobertaModel": BertModelAdaptersMixin,
"T5Block": T5BlockAdaptersMixin,
Expand Down
56 changes: 40 additions & 16 deletions src/adapter_transformers/mixins/distilbert.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,58 @@
from typing import Iterable, Tuple
from typing import Callable, Iterable, Tuple

import torch.nn as nn

from ..layer import AdapterLayer
from ..model_mixin import (
EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
InvertibleAdaptersMixin,
ModelAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)
from ..lora import Linear as LoRALinear
from ..model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
from ..prefix_tuning import PrefixTuningShim


class DistilBertMultiHeadSelfAttentionMixin:
"""Adds adapters to the MultiHeadSelfAttention module of DistilBert."""

def init_adapters(self, config):
# Wrap layers for LoRA
self.q_lin = LoRALinear.wrap(self.q_lin, "selfattn", config, attn_key="q")
self.k_lin = LoRALinear.wrap(self.k_lin, "selfattn", config, attn_key="k")
self.v_lin = LoRALinear.wrap(self.v_lin, "selfattn", config, attn_key="v")

self.prefix_tuning = PrefixTuningShim("self", config)


class DistilBertTransfomerBlockAdaptersMixin:
"""Adds adapters to the TransformerBlock module of DistilBert."""

def _init_adapter_modules(self):
self.attention_adapters = AdapterLayer("mh_adapter", self.config)
self.output_adapters = AdapterLayer("output_adapter", self.config)
self.attention_adapters._init_adapter_modules()
self.output_adapters._init_adapter_modules()
def init_adapters(self, config):
# Wrap layers for LoRA
self.ffn.lin1 = LoRALinear.wrap(self.ffn.lin1, "intermediate", config)
self.ffn.lin2 = LoRALinear.wrap(self.ffn.lin2, "output", config)

self.attention_adapters = AdapterLayer("mh_adapter")
self.output_adapters = AdapterLayer("output_adapter")


class DistilBertTransformerAdaptersMixin:
"""Adds adapters to the Transformer module of DistilBert."""

def forward(self, *args, **kwargs):
if hasattr(self, "pre_forward_fn"):
kwargs["x"] = self.pre_forward_fn(self, kwargs["x"])
return super().forward(*args, **kwargs)


class DistilBertModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin):
class DistilBertModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin):
"""Adds adapters to the DistilBert module."""

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.transformer.layer):
yield i, layer

def _hook_fn(self, module, input):
new_input = self.invertible_adapters_forward(input)
return new_input

class DistilBertModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin):
pass
def hook_after_embeddings(self, hook_fn: Callable):
# PyTorch's built-in pre-forward hook does not pass the input ids.
# Therefore, we need to use a custom hook.
self.transformer.pre_forward_fn = hook_fn
2 changes: 1 addition & 1 deletion src/adapter_transformers/mixins/t5.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Iterable, Tuple
from typing import Iterable, Tuple

import torch.nn as nn

Expand Down
27 changes: 2 additions & 25 deletions src/adapter_transformers/models/distilbert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,12 @@


_import_structure = {
"adapter_model": [
"DistilBertAdapterModel",
"DistilBertModelWithHeads",
],
"modeling_distilbert": [
"DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"DistilBertForMaskedLM",
"DistilBertForMultipleChoice",
"DistilBertForQuestionAnswering",
"DistilBertForSequenceClassification",
"DistilBertForTokenClassification",
"DistilBertModel",
"DistilBertPreTrainedModel",
],
"adapter_model": ["DistilBertAdapterModel"],
}


if TYPE_CHECKING:
from .adapter_model import DistilBertAdapterModel, DistilBertModelWithHeads
from .modeling_distilbert import (
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
DistilBertForMaskedLM,
DistilBertForMultipleChoice,
DistilBertForQuestionAnswering,
DistilBertForSequenceClassification,
DistilBertForTokenClassification,
DistilBertModel,
DistilBertPreTrainedModel,
)
from .adapter_model import DistilBertAdapterModel

else:
import sys
Expand Down
51 changes: 8 additions & 43 deletions src/adapter_transformers/models/distilbert/adapter_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import warnings

import torch.nn as nn

from transformers.models.distilbert.modeling_distilbert import (
DISTILBERT_INPUTS_DOCSTRING,
DISTILBERT_START_DOCSTRING,
DistilBertModel,
DistilBertPreTrainedModel,
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...heads import (
Expand All @@ -16,12 +20,7 @@
TaggingHead,
)
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from .modeling_distilbert import (
DISTILBERT_INPUTS_DOCSTRING,
DISTILBERT_START_DOCSTRING,
DistilBertModel,
DistilBertPreTrainedModel,
)
from ...wrappers import wrap_model


@add_start_docstrings(
Expand All @@ -33,7 +32,7 @@ class DistilBertAdapterModel(
):
def __init__(self, config):
super().__init__(config)
self.distilbert = DistilBertModel(config)
self.distilbert = wrap_model(DistilBertModel(config))

self._init_head_modules()

Expand Down Expand Up @@ -249,37 +248,3 @@ def add_causal_lm_head(self, head_name, activation_function="gelu", overwrite_ok
self, head_name, layers=2, activation_function=activation_function, layer_norm=True, bias=True
)
self.add_prediction_head(head, overwrite_ok=overwrite_ok)


class DistilBertModelWithHeads(DistilBertAdapterModel):
def __init__(self, *args, **kwargs):
warnings.warn(
"This class has been renamed to `{}` in v3. "
"Please use the new class instead as this class might be removed in a future version.".format(
self.__class__.__bases__[0].__name__
),
FutureWarning,
)
super().__init__(*args, **kwargs)

@classmethod
def from_config(cls, config):
warnings.warn(
"This class has been renamed to `{}` in v3. "
"Please use the new class instead as this class might be removed in a future version.".format(
cls.__bases__[0].__name__
),
FutureWarning,
)
return super().from_config(config)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
warnings.warn(
"This class has been renamed to `{}` in v3. "
"Please use the new class instead as this class might be removed in a future version.".format(
cls.__bases__[0].__name__
),
FutureWarning,
)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
Loading

0 comments on commit b077e26

Please sign in to comment.