diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 1416014ec..35201d45d 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -7,15 +7,18 @@ import torch +from brevitas import is_dynamo_compiling from brevitas.function.ops import max_float from brevitas.function.ops import max_int from brevitas.function.ops import min_int from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase +from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector +from brevitas.quant.experimental.mx_quant_ocp import GroupwiseActQuantProxyFromInjector from brevitas.utils.torch_utils import float_internal_scale @@ -40,6 +43,11 @@ def dequantize(self, x): class IntInferencetHandler(InferenceHandler): handled_layer = (ActQuantProxyFromInjector, BiasQuantProxyFromInjector) + def __init__(self): + super().__init__() + self.register_buffer('scale', torch.ones(1)) + self.register_buffer('zero_point', torch.ones(0)) + def attach_debug_info(self, module): pass @@ -51,12 +59,11 @@ def prepare_for_export(self, module): self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width) self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width) - def quantize(self, x): - return torch.clamp( - torch.round(x / self.scale + self.zero_point), self.min_clamp, self.max_clamp) + def quantize(self, x, scale, zero_point): + return torch.clamp(torch.round(x / scale + zero_point), self.min_clamp, self.max_clamp) - def dequantize(self, x): - return (x - self.zero_point) * self.scale + def dequantize(self, x, scale, zero_point): + return (x - zero_point) * scale def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.bit_width @@ -65,6 +72,10 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: class IntWeightInferencetHandler(IntInferencetHandler): handled_layer = WeightQuantProxyFromInjector + def __init__(self): + super().__init__() + self.register_buffer('cached_weight', torch.ones(1)) + def prepare_for_export(self, module): if module.is_quant_enabled: self.cached_weight = None @@ -76,7 +87,8 @@ def forward(self, x) -> Tuple[torch.Tensor]: if self.cached_weight is not None: x = self.cached_weight else: - x = self.dequantize(self.quantize(x)) + x = self.dequantize( + self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point) return x, self.scale, self.zero_point, self.bit_width @@ -114,7 +126,6 @@ def quantize(self, x): inf_mask = x.isinf() p_max_val_mask = x > self.max_value n_max_val_mask = -x > self.max_value - # Quantize x = x / self.scale internal_scale = float_internal_scale( @@ -151,3 +162,46 @@ def forward(self, x) -> Tuple[torch.Tensor]: else: x = self.dequantize(self.quantize(x)) return x, self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values + + +class GroupwiseIntInferenceHandler(IntInferencetHandler): + handled_layer = GroupwiseActQuantProxyFromInjector + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.module_forward = module.fused_activation_quant_proxy + self.flattened_view = module.apply_input_view + self.input_view = module.input_view_impl + self.group_dim = module.group_dim + + def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: + x, *other = self.module_forward(x) + if is_dynamo_compiling: + start_dim = self.group_dim if self.group_dim != -1 else -2 + x = x.flatten(start_dim, start_dim + 1) + return x, *other + + +class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler): + handled_layer = GroupwiseWeightQuantProxyFromInjector + + def prepare_for_export(self, module): + super().prepare_for_export(module) + self.input_view = module.input_view_impl + self.flattened_view = module.apply_input_view + + def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: + x = self.input_view(x) + if self.scale.shape != (): + scale = self.input_view(self.scale) + else: + scale = self.scale + + if self.zero_point.shape != (): + zero_point = self.input_view(self.zero_point) + else: + zero_point = self.zero_point + out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) + if is_dynamo_compiling: + out = self.flattened_view(out) + return out, scale, zero_point, self.bit_width diff --git a/src/brevitas/export/inference/manager.py b/src/brevitas/export/inference/manager.py index 936106884..ab09e83e5 100644 --- a/src/brevitas/export/inference/manager.py +++ b/src/brevitas/export/inference/manager.py @@ -6,6 +6,8 @@ from brevitas.export.inference.handler import FloatInferencetHandler from brevitas.export.inference.handler import FloatWeightInferencetHandler +from brevitas.export.inference.handler import GroupwiseIntInferenceHandler +from brevitas.export.inference.handler import GroupwiseIntWeightInferenceHandler from brevitas.export.inference.handler import IntInferencetHandler from brevitas.export.inference.handler import IntWeightInferencetHandler from brevitas.export.manager import _set_proxy_export_handler @@ -93,7 +95,9 @@ class InferenceManager(BaseManager): IntInferencetHandler, FloatInferencetHandler, IntWeightInferencetHandler, - FloatWeightInferencetHandler] + FloatWeightInferencetHandler, + GroupwiseIntInferenceHandler, + GroupwiseIntWeightInferenceHandler] @classmethod def set_export_mode(cls, model: Module, enabled: bool): diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index 453cb3f9b..ea91d1996 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -12,10 +12,11 @@ class GroupwiseActQuantProxyFromInjector(ActQuantProxyFromInjector): def __init__(self, quant_layer, quant_injector): super().__init__(quant_layer, quant_injector) self.cache_class = _CachedIOGroupwiseInt + self.group_dim = self.quant_injector.group_dim - @property - def group_dim(self): - return self.quant_injector.group_dim + # @property + # def group_dim(self): + # return self.quant_injector.group_dim @property def group_size(self): diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 6004ec97d..bd9669bd8 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -14,6 +14,7 @@ from transformers.utils.fx import _SUPPORTED_MODELS from brevitas.export import export_torch_qcdq +from brevitas.export.inference.manager import quant_inference_mode from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation @@ -421,8 +422,10 @@ def main(args): if args.eval and not args.no_quantize: print("Model eval...") - quant_ppl = compute_perplexity( - model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) + with torch.no_grad(), quant_inference_mode(model): + model(**calibration_loader[0]) + quant_ppl = compute_perplexity( + model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") remove_hooks(model)