Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat (examples): add support for Stable Diffusion XL
Browse files Browse the repository at this point in the history
Giuseppe5 committed Apr 8, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 44cb08b commit 3036b4d
Showing 7 changed files with 379 additions and 65 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
from brevitas.graph.utils import get_module
from brevitas.graph.utils import get_node
from brevitas.nn.equalized_layer import EqualizedModule
from brevitas.nn.equalized_layer import INPUT_NAMES
from brevitas.nn.quant_scale_bias import ScaleBias
from brevitas.utils.torch_utils import KwargsForwardHook

@@ -970,8 +971,7 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k
self.float_act_map[name] = None
return

possible_input_kwargs = ['input', 'inp', 'query']
input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0]
input_kwarg = [x for x in kwargs.keys() if x in INPUT_NAMES][0]
if use_inp:
x = kwargs[input_kwarg]
elif not use_inp:
2 changes: 1 addition & 1 deletion src/brevitas/nn/equalized_layer.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@

from brevitas.nn.quant_mha import QuantMultiheadAttention

INPUT_NAMES = ['input', 'inp', 'query', 'x']
INPUT_NAMES = ['input', 'inp', 'query', 'x', 'hidden_states']


class EqualizedModule(torch.nn.Module):
141 changes: 141 additions & 0 deletions src/brevitas_examples/stable_diffusion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Stable Diffusion Quantization

It currently supports Stable Diffusion 2.1 and Stable Diffusion XL.

The following PTQ techniques are currently supported:
- Activation Equalization (e.g., SmoothQuant), layerwise (with the addition of Mul ops)
- Activation Calibration, in the case of static activation quantization
- GPTQ
- Bias Correction

These techniques can be applied for both integer and floating point quantization

We support ONNX integer export, and we are planning to release soon export for floating point quantization (e.g., FP8).
To export the model in fp16, enable `export-cuda-float16`. This will performing the tracing necessary for export on GPU, leaving the model in fp16.
If the flag is not enabled, the model will be moved to CPU and cast to float32 before export because of missing CPU kernels in fp16.

NB: when exporting Stable Diffusion XL, make sure to enable `is-sd-xl` flag. The flag is not needed when export is not executed.


## Run
usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT]
[--resolution RESOLUTION]
[--output-path OUTPUT_PATH | --no-output-path]
[--quantize | --no-quantize]
[--activation-equalization | --no-activation-equalization]
[--gptq | --no-gptq] [--float16 | --no-float16]
[--attention-slicing | --no-attention-slicing]
[--is-sd-xl | --no-is-sd-xl] [--export-target {,onnx}]
[--export-weight-q-node | --no-export-weight-q-node]
[--conv-weight-bit-width CONV_WEIGHT_BIT_WIDTH]
[--linear-weight-bit-width LINEAR_WEIGHT_BIT_WIDTH]
[--conv-input-bit-width CONV_INPUT_BIT_WIDTH]
[--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH]
[--weight-param-method {stats,mse}]
[--input-param-method {stats,mse}]
[--weight-scale-precision {float_scale,po2_scale}]
[--input-scale-precision {float_scale,po2_scale}]
[--weight-quant-type {sym,asym}]
[--input-quant-type {sym,asym}]
[--weight-quant-format WEIGHT_QUANT_FORMAT]
[--input-quant-format INPUT_QUANT_FORMAT]
[--weight-quant-granularity {per_channel,per_tensor,per_group}]
[--input-quant-granularity {per_tensor}]
[--input-scale-type {static,dynamic}]
[--weight-group-size WEIGHT_GROUP_SIZE]
[--quantize-weight-zero-point | --no-quantize-weight-zero-point]
[--export-cuda-float16 | --no-export-cuda-float16]

Stable Diffusion quantization

