Skip to content

Commit

Permalink
simplify haldning for newer multipack patches so they can be added in…
Browse files Browse the repository at this point in the history
… a single place (#1270)
  • Loading branch information
winglian authored Feb 7, 2024
1 parent 411293b commit 5698943
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 88 deletions.
3 changes: 2 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from transformers.trainer_utils import seed_worker
from trl import DPOTrainer

from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
Expand Down Expand Up @@ -994,7 +995,7 @@ def build_collator(
]
]
if use_batch_sampler_collator:
if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
self.cfg.model_config_type in ["llama"]
Expand Down
12 changes: 0 additions & 12 deletions src/axolotl/monkeypatch/falcon/__init__.py

This file was deleted.

11 changes: 0 additions & 11 deletions src/axolotl/monkeypatch/mixtral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
Patches to support multipack for mixtral
"""
import torch
import transformers

from axolotl.monkeypatch.utils import get_unpad_data


def patch_mixtral_moe_forward_zero3() -> None:
Expand Down Expand Up @@ -51,11 +48,3 @@ def moe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

MixtralBLockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward


def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if for_zero3:
patch_mixtral_moe_forward_zero3()
30 changes: 30 additions & 0 deletions src/axolotl/monkeypatch/multipack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""multipack patching for v2 of sample packing"""

import transformers
from transformers.integrations import is_deepspeed_zero3_enabled

from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data

SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]


def patch_for_multipack(model_type):
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
12 changes: 0 additions & 12 deletions src/axolotl/monkeypatch/phi/__init__.py

This file was deleted.

12 changes: 0 additions & 12 deletions src/axolotl/monkeypatch/qwen2/__init__.py

This file was deleted.

54 changes: 14 additions & 40 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled

from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates
Expand Down Expand Up @@ -299,8 +303,15 @@ def load_model(
shifted-sparse attention does not currently support sample packing."
)

# Modify all llama derived models in one block
if cfg.is_llama_derived_model:
if (
cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and cfg.flash_attention
and cfg.sample_packing
):
patch_for_multipack(cfg.model_config_type)
elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block

if cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
Expand Down Expand Up @@ -354,43 +365,6 @@ def load_model(
LOG.info("patching mistral with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)

if (
cfg.model_config_type == "mixtral"
and cfg.flash_attention
and cfg.sample_packing
):
from axolotl.monkeypatch.mixtral import (
replace_mixtral_attn_with_multipack_flash_attn,
)

LOG.info("patching mixtral with flash attention")
mixtral_patch_kwargs = {}
if is_deepspeed_zero3_enabled():
mixtral_patch_kwargs["for_zero3"] = True
replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)

if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.falcon import (
replace_falcon_attn_with_multipack_flash_attn,
)

LOG.info("patching falcon with flash attention")
replace_falcon_attn_with_multipack_flash_attn()

if cfg.model_config_type == "phi" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.phi import replace_phi_attn_with_multipack_flash_attn

LOG.info("patching phi with flash attention")
replace_phi_attn_with_multipack_flash_attn()

if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.qwen2 import (
replace_qwen2_attn_with_multipack_flash_attn,
)

LOG.info("patching qwen2 with flash attention")
replace_qwen2_attn_with_multipack_flash_attn()

if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask

Expand Down Expand Up @@ -501,7 +475,7 @@ def load_model(
"flash_attention_2"
)
else:
if model_config.model_type in ["mixtral", "qwen2", "falcon", "phi"]:
if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
Expand Down

0 comments on commit 5698943

Please sign in to comment.