Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix some issues
Browse files Browse the repository at this point in the history
wenhuach21 committed Jan 3, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 8ebf551 commit ba4234b
Showing 3 changed files with 25 additions and 11 deletions.
6 changes: 4 additions & 2 deletions auto_round/auto_quantizer.py
Original file line number Diff line number Diff line change
@@ -599,14 +599,16 @@ def cpu_post_init(self, model):

for n, layer in tqdm(layers, desc=message, total=len(layers),
leave=True):
from auto_round_extension.qbits import qbits_qlinear_classes
from auto_round_extension.qbits import qbits_qlinear_classes,qbits_awq_classes
from auto_round_extension.ipex import ipex_qlinear_classes
if isinstance(layer, qbits_qlinear_classes):
if dep_check:
layer.req_check()
layer.post_init()
dep_check = False
if isinstance(layer, ipex_qlinear_classes):
elif isinstance(layer, ipex_qlinear_classes):
layer.post_init()
elif isinstance(layer, qbits_awq_classes):
layer.post_init()

return model
3 changes: 3 additions & 0 deletions auto_round_extension/qbits/__init__.py
Original file line number Diff line number Diff line change
@@ -2,5 +2,8 @@
from auto_round_extension.qbits.qlinear_qbits_gptq import (
QuantLinear as QBitsGPTQQuantLinear,
)
from auto_round_extension.qbits.qbits_awq import QuantLinear as QBitsAWQQuantLinear

qbits_qlinear_classes = (QBitsQuantLinear, QBitsGPTQQuantLinear)

qbits_awq_classes=(QBitsAWQQuantLinear)
27 changes: 18 additions & 9 deletions auto_round_extension/qbits/qbits_awq.py
Original file line number Diff line number Diff line change
@@ -90,15 +90,24 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_poin
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
self.pack_num = 32 // self.w_bit

self.register_buffer(
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // self.pack_num),
dtype=torch.int8,
device=dev,
) if self.zero_point else None,
)
if self.zero_point:
self.register_buffer(
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // self.pack_num),
dtype=torch.int8,
device=dev,
)
)
else:
self.register_buffer(
"qzeros",
torch.ones(
(in_features // self.group_size, out_features // self.pack_num),
dtype=torch.int8,
device=dev,
)*8
)
self.register_buffer(
"scales",
torch.zeros(

0 comments on commit ba4234b

Please sign in to comment.