Skip to content

Commit

Permalink
Fix (examples/stable_diffusion): fix for bitwidth and export (#824)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Feb 6, 2024
1 parent 45a7acc commit c768441
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 16 deletions.
42 changes: 29 additions & 13 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@
import re
import time

from dependencies import value
from diffusers import StableDiffusionPipeline
import torch
from torch import nn

from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas.export.torch.qcdq.manager import TorchQCDQManager
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.common.parse_utils import add_bool_arg
from brevitas_examples.common.parse_utils import quant_format_validator
from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager
from brevitas_examples.stable_diffusion.sd_quant.constants import SD_2_1_EMBEDDINGS_SHAPE
from brevitas_examples.stable_diffusion.sd_quant.export import export_torchscript_weight_group_quant
from brevitas_examples.stable_diffusion.sd_quant.export import export_onnx
from brevitas_examples.stable_diffusion.sd_quant.export import export_torchscript
from brevitas_examples.stable_diffusion.sd_quant.utils import generate_latents
from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_rand_inputs
from brevitas_examples.stable_diffusion.sd_quant.utils import unet_input_shape
Expand Down Expand Up @@ -87,24 +92,23 @@ def main(args):
# Quantize model
if args.quantize:

def bit_width_fn(module):
@value
def bit_width(module):
if isinstance(module, nn.Linear):
return args.linear_weight_bit_width
elif isinstance(module, nn.Conv2d):
return args.conv_weight_bit_width
else:
raise RuntimeError(f"Module {module} not supported.")

weight_bit_width = lambda module: bit_width_fn(module)

print("Applying model quantization...")
quantize_model(
pipe.unet,
dtype=dtype,
name_blacklist=blacklist,
weight_quant_format=args.weight_quant_format,
weight_quant_type=args.weight_quant_type,
weight_bit_width=weight_bit_width,
weight_bit_width=bit_width,
weight_param_method=args.weight_param_method,
weight_scale_precision=args.weight_scale_precision,
weight_quant_granularity=args.weight_quant_granularity,
Expand All @@ -127,17 +131,29 @@ def bit_width_fn(module):
# Move to cpu and to float32 to enable CPU export
pipe.unet.to('cpu').to(torch.float32)
pipe.unet.eval()
if args.export_target == 'torchscript_weight_group_quant':
assert args.weight_quant_granularity == 'per_group', "Per-group quantization required."
device = next(iter(pipe.unet.parameters())).device
dtype = next(iter(pipe.unet.parameters())).dtype
if args.export_target:
assert args.weight_quant_format == 'int', "Only integer quantization supported for export."
trace_inputs = generate_unet_rand_inputs(
embedding_shape=SD_2_1_EMBEDDINGS_SHAPE,
unet_input_shape=unet_input_shape(args.resolution),
device='cpu',
dtype=torch.float32)
export_torchscript_weight_group_quant(pipe, trace_inputs, output_dir)
else:
raise ValueError(f"{args.export_target} not recognized.")
device=device,
dtype=dtype)
if args.export_target == 'torchscript':
if args.weight_quant_granularity == 'per_group':
export_manager = BlockQuantProxyLevelManager
else:
export_manager = TorchQCDQManager
export_manager.change_weight_export(export_weight_q_node=True)
export_torchscript(pipe, trace_inputs, output_dir, export_manager)
elif args.export_target == 'onnx':
if args.weight_quant_granularity == 'per_group':
export_manager = BlockQuantProxyLevelManager
else:
export_manager = StdQCDQONNXManager
export_manager.change_weight_export(export_weight_q_node=True)
export_onnx(pipe, trace_inputs, output_dir, export_manager)


if __name__ == "__main__":
Expand Down Expand Up @@ -174,7 +190,7 @@ def bit_width_fn(module):
'--export-target',
type=str,
default='',
choices=['', 'torchscript_weight_group_quant'],
choices=['', 'torchscript', 'onnx'],
help='Target export flow.')
parser.add_argument(
'--conv-weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.')
Expand Down
14 changes: 11 additions & 3 deletions src/brevitas_examples/stable_diffusion/sd_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from brevitas.backport.fx.experimental.proxy_tensor import make_fx
from brevitas.export.manager import _force_requires_grad_false
from brevitas.export.manager import _JitTraceExportWrapper
from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager
from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode


Expand All @@ -25,8 +26,8 @@ def forward(self, *args, **kwargs):
return self.unet(*args, **kwargs, return_dict=False)


def export_torchscript_weight_group_quant(pipe, trace_inputs, output_dir):
with brevitas_proxy_export_mode(pipe.unet):
def export_torchscript(pipe, trace_inputs, output_dir, export_manager):
with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager):
fx_g = make_fx(
UnetExportWrapper(pipe.unet),
decomposition_table=get_decompositions([
Expand All @@ -39,9 +40,16 @@ def export_torchscript_weight_group_quant(pipe, trace_inputs, output_dir):
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,]),
)(*trace_inputs.values())
)(*tuple(trace_inputs.values()))
_force_requires_grad_false(fx_g)
jit_g = torch.jit.trace(_JitTraceExportWrapper(fx_g), tuple(trace_inputs.values()))
output_path = os.path.join(output_dir, 'unet.ts')
print(f"Saving unet to {output_path} ...")
torch.jit.save(jit_g, output_path)


def export_onnx(pipe, trace_inputs, output_dir, export_manager):
output_path = os.path.join(output_dir, 'unet.onnx')
print(f"Saving unet to {output_path} ...")
with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager):
torch.onnx.export(pipe.unet, args=tuple(trace_inputs.values()), f=output_path)

0 comments on commit c768441

Please sign in to comment.