options:
-h, --help show this help message and exit
-m MODEL, --model MODEL
Path or name of the model.
-d DEVICE, --device DEVICE
Target device for quantized model.
-b BATCH_SIZE, --batch-size BATCH_SIZE
Batch size. Default: 4
--prompt PROMPT Manual prompt for testing. Default: An austronaut
riding a horse on Mars.
--resolution RESOLUTION
Resolution along height and width dimension. Default:
512.
--output-path OUTPUT_PATH
Path where to generate output folder.
--no-output-path Disable Path where to generate output folder.
--quantize Enable Toggle quantization. Default: Enabled
--no-quantize Disable Toggle quantization. Default: Enabled
--activation-equalization
Enable Toggle Activation Equalization. Default:
Disabled
--no-activation-equalization
Disable Toggle Activation Equalization. Default:
Disabled
--gptq Enable Toggle gptq. Default: Disabled
--no-gptq Disable Toggle gptq. Default: Disabled
--float16 Enable Enable float16 execution. Default: Enabled
--no-float16 Disable Enable float16 execution. Default: Enabled
--attention-slicing Enable Enable attention slicing. Default: Disabled
--no-attention-slicing
Disable Enable attention slicing. Default: Disabled
--is-sd-xl Enable Enable this flag to correctly export SDXL.
Default: Disabled
--no-is-sd-xl Disable Enable this flag to correctly export SDXL.
Default: Disabled
--export-target {,onnx}
Target export flow.
--export-weight-q-node
Enable Enable export of floating point weights + QDQ
rather than integer weights + DQ. Default: Disabled
--no-export-weight-q-node
Disable Enable export of floating point weights + QDQ
rather than integer weights + DQ. Default: Disabled
--conv-weight-bit-width CONV_WEIGHT_BIT_WIDTH
Weight bit width. Default: 8.
--linear-weight-bit-width LINEAR_WEIGHT_BIT_WIDTH
Weight bit width. Default: 8.
--conv-input-bit-width CONV_INPUT_BIT_WIDTH
Input bit width. Default: None (not quantized)
--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH
Input bit width. Default: None (not quantized).
--weight-param-method {stats,mse}
How scales/zero-point are determined. Default: stats.
--input-param-method {stats,mse}
How scales/zero-point are determined. Default: stats.
--weight-scale-precision {float_scale,po2_scale}
Whether scale is a float value or a po2. Default:
float_scale.
--input-scale-precision {float_scale,po2_scale}
Whether scale is a float value or a po2. Default:
float_scale.
--weight-quant-type {sym,asym}
Weight quantization type. Default: asym.
--input-quant-type {sym,asym}
Input quantization type. Default: asym.
--weight-quant-format WEIGHT_QUANT_FORMAT
Weight quantization type. Either int or eXmY, with
X+Y==weight_bit_width-1. Default: int.
--input-quant-format INPUT_QUANT_FORMAT
Weight quantization type. Either int or eXmY, with
X+Y==weight_bit_width-1. Default: int.
--weight-quant-granularity {per_channel,per_tensor,per_group}
Granularity for scales/zero-point of weights. Default:
per_channel.
--input-quant-granularity {per_tensor}
Granularity for scales/zero-point of inputs. Default:
per_tensor.
--input-scale-type {static,dynamic}
Whether to do static or dynamic input quantization.
Default: static.
--weight-group-size WEIGHT_GROUP_SIZE
Group size for per_group weight quantization. Default:
16.
--quantize-weight-zero-point
Enable Quantize weight zero-point. Default: Enabled
--no-quantize-weight-zero-point
Disable Quantize weight zero-point. Default: Enabled
--export-cuda-float16
Enable Export FP16 on CUDA. Default: Disabled
--no-export-cuda-float16
Disable Export FP16 on CUDA. Default: Disabled
236 changes: 197 additions & 39 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
@@ -7,29 +7,40 @@
from datetime import datetime
import json
import os
import re
import time

from dependencies import value
from diffusers import StableDiffusionPipeline
from diffusers import DiffusionPipeline
import torch
from torch import nn
from tqdm import tqdm

