Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (brevitas_examples/llm): inference_mode support #1129

Merged
merged 9 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/brevitas/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from abc import ABCMeta
from abc import abstractmethod
import warnings

from brevitas import config

Expand All @@ -29,8 +30,8 @@ def export_mode(self):
@export_mode.setter
def export_mode(self, value):
if value and config.JIT_ENABLED:
raise RuntimeError(
"Export mode with BREVITAS_JIT is currently not supported. Save the model' "
warnings.warn(
"Export mode with BREVITAS_JIT might fail. If so, save the model' "
"state_dict to a .pth, load it back with BREVITAS_JIT=0, and call export.")
if value and self.training:
raise RuntimeError("Can't enter export mode during training, only during inference")
Expand Down
195 changes: 165 additions & 30 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,83 +6,164 @@
from typing import Tuple

import torch
from torch import Tensor
import torch.nn as nn

from brevitas import is_dynamo_compiling
from brevitas.function.ops import max_float
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int
from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjectorBase
from brevitas.proxy.groupwise_float_parameter_quant import \
GroupwiseWeightFloatQuantProxyFromInjector
from brevitas.proxy.groupwise_float_runtime_quant import GroupwiseActFloatQuantProxyFromInjector
from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector
from brevitas.quant.experimental.mx_quant_ocp import GroupwiseActQuantProxyFromInjector
from brevitas.utils.torch_utils import float_internal_scale


class InferenceHandler(torch.nn.Module, ABC):

def attach_debug_info(self, module):
def attach_debug_info(self, module: nn.Module):
pass

@abstractmethod
def prepare_for_export(self, module):
def prepare_for_export(self, module: nn.Module):
pass

@abstractmethod
def quantize(self, x):
def quantize(self, x: Tensor):
pass

@abstractmethod
def dequantize(self, x):
def dequantize(self, x: Tensor):
pass


class IntInferencetHandler(InferenceHandler):
handled_layer = (ActQuantProxyFromInjector, BiasQuantProxyFromInjector)

def attach_debug_info(self, module):
pass
def __init__(self):
super().__init__()
self.register_buffer('scale', torch.ones(1))
self.register_buffer('zero_point', torch.ones(0))

def prepare_for_export(self, module):
def prepare_for_export(self, module: nn.Module):
if module.is_quant_enabled:
self.scale = module.scale()
self.zero_point = module.zero_point().to(self.scale.device)
self.bit_width = module.bit_width()
self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width)
self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width)

def quantize(self, x):
return torch.clamp(
torch.round(x / self.scale + self.zero_point), self.min_clamp, self.max_clamp)
def quantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tuple[Tensor]:
return torch.clamp(torch.round(x / scale + zero_point), self.min_clamp, self.max_clamp)

def dequantize(self, x):
return (x - self.zero_point) * self.scale
def dequantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tensor:
return (x - zero_point) * scale

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.bit_width
def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
return self.dequantize(self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point), self.scale, self.zero_point, self.bit_width


class IntWeightInferencetHandler(IntInferencetHandler):
handled_layer = WeightQuantProxyFromInjector

def prepare_for_export(self, module):
def __init__(self):
super().__init__()
self.register_buffer('cached_weight', torch.ones(1))

def prepare_for_export(self, module: nn.Module):
super().prepare_for_export(module)
if module.is_quant_enabled:
self.cached_weight = None
super().prepare_for_export(module)
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.value
else:
self.cached_weight = None

def forward(self, x) -> Tuple[torch.Tensor]:
def forward(self, x: Tensor) -> Tuple[Tensor]:
if self.cached_weight is not None:
x = self.cached_weight
else:
x = self.dequantize(self.quantize(x))
x = self.dequantize(
self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point)
return x, self.scale, self.zero_point, self.bit_width


class DynamicIntInferenceHandler(IntInferencetHandler):
handled_layer = DynamicActQuantProxyFromInjector

def prepare_for_export(self, module: nn.Module):
if module.is_quant_enabled:
self.module_forward = module.fused_activation_quant_proxy

def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
return self.module_forward(x)


class GroupwiseIntInferenceHandler(IntInferencetHandler):
handled_layer = GroupwiseActQuantProxyFromInjector

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.module_forward = module.fused_activation_quant_proxy
self.group_dim = module.group_dim

def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
x, *other = self.module_forward(x)
if is_dynamo_compiling():
start_dim = self.group_dim if self.group_dim != -1 else -2
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
return output_args


class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler):
handled_layer = GroupwiseWeightQuantProxyFromInjector

def prepare_for_export(self, module):
super().prepare_for_export(module)
if module.is_quant_enabled:
self.input_view = module.input_view_impl
self.flattened_view = module.apply_input_view
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.quant_tensor.value_
else:
self.cached_weight = None

def forward(self, x: Tensor) -> Tuple[Tensor]:
if self.scale.shape != ():
scale = self.input_view(self.scale)
else:
scale = self.scale
if self.zero_point.shape != ():
zero_point = self.input_view(self.zero_point)
else:
zero_point = self.zero_point
if self.cached_weight is not None:
out = self.cached_weight
else:
x = self.input_view(x)
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)
if is_dynamo_compiling():
out = self.flattened_view(out)
return out, scale, zero_point, self.bit_width


class FloatInferencetHandler(InferenceHandler):
handled_layer = (ActFloatQuantProxyFromInjector, BiasQuantProxyFromInjector)

def __init__(self):
super().__init__()
self.register_buffer('scale', torch.ones(1))
self.register_buffer('zero_point', torch.ones(0))

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.scale = module.scale()
Expand All @@ -109,14 +190,13 @@ def prepare_for_export(self, module):
self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias)
self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value

def quantize(self, x):
def quantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tuple[Tensor]:
# Compute masks
inf_mask = x.isinf()
p_max_val_mask = x > self.max_value
n_max_val_mask = -x > self.max_value

# Quantize
x = x / self.scale
x = x / scale
internal_scale = float_internal_scale(
x, self.mantissa_bit_width, self.fp_internal_scale_min, self.eps)
x = internal_scale * self.float_to_int_impl(x / internal_scale)
Expand All @@ -128,26 +208,81 @@ def quantize(self, x):

return x

def dequantize(self, x):
return (x - self.zero_point) * self.scale
def dequantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tensor:
return (x - zero_point) * scale

def forward(self, x) -> Tuple[torch.Tensor]:
return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values
def forward(self, x: Tensor) -> Tuple[Tensor]:
return self.dequantize(self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values


class FloatWeightInferencetHandler(FloatInferencetHandler):
handled_layer = WeightFloatQuantProxyFromInjector

def __init__(self):
super().__init__()
self.register_buffer('cached_weight', torch.ones(1))

def prepare_for_export(self, module):
super().prepare_for_export(module)
if module.is_quant_enabled:
self.cached_weight = None
super().prepare_for_export(module)
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.value
else:
self.cached_weight = None

def forward(self, x) -> Tuple[torch.Tensor]:
def forward(self, x: Tensor) -> Tuple[Tensor]:
if self.cached_weight is not None:
x = self.cached_weight
else:
x = self.dequantize(self.quantize(x))
x = self.dequantize(
self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point)
return x, self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values


class GroupwiseFloatInferenceHandler(FloatInferencetHandler):
handled_layer = GroupwiseActFloatQuantProxyFromInjector

def prepare_for_export(self, module: nn.Module):
if module.is_quant_enabled:
self.module_forward = module.fused_activation_quant_proxy
self.group_dim = module.group_dim

def forward(self, x: Tensor) -> Tuple[Tensor]:
x, *other = self.module_forward(x)
if is_dynamo_compiling():
start_dim = self.group_dim if self.group_dim != -1 else -2
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
return output_args


class GroupwiseFloatWeightInferenceHandler(FloatWeightInferencetHandler):
handled_layer = GroupwiseWeightFloatQuantProxyFromInjector

def prepare_for_export(self, module: nn.Module):
super().prepare_for_export(module)
if module.is_quant_enabled:
self.input_view = module.input_view_impl
self.flattened_view = module.apply_input_view
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.quant_tensor.value_
else:
self.cached_weight = None

def forward(self, x: Tensor) -> Tuple[Tensor]:
if self.scale.shape != ():
scale = self.input_view(self.scale)
else:
scale = self.scale
if self.zero_point.shape != ():
zero_point = self.input_view(self.zero_point)
else:
zero_point = self.zero_point
if self.cached_weight is not None:
out = self.cached_weight
else:
x = self.input_view(x)
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)
if is_dynamo_compiling():
out = self.flattened_view(out)
return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values
14 changes: 12 additions & 2 deletions src/brevitas/export/inference/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@
from torch.nn import Module
import torch.nn as nn

