Skip to content

Commit

Permalink
Remove legacy code for patching _get_unpad_data
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Nov 9, 2024
1 parent ed1ecc3 commit 285193c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 61 deletions.
73 changes: 15 additions & 58 deletions src/axolotl/monkeypatch/multipack.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""multipack patching for v2 of sample packing"""

import importlib

import transformers
Expand Down Expand Up @@ -28,71 +29,27 @@


def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
if model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
if has_remote_code:
patch_remote(model_name)
elif hasattr(transformers, "modeling_flash_attention_utils"):
if not has_remote_code:
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
return

# retain for legacy
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
elif model_type == "llama":
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "mistral":
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2_moe":
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma2":
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()


def patch_remote(model_name, config_name, modeling_name):
def patch_remote(model_name):
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_* to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
parts = model_config.__class__.__module__.split(".")
parts[-1] = parts[-1].replace("configuration_", "modeling_", 1)
module_name = ".".join(parts)
modeling_arch = importlib.import_module(module_name)
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
if hasattr(modeling_arch, "_get_unpad_data"):
modeling_arch._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
6 changes: 3 additions & 3 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,9 @@ def apply_patches(self) -> None:
and self.cfg.flash_attention
and self.cfg.sample_packing
):
has_remote_code = (
"auto_map" in self.model_config
and self.model_type in self.model_config["auto_map"]
has_remote_code = "auto_map" in self.model_config and (
self.model_type in self.model_config["auto_map"]
or "AutoModel" in self.model_config["auto_map"]
)
patch_for_multipack(
self.cfg.model_config_type,
Expand Down

0 comments on commit 285193c

Please sign in to comment.