Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 10, 2024
1 parent d8cab12 commit f5446a8
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 53 deletions.
90 changes: 77 additions & 13 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC
from abc import abstractmethod
from typing import Tuple

import torch
Expand All @@ -14,9 +19,26 @@
from brevitas.utils.torch_utils import float_internal_scale


class IntInferencetHandler(torch.nn.Module):
handled_layer = (
ActQuantProxyFromInjector, WeightQuantProxyFromInjector, BiasQuantProxyFromInjector)
class InferenceHandler(torch.nn.Module, ABC):

def attach_debug_info(self, module):
pass

@abstractmethod
def prepare_for_export(self, module):
pass

@abstractmethod
def quantize(self, x):
pass

@abstractmethod
def dequantize(self, x):
pass


class IntInferencetHandler(InferenceHandler):
handled_layer = (ActQuantProxyFromInjector, BiasQuantProxyFromInjector)

def attach_debug_info(self, module):
pass
Expand All @@ -29,22 +51,38 @@ def prepare_for_export(self, module):
self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width)
self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width)

def quant(self, x):
def quantize(self, x):
return torch.clamp(
torch.round(x / self.scale + self.zero_point), self.min_clamp, self.max_clamp)

def dequant(self, x):
def dequantize(self, x):
return (x - self.zero_point) * self.scale

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
return self.dequant(self.quant(x)), self.scale, self.zero_point, self.bit_width
return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.bit_width


class FloatInferencetHandler(IntInferencetHandler):
handled_layer = (ActFloatQuantProxyFromInjector, WeightFloatQuantProxyFromInjector)
class IntWeightInferencetHandler(IntInferencetHandler):
handled_layer = WeightQuantProxyFromInjector

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.cached_weight = None
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight
else:
super.prepare_for_export(module)

def forward(self, x) -> Tuple[torch.Tensor]:
if self.cached_weight is not None:
x = self.cached_weight
else:
x = self.dequantize(self.quantize(x))
return x, self.scale, self.zero_point, self.bit_width

def attach_debug_info(self, module):
pass

class FloatInferencetHandler(InferenceHandler):
handled_layer = (ActFloatQuantProxyFromInjector, BiasQuantProxyFromInjector)

def prepare_for_export(self, module):
if module.is_quant_enabled:
Expand Down Expand Up @@ -72,20 +110,46 @@ def prepare_for_export(self, module):
self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias)
self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value

def quant(self, x):
def quantize(self, x):
# Compute masks
inf_mask = x.isinf()
p_max_val_mask = x > self.max_value
n_max_val_mask = -x > self.max_value

# Quantize
x = x / self.scale
internal_scale = float_internal_scale(
x, self.mantissa_bit_width, self.fp_internal_scale_min, self.eps)
x = internal_scale * self.float_to_int_impl(x / internal_scale)

# Clamp
x = self.float_clamp_impl.saturating_clamp(x, self.max_value, self.min_value)
if not self.saturating:
x = self.float_clamp_impl.inf_nan_clamp(x, inf_mask, p_max_val_mask, n_max_val_mask)

return x

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
return self.dequant(self.quant(x)), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values
def dequantize(self, x):
return (x - self.zero_point) * self.scale

def forward(self, x) -> Tuple[torch.Tensor]:
return self.dequantize(self.quantize(x)), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values


class FloatWeightInferencetHandler(FloatInferencetHandler):
handled_layer = (ActFloatQuantProxyFromInjector, WeightFloatQuantProxyFromInjector)

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.cached_weight = None
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight
else:
super().prepare_for_export(module)

def forward(self, x) -> Tuple[torch.Tensor]:
if self.cached_weight is not None:
x = self.cached_weight
else:
x = self.dequantize(self.quantize(x))
return x, self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values
13 changes: 11 additions & 2 deletions src/brevitas/export/inference/manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from torch.nn import Module
import torch.nn as nn

from brevitas.export.inference.handler import FloatInferencetHandler
from brevitas.export.inference.handler import FloatWeightInferencetHandler
from brevitas.export.inference.handler import IntInferencetHandler
from brevitas.export.inference.handler import IntWeightInferencetHandler
from brevitas.export.manager import _set_proxy_export_handler
from brevitas.export.manager import _set_proxy_export_mode
from brevitas.export.manager import _set_recurrent_layer_export_handler
Expand Down Expand Up @@ -32,7 +37,7 @@ def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bo
_override_caching_mode(m, 'weight', enabled, metadata_only)


class inference_mode:
class quant_inference_mode:

def __init__(self, model, cache_quant_weight=False, enabled=True):
self.model = model
Expand Down Expand Up @@ -84,7 +89,11 @@ def hook(self, module, inp, out):

# Inheritance from BaseManager is not techincally needed
class InferenceManager(BaseManager):
handlers = [IntInferencetHandler, FloatInferencetHandler]
handlers = [
IntInferencetHandler,
FloatInferencetHandler,
IntWeightInferencetHandler,
FloatWeightInferencetHandler]

@classmethod
def set_export_mode(cls, model: Module, enabled: bool):
Expand Down
6 changes: 0 additions & 6 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,13 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
if self.is_quant_enabled:
# If quant is enabled the priority is:
# - export mode
# - cached weight
# - quantization flow
if self.export_mode:
out = self.export_handler(x)
if is_dynamo_compiling():
out = out[0]
else:
out = self.create_quant_tensor(out)
elif self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only:
if is_dynamo_compiling():
out = self._cached_weight.value
else:
out = self._cached_weight.quant_tensor
else:
out = self.tensor_quant(x)
if is_dynamo_compiling():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from brevitas.export import export_onnx_qcdq
from brevitas.export import export_torch_qcdq
from brevitas.export.inference.manager import inference_mode
from brevitas.export.inference import quant_inference_mode
from brevitas.graph.quantize import preprocess_for_quantize
from brevitas.graph.target.flexml import preprocess_for_flexml_quantize
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization
Expand Down Expand Up @@ -268,6 +268,7 @@ def parse_type(v, default_type):
'uint_sym_act_for_unsigned_values',
default=True,
help='Use unsigned act quant when possible (default: enabled)')
add_bool_arg(parser, 'compile', default=False, help='Use torch.compile (default: disabled)')


def generate_ref_input(args, device, dtype):
Expand Down Expand Up @@ -482,13 +483,14 @@ def main():

# Validate the quant_model on the validation dataloader
print("Starting validation:")
with torch.no_grad(), inference_mode(quant_model):
with torch.no_grad(), quant_inference_mode(quant_model):
param = next(iter(quant_model.parameters()))
device, dtype = param.device, param.dtype
ref_input = generate_ref_input(args, device, dtype)
quant_model(ref_input)
quant_model = torch.compile(quant_model, fullgraph=True, dynamic=True)
validate(val_loader, quant_model, stable=dtype != torch.bfloat16)
compiled_model = torch.compile(
quant_model, fullgraph=True, dynamic=True, disable=not args.compile)
validate(val_loader, compiled_model, stable=dtype != torch.bfloat16)

if args.export_onnx_qcdq or args.export_torch_qcdq:
# Generate reference input tensor to drive the export process
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from brevitas.graph.base import ModuleToModuleByClass
from brevitas.graph.calibrate import bias_correction_mode
from brevitas.graph.calibrate import calibration_mode
from brevitas.graph.calibrate import inference_mode
from brevitas.graph.calibrate import load_quant_model_mode
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.gptq import gptq_mode
Expand Down Expand Up @@ -150,6 +149,7 @@ def main(args):
calibration_prompts = CALIBRATION_PROMPTS
if args.calibration_prompt_path is not None:
calibration_prompts = load_calib_prompts(args.calibration_prompt_path)
print(args.calibration_prompt, len(calibration_prompts))
assert args.calibration_prompt <= len(calibration_prompts) , f"Only {len(calibration_prompts)} prompts are available"
calibration_prompts = calibration_prompts[:args.calibration_prompt]

Expand Down Expand Up @@ -231,6 +231,8 @@ def main(args):
non_blacklist[name_to_add] = 1
else:
non_blacklist[name_to_add] += 1
print(f"Blacklisted layers: {set(blacklist)}")
print(f"Non blacklisted layers: {non_blacklist}")

# Make sure there all LoRA layers are fused first, otherwise raise an error
for m in pipe.unet.modules():
Expand Down
53 changes: 26 additions & 27 deletions tests/brevitas_end_to_end/test_torchvision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from brevitas import torch_version
from brevitas.export import export_onnx_qcdq
from brevitas.export import export_torch_qcdq
from brevitas.export.inference.manager import inference_mode
from brevitas.export.inference import quant_inference_mode
from brevitas.graph.calibrate import calibration_mode
from brevitas.graph.quantize import layerwise_quantize
from brevitas.graph.quantize import quantize
Expand All @@ -28,19 +28,19 @@
MODEL_LIST = [
'vit_b_32',
'efficientnet_b0',
# 'mobilenet_v3_small',
# 'mobilenet_v2',
# 'resnet50',
# 'resnet18',
# 'mnasnet0_5',
# 'alexnet',
# 'googlenet',
# 'vgg11',
# 'densenet121',
# 'deeplabv3_resnet50',
# 'fcn_resnet50',
# 'regnet_x_400mf',
# 'squeezenet1_0',
'mobilenet_v3_small',
'mobilenet_v2',
'resnet50',
'resnet18',
'mnasnet0_5',
'alexnet',
'googlenet',
'vgg11',
'densenet121',
'deeplabv3_resnet50',
'fcn_resnet50',
'regnet_x_400mf',
'squeezenet1_0',
'inception_v3']


Expand Down Expand Up @@ -71,7 +71,7 @@ def quantize_float(model):

@fixture
@parametrize('model_name', MODEL_LIST)
@parametrize('quantize_fn', [quantize_float])
@parametrize('quantize_fn', [quantize_float, quantize, layerwise_quantize, quantize_flexml])
def torchvision_model(model_name, quantize_fn):

inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH)
Expand Down Expand Up @@ -122,28 +122,27 @@ def test_torchvision_graph_quantization_flexml_qcdq_onnx(
pytest.skip('Model not instantiated')
if enable_compile:
model_name = test_id.split("-")[1]
if torch_version <= version.parse('2.0'):
pytest.skip("Pytorch 2.0 is required to test compile")
if torch_version <= version.parse('2.2'):
pytest.skip("Pytorch 2.2 is required to test compile")
else:
torch._dynamo.config.capture_scalar_outputs = True
if 'vit' in model_name:
pytest.skip("QuantMHA not supported with compile")

inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH)

quantize_fn_name = test_id.split("-")[0]
if enable_compile:
torch._dynamo.config.capture_scalar_outputs = True
with torch.no_grad(), inference_mode(torchvision_model):
prehook_non_compiled_out = torchvision_model(inp)
post_hook_non_compiled_out = torchvision_model(inp)
assert torch.allclose(prehook_non_compiled_out, post_hook_non_compiled_out)
with torch.no_grad(), quant_inference_mode(torchvision_model):
prehook_non_compiled_out = torchvision_model(inp)
post_hook_non_compiled_out = torchvision_model(inp)
assert torch.allclose(prehook_non_compiled_out, post_hook_non_compiled_out)

if enable_compile:
compiled_model = torch.compile(torchvision_model, fullgraph=True)
compiled_out = compiled_model(inp)

# This fails! Compile might needs more small-scoped tests for accuracy evaluation
# assert torch.allclose(post_hook_non_compiled_out, compiled_out)
else:
torchvision_model(inp)
# This fails! Compile might needs more small-scoped tests for accuracy evaluation
# assert torch.allclose(post_hook_non_compiled_out, compiled_out)

if quantize_fn_name != 'quantize_float' and not enable_compile:
export_onnx_qcdq(torchvision_model, args=inp)
Expand Down

0 comments on commit f5446a8

Please sign in to comment.