From 285193c1933dac665ae08b9eef95a355117bf8a2 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Sat, 9 Nov 2024 18:54:54 +0000 Subject: [PATCH] Remove legacy code for patching _get_unpad_data --- src/axolotl/monkeypatch/multipack.py | 73 ++++++---------------------- src/axolotl/utils/models.py | 6 +-- 2 files changed, 18 insertions(+), 61 deletions(-) diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index ca8b8b3664..3ee89d2e5c 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -1,4 +1,5 @@ """multipack patching for v2 of sample packing""" + import importlib import transformers @@ -28,71 +29,27 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False): - if model_type == "gemmoe": - patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") - elif model_type == "deepseek_v2": - patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek") + if has_remote_code: + patch_remote(model_name) elif hasattr(transformers, "modeling_flash_attention_utils"): - if not has_remote_code: - transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - if model_type == "mixtral" and is_deepspeed_zero3_enabled(): - patch_mixtral_moe_forward_zero3() - return - - # retain for legacy - if model_type == "mixtral": - transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - if is_deepspeed_zero3_enabled(): - patch_mixtral_moe_forward_zero3() - elif model_type == "llama": - if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"): - transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "mistral": - if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"): - transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "qwen2": - transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "qwen2_moe": - transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "falcon": - transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "phi": - transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "gemma": - transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "gemma2": - transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "starcoder2": - transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access + transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data ) + if model_type == "mixtral" and is_deepspeed_zero3_enabled(): + patch_mixtral_moe_forward_zero3() + -def patch_remote(model_name, config_name, modeling_name): +def patch_remote(model_name): model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) # we need to load the model here in order for modeling_* to be available with init_empty_weights(): AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) - module_name = model_config.__class__.__module__.replace(config_name, modeling_name) + parts = model_config.__class__.__module__.split(".") + parts[-1] = parts[-1].replace("configuration_", "modeling_", 1) + module_name = ".".join(parts) modeling_arch = importlib.import_module(module_name) - modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access + if hasattr(modeling_arch, "_get_unpad_data"): + modeling_arch._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d3643bf4e3..4da4973918 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -394,9 +394,9 @@ def apply_patches(self) -> None: and self.cfg.flash_attention and self.cfg.sample_packing ): - has_remote_code = ( - "auto_map" in self.model_config - and self.model_type in self.model_config["auto_map"] + has_remote_code = "auto_map" in self.model_config and ( + self.model_type in self.model_config["auto_map"] + or "AutoModel" in self.model_config["auto_map"] ) patch_for_multipack( self.cfg.model_config_type,