From d3bb48031fcb79ef94de19c99710b10403acc1cf Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 14 Dec 2024 15:22:55 +0000 Subject: [PATCH] fix --- src/brevitas/export/inference/handler.py | 2 +- src/brevitas/proxy/groupwise_int_runtime_quant.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 35201d45d..e09e9a1d3 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -66,7 +66,7 @@ def dequantize(self, x, scale, zero_point): return (x - zero_point) * scale def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: - return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.bit_width + return self.dequantize(self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point), self.scale, self.zero_point, self.bit_width class IntWeightInferencetHandler(IntInferencetHandler): diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index ea91d1996..453cb3f9b 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -12,11 +12,10 @@ class GroupwiseActQuantProxyFromInjector(ActQuantProxyFromInjector): def __init__(self, quant_layer, quant_injector): super().__init__(quant_layer, quant_injector) self.cache_class = _CachedIOGroupwiseInt - self.group_dim = self.quant_injector.group_dim - # @property - # def group_dim(self): - # return self.quant_injector.group_dim + @property + def group_dim(self): + return self.quant_injector.group_dim @property def group_size(self):