Skip to content

Commit

Permalink
add ut and add backend
Browse files Browse the repository at this point in the history
Signed-off-by: changwangss <[email protected]>
  • Loading branch information
changwangss committed Sep 18, 2024
1 parent f0e1e51 commit 8ecb856
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 31 deletions.
2 changes: 1 addition & 1 deletion neural_compressor/torch/algorithms/weight_only/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,5 +1438,5 @@ def repack_awq_to_optimum_format(
Expected shape: (in_features // group_size, out_features)
"""
unpack_qweight, unpack_qzeros = unpack_awq(awq_qweight, awq_qzeros, awq_scales, bits, group_size)
qweight, qzeros = pack_from_tensors(unpack_qweight, unpack_qzeros, awq_scales)
qweight, qzeros = pack_from_tensors(unpack_qweight, unpack_qzeros, awq_scales, bits, group_size)
return qweight, qzeros, awq_scales
82 changes: 52 additions & 30 deletions neural_compressor/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,13 @@
from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear
from neural_compressor.torch.utils import set_module

from ..quantization.utils import convert_dtype_torch2str, convert_to_quantized_model, replace_linear, save_low_bit
from ..quantization.utils import (
convert_dtype_torch2str,
convert_to_quantized_model,
repack_awq_and_load_state_dict,
replace_linear,
save_low_bit,
)
from ..utils import AutoRoundConfig, AwqConfig, GPTQConfig, RtnConfig, TeqConfig


Expand Down Expand Up @@ -179,6 +185,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
) and model.config.model_type == "chatglm":
model = model.float()
model = convert_to_quantized_model(model, quantization_config, device=device_map)
if isinstance(quantization_config, AwqConfig):
quantization_config.backend = "inc"
quantization_config.remove_redundant_parameters()
model.config.quantization_config = quantization_config
else:
Expand Down Expand Up @@ -295,6 +303,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
quantization_config = GPTQConfig.from_dict(quantization_config)
elif quantization_config["quant_method"] == "autoround":
quantization_config = AutoRoundConfig.from_dict(quantization_config)

assert quantization_config is not None, "Detect this model is not a low-bit model."

if commit_hash is None:
Expand Down Expand Up @@ -613,41 +622,54 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):

with ContextManagers(init_contexts):
model = model_class(config, *model_args, **kwargs)

if quantization_config.quant_method.value == "awq" and quantization_config.backend != "inc":
if quantization_config.modules_to_not_convert is None:
quantization_config.modules_to_not_convert = ["lm_head", "transformer.output_layer", "embed_out"]
else:
quantization_config.modules_to_not_convert += ["lm_head", "transformer.output_layer", "embed_out"]
model = build_woq_model(model, quantization_config)

if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
# Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file)
loaded_state_dict_keys = list(state_dict.keys())

# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = model_class._load_pretrained_model(
model,
None,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=True,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
keep_in_fp32_modules=[],
)
if quantization_config.quant_method.value == "awq" and quantization_config.backend != "inc":
if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
state_dict = load_state_dict(resolved_archive_file)
loaded_state_dict_keys = list(state_dict.keys())
model = repack_awq_and_load_state_dict(
model, resolved_archive_file, loaded_state_dict_keys, quantization_config, is_sharded
)
else:
if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
# Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file)
loaded_state_dict_keys = list(state_dict.keys())
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = model_class._load_pretrained_model(
model,
None,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=True,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
keep_in_fp32_modules=[],
)

# make sure token embedding weights are still tied if needed
model.tie_weights()
Expand Down
38 changes: 38 additions & 0 deletions neural_compressor/transformers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from neural_compressor.common.utils import LazyImport, logger
from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear
from neural_compressor.torch.algorithms.weight_only.utility import repack_awq_to_optimum_format
from neural_compressor.torch.quantization import (
AutoRoundConfig,
AWQConfig,
Expand Down Expand Up @@ -654,3 +655,40 @@ def save_low_bit(self, save_directory: Union[str, os.PathLike], push_to_hub: boo
token=kwargs.get("token"),
)
self.quantization_config.save_pretrained(save_directory, **kwargs)


def repack_awq_and_load_state_dict(
model, resolved_archive_file, loaded_state_dict_keys, quantization_config, is_sharded
):
from transformers.modeling_utils import load_state_dict

bits = quantization_config.bits
group_size = quantization_config.group_size

state_dict = {}
if isinstance(resolved_archive_file, str):
resolved_archive_file = [resolved_archive_file]
assert isinstance(resolved_archive_file, list), "Please check if the loading weight is shared."
for shard_file in resolved_archive_file:
assert shard_file.endswith("safetensors"), "Please check the loading weight saved format."
state_dict.update(load_state_dict(shard_file))
assert len(state_dict.keys()) > 0, "Please check the state_dict loading."
for name, module in model.named_modules():
if isinstance(module, INCWeightOnlyLinear):
assert name + ".qweight" in loaded_state_dict_keys, f"Please check the state_dict key { name + '.qweight'}"
assert name + ".qzeros" in loaded_state_dict_keys, f"Please check the state_dict key {name + '.qzeros'}"
assert name + ".scales" in loaded_state_dict_keys, f"Please check the state_dict key { name + '.scales'}"
if name + ".scales" in loaded_state_dict_keys:
awq_qweight = state_dict[name + ".qweight"]
awq_qzeros = state_dict[name + ".qzeros"]
awq_scales = state_dict[name + ".scales"]
qweight, qzeros, awq_scales = repack_awq_to_optimum_format(
awq_qweight, awq_qzeros, awq_scales, bits, group_size
)
state_dict[name + ".qweight"] = qweight
state_dict[name + ".qzeros"] = qzeros
state_dict[name + ".scales"] = awq_scales

model.load_state_dict(state_dict, strict=False, assign=True)

return model
2 changes: 2 additions & 0 deletions neural_compressor/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def __init__(
zero_point: bool = True,
absorb_layer_dict: dict = {},
quant_lm_head: bool = False,
backend: str = None,
**kwargs,
):
self.quant_method = QuantizationMethod.AWQ
Expand All @@ -427,6 +428,7 @@ def __init__(
self.seq_len = seq_len
self.absorb_layer_dict = absorb_layer_dict
self.quant_lm_head = quant_lm_head
self.backend = backend
self.modules_to_not_convert = kwargs.get(
"modules_to_not_convert", ["lm_head", "transformer.output_layer", "embed_out"]
)
Expand Down
13 changes: 13 additions & 0 deletions test/3x/torch/quantization/weight_only/test_transfomers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
class TestTansformersLikeAPI:
def setup_class(self):
self.model_name_or_path = "hf-internal-testing/tiny-random-gptj"
self.autoawq_model = "casperhansen/opt-125m-awq"
self.prompt = "One day, the little girl"
self.generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4)

def teardown_class(self):
shutil.rmtree("nc_workspace", ignore_errors=True)
Expand Down Expand Up @@ -111,3 +114,13 @@ def test_save_load(self):
loaded_model = AutoModelForCausalLM.from_pretrained(output_dir)
loaded_output = loaded_model(dummy_input)[0]
assert torch.equal(woq_output, loaded_output), "loaded output should be same. Please double check."

def test_loading_autoawq_model(self):
user_model = AutoModelForCausalLM.from_pretrained(self.autoawq_model)
tokenizer = AutoTokenizer.from_pretrained(self.autoawq_model)
input_ids = tokenizer(self.prompt, return_tensors="pt")["input_ids"]
self.generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4)
gen_ids = user_model.generate(input_ids, **self.generate_kwargs)
gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
target_text = ["One day, the little girl in the back of my mind will ask me if I'm a"]
assert gen_text == target_text, "loading autoawq quantized model failed."

0 comments on commit 8ecb856

Please sign in to comment.