diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 9043520108..44fc4cb473 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -17,6 +17,7 @@ "qwen2_moe", "falcon", "phi", + "phi3", "gemma", "gemma2", "gemmoe", diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4f47d59bfb..8d24524a23 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -591,16 +591,10 @@ def load_model( "flash_attention_2" ) else: - if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES: - model_kwargs["attn_implementation"] = "flash_attention_2" - model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) - else: - model_kwargs["attn_implementation"] = "eager" - model_config._attn_implementation = ( # pylint: disable=protected-access - "eager" - ) + model_kwargs["attn_implementation"] = "flash_attention_2" + model_config._attn_implementation = ( # pylint: disable=protected-access + "flash_attention_2" + ) elif cfg.sdp_attention: model_kwargs["attn_implementation"] = "sdpa" model_config._attn_implementation = "sdpa" # pylint: disable=protected-access