Skip to content

Commit

Permalink
fix the monkeypatch
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Nov 19, 2024
1 parent 1ff78d6 commit afb8218
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions src/axolotl/monkeypatch/modeling_zero3_int8_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
see https://github.com/huggingface/transformers/pull/32943/files
"""
import inspect
import logging

import transformers
import transformers.modeling_utils
from accelerate.logging import get_logger
from transformers import modeling_utils

LOG = get_logger("axolotl.monkeypatch.modeling_zero3_int8_lora")
LOG = logging.getLogger("axolotl.monkeypatch.modeling_zero3_int8_lora")

ORIGINAL_LOAD_CODE = """
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
Expand Down Expand Up @@ -38,7 +37,7 @@

def get_modeling_state_dict_code() -> str:
load_code = inspect.getsource(
transformers.modeling_utils._load_state_dict_into_meta_model # pylint: disable=protected-access
modeling_utils._load_state_dict_into_meta_model # pylint: disable=protected-access
)
return load_code

Expand All @@ -54,7 +53,7 @@ def patch_modeling_state_dict_code():
"""

load_code = get_modeling_state_dict_code()
transformers.modeling_utils._original_load_state_dict_into_meta_model = ( # pylint: disable=protected-access
modeling_utils._original_load_state_dict_into_meta_model = ( # pylint: disable=protected-access
load_code
)
assert (
Expand All @@ -69,7 +68,7 @@ def patch_modeling_state_dict_code():
)

items_to_import = []
for item in dir(transformers.modeling_utils):
for item in dir(modeling_utils):
if item in load_code:
items_to_import.append(item)

Expand All @@ -80,5 +79,5 @@ def patch_modeling_state_dict_code():
globals(),
)
exec(load_code, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _load_state_dict_into_meta_model", main_process_only=True)
transformers.modeling_utils._load_state_dict_into_meta_model = _fixed_load_state_dict_into_meta_model # pylint: disable=protected-access,undefined-variable # noqa: F821
LOG.info("patching _load_state_dict_into_meta_model")
modeling_utils._load_state_dict_into_meta_model = _fixed_load_state_dict_into_meta_model # pylint: disable=protected-access,undefined-variable # noqa: F821

0 comments on commit afb8218

Please sign in to comment.