Skip to content

Commit

Permalink
Feat (examples): initial Stable Diffusion support (#715)
Browse files Browse the repository at this point in the history
Signed-off-by: Alessandro Pappalardo <[email protected]>
  • Loading branch information
volcacius authored Nov 10, 2023
1 parent 98213ab commit 7750ea8
Show file tree
Hide file tree
Showing 16 changed files with 439 additions and 43 deletions.
8 changes: 6 additions & 2 deletions src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,16 @@ def quantize(
return graph_model


def layerwise_quantize(model: nn.Module, compute_layer_map: dict = LAYERWISE_COMPUTE_LAYER_MAP):
def layerwise_quantize(
model: nn.Module,
compute_layer_map: dict = LAYERWISE_COMPUTE_LAYER_MAP,
name_blacklist=None):
ignore_missing_keys_state = config.IGNORE_MISSING_KEYS
config.IGNORE_MISSING_KEYS = True
training_state = model.training
model.eval()
model = layerwise_layer_handler(model, layer_map=compute_layer_map)
model = layerwise_layer_handler(
model, layer_map=compute_layer_map, name_blacklist=name_blacklist)
model.train(training_state)
config.IGNORE_MISSING_KEYS = ignore_missing_keys_state
return model
31 changes: 23 additions & 8 deletions src/brevitas/graph/quantize_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
from inspect import isclass
import operator
from typing import Dict, List, Optional

Expand Down Expand Up @@ -493,31 +494,45 @@ def layer_handler(
return model


def _module_class_name(module_class_or_str):
name = module_class_or_str.__module__ + '.' + module_class_or_str.__name__ if isclass(
module_class_or_str) else module_class_or_str
return name


def find_module(
model: nn.Module, layer_map: Dict[nn.Module, Optional[Dict]], module_to_replace: List):
model: nn.Module,
layer_map: Dict[nn.Module, Optional[Dict]],
module_to_replace: List,
name_blacklist):
"""
Iterate through the model looking at immediate children of every module to look for supported modules.
This allows us to stop the search when we meet a top-level module that is supported.
Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its
Linear submodules.
"""
if isinstance(model, tuple(layer_map.keys())):
if _module_class_name(type(model)) in layer_map.keys():
module_to_replace.append(model)
else:
for module in model.children():
find_module(module, layer_map, module_to_replace)
for name, module in model.named_children():
if name_blacklist is not None and name in name_blacklist:
continue
find_module(module, layer_map, module_to_replace, name_blacklist)


def layerwise_layer_handler(model: nn.Module, layer_map: Dict[nn.Module, Optional[Dict]]):
def layerwise_layer_handler(
model: nn.Module, layer_map: Dict[nn.Module, Optional[Dict]], name_blacklist=None):
"""
Replace FP weight layers with their corresponding quantized version
"""
# Normalize all module lookups to fully qualified strings
layer_map = {_module_class_name(m): v for m, v in layer_map.items()}
module_to_replace = []
find_module(model, layer_map, module_to_replace)
find_module(model, layer_map, module_to_replace, name_blacklist)
rewriters = []
for module in module_to_replace:
if layer_map[type(module)] is not None:
quant_module_class, quant_module_kwargs = layer_map[type(module)]
if layer_map[_module_class_name(type(module))] is not None:
quant_module_class, quant_module_kwargs = layer_map[_module_class_name(type(module))]
rewriter = ModuleToModuleByInstance(module, quant_module_class, **quant_module_kwargs)
rewriters.append(rewriter)
for rewriter in rewriters:
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .torch_handler import QUANT_TENSOR_FN_HANDLER

IS_VALID_ATOL = 1e-1
IS_VALID_ATOL = 2e-1


class QuantTensorBase(NamedTuple):
Expand Down
22 changes: 22 additions & 0 deletions src/brevitas_examples/common/generative/nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantLinear


class LoRACompatibleQuantConv2d(QuantConv2d):
"""
A QuantConv2d layer that can be used with as a replacement for LoRACompatibleConv.
It doesn't actually support LoRA, it only matches the same forward pass.
"""

def forward(self, hidden_states, scale: float = 1.0):
return super().forward(hidden_states)


class LoRACompatibleQuantLinear(QuantLinear):
"""
A QuantLinear layer that can be used with as a replacement for LoRACompatibleLinear.
It doesn't actually support LoRA, it only matches the same forward pass.
"""

def forward(self, hidden_states, scale: float = 1.0):
return super().forward(hidden_states)
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,18 @@
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE
from brevitas_examples.llm.llm_quant.quantizers import Fp8e4m3WeightSymmetricGroupQuant
from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerGroupFloat
from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerRowFloat
from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerTensorFloat
from brevitas_examples.llm.llm_quant.quantizers import Int8ActPerRowFloat
from brevitas_examples.llm.llm_quant.quantizers import Int8ActPerRowFloatMSE
from brevitas_examples.llm.llm_quant.quantizers import IntWeightSymmetricGroupQuant
from brevitas_examples.llm.llm_quant.quantizers import ShiftedUint8ActPerRowFloat
from brevitas_examples.llm.llm_quant.quantizers import ShiftedUint8ActPerRowFloatMSE
from brevitas_examples.llm.llm_quant.quantizers import ShiftedUintWeightAsymmetricGroupQuant
from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d
from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear
from brevitas_examples.common.generative.quantizers import Fp8e4m3WeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import Int8ActDynamicPerGroupFloat
from brevitas_examples.common.generative.quantizers import Int8ActDynamicPerRowFloat
from brevitas_examples.common.generative.quantizers import Int8ActDynamicPerTensorFloat
from brevitas_examples.common.generative.quantizers import Int8ActPerRowFloat
from brevitas_examples.common.generative.quantizers import Int8ActPerRowFloatMSE
from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import ShiftedUint8ActPerRowFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8ActPerRowFloatMSE
from brevitas_examples.common.generative.quantizers import ShiftedUintWeightAsymmetricGroupQuant

WEIGHT_QUANT_MAP = {
'int': {
Expand Down Expand Up @@ -132,8 +134,9 @@ def quantize_model(
weight_group_size,
quantize_weight_zero_point,
weight_quant_format='int',
name_blacklist=None,
input_bit_width=None,
input_quant_format=None,
input_quant_format='',
input_scale_precision=None,
input_scale_type=None,
input_param_method=None,
Expand Down Expand Up @@ -190,7 +193,6 @@ def quantize_model(
# Modify the weight quantizer based on the arguments passed in
weight_quant = weight_quant.let(
**{
'bit_width': weight_bit_width,
'narrow_range': False,
'block_size': weight_group_size,
'quantize_zero_point': quantize_weight_zero_point},
Expand Down Expand Up @@ -309,7 +311,15 @@ def quantize_model(
'group_dim': 1, 'group_size': input_group_size})

quant_linear_kwargs = {
'input_quant': linear_2d_input_quant, 'weight_quant': weight_quant, 'dtype': dtype}
'input_quant': linear_2d_input_quant,
'weight_quant': weight_quant,
'weight_bit_width': weight_bit_width,
'dtype': dtype}
quant_conv_kwargs = {
'input_quant': input_quant,
'weight_quant': weight_quant,
'weight_bit_width': weight_bit_width,
'dtype': dtype}

quant_mha_kwargs = {
'in_proj_input_quant': input_quant,
Expand All @@ -333,10 +343,14 @@ def quantize_model(

layer_map = {
nn.Linear: (qnn.QuantLinear, quant_linear_kwargs),
nn.Conv2d: (qnn.QuantConv2d, quant_conv_kwargs),
'diffusers.models.lora.LoRACompatibleLinear':
(LoRACompatibleQuantLinear, quant_linear_kwargs),
'diffusers.models.lora.LoRACompatibleConv': (LoRACompatibleQuantConv2d, quant_conv_kwargs),
nn.MultiheadAttention: (qnn.QuantMultiheadAttention, quant_mha_kwargs)}

if quantize_embedding:
quant_embedding_kwargs = {'weight_quant': weight_quant, 'dtype': dtype}
layer_map[nn.Embedding] = (qnn.QuantEmbedding, quant_embedding_kwargs)

layerwise_quantize(model=model, compute_layer_map=layer_map)
layerwise_quantize(model=model, compute_layer_map=layer_map, name_blacklist=name_blacklist)
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from brevitas.core.scaling import ParameterFromStatsFromParameterScaling
from brevitas.core.stats import AbsMinMax
from brevitas.core.stats import NegativeMinOrZero
from brevitas.core.stats import NegativePercentileOrZero
from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
from brevitas.inject import ExtendedInjector
from brevitas.inject import this
Expand Down
33 changes: 33 additions & 0 deletions src/brevitas_examples/common/parse_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
SPDX-License-Identifier: MIT
"""

import argparse
import re


class CustomValidator(object):

def __init__(self, pattern):
self._pattern = re.compile(pattern)

def __call__(self, value):
if not self._pattern.match(value):
raise argparse.ArgumentTypeError(
"Argument has to match '{}'".format(self._pattern.pattern))
return value


quant_format_validator = CustomValidator(r"int|e[1-8]m[1-8]")


def add_bool_arg(parser, name, default, help, str_true=False):
dest = name.replace('-', '_')
group = parser.add_mutually_exclusive_group(required=False)
if str_true:
group.add_argument('--' + name, dest=dest, type=str, help=help)
else:
group.add_argument('--' + name, dest=dest, action='store_true', help='Enable ' + help)
group.add_argument('--no-' + name, dest=dest, action='store_false', help='Disable ' + help)
parser.set_defaults(**{dest: default})
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def compile_vicuna_layer(
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,]),
)(hidden_states, attention_mask, position_ids)
print(fx_g.graph)
else:
with export_context_manager(vicuna_layer, export_class):
fx_g = make_fx(
Expand Down
17 changes: 2 additions & 15 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from brevitas.export import export_onnx_qcdq
from brevitas.export import export_torch_qcdq
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.common.parse_utils import quant_format_validator
from brevitas_examples.llm.llm_quant.bias_corr import apply_bias_correction
from brevitas_examples.llm.llm_quant.calibrate import apply_calibration
from brevitas_examples.llm.llm_quant.data import get_c4
Expand All @@ -21,25 +23,10 @@
from brevitas_examples.llm.llm_quant.gptq import apply_gptq
from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge
from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl


class CustomValidator(object):

def __init__(self, pattern):
self._pattern = re.compile(pattern)

def __call__(self, value):
if not self._pattern.match(value):
raise argparse.ArgumentTypeError(
"Argument has to match '{}'".format(self._pattern.pattern))
return value


parser = argparse.ArgumentParser()
quant_format_validator = CustomValidator(r"int|e[1-8]m[1-8]")
parser.add_argument(
'--model',
type=str,
Expand Down
Empty file.
Loading

0 comments on commit 7750ea8

Please sign in to comment.