Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Support huggingface popular weight format for weight-only quantization #1580

Merged
merged 8 commits into from
Jul 5, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,7 @@

user_model = None

# tokenizer
if config.model_type == "llama":
from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained(args.model)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)

quantization_config = None
if args.woq:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
import gc
import math
import os
from ...utils import CpuInfo
from ....tools.utils import _ipex_version
from accelerate import init_empty_weights
from datasets import load_dataset
from neural_compressor import quantization
from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear
from neural_compressor.utils.utility import LazyImport
Expand All @@ -31,7 +30,6 @@
is_ipex_available,
is_autoround_available,
)
from transformers import AutoTokenizer

if is_ipex_available():
import intel_extension_for_pytorch as ipex
Expand Down Expand Up @@ -273,10 +271,12 @@ def _replace_linear(
scale_dtype=quantization_config.scale_dtype,
blocksize=quantization_config.group_size,
scheme=quantization_config.scheme,
compression_dtype=getattr(module, "compression_dtype", torch.int8),
compression_dim=getattr(module, "compression_dim", 0),
compression_dtype=getattr(module, "compression_dtype",
torch.int8 if _ipex_version < "2.3.10" else torch.int32),
a32543254 marked this conversation as resolved.
Show resolved Hide resolved
compression_dim=getattr(module, "compression_dim", 0 if _ipex_version < "2.3.10" else 1),
device=device,
use_optimum_format=getattr(module, "use_optimum_format", False),
use_optimum_format=getattr(module, "use_optimum_format",
False if _ipex_version < "2.3.10" else True),
)
if quantization_config.quant_method.value == "gptq":
g_idx = getattr(module, "g_idx", torch.zeros(in_features, dtype=torch.int32).to(device))
Expand All @@ -297,6 +297,17 @@ def _replace_linear(
quantization_config.compute_dtype
),
device=torch.device(device),
) if _ipex_version < "2.3.10" else torch.ones(
(
math.ceil(
in_features / quantization_config.group_size
),
out_features,
),
dtype=convert_dtype_str2torch(
quantization_config.compute_dtype
),
device=torch.device(device),
)
),
module.qzeros if hasattr(module, "qzeros") else None,
Expand Down Expand Up @@ -348,11 +359,13 @@ def _replace_linear(
else:
if not hasattr(module, "qweight"):
n_pack = (
8 // DTYPE_BITS_MAPPING[quantization_config.weight_dtype]
(8 if _ipex_version < "2.3.10" else 32)
// DTYPE_BITS_MAPPING[quantization_config.weight_dtype]
)
weight = torch.zeros(
(math.ceil(out_features / n_pack), in_features),
dtype=torch.int8,
(math.ceil(out_features / n_pack), in_features) if _ipex_version < "2.3.10" else
(math.ceil(in_features / n_pack), out_features),
dtype=torch.int8 if _ipex_version < "2.3.10" else torch.int32,
device=torch.device(device),
)
model._modules[name].set_weights_bias(
Expand Down Expand Up @@ -592,7 +605,7 @@ def default_calib_func(model):
use_optimum_format=False,
scale_dtype=convert_dtype_str2torch(config.scale_dtype),
device="xpu",
)
) if _ipex_version < "2.3.10" else inc_model.export_compressed_model(use_optimum_format=True, device="xpu")

q_model = replace_linear(model, None, None, config, device=device)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def convert_model_to_public(model):
# reorder weight and scales if they have been transposed
if model.device == "xpu" or (isinstance(model.device, torch.device) and model.device.type == "xpu"):
for name, module in model.named_modules():
if isinstance(module, WeightOnlyQuantizedLinear):
if isinstance(module, WeightOnlyQuantizedLinear) and not module.use_optimum_format:
if module.weight_transposed:
module.qweight.data = module.qweight.t_().contiguous()
module.scales.data = module.scales.t_().contiguous()
Expand All @@ -198,6 +198,7 @@ def convert_model_to_public(model):
]:
model = recover_export_model(model)


def make_contiguous(model):
for param in model.parameters():
if param.data.ndimension() > 1:
Expand Down
Loading