diff --git a/src/llmcompressor/transformers/utils/helpers.py b/src/llmcompressor/transformers/utils/helpers.py index 401a454cf..b0487b04d 100644 --- a/src/llmcompressor/transformers/utils/helpers.py +++ b/src/llmcompressor/transformers/utils/helpers.py @@ -4,6 +4,7 @@ """ import inspect +import logging import os from collections import OrderedDict from contextlib import suppress @@ -34,6 +35,7 @@ "POSSIBLE_TOKENIZER_FILES", "download_repo_from_huggingface_hub", "download_model_directory", + "FastInitialization", ] @@ -490,3 +492,35 @@ def download_model_directory(pretrained_model_name_or_path: str, **kwargs): return download_repo_from_huggingface_hub( repo_id=pretrained_model_name_or_path, **kwargs ) + + +class FastInitialization: + kaiming_uniform_ = torch.nn.init.kaiming_uniform_ + uniform_ = torch.nn.init.uniform_ + normal_ = torch.nn.init.normal_ + + transformers_logger = logging.getLogger("transformers.modeling_utils") + restore_log_level = transformers_logger.getEffectiveLevel() + + def __enter__(self): + # Skip the initializer step. This accelerates the loading + # of the models, especially for the quantized models + torch.nn.init.kaiming_uniform_ = self.skip + torch.nn.init.uniform_ = self.skip + torch.nn.init.normal_ = self.skip + + # temporarily set the log level to error, to ignore printing out long missing + # and unexpected key error messages (these are EXPECTED for quantized models) + self.transformers_logger.setLevel(level=logging.ERROR) + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + # restore original functions + torch.nn.init.kaiming_uniform_ = self.kaiming_uniform_ + torch.nn.init.uniform_ = self.uniform_ + torch.nn.init.normal_ = self.normal_ + + # restore transformers logging level now that model shell is loaded + self.transformers_logger.setLevel(level=self.restore_log_level) + + def skip(self, *args, **kwargs): + pass diff --git a/src/llmcompressor/transformers/wrap.py b/src/llmcompressor/transformers/wrap.py index e97efeded..ec638a197 100644 --- a/src/llmcompressor/transformers/wrap.py +++ b/src/llmcompressor/transformers/wrap.py @@ -1,8 +1,6 @@ -import logging from pathlib import Path from typing import Optional, Union -import torch from accelerate import load_checkpoint_and_dispatch from compressed_tensors.compressors import ModelCompressor from compressed_tensors.quantization import ( @@ -17,6 +15,7 @@ modify_save_pretrained, ) from llmcompressor.transformers.utils.helpers import ( + FastInitialization, download_model_directory, resolve_recipe, ) @@ -54,16 +53,6 @@ def from_pretrained( pretrained_model_name_or_path directory and applied if found :return the created model for causal language modeling """ - - def skip(*args, **kwargs): - pass - - # Skip the initializer step. This accelerates the loading - # of the models, especially for the quantized models - torch.nn.init.kaiming_uniform_ = skip - torch.nn.init.uniform_ = skip - torch.nn.init.normal_ = skip - pretrained_model_name_or_path = ( pretrained_model_name_or_path.as_posix() if isinstance(pretrained_model_name_or_path, Path) @@ -79,12 +68,6 @@ def skip(*args, **kwargs): pretrained_model_name_or_path, **kwargs ) - # temporarily set the log level to error, to ignore printing out long missing - # and unexpected key error messages (these are EXPECTED for quantized models) - transformers_logger = logging.getLogger("transformers.modeling_utils") - restore_log_level = transformers_logger.getEffectiveLevel() - transformers_logger.setLevel(level=logging.ERROR) - if kwargs.get("trust_remote_code"): # By artifically aliasing the # class name to the @@ -94,9 +77,10 @@ def skip(*args, **kwargs): # (has_remote_code and trust_remote_code) == True cls.__name__ = hf_model_class.__name__ - model = super(hf_model_class, cls).from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs - ) + with FastInitialization(): + model = super(hf_model_class, cls).from_pretrained( + pretrained_model_name_or_path, *model_args, **kwargs + ) if model.dtype != model.config.torch_dtype: logger.warning( @@ -107,9 +91,6 @@ def skip(*args, **kwargs): "set torch_dtype=`auto` in the SparseAutoModel creation call." ) - # restore transformers logging level now that model shell is loaded - transformers_logger.setLevel(level=restore_log_level) - # HfQuantizer Quantization if hasattr(model.config, "quantization_config"): return model