from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas.export.torch.qcdq.manager import TorchQCDQManager
from brevitas.graph.calibrate import bias_correction_mode
from brevitas.graph.calibrate import calibration_mode
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.gptq import gptq_mode
from brevitas.nn.equalized_layer import EqualizedModule
from brevitas.utils.torch_utils import KwargsForwardHook
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.common.parse_utils import add_bool_arg
from brevitas_examples.common.parse_utils import quant_format_validator
from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager
from brevitas_examples.stable_diffusion.sd_quant.constants import SD_2_1_EMBEDDINGS_SHAPE
from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE
from brevitas_examples.stable_diffusion.sd_quant.export import export_onnx
from brevitas_examples.stable_diffusion.sd_quant.export import export_torchscript
from brevitas_examples.stable_diffusion.sd_quant.utils import generate_latents
from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_rand_inputs
from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_21_rand_inputs
from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_xl_rand_inputs
from brevitas_examples.stable_diffusion.sd_quant.utils import unet_input_shape

TEST_SEED = 123456

VALIDATION_PROMPTS = {
'validation_prompt_0': 'A cat playing with a ball',
'validation_prompt_1': 'A dog running on the beach'}


def run_test_inference(
pipe, resolution, prompts, seeds, output_path, device, dtype, name_prefix=''):
@@ -47,8 +58,21 @@ def run_test_inference(
images[i].save(file_path)


def run_val_inference(pipe, resolution, prompts, seeds, output_path, device, dtype, name_prefix=''):
with torch.no_grad():
test_latents = generate_latents(seeds, device, dtype, unet_input_shape(resolution))

for name, prompt in prompts.items():
print(f"Generating: {name}")
# We don't want to generate any image, so we return only the latent encoding pre VAE
pipe([prompt] * len(seeds), latents=test_latents, output_type='latent')


def main(args):

if args.export_target:
assert args.weight_quant_format == 'int', "Currently only integer quantization supported for export."

# Select dtype
if args.float16:
dtype = torch.float16
@@ -70,7 +94,7 @@ def main(args):

# Load model from float checkpoint
print(f"Loading model from {args.model}...")
pipe = StableDiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype)
pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype)
print(f"Model loaded from {args.model}.")

# Enable attention slicing
@@ -89,36 +113,102 @@ def main(args):
if hasattr(m, 'lora_layer') and m.lora_layer is not None:
raise RuntimeError("LoRA layers should be fused in before calling into quantization.")

# Move model to target device
print(f"Moving model to {args.device}...")
pipe = pipe.to(args.device)

if args.activation_equalization:
with activation_equalization_mode(pipe.unet, alpha=0.5, layerwise=True, add_mul_node=True):
# Workaround to expose `in_features` attribute from the Hook Wrapper
for m in pipe.unet.modules():
if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'):
m.in_features = m.module.in_features
prompts = VALIDATION_PROMPTS
run_val_inference(
pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype)

# Workaround to expose `in_features` attribute from the EqualizedModule Wrapper
for m in pipe.unet.modules():
if isinstance(m, EqualizedModule) and hasattr(m.module, 'in_features'):
m.in_features = m.module.in_features

# Quantize model
if args.quantize:

@value
def bit_width(module):
def weight_bit_width(module):
if isinstance(module, nn.Linear):
return args.linear_weight_bit_width
elif isinstance(module, nn.Conv2d):
return args.conv_weight_bit_width
else:
raise RuntimeError(f"Module {module} not supported.")

# XOR between the two input_bit_width. Either they are both None, or neither of them is
assert (args.linear_input_bit_width is None) == (args.conv_input_bit_width is None), 'Both input bit width must be specified or left to None'

is_input_quantized = args.linear_input_bit_width is not None and args.conv_input_bit_width is not None
if is_input_quantized:

@value
def input_bit_width(module):
if isinstance(module, nn.Linear):
return args.linear_input_bit_width
elif isinstance(module, nn.Conv2d):
return args.conv_input_bit_width
else:
raise RuntimeError(f"Module {module} not supported.")
else:
input_bit_width = None

