From 8ecb856df23cd7750598d8988d8107a4ff97952b Mon Sep 17 00:00:00 2001 From: changwangss Date: Wed, 18 Sep 2024 01:29:07 -0700 Subject: [PATCH] add ut and add backend Signed-off-by: changwangss --- .../torch/algorithms/weight_only/utility.py | 2 +- .../transformers/models/modeling_auto.py | 82 ++++++++++++------- .../transformers/quantization/utils.py | 38 +++++++++ .../transformers/utils/quantization_config.py | 2 + .../weight_only/test_transfomers.py | 13 +++ 5 files changed, 106 insertions(+), 31 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/utility.py b/neural_compressor/torch/algorithms/weight_only/utility.py index 7405aca960d..6f4256534bd 100644 --- a/neural_compressor/torch/algorithms/weight_only/utility.py +++ b/neural_compressor/torch/algorithms/weight_only/utility.py @@ -1438,5 +1438,5 @@ def repack_awq_to_optimum_format( Expected shape: (in_features // group_size, out_features) """ unpack_qweight, unpack_qzeros = unpack_awq(awq_qweight, awq_qzeros, awq_scales, bits, group_size) - qweight, qzeros = pack_from_tensors(unpack_qweight, unpack_qzeros, awq_scales) + qweight, qzeros = pack_from_tensors(unpack_qweight, unpack_qzeros, awq_scales, bits, group_size) return qweight, qzeros, awq_scales diff --git a/neural_compressor/transformers/models/modeling_auto.py b/neural_compressor/transformers/models/modeling_auto.py index a4a91e27f03..657d2a9bd49 100644 --- a/neural_compressor/transformers/models/modeling_auto.py +++ b/neural_compressor/transformers/models/modeling_auto.py @@ -47,7 +47,13 @@ from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear from neural_compressor.torch.utils import set_module -from ..quantization.utils import convert_dtype_torch2str, convert_to_quantized_model, replace_linear, save_low_bit +from ..quantization.utils import ( + convert_dtype_torch2str, + convert_to_quantized_model, + repack_awq_and_load_state_dict, + replace_linear, + save_low_bit, +) from ..utils import AutoRoundConfig, AwqConfig, GPTQConfig, RtnConfig, TeqConfig @@ -179,6 +185,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): ) and model.config.model_type == "chatglm": model = model.float() model = convert_to_quantized_model(model, quantization_config, device=device_map) + if isinstance(quantization_config, AwqConfig): + quantization_config.backend = "inc" quantization_config.remove_redundant_parameters() model.config.quantization_config = quantization_config else: @@ -295,6 +303,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): quantization_config = GPTQConfig.from_dict(quantization_config) elif quantization_config["quant_method"] == "autoround": quantization_config = AutoRoundConfig.from_dict(quantization_config) + assert quantization_config is not None, "Detect this model is not a low-bit model." if commit_hash is None: @@ -613,41 +622,54 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): with ContextManagers(init_contexts): model = model_class(config, *model_args, **kwargs) - + if quantization_config.quant_method.value == "awq" and quantization_config.backend != "inc": + if quantization_config.modules_to_not_convert is None: + quantization_config.modules_to_not_convert = ["lm_head", "transformer.output_layer", "embed_out"] + else: + quantization_config.modules_to_not_convert += ["lm_head", "transformer.output_layer", "embed_out"] model = build_woq_model(model, quantization_config) - if is_sharded: - loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] - else: - # Time to load the checkpoint - state_dict = load_state_dict(resolved_archive_file) - loaded_state_dict_keys = list(state_dict.keys()) - # restore default dtype if dtype_orig is not None: torch.set_default_dtype(dtype_orig) - ( - model, - missing_keys, - unexpected_keys, - mismatched_keys, - offload_index, - error_msgs, - ) = model_class._load_pretrained_model( - model, - None, - loaded_state_dict_keys, # XXX: rename? - resolved_archive_file, - pretrained_model_name_or_path, - sharded_metadata=sharded_metadata, - _fast_init=_fast_init, - low_cpu_mem_usage=True, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, - dtype=torch_dtype, - keep_in_fp32_modules=[], - ) + if quantization_config.quant_method.value == "awq" and quantization_config.backend != "inc": + if is_sharded: + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + else: + state_dict = load_state_dict(resolved_archive_file) + loaded_state_dict_keys = list(state_dict.keys()) + model = repack_awq_and_load_state_dict( + model, resolved_archive_file, loaded_state_dict_keys, quantization_config, is_sharded + ) + else: + if is_sharded: + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + else: + # Time to load the checkpoint + state_dict = load_state_dict(resolved_archive_file) + loaded_state_dict_keys = list(state_dict.keys()) + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = model_class._load_pretrained_model( + model, + None, + loaded_state_dict_keys, # XXX: rename? + resolved_archive_file, + pretrained_model_name_or_path, + sharded_metadata=sharded_metadata, + _fast_init=_fast_init, + low_cpu_mem_usage=True, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + keep_in_fp32_modules=[], + ) # make sure token embedding weights are still tied if needed model.tie_weights() diff --git a/neural_compressor/transformers/quantization/utils.py b/neural_compressor/transformers/quantization/utils.py index 6f209344348..e66e573e3b2 100644 --- a/neural_compressor/transformers/quantization/utils.py +++ b/neural_compressor/transformers/quantization/utils.py @@ -23,6 +23,7 @@ from neural_compressor.common.utils import LazyImport, logger from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear +from neural_compressor.torch.algorithms.weight_only.utility import repack_awq_to_optimum_format from neural_compressor.torch.quantization import ( AutoRoundConfig, AWQConfig, @@ -654,3 +655,40 @@ def save_low_bit(self, save_directory: Union[str, os.PathLike], push_to_hub: boo token=kwargs.get("token"), ) self.quantization_config.save_pretrained(save_directory, **kwargs) + + +def repack_awq_and_load_state_dict( + model, resolved_archive_file, loaded_state_dict_keys, quantization_config, is_sharded +): + from transformers.modeling_utils import load_state_dict + + bits = quantization_config.bits + group_size = quantization_config.group_size + + state_dict = {} + if isinstance(resolved_archive_file, str): + resolved_archive_file = [resolved_archive_file] + assert isinstance(resolved_archive_file, list), "Please check if the loading weight is shared." + for shard_file in resolved_archive_file: + assert shard_file.endswith("safetensors"), "Please check the loading weight saved format." + state_dict.update(load_state_dict(shard_file)) + assert len(state_dict.keys()) > 0, "Please check the state_dict loading." + for name, module in model.named_modules(): + if isinstance(module, INCWeightOnlyLinear): + assert name + ".qweight" in loaded_state_dict_keys, f"Please check the state_dict key { name + '.qweight'}" + assert name + ".qzeros" in loaded_state_dict_keys, f"Please check the state_dict key {name + '.qzeros'}" + assert name + ".scales" in loaded_state_dict_keys, f"Please check the state_dict key { name + '.scales'}" + if name + ".scales" in loaded_state_dict_keys: + awq_qweight = state_dict[name + ".qweight"] + awq_qzeros = state_dict[name + ".qzeros"] + awq_scales = state_dict[name + ".scales"] + qweight, qzeros, awq_scales = repack_awq_to_optimum_format( + awq_qweight, awq_qzeros, awq_scales, bits, group_size + ) + state_dict[name + ".qweight"] = qweight + state_dict[name + ".qzeros"] = qzeros + state_dict[name + ".scales"] = awq_scales + + model.load_state_dict(state_dict, strict=False, assign=True) + + return model diff --git a/neural_compressor/transformers/utils/quantization_config.py b/neural_compressor/transformers/utils/quantization_config.py index 925cc3ccc7a..13dff04dc4f 100644 --- a/neural_compressor/transformers/utils/quantization_config.py +++ b/neural_compressor/transformers/utils/quantization_config.py @@ -409,6 +409,7 @@ def __init__( zero_point: bool = True, absorb_layer_dict: dict = {}, quant_lm_head: bool = False, + backend: str = None, **kwargs, ): self.quant_method = QuantizationMethod.AWQ @@ -427,6 +428,7 @@ def __init__( self.seq_len = seq_len self.absorb_layer_dict = absorb_layer_dict self.quant_lm_head = quant_lm_head + self.backend = backend self.modules_to_not_convert = kwargs.get( "modules_to_not_convert", ["lm_head", "transformer.output_layer", "embed_out"] ) diff --git a/test/3x/torch/quantization/weight_only/test_transfomers.py b/test/3x/torch/quantization/weight_only/test_transfomers.py index 95a89f86f68..e9194d9a371 100644 --- a/test/3x/torch/quantization/weight_only/test_transfomers.py +++ b/test/3x/torch/quantization/weight_only/test_transfomers.py @@ -18,6 +18,9 @@ class TestTansformersLikeAPI: def setup_class(self): self.model_name_or_path = "hf-internal-testing/tiny-random-gptj" + self.autoawq_model = "casperhansen/opt-125m-awq" + self.prompt = "One day, the little girl" + self.generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4) def teardown_class(self): shutil.rmtree("nc_workspace", ignore_errors=True) @@ -111,3 +114,13 @@ def test_save_load(self): loaded_model = AutoModelForCausalLM.from_pretrained(output_dir) loaded_output = loaded_model(dummy_input)[0] assert torch.equal(woq_output, loaded_output), "loaded output should be same. Please double check." + + def test_loading_autoawq_model(self): + user_model = AutoModelForCausalLM.from_pretrained(self.autoawq_model) + tokenizer = AutoTokenizer.from_pretrained(self.autoawq_model) + input_ids = tokenizer(self.prompt, return_tensors="pt")["input_ids"] + self.generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4) + gen_ids = user_model.generate(input_ids, **self.generate_kwargs) + gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True) + target_text = ["One day, the little girl in the back of my mind will ask me if I'm a"] + assert gen_text == target_text, "loading autoawq quantized model failed."