From 301453713e57d708ce830e00a9fb0076b60087ca Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 9 Apr 2024 13:28:56 +0200 Subject: [PATCH] Feat (examples): add support for Stable Diffusion XL (#909) --- src/brevitas/graph/equalize.py | 4 +- src/brevitas/nn/equalized_layer.py | 2 +- .../stable_diffusion/README.md | 141 +++++++++++ .../stable_diffusion/main.py | 236 +++++++++++++++--- .../stable_diffusion/sd_quant/constants.py | 1 + .../stable_diffusion/sd_quant/export.py | 24 +- .../stable_diffusion/sd_quant/utils.py | 36 +++ 7 files changed, 379 insertions(+), 65 deletions(-) create mode 100644 src/brevitas_examples/stable_diffusion/README.md diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index e6538421e..fa63bf80d 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -21,6 +21,7 @@ from brevitas.graph.utils import get_module from brevitas.graph.utils import get_node from brevitas.nn.equalized_layer import EqualizedModule +from brevitas.nn.equalized_layer import INPUT_NAMES from brevitas.nn.quant_scale_bias import ScaleBias from brevitas.utils.torch_utils import KwargsForwardHook @@ -970,8 +971,7 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k self.float_act_map[name] = None return - possible_input_kwargs = ['input', 'inp', 'query'] - input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0] + input_kwarg = [x for x in kwargs.keys() if x in INPUT_NAMES][0] if use_inp: x = kwargs[input_kwarg] elif not use_inp: diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index 35f636604..7093c8c17 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -4,7 +4,7 @@ from brevitas.nn.quant_mha import QuantMultiheadAttention -INPUT_NAMES = ['input', 'inp', 'query', 'x'] +INPUT_NAMES = ['input', 'inp', 'query', 'x', 'hidden_states'] class EqualizedModule(torch.nn.Module): diff --git a/src/brevitas_examples/stable_diffusion/README.md b/src/brevitas_examples/stable_diffusion/README.md new file mode 100644 index 000000000..8c0d1dae8 --- /dev/null +++ b/src/brevitas_examples/stable_diffusion/README.md @@ -0,0 +1,141 @@ +# Stable Diffusion Quantization + +It currently supports Stable Diffusion 2.1 and Stable Diffusion XL. + +The following PTQ techniques are currently supported: +- Activation Equalization (e.g., SmoothQuant), layerwise (with the addition of Mul ops) +- Activation Calibration, in the case of static activation quantization +- GPTQ +- Bias Correction + +These techniques can be applied for both integer and floating point quantization + +We support ONNX integer export, and we are planning to release soon export for floating point quantization (e.g., FP8). +To export the model in fp16, enable `export-cuda-float16`. This will performing the tracing necessary for export on GPU, leaving the model in fp16. +If the flag is not enabled, the model will be moved to CPU and cast to float32 before export because of missing CPU kernels in fp16. + +NB: when exporting Stable Diffusion XL, make sure to enable `is-sd-xl` flag. The flag is not needed when export is not executed. + + +## Run +usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] + [--resolution RESOLUTION] + [--output-path OUTPUT_PATH | --no-output-path] + [--quantize | --no-quantize] + [--activation-equalization | --no-activation-equalization] + [--gptq | --no-gptq] [--float16 | --no-float16] + [--attention-slicing | --no-attention-slicing] + [--is-sd-xl | --no-is-sd-xl] [--export-target {,onnx}] + [--export-weight-q-node | --no-export-weight-q-node] + [--conv-weight-bit-width CONV_WEIGHT_BIT_WIDTH] + [--linear-weight-bit-width LINEAR_WEIGHT_BIT_WIDTH] + [--conv-input-bit-width CONV_INPUT_BIT_WIDTH] + [--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH] + [--weight-param-method {stats,mse}] + [--input-param-method {stats,mse}] + [--weight-scale-precision {float_scale,po2_scale}] + [--input-scale-precision {float_scale,po2_scale}] + [--weight-quant-type {sym,asym}] + [--input-quant-type {sym,asym}] + [--weight-quant-format WEIGHT_QUANT_FORMAT] + [--input-quant-format INPUT_QUANT_FORMAT] + [--weight-quant-granularity {per_channel,per_tensor,per_group}] + [--input-quant-granularity {per_tensor}] + [--input-scale-type {static,dynamic}] + [--weight-group-size WEIGHT_GROUP_SIZE] + [--quantize-weight-zero-point | --no-quantize-weight-zero-point] + [--export-cuda-float16 | --no-export-cuda-float16] + +Stable Diffusion quantization + +options: + -h, --help show this help message and exit + -m MODEL, --model MODEL + Path or name of the model. + -d DEVICE, --device DEVICE + Target device for quantized model. + -b BATCH_SIZE, --batch-size BATCH_SIZE + Batch size. Default: 4 + --prompt PROMPT Manual prompt for testing. Default: An austronaut + riding a horse on Mars. + --resolution RESOLUTION + Resolution along height and width dimension. Default: + 512. + --output-path OUTPUT_PATH + Path where to generate output folder. + --no-output-path Disable Path where to generate output folder. + --quantize Enable Toggle quantization. Default: Enabled + --no-quantize Disable Toggle quantization. Default: Enabled + --activation-equalization + Enable Toggle Activation Equalization. Default: + Disabled + --no-activation-equalization + Disable Toggle Activation Equalization. Default: + Disabled + --gptq Enable Toggle gptq. Default: Disabled + --no-gptq Disable Toggle gptq. Default: Disabled + --float16 Enable Enable float16 execution. Default: Enabled + --no-float16 Disable Enable float16 execution. Default: Enabled + --attention-slicing Enable Enable attention slicing. Default: Disabled + --no-attention-slicing + Disable Enable attention slicing. Default: Disabled + --is-sd-xl Enable Enable this flag to correctly export SDXL. + Default: Disabled + --no-is-sd-xl Disable Enable this flag to correctly export SDXL. + Default: Disabled + --export-target {,onnx} + Target export flow. + --export-weight-q-node + Enable Enable export of floating point weights + QDQ + rather than integer weights + DQ. Default: Disabled + --no-export-weight-q-node + Disable Enable export of floating point weights + QDQ + rather than integer weights + DQ. Default: Disabled + --conv-weight-bit-width CONV_WEIGHT_BIT_WIDTH + Weight bit width. Default: 8. + --linear-weight-bit-width LINEAR_WEIGHT_BIT_WIDTH + Weight bit width. Default: 8. + --conv-input-bit-width CONV_INPUT_BIT_WIDTH + Input bit width. Default: None (not quantized) + --linear-input-bit-width LINEAR_INPUT_BIT_WIDTH + Input bit width. Default: None (not quantized). + --weight-param-method {stats,mse} + How scales/zero-point are determined. Default: stats. + --input-param-method {stats,mse} + How scales/zero-point are determined. Default: stats. + --weight-scale-precision {float_scale,po2_scale} + Whether scale is a float value or a po2. Default: + float_scale. + --input-scale-precision {float_scale,po2_scale} + Whether scale is a float value or a po2. Default: + float_scale. + --weight-quant-type {sym,asym} + Weight quantization type. Default: asym. + --input-quant-type {sym,asym} + Input quantization type. Default: asym. + --weight-quant-format WEIGHT_QUANT_FORMAT + Weight quantization type. Either int or eXmY, with + X+Y==weight_bit_width-1. Default: int. + --input-quant-format INPUT_QUANT_FORMAT + Weight quantization type. Either int or eXmY, with + X+Y==weight_bit_width-1. Default: int. + --weight-quant-granularity {per_channel,per_tensor,per_group} + Granularity for scales/zero-point of weights. Default: + per_channel. + --input-quant-granularity {per_tensor} + Granularity for scales/zero-point of inputs. Default: + per_tensor. + --input-scale-type {static,dynamic} + Whether to do static or dynamic input quantization. + Default: static. + --weight-group-size WEIGHT_GROUP_SIZE + Group size for per_group weight quantization. Default: + 16. + --quantize-weight-zero-point + Enable Quantize weight zero-point. Default: Enabled + --no-quantize-weight-zero-point + Disable Quantize weight zero-point. Default: Enabled + --export-cuda-float16 + Enable Export FP16 on CUDA. Default: Disabled + --no-export-cuda-float16 + Disable Export FP16 on CUDA. Default: Disabled diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 69677b0f6..3fa5e9ca1 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -7,29 +7,40 @@ from datetime import datetime import json import os -import re import time from dependencies import value -from diffusers import StableDiffusionPipeline +from diffusers import DiffusionPipeline import torch from torch import nn +from tqdm import tqdm from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas.export.torch.qcdq.manager import TorchQCDQManager +from brevitas.graph.calibrate import bias_correction_mode +from brevitas.graph.calibrate import calibration_mode +from brevitas.graph.equalize import activation_equalization_mode +from brevitas.graph.gptq import gptq_mode +from brevitas.nn.equalized_layer import EqualizedModule +from brevitas.utils.torch_utils import KwargsForwardHook from brevitas_examples.common.generative.quantize import quantize_model from brevitas_examples.common.parse_utils import add_bool_arg from brevitas_examples.common.parse_utils import quant_format_validator from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager from brevitas_examples.stable_diffusion.sd_quant.constants import SD_2_1_EMBEDDINGS_SHAPE +from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.export import export_onnx -from brevitas_examples.stable_diffusion.sd_quant.export import export_torchscript from brevitas_examples.stable_diffusion.sd_quant.utils import generate_latents -from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_rand_inputs +from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_21_rand_inputs +from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_xl_rand_inputs from brevitas_examples.stable_diffusion.sd_quant.utils import unet_input_shape TEST_SEED = 123456 +VALIDATION_PROMPTS = { + 'validation_prompt_0': 'A cat playing with a ball', + 'validation_prompt_1': 'A dog running on the beach'} + def run_test_inference( pipe, resolution, prompts, seeds, output_path, device, dtype, name_prefix=''): @@ -47,8 +58,21 @@ def run_test_inference( images[i].save(file_path) +def run_val_inference(pipe, resolution, prompts, seeds, output_path, device, dtype, name_prefix=''): + with torch.no_grad(): + test_latents = generate_latents(seeds, device, dtype, unet_input_shape(resolution)) + + for name, prompt in prompts.items(): + print(f"Generating: {name}") + # We don't want to generate any image, so we return only the latent encoding pre VAE + pipe([prompt] * len(seeds), latents=test_latents, output_type='latent') + + def main(args): + if args.export_target: + assert args.weight_quant_format == 'int', "Currently only integer quantization supported for export." + # Select dtype if args.float16: dtype = torch.float16 @@ -70,7 +94,7 @@ def main(args): # Load model from float checkpoint print(f"Loading model from {args.model}...") - pipe = StableDiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype) + pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype) print(f"Model loaded from {args.model}.") # Enable attention slicing @@ -89,11 +113,30 @@ def main(args): if hasattr(m, 'lora_layer') and m.lora_layer is not None: raise RuntimeError("LoRA layers should be fused in before calling into quantization.") + # Move model to target device + print(f"Moving model to {args.device}...") + pipe = pipe.to(args.device) + + if args.activation_equalization: + with activation_equalization_mode(pipe.unet, alpha=0.5, layerwise=True, add_mul_node=True): + # Workaround to expose `in_features` attribute from the Hook Wrapper + for m in pipe.unet.modules(): + if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'): + m.in_features = m.module.in_features + prompts = VALIDATION_PROMPTS + run_val_inference( + pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) + + # Workaround to expose `in_features` attribute from the EqualizedModule Wrapper + for m in pipe.unet.modules(): + if isinstance(m, EqualizedModule) and hasattr(m.module, 'in_features'): + m.in_features = m.module.in_features + # Quantize model if args.quantize: @value - def bit_width(module): + def weight_bit_width(module): if isinstance(module, nn.Linear): return args.linear_weight_bit_width elif isinstance(module, nn.Conv2d): @@ -101,24 +144,71 @@ def bit_width(module): else: raise RuntimeError(f"Module {module} not supported.") + # XOR between the two input_bit_width. Either they are both None, or neither of them is + assert (args.linear_input_bit_width is None) == (args.conv_input_bit_width is None), 'Both input bit width must be specified or left to None' + + is_input_quantized = args.linear_input_bit_width is not None and args.conv_input_bit_width is not None + if is_input_quantized: + + @value + def input_bit_width(module): + if isinstance(module, nn.Linear): + return args.linear_input_bit_width + elif isinstance(module, nn.Conv2d): + return args.conv_input_bit_width + else: + raise RuntimeError(f"Module {module} not supported.") + else: + input_bit_width = None + print("Applying model quantization...") quantize_model( pipe.unet, dtype=dtype, + device=args.device, name_blacklist=blacklist, + weight_bit_width=weight_bit_width, weight_quant_format=args.weight_quant_format, weight_quant_type=args.weight_quant_type, - weight_bit_width=bit_width, weight_param_method=args.weight_param_method, weight_scale_precision=args.weight_scale_precision, weight_quant_granularity=args.weight_quant_granularity, weight_group_size=args.weight_group_size, - quantize_weight_zero_point=args.quantize_weight_zero_point) + quantize_weight_zero_point=args.quantize_weight_zero_point, + input_bit_width=input_bit_width, + input_quant_format=args.input_quant_format, + input_scale_type=args.input_scale_type, + input_scale_precision=args.input_scale_precision, + input_param_method=args.input_param_method, + input_quant_type=args.input_quant_type, + input_quant_granularity=args.input_quant_granularity) print("Model quantization applied.") - # Move model to target device - print(f"Moving model to {args.device}...") - pipe = pipe.to(args.device) + if is_input_quantized and args.input_scale_type == 'static': + print("Applying activation calibration") + with calibration_mode(pipe.unet): + prompts = VALIDATION_PROMPTS + run_val_inference( + pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) + + if args.gptq: + print("Applying GPTQ. It can take several hours") + with gptq_mode(pipe.unet, + create_weight_orig=False, + use_quant_activations=False, + return_forward_output=True, + act_order=True) as gptq: + prompts = VALIDATION_PROMPTS + for _ in tqdm(range(gptq.num_layers)): + run_val_inference( + pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) + gptq.update() + + print("Applying bias correction") + with bias_correction_mode(pipe.unet): + prompts = VALIDATION_PROMPTS + run_val_inference( + pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) # Perform inference if args.prompt: @@ -134,26 +224,26 @@ def bit_width(module): pipe.unet.eval() device = next(iter(pipe.unet.parameters())).device dtype = next(iter(pipe.unet.parameters())).dtype - if args.export_target: - assert args.weight_quant_format == 'int', "Only integer quantization supported for export." - trace_inputs = generate_unet_rand_inputs( - embedding_shape=SD_2_1_EMBEDDINGS_SHAPE, + + # Define tracing input + if args.is_sd_xl: + generate_fn = generate_unet_xl_rand_inputs + shape = SD_XL_EMBEDDINGS_SHAPE + else: + generate_fn = generate_unet_21_rand_inputs + shape = SD_2_1_EMBEDDINGS_SHAPE + trace_inputs = generate_fn( + embedding_shape=shape, unet_input_shape=unet_input_shape(args.resolution), device=device, dtype=dtype) - if args.export_target == 'torchscript': - if args.weight_quant_granularity == 'per_group': - export_manager = BlockQuantProxyLevelManager - else: - export_manager = TorchQCDQManager - export_manager.change_weight_export(export_weight_q_node=True) - export_torchscript(pipe, trace_inputs, output_dir, export_manager) - elif args.export_target == 'onnx': + + if args.export_target == 'onnx': if args.weight_quant_granularity == 'per_group': export_manager = BlockQuantProxyLevelManager else: export_manager = StdQCDQONNXManager - export_manager.change_weight_export(export_weight_q_node=True) + export_manager.change_weight_export(export_weight_q_node=args.export_weight_q_node) export_onnx(pipe, trace_inputs, output_dir, export_manager) @@ -167,12 +257,12 @@ def bit_width(module): help='Path or name of the model.') parser.add_argument( '-d', '--device', type=str, default='cuda:0', help='Target device for quantized model.') - parser.add_argument('-b', '--batch-size', type=int, default=4, help='Batch size.') + parser.add_argument('-b', '--batch-size', type=int, default=4, help='Batch size. Default: 4') parser.add_argument( '--prompt', type=str, default='An austronaut riding a horse on Mars.', - help='Manual prompt for testing.') + help='Manual prompt for testing. Default: An austronaut riding a horse on Mars.') parser.add_argument( '--resolution', type=int, @@ -184,57 +274,125 @@ def bit_width(module): str_true=True, default='.', help='Path where to generate output folder.') - add_bool_arg(parser, 'quantize', default=True, help='Toggle quantization.') - add_bool_arg(parser, 'float16', default=True, help='Enable float16 execution.') - add_bool_arg(parser, 'attention-slicing', default=False, help='Enable attention slicing.') + add_bool_arg(parser, 'quantize', default=True, help='Toggle quantization. Default: Enabled') + add_bool_arg( + parser, + 'activation-equalization', + default=False, + help='Toggle Activation Equalization. Default: Disabled') + add_bool_arg(parser, 'gptq', default=False, help='Toggle gptq. Default: Disabled') + add_bool_arg(parser, 'float16', default=True, help='Enable float16 execution. Default: Enabled') + add_bool_arg( + parser, + 'attention-slicing', + default=False, + help='Enable attention slicing. Default: Disabled') + add_bool_arg( + parser, + 'is-sd-xl', + default=False, + help='Enable this flag to correctly export SDXL. Default: Disabled') parser.add_argument( - '--export-target', - type=str, - default='', - choices=['', 'torchscript', 'onnx'], - help='Target export flow.') + '--export-target', type=str, default='', choices=['', 'onnx'], help='Target export flow.') + add_bool_arg( + parser, + 'export-weight-q-node', + default=False, + help= + 'Enable export of floating point weights + QDQ rather than integer weights + DQ. Default: Disabled' + ) parser.add_argument( '--conv-weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.') parser.add_argument( - '--linear-weight-bit-width', type=int, default=8, help='Weight bit width. Default: 4.') + '--linear-weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.') + parser.add_argument( + '--conv-input-bit-width', + type=int, + default=None, + help='Input bit width. Default: None (not quantized)') + parser.add_argument( + '--linear-input-bit-width', + type=int, + default=None, + help='Input bit width. Default: None (not quantized).') parser.add_argument( '--weight-param-method', type=str, default='stats', choices=['stats', 'mse'], help='How scales/zero-point are determined. Default: stats.') + parser.add_argument( + '--input-param-method', + type=str, + default='stats', + choices=['stats', 'mse'], + help='How scales/zero-point are determined. Default: stats.') parser.add_argument( '--weight-scale-precision', type=str, default='float_scale', choices=['float_scale', 'po2_scale'], help='Whether scale is a float value or a po2. Default: float_scale.') + parser.add_argument( + '--input-scale-precision', + type=str, + default='float_scale', + choices=['float_scale', 'po2_scale'], + help='Whether scale is a float value or a po2. Default: float_scale.') parser.add_argument( '--weight-quant-type', type=str, default='asym', choices=['sym', 'asym'], help='Weight quantization type. Default: asym.') + parser.add_argument( + '--input-quant-type', + type=str, + default='asym', + choices=['sym', 'asym'], + help='Input quantization type. Default: asym.') parser.add_argument( '--weight-quant-format', type=quant_format_validator, default='int', help= 'Weight quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. Default: int.') + parser.add_argument( + '--input-quant-format', + type=quant_format_validator, + default='int', + help= + 'Weight quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. Default: int.') parser.add_argument( '--weight-quant-granularity', type=str, - default='per_group', + default='per_channel', choices=['per_channel', 'per_tensor', 'per_group'], - help='Granularity for scales/zero-point of weights. Default: per_group.') + help='Granularity for scales/zero-point of weights. Default: per_channel.') + parser.add_argument( + '--input-quant-granularity', + type=str, + default='per_tensor', + choices=['per_tensor'], + help='Granularity for scales/zero-point of inputs. Default: per_tensor.') + parser.add_argument( + '--input-scale-type', + type=str, + default='static', + choices=['static', 'dynamic'], + help='Whether to do static or dynamic input quantization. Default: static.') parser.add_argument( '--weight-group-size', type=int, default=16, help='Group size for per_group weight quantization. Default: 16.') add_bool_arg( - parser, 'quantize-weight-zero-point', default=True, help='Quantize weight zero-point.') - add_bool_arg(parser, 'export-cuda-float16', default=False, help='Export FP16 on CUDA') + parser, + 'quantize-weight-zero-point', + default=True, + help='Quantize weight zero-point. Default: Enabled') + add_bool_arg( + parser, 'export-cuda-float16', default=False, help='Export FP16 on CUDA. Default: Disabled') args = parser.parse_args() print("Args: " + str(vars(args))) main(args) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/constants.py b/src/brevitas_examples/stable_diffusion/sd_quant/constants.py index 7359a06dd..1c288f7fa 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/constants.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/constants.py @@ -4,3 +4,4 @@ """ SD_2_1_EMBEDDINGS_SHAPE = (77, 1024) +SD_XL_EMBEDDINGS_SHAPE = (77, 2048) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py index 45ccf7990..b466d6303 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/export.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -26,30 +26,8 @@ def forward(self, *args, **kwargs): return self.unet(*args, **kwargs, return_dict=False) -def export_torchscript(pipe, trace_inputs, output_dir, export_manager): - with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager): - fx_g = make_fx( - UnetExportWrapper(pipe.unet), - decomposition_table=get_decompositions([ - torch.ops.aten.embedding_dense_backward, - torch.ops.aten.native_layer_norm_backward, - torch.ops.aten.slice_backward, - torch.ops.aten.select_backward, - torch.ops.aten.norm.ScalarOpt_dim, - torch.ops.aten.native_group_norm, - torch.ops.aten.upsample_bilinear2d.vec, - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes,]), - )(*tuple(trace_inputs.values())) - _force_requires_grad_false(fx_g) - jit_g = torch.jit.trace(_JitTraceExportWrapper(fx_g), tuple(trace_inputs.values())) - output_path = os.path.join(output_dir, 'unet.ts') - print(f"Saving unet to {output_path} ...") - torch.jit.save(jit_g, output_path) - - def export_onnx(pipe, trace_inputs, output_dir, export_manager): output_path = os.path.join(output_dir, 'unet.onnx') print(f"Saving unet to {output_path} ...") with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager): - torch.onnx.export(pipe.unet, args=tuple(trace_inputs.values()), f=output_path) + torch.onnx.export(pipe.unet, args=trace_inputs, f=output_path) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/utils.py b/src/brevitas_examples/stable_diffusion/sd_quant/utils.py index 7a88aebdf..b2c30176f 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/utils.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/utils.py @@ -47,3 +47,39 @@ def generate_unet_rand_inputs( if with_return_dict_false: unet_rand_inputs['return_dict'] = False return unet_rand_inputs + + +def generate_unet_21_rand_inputs( + embedding_shape, + unet_input_shape, + batch_size=1, + device='cpu', + dtype=torch.float32, + with_return_dict_false=False): + unet_rand_inputs = generate_unet_rand_inputs( + embedding_shape, unet_input_shape, batch_size, device, dtype, with_return_dict_false) + return tuple(unet_rand_inputs.values()) + + +def generate_unet_xl_rand_inputs( + embedding_shape, + unet_input_shape, + batch_size=1, + device='cpu', + dtype=torch.float32, + with_return_dict_false=False): + # We need to pass a combination of args and kwargs to ONNX export + # If we pass all kwargs, something breaks + # If we pass only the last element as kwargs, since it is a dict, it has a weird interaction and something breaks + # The solution is to pass only one argument as args, and everything else as kwargs + unet_rand_inputs = generate_unet_rand_inputs( + embedding_shape, unet_input_shape, batch_size, device, dtype, with_return_dict_false) + sample = unet_rand_inputs['sample'] + del unet_rand_inputs['sample'] + unet_rand_inputs['timestep_cond'] = None + unet_rand_inputs['cross_attention_kwargs'] = None + unet_rand_inputs['added_cond_kwargs'] = { + "text_embeds": torch.randn(1, 1280, dtype=dtype, device=device), + "time_ids": torch.randn(1, 6, dtype=dtype, device=device)} + inputs = (sample, unet_rand_inputs) + return inputs