Skip to content

Commit

Permalink
Compile
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 7, 2024
1 parent b714943 commit eb9b9b4
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 43 deletions.
8 changes: 6 additions & 2 deletions src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,15 @@ def _trace_fn_dispatcher(cls, fn, input, *args, **kwargs):
@classmethod
def handler_from_module(cls, module: Module, no_inheritance=False):
for handler in cls.handlers:
if not isinstance(handler.handled_layer, tuple):
handled_classes = (handler.handled_layer,)
else:
handled_classes = handler.handled_layer
if no_inheritance:
if type(module) == handler.handled_layer:
if type(module) in handled_classes:
return handler
else:
if isinstance(module, handler.handled_layer):
if any([isinstance(module, handler) for handler in handled_classes]):
return handler
return None

Expand Down
10 changes: 9 additions & 1 deletion src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
from typing import Optional, Tuple, Union
from warnings import warn

import packaging.version
import torch
from torch import nn
from torch import Tensor
import torch.jit
from torch.nn.utils.rnn import PackedSequence

from brevitas import config
from brevitas import torch_version
from brevitas.common import ExportMixin
from brevitas.inject import ExtendedInjector
from brevitas.inject import Injector
Expand All @@ -26,6 +29,11 @@

from .utils import filter_kwargs

if torch_version < packaging.version.parse('2.0'):
is_dynamo_compiling = lambda: False
else:
is_dynamo_compiling = torch._dynamo.is_compiling


class QuantProxyMixin(object):
__metaclass__ = ABCMeta
Expand Down Expand Up @@ -85,7 +93,7 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
qt_class = self.get_quant_tensor_class(inp)
if qt_class is not None:
inp = qt_class(*inp)
if not torch._C._get_tracing_state():
if not torch._C._get_tracing_state() and not is_dynamo_compiling():
if isinstance(inp, QuantTensor):
inp = inp.set(value=inp.value.rename(None))
else:
Expand Down
56 changes: 39 additions & 17 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,19 @@
from abc import ABC
from abc import ABCMeta
from abc import abstractmethod
from typing import Any, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union
from warnings import warn

import packaging.version
import torch

from brevitas import torch_version

if torch_version < packaging.version.parse('2.0'):
is_dynamo_compiling = lambda: False
else:
is_dynamo_compiling = torch._dynamo.is_compiling

from torch import Tensor
import torch.nn as nn
from typing_extensions import Protocol
Expand All @@ -16,6 +25,7 @@
from brevitas import config
from brevitas.function import max_int
from brevitas.inject import BaseInjector as Injector
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO
Expand Down Expand Up @@ -122,15 +132,25 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
# - quantization flow
if self.export_mode:
out = self.export_handler(x)
out = self.create_quant_tensor(out)
if is_dynamo_compiling():
out = out[0]
else:
out = self.create_quant_tensor(out)
elif self._cached_weight is not None and not self.cache_inference_quant_weight_metadata_only:
out = self._cached_weight.quant_tensor
if is_dynamo_compiling():
out = self._cached_weight.value
else:
out = self._cached_weight.quant_tensor
else:
out = self.tensor_quant(x)
out = self.create_quant_tensor(out)
if not self.training and self.cache_inference_quant_weight and self._cached_weight is None:
self._cached_weight = self.cache_class(
out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only)
if is_dynamo_compiling():
out = out[0]
else:
out = self.create_quant_tensor(out)
if not self.training and self.cache_inference_quant_weight and self._cached_weight is None:
self._cached_weight = self.cache_class(
out.detach(),
metadata_only=self.cache_inference_quant_weight_metadata_only)
else: # quantization disabled
out = self.apply_input_view(x)
return out
Expand All @@ -151,9 +171,10 @@ def tracked_parameter_list(self):

def get_cached(self, attr):
if self._cached_bias is None:
warn(
"No quant bias cache found, set cache_inference_quant_bias=True and run an "
"inference pass first")
if not is_dynamo_compiling():
warn(
"No quant bias cache found, set cache_inference_quant_bias=True and run an "
"inference pass first")
return None
if self.training:
warn("Cached quant bias scale is being used in training mode.")
Expand Down Expand Up @@ -268,7 +289,7 @@ class BiasQuantProxyFromInjector(BiasQuantProxyFromInjectorBase):
def scale(self):
if not self.is_quant_enabled:
return None
if self.requires_input_scale and self.is_quant_enabled and self.is_quant_enabled:
if self.requires_input_scale and self.is_quant_enabled:
cache = self.get_cached('scale')
return cache
zhs = self._zero_hw_sentinel()
Expand Down Expand Up @@ -335,12 +356,13 @@ def forward(
out, out_scale, out_zp, out_bit_width = impl(x, input_scale)
else:
out, out_scale, out_zp, out_bit_width = impl(x)
out = IntQuantTensor(
out, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
if not self.training and self.cache_inference_quant_bias:
cached_bias = _CachedIO(
out.detach(), metadata_only=self.cache_inference_quant_bias_metadata_only)
self._cached_bias = cached_bias
if not is_dynamo_compiling():
out = IntQuantTensor(
out, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
if not self.training and self.cache_inference_quant_bias:
cached_bias = _CachedIO(
out.detach(), metadata_only=self.cache_inference_quant_bias_metadata_only)
self._cached_bias = cached_bias
else:
out = x
return out
2 changes: 1 addition & 1 deletion src/brevitas/proxy/quant_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def add_tracked_module(self, module: nn.Module) -> None:
raise RuntimeError("Trying to add None as a parent module.")

def apply_input_view(self, x):
return self.quant_injector.input_view_impl(x)
return self.tensor_quant.int_quant.input_view_impl(x)

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
Expand Down
30 changes: 22 additions & 8 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
from abc import abstractmethod
from typing import Any, Optional, Tuple, Union

import packaging.version
import torch

from brevitas import torch_version

if torch_version < packaging.version.parse('2.0'):
is_dynamo_compiling = lambda: False
else:
is_dynamo_compiling = torch._dynamo.is_compiling
from torch import nn
from torch import Tensor
from torch.nn import Identity
Expand Down Expand Up @@ -115,6 +123,9 @@ def retrieve_attribute(self, attribute, force_eval):
elif self._cached_act is None:
return None

def apply_input_view(self, x):
return self.fused_activation_quant_proxy.tensor_quant.int_quant.input_view_impl(x)

@property
def is_quant_enabled(self):
return self._is_quant_enabled and not self.disable_quant
Expand Down Expand Up @@ -176,15 +187,18 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
# If y is an empty QuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor

# If the second value (i.e., scale) is None, then quant is disabled
if isinstance(y, tuple) and y[1] is not None:
out = self.create_quant_tensor(y)
elif self.is_passthrough_act and isinstance(x, QuantTensor):
# preserve quant_metadata
y = y[0]
out = self.create_quant_tensor(y, x=x)
else:
if is_dynamo_compiling():
out = y[0]
else:
# If the second value (i.e., scale) is None, then quant is disabled
if y[1] is not None:
out = self.create_quant_tensor(y)
elif self.is_passthrough_act and isinstance(x, QuantTensor):
# preserve scale/zp/bit/sign even without output quant
y = y[0]
out = self.create_quant_tensor(y, x=x)
else:
out = y[0]

if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor):
cached_out = self.cache_class(out.detach(), self.cache_quant_io_metadata_only)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from brevitas.graph.target.flexml import quantize_flexml
from brevitas.inject import value
import brevitas.nn as qnn
from brevitas.quant.experimental.float import Fp8e4m3Act
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloatMSE
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat
Expand Down
29 changes: 21 additions & 8 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from brevitas.export import export_onnx_qcdq
from brevitas.export import export_torch_qcdq
from brevitas.export.inference.manager import inference_mode
from brevitas.graph.quantize import preprocess_for_quantize
from brevitas.graph.target.flexml import preprocess_for_flexml_quantize
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization
Expand Down Expand Up @@ -269,6 +270,13 @@ def parse_type(v, default_type):
help='Use unsigned act quant when possible (default: enabled)')


def generate_ref_input(args, device, dtype):
model_config = get_model_config(args.model_name)
center_crop_shape = model_config['center_crop_shape']
img_shape = center_crop_shape
return torch.ones(1, 3, img_shape, img_shape, device=device, dtype=dtype)


def main():
args = parser.parse_args()
dtype = getattr(torch, args.dtype)
Expand Down Expand Up @@ -474,23 +482,28 @@ def main():

# Validate the quant_model on the validation dataloader
print("Starting validation:")
validate(val_loader, quant_model, stable=dtype != torch.bfloat16)
with torch.no_grad(), inference_mode(quant_model):
param = next(iter(quant_model.parameters()))
device, dtype = param.device, param.dtype
ref_input = generate_ref_input(args, device, dtype)
quant_model(ref_input)
quant_model = torch.compile(quant_model, fullgraph=True, dynamic=True)
validate(val_loader, quant_model, stable=dtype != torch.bfloat16)

if args.export_onnx_qcdq or args.export_torch_qcdq:
# Generate reference input tensor to drive the export process
model_config = get_model_config(args.model_name)
center_crop_shape = model_config['center_crop_shape']
img_shape = center_crop_shape
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
ref_input = torch.ones(1, 3, img_shape, img_shape, device=device, dtype=dtype)
param = next(iter(quant_model.parameters()))
device, dtype = param.device, param.dtype
ref_input = generate_ref_input(args, device, dtype)

export_name = os.path.join(args.export_dir, config)
if args.export_onnx_qcdq:
export_name = export_name + '.onnx'
export_onnx_qcdq(model, ref_input, export_name, opset_version=args.onnx_opset_version)
export_onnx_qcdq(
quant_model, ref_input, export_name, opset_version=args.onnx_opset_version)
if args.export_torch_qcdq:
export_name = export_name + '.pt'
export_torch_qcdq(model, ref_input, export_name)
export_torch_qcdq(quant_model, ref_input, export_name)


if __name__ == '__main__':
Expand Down
2 changes: 0 additions & 2 deletions src/brevitas_examples/imagenet_classification/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import csv

import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
Expand Down
4 changes: 1 addition & 3 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from brevitas.graph.base import ModuleToModuleByClass
from brevitas.graph.calibrate import bias_correction_mode
from brevitas.graph.calibrate import calibration_mode
from brevitas.graph.calibrate import inference_mode
from brevitas.graph.calibrate import load_quant_model_mode
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.gptq import gptq_mode
Expand Down Expand Up @@ -149,7 +150,6 @@ def main(args):
calibration_prompts = CALIBRATION_PROMPTS
if args.calibration_prompt_path is not None:
calibration_prompts = load_calib_prompts(args.calibration_prompt_path)
print(args.calibration_prompt, len(calibration_prompts))
assert args.calibration_prompt <= len(calibration_prompts) , f"Only {len(calibration_prompts)} prompts are available"
calibration_prompts = calibration_prompts[:args.calibration_prompt]

Expand Down Expand Up @@ -231,8 +231,6 @@ def main(args):
non_blacklist[name_to_add] = 1
else:
non_blacklist[name_to_add] += 1
print(f"Blacklisted layers: {set(blacklist)}")
print(f"Non blacklisted layers: {non_blacklist}")

# Make sure there all LoRA layers are fused first, otherwise raise an error
for m in pipe.unet.modules():
Expand Down

0 comments on commit eb9b9b4

Please sign in to comment.