diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index ffeb17e62..63143c4e5 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -322,12 +322,16 @@ def quantize( return graph_model -def layerwise_quantize(model: nn.Module, compute_layer_map: dict = LAYERWISE_COMPUTE_LAYER_MAP): +def layerwise_quantize( + model: nn.Module, + compute_layer_map: dict = LAYERWISE_COMPUTE_LAYER_MAP, + name_blacklist=None): ignore_missing_keys_state = config.IGNORE_MISSING_KEYS config.IGNORE_MISSING_KEYS = True training_state = model.training model.eval() - model = layerwise_layer_handler(model, layer_map=compute_layer_map) + model = layerwise_layer_handler( + model, layer_map=compute_layer_map, name_blacklist=name_blacklist) model.train(training_state) config.IGNORE_MISSING_KEYS = ignore_missing_keys_state return model diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index 130f977a2..4fc8e5c66 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -1,5 +1,6 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from inspect import isclass import operator from typing import Dict, List, Optional @@ -493,31 +494,45 @@ def layer_handler( return model +def _module_class_name(module_class_or_str): + name = module_class_or_str.__module__ + '.' + module_class_or_str.__name__ if isclass( + module_class_or_str) else module_class_or_str + return name + + def find_module( - model: nn.Module, layer_map: Dict[nn.Module, Optional[Dict]], module_to_replace: List): + model: nn.Module, + layer_map: Dict[nn.Module, Optional[Dict]], + module_to_replace: List, + name_blacklist): """ Iterate through the model looking at immediate children of every module to look for supported modules. This allows us to stop the search when we meet a top-level module that is supported. Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its Linear submodules. """ - if isinstance(model, tuple(layer_map.keys())): + if _module_class_name(type(model)) in layer_map.keys(): module_to_replace.append(model) else: - for module in model.children(): - find_module(module, layer_map, module_to_replace) + for name, module in model.named_children(): + if name_blacklist is not None and name in name_blacklist: + continue + find_module(module, layer_map, module_to_replace, name_blacklist) -def layerwise_layer_handler(model: nn.Module, layer_map: Dict[nn.Module, Optional[Dict]]): +def layerwise_layer_handler( + model: nn.Module, layer_map: Dict[nn.Module, Optional[Dict]], name_blacklist=None): """ Replace FP weight layers with their corresponding quantized version """ + # Normalize all module lookups to fully qualified strings + layer_map = {_module_class_name(m): v for m, v in layer_map.items()} module_to_replace = [] - find_module(model, layer_map, module_to_replace) + find_module(model, layer_map, module_to_replace, name_blacklist) rewriters = [] for module in module_to_replace: - if layer_map[type(module)] is not None: - quant_module_class, quant_module_kwargs = layer_map[type(module)] + if layer_map[_module_class_name(type(module))] is not None: + quant_module_class, quant_module_kwargs = layer_map[_module_class_name(type(module))] rewriter = ModuleToModuleByInstance(module, quant_module_class, **quant_module_kwargs) rewriters.append(rewriter) for rewriter in rewriters: diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index 05593f5a3..e017f5de3 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -14,7 +14,7 @@ from .torch_handler import QUANT_TENSOR_FN_HANDLER -IS_VALID_ATOL = 1e-1 +IS_VALID_ATOL = 2e-1 class QuantTensorBase(NamedTuple): diff --git a/src/brevitas_examples/common/generative/nn.py b/src/brevitas_examples/common/generative/nn.py new file mode 100644 index 000000000..af43900b6 --- /dev/null +++ b/src/brevitas_examples/common/generative/nn.py @@ -0,0 +1,22 @@ +from brevitas.nn import QuantConv2d +from brevitas.nn import QuantLinear + + +class LoRACompatibleQuantConv2d(QuantConv2d): + """ + A QuantConv2d layer that can be used with as a replacement for LoRACompatibleConv. + It doesn't actually support LoRA, it only matches the same forward pass. + """ + + def forward(self, hidden_states, scale: float = 1.0): + return super().forward(hidden_states) + + +class LoRACompatibleQuantLinear(QuantLinear): + """ + A QuantLinear layer that can be used with as a replacement for LoRACompatibleLinear. + It doesn't actually support LoRA, it only matches the same forward pass. + """ + + def forward(self, hidden_states, scale: float = 1.0): + return super().forward(hidden_states) diff --git a/src/brevitas_examples/llm/llm_quant/quant_blocks.py b/src/brevitas_examples/common/generative/quant_blocks.py similarity index 100% rename from src/brevitas_examples/llm/llm_quant/quant_blocks.py rename to src/brevitas_examples/common/generative/quant_blocks.py diff --git a/src/brevitas_examples/llm/llm_quant/quantize.py b/src/brevitas_examples/common/generative/quantize.py similarity index 89% rename from src/brevitas_examples/llm/llm_quant/quantize.py rename to src/brevitas_examples/common/generative/quantize.py index 647dd35a2..c0d76559d 100644 --- a/src/brevitas_examples/llm/llm_quant/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -31,16 +31,18 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE -from brevitas_examples.llm.llm_quant.quantizers import Fp8e4m3WeightSymmetricGroupQuant -from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerGroupFloat -from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerRowFloat -from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerTensorFloat -from brevitas_examples.llm.llm_quant.quantizers import Int8ActPerRowFloat -from brevitas_examples.llm.llm_quant.quantizers import Int8ActPerRowFloatMSE -from brevitas_examples.llm.llm_quant.quantizers import IntWeightSymmetricGroupQuant -from brevitas_examples.llm.llm_quant.quantizers import ShiftedUint8ActPerRowFloat -from brevitas_examples.llm.llm_quant.quantizers import ShiftedUint8ActPerRowFloatMSE -from brevitas_examples.llm.llm_quant.quantizers import ShiftedUintWeightAsymmetricGroupQuant +from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d +from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear +from brevitas_examples.common.generative.quantizers import Fp8e4m3WeightSymmetricGroupQuant +from brevitas_examples.common.generative.quantizers import Int8ActDynamicPerGroupFloat +from brevitas_examples.common.generative.quantizers import Int8ActDynamicPerRowFloat +from brevitas_examples.common.generative.quantizers import Int8ActDynamicPerTensorFloat +from brevitas_examples.common.generative.quantizers import Int8ActPerRowFloat +from brevitas_examples.common.generative.quantizers import Int8ActPerRowFloatMSE +from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant +from brevitas_examples.common.generative.quantizers import ShiftedUint8ActPerRowFloat +from brevitas_examples.common.generative.quantizers import ShiftedUint8ActPerRowFloatMSE +from brevitas_examples.common.generative.quantizers import ShiftedUintWeightAsymmetricGroupQuant WEIGHT_QUANT_MAP = { 'int': { @@ -132,8 +134,9 @@ def quantize_model( weight_group_size, quantize_weight_zero_point, weight_quant_format='int', + name_blacklist=None, input_bit_width=None, - input_quant_format=None, + input_quant_format='', input_scale_precision=None, input_scale_type=None, input_param_method=None, @@ -190,7 +193,6 @@ def quantize_model( # Modify the weight quantizer based on the arguments passed in weight_quant = weight_quant.let( **{ - 'bit_width': weight_bit_width, 'narrow_range': False, 'block_size': weight_group_size, 'quantize_zero_point': quantize_weight_zero_point}, @@ -309,7 +311,15 @@ def quantize_model( 'group_dim': 1, 'group_size': input_group_size}) quant_linear_kwargs = { - 'input_quant': linear_2d_input_quant, 'weight_quant': weight_quant, 'dtype': dtype} + 'input_quant': linear_2d_input_quant, + 'weight_quant': weight_quant, + 'weight_bit_width': weight_bit_width, + 'dtype': dtype} + quant_conv_kwargs = { + 'input_quant': input_quant, + 'weight_quant': weight_quant, + 'weight_bit_width': weight_bit_width, + 'dtype': dtype} quant_mha_kwargs = { 'in_proj_input_quant': input_quant, @@ -333,10 +343,14 @@ def quantize_model( layer_map = { nn.Linear: (qnn.QuantLinear, quant_linear_kwargs), + nn.Conv2d: (qnn.QuantConv2d, quant_conv_kwargs), + 'diffusers.models.lora.LoRACompatibleLinear': + (LoRACompatibleQuantLinear, quant_linear_kwargs), + 'diffusers.models.lora.LoRACompatibleConv': (LoRACompatibleQuantConv2d, quant_conv_kwargs), nn.MultiheadAttention: (qnn.QuantMultiheadAttention, quant_mha_kwargs)} if quantize_embedding: quant_embedding_kwargs = {'weight_quant': weight_quant, 'dtype': dtype} layer_map[nn.Embedding] = (qnn.QuantEmbedding, quant_embedding_kwargs) - layerwise_quantize(model=model, compute_layer_map=layer_map) + layerwise_quantize(model=model, compute_layer_map=layer_map, name_blacklist=name_blacklist) diff --git a/src/brevitas_examples/llm/llm_quant/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py similarity index 97% rename from src/brevitas_examples/llm/llm_quant/quantizers.py rename to src/brevitas_examples/common/generative/quantizers.py index 28590a0e8..5c7e82513 100644 --- a/src/brevitas_examples/llm/llm_quant/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -11,8 +11,6 @@ from brevitas.core.scaling import ParameterFromStatsFromParameterScaling from brevitas.core.stats import AbsMinMax from brevitas.core.stats import NegativeMinOrZero -from brevitas.core.stats import NegativePercentileOrZero -from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint from brevitas.inject import ExtendedInjector from brevitas.inject import this diff --git a/src/brevitas_examples/common/parse_utils.py b/src/brevitas_examples/common/parse_utils.py new file mode 100644 index 000000000..0b13b69b8 --- /dev/null +++ b/src/brevitas_examples/common/parse_utils.py @@ -0,0 +1,33 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +SPDX-License-Identifier: MIT +""" + +import argparse +import re + + +class CustomValidator(object): + + def __init__(self, pattern): + self._pattern = re.compile(pattern) + + def __call__(self, value): + if not self._pattern.match(value): + raise argparse.ArgumentTypeError( + "Argument has to match '{}'".format(self._pattern.pattern)) + return value + + +quant_format_validator = CustomValidator(r"int|e[1-8]m[1-8]") + + +def add_bool_arg(parser, name, default, help, str_true=False): + dest = name.replace('-', '_') + group = parser.add_mutually_exclusive_group(required=False) + if str_true: + group.add_argument('--' + name, dest=dest, type=str, help=help) + else: + group.add_argument('--' + name, dest=dest, action='store_true', help='Enable ' + help) + group.add_argument('--no-' + name, dest=dest, action='store_false', help='Disable ' + help) + parser.set_defaults(**{dest: default}) diff --git a/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py b/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py index a234c86d0..ef0a72880 100644 --- a/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py +++ b/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py @@ -185,7 +185,6 @@ def compile_vicuna_layer( torch.ops.aten.split.Tensor, torch.ops.aten.split_with_sizes,]), )(hidden_states, attention_mask, position_ids) - print(fx_g.graph) else: with export_context_manager(vicuna_layer, export_class): fx_g = make_fx( diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 5f640dcda..17c0d5fe7 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -12,6 +12,8 @@ from brevitas.export import export_onnx_qcdq from brevitas.export import export_torch_qcdq +from brevitas_examples.common.generative.quantize import quantize_model +from brevitas_examples.common.parse_utils import quant_format_validator from brevitas_examples.llm.llm_quant.bias_corr import apply_bias_correction from brevitas_examples.llm.llm_quant.calibrate import apply_calibration from brevitas_examples.llm.llm_quant.data import get_c4 @@ -21,25 +23,10 @@ from brevitas_examples.llm.llm_quant.gptq import apply_gptq from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers -from brevitas_examples.llm.llm_quant.quantize import quantize_model from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 from brevitas_examples.llm.llm_quant.run_utils import get_model_impl - -class CustomValidator(object): - - def __init__(self, pattern): - self._pattern = re.compile(pattern) - - def __call__(self, value): - if not self._pattern.match(value): - raise argparse.ArgumentTypeError( - "Argument has to match '{}'".format(self._pattern.pattern)) - return value - - parser = argparse.ArgumentParser() -quant_format_validator = CustomValidator(r"int|e[1-8]m[1-8]") parser.add_argument( '--model', type=str, diff --git a/src/brevitas_examples/stable_diffusion/__init__.py b/src/brevitas_examples/stable_diffusion/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py new file mode 100644 index 000000000..6986c5549 --- /dev/null +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -0,0 +1,222 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +SPDX-License-Identifier: MIT +""" + +import argparse +from datetime import datetime +import json +import os +import re +import time + +from diffusers import StableDiffusionPipeline +import torch +from torch import nn + +from brevitas_examples.common.generative.quantize import quantize_model +from brevitas_examples.common.parse_utils import add_bool_arg +from brevitas_examples.common.parse_utils import quant_format_validator +from brevitas_examples.stable_diffusion.sd_quant.constants import SD_2_1_EMBEDDINGS_SHAPE +from brevitas_examples.stable_diffusion.sd_quant.export import export_torchscript_weight_group_quant +from brevitas_examples.stable_diffusion.sd_quant.utils import generate_latents +from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_rand_inputs +from brevitas_examples.stable_diffusion.sd_quant.utils import unet_input_shape + +TEST_SEED = 123456 + + +def run_test_inference( + pipe, resolution, prompts, seeds, output_path, device, dtype, name_prefix=''): + with torch.no_grad(): + if not os.path.exists(output_path): + os.mkdir(output_path) + test_latents = generate_latents(seeds, device, dtype, unet_input_shape(resolution)) + + for name, prompt in prompts.items(): + print(f"Generating: {name}") + images = pipe([prompt] * len(seeds), latents=test_latents).images + for i, seed in enumerate(seeds): + file_path = os.path.join(output_path, f"{name_prefix}{name}_{seed}.png") + print(f"Saving to {file_path}") + images[i].save(file_path) + + +def main(args): + + # Select dtype + if args.float16: + dtype = torch.float16 + else: + dtype = torch.float32 + + # Create output dir. Move to tmp if None + ts = datetime.fromtimestamp(time.time()) + str_ts = ts.strftime("%Y%m%d_%H%M%S") + output_dir = os.path.join(args.output_path, f'{str_ts}') + os.mkdir(output_dir) + + # Dump args to json + with open(os.path.join(output_dir, 'args.json'), 'w') as fp: + json.dump(vars(args), fp) + + # Extend seeds based on batch_size + test_seeds = [TEST_SEED] + [TEST_SEED + i for i in range(1, args.batch_size)] + + # Load model from float checkpoint + print(f"Loading model from {args.model}...") + pipe = StableDiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype) + print(f"Model loaded from {args.model}.") + + # Enable attention slicing + if args.attention_slicing: + pipe.enable_attention_slicing() + + # Extract list of layers to avoid + blacklist = [] + for name, _ in pipe.unet.named_modules(): + if 'time_emb' in name or 'conv_in' in name: + blacklist.append(name) + print(f"Blacklisted layers: {blacklist}") + + # Make sure there all LoRA layers are fused first, otherwise raise an error + for m in pipe.unet.modules(): + if hasattr(m, 'lora_layer') and m.lora_layer is not None: + raise RuntimeError("LoRA layers should be fused in before calling into quantization.") + + # Quantize model + if args.quantize: + + def bit_width_fn(module): + if isinstance(module, nn.Linear): + return args.linear_weight_bit_width + elif isinstance(module, nn.Conv2d): + return args.conv_weight_bit_width + else: + raise RuntimeError(f"Module {module} not supported.") + + weight_bit_width = lambda module: bit_width_fn(module) + + print("Applying model quantization...") + quantize_model( + pipe.unet, + dtype=dtype, + name_blacklist=blacklist, + weight_quant_format=args.weight_quant_format, + weight_quant_type=args.weight_quant_type, + weight_bit_width=weight_bit_width, + weight_param_method=args.weight_param_method, + weight_scale_precision=args.weight_scale_precision, + weight_quant_granularity=args.weight_quant_granularity, + weight_group_size=args.weight_group_size, + quantize_weight_zero_point=args.quantize_weight_zero_point) + print("Model quantization applied.") + + # Move model to target device + print(f"Moving model to {args.device}...") + pipe = pipe.to(args.device) + + # Perform inference + if args.prompt: + print(f"Running inference with prompt '{args.prompt}' ...") + prompts = {'manual_prompt': args.prompt} + run_test_inference( + pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) + + if args.export_target: + # Move to cpu and to float32 to enable CPU export + pipe.unet.to('cpu').to(torch.float32) + pipe.unet.eval() + if args.export_target == 'torchscript_weight_group_quant': + assert args.weight_quant_granularity == 'per_group', "Per-group quantization required." + assert args.weight_quant_format == 'int', "Only integer quantization supported for export." + trace_inputs = generate_unet_rand_inputs( + embedding_shape=SD_2_1_EMBEDDINGS_SHAPE, + unet_input_shape=unet_input_shape(args.resolution), + device='cpu', + dtype=torch.float32) + export_torchscript_weight_group_quant(pipe, trace_inputs, output_dir) + else: + raise ValueError(f"{args.export_target} not recognized.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Stable Diffusion quantization') + parser.add_argument( + '-m', + '--model', + type=str, + default='/scratch/hf_models/stable-diffusion-2-1-base', + help='Path or name of the model.') + parser.add_argument( + '-d', '--device', type=str, default='cuda:0', help='Target device for quantized model.') + parser.add_argument('-b', '--batch-size', type=int, default=4, help='Batch size.') + parser.add_argument( + '--prompt', + type=str, + default='An austronaut riding a horse on Mars.', + help='Manual prompt for testing.') + parser.add_argument( + '--resolution', + type=int, + default=512, + help='Resolution along height and width dimension. Default: 512.') + add_bool_arg( + parser, + 'output-path', + str_true=True, + default='.', + help='Path where to generate output folder.') + add_bool_arg(parser, 'quantize', default=True, help='Toggle quantization.') + add_bool_arg(parser, 'float16', default=True, help='Enable float16 execution.') + add_bool_arg(parser, 'attention-slicing', default=False, help='Enable attention slicing.') + parser.add_argument( + '--export-target', + type=str, + default='', + choices=['', 'torchscript_weight_group_quant'], + help='Target export flow.') + parser.add_argument( + '--conv-weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.') + parser.add_argument( + '--linear-weight-bit-width', type=int, default=8, help='Weight bit width. Default: 4.') + parser.add_argument( + '--weight-param-method', + type=str, + default='stats', + choices=['stats', 'mse'], + help='How scales/zero-point are determined. Default: stats.') + parser.add_argument( + '--weight-scale-precision', + type=str, + default='float_scale', + choices=['float_scale', 'po2_scale'], + help='Whether scale is a float value or a po2. Default: float_scale.') + parser.add_argument( + '--weight-quant-type', + type=str, + default='asym', + choices=['sym', 'asym'], + help='Weight quantization type. Default: asym.') + parser.add_argument( + '--weight-quant-format', + type=quant_format_validator, + default='int', + help= + 'Weight quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. Default: int.') + parser.add_argument( + '--weight-quant-granularity', + type=str, + default='per_group', + choices=['per_channel', 'per_tensor', 'per_group'], + help='Granularity for scales/zero-point of weights. Default: per_group.') + parser.add_argument( + '--weight-group-size', + type=int, + default=16, + help='Group size for per_group weight quantization. Default: 16.') + add_bool_arg( + parser, 'quantize-weight-zero-point', default=True, help='Quantize weight zero-point.') + args = parser.parse_args() + print("Args: " + str(vars(args))) + main(args) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/__init__.py b/src/brevitas_examples/stable_diffusion/sd_quant/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/constants.py b/src/brevitas_examples/stable_diffusion/sd_quant/constants.py new file mode 100644 index 000000000..7359a06dd --- /dev/null +++ b/src/brevitas_examples/stable_diffusion/sd_quant/constants.py @@ -0,0 +1,6 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +SPDX-License-Identifier: MIT +""" + +SD_2_1_EMBEDDINGS_SHAPE = (77, 1024) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py new file mode 100644 index 000000000..624d183b2 --- /dev/null +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -0,0 +1,47 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +SPDX-License-Identifier: MIT +""" + +import os + +import torch +from torch import nn +from torch._decomp import get_decompositions + +from brevitas.backport.fx.experimental.proxy_tensor import make_fx +from brevitas.export.manager import _force_requires_grad_false +from brevitas.export.manager import _JitTraceExportWrapper +from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode + + +class UnetExportWrapper(nn.Module): + + def __init__(self, unet): + super().__init__() + self.unet = unet + + def forward(self, *args, **kwargs): + return self.unet(*args, **kwargs, return_dict=False) + + +def export_torchscript_weight_group_quant(pipe, trace_inputs, output_dir): + with brevitas_proxy_export_mode(pipe.unet): + fx_g = make_fx( + UnetExportWrapper(pipe.unet), + decomposition_table=get_decompositions([ + torch.ops.aten.embedding_dense_backward, + torch.ops.aten.native_layer_norm_backward, + torch.ops.aten.slice_backward, + torch.ops.aten.select_backward, + torch.ops.aten.norm.ScalarOpt_dim, + torch.ops.aten.native_group_norm, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes,]), + )(*trace_inputs.values()) + _force_requires_grad_false(fx_g) + jit_g = torch.jit.trace(_JitTraceExportWrapper(fx_g), tuple(trace_inputs.values())) + output_path = os.path.join(output_dir, 'unet.ts') + print(f"Saving unet to {output_path} ...") + torch.jit.save(jit_g, output_path) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/utils.py b/src/brevitas_examples/stable_diffusion/sd_quant/utils.py new file mode 100644 index 000000000..7a88aebdf --- /dev/null +++ b/src/brevitas_examples/stable_diffusion/sd_quant/utils.py @@ -0,0 +1,49 @@ +""" +Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +SPDX-License-Identifier: MIT +""" + +import torch + + +def unet_input_shape(resolution): + return (4, resolution // 8, resolution // 8) + + +def generate_latents(seeds, device, dtype, input_shape): + """ + Generate a concatenation of latents of a given input_shape + (batch size excluded) on a target device from one or more seeds. + """ + latents = None + if not isinstance(seeds, (list, tuple)): + seeds = [seeds] + for seed in seeds: + generator = torch.Generator(device=device) + generator = generator.manual_seed(seed) + image_latents = torch.randn((1, *input_shape), + generator=generator, + device=device, + dtype=dtype) + latents = image_latents if latents is None else torch.cat((latents, image_latents)) + return latents + + +def generate_unet_rand_inputs( + embedding_shape, + unet_input_shape, + batch_size=1, + device='cpu', + dtype=torch.float32, + with_return_dict_false=False): + sample = torch.randn(batch_size, *unet_input_shape, device=device, dtype=dtype) + unet_rand_inputs = { + 'sample': + sample, + 'timestep': + torch.tensor(1, dtype=torch.int64, device=device), + 'encoder_hidden_states': + torch.randn(batch_size, *embedding_shape, device=device, dtype=dtype)} + if with_return_dict_false: + unet_rand_inputs['return_dict'] = False + return unet_rand_inputs