diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index 01b6a6f15..dcc54aa7d 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -195,7 +195,7 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq"]: + if isinstance(weight, torch.Tensor): weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index e4c803302..61a336e51 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -7,6 +7,7 @@ from accelerate import init_empty_weights from torch import nn from torch.nn import functional as F +from loguru import logger from lorax_server.adapters.types import LORA, MEDUSA from lorax_server.utils.gptq.quant_linear import QuantLinear @@ -166,8 +167,17 @@ def __init__( self, weight, bias, + scales=None, + quantized=False, ) -> None: super().__init__() + + if quantized: + self.weight = weight + self.scale = scales + self.bias = bias if bias is not None else None + return + # Get the device where the weight tensor is currently stored. device = weight.device @@ -344,10 +354,20 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False): quant_type="fp4", ) elif quantize == "eetq": - if HAS_EETQ: - linear = EETQLinear(weight, bias) - else: + if not HAS_EETQ: raise ImportError("Please install EETQ from https://github.com/NetEase-FuXi/EETQ") + + try: + qweight, scales = weight + linear = EETQLinear( + qweight, + bias, + scales, + True, + ) + except Exception: + logger.info("It seems that weight not quantized, make JIT now") + linear = EETQLinear(weight, bias) elif quantize == "gptq": try: qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight diff --git a/server/lorax_server/utils/weights.py b/server/lorax_server/utils/weights.py index 1bc777d92..03778bb5c 100644 --- a/server/lorax_server/utils/weights.py +++ b/server/lorax_server/utils/weights.py @@ -139,7 +139,7 @@ def get_tensor(self, tensor_name: str): tensor = f.get_tensor(tensor_name) # Special case for gptq which shouldn't convert # u4 which are disguised as int32 - if tensor.dtype not in [torch.int32, torch.int64]: + if tensor.dtype not in [torch.int8, torch.int32, torch.int64]: tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor @@ -178,7 +178,7 @@ def get_partial_sharded(self, tensor_name: str, dim: int, range: Optional[Tuple[ raise NotImplementedError("Let's make that generic when needed") # Special case for gptq which shouldn't convert # u4 which are disguised as int32 - if tensor.dtype != torch.int32: + if tensor.dtype not in [torch.int8, torch.int32]: tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor @@ -226,6 +226,15 @@ def get_multi_weights_col(self, prefixes: List[Union[str, Tuple]], quantize: str bits, groupsize = self._get_gptq_params() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) + elif quantize == "eetq": + try: + qweight = torch.cat(self.get_sharded_list("qweight", prefixes, dim=1), dim=1) + scales = torch.cat(self.get_sharded_list("weight_scales", prefixes, dim=0), dim=0) + weight = (qweight, scales) + except RuntimeError: + logger.info("It seems that weight is not quantized, so load it normally then make JIT later") + w = self.get_sharded_list("weight", prefixes, dim=0) + weight = torch.cat(w, dim=dim) else: w = self.get_sharded_list("weight", prefixes, dim=0) weight = torch.cat(w, dim=dim) @@ -310,6 +319,14 @@ def get_multi_weights_row(self, prefix: str, quantize: str): g_idx = None use_exllama = False weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) + elif quantize == "eetq": + try: + qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + scales = self.get_sharded(f"{prefix}.weight_scales", dim=0) + weight = (qweight, scales) + except RuntimeError: + logger.info("It seems that weight is not quantized, so load it normally then make JIT later") + weight = self.get_sharded(f"{prefix}.weight", dim=1) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight