diff --git a/src/brevitas/export/manager.py b/src/brevitas/export/manager.py index 2805c6174..7b7e7a145 100644 --- a/src/brevitas/export/manager.py +++ b/src/brevitas/export/manager.py @@ -166,11 +166,15 @@ def _trace_fn_dispatcher(cls, fn, input, *args, **kwargs): @classmethod def handler_from_module(cls, module: Module, no_inheritance=False): for handler in cls.handlers: + if not isinstance(handler.handled_layer, tuple): + handled_classes = (handler.handled_layer,) + else: + handled_classes = handler.handled_layer if no_inheritance: - if type(module) == handler.handled_layer: + if type(module) in handled_classes: return handler else: - if isinstance(module, handler.handled_layer): + if any([isinstance(module, handler) for handler in handled_classes]): return handler return None diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index d64271cb5..59b559787 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -8,12 +8,15 @@ from typing import Optional, Tuple, Union from warnings import warn +import packaging.version +import torch from torch import nn from torch import Tensor import torch.jit from torch.nn.utils.rnn import PackedSequence from brevitas import config +from brevitas import torch_version from brevitas.common import ExportMixin from brevitas.inject import ExtendedInjector from brevitas.inject import Injector @@ -26,6 +29,11 @@ from .utils import filter_kwargs +if torch_version < packaging.version.parse('2.0'): + is_dynamo_compiling = lambda: False +else: + is_dynamo_compiling = torch._dynamo.is_compiling + class QuantProxyMixin(object): __metaclass__ = ABCMeta @@ -85,7 +93,7 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe qt_class = self.get_quant_tensor_class(inp) if qt_class is not None: inp = qt_class(*inp) - if not torch._C._get_tracing_state(): + if not torch._C._get_tracing_state() and not is_dynamo_compiling(): if isinstance(inp, QuantTensor): inp = inp.set(value=inp.value.rename(None)) else: diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 77a806ee8..f91d90a75 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -4,10 +4,19 @@ from abc import ABC from abc import ABCMeta from abc import abstractmethod -from typing import Any, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union from warnings import warn +import packaging.version import torch + +from brevitas import torch_version + +if torch_version < packaging.version.parse('2.0'): + is_dynamo_compiling = lambda: False +else: + is_dynamo_compiling = torch._dynamo.is_compiling + from torch import Tensor import torch.nn as nn from typing_extensions import Protocol @@ -16,6 +25,7 @@ from brevitas import config from brevitas.function import max_int from brevitas.inject import BaseInjector as Injector +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO @@ -122,15 +132,25 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: # - quantization flow if self.export_mode: out = self.export_handler(x) - out = self.create_quant_tensor(out) + if is_dynamo_compiling(): + out = out[0] + else: + out = self.create_quant_tensor(out) elif self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only: - out = self._cached_weight.quant_tensor + if is_dynamo_compiling(): + out = self._cached_weight.value + else: + out = self._cached_weight.quant_tensor else: out = self.tensor_quant(x) - out = self.create_quant_tensor(out) - if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: - self._cached_weight = self.cache_class( - out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only) + if is_dynamo_compiling(): + out = out[0] + else: + out = self.create_quant_tensor(out) + if not self.training and self.cache_inference_quant_weight and self._cached_weight is None: + self._cached_weight = self.cache_class( + out.detach(), + metadata_only=self.cache_inference_quant_weight_metadata_only) else: # quantization disabled out = self.apply_input_view(x) return out @@ -151,9 +171,10 @@ def tracked_parameter_list(self): def get_cached(self, attr): if self._cached_bias is None: - warn( - "No quant bias cache found, set cache_inference_quant_bias=True and run an " - "inference pass first") + if not is_dynamo_compiling(): + warn( + "No quant bias cache found, set cache_inference_quant_bias=True and run an " + "inference pass first") return None if self.training: warn("Cached quant bias scale is being used in training mode.") @@ -268,7 +289,7 @@ class BiasQuantProxyFromInjector(BiasQuantProxyFromInjectorBase): def scale(self): if not self.is_quant_enabled: return None - if self.requires_input_scale and self.is_quant_enabled and self.is_quant_enabled: + if self.requires_input_scale and self.is_quant_enabled: cache = self.get_cached('scale') return cache zhs = self._zero_hw_sentinel() @@ -335,12 +356,13 @@ def forward( out, out_scale, out_zp, out_bit_width = impl(x, input_scale) else: out, out_scale, out_zp, out_bit_width = impl(x) - out = IntQuantTensor( - out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) - if not self.training and self.cache_inference_quant_bias: - cached_bias = _CachedIO( - out.detach(), metadata_only=self.cache_inference_quant_bias_metadata_only) - self._cached_bias = cached_bias + if not is_dynamo_compiling(): + out = IntQuantTensor( + out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) + if not self.training and self.cache_inference_quant_bias: + cached_bias = _CachedIO( + out.detach(), metadata_only=self.cache_inference_quant_bias_metadata_only) + self._cached_bias = cached_bias else: out = x return out diff --git a/src/brevitas/proxy/quant_proxy.py b/src/brevitas/proxy/quant_proxy.py index 9c4255773..7f61d3c45 100644 --- a/src/brevitas/proxy/quant_proxy.py +++ b/src/brevitas/proxy/quant_proxy.py @@ -122,7 +122,7 @@ def add_tracked_module(self, module: nn.Module) -> None: raise RuntimeError("Trying to add None as a parent module.") def apply_input_view(self, x): - return self.quant_injector.input_view_impl(x) + return self.tensor_quant.int_quant.input_view_impl(x) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 511f914e6..7a5381148 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -5,7 +5,15 @@ from abc import abstractmethod from typing import Any, Optional, Tuple, Union +import packaging.version import torch + +from brevitas import torch_version + +if torch_version < packaging.version.parse('2.0'): + is_dynamo_compiling = lambda: False +else: + is_dynamo_compiling = torch._dynamo.is_compiling from torch import nn from torch import Tensor from torch.nn import Identity @@ -115,6 +123,9 @@ def retrieve_attribute(self, attribute, force_eval): elif self._cached_act is None: return None + def apply_input_view(self, x): + return self.fused_activation_quant_proxy.tensor_quant.int_quant.input_view_impl(x) + @property def is_quant_enabled(self): return self._is_quant_enabled and not self.disable_quant @@ -176,15 +187,18 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor - # If the second value (i.e., scale) is None, then quant is disabled - if isinstance(y, tuple) and y[1] is not None: - out = self.create_quant_tensor(y) - elif self.is_passthrough_act and isinstance(x, QuantTensor): - # preserve quant_metadata - y = y[0] - out = self.create_quant_tensor(y, x=x) - else: + if is_dynamo_compiling(): out = y[0] + else: + # If the second value (i.e., scale) is None, then quant is disabled + if y[1] is not None: + out = self.create_quant_tensor(y) + elif self.is_passthrough_act and isinstance(x, QuantTensor): + # preserve scale/zp/bit/sign even without output quant + y = y[0] + out = self.create_quant_tensor(y, x=x) + else: + out = y[0] if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor): cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index bac596be5..661328d7e 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -22,7 +22,6 @@ from brevitas.graph.target.flexml import quantize_flexml from brevitas.inject import value import brevitas.nn as qnn -from brevitas.quant.experimental.float import Fp8e4m3Act from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloatMSE from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index c960a89e6..a561dc0c4 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -18,6 +18,7 @@ from brevitas.export import export_onnx_qcdq from brevitas.export import export_torch_qcdq +from brevitas.export.inference.manager import inference_mode from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization @@ -269,6 +270,13 @@ def parse_type(v, default_type): help='Use unsigned act quant when possible (default: enabled)') +def generate_ref_input(args, device, dtype): + model_config = get_model_config(args.model_name) + center_crop_shape = model_config['center_crop_shape'] + img_shape = center_crop_shape + return torch.ones(1, 3, img_shape, img_shape, device=device, dtype=dtype) + + def main(): args = parser.parse_args() dtype = getattr(torch, args.dtype) @@ -474,23 +482,28 @@ def main(): # Validate the quant_model on the validation dataloader print("Starting validation:") - validate(val_loader, quant_model, stable=dtype != torch.bfloat16) + with torch.no_grad(), inference_mode(quant_model): + param = next(iter(quant_model.parameters())) + device, dtype = param.device, param.dtype + ref_input = generate_ref_input(args, device, dtype) + quant_model(ref_input) + quant_model = torch.compile(quant_model, fullgraph=True, dynamic=True) + validate(val_loader, quant_model, stable=dtype != torch.bfloat16) if args.export_onnx_qcdq or args.export_torch_qcdq: # Generate reference input tensor to drive the export process - model_config = get_model_config(args.model_name) - center_crop_shape = model_config['center_crop_shape'] - img_shape = center_crop_shape - device, dtype = next(model.parameters()).device, next(model.parameters()).dtype - ref_input = torch.ones(1, 3, img_shape, img_shape, device=device, dtype=dtype) + param = next(iter(quant_model.parameters())) + device, dtype = param.device, param.dtype + ref_input = generate_ref_input(args, device, dtype) export_name = os.path.join(args.export_dir, config) if args.export_onnx_qcdq: export_name = export_name + '.onnx' - export_onnx_qcdq(model, ref_input, export_name, opset_version=args.onnx_opset_version) + export_onnx_qcdq( + quant_model, ref_input, export_name, opset_version=args.onnx_opset_version) if args.export_torch_qcdq: export_name = export_name + '.pt' - export_torch_qcdq(model, ref_input, export_name) + export_torch_qcdq(quant_model, ref_input, export_name) if __name__ == '__main__': diff --git a/src/brevitas_examples/imagenet_classification/utils.py b/src/brevitas_examples/imagenet_classification/utils.py index d506b8a61..460e7d77f 100644 --- a/src/brevitas_examples/imagenet_classification/utils.py +++ b/src/brevitas_examples/imagenet_classification/utils.py @@ -1,5 +1,3 @@ -import csv - import torch import torchvision.datasets as datasets import torchvision.transforms as transforms diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index a1c4fef53..3db0bca9a 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -29,6 +29,7 @@ from brevitas.graph.base import ModuleToModuleByClass from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode +from brevitas.graph.calibrate import inference_mode from brevitas.graph.calibrate import load_quant_model_mode from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.gptq import gptq_mode @@ -149,7 +150,6 @@ def main(args): calibration_prompts = CALIBRATION_PROMPTS if args.calibration_prompt_path is not None: calibration_prompts = load_calib_prompts(args.calibration_prompt_path) - print(args.calibration_prompt, len(calibration_prompts)) assert args.calibration_prompt <= len(calibration_prompts) , f"Only {len(calibration_prompts)} prompts are available" calibration_prompts = calibration_prompts[:args.calibration_prompt] @@ -231,8 +231,6 @@ def main(args): non_blacklist[name_to_add] = 1 else: non_blacklist[name_to_add] += 1 - print(f"Blacklisted layers: {set(blacklist)}") - print(f"Non blacklisted layers: {non_blacklist}") # Make sure there all LoRA layers are fused first, otherwise raise an error for m in pipe.unet.modules():