Skip to content

Commit

Permalink
Feat (brevitas_examples/llm): inference_mode support
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 13, 2024
1 parent 9acfc50 commit 10fdfe1
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 13 deletions.
68 changes: 61 additions & 7 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@

import torch

from brevitas import is_dynamo_compiling
from brevitas.function.ops import max_float
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int
from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase
from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.quant.experimental.mx_quant_ocp import GroupwiseActQuantProxyFromInjector
from brevitas.utils.torch_utils import float_internal_scale


Expand All @@ -40,6 +43,11 @@ def dequantize(self, x):
class IntInferencetHandler(InferenceHandler):
handled_layer = (ActQuantProxyFromInjector, BiasQuantProxyFromInjector)

def __init__(self):
super().__init__()
self.register_buffer('scale', torch.ones(1))
self.register_buffer('zero_point', torch.ones(0))

def attach_debug_info(self, module):
pass

Expand All @@ -51,12 +59,11 @@ def prepare_for_export(self, module):
self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width)
self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width)

def quantize(self, x):
return torch.clamp(
torch.round(x / self.scale + self.zero_point), self.min_clamp, self.max_clamp)
def quantize(self, x, scale, zero_point):
return torch.clamp(torch.round(x / scale + zero_point), self.min_clamp, self.max_clamp)

def dequantize(self, x):
return (x - self.zero_point) * self.scale
def dequantize(self, x, scale, zero_point):
return (x - zero_point) * scale

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.bit_width
Expand All @@ -65,6 +72,10 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
class IntWeightInferencetHandler(IntInferencetHandler):
handled_layer = WeightQuantProxyFromInjector

def __init__(self):
super().__init__()
self.register_buffer('cached_weight', torch.ones(1))

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.cached_weight = None
Expand All @@ -76,7 +87,8 @@ def forward(self, x) -> Tuple[torch.Tensor]:
if self.cached_weight is not None:
x = self.cached_weight
else:
x = self.dequantize(self.quantize(x))
x = self.dequantize(
self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point)
return x, self.scale, self.zero_point, self.bit_width


Expand Down Expand Up @@ -114,7 +126,6 @@ def quantize(self, x):
inf_mask = x.isinf()
p_max_val_mask = x > self.max_value
n_max_val_mask = -x > self.max_value

# Quantize
x = x / self.scale
internal_scale = float_internal_scale(
Expand Down Expand Up @@ -151,3 +162,46 @@ def forward(self, x) -> Tuple[torch.Tensor]:
else:
x = self.dequantize(self.quantize(x))
return x, self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values


class GroupwiseIntInferenceHandler(IntInferencetHandler):
handled_layer = GroupwiseActQuantProxyFromInjector

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.module_forward = module.fused_activation_quant_proxy
self.flattened_view = module.apply_input_view
self.input_view = module.input_view_impl
self.group_dim = module.group_dim

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
x, *other = self.module_forward(x)
if is_dynamo_compiling:
start_dim = self.group_dim if self.group_dim != -1 else -2
x = x.flatten(start_dim, start_dim + 1)
return x, *other


class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler):
handled_layer = GroupwiseWeightQuantProxyFromInjector

def prepare_for_export(self, module):
super().prepare_for_export(module)
self.input_view = module.input_view_impl
self.flattened_view = module.apply_input_view

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
x = self.input_view(x)
if self.scale.shape != ():
scale = self.input_view(self.scale)
else:
scale = self.scale

if self.zero_point.shape != ():
zero_point = self.input_view(self.zero_point)
else:
zero_point = self.zero_point
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)
if is_dynamo_compiling:
out = self.flattened_view(out)
return out, scale, zero_point, self.bit_width
6 changes: 5 additions & 1 deletion src/brevitas/export/inference/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from brevitas.export.inference.handler import FloatInferencetHandler
from brevitas.export.inference.handler import FloatWeightInferencetHandler
from brevitas.export.inference.handler import GroupwiseIntInferenceHandler
from brevitas.export.inference.handler import GroupwiseIntWeightInferenceHandler
from brevitas.export.inference.handler import IntInferencetHandler
from brevitas.export.inference.handler import IntWeightInferencetHandler
from brevitas.export.manager import _set_proxy_export_handler
Expand Down Expand Up @@ -93,7 +95,9 @@ class InferenceManager(BaseManager):
IntInferencetHandler,
FloatInferencetHandler,
IntWeightInferencetHandler,
FloatWeightInferencetHandler]
FloatWeightInferencetHandler,
GroupwiseIntInferenceHandler,
GroupwiseIntWeightInferenceHandler]

@classmethod
def set_export_mode(cls, model: Module, enabled: bool):
Expand Down
7 changes: 4 additions & 3 deletions src/brevitas/proxy/groupwise_int_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ class GroupwiseActQuantProxyFromInjector(ActQuantProxyFromInjector):
def __init__(self, quant_layer, quant_injector):
super().__init__(quant_layer, quant_injector)
self.cache_class = _CachedIOGroupwiseInt
self.group_dim = self.quant_injector.group_dim

@property
def group_dim(self):
return self.quant_injector.group_dim
# @property
# def group_dim(self):
# return self.quant_injector.group_dim

@property
def group_size(self):
Expand Down
7 changes: 5 additions & 2 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformers.utils.fx import _SUPPORTED_MODELS

from brevitas.export import export_torch_qcdq
from brevitas.export.inference.manager import quant_inference_mode
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas.graph.equalize import GraphRotationEqualization
from brevitas.graph.equalize import LayerwiseActivationRotation
Expand Down Expand Up @@ -421,8 +422,10 @@ def main(args):

if args.eval and not args.no_quantize:
print("Model eval...")
quant_ppl = compute_perplexity(
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
quant_ppl = compute_perplexity(
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")
remove_hooks(model)

Expand Down

0 comments on commit 10fdfe1

Please sign in to comment.