print("Applying model quantization...")
quantize_model(
pipe.unet,
dtype=dtype,
device=args.device,
name_blacklist=blacklist,
weight_bit_width=weight_bit_width,
weight_quant_format=args.weight_quant_format,
weight_quant_type=args.weight_quant_type,
weight_bit_width=bit_width,
weight_param_method=args.weight_param_method,
weight_scale_precision=args.weight_scale_precision,
weight_quant_granularity=args.weight_quant_granularity,
weight_group_size=args.weight_group_size,
quantize_weight_zero_point=args.quantize_weight_zero_point)
quantize_weight_zero_point=args.quantize_weight_zero_point,
input_bit_width=input_bit_width,
input_quant_format=args.input_quant_format,
input_scale_type=args.input_scale_type,
input_scale_precision=args.input_scale_precision,
input_param_method=args.input_param_method,
input_quant_type=args.input_quant_type,
input_quant_granularity=args.input_quant_granularity)
print("Model quantization applied.")

# Move model to target device
print(f"Moving model to {args.device}...")
pipe = pipe.to(args.device)
if is_input_quantized and args.input_scale_type == 'static':
print("Applying activation calibration")
with calibration_mode(pipe.unet):
prompts = VALIDATION_PROMPTS
run_val_inference(
pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype)

if args.gptq:
print("Applying GPTQ. It can take several hours")
with gptq_mode(pipe.unet,
create_weight_orig=False,
use_quant_activations=False,
return_forward_output=True,
act_order=True) as gptq:
prompts = VALIDATION_PROMPTS
for _ in tqdm(range(gptq.num_layers)):
run_val_inference(
pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype)
gptq.update()

print("Applying bias correction")
with bias_correction_mode(pipe.unet):
prompts = VALIDATION_PROMPTS
run_val_inference(
pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype)

# Perform inference
if args.prompt:
@@ -134,26 +224,26 @@ def bit_width(module):
pipe.unet.eval()
device = next(iter(pipe.unet.parameters())).device
dtype = next(iter(pipe.unet.parameters())).dtype
if args.export_target:
assert args.weight_quant_format == 'int', "Only integer quantization supported for export."
trace_inputs = generate_unet_rand_inputs(
embedding_shape=SD_2_1_EMBEDDINGS_SHAPE,

# Define tracing input
if args.is_sd_xl:
generate_fn = generate_unet_xl_rand_inputs
shape = SD_XL_EMBEDDINGS_SHAPE
else:
generate_fn = generate_unet_21_rand_inputs
shape = SD_2_1_EMBEDDINGS_SHAPE
trace_inputs = generate_fn(
embedding_shape=shape,
unet_input_shape=unet_input_shape(args.resolution),
device=device,
dtype=dtype)
if args.export_target == 'torchscript':
if args.weight_quant_granularity == 'per_group':
export_manager = BlockQuantProxyLevelManager
else:
export_manager = TorchQCDQManager
export_manager.change_weight_export(export_weight_q_node=True)
export_torchscript(pipe, trace_inputs, output_dir, export_manager)
elif args.export_target == 'onnx':

if args.export_target == 'onnx':
if args.weight_quant_granularity == 'per_group':
export_manager = BlockQuantProxyLevelManager
else:
export_manager = StdQCDQONNXManager
export_manager.change_weight_export(export_weight_q_node=True)
export_manager.change_weight_export(export_weight_q_node=args.export_weight_q_node)
export_onnx(pipe, trace_inputs, output_dir, export_manager)


@@ -167,12 +257,12 @@ def bit_width(module):
help='Path or name of the model.')
parser.add_argument(
'-d', '--device', type=str, default='cuda:0', help='Target device for quantized model.')
parser.add_argument('-b', '--batch-size', type=int, default=4, help='Batch size.')
parser.add_argument('-b', '--batch-size', type=int, default=4, help='Batch size. Default: 4')
parser.add_argument(
'--prompt',
type=str,
default='An austronaut riding a horse on Mars.',
help='Manual prompt for testing.')
help='Manual prompt for testing. Default: An austronaut riding a horse on Mars.')
parser.add_argument(
'--resolution',
type=int,
@@ -184,57 +274,125 @@ def bit_width(module):
str_true=True,
default='.',
help='Path where to generate output folder.')
add_bool_arg(parser, 'quantize', default=True, help='Toggle quantization.')
add_bool_arg(parser, 'float16', default=True, help='Enable float16 execution.')
add_bool_arg(parser, 'attention-slicing', default=False, help='Enable attention slicing.')
add_bool_arg(parser, 'quantize', default=True, help='Toggle quantization. Default: Enabled')
add_bool_arg(
parser,
'activation-equalization',
default=False,
help='Toggle Activation Equalization. Default: Disabled')
add_bool_arg(parser, 'gptq', default=False, help='Toggle gptq. Default: Disabled')
add_bool_arg(parser, 'float16', default=True, help='Enable float16 execution. Default: Enabled')
add_bool_arg(
parser,
'attention-slicing',
default=False,
help='Enable attention slicing. Default: Disabled')
add_bool_arg(
parser,
'is-sd-xl',
default=False,
help='Enable this flag to correctly export SDXL. Default: Disabled')
parser.add_argument(
'--export-target',
type=str,
default='',
choices=['', 'torchscript', 'onnx'],
help='Target export flow.')
'--export-target', type=str, default='', choices=['', 'onnx'], help='Target export flow.')
add_bool_arg(
parser,
'export-weight-q-node',
default=False,
help=
'Enable export of floating point weights + QDQ rather than integer weights + DQ. Default: Disabled'
)
parser.add_argument(
'--conv-weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.')
parser.add_argument(
'--linear-weight-bit-width', type=int, default=8, help='Weight bit width. Default: 4.')
'--linear-weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.')
parser.add_argument(
'--conv-input-bit-width',
type=int,
default=None,
help='Input bit width. Default: None (not quantized)')
parser.add_argument(
'--linear-input-bit-width',
type=int,
default=None,
help='Input bit width. Default: None (not quantized).')
parser.add_argument(
'--weight-param-method',
type=str,
default='stats',
choices=['stats', 'mse'],
help='How scales/zero-point are determined. Default: stats.')
parser.add_argument(
'--input-param-method',
type=str,
default='stats',
choices=['stats', 'mse'],
help='How scales/zero-point are determined. Default: stats.')
parser.add_argument(
'--weight-scale-precision',
type=str,
default='float_scale',
choices=['float_scale', 'po2_scale'],
help='Whether scale is a float value or a po2. Default: float_scale.')
parser.add_argument(
'--input-scale-precision',
type=str,
default='float_scale',
choices=['float_scale', 'po2_scale'],
help='Whether scale is a float value or a po2. Default: float_scale.')
parser.add_argument(
'--weight-quant-type',
type=str,
default='asym',
choices=['sym', 'asym'],
help='Weight quantization type. Default: asym.')
parser.add_argument(
'--input-quant-type',
type=str,
default='asym',
choices=['sym', 'asym'],
help='Input quantization type. Default: asym.')
parser.add_argument(
'--weight-quant-format',
type=quant_format_validator,
default='int',
help=
'Weight quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. Default: int.')
parser.add_argument(
'--input-quant-format',
type=quant_format_validator,
default='int',
help=
'Weight quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. Default: int.')
parser.add_argument(
'--weight-quant-granularity',
type=str,
default='per_group',
default='per_channel',
choices=['per_channel', 'per_tensor', 'per_group'],
help='Granularity for scales/zero-point of weights. Default: per_group.')
help='Granularity for scales/zero-point of weights. Default: per_channel.')
parser.add_argument(
'--input-quant-granularity',
type=str,
default='per_tensor',
choices=['per_tensor'],
help='Granularity for scales/zero-point of inputs. Default: per_tensor.')
parser.add_argument(
'--input-scale-type',
type=str,
default='static',
choices=['static', 'dynamic'],
help='Whether to do static or dynamic input quantization. Default: static.')
parser.add_argument(
'--weight-group-size',
type=int,
default=16,
help='Group size for per_group weight quantization. Default: 16.')
add_bool_arg(
parser, 'quantize-weight-zero-point', default=True, help='Quantize weight zero-point.')
add_bool_arg(parser, 'export-cuda-float16', default=False, help='Export FP16 on CUDA')
parser,
'quantize-weight-zero-point',
default=True,
help='Quantize weight zero-point. Default: Enabled')
add_bool_arg(
parser, 'export-cuda-float16', default=False, help='Export FP16 on CUDA. Default: Disabled')
args = parser.parse_args()
print("Args: " + str(vars(args)))
main(args)
Original file line number Diff line number Diff line change
@@ -4,3 +4,4 @@
"""

SD_2_1_EMBEDDINGS_SHAPE = (77, 1024)
SD_XL_EMBEDDINGS_SHAPE = (77, 2048)
24 changes: 1 addition & 23 deletions src/brevitas_examples/stable_diffusion/sd_quant/export.py
Original file line number Diff line number Diff line change
@@ -26,30 +26,8 @@ def forward(self, *args, **kwargs):
return self.unet(*args, **kwargs, return_dict=False)


def export_torchscript(pipe, trace_inputs, output_dir, export_manager):
with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager):
fx_g = make_fx(
UnetExportWrapper(pipe.unet),
decomposition_table=get_decompositions([
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,]),
)(*tuple(trace_inputs.values()))
_force_requires_grad_false(fx_g)
jit_g = torch.jit.trace(_JitTraceExportWrapper(fx_g), tuple(trace_inputs.values()))
output_path = os.path.join(output_dir, 'unet.ts')
print(f"Saving unet to {output_path} ...")
torch.jit.save(jit_g, output_path)


def export_onnx(pipe, trace_inputs, output_dir, export_manager):
output_path = os.path.join(output_dir, 'unet.onnx')
print(f"Saving unet to {output_path} ...")
with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager):
torch.onnx.export(pipe.unet, args=tuple(trace_inputs.values()), f=output_path)
torch.onnx.export(pipe.unet, args=trace_inputs, f=output_path)
36 changes: 36 additions & 0 deletions src/brevitas_examples/stable_diffusion/sd_quant/utils.py
Original file line number Diff line number Diff line change
@@ -47,3 +47,39 @@ def generate_unet_rand_inputs(
if with_return_dict_false:
unet_rand_inputs['return_dict'] = False
return unet_rand_inputs


def generate_unet_21_rand_inputs(
embedding_shape,
unet_input_shape,
batch_size=1,
device='cpu',
dtype=torch.float32,
with_return_dict_false=False):
unet_rand_inputs = generate_unet_rand_inputs(
embedding_shape, unet_input_shape, batch_size, device, dtype, with_return_dict_false)
return tuple(unet_rand_inputs.values())


def generate_unet_xl_rand_inputs(
embedding_shape,
unet_input_shape,
batch_size=1,
device='cpu',
dtype=torch.float32,
with_return_dict_false=False):
# We need to pass a combination of args and kwargs to ONNX export
# If we pass all kwargs, something breaks
# If we pass only the last element as kwargs, since it is a dict, it has a weird interaction and something breaks
# The solution is to pass only one argument as args, and everything else as kwargs
unet_rand_inputs = generate_unet_rand_inputs(
embedding_shape, unet_input_shape, batch_size, device, dtype, with_return_dict_false)
sample = unet_rand_inputs['sample']
del unet_rand_inputs['sample']
unet_rand_inputs['timestep_cond'] = None
unet_rand_inputs['cross_attention_kwargs'] = None
unet_rand_inputs['added_cond_kwargs'] = {
"text_embeds": torch.randn(1, 1280, dtype=dtype, device=device),
"time_ids": torch.randn(1, 6, dtype=dtype, device=device)}
inputs = (sample, unet_rand_inputs)
return inputs

0 comments on commit 3036b4d

Please sign in to comment.