From af48625e86a71cc73656b30b966944832729c8ba Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Wed, 30 Oct 2024 11:37:13 +0000 Subject: [PATCH] Fix remote code checking --- src/axolotl/core/trainer_builder.py | 4 ++-- src/axolotl/monkeypatch/multipack.py | 11 ++++++----- src/axolotl/utils/models.py | 6 +++++- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index aab9a80b8b..f4a2f90019 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -896,13 +896,13 @@ def store_metrics( for key, value in metrics.items(): self._stored_metrics[train_eval][key].append(value) - def _save_checkpoint(self, model, trial, metrics=None): + def _save_checkpoint(self, model, trial): # make sure the checkpoint dir exists, since trainer is flakey checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) os.makedirs(output_dir, exist_ok=True) - return super()._save_checkpoint(model, trial, metrics=metrics) + return super()._save_checkpoint(model, trial) class AxolotlMambaTrainer(AxolotlTrainer): diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 85101cd3c4..ca8b8b3664 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -27,15 +27,16 @@ ] -def patch_for_multipack(model_type, model_name=None, is_remote_code=False): +def patch_for_multipack(model_type, model_name=None, has_remote_code=False): if model_type == "gemmoe": patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") elif model_type == "deepseek_v2": patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek") - elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code: - transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) + elif hasattr(transformers, "modeling_flash_attention_utils"): + if not has_remote_code: + transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) if model_type == "mixtral" and is_deepspeed_zero3_enabled(): patch_mixtral_moe_forward_zero3() return diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index f3386cccfa..ba4a7446ee 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -394,10 +394,14 @@ def apply_patches(self) -> None: and self.cfg.flash_attention and self.cfg.sample_packing ): + has_remote_code = ( + "auto_map" in self.model_config + and self.model_type in self.model_config["auto_map"] + ) patch_for_multipack( self.cfg.model_config_type, model_name=self.cfg.base_model, - is_remote_code=self.cfg.trust_remote_code, + has_remote_code=has_remote_code, ) if self.cfg.is_llama_derived_model: