diff --git a/docs/overview.md b/docs/overview.md index 244bf9965f..943d33dbda 100644 --- a/docs/overview.md +++ b/docs/overview.md @@ -56,6 +56,7 @@ Identifiers and configuration classes are explained in more detail in the [next | `prefix_tuning_flat` | `PrefixTuningConfig(flat=True)` | [Prefix Tuning](methods.html#prefix-tuning) | | `lora` | `LoRAConfig()` | [LoRA](methods.html#lora) | | `vera` | `VeraConfig()` | [Vera](methods.html#vera) | +| `adamix` | `AdaMixConfig()` | [AdaMix](methods.html#adamix) | `ia3` | `IA3Config()` | [IA³](methods.html#ia-3) | | `mam` | `MAMConfig()` | [Mix-and-Match Adapters](method_combinations.html#mix-and-match-adapters) | | `unipelt` | `UniPELTConfig()` | [UniPELT](method_combinations.html#unipelt) | diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index b8424e0107..b883364e66 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -69,6 +69,7 @@ "StaticAdapterFusionConfig", "UniPELTConfig", "VeraConfig", + "AdaMixConfig", ], "context": [ "AdapterSetup", @@ -194,6 +195,7 @@ StaticAdapterFusionConfig, UniPELTConfig, VeraConfig, + AdaMixConfig, ) from .context import AdapterSetup, ForwardContext from .heads import ( diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 41b8570699..786a1fcbcd 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -842,6 +842,28 @@ def __init__( ] super().__init__(*[c.replace(use_gating=True) for c in components]) + + +class AdaMixConfig(AdapterConfig): + """ + The 'Mixture of Adapter Experts' AdaMix method was proposed by Zhang et al. (2022). + See https://arxiv.org/abs/2205.12410. + """ + mh_adapter = True + output_adapter = True + reduction_factor = 16 + non_linearity = "relu" + + num_experts = 4 + expert_dropout = 0.1 + routing_algorithm = "linear" + use_load_balancing = True + load_balancing_weight = 0.01 + selection_mode = "top_k" + k = 2 + + def __init__(self, **kwargs): + super().__init__(**kwargs) # IMPORTANT: When adding a new config here, also add it to docs/overview.md! @@ -874,6 +896,7 @@ def __init__( "direft": DiReftConfig(), "mam": MAMConfig(), "unipelt": UniPELTConfig(), + "adamix": AdaMixConfig(), } DEFAULT_ADAPTER_CONFIG = "seq_bn" diff --git a/src/adapters/methods/adamix.py b/src/adapters/methods/adamix.py new file mode 100644 index 0000000000..e69de29bb2