Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Initialization Context #695

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/llmcompressor/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import inspect
import logging
import os
from collections import OrderedDict
from contextlib import suppress
Expand Down Expand Up @@ -34,6 +35,7 @@
"POSSIBLE_TOKENIZER_FILES",
"download_repo_from_huggingface_hub",
"download_model_directory",
"FastInitialization",
]


Expand Down Expand Up @@ -490,3 +492,35 @@ def download_model_directory(pretrained_model_name_or_path: str, **kwargs):
return download_repo_from_huggingface_hub(
repo_id=pretrained_model_name_or_path, **kwargs
)


class FastInitialization:
kaiming_uniform_ = torch.nn.init.kaiming_uniform_
uniform_ = torch.nn.init.uniform_
normal_ = torch.nn.init.normal_

transformers_logger = logging.getLogger("transformers.modeling_utils")
restore_log_level = transformers_logger.getEffectiveLevel()

def __enter__(self):
# Skip the initializer step. This accelerates the loading
# of the models, especially for the quantized models
torch.nn.init.kaiming_uniform_ = self.skip
torch.nn.init.uniform_ = self.skip
torch.nn.init.normal_ = self.skip

# temporarily set the log level to error, to ignore printing out long missing
# and unexpected key error messages (these are EXPECTED for quantized models)
self.transformers_logger.setLevel(level=logging.ERROR)

def __exit__(self, _exc_type, _exc_val, _exc_tb):
# restore original functions
torch.nn.init.kaiming_uniform_ = self.kaiming_uniform_
torch.nn.init.uniform_ = self.uniform_
torch.nn.init.normal_ = self.normal_

# restore transformers logging level now that model shell is loaded
self.transformers_logger.setLevel(level=self.restore_log_level)

def skip(self, *args, **kwargs):
pass
29 changes: 5 additions & 24 deletions src/llmcompressor/transformers/wrap.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import logging
from pathlib import Path
from typing import Optional, Union

import torch
from accelerate import load_checkpoint_and_dispatch
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.quantization import (
Expand All @@ -17,6 +15,7 @@
modify_save_pretrained,
)
from llmcompressor.transformers.utils.helpers import (
FastInitialization,
download_model_directory,
resolve_recipe,
)
Expand Down Expand Up @@ -54,16 +53,6 @@ def from_pretrained(
pretrained_model_name_or_path directory and applied if found
:return the created model for causal language modeling
"""

def skip(*args, **kwargs):
pass

# Skip the initializer step. This accelerates the loading
# of the models, especially for the quantized models
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip

pretrained_model_name_or_path = (
pretrained_model_name_or_path.as_posix()
if isinstance(pretrained_model_name_or_path, Path)
Expand All @@ -79,12 +68,6 @@ def skip(*args, **kwargs):
pretrained_model_name_or_path, **kwargs
)

# temporarily set the log level to error, to ignore printing out long missing
# and unexpected key error messages (these are EXPECTED for quantized models)
transformers_logger = logging.getLogger("transformers.modeling_utils")
restore_log_level = transformers_logger.getEffectiveLevel()
transformers_logger.setLevel(level=logging.ERROR)

if kwargs.get("trust_remote_code"):
# By artifically aliasing the
# class name to the
Expand All @@ -94,9 +77,10 @@ def skip(*args, **kwargs):
# (has_remote_code and trust_remote_code) == True
cls.__name__ = hf_model_class.__name__

model = super(hf_model_class, cls).from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
with FastInitialization():
model = super(hf_model_class, cls).from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)

if model.dtype != model.config.torch_dtype:
logger.warning(
Expand All @@ -107,9 +91,6 @@ def skip(*args, **kwargs):
"set torch_dtype=`auto` in the SparseAutoModel creation call."
)

# restore transformers logging level now that model shell is loaded
transformers_logger.setLevel(level=restore_log_level)

# HfQuantizer Quantization
if hasattr(model.config, "quantization_config"):
return model
Expand Down
Loading