Skip to content

Commit

Permalink
hqq upgrades (#491)
Browse files Browse the repository at this point in the history
  • Loading branch information
flozi00 authored Jun 6, 2024
1 parent c71861a commit 1b528e0
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 132 deletions.
10 changes: 5 additions & 5 deletions server/lorax_server/layers/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

HAS_HQQ = True
try:
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear

from hqq.core.quantize import BaseQuantizeConfig, HQQBackend, HQQLinear
HQQLinear.set_backend(HQQBackend.ATEN)
class HQQLinearLayer(HQQLinear):
@property
def weight(self) -> torch.Tensor:
Expand All @@ -16,11 +16,11 @@ def weight(self) -> torch.Tensor:

def get_hqq_linear(quantize, weight, bias=None) -> HQQLinearLayer:
if quantize == "hqq-4bit":
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=False)
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16)
elif quantize == "hqq-3bit":
quant_config = BaseQuantizeConfig(nbits=3, group_size=64, quant_zero=True, quant_scale=False)
quant_config = BaseQuantizeConfig(nbits=3, group_size=64, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16)
elif quantize == "hqq-2bit":
quant_config = BaseQuantizeConfig(nbits=2, group_size=16, quant_zero=True, quant_scale=False)
quant_config = BaseQuantizeConfig(nbits=2, group_size=16, quant_zero=True, quant_scale=True, offload_meta=True, compute_dtype=torch.bfloat16)

# init nn.linear from weight and bias
layer = nn.Linear(weight.shape[1], weight.shape[0], bias=bias is not None)
Expand Down
130 changes: 4 additions & 126 deletions server/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ torch = { version = "2.3.0", optional = true }
peft = { version = "0.4.0", optional = true }
boto3 = "^1.28.34"
urllib3 = "<=1.26.18"
hqq = { version = "^0.1.2", optional = true }
hqq = { version = "^0.1.7", optional = true }
stanford-stk = { version = "^0.7.0", markers = "sys_platform == 'linux'" }
outlines = { version = "^0.0.40", optional = true }
prometheus-client = "^0.20.0"
Expand Down

0 comments on commit 1b528e0

Please sign in to comment.