from brevitas.export.inference.handler import DynamicIntInferenceHandler
from brevitas.export.inference.handler import FloatInferencetHandler
from brevitas.export.inference.handler import FloatWeightInferencetHandler
from brevitas.export.inference.handler import GroupwiseFloatInferenceHandler
from brevitas.export.inference.handler import GroupwiseFloatWeightInferenceHandler
from brevitas.export.inference.handler import GroupwiseIntInferenceHandler
from brevitas.export.inference.handler import GroupwiseIntWeightInferenceHandler
from brevitas.export.inference.handler import IntInferencetHandler
from brevitas.export.inference.handler import IntWeightInferencetHandler
from brevitas.export.manager import _set_proxy_export_handler
Expand Down Expand Up @@ -65,14 +70,14 @@ def __exit__(self, type, value, traceback):
# Disable all caching
# deactivate export mode
# restore return quant tensor
InferenceManager.set_export_mode(self.model, enabled=False)
self.model.apply(
lambda m: _override_bias_caching_mode(m, enabled=False, metadata_only=False))
self.model.apply(
lambda m: _override_act_caching_mode(m, enabled=False, metadata_only=False))
if self.cache_quant_weight:
self.model.apply(
lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False))
InferenceManager.set_export_mode(self.model, enabled=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

def hook(self, module, inp, out):
Expand All @@ -91,9 +96,14 @@ def hook(self, module, inp, out):
class InferenceManager(BaseManager):
handlers = [
IntInferencetHandler,
DynamicIntInferenceHandler,
FloatInferencetHandler,
IntWeightInferencetHandler,
FloatWeightInferencetHandler]
FloatWeightInferencetHandler,
GroupwiseIntInferenceHandler,
GroupwiseIntWeightInferenceHandler,
GroupwiseFloatInferenceHandler,
GroupwiseFloatWeightInferenceHandler]

@classmethod
def set_export_mode(cls, model: Module, enabled: bool):
Expand Down
1 change: 0 additions & 1 deletion src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def expand(self):
curr_shape = self.value_.shape
start_dim = self.group_dim if self.group_dim != -1 else -2
new_value = self.value_.flatten(start_dim, start_dim + 1)
new_value = self.value_.flatten(start_dim, start_dim + 1)
if self.scale_.shape != ():
new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
Expand Down
7 changes: 5 additions & 2 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformers.utils.fx import _SUPPORTED_MODELS

from brevitas.export import export_torch_qcdq
from brevitas.export.inference.manager import quant_inference_mode
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas.graph.equalize import GraphRotationEqualization
from brevitas.graph.equalize import LayerwiseActivationRotation
Expand Down Expand Up @@ -421,8 +422,10 @@ def main(args):

if args.eval and not args.no_quantize:
print("Model eval...")
quant_ppl = compute_perplexity(
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
quant_ppl = compute_perplexity(
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")
remove_hooks(model)

Expand Down
Loading