From 7f73055a4bab65b325303cfe304b3ee215bc77ae Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 27 Apr 2024 19:22:55 +0100 Subject: [PATCH 01/22] Feat (examples/stable_diffusion): improvements to SD --- src/brevitas/graph/gpfq.py | 5 + src/brevitas/graph/gptq.py | 5 + src/brevitas/graph/gpxq.py | 13 + .../common/generative/quantize.py | 30 +- .../stable_diffusion/README.md | 63 +- .../stable_diffusion/main.py | 367 +++++++++-- .../mlperf_evaluation/accuracy.py | 526 +++++++++++++++ .../mlperf_evaluation/backend.py | 610 ++++++++++++++++++ .../mlperf_evaluation/dataset.py | 359 +++++++++++ .../mlperf_evaluation/requirements.txt | 8 + .../stable_diffusion/sd_quant/export.py | 14 +- .../stable_diffusion/sd_quant/utils.py | 73 +++ 12 files changed, 1991 insertions(+), 82 deletions(-) create mode 100644 src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py create mode 100644 src/brevitas_examples/stable_diffusion/mlperf_evaluation/backend.py create mode 100644 src/brevitas_examples/stable_diffusion/mlperf_evaluation/dataset.py create mode 100644 src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index fd7df9223..9a8adc8e9 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -298,6 +298,9 @@ def single_layer_update(self): # No permutation, permutation tensor is a ordered index perm = torch.tensor(range(weight.shape[-1]), device=dev) permutation_list.append(perm) + + self.reactivate_quantization() + for t in range(weight.shape[-1]): for group_index in range(self.groups): U[group_index] += torch.matmul( @@ -397,6 +400,8 @@ def single_layer_update(self): perm = torch.tensor(range(weight.shape[-1]), device=dev) permutation_list.append(perm) + self.reactivate_quantization() + for t in range(weight.shape[-1]): for group_index in range(self.groups): U[group_index] += torch.matmul( diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 31d31433b..0861fd15c 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -86,9 +86,11 @@ def catch_stopfwd(self, *args, **kwargs): # If we want to return the output of the network, we need to disable all hooks for name, gpxq_class in self.gpxq_layers.items(): gpxq_class.disable_pre_forward_hook = True + out = self.orig_forward(*args, **kwargs) for name, gpxq_class in self.gpxq_layers.items(): gpxq_class.disable_pre_forward_hook = False + return out def initialize_module_optimizer( @@ -134,6 +136,7 @@ def __init__( device='cpu', dtype=torch.float32) self.nsamples = 0 + self.done = False assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher" @@ -257,6 +260,8 @@ def single_layer_update(self, percdamp=.01): finally: del self.H + self.reactivate_quantization() + for i1 in range(0, self.columns, self.blocksize): i2 = min(i1 + self.blocksize, self.columns) count = i2 - i1 diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index fdbaee52f..2b46ac4f4 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -164,6 +164,10 @@ def __enter__(self): return self def __exit__(self, type, value, traceback): + for name, layer in self.gpxq_layers.items(): + if not layer.done: + layer.reactivate_quantization() + if isinstance(self.model, (GraphModule, TorchGraphModule)): self.model.__class__.forward = self.orig_forward else: @@ -219,6 +223,10 @@ def __init__( self.disable_pre_forward_hook = False # Some layers require knowledge from quant inputs to compute quant weights self.quant_metadata = None + self.disable_quant_inference = DisableEnableQuantization() + self.return_quant_tensor_state = disable_return_quant_tensor(self.layer) + self.disable_quant_inference.disable_param_quantization(self.layer, False) + self.done = False def process_input(self, inp): # Input is a tuple, so we take first element @@ -255,6 +263,11 @@ def update_batch(self): def single_layer_update(self): pass + def reactivate_quantization(self): + self.done = True + self.disable_quant_inference.enable_param_quantization(self.layer, False) + restore_return_quant_tensor(self.layer, self.return_quant_tensor_state) + def get_quant_weights(self, i, i1, permutation_list): # We need to recompute quant weights at runtime since our float weights are being updated # Add offset in case of blockwise computation diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 31ab57361..0b1614e74 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -4,6 +4,7 @@ """ import re +import torch from torch import nn from brevitas import nn as qnn @@ -13,6 +14,8 @@ from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint @@ -79,7 +82,14 @@ 'per_channel': { 'sym': Fp8e4m3WeightPerChannelFloat}, 'per_group': { - 'sym': Fp8e4m3WeightSymmetricGroupQuant}},}}} + 'sym': Fp8e4m3WeightSymmetricGroupQuant}}}}, + 'float_ocp': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3OCPWeightPerTensorFloat}, + 'per_channel': { + 'sym': Fp8e4m3OCPWeightPerChannelFloat}}}}} INPUT_QUANT_MAP = { 'int': { @@ -142,7 +152,10 @@ def quantize_model( input_group_size=None, quantize_input_zero_point=False, quantize_embedding=False, - device=None): + use_ocp=False, + device=None, + weight_kwargs=None, + input_kwargs=None): """ Replace float layers with quant layers in the target model """ @@ -154,6 +167,8 @@ def quantize_model( 'exponent_bit_width': int(weight_quant_format[1]), 'mantissa_bit_width': int(weight_quant_format[3])} weight_quant_format = 'float' + if use_ocp: + weight_quant_format += '_ocp' else: weight_float_format = {} if re.compile(r'e[1-8]m[1-8]').match(input_quant_format): @@ -161,6 +176,8 @@ def quantize_model( 'exponent_bit_width': int(input_quant_format[1]), 'mantissa_bit_width': int(input_quant_format[3])} input_quant_format = 'float' + if use_ocp: + input_quant_format += '_ocp' else: input_float_format = {} @@ -178,6 +195,11 @@ def quantize_model( linear_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ input_scale_precision][input_param_method][input_quant_granularity][input_quant_type] + if input_kwargs is not None: + input_quant = input_quant.let(**input_kwargs) + sym_input_quant = sym_input_quant.let(**input_kwargs) + linear_input_quant = linear_input_quant.let(**input_kwargs) + else: input_quant = None sym_input_quant = None @@ -190,6 +212,10 @@ def quantize_model( 'narrow_range': False, 'quantize_zero_point': quantize_weight_zero_point}, **weight_float_format) + if dtype == torch.float16: + weight_quant = weight_quant.let(**{'scaling_min_val': 1e-4}) + if weight_kwargs is not None: + weight_quant = weight_quant.let(**weight_kwargs) # Set the group_size is we're doing groupwise quantization if weight_quant_granularity == 'per_group': diff --git a/src/brevitas_examples/stable_diffusion/README.md b/src/brevitas_examples/stable_diffusion/README.md index 30754b3d8..1c71e4431 100644 --- a/src/brevitas_examples/stable_diffusion/README.md +++ b/src/brevitas_examples/stable_diffusion/README.md @@ -16,18 +16,25 @@ We support ONNX integer export, and we are planning to release soon export for f To export the model with fp16 scale factors, enable `export-cuda-float16`. This will performing the tracing necessary for export on GPU, leaving the model in fp16. If the flag is not enabled, the model will be moved to CPU and cast to float32 before export because of missing CPU kernels in fp16. +To use MLPerf inference setup, check and install the correct requirements specified in the `requirements.txt` file under mlperf_evaluation. ## Run ```bash usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] - [--resolution RESOLUTION] + [--calibration-prompt CALIBRATION_PROMPT] + [--calibration-prompt-path CALIBRATION_PROMPT_PATH] + [--checkpoint-name CHECKPOINT_NAME] + [--path-to-latents PATH_TO_LATENTS] [--resolution RESOLUTION] + [--guidance-scale GUIDANCE_SCALE] + [--calibration-steps CALIBRATION_STEPS] [--output-path OUTPUT_PATH | --no-output-path] [--quantize | --no-quantize] [--activation-equalization | --no-activation-equalization] - [--gptq | --no-gptq] [--float16 | --no-float16] + [--gptq | --no-gptq] [--bias-correction | --no-bias-correction] + [--dtype {float32,float16,bfloat16}] [--attention-slicing | --no-attention-slicing] - [--export-target {,onnx}] + [--export-target {,torch,onnx}] [--export-weight-q-node | --no-export-weight-q-node] [--conv-weight-bit-width CONV_WEIGHT_BIT_WIDTH] [--linear-weight-bit-width LINEAR_WEIGHT_BIT_WIDTH] @@ -47,6 +54,9 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--weight-group-size WEIGHT_GROUP_SIZE] [--quantize-weight-zero-point | --no-quantize-weight-zero-point] [--export-cuda-float16 | --no-export-cuda-float16] + [--use-mlperf-inference | --no-use-mlperf-inference] + [--use-ocp | --no-use-ocp] + [--use-negative-prompts | --no-use-negative-prompts] Stable Diffusion quantization @@ -57,12 +67,27 @@ options: -d DEVICE, --device DEVICE Target device for quantized model. -b BATCH_SIZE, --batch-size BATCH_SIZE - Batch size. Default: 4 - --prompt PROMPT Manual prompt for testing. Default: An austronaut - riding a horse on Mars. + How many seeds to use for each image during + validation. Default: 2 + --prompt PROMPT Number of prompt to use for testing. Default: 4. Max: + 4 + --calibration-prompt CALIBRATION_PROMPT + Number of prompt to use for calibration. Default: 2 + --calibration-prompt-path CALIBRATION_PROMPT_PATH + Path to calibration prompt + --checkpoint-name CHECKPOINT_NAME + Name to use to store the checkpoint. If not provided, + no checkpoint is saved. + --path-to-latents PATH_TO_LATENTS + Load pre-defined latents. If not provided, they are + generated based on an internal seed. --resolution RESOLUTION Resolution along height and width dimension. Default: 512. + --guidance-scale GUIDANCE_SCALE + Guidance scale. + --calibration-steps CALIBRATION_STEPS + Percentage of steps used during calibration --output-path OUTPUT_PATH Path where to generate output folder. --no-output-path Disable Path where to generate output folder. @@ -76,12 +101,15 @@ options: Disabled --gptq Enable Toggle gptq. Default: Disabled --no-gptq Disable Toggle gptq. Default: Disabled - --float16 Enable Enable float16 execution. Default: Enabled - --no-float16 Disable Enable float16 execution. Default: Enabled + --bias-correction Enable Toggle bias-correction. Default: Enabled + --no-bias-correction Disable Toggle bias-correction. Default: Enabled + --dtype {float32,float16,bfloat16} + Model Dtype, choices are float32, float16, bfloat16. + Default: float16 --attention-slicing Enable Enable attention slicing. Default: Disabled --no-attention-slicing Disable Enable attention slicing. Default: Disabled - --export-target {,onnx} + --export-target {,torch,onnx} Target export flow. --export-weight-q-node Enable Enable export of floating point weights + QDQ @@ -137,4 +165,21 @@ options: Enable Export FP16 on CUDA. Default: Disabled --no-export-cuda-float16 Disable Export FP16 on CUDA. Default: Disabled + --use-mlperf-inference + Enable Evaluate FID score with MLPerf pipeline. + Default: False + --no-use-mlperf-inference + Disable Evaluate FID score with MLPerf pipeline. + Default: False + --use-ocp Enable Use OCP format for float quantization. Default: + True + --no-use-ocp Disable Use OCP format for float quantization. + Default: True + --use-negative-prompts + Enable Use negative prompts during + generation/calibration. Default: Enabled + --no-use-negative-prompts + Disable Use negative prompts during + generation/calibration. Default: Enabled + ``` diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 5d626accb..b5c2eee5f 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -5,6 +5,7 @@ import argparse from datetime import datetime +from functools import partial import json import os import time @@ -12,24 +13,32 @@ from dependencies import value from diffusers import DiffusionPipeline from diffusers import StableDiffusionXLPipeline +import numpy as np +import pandas as pd import torch from torch import nn +from torchmetrics.image.fid import FrechetInceptionDistance from tqdm import tqdm from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager +from brevitas.export.torch.qcdq.manager import TorchQCDQManager from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.gptq import gptq_mode +from brevitas.inject.enum import QuantType from brevitas.nn.equalized_layer import EqualizedModule +from brevitas.nn.forward_handlers import brevitas_proxy_inference_mode from brevitas.utils.torch_utils import KwargsForwardHook from brevitas_examples.common.generative.quantize import quantize_model from brevitas_examples.common.parse_utils import add_bool_arg from brevitas_examples.common.parse_utils import quant_format_validator from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager +from brevitas_examples.stable_diffusion.mlperf_evaluation.accuracy import compute_mlperf_fid from brevitas_examples.stable_diffusion.sd_quant.constants import SD_2_1_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.export import export_onnx +from brevitas_examples.stable_diffusion.sd_quant.export import export_torch_export from brevitas_examples.stable_diffusion.sd_quant.utils import generate_latents from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_21_rand_inputs from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_xl_rand_inputs @@ -37,35 +46,88 @@ TEST_SEED = 123456 -VALIDATION_PROMPTS = { - 'validation_prompt_0': 'A cat playing with a ball', - 'validation_prompt_1': 'A dog running on the beach'} +NEGATIVE_PROMPTS = ["normal quality, low quality, worst quality, low res, blurry, nsfw, nude"] + +CALIBRATION_PROMPTS = [ + 'A man in a space suit playing a guitar, inspired by Cyril Rolando, highly detailed illustration, full color illustration, very detailed illustration, dan mumford and alex grey style', + 'a living room, bright modern Scandinavian style house, large windows, magazine photoshoot, 8k, studio lighting', + 'cute rabbit in a spacesuit', + 'minimalistic plolygon geometric car in brutalism warehouse, Rick Owens'] + +TESTING_PROMPTS = [ + 'batman, cute modern disney style, Pixar 3d portrait, ultra detailed, gorgeous, 3d zbrush, trending on dribbble, 8k render', + 'A beautiful stack of rocks sitting on top of a beach, a picture, red black white golden colors, chakras, packshot, stock photo', + 'A painting of a fish on a black background, a digital painting, by Jason Benjamin, colorful vector illustration, mixed media style illustration, epic full color illustration, mascot illustration', + 'close up photo of a rabbit, forest in spring, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot' +] + + +def load_calib_prompts(calib_data_path, sep="\t"): + df = pd.read_csv(calib_data_path, sep=sep) + lst = df["caption"].tolist() + return lst def run_test_inference( - pipe, resolution, prompts, seeds, output_path, device, dtype, name_prefix=''): + pipe, + resolution, + prompts, + seeds, + output_path, + device, + dtype, + use_negative_prompts, + guidance_scale, + name_prefix=''): + images = dict() with torch.no_grad(): if not os.path.exists(output_path): os.mkdir(output_path) test_latents = generate_latents(seeds, device, dtype, unet_input_shape(resolution)) - - for name, prompt in prompts.items(): - print(f"Generating: {name}") - images = pipe([prompt] * len(seeds), latents=test_latents).images - for i, seed in enumerate(seeds): - file_path = os.path.join(output_path, f"{name_prefix}{name}_{seed}.png") + neg_prompts = NEGATIVE_PROMPTS * len(seeds) if use_negative_prompts else [] + for prompt in prompts: + prompt_images = pipe([prompt] * len(seeds), + latents=test_latents, + negative_prompt=neg_prompts, + guidance_scale=guidance_scale).images + images[prompt] = prompt_images + + i = 0 + for prompt, prompt_images in images.items(): + for image in prompt_images: + file_path = os.path.join(output_path, f"{name_prefix}{i}.png") print(f"Saving to {file_path}") - images[i].save(file_path) - - -def run_val_inference(pipe, resolution, prompts, seeds, output_path, device, dtype, name_prefix=''): + image.save(file_path) + i += 1 + return images + + +def run_val_inference( + pipe, + resolution, + prompts, + seeds, + device, + dtype, + use_negative_prompts, + guidance_scale, + total_steps, + test_latents=None): with torch.no_grad(): - test_latents = generate_latents(seeds, device, dtype, unet_input_shape(resolution)) - for name, prompt in prompts.items(): - print(f"Generating: {name}") + if test_latents is None: + test_latents = generate_latents(seeds[0], device, dtype, unet_input_shape(resolution)) + + neg_prompts = NEGATIVE_PROMPTS if use_negative_prompts else [] + for prompt in prompts: # We don't want to generate any image, so we return only the latent encoding pre VAE - pipe([prompt] * len(seeds), latents=test_latents, output_type='latent') + pipe( + prompt, + negative_prompt=neg_prompts[0], + latents=test_latents, + output_type='latent', + guidance_scale=guidance_scale, + num_inference_steps=total_steps) def main(args): @@ -73,11 +135,21 @@ def main(args): if args.export_target: assert args.weight_quant_format == 'int', "Currently only integer quantization supported for export." - # Select dtype - if args.float16: - dtype = torch.float16 - else: - dtype = torch.float32 + dtype = getattr(torch, args.dtype) + + calibration_prompts = CALIBRATION_PROMPTS + if args.calibration_prompt_path is not None: + calibration_prompts = load_calib_prompts(args.calibration_prompt_path) + prompts = list() + for i, v in enumerate(calibration_prompts): + if i == args.calibration_prompt: + break + prompts.append(v) + calibration_prompts = prompts + + latents = None + if args.path_to_latents is not None: + latents = torch.load(args.path_to_latents).to(torch.float16) # Create output dir. Move to tmp if None ts = datetime.fromtimestamp(time.time()) @@ -97,6 +169,29 @@ def main(args): pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype) print(f"Model loaded from {args.model}.") + # Move model to target device + print(f"Moving model to {args.device}...") + pipe = pipe.to(args.device) + + if args.prompt > 0 and not args.use_mlperf_inference: + print(f"Running inference with prompt ...") + prompts = [] + for i, v in enumerate(TESTING_PROMPTS): + if i == args.prompt: + break + prompts.append(v) + float_images = run_test_inference( + pipe, + args.resolution, + prompts, + test_seeds, + output_dir, + args.device, + dtype, + guidance_scale=args.guidance_scale, + use_negative_prompts=args.use_negative_prompts, + name_prefix='float_') + # Detect Stable Diffusion XL pipeline is_sd_xl = isinstance(pipe, StableDiffusionXLPipeline) @@ -116,19 +211,23 @@ def main(args): if hasattr(m, 'lora_layer') and m.lora_layer is not None: raise RuntimeError("LoRA layers should be fused in before calling into quantization.") - # Move model to target device - print(f"Moving model to {args.device}...") - pipe = pipe.to(args.device) - if args.activation_equalization: with activation_equalization_mode(pipe.unet, alpha=0.5, layerwise=True, add_mul_node=True): # Workaround to expose `in_features` attribute from the Hook Wrapper for m in pipe.unet.modules(): if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'): m.in_features = m.module.in_features - prompts = VALIDATION_PROMPTS run_val_inference( - pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) + pipe, + args.resolution, + calibration_prompts, + test_seeds, + args.device, + dtype, + total_steps=args.calibration_steps, + use_negative_prompts=args.use_negative_prompts, + test_latents=latents, + guidance_scale=args.guidance_scale) # Workaround to expose `in_features` attribute from the EqualizedModule Wrapper for m in pipe.unet.modules(): @@ -147,22 +246,28 @@ def weight_bit_width(module): else: raise RuntimeError(f"Module {module} not supported.") - # XOR between the two input_bit_width. Either they are both None, or neither of them is - assert (args.linear_input_bit_width is None) == (args.conv_input_bit_width is None), 'Both input bit width must be specified or left to None' + @value + def input_bit_width(module): + if isinstance(module, nn.Linear): + return args.linear_input_bit_width + elif isinstance(module, nn.Conv2d): + return args.conv_input_bit_width + else: + raise RuntimeError(f"Module {module} not supported.") - is_input_quantized = args.linear_input_bit_width is not None and args.conv_input_bit_width is not None - if is_input_quantized: + input_kwargs = dict() + if args.linear_input_bit_width is None or args.conv_input_bit_width is None: @value - def input_bit_width(module): - if isinstance(module, nn.Linear): - return args.linear_input_bit_width - elif isinstance(module, nn.Conv2d): - return args.conv_input_bit_width + def input_quant_type(module): + if args.linear_input_bit_width is None and isinstance(module, nn.Linear): + return QuantType.FP + elif args.conv_input_bit_width is None and isinstance(module, nn.Conv2d): + return QuantType.FP else: - raise RuntimeError(f"Module {module} not supported.") - else: - input_bit_width = None + return QuantType.INT + + input_kwargs['quant_type'] = input_quant_type print("Applying model quantization...") quantize_model( @@ -184,46 +289,115 @@ def input_bit_width(module): input_scale_precision=args.input_scale_precision, input_param_method=args.input_param_method, input_quant_type=args.input_quant_type, - input_quant_granularity=args.input_quant_granularity) + input_quant_granularity=args.input_quant_granularity, + use_ocp=args.use_ocp, + input_kwargs=input_kwargs) print("Model quantization applied.") - if is_input_quantized and args.input_scale_type == 'static': + if (args.linear_input_bit_width is not None or + args.conv_input_bit_width is not None) and args.input_scale_type == 'static': print("Applying activation calibration") - with calibration_mode(pipe.unet): - prompts = VALIDATION_PROMPTS + with brevitas_proxy_inference_mode(pipe.unet), torch.no_grad(), calibration_mode(pipe.unet): run_val_inference( - pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) + pipe, + args.resolution, + calibration_prompts, + test_seeds, + args.device, + dtype, + total_steps=args.calibration_steps, + use_negative_prompts=args.use_negative_prompts, + test_latents=latents, + guidance_scale=args.guidance_scale) + pipe.set_progress_bar_config(disable=True) if args.gptq: print("Applying GPTQ. It can take several hours") - with gptq_mode(pipe.unet, + with torch.no_grad(), gptq_mode(pipe.unet, create_weight_orig=False, use_quant_activations=False, return_forward_output=True, act_order=True) as gptq: - prompts = VALIDATION_PROMPTS for _ in tqdm(range(gptq.num_layers)): run_val_inference( - pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) + pipe, + args.resolution, + calibration_prompts, + test_seeds, + args.device, + dtype, + total_steps=args.calibration_steps, + use_negative_prompts=args.use_negative_prompts, + test_latents=latents, + guidance_scale=args.guidance_scale) gptq.update() + torch.cuda.empty_cache() + pipe.set_progress_bar_config(disable=False) - print("Applying bias correction") - with bias_correction_mode(pipe.unet): - prompts = VALIDATION_PROMPTS - run_val_inference( - pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) + if args.bias_correction: + print("Applying bias correction") + with brevitas_proxy_inference_mode(pipe.unet), bias_correction_mode(pipe.unet): + run_val_inference( + pipe, + args.resolution, + calibration_prompts, + test_seeds, + args.device, + dtype, + total_steps=args.calibration_steps, + use_negative_prompts=args.use_negative_prompts, + test_latents=latents, + guidance_scale=args.guidance_scale) + + if args.checkpoint_name is not None: + torch.save(pipe.unet.state_dict(), args.checkpoint_name) # Perform inference - if args.prompt: - print(f"Running inference with prompt '{args.prompt}' ...") - prompts = {'manual_prompt': args.prompt} - run_test_inference( - pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) + if args.prompt > 0: + with brevitas_proxy_inference_mode(pipe.unet): + if args.use_mlperf_inference: + print(f"Computing accuracy with MLPerf pipeline") + compute_mlperf_fid(pipe.unet, args.prompt) + else: + print(f"Computing accuracy on default prompt") + prompts = list() + for i, v in enumerate(TESTING_PROMPTS): + if i == args.prompt: + break + prompts.append(v) + quant_images = run_test_inference( + pipe, + args.resolution, + prompts, + test_seeds, + output_dir, + args.device, + dtype, + use_negative_prompts=args.use_negative_prompts, + guidance_scale=args.guidance_scale, + name_prefix='quant_') + + float_images_values = float_images.values() + float_images_values = [x for x_nested in float_images_values for x in x_nested] + float_images_values = torch.tensor([ + np.array(image) for image in float_images_values]) + float_images_values = float_images_values.permute(0, 3, 1, 2) + + quant_images_values = quant_images.values() + quant_images_values = [x for x_nested in quant_images_values for x in x_nested] + quant_images_values = torch.tensor([ + np.array(image) for image in quant_images_values]) + quant_images_values = quant_images_values.permute(0, 3, 1, 2) + + fid = FrechetInceptionDistance(normalize=False) + fid.update(float_images_values, real=True) + fid.update(quant_images_values, real=False) + print(f"FID: {float(fid.compute())}") if args.export_target: # Move to cpu and to float32 to enable CPU export - if not (args.float16 and args.export_cuda_float16): - pipe.unet.to('cpu').to(torch.float32) + if not (dtype == torch.float16 and args.export_cuda_float16): + pipe.unet.to('cpu').to(dtype) pipe.unet.eval() device = next(iter(pipe.unet.parameters())).device dtype = next(iter(pipe.unet.parameters())).dtype @@ -248,6 +422,13 @@ def input_bit_width(module): export_manager = StdQCDQONNXManager export_manager.change_weight_export(export_weight_q_node=args.export_weight_q_node) export_onnx(pipe, trace_inputs, output_dir, export_manager) + if args.export_target == 'torch': + if args.weight_quant_granularity == 'per_group': + export_manager = BlockQuantProxyLevelManager + else: + export_manager = TorchQCDQManager + export_manager.change_weight_export(export_weight_q_node=args.export_weight_q_node) + export_torch_export(pipe, trace_inputs, output_dir, export_manager) if __name__ == "__main__": @@ -260,17 +441,46 @@ def input_bit_width(module): help='Path or name of the model.') parser.add_argument( '-d', '--device', type=str, default='cuda:0', help='Target device for quantized model.') - parser.add_argument('-b', '--batch-size', type=int, default=4, help='Batch size. Default: 4') + parser.add_argument( + '-b', + '--batch-size', + type=int, + default=2, + help='How many seeds to use for each image during validation. Default: 2') parser.add_argument( '--prompt', + type=int, + default=4, + help='Number of prompt to use for testing. Default: 4. Max: 4') + parser.add_argument( + '--calibration-prompt', + type=int, + default=2, + help='Number of prompt to use for calibration. Default: 2') + parser.add_argument( + '--calibration-prompt-path', type=str, default=None, help='Path to calibration prompt') + parser.add_argument( + '--checkpoint-name', type=str, - default='An austronaut riding a horse on Mars.', - help='Manual prompt for testing. Default: An austronaut riding a horse on Mars.') + default=None, + help='Name to use to store the checkpoint. If not provided, no checkpoint is saved.') + parser.add_argument( + '--path-to-latents', + type=str, + default=None, + help= + 'Load pre-defined latents. If not provided, they are generated based on an internal seed.') parser.add_argument( '--resolution', type=int, default=512, help='Resolution along height and width dimension. Default: 512.') + parser.add_argument('--guidance-scale', type=float, default=7.5, help='Guidance scale.') + parser.add_argument( + '--calibration-steps', + type=float, + default=8, + help='Percentage of steps used during calibration') add_bool_arg( parser, 'output-path', @@ -284,14 +494,24 @@ def input_bit_width(module): default=False, help='Toggle Activation Equalization. Default: Disabled') add_bool_arg(parser, 'gptq', default=False, help='Toggle gptq. Default: Disabled') - add_bool_arg(parser, 'float16', default=True, help='Enable float16 execution. Default: Enabled') + add_bool_arg( + parser, 'bias-correction', default=True, help='Toggle bias-correction. Default: Enabled') + parser.add_argument( + '--dtype', + default='float16', + choices=['float32', 'float16', 'bfloat16'], + help='Model Dtype, choices are float32, float16, bfloat16. Default: float16') add_bool_arg( parser, 'attention-slicing', default=False, help='Enable attention slicing. Default: Disabled') parser.add_argument( - '--export-target', type=str, default='', choices=['', 'onnx'], help='Target export flow.') + '--export-target', + type=str, + default='', + choices=['', 'torch', 'onnx'], + help='Target export flow.') add_bool_arg( parser, 'export-weight-q-node', @@ -391,6 +611,21 @@ def input_bit_width(module): help='Quantize weight zero-point. Default: Enabled') add_bool_arg( parser, 'export-cuda-float16', default=False, help='Export FP16 on CUDA. Default: Disabled') + add_bool_arg( + parser, + 'use-mlperf-inference', + default=False, + help='Evaluate FID score with MLPerf pipeline. Default: False') + add_bool_arg( + parser, + 'use-ocp', + default=True, + help='Use OCP format for float quantization. Default: True') + add_bool_arg( + parser, + 'use-negative-prompts', + default=True, + help='Use negative prompts during generation/calibration. Default: Enabled') args = parser.parse_args() print("Args: " + str(vars(args))) main(args) diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py new file mode 100644 index 000000000..71353e9cb --- /dev/null +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py @@ -0,0 +1,526 @@ +""" +Code is adapted from the MLPerf text-to-image pipeline: https://github.com/mlcommons/inference/tree/master/text_to_image +Available under the following LICENSE: + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS +""" + +import logging +import os +import pathlib +import random + +import numpy as np +from PIL import Image +from scipy import linalg +import torch +from torch.nn.functional import adaptive_avg_pool2d +import torchvision.transforms as TF +from tqdm import tqdm + +from brevitas_examples.inception import InceptionV3 +from brevitas_examples.stable_diffusion.mlperf_evaluation.backend import BackendPytorch +from brevitas_examples.stable_diffusion.mlperf_evaluation.backend import Item +from brevitas_examples.stable_diffusion.mlperf_evaluation.backend import RunnerBase +from brevitas_examples.stable_diffusion.mlperf_evaluation.dataset import Coco +from brevitas_examples.stable_diffusion.mlperf_evaluation.dataset import ImagesDataset + +IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 'tif', 'tiff', 'webp'} + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger() + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert (mu1.shape == mu2.shape), "Training and test mean vectors have different lengths" + assert (sigma1.shape == sigma2.shape), "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ( + "fid calculation produces singular product; " + "adding %s to diagonal of cov estimates") % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + +def get_activations(files, model, batch_size=50, dims=2048, device="cpu", num_workers=1): + """Calculates the activations of the pool_3 layer for all images. + + Params: + -- files : List of image files paths + -- model : Instance of inception model + -- batch_size : Batch size of images for the model to process at once. + Make sure that the number of samples is a multiple of + the batch size, otherwise some samples are ignored. This + behavior is retained to match the original FID score + implementation. + -- dims : Dimensionality of features returned by Inception + -- device : Device to run calculations + -- num_workers : Number of parallel dataloader workers + + Returns: + -- A numpy array of dimension (num images, dims) that contains the + activations of the given tensor when feeding inception with the + query tensor. + """ + model.eval() + + if batch_size > len(files): + print(( + "Warning: batch size is bigger than the data size. " + "Setting batch size to data size")) + batch_size = len(files) + + dataset = ImagesDataset(files, transforms=TF.ToTensor()) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + num_workers=num_workers, + ) + + pred_arr = np.empty((len(files), dims)) + + start_idx = 0 + + for batch in tqdm(dataloader): + batch = batch.to(device) + + with torch.no_grad(): + pred = model(batch)[0] + + # If model output is not scalar, apply global spatial average pooling. + # This happens if you choose a dimensionality not equal 2048. + if pred.size(2) != 1 or pred.size(3) != 1: + pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) + + pred = pred.squeeze(3).squeeze(2).cpu().numpy() + + pred_arr[start_idx:start_idx + pred.shape[0]] = pred + + start_idx = start_idx + pred.shape[0] + + return pred_arr + + +def calculate_activation_statistics( + files, model, batch_size=50, dims=2048, device="cpu", num_workers=1): + """Calculation of the statistics used by the FID. + Params: + -- files : List of image files paths + -- model : Instance of inception model + -- batch_size : The images numpy array is split into batches with + batch size batch_size. A reasonable batch size + depends on the hardware. + -- dims : Dimensionality of features returned by Inception + -- device : Device to run calculations + -- num_workers : Number of parallel dataloader workers + + Returns: + -- mu : The mean over samples of the activations of the pool_3 layer of + the inception model. + -- sigma : The covariance matrix of the activations of the pool_3 layer of + the inception model. + """ + act = get_activations(files, model, batch_size, dims, device, num_workers) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +def compute_statistics_of_path( + path, + model, + batch_size, + dims, + device, + num_workers=1, + subset_size=None, + shuffle_seed=None, + ds=None): + if path.endswith(".npz"): + with np.load(path) as f: + m, s = f["mu"][:], f["sigma"][:] + else: + path = pathlib.Path(path) + files = [file for ext in IMAGE_EXTENSIONS for file in path.glob("*.{}".format(ext))] + + files = ds.get_imgs([i for i in range(10)]) + files = [file.permute(1, 2, 0).numpy() for file in files] + if subset_size is not None: + random.seed(shuffle_seed) + files = random.sample(files, subset_size) + m, s = calculate_activation_statistics(files, model, batch_size, dims, device, num_workers) + + return m, s + + +def compute_fid( + results, + statistics_path, + device, + dims=2048, + num_workers=1, + batch_size=1, + subset_size=None, + shuffle_seed=None, + ds=None, +): + imgs = [Image.fromarray(e).convert("RGB") for e in results] + device = torch.device(device if torch.cuda.is_available() else "cpu") + if num_workers is None: + try: + num_cpus = len(os.sched_getaffinity(0)) + except AttributeError: + # os.sched_getaffinity is not available under Windows, use + # os.cpu_count instead (which may not return the *available* number + # of CPUs). + num_cpus = os.cpu_count() + + num_workers = min(num_cpus, 8) if num_cpus is not None else 0 + else: + num_workers = num_workers + # assert statistics_path.endswith(".npz") + + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + + model = InceptionV3([block_idx]).to(device) + + m1, s1 = compute_statistics_of_path( + statistics_path, + model, + batch_size, + dims, + device, + num_workers, + subset_size, + shuffle_seed, + ds=ds + ) + + m2, s2 = calculate_activation_statistics(imgs, model, batch_size, dims, device, num_workers) + + fid_value = calculate_frechet_distance(m1, s1, m2, s2) + + return fid_value + + +class PostProcessCoco: + + def __init__( + self, + device="cpu", + dtype="uint8", + statistics_path=os.path.join(os.path.dirname(__file__), "tools", "val2014.npz")): + self.results = [] + self.good = 0 + self.total = 0 + self.content_ids = [] + self.clip_scores = [] + self.fid_scores = [] + self.device = device if torch.cuda.is_available() else "cpu" + if dtype == "uint8": + self.dtype = torch.uint8 + self.numpy_dtype = np.uint8 + else: + raise ValueError(f"dtype must be one of: uint8") + self.statistics_path = statistics_path + + def add_results(self, results): + self.results.extend(results) + + def __call__(self, results, ids, expected=None, result_dict=None): + self.content_ids.extend(ids) + return [(t.cpu().permute(1, 2, 0).float().numpy() * 255).round().astype(self.numpy_dtype) + for t in results] + + def save_images(self, ids, ds): + info = [] + idx = {} + for i, id in enumerate(self.content_ids): + if id in ids: + idx[id] = i + if not os.path.exists("images/"): + os.makedirs("images/", exist_ok=True) + for id in ids: + caption = ds.get_caption(id) + generated = Image.fromarray(self.results[idx[id]]) + image_path_tmp = f"images/{self.content_ids[idx[id]]}.png" + generated.save(image_path_tmp) + info.append((self.content_ids[idx[id]], caption)) + with open("images/captions.txt", "w+") as f: + for id, caption in info: + f.write(f"{id} {caption}\n") + + def start(self): + self.results = [] + + def finalize(self, result_dict, ds=None, output_dir=None): + log.info("Accumulating results") + + fid_score = compute_fid(self.results, self.statistics_path, self.device, ds=ds) + result_dict["FID_SCORE"] = fid_score + + return result_dict + + +def compute_mlperf_fid(model_to_replace=None, samples_to_evaluate=500): + + post_proc = PostProcessCoco( + statistics_path='/scratch/users/gfranco/datasets/coco/tools/val2014.npz') + + dtype = next(iter(model_to_replace.parameters())).dtype + res_dict = {} + model = BackendPytorch( + '/scratch/hf_models/stable-diffusion-xl-base-1.0/stable-diffusion-xl-base-1.0/', + 'xl', + steps=20, + batch_size=1, + precision=dtype) + model.load() + + if model_to_replace is not None: + model.pipe.unet = model_to_replace + + ds = Coco( + data_path='/scratch/users/gfranco/datasets/coco', + name="coco-1024", + pre_process=torch.nn.Identity, + count=None, + threads=1, + pipe_tokenizer=model.pipe.tokenizer, + pipe_tokenizer_2=model.pipe.tokenizer_2, + latent_dtype=dtype, + latent_device='cuda', + latent_framework='torch', + **{"image_size": [3, 1024, 1024]}, + ) + model.pipe.set_progress_bar_config(disable=True) + with torch.no_grad(): + runner = RunnerBase(model, ds, 1, post_proc=post_proc, max_batchsize=1) + runner.start_run(res_dict, True) + idx = list(range(0, samples_to_evaluate)) + ds.load_query_samples(idx) + data, label = ds.get_samples(idx) + runner.run_one_item(Item(idx, idx, data, label)) + post_proc.finalize(res_dict, ds=ds) + log.info(res_dict) diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/backend.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/backend.py new file mode 100644 index 000000000..cf79421a0 --- /dev/null +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/backend.py @@ -0,0 +1,610 @@ +""" +Code is adapted from the MLPerf text-to-image pipeline: https://github.com/mlcommons/inference/tree/master/text_to_image +Available under the following LICENSE: + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS +""" +import array +import logging +import time +from typing import Optional + +from diffusers import EulerDiscreteScheduler +from diffusers import StableDiffusionXLPipeline +import numpy as np +import torch +from tqdm import tqdm + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger() + + +class Item: + """An item that we queue for processing by the thread pool.""" + + def __init__(self, query_id, content_id, inputs, img=None): + self.query_id = query_id + self.content_id = content_id + self.img = img + self.inputs = inputs + self.start = time.time() + + +class RunnerBase: + + def __init__(self, model, ds, threads, post_proc=None, max_batchsize=128): + self.take_accuracy = False + self.ds = ds + self.model = model + self.post_process = post_proc + self.threads = threads + self.take_accuracy = False + self.max_batchsize = max_batchsize + self.result_timing = [] + + def handle_tasks(self, tasks_queue): + pass + + def start_run(self, result_dict, take_accuracy): + self.result_dict = result_dict + self.result_timing = [] + self.take_accuracy = take_accuracy + self.post_process.start() + + def run_one_item(self, qitem: Item): + # run the prediction + processed_results = [] + try: + results = self.model.predict(qitem.inputs) + processed_results = self.post_process( + results, qitem.content_id, qitem.inputs, self.result_dict) + if self.take_accuracy: + self.post_process.add_results(processed_results) + self.result_timing.append(time.time() - qitem.start) + except Exception as ex: # pylint: disable=broad-except + src = [self.ds.get_item_loc(i) for i in qitem.content_id] + log.error("thread: failed on contentid=%s, %s", src, ex) + # since post_process will not run, fake empty responses + processed_results = [[]] * len(qitem.query_id) + finally: + response_array_refs = [] + response = [] + for idx, query_id in enumerate(qitem.query_id): + response_array = array.array( + "B", np.array(processed_results[idx], np.uint8).tobytes()) + response_array_refs.append(response_array) + bi = response_array.buffer_info() + response.append((query_id, bi[0], bi[1])) + # lg.QuerySamplesComplete(response) + + def enqueue(self, query_samples): + idx = [q.index for q in query_samples] + query_id = [q.id for q in query_samples] + if len(query_samples) < self.max_batchsize: + data, label = self.ds.get_samples(idx) + self.run_one_item(Item(query_id, idx, data, label)) + else: + bs = self.max_batchsize + for i in range(0, len(idx), bs): + data, label = self.ds.get_samples(idx[i:i + bs]) + self.run_one_item(Item(query_id[i:i + bs], idx[i:i + bs], data, label)) + + def finish(self): + pass + + +class BackendPytorch: + + def __init__( + self, + model_path=None, + model_id="xl", + guidance=8, + steps=20, + batch_size=1, + device="cuda", + precision=torch.float32, + negative_prompt="normal quality, low quality, worst quality, low res, blurry, nsfw, nude", + ): + self.inputs = [] + self.outputs = [] + + self.model_path = model_path + if model_id == "xl": + self.model_id = "stabilityai/stable-diffusion-xl-base-1.0" + else: + raise ValueError(f"{model_id} is not a valid model id") + + self.device = device if torch.cuda.is_available() else "cpu" + self.dtype = precision + + if torch.cuda.is_available(): + self.local_rank = 0 + self.world_size = 1 + + self.guidance = guidance + self.steps = steps + self.negative_prompt = negative_prompt + self.max_length_neg_prompt = 77 + self.batch_size = batch_size + + def version(self): + return torch.__version__ + + # def name(self): + # return "pytorch-SUT" + + def image_format(self): + return "NCHW" + + def load(self): + if self.model_path is None: + log.warning( + "Model path not provided, running with default hugging face weights\n" + "This may not be valid for official submissions") + self.scheduler = EulerDiscreteScheduler.from_pretrained( + self.model_id, subfolder="scheduler") + self.pipe = StableDiffusionXLPipeline.from_pretrained( + self.model_id, + scheduler=self.scheduler, + safety_checker=None, + add_watermarker=False, + variant="fp16" if (self.dtype == torch.float16) else None, + torch_dtype=self.dtype, + ) + # self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True) + else: + self.scheduler = EulerDiscreteScheduler.from_pretrained( + self.model_id, subfolder="scheduler") + self.pipe = StableDiffusionXLPipeline.from_pretrained( + self.model_path, + scheduler=self.scheduler, + safety_checker=None, + add_watermarker=False, + variant="fp16" if (self.dtype == torch.float16) else None, + torch_dtype=self.dtype, + ) + # self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True) + + self.pipe.to(self.device) + #self.pipe.set_progress_bar_config(disable=True) + + self.negative_prompt_tokens = self.pipe.tokenizer( + self.convert_prompt(self.negative_prompt, self.pipe.tokenizer), + padding="max_length", + max_length=self.max_length_neg_prompt, + truncation=True, + return_tensors="pt", + ) + self.negative_prompt_tokens_2 = self.pipe.tokenizer_2( + self.convert_prompt(self.negative_prompt, self.pipe.tokenizer_2), + padding="max_length", + max_length=self.max_length_neg_prompt, + truncation=True, + return_tensors="pt", + ) + return self + + def convert_prompt(self, prompt, tokenizer): + tokens = tokenizer.tokenize(prompt) + unique_tokens = set(tokens) + for token in unique_tokens: + if token in tokenizer.added_tokens_encoder: + replacement = token + i = 1 + while f"{token}_{i}" in tokenizer.added_tokens_encoder: + replacement += f" {token}_{i}" + i += 1 + + prompt = prompt.replace(token, replacement) + + return prompt + + def encode_tokens( + self, + pipe: StableDiffusionXLPipeline, + text_input: torch.Tensor, + text_input_2: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[torch.Tensor] = None, + negative_prompt_2: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the input tokens into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or pipe._execution_device + batch_size = text_input.input_ids.shape[0] + + # Define tokenizers and text encoders + tokenizers = ([pipe.tokenizer, pipe.tokenizer_2] if pipe.tokenizer is not None else [ + pipe.tokenizer_2]) + text_encoders = ([pipe.text_encoder, pipe.text_encoder_2] + if pipe.text_encoder is not None else [pipe.text_encoder_2]) + + if prompt_embeds is None: + text_input_2 = text_input_2 or text_input + + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + text_inputs_list = [text_input, text_input_2] + for text_inputs, tokenizer, text_encoder in zip( + text_inputs_list, tokenizers, text_encoders + ): + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = ( + negative_prompt is None and pipe.config.force_zeros_for_empty_prompt) + if (do_classifier_free_guidance and negative_prompt_embeds is None and + zero_out_negative_prompt): + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt_inputs = ( + negative_prompt.input_ids.repeat(batch_size, 1) if + (len(negative_prompt.input_ids.shape) == 1) else negative_prompt.input_ids) + negative_prompt_2_inputs = ( + negative_prompt_2.input_ids.repeat(batch_size, 1) if + (len(negative_prompt_2.input_ids.shape) == 1) else negative_prompt_2.input_ids) + + uncond_inputs = [negative_prompt_inputs, negative_prompt_2_inputs] + + negative_prompt_embeds_list = [] + for uncond_input, tokenizer, text_encoder in zip( + uncond_inputs, tokenizers, text_encoders + ): + negative_prompt_embeds = text_encoder( + uncond_input.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if pipe.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=pipe.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=pipe.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if pipe.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=pipe.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=pipe.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( + 1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1) + return ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + def prepare_inputs(self, inputs, i): + if self.batch_size == 1: + return self.encode_tokens( + self.pipe, + inputs[i]["input_tokens"], + inputs[i]["input_tokens_2"], + negative_prompt=self.negative_prompt_tokens, + negative_prompt_2=self.negative_prompt_tokens_2, + ) + else: + prompt_embeds = [] + negative_prompt_embeds = [] + pooled_prompt_embeds = [] + negative_pooled_prompt_embeds = [] + for prompt in inputs[i:min(i + self.batch_size, len(inputs))]: + assert isinstance(prompt, dict) + text_input = prompt["input_tokens"] + text_input_2 = prompt["input_tokens_2"] + ( + p_e, + n_p_e, + p_p_e, + n_p_p_e, + ) = self.encode_tokens( + self.pipe, + text_input, + text_input_2, + negative_prompt=self.negative_prompt_tokens, + negative_prompt_2=self.negative_prompt_tokens_2, + ) + prompt_embeds.append(p_e) + negative_prompt_embeds.append(n_p_e) + pooled_prompt_embeds.append(p_p_e) + negative_pooled_prompt_embeds.append(n_p_p_e) + + prompt_embeds = torch.cat(prompt_embeds) + negative_prompt_embeds = torch.cat(negative_prompt_embeds) + pooled_prompt_embeds = torch.cat(pooled_prompt_embeds) + negative_pooled_prompt_embeds = torch.cat(negative_pooled_prompt_embeds) + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def predict(self, inputs): + images = [0] * len(inputs) + with torch.no_grad(): + for i in tqdm(range(0, len(inputs), self.batch_size)): + max_index = min(i + self.batch_size, len(inputs)) + latents_input = [inputs[idx]["latents"] for idx in range(i, max_index)] + latents_input = torch.cat(latents_input).to(self.device) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.prepare_inputs(inputs, i) + generated = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + guidance_scale=self.guidance, + num_inference_steps=self.steps, + output_type="pt", + latents=latents_input, + ).images + images[i:i + max_index] = generated.cpu() + # images.extend(generated) + return images diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/dataset.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/dataset.py new file mode 100644 index 000000000..ab7f92ea1 --- /dev/null +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/dataset.py @@ -0,0 +1,359 @@ +""" +Code is adapted from the MLPerf text-to-image pipeline: https://github.com/mlcommons/inference/tree/master/text_to_image +Available under the following LICENSE: + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS +""" + +import logging +import os +import time + +import numpy as np +import pandas as pd +from PIL import Image +import torch + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger() + + +class Dataset: + + def __init__(self): + self.arrival = None + self.image_list = [] + self.caption_list = [] + self.items_inmemory = {} + self.last_loaded = -1 + + def preprocess(self, use_cache=True): + raise NotImplementedError("Dataset:preprocess") + + def get_item_count(self): + return len(self.image_list) + + def get_list(self): + raise NotImplementedError("Dataset:get_list") + + def load_query_samples(self, sample_list): + self.items_inmemory = {} + for sample in sample_list: + self.items_inmemory[sample] = self.get_item(sample) + self.last_loaded = time.time() + + def unload_query_samples(self, sample_list): + if sample_list: + for sample in sample_list: + if sample in self.items_inmemory: + del self.items_inmemory[sample] + else: + self.items_inmemory = {} + + def get_samples(self, id_list): + data = [{ + "input_tokens": self.items_inmemory[id]["input_tokens"], + "input_tokens_2": self.items_inmemory[id]["input_tokens_2"], + "latents": self.items_inmemory[id]["latents"],} for id in id_list] + images = [self.items_inmemory[id]["file_name"] for id in id_list] + return data, images + + def get_item(self, id): + raise NotImplementedError("Dataset:get_item") + + +class ImagesDataset(torch.utils.data.Dataset): + + def __init__(self, imgs, transforms=None): + self.imgs = imgs + self.transforms = transforms + + def __len__(self): + return len(self.imgs) + + def __getitem__(self, i): + img = self.imgs[i] + if self.transforms is not None: + img = self.transforms(img) + return img + + +class Coco(Dataset): + + def __init__( + self, + data_path, + name=None, + image_size=None, + pre_process=None, + pipe_tokenizer=None, + pipe_tokenizer_2=None, + latent_dtype=torch.float32, + latent_device="cuda", + latent_framework="torch", + **kwargs, + ): + super().__init__() + self.captions_df = pd.read_csv(f"{data_path}/captions/captions.tsv", sep="\t") + self.image_size = image_size + self.preprocessed_dir = os.path.abspath(f"{data_path}/preprocessed/") + self.img_dir = os.path.abspath(f"{data_path}/validation/data/") + self.name = name + + # Preprocess prompts + self.captions_df["input_tokens"] = self.captions_df["caption"].apply( + lambda x: self.preprocess(x, pipe_tokenizer)) + self.captions_df["input_tokens_2"] = self.captions_df["caption"].apply( + lambda x: self.preprocess(x, pipe_tokenizer_2)) + self.latent_dtype = latent_dtype + self.latent_device = latent_device if torch.cuda.is_available() else "cpu" + if latent_framework == "torch": + self.latents = ( + torch.load(f"{data_path}/latents/latents.pt").to(latent_dtype).to(latent_device)) + elif latent_framework == "numpy": + self.latents = ( + torch.Tensor( + np.load(f"{data_path}/latents/latents.npy")).to(latent_dtype).to(latent_device)) + + def preprocess(self, prompt, tokenizer): + converted_prompt = self.convert_prompt(prompt, tokenizer) + return tokenizer( + converted_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + def image_to_tensor(self, img): + img = np.asarray(img) + if len(img.shape) == 2: + img = np.expand_dims(img, axis=-1) + tensor = torch.Tensor(img.transpose([2, 0, 1])).to(torch.uint8) + if tensor.shape[0] == 1: + tensor = tensor.repeat(3, 1, 1) + return tensor + + def preprocess_images(self, file_name): + img = Image.open(self.img_dir + "/" + file_name) + tensor = self.image_to_tensor(img) + target_name = file_name.split(".")[0] + target_path = self.preprocessed_dir + "/" + target_name + ".pt" + if not os.path.exists(target_path): + torch.save(tensor, target_path) + return target_path + + def convert_prompt(self, prompt, tokenizer): + tokens = tokenizer.tokenize(prompt) + unique_tokens = set(tokens) + for token in unique_tokens: + if token in tokenizer.added_tokens_encoder: + replacement = token + i = 1 + while f"{token}_{i}" in tokenizer.added_tokens_encoder: + replacement += f" {token}_{i}" + i += 1 + + prompt = prompt.replace(token, replacement) + + return prompt + + def get_item(self, id): + return dict(self.captions_df.loc[id], latents=self.latents) + + def get_item_count(self): + return len(self.captions_df) + + def get_img(self, id): + img = Image.open(self.img_dir + "/" + self.captions_df.loc[id]["file_name"]) + return self.image_to_tensor(img) + + def get_imgs(self, id_list): + image_list = [] + for id in id_list: + image_list.append(self.get_img(id)) + return image_list + + def get_caption(self, i): + return self.get_item(i)["caption"] + + def get_captions(self, id_list): + return [self.get_caption(id) for id in id_list] + + def get_item_loc(self, id): + return self.img_dir + "/" + self.captions_df.loc[id]["file_name"] diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt new file mode 100644 index 000000000..3b453267e --- /dev/null +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt @@ -0,0 +1,8 @@ +accelerate==0.23.0 +diffusers==0.21.2 +open-clip-torch==2.7.0 +opencv-python==4.8.1.78 +pycocotools==2.0.7 +scipy==1.9.1 +torchmetrics[image]==1.2.0 +transformers==4.33.2 diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py index b466d6303..7ce70e783 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/export.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -7,12 +7,7 @@ import torch from torch import nn -from torch._decomp import get_decompositions -from brevitas.backport.fx.experimental.proxy_tensor import make_fx -from brevitas.export.manager import _force_requires_grad_false -from brevitas.export.manager import _JitTraceExportWrapper -from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode @@ -31,3 +26,12 @@ def export_onnx(pipe, trace_inputs, output_dir, export_manager): print(f"Saving unet to {output_path} ...") with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager): torch.onnx.export(pipe.unet, args=trace_inputs, f=output_path) + + +def export_torch_export(pipe, trace_inputs, output_dir, export_manager): + output_path = os.path.join(output_dir, 'unet.onnx') + print(trace_inputs[1]) + print(f"Saving unet to {output_path} ...") + with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager): + torch.export.export( + UnetExportWrapper(pipe.unet), args=(trace_inputs[0],), kwargs=trace_inputs[1]) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/utils.py b/src/brevitas_examples/stable_diffusion/sd_quant/utils.py index b2c30176f..a5af383ef 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/utils.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/utils.py @@ -3,8 +3,81 @@ SPDX-License-Identifier: MIT """ +from contextlib import contextmanager + import torch +from brevitas.export.common.handler.base import BaseHandler +from brevitas.export.manager import _set_proxy_export_handler +from brevitas.export.manager import _set_proxy_export_mode +from brevitas.export.manager import BaseManager +from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer +from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector + + +class InferenceWeightProxyHandler(BaseHandler): + handled_layer = WeightQuantProxyFromInjector + + def __init__(self): + super(InferenceWeightProxyHandler, self).__init__() + self.scale = None + self.zero_point = None + self.bit_width = None + self.float_weight = None + + def prepare_for_export(self, module): + assert len(module.tracked_module_list) == 1, "Shared quantizers not supported." + quant_layer = module.tracked_module_list[0] + self.float_weight = quant_layer.quant_weight() + quant_layer.weight.data = quant_layer.weight.data.cpu() + self.scale = module.scale() + self.zero_point = module.zero_point() + self.bit_width = module.bit_width() + + def forward(self, x): + return self.float_weight, self.scale, self.zero_point, self.bit_width + + +class InferenceWeightProxyManager(BaseManager): + handlers = [InferenceWeightProxyHandler] + + @classmethod + def set_export_handler(cls, module): + if hasattr(module, + 'requires_export_handler') and module.requires_export_handler and not isinstance( + module, (WeightQuantProxyFromInjector)): + return + _set_proxy_export_handler(cls, module) + + +def store_mapping_tensor_state_dict(model): + mapping = dict() + for module in model.modules(): + if isinstance(module, QuantWeightBiasInputOutputLayer): + mapping[module.weight.data_ptr()] = module.weight.device + return mapping + + +def restore_mapping(model, mapping): + for module in model.modules(): + if isinstance(module, QuantWeightBiasInputOutputLayer): + module.weight.data = module.weight.data.to(mapping[module.weight.data_ptr()]) + + +@contextmanager +def brevitas_proxy_inference_mode(model): + mapping = store_mapping_tensor_state_dict(model) + is_training = model.training + model.eval() + model.apply(InferenceWeightProxyManager.set_export_handler) + _set_proxy_export_mode(model, enabled=True) + try: + yield model + finally: + restore_mapping(model, mapping) + _set_proxy_export_mode(model, enabled=False) + model.train(is_training) + def unet_input_shape(resolution): return (4, resolution // 8, resolution // 8) From 31fdcb71ba76057f389fb89df4010e6e32087abc Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 31 May 2024 13:19:06 +0100 Subject: [PATCH 02/22] Update --- src/brevitas/quant/shifted_scaled_int.py | 2 +- .../stable_diffusion/main.py | 190 ++++++++++-------- .../mlperf_evaluation/accuracy.py | 5 +- .../stable_diffusion/sd_quant/utils.py | 37 +++- 4 files changed, 136 insertions(+), 98 deletions(-) diff --git a/src/brevitas/quant/shifted_scaled_int.py b/src/brevitas/quant/shifted_scaled_int.py index 936737571..72507e56a 100644 --- a/src/brevitas/quant/shifted_scaled_int.py +++ b/src/brevitas/quant/shifted_scaled_int.py @@ -50,7 +50,7 @@ class ShiftedUint8ActPerTensorFixedPointMSE(MSEAsymmetricScale, class ShiftedUint8ActPerTensorFloat(ShiftedParamFromPercentileUintQuant, - ParamFromRuntimePercentileIntervalScaling, + ParamFromRuntimeMinMaxScaling, PerTensorFloatScaling8bit, ActQuantSolver): """ diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index b5c2eee5f..4c08a569e 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -24,11 +24,12 @@ from brevitas.export.torch.qcdq.manager import TorchQCDQManager from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode +from brevitas.graph.calibrate import load_quant_model_mode from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.gptq import gptq_mode from brevitas.inject.enum import QuantType from brevitas.nn.equalized_layer import EqualizedModule -from brevitas.nn.forward_handlers import brevitas_proxy_inference_mode +from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer from brevitas.utils.torch_utils import KwargsForwardHook from brevitas_examples.common.generative.quantize import quantize_model from brevitas_examples.common.parse_utils import add_bool_arg @@ -39,6 +40,7 @@ from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.export import export_onnx from brevitas_examples.stable_diffusion.sd_quant.export import export_torch_export +from brevitas_examples.stable_diffusion.sd_quant.utils import brevitas_proxy_inference_mode from brevitas_examples.stable_diffusion.sd_quant.utils import generate_latents from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_21_rand_inputs from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_xl_rand_inputs @@ -119,7 +121,7 @@ def run_val_inference( test_latents = generate_latents(seeds[0], device, dtype, unet_input_shape(resolution)) neg_prompts = NEGATIVE_PROMPTS if use_negative_prompts else [] - for prompt in prompts: + for prompt in tqdm(prompts): # We don't want to generate any image, so we return only the latent encoding pre VAE pipe( prompt, @@ -203,7 +205,7 @@ def main(args): blacklist = [] for name, _ in pipe.unet.named_modules(): if 'time_emb' in name or 'conv_in' in name: - blacklist.append(name) + blacklist.append(name.split('.')[-1]) print(f"Blacklisted layers: {blacklist}") # Make sure there all LoRA layers are fused first, otherwise raise an error @@ -211,7 +213,8 @@ def main(args): if hasattr(m, 'lora_layer') and m.lora_layer is not None: raise RuntimeError("LoRA layers should be fused in before calling into quantization.") - if args.activation_equalization: + if args.activation_equalization and args.load_checkpoint is None: + pipe.set_progress_bar_config(disable=True) with activation_equalization_mode(pipe.unet, alpha=0.5, layerwise=True, add_mul_node=True): # Workaround to expose `in_features` attribute from the Hook Wrapper for m in pipe.unet.modules(): @@ -231,8 +234,8 @@ def main(args): # Workaround to expose `in_features` attribute from the EqualizedModule Wrapper for m in pipe.unet.modules(): - if isinstance(m, EqualizedModule) and hasattr(m.module, 'in_features'): - m.in_features = m.module.in_features + if isinstance(m, EqualizedModule) and hasattr(m.layer, 'in_features'): + m.in_features = m.layer.in_features # Quantize model if args.quantize: @@ -293,32 +296,56 @@ def input_quant_type(module): use_ocp=args.use_ocp, input_kwargs=input_kwargs) print("Model quantization applied.") - - if (args.linear_input_bit_width is not None or - args.conv_input_bit_width is not None) and args.input_scale_type == 'static': - print("Applying activation calibration") - with brevitas_proxy_inference_mode(pipe.unet), torch.no_grad(), calibration_mode(pipe.unet): - run_val_inference( - pipe, - args.resolution, - calibration_prompts, - test_seeds, - args.device, - dtype, - total_steps=args.calibration_steps, - use_negative_prompts=args.use_negative_prompts, - test_latents=latents, - guidance_scale=args.guidance_scale) + list_of_names = [ + n for n, + m in pipe.unet.named_modules() if isinstance(m, QuantWeightBiasInputOutputLayer)] pipe.set_progress_bar_config(disable=True) + if args.load_checkpoint is not None: + with load_quant_model_mode(pipe.unet): + pipe.unet.load_state_dict(torch.load(args.load_checkpoint)) + pipe = pipe.to(args.device) + else: + if (args.linear_input_bit_width is not None or + args.conv_input_bit_width is not None) and args.input_scale_type == 'static': + print("Applying activation calibration") + with torch.no_grad(), calibration_mode(pipe.unet): + run_val_inference( + pipe, + args.resolution, + calibration_prompts, + test_seeds, + args.device, + dtype, + total_steps=args.calibration_steps, + use_negative_prompts=args.use_negative_prompts, + test_latents=latents, + guidance_scale=args.guidance_scale) - if args.gptq: - print("Applying GPTQ. It can take several hours") - with torch.no_grad(), gptq_mode(pipe.unet, - create_weight_orig=False, - use_quant_activations=False, - return_forward_output=True, - act_order=True) as gptq: - for _ in tqdm(range(gptq.num_layers)): + if args.gptq: + print("Applying GPTQ. It can take several hours") + with torch.no_grad(), gptq_mode(pipe.unet, + create_weight_orig=False, + use_quant_activations=False, + return_forward_output=True, + act_order=True) as gptq: + for _ in tqdm(range(gptq.num_layers)): + run_val_inference( + pipe, + args.resolution, + calibration_prompts, + test_seeds, + args.device, + dtype, + total_steps=args.calibration_steps, + use_negative_prompts=args.use_negative_prompts, + test_latents=latents, + guidance_scale=args.guidance_scale) + gptq.update() + torch.cuda.empty_cache() + + if args.bias_correction: + print("Applying bias correction") + with bias_correction_mode(pipe.unet): run_val_inference( pipe, args.resolution, @@ -330,69 +357,49 @@ def input_quant_type(module): use_negative_prompts=args.use_negative_prompts, test_latents=latents, guidance_scale=args.guidance_scale) - gptq.update() - torch.cuda.empty_cache() - pipe.set_progress_bar_config(disable=False) - - if args.bias_correction: - print("Applying bias correction") - with brevitas_proxy_inference_mode(pipe.unet), bias_correction_mode(pipe.unet): - run_val_inference( - pipe, - args.resolution, - calibration_prompts, - test_seeds, - args.device, - dtype, - total_steps=args.calibration_steps, - use_negative_prompts=args.use_negative_prompts, - test_latents=latents, - guidance_scale=args.guidance_scale) - - if args.checkpoint_name is not None: + + if args.checkpoint_name is not None and args.load_checkpoint is None: torch.save(pipe.unet.state_dict(), args.checkpoint_name) # Perform inference if args.prompt > 0: - with brevitas_proxy_inference_mode(pipe.unet): - if args.use_mlperf_inference: - print(f"Computing accuracy with MLPerf pipeline") - compute_mlperf_fid(pipe.unet, args.prompt) - else: - print(f"Computing accuracy on default prompt") - prompts = list() - for i, v in enumerate(TESTING_PROMPTS): - if i == args.prompt: - break - prompts.append(v) - quant_images = run_test_inference( - pipe, - args.resolution, - prompts, - test_seeds, - output_dir, - args.device, - dtype, - use_negative_prompts=args.use_negative_prompts, - guidance_scale=args.guidance_scale, - name_prefix='quant_') - - float_images_values = float_images.values() - float_images_values = [x for x_nested in float_images_values for x in x_nested] - float_images_values = torch.tensor([ - np.array(image) for image in float_images_values]) - float_images_values = float_images_values.permute(0, 3, 1, 2) - - quant_images_values = quant_images.values() - quant_images_values = [x for x_nested in quant_images_values for x in x_nested] - quant_images_values = torch.tensor([ - np.array(image) for image in quant_images_values]) - quant_images_values = quant_images_values.permute(0, 3, 1, 2) - - fid = FrechetInceptionDistance(normalize=False) - fid.update(float_images_values, real=True) - fid.update(quant_images_values, real=False) - print(f"FID: {float(fid.compute())}") + # with brevitas_proxy_inference_mode(pipe.unet): + if args.use_mlperf_inference: + print(f"Computing accuracy with MLPerf pipeline") + compute_mlperf_fid(pipe, args.prompt) + else: + print(f"Computing accuracy on default prompt") + prompts = list() + for i, v in enumerate(TESTING_PROMPTS): + if i == args.prompt: + break + prompts.append(v) + quant_images = run_test_inference( + pipe, + args.resolution, + prompts, + test_seeds, + output_dir, + args.device, + dtype, + use_negative_prompts=args.use_negative_prompts, + guidance_scale=args.guidance_scale, + name_prefix='quant_') + + float_images_values = float_images.values() + float_images_values = [x for x_nested in float_images_values for x in x_nested] + float_images_values = torch.tensor([np.array(image) for image in float_images_values]) + float_images_values = float_images_values.permute(0, 3, 1, 2) + + quant_images_values = quant_images.values() + quant_images_values = [x for x_nested in quant_images_values for x in x_nested] + quant_images_values = torch.tensor([np.array(image) for image in quant_images_values]) + quant_images_values = quant_images_values.permute(0, 3, 1, 2) + + fid = FrechetInceptionDistance(normalize=False) + fid.update(float_images_values, real=True) + fid.update(quant_images_values, real=False) + print(f"FID: {float(fid.compute())}") if args.export_target: # Move to cpu and to float32 to enable CPU export @@ -464,6 +471,11 @@ def input_quant_type(module): type=str, default=None, help='Name to use to store the checkpoint. If not provided, no checkpoint is saved.') + parser.add_argument( + '--load-checkpoint', + type=str, + default=None, + help='Path to checkpoint to load. If provided, PTQ techniques are skipped.') parser.add_argument( '--path-to-latents', type=str, diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py index 71353e9cb..b1826e9ad 100644 --- a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py @@ -488,18 +488,19 @@ def compute_mlperf_fid(model_to_replace=None, samples_to_evaluate=500): post_proc = PostProcessCoco( statistics_path='/scratch/users/gfranco/datasets/coco/tools/val2014.npz') - dtype = next(iter(model_to_replace.parameters())).dtype + dtype = next(iter(model_to_replace.unet.parameters())).dtype res_dict = {} model = BackendPytorch( '/scratch/hf_models/stable-diffusion-xl-base-1.0/stable-diffusion-xl-base-1.0/', 'xl', steps=20, batch_size=1, + device='cpu', precision=dtype) model.load() if model_to_replace is not None: - model.pipe.unet = model_to_replace + model.pipe = model_to_replace ds = Coco( data_path='/scratch/users/gfranco/datasets/coco', diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/utils.py b/src/brevitas_examples/stable_diffusion/sd_quant/utils.py index a5af383ef..2700dd032 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/utils.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/utils.py @@ -23,19 +23,44 @@ def __init__(self): self.scale = None self.zero_point = None self.bit_width = None + self.dtype = None self.float_weight = None + def scaling_impl(self, proxy_module): + return proxy_module.tensor_quant.scaling_impl + + def zero_point_impl(self, proxy_module): + return proxy_module.tensor_quant.zero_point_impl + + def bit_width_impl(self, proxy_module): + return proxy_module.tensor_quant.msb_clamp_bit_width_impl + + def export_scale(self, proxy_module, bit_width): + scaling_impl = self.scaling_impl(proxy_module) + int_scaling_impl = proxy_module.tensor_quant.int_scaling_impl + int_threshold = int_scaling_impl(bit_width) + threshold = scaling_impl.stats_scaling_impl(scaling_impl.parameter_list_stats()) + return threshold / int_threshold + + def export_zero_point(self, proxy_module, weight, scale, bit_width): + zero_point_impl = self.zero_point_impl(proxy_module) + return zero_point_impl(weight, scale, bit_width) + def prepare_for_export(self, module): assert len(module.tracked_module_list) == 1, "Shared quantizers not supported." + self.bit_width = self.bit_width_impl(module)() + assert self.bit_width <= 8., "Only 8b or lower is supported." quant_layer = module.tracked_module_list[0] self.float_weight = quant_layer.quant_weight() - quant_layer.weight.data = quant_layer.weight.data.cpu() - self.scale = module.scale() - self.zero_point = module.zero_point() - self.bit_width = module.bit_width() + self.dtype = self.float_weight.value.dtype + # if (self.float_weight.zero_point != 0.).any(): + # self.zero_point = self.export_zero_point(module, quant_layer.weight, self.scale, self.bit_width).detach().cpu() + # self.scale = self.export_scale(module, self.bit_width).detach().cpu() + # quant_layer.weight.data = quant_layer.weight.data.cpu() def forward(self, x): - return self.float_weight, self.scale, self.zero_point, self.bit_width + + return self.float_weight.value, self.float_weight.scale, self.float_weight.zero_point, self.bit_width class InferenceWeightProxyManager(BaseManager): @@ -70,7 +95,7 @@ def brevitas_proxy_inference_mode(model): is_training = model.training model.eval() model.apply(InferenceWeightProxyManager.set_export_handler) - _set_proxy_export_mode(model, enabled=True) + _set_proxy_export_mode(model, enabled=True, proxy_class=WeightQuantProxyFromInjector) try: yield model finally: From 4eb29c6f60af0cafa31108619e591478863a1796 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 21 May 2024 13:04:38 +0100 Subject: [PATCH 03/22] Missing file --- .../mlperf_evaluation/accuracy.py | 2 +- .../mlperf_evaluation/inception.py | 510 ++++++++++++++++++ 2 files changed, 511 insertions(+), 1 deletion(-) create mode 100644 src/brevitas_examples/stable_diffusion/mlperf_evaluation/inception.py diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py index b1826e9ad..ec1f2e510 100644 --- a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py @@ -193,12 +193,12 @@ import torchvision.transforms as TF from tqdm import tqdm -from brevitas_examples.inception import InceptionV3 from brevitas_examples.stable_diffusion.mlperf_evaluation.backend import BackendPytorch from brevitas_examples.stable_diffusion.mlperf_evaluation.backend import Item from brevitas_examples.stable_diffusion.mlperf_evaluation.backend import RunnerBase from brevitas_examples.stable_diffusion.mlperf_evaluation.dataset import Coco from brevitas_examples.stable_diffusion.mlperf_evaluation.dataset import ImagesDataset +from brevitas_examples.stable_diffusion.mlperf_evaluation.inception import InceptionV3 IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 'tif', 'tiff', 'webp'} diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/inception.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/inception.py new file mode 100644 index 000000000..29268fb7d --- /dev/null +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/inception.py @@ -0,0 +1,510 @@ +""" +Code is adapted from the MLPerf text-to-image pipeline: https://github.com/mlcommons/inference/tree/master/text_to_image +Available under the following LICENSE: + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +try: + from torchvision.models.utils import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling featurs + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__( + self, + output_blocks=(DEFAULT_BLOCK_INDEX,), + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3 + + Parameters + ---------- + output_blocks : list of int + Indices of blocks to return features of. Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input : bool + If true, bilinearly resizes input to width and height 299 before + feeding input to model. As the network without fully connected + layers is fully convolutional, it should be able to handle inputs + of arbitrary size, so resizing might not be strictly needed + normalize_input : bool + If true, scales the input from range (0, 1) to the range the + pretrained Inception network expects, namely (-1, 1) + requires_grad : bool + If true, parameters of the model require gradients. Possibly useful + for finetuning the network + use_fid_inception : bool + If true, uses the pretrained Inception model used in Tensorflow's + FID implementation. If false, uses the pretrained Inception model + available in torchvision. The FID Inception model has different + weights and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get comparable + results. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + assert self.last_needed_block <= 3, \ + 'Last possible output block index is 3' + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + inception = _inception_v3(weights='DEFAULT') + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, + inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2)] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [ + inception.Conv2d_3b_1x1, + inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2)] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e,] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, + inception.Mixed_7b, + inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1))] + self.blocks.append(nn.Sequential(*block3)) + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, inp): + """Get Inception feature maps + + Parameters + ---------- + inp : torch.autograd.Variable + Input tensor of shape Bx3xHxW. Values are expected to be in + range (0, 1) + + Returns + ------- + List of torch.autograd.Variable, corresponding to the selected output + block, sorted ascending by index + """ + outp = [] + x = inp + + if self.resize_input: + x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + x = block(x) + if idx in self.output_blocks: + outp.append(x) + + if idx == self.last_needed_block: + break + + return outp + + +def _inception_v3(*args, **kwargs): + """Wraps `torchvision.models.inception_v3`""" + try: + version = tuple(map(int, torchvision.__version__.split('.')[:2])) + except ValueError: + # Just a caution against weird version strings + version = (0,) + + # Skips default weight inititialization if supported by torchvision + # version. See https://github.com/mseitzer/pytorch-fid/issues/28. + if version >= (0, 6): + kwargs['init_weights'] = False + + # Backwards compatibility: `weights` argument was handled by `pretrained` + # argument prior to version 0.13. + if version < (0, 13) and 'weights' in kwargs: + if kwargs['weights'] == 'DEFAULT': + kwargs['pretrained'] = True + elif kwargs['weights'] is None: + kwargs['pretrained'] = False + else: + raise ValueError( + 'weights=={} not supported in torchvision {}'.format( + kwargs['weights'], torchvision.__version__)) + del kwargs['weights'] + + return torchvision.models.inception_v3(*args, **kwargs) + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation + + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + inception = _inception_v3(num_classes=1008, aux_logits=False, weights=None) + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) + inception.load_state_dict(state_dict) + return inception + + +class FIDInceptionA(torchvision.models.inception.InceptionA): + """InceptionA block patched for FID computation""" + + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(torchvision.models.inception.InceptionC): + """InceptionC block patched for FID computation""" + + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(torchvision.models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3),] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl),] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(torchvision.models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3),] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl),] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) From a93ad7c22ae2495f0c404beffb94ca93682572f8 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 21 May 2024 13:59:02 +0100 Subject: [PATCH 04/22] New update --- src/brevitas/quant/shifted_scaled_int.py | 2 +- .../common/generative/quantize.py | 22 +++++++++-- .../stable_diffusion/main.py | 37 +++++++++++++++---- 3 files changed, 49 insertions(+), 12 deletions(-) diff --git a/src/brevitas/quant/shifted_scaled_int.py b/src/brevitas/quant/shifted_scaled_int.py index 72507e56a..936737571 100644 --- a/src/brevitas/quant/shifted_scaled_int.py +++ b/src/brevitas/quant/shifted_scaled_int.py @@ -50,7 +50,7 @@ class ShiftedUint8ActPerTensorFixedPointMSE(MSEAsymmetricScale, class ShiftedUint8ActPerTensorFloat(ShiftedParamFromPercentileUintQuant, - ParamFromRuntimeMinMaxScaling, + ParamFromRuntimePercentileIntervalScaling, PerTensorFloatScaling8bit, ActQuantSolver): """ diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 0b1614e74..346f41c81 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -8,8 +8,10 @@ from torch import nn from brevitas import nn as qnn +from brevitas.core.stats.stats_op import NegativeMinOrZero from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint from brevitas.graph.quantize import layerwise_quantize +from brevitas.inject.enum import StatsOp from brevitas.quant.experimental.float import Fp8e4m3Act from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat @@ -150,6 +152,7 @@ def quantize_model( input_quant_type=None, input_quant_granularity=None, input_group_size=None, + input_stats_op='percentile', quantize_input_zero_point=False, quantize_embedding=False, use_ocp=False, @@ -195,10 +198,21 @@ def quantize_model( linear_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ input_scale_precision][input_param_method][input_quant_granularity][input_quant_type] - if input_kwargs is not None: - input_quant = input_quant.let(**input_kwargs) - sym_input_quant = sym_input_quant.let(**input_kwargs) - linear_input_quant = linear_input_quant.let(**input_kwargs) + if input_kwargs is None: + input_kwargs = dict() + + if input_stats_op == 'minmax': + if input_quant_type == 'asym': + input_scaling_stats_op = StatsOp.MIN_MAX + # zero_point_stats_impl = NegativeMinOrZero + # input_kwargs['zero_point_stats_impl'] = zero_point_stats_impl + else: + input_scaling_stats_op = StatsOp.MAX + input_kwargs['scaling_stats_op'] = input_scaling_stats_op + + input_quant = input_quant.let(**input_kwargs) + sym_input_quant = sym_input_quant.let(**input_kwargs) + linear_input_quant = linear_input_quant.let(**input_kwargs) else: input_quant = None diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 4c08a569e..253ca3ac7 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -220,6 +220,8 @@ def main(args): for m in pipe.unet.modules(): if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'): m.in_features = m.module.in_features + if args.dry_run: + calibration_prompts = [calibration_prompts[0]] run_val_inference( pipe, args.resolution, @@ -293,18 +295,29 @@ def input_quant_type(module): input_param_method=args.input_param_method, input_quant_type=args.input_quant_type, input_quant_granularity=args.input_quant_granularity, + input_stats_op=args.input_stats_op, use_ocp=args.use_ocp, input_kwargs=input_kwargs) print("Model quantization applied.") - list_of_names = [ - n for n, - m in pipe.unet.named_modules() if isinstance(m, QuantWeightBiasInputOutputLayer)] + pipe.set_progress_bar_config(disable=True) + if args.dry_run: + with torch.no_grad(): + run_val_inference( + pipe, + args.resolution, [calibration_prompts[0]], + test_seeds, + args.device, + dtype, + total_steps=args.calibration_steps, + use_negative_prompts=args.use_negative_prompts, + test_latents=latents, + guidance_scale=args.guidance_scale) if args.load_checkpoint is not None: with load_quant_model_mode(pipe.unet): pipe.unet.load_state_dict(torch.load(args.load_checkpoint)) pipe = pipe.to(args.device) - else: + elif not args.dry_run: if (args.linear_input_bit_width is not None or args.conv_input_bit_width is not None) and args.input_scale_type == 'static': print("Applying activation calibration") @@ -342,7 +355,6 @@ def input_quant_type(module): guidance_scale=args.guidance_scale) gptq.update() torch.cuda.empty_cache() - if args.bias_correction: print("Applying bias correction") with bias_correction_mode(pipe.unet): @@ -358,11 +370,11 @@ def input_quant_type(module): test_latents=latents, guidance_scale=args.guidance_scale) - if args.checkpoint_name is not None and args.load_checkpoint is None: + if args.checkpoint_name is not None and args.load_checkpoint is None and not args.dry_run: torch.save(pipe.unet.state_dict(), args.checkpoint_name) # Perform inference - if args.prompt > 0: + if args.prompt > 0 and not args.dry_run: # with brevitas_proxy_inference_mode(pipe.unet): if args.use_mlperf_inference: print(f"Computing accuracy with MLPerf pipeline") @@ -557,6 +569,12 @@ def input_quant_type(module): default='stats', choices=['stats', 'mse'], help='How scales/zero-point are determined. Default: stats.') + parser.add_argument( + '--input-stats-op', + type=str, + default='minmax', + choices=['minmax', 'percentile'], + help='How scales/zero-point are determined. Default: stats.') parser.add_argument( '--weight-scale-precision', type=str, @@ -638,6 +656,11 @@ def input_quant_type(module): 'use-negative-prompts', default=True, help='Use negative prompts during generation/calibration. Default: Enabled') + add_bool_arg( + parser, + 'dry-run', + default=False, + help='Generate a quantized model without any calibration. Default: Disabled') args = parser.parse_args() print("Args: " + str(vars(args))) main(args) From 52d612000cbcf305fce3206d25caa207e559dda8 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 21 May 2024 17:13:39 +0100 Subject: [PATCH 05/22] Updated readme --- .../stable_diffusion/README.md | 29 +++++++++++++++++++ .../stable_diffusion/main.py | 4 +-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/README.md b/src/brevitas_examples/stable_diffusion/README.md index 1c71e4431..03393d0a4 100644 --- a/src/brevitas_examples/stable_diffusion/README.md +++ b/src/brevitas_examples/stable_diffusion/README.md @@ -18,6 +18,23 @@ If the flag is not enabled, the model will be moved to CPU and cast to float32 b To use MLPerf inference setup, check and install the correct requirements specified in the `requirements.txt` file under mlperf_evaluation. +For example, to perform weight-only quantization on SDXL, the following can be used: + +`python main.py --resolution 1024 --batch 1 --model /path/to/sdxl --prompt 500 --conv-weight-bit-width 8 --linear-weight-bit-width 8 --dtype float16 --weight-quant-type sym --calibration-steps 8 --guidance-scale 8. --use-negative-prompts --calibration-prompt 500 --activation-eq --use-mlperf` + +To add activation quantization: + +`--linear-input-bit 8 --conv-input-bit 8` + +To choose between `static` or `dynamic` activation quantization, set the flag: `--input-scale-type` to either option + +To include export: +`--export-target torch` or `--export-target onnx` + +To perform a dry-run quantization, where only the structure of the quantized model is preserved but no calibration of the quantized parameter is performed, add the `--dry-run` flag. + + + ## Run ```bash @@ -25,6 +42,7 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--calibration-prompt CALIBRATION_PROMPT] [--calibration-prompt-path CALIBRATION_PROMPT_PATH] [--checkpoint-name CHECKPOINT_NAME] + [--load-checkpoint LOAD_CHECKPOINT] [--path-to-latents PATH_TO_LATENTS] [--resolution RESOLUTION] [--guidance-scale GUIDANCE_SCALE] [--calibration-steps CALIBRATION_STEPS] @@ -42,6 +60,7 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH] [--weight-param-method {stats,mse}] [--input-param-method {stats,mse}] + [--input-stats-op {minmax,percentile}] [--weight-scale-precision {float_scale,po2_scale}] [--input-scale-precision {float_scale,po2_scale}] [--weight-quant-type {sym,asym}] @@ -57,6 +76,7 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--use-mlperf-inference | --no-use-mlperf-inference] [--use-ocp | --no-use-ocp] [--use-negative-prompts | --no-use-negative-prompts] + [--dry-run | --no-dry-run] Stable Diffusion quantization @@ -78,6 +98,9 @@ options: --checkpoint-name CHECKPOINT_NAME Name to use to store the checkpoint. If not provided, no checkpoint is saved. + --load-checkpoint LOAD_CHECKPOINT + Path to checkpoint to load. If provided, PTQ + techniques are skipped. --path-to-latents PATH_TO_LATENTS Load pre-defined latents. If not provided, they are generated based on an internal seed. @@ -129,6 +152,8 @@ options: How scales/zero-point are determined. Default: stats. --input-param-method {stats,mse} How scales/zero-point are determined. Default: stats. + --input-stats-op {minmax,percentile} + Define what statics op to use . Default: minmax. --weight-scale-precision {float_scale,po2_scale} Whether scale is a float value or a po2. Default: float_scale. @@ -181,5 +206,9 @@ options: --no-use-negative-prompts Disable Use negative prompts during generation/calibration. Default: Enabled + --dry-run Enable Generate a quantized model without any + calibration. Default: Disabled + --no-dry-run Disable Generate a quantized model without any + calibration. Default: Disabled ``` diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 253ca3ac7..3cc87cac8 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -215,7 +215,7 @@ def main(args): if args.activation_equalization and args.load_checkpoint is None: pipe.set_progress_bar_config(disable=True) - with activation_equalization_mode(pipe.unet, alpha=0.5, layerwise=True, add_mul_node=True): + with activation_equalization_mode(pipe.unet, alpha=0.9, layerwise=True, add_mul_node=True): # Workaround to expose `in_features` attribute from the Hook Wrapper for m in pipe.unet.modules(): if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'): @@ -574,7 +574,7 @@ def input_quant_type(module): type=str, default='minmax', choices=['minmax', 'percentile'], - help='How scales/zero-point are determined. Default: stats.') + help='Define what statics op to use . Default: minmax.') parser.add_argument( '--weight-scale-precision', type=str, From dccdc12cd877a07d427f7f82934b76bbe3fac136 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 22 May 2024 10:08:13 +0100 Subject: [PATCH 06/22] FP OCP weight act support --- .../common/generative/quantize.py | 88 ++++++++++++++----- 1 file changed, 67 insertions(+), 21 deletions(-) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 346f41c81..df2a6b16c 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -16,8 +16,12 @@ from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPActPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeightPerChannelFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeightPerTensorFloat from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint @@ -86,12 +90,20 @@ 'per_group': { 'sym': Fp8e4m3WeightSymmetricGroupQuant}}}}, 'float_ocp': { - 'float_scale': { - 'stats': { - 'per_tensor': { - 'sym': Fp8e4m3OCPWeightPerTensorFloat}, - 'per_channel': { - 'sym': Fp8e4m3OCPWeightPerChannelFloat}}}}} + 'e4m3': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3OCPWeightPerTensorFloat}, + 'per_channel': { + 'sym': Fp8e4m3OCPWeightPerChannelFloat}}}}, + 'e5m2': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e5m2OCPWeightPerTensorFloat}, + 'per_channel': { + 'sym': Fp8e5m2OCPWeightPerChannelFloat}}}}}} INPUT_QUANT_MAP = { 'int': { @@ -129,7 +141,18 @@ 'per_tensor': { 'sym': Fp8e4m3ActPerTensorFloat},}}}, 'no_scale': { - 'sym': Fp8e4m3Act,}}} + 'sym': Fp8e4m3Act,}}, + 'float_ocp': { + 'e4m3': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3OCPActPerTensorFloat}}}}, + 'e5m2': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e5m2OCPActPerTensorFloat}}}}}} def quantize_model( @@ -163,40 +186,63 @@ def quantize_model( Replace float layers with quant layers in the target model """ # Retrive base input and weight quantizers - + ocp_weight_format = None + ocp_input_format = None # match against custom float format if re.compile(r'e[1-8]m[1-8]').match(weight_quant_format): weight_float_format = { 'exponent_bit_width': int(weight_quant_format[1]), 'mantissa_bit_width': int(weight_quant_format[3])} - weight_quant_format = 'float' - if use_ocp: + if ocp_weight_format: weight_quant_format += '_ocp' + ocp_weight_format = weight_quant_format + weight_quant_format = 'float' else: weight_float_format = {} if re.compile(r'e[1-8]m[1-8]').match(input_quant_format): input_float_format = { 'exponent_bit_width': int(input_quant_format[1]), 'mantissa_bit_width': int(input_quant_format[3])} - input_quant_format = 'float' - if use_ocp: + if ocp_weight_format: input_quant_format += '_ocp' + ocp_input_format = input_quant_format + input_quant_format = 'float' else: input_float_format = {} - weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][ - weight_param_method][weight_quant_granularity][weight_quant_type] + if ocp_weight_format is not None: + weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][ocp_weight_format][ + weight_scale_precision][weight_param_method][weight_quant_granularity][ + weight_quant_type] + else: + weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][ + weight_param_method][weight_quant_granularity][weight_quant_type] + if input_bit_width is not None and input_scale_type == 'no_scale': input_quant = sym_input_quant = linear_input_quant = INPUT_QUANT_MAP[input_quant_format][ input_scale_type][input_quant_type] elif input_bit_width is not None: - input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][input_scale_precision][ - input_param_method][input_quant_granularity][input_quant_type] - # Some activations in MHA should always be symmetric - sym_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ - input_scale_precision][input_param_method][input_quant_granularity]['sym'] - linear_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ - input_scale_precision][input_param_method][input_quant_granularity][input_quant_type] + if ocp_input_format: + input_quant = INPUT_QUANT_MAP[input_quant_format][ocp_input_format][input_scale_type][ + input_scale_precision][input_param_method][input_quant_granularity][ + input_quant_type] + # Some activations in MHA should always be symmetric + sym_input_quant = INPUT_QUANT_MAP[input_quant_format][ocp_input_format][ + input_scale_type][input_scale_precision][input_param_method][ + input_quant_granularity]['sym'] + linear_input_quant = INPUT_QUANT_MAP[input_quant_format][ocp_input_format][ + input_scale_type][input_scale_precision][input_param_method][ + input_quant_granularity][input_quant_type] + else: + input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ + input_scale_precision][input_param_method][input_quant_granularity][ + input_quant_type] + # Some activations in MHA should always be symmetric + sym_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ + input_scale_precision][input_param_method][input_quant_granularity]['sym'] + linear_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ + input_scale_precision][input_param_method][input_quant_granularity][ + input_quant_type] if input_kwargs is None: input_kwargs = dict() From 66428069558a4031978f282d8a5f95d3a876d1dc Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 22 May 2024 10:13:58 +0100 Subject: [PATCH 07/22] Save checkpoint with dry-run --- src/brevitas_examples/stable_diffusion/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 3cc87cac8..85e240333 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -370,7 +370,7 @@ def input_quant_type(module): test_latents=latents, guidance_scale=args.guidance_scale) - if args.checkpoint_name is not None and args.load_checkpoint is None and not args.dry_run: + if args.checkpoint_name is not None and args.load_checkpoint is None: torch.save(pipe.unet.state_dict(), args.checkpoint_name) # Perform inference From 2f2ef858cb2653b19ac2fe8c9db7f5e50a4cd6c2 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 22 May 2024 22:50:50 +0100 Subject: [PATCH 08/22] New config --- src/brevitas/nn/quant_layer.py | 7 +++-- .../common/generative/quantize.py | 4 +-- .../stable_diffusion/main.py | 29 +++++++++++++------ 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 43d97a071..215299837 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -166,8 +166,11 @@ def _load_from_state_dict( bias_key = prefix + 'bias' # If the state dict has a bias and the module does not, bias correction was used # We add a bias module to prevent failing during the load of the state dict - if bias_key in state_dict and self.bias is None and self._quant_load_model_mode: + if (bias_key in state_dict) and (self.bias is None) and self._quant_load_model_mode: self.register_parameter( - 'bias', torch.nn.Parameter(torch.zeros(self.out_channels)).to(self.weight.device)) + 'bias', + torch.nn.Parameter( + torch.zeros( + self.out_channels, device=self.weight.device, dtype=self.weight.dtype))) super(QuantWeightBiasInputOutputLayer, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index df2a6b16c..bfc2cf898 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -250,8 +250,8 @@ def quantize_model( if input_stats_op == 'minmax': if input_quant_type == 'asym': input_scaling_stats_op = StatsOp.MIN_MAX - # zero_point_stats_impl = NegativeMinOrZero - # input_kwargs['zero_point_stats_impl'] = zero_point_stats_impl + zero_point_stats_impl = NegativeMinOrZero + input_kwargs['zero_point_stats_impl'] = zero_point_stats_impl else: input_scaling_stats_op = StatsOp.MAX input_kwargs['scaling_stats_op'] = input_scaling_stats_op diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 85e240333..c83ec290e 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -202,25 +202,25 @@ def main(args): pipe.enable_attention_slicing() # Extract list of layers to avoid - blacklist = [] - for name, _ in pipe.unet.named_modules(): - if 'time_emb' in name or 'conv_in' in name: - blacklist.append(name.split('.')[-1]) - print(f"Blacklisted layers: {blacklist}") + # blacklist = [] + # for name, _ in pipe.unet.named_modules(): + # if 'time_emb' in name or 'conv_in' in name: + # blacklist.append(name) + # print(f"Blacklisted layers: {blacklist}") # Make sure there all LoRA layers are fused first, otherwise raise an error for m in pipe.unet.modules(): if hasattr(m, 'lora_layer') and m.lora_layer is not None: raise RuntimeError("LoRA layers should be fused in before calling into quantization.") - if args.activation_equalization and args.load_checkpoint is None: + if args.activation_equalization: pipe.set_progress_bar_config(disable=True) with activation_equalization_mode(pipe.unet, alpha=0.9, layerwise=True, add_mul_node=True): # Workaround to expose `in_features` attribute from the Hook Wrapper for m in pipe.unet.modules(): if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'): m.in_features = m.module.in_features - if args.dry_run: + if args.dry_run or args.load_checkpoint is not None: calibration_prompts = [calibration_prompts[0]] run_val_inference( pipe, @@ -279,7 +279,7 @@ def input_quant_type(module): pipe.unet, dtype=dtype, device=args.device, - name_blacklist=blacklist, + # name_blacklist=blacklist, weight_bit_width=weight_bit_width, weight_quant_format=args.weight_quant_format, weight_quant_type=args.weight_quant_type, @@ -300,7 +300,16 @@ def input_quant_type(module): input_kwargs=input_kwargs) print("Model quantization applied.") + for name, module in pipe.unet.named_modules(): + if 'time_emb' in name: # or 'conv_in' in name: + if hasattr(module, 'input_quant'): + module.input_quant.quant_injector = module.input_quant.quant_injector.let( + **{'quant_type': QuantType.FP}) + module.input_quant.init_tensor_quant() + # blacklist.append(name.split('.')[-1]) + pipe.set_progress_bar_config(disable=True) + if args.dry_run: with torch.no_grad(): run_val_inference( @@ -313,9 +322,11 @@ def input_quant_type(module): use_negative_prompts=args.use_negative_prompts, test_latents=latents, guidance_scale=args.guidance_scale) + if args.load_checkpoint is not None: with load_quant_model_mode(pipe.unet): - pipe.unet.load_state_dict(torch.load(args.load_checkpoint)) + pipe = pipe.to('cpu') + pipe.unet.load_state_dict(torch.load(args.load_checkpoint, map_location='cpu')) pipe = pipe.to(args.device) elif not args.dry_run: if (args.linear_input_bit_width is not None or From c8c9771ab8a114eca2ea14f67e52179d67096fc4 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 22 May 2024 22:58:18 +0100 Subject: [PATCH 09/22] Reduced dry-run inference time --- src/brevitas_examples/stable_diffusion/main.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index c83ec290e..021c462f1 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -220,8 +220,10 @@ def main(args): for m in pipe.unet.modules(): if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'): m.in_features = m.module.in_features + total_steps = args.calibration_steps if args.dry_run or args.load_checkpoint is not None: calibration_prompts = [calibration_prompts[0]] + total_steps = 1 run_val_inference( pipe, args.resolution, @@ -229,7 +231,7 @@ def main(args): test_seeds, args.device, dtype, - total_steps=args.calibration_steps, + total_steps=total_steps, use_negative_prompts=args.use_negative_prompts, test_latents=latents, guidance_scale=args.guidance_scale) @@ -318,7 +320,7 @@ def input_quant_type(module): test_seeds, args.device, dtype, - total_steps=args.calibration_steps, + total_steps=1, use_negative_prompts=args.use_negative_prompts, test_latents=latents, guidance_scale=args.guidance_scale) From c48cf4c8d3dff84040607a35ac93bb96a3f5d09f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 23 May 2024 11:43:11 +0100 Subject: [PATCH 10/22] parametrized args --- .../common/generative/quantize.py | 28 ++++++++++++------- .../common/generative/quantizers.py | 11 ++++++++ .../stable_diffusion/main.py | 9 +++++- .../mlperf_evaluation/accuracy.py | 16 ++++------- 4 files changed, 43 insertions(+), 21 deletions(-) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index bfc2cf898..1b0fc98ff 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -42,6 +42,7 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear +from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicOCPActPerTensorFloat from brevitas_examples.common.generative.quantizers import Fp8e4m3WeightSymmetricGroupQuant from brevitas_examples.common.generative.quantizers import Int8DynamicActPerGroupFloat from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat @@ -143,16 +144,23 @@ 'no_scale': { 'sym': Fp8e4m3Act,}}, 'float_ocp': { - 'e4m3': { - 'float_scale': { - 'stats': { - 'per_tensor': { - 'sym': Fp8e4m3OCPActPerTensorFloat}}}}, - 'e5m2': { - 'float_scale': { - 'stats': { - 'per_tensor': { - 'sym': Fp8e5m2OCPActPerTensorFloat}}}}}} + 'static': { + 'e4m3': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3OCPActPerTensorFloat}}}}, + 'e5m2': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e5m2OCPActPerTensorFloat}}}}}, + 'dynamic': { + 'e4m3': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3DynamicOCPActPerTensorFloat}}}}}}} def quantize_model( diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index 76e2e4099..445ea45b6 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -17,6 +17,7 @@ from brevitas.inject import value from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloatMSE from brevitas.quant.scaled_int import Int8WeightPerChannelFloat @@ -167,3 +168,13 @@ class ShiftedUint8DynamicActPerRowFloat(DynamicActProxyMixin, ShiftedUint8ActPer scaling_per_output_channel = True zero_point_impl = RuntimeDynamicStatsZeroPoint zero_point_stats_impl = NegativeMinOrZero + + +class Fp8e4m3DynamicOCPActPerTensorFloat(DynamicActProxyMixin, Fp8e4m3OCPActPerTensorFloat): + """ + Symmetric quantizer with per tensor dynamic scale. + """ + scaling_impl = RuntimeDynamicStatsScaling + scaling_stats_input_view_shape_impl = OverTensorView + scaling_stats_op = 'min_max' + dynamic_scaling_broadcastable_fn = lambda x, shape: x.view(SCALAR_SHAPE) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 021c462f1..ce6f59c3a 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -391,7 +391,7 @@ def input_quant_type(module): # with brevitas_proxy_inference_mode(pipe.unet): if args.use_mlperf_inference: print(f"Computing accuracy with MLPerf pipeline") - compute_mlperf_fid(pipe, args.prompt) + compute_mlperf_fid(args.model, args.path_to_coco, pipe, args.prompt) else: print(f"Computing accuracy on default prompt") prompts = list() @@ -507,6 +507,13 @@ def input_quant_type(module): default=None, help= 'Load pre-defined latents. If not provided, they are generated based on an internal seed.') + parser.add_argument( + '--path-to-coco', + type=str, + default=None, + help= + 'Path to MLPerf compliant Coco dataset. Used when the --use-mlperf flag is set. Default: None' + ) parser.add_argument( '--resolution', type=int, diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py index ec1f2e510..784ea16bd 100644 --- a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py @@ -483,27 +483,23 @@ def finalize(self, result_dict, ds=None, output_dir=None): return result_dict -def compute_mlperf_fid(model_to_replace=None, samples_to_evaluate=500): +def compute_mlperf_fid(path_to_sdxl, path_to_coco, model_to_replace=None, samples_to_evaluate=500): - post_proc = PostProcessCoco( - statistics_path='/scratch/users/gfranco/datasets/coco/tools/val2014.npz') + assert os.path.isfile(path_to_coco + '/tools/val2014.npz'), "Val2014.npz file required. Check the MLPerf directory for instructions" + + post_proc = PostProcessCoco(statistics_path=path_to_coco + '/tools/val2014.npz') dtype = next(iter(model_to_replace.unet.parameters())).dtype res_dict = {} model = BackendPytorch( - '/scratch/hf_models/stable-diffusion-xl-base-1.0/stable-diffusion-xl-base-1.0/', - 'xl', - steps=20, - batch_size=1, - device='cpu', - precision=dtype) + path_to_sdxl, 'xl', steps=20, batch_size=1, device='cpu', precision=dtype) model.load() if model_to_replace is not None: model.pipe = model_to_replace ds = Coco( - data_path='/scratch/users/gfranco/datasets/coco', + data_path=path_to_coco, name="coco-1024", pre_process=torch.nn.Identity, count=None, From fc1569b6c39b5097051e1e1b96e271960556a36e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 23 May 2024 17:51:07 +0100 Subject: [PATCH 11/22] Updated readme --- .../stable_diffusion/README.md | 17 +++++++++++++++++ .../mlperf_evaluation/requirements.txt | 1 + 2 files changed, 18 insertions(+) diff --git a/src/brevitas_examples/stable_diffusion/README.md b/src/brevitas_examples/stable_diffusion/README.md index 03393d0a4..db64b4e41 100644 --- a/src/brevitas_examples/stable_diffusion/README.md +++ b/src/brevitas_examples/stable_diffusion/README.md @@ -1,5 +1,21 @@ # Stable Diffusion Quantization +## Requirements + +For MLPerf inference execution, it is recommended to follow the MLPerf instruction to download the dataset and all relevant files, +such as pre-generated latents and captions for calibration. + +Similarly, a new python enviornment should be used with python<=3.10, installing first the requirements specified in +`requirements.txt` in stable_diffusion/mlperf_evaluation. + + +Afterwards, install brevitas with: +```bash +pip install -e .[export] +``` + +## Quantization Options + It supports Stable Diffusion 2.1 and Stable Diffusion XL. The following PTQ techniques are currently supported: @@ -9,6 +25,7 @@ The following PTQ techniques are currently supported: - Bias Correction These techniques can be applied for both integer and floating point quantization. + Activation quantization is optional, and disabled by default. To enable, set both `conv-input-bit-width` and `linear-input-bit-width`. We support ONNX integer export, and we are planning to release soon export for floating point quantization (e.g., FP8). diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt index 3b453267e..690f7b0b0 100644 --- a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt @@ -5,4 +5,5 @@ opencv-python==4.8.1.78 pycocotools==2.0.7 scipy==1.9.1 torchmetrics[image]==1.2.0 +tqdm transformers==4.33.2 From 1f915b4bdbb8cd579656f65fbdb8689bbf86e64d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 27 May 2024 10:29:59 +0100 Subject: [PATCH 12/22] Quantize input zp --- src/brevitas_examples/stable_diffusion/main.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index ce6f59c3a..7ee7495c0 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -290,6 +290,7 @@ def input_quant_type(module): weight_quant_granularity=args.weight_quant_granularity, weight_group_size=args.weight_group_size, quantize_weight_zero_point=args.quantize_weight_zero_point, + quantize_input_zero_point=args.quantize_input_zero_point, input_bit_width=input_bit_width, input_quant_format=args.input_quant_format, input_scale_type=args.input_scale_type, @@ -309,7 +310,6 @@ def input_quant_type(module): **{'quant_type': QuantType.FP}) module.input_quant.init_tensor_quant() # blacklist.append(name.split('.')[-1]) - pipe.set_progress_bar_config(disable=True) if args.dry_run: @@ -659,6 +659,11 @@ def input_quant_type(module): 'quantize-weight-zero-point', default=True, help='Quantize weight zero-point. Default: Enabled') + add_bool_arg( + parser, + 'quantize-input-zero-point', + default=False, + help='Quantize input zero-point. Default: Enabled') add_bool_arg( parser, 'export-cuda-float16', default=False, help='Export FP16 on CUDA. Default: Disabled') add_bool_arg( From c7932da0d72fea8df6a97df02fa3434460941626 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 31 May 2024 13:16:51 +0100 Subject: [PATCH 13/22] update --- .../common/generative/quantize.py | 14 +-- .../stable_diffusion/README.md | 50 ++++++++-- .../stable_diffusion/main.py | 93 +++++++++++++++---- 3 files changed, 119 insertions(+), 38 deletions(-) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 1b0fc98ff..2e0fe12d0 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -183,7 +183,6 @@ def quantize_model( input_quant_type=None, input_quant_granularity=None, input_group_size=None, - input_stats_op='percentile', quantize_input_zero_point=False, quantize_embedding=False, use_ocp=False, @@ -201,7 +200,7 @@ def quantize_model( weight_float_format = { 'exponent_bit_width': int(weight_quant_format[1]), 'mantissa_bit_width': int(weight_quant_format[3])} - if ocp_weight_format: + if use_ocp: weight_quant_format += '_ocp' ocp_weight_format = weight_quant_format weight_quant_format = 'float' @@ -211,7 +210,7 @@ def quantize_model( input_float_format = { 'exponent_bit_width': int(input_quant_format[1]), 'mantissa_bit_width': int(input_quant_format[3])} - if ocp_weight_format: + if use_ocp: input_quant_format += '_ocp' ocp_input_format = input_quant_format input_quant_format = 'float' @@ -255,15 +254,6 @@ def quantize_model( if input_kwargs is None: input_kwargs = dict() - if input_stats_op == 'minmax': - if input_quant_type == 'asym': - input_scaling_stats_op = StatsOp.MIN_MAX - zero_point_stats_impl = NegativeMinOrZero - input_kwargs['zero_point_stats_impl'] = zero_point_stats_impl - else: - input_scaling_stats_op = StatsOp.MAX - input_kwargs['scaling_stats_op'] = input_scaling_stats_op - input_quant = input_quant.let(**input_kwargs) sym_input_quant = sym_input_quant.let(**input_kwargs) linear_input_quant = linear_input_quant.let(**input_kwargs) diff --git a/src/brevitas_examples/stable_diffusion/README.md b/src/brevitas_examples/stable_diffusion/README.md index db64b4e41..01c05b2e8 100644 --- a/src/brevitas_examples/stable_diffusion/README.md +++ b/src/brevitas_examples/stable_diffusion/README.md @@ -60,7 +60,8 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--calibration-prompt-path CALIBRATION_PROMPT_PATH] [--checkpoint-name CHECKPOINT_NAME] [--load-checkpoint LOAD_CHECKPOINT] - [--path-to-latents PATH_TO_LATENTS] [--resolution RESOLUTION] + [--path-to-latents PATH_TO_LATENTS] + [--path-to-coco PATH_TO_COCO] [--resolution RESOLUTION] [--guidance-scale GUIDANCE_SCALE] [--calibration-steps CALIBRATION_STEPS] [--output-path OUTPUT_PATH | --no-output-path] @@ -77,7 +78,8 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH] [--weight-param-method {stats,mse}] [--input-param-method {stats,mse}] - [--input-stats-op {minmax,percentile}] + [--input-scale-stats-op {minmax,percentile}] + [--input-zp-stats-op {minmax,percentile}] [--weight-scale-precision {float_scale,po2_scale}] [--input-scale-precision {float_scale,po2_scale}] [--weight-quant-type {sym,asym}] @@ -89,11 +91,16 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--input-scale-type {static,dynamic}] [--weight-group-size WEIGHT_GROUP_SIZE] [--quantize-weight-zero-point | --no-quantize-weight-zero-point] + [--quantize-input-zero-point | --no-quantize-input-zero-point] [--export-cuda-float16 | --no-export-cuda-float16] [--use-mlperf-inference | --no-use-mlperf-inference] [--use-ocp | --no-use-ocp] [--use-negative-prompts | --no-use-negative-prompts] [--dry-run | --no-dry-run] + [--quantize-time-emb | --no-quantize-time-emb] + [--quantize-conv-in | --no-quantize-conv-in] + [--quantize-input-time-emb | --no-quantize-input-time-emb] + [--quantize-input-conv-in | --no-quantize-input-conv-in] Stable Diffusion quantization @@ -113,21 +120,24 @@ options: --calibration-prompt-path CALIBRATION_PROMPT_PATH Path to calibration prompt --checkpoint-name CHECKPOINT_NAME - Name to use to store the checkpoint. If not provided, - no checkpoint is saved. + Name to use to store the checkpoint in the output dir. + If not provided, no checkpoint is saved. --load-checkpoint LOAD_CHECKPOINT Path to checkpoint to load. If provided, PTQ techniques are skipped. --path-to-latents PATH_TO_LATENTS Load pre-defined latents. If not provided, they are generated based on an internal seed. + --path-to-coco PATH_TO_COCO + Path to MLPerf compliant Coco dataset. Used when the + --use-mlperf flag is set. Default: None --resolution RESOLUTION Resolution along height and width dimension. Default: 512. --guidance-scale GUIDANCE_SCALE Guidance scale. --calibration-steps CALIBRATION_STEPS - Percentage of steps used during calibration + Steps used during calibration --output-path OUTPUT_PATH Path where to generate output folder. --no-output-path Disable Path where to generate output folder. @@ -169,8 +179,12 @@ options: How scales/zero-point are determined. Default: stats. --input-param-method {stats,mse} How scales/zero-point are determined. Default: stats. - --input-stats-op {minmax,percentile} - Define what statics op to use . Default: minmax. + --input-scale-stats-op {minmax,percentile} + Define what statics op to use for input scale. + Default: minmax. + --input-zp-stats-op {minmax,percentile} + Define what statics op to use for input zero point. + Default: minmax. --weight-scale-precision {float_scale,po2_scale} Whether scale is a float value or a po2. Default: float_scale. @@ -203,6 +217,10 @@ options: Enable Quantize weight zero-point. Default: Enabled --no-quantize-weight-zero-point Disable Quantize weight zero-point. Default: Enabled + --quantize-input-zero-point + Enable Quantize input zero-point. Default: Enabled + --no-quantize-input-zero-point + Disable Quantize input zero-point. Default: Enabled --export-cuda-float16 Enable Export FP16 on CUDA. Default: Disabled --no-export-cuda-float16 @@ -227,5 +245,23 @@ options: calibration. Default: Disabled --no-dry-run Disable Generate a quantized model without any calibration. Default: Disabled + --quantize-time-emb Enable Quantize time embedding layers. Default: True + --no-quantize-time-emb + Disable Quantize time embedding layers. Default: True + --quantize-conv-in Enable Quantize first conv layer. Default: True + --no-quantize-conv-in + Disable Quantize first conv layer. Default: True + --quantize-input-time-emb + Enable Quantize input to time embedding layers. + Default: Disabled + --no-quantize-input-time-emb + Disable Quantize input to time embedding layers. + Default: Disabled + --quantize-input-conv-in + Enable Quantize input to first conv layer. Default: + Enabled + --no-quantize-input-conv-in + Disable Quantize input to first conv layer. Default: + Enabled ``` diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 7ee7495c0..cc7b4d333 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -20,6 +20,7 @@ from torchmetrics.image.fid import FrechetInceptionDistance from tqdm import tqdm +from brevitas.core.stats.stats_op import NegativeMinOrZero from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas.export.torch.qcdq.manager import TorchQCDQManager from brevitas.graph.calibrate import bias_correction_mode @@ -28,6 +29,7 @@ from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.gptq import gptq_mode from brevitas.inject.enum import QuantType +from brevitas.inject.enum import StatsOp from brevitas.nn.equalized_layer import EqualizedModule from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer from brevitas.utils.torch_utils import KwargsForwardHook @@ -202,11 +204,13 @@ def main(args): pipe.enable_attention_slicing() # Extract list of layers to avoid - # blacklist = [] - # for name, _ in pipe.unet.named_modules(): - # if 'time_emb' in name or 'conv_in' in name: - # blacklist.append(name) - # print(f"Blacklisted layers: {blacklist}") + blacklist = [] + for name, _ in pipe.unet.named_modules(): + if 'time_emb' in name and not args.quantize_time_emb: + blacklist.append(name) + if 'conv_in' in name and not args.quantize_conv_in: + blacklist.append(name) + print(f"Blacklisted layers: {blacklist}") # Make sure there all LoRA layers are fused first, otherwise raise an error for m in pipe.unet.modules(): @@ -266,7 +270,7 @@ def input_bit_width(module): if args.linear_input_bit_width is None or args.conv_input_bit_width is None: @value - def input_quant_type(module): + def input_quant_enabled(module): if args.linear_input_bit_width is None and isinstance(module, nn.Linear): return QuantType.FP elif args.conv_input_bit_width is None and isinstance(module, nn.Conv2d): @@ -274,14 +278,36 @@ def input_quant_type(module): else: return QuantType.INT - input_kwargs['quant_type'] = input_quant_type + input_kwargs['quant_type'] = input_quant_enabled + + if args.input_scale_stats_op == 'minmax': + + @value + def input_scale_stats_type(): + if args.input_quant_type == 'asym': + input_scaling_stats_op = StatsOp.MIN_MAX + else: + input_scaling_stats_op = StatsOp.MAX + return input_scaling_stats_op + + input_kwargs['scaling_stats_op'] = input_scale_stats_type + + if args.input_zp_stats_op == 'minmax': + + @value + def input_zp_stats_type(): + if args.input_quant_type == 'asym': + zero_point_stats_impl = NegativeMinOrZero + return zero_point_stats_impl + + input_kwargs['zero_point_stats_impl'] = input_zp_stats_type print("Applying model quantization...") quantize_model( pipe.unet, dtype=dtype, device=args.device, - # name_blacklist=blacklist, + name_blacklist=blacklist, weight_bit_width=weight_bit_width, weight_quant_format=args.weight_quant_format, weight_quant_type=args.weight_quant_type, @@ -298,18 +324,25 @@ def input_quant_type(module): input_param_method=args.input_param_method, input_quant_type=args.input_quant_type, input_quant_granularity=args.input_quant_granularity, - input_stats_op=args.input_stats_op, use_ocp=args.use_ocp, input_kwargs=input_kwargs) print("Model quantization applied.") + skipped_layers = [] for name, module in pipe.unet.named_modules(): - if 'time_emb' in name: # or 'conv_in' in name: + if 'time_emb' in name and not args.quantize_input_time_emb: if hasattr(module, 'input_quant'): module.input_quant.quant_injector = module.input_quant.quant_injector.let( **{'quant_type': QuantType.FP}) module.input_quant.init_tensor_quant() - # blacklist.append(name.split('.')[-1]) + skipped_layers.append(name) + if 'conv_in' in name and not args.quantize_input_conv_in: + if hasattr(module, 'input_quant'): + module.input_quant.quant_injector = module.input_quant.quant_injector.let( + **{'quant_type': QuantType.FP}) + module.input_quant.init_tensor_quant() + skipped_layers.append(name) + print(f"Skipped input quantization for layers: {skipped_layers}") pipe.set_progress_bar_config(disable=True) if args.dry_run: @@ -384,7 +417,7 @@ def input_quant_type(module): guidance_scale=args.guidance_scale) if args.checkpoint_name is not None and args.load_checkpoint is None: - torch.save(pipe.unet.state_dict(), args.checkpoint_name) + torch.save(pipe.unet.state_dict(), os.path.join(output_dir, args.checkpoint_name)) # Perform inference if args.prompt > 0 and not args.dry_run: @@ -495,7 +528,9 @@ def input_quant_type(module): '--checkpoint-name', type=str, default=None, - help='Name to use to store the checkpoint. If not provided, no checkpoint is saved.') + help= + 'Name to use to store the checkpoint in the output dir. If not provided, no checkpoint is saved.' + ) parser.add_argument( '--load-checkpoint', type=str, @@ -521,10 +556,7 @@ def input_quant_type(module): help='Resolution along height and width dimension. Default: 512.') parser.add_argument('--guidance-scale', type=float, default=7.5, help='Guidance scale.') parser.add_argument( - '--calibration-steps', - type=float, - default=8, - help='Percentage of steps used during calibration') + '--calibration-steps', type=float, default=8, help='Steps used during calibration') add_bool_arg( parser, 'output-path', @@ -590,11 +622,17 @@ def input_quant_type(module): choices=['stats', 'mse'], help='How scales/zero-point are determined. Default: stats.') parser.add_argument( - '--input-stats-op', + '--input-scale-stats-op', type=str, default='minmax', choices=['minmax', 'percentile'], - help='Define what statics op to use . Default: minmax.') + help='Define what statics op to use for input scale. Default: minmax.') + parser.add_argument( + '--input-zp-stats-op', + type=str, + default='minmax', + choices=['minmax', 'percentile'], + help='Define what statics op to use for input zero point. Default: minmax.') parser.add_argument( '--weight-scale-precision', type=str, @@ -686,6 +724,23 @@ def input_quant_type(module): 'dry-run', default=False, help='Generate a quantized model without any calibration. Default: Disabled') + add_bool_arg( + parser, + 'quantize-time-emb', + default=True, + help='Quantize time embedding layers. Default: True') + add_bool_arg( + parser, 'quantize-conv-in', default=True, help='Quantize first conv layer. Default: True') + add_bool_arg( + parser, + 'quantize-input-time-emb', + default=False, + help='Quantize input to time embedding layers. Default: Disabled') + add_bool_arg( + parser, + 'quantize-input-conv-in', + default=True, + help='Quantize input to first conv layer. Default: Enabled') args = parser.parse_args() print("Args: " + str(vars(args))) main(args) From 09777b75e5d679898963b0230fbc9489102d4aa3 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 30 May 2024 14:39:31 +0100 Subject: [PATCH 14/22] Small fixes --- src/brevitas/core/function_wrapper/clamp.py | 2 +- src/brevitas_examples/stable_diffusion/main.py | 6 +++--- .../stable_diffusion/mlperf_evaluation/accuracy.py | 12 +++++++++++- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/brevitas/core/function_wrapper/clamp.py b/src/brevitas/core/function_wrapper/clamp.py index cca7da087..70d1fc23f 100644 --- a/src/brevitas/core/function_wrapper/clamp.py +++ b/src/brevitas/core/function_wrapper/clamp.py @@ -87,7 +87,7 @@ class FloatClamp(brevitas.jit.ScriptModule): I.e. setting inf to 1101.111 (E4M3) is not a valid code. """ - __constants__ = ['saturating', 'inf_values', 'nan_values', 'signed', 'max_available_float'] + __constants__ = ['saturating', 'inf_values', 'nan_values', 'signed'] def __init__( self, diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index cc7b4d333..cf6889dfd 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -207,9 +207,9 @@ def main(args): blacklist = [] for name, _ in pipe.unet.named_modules(): if 'time_emb' in name and not args.quantize_time_emb: - blacklist.append(name) + blacklist.append(name.split('.')[-1]) if 'conv_in' in name and not args.quantize_conv_in: - blacklist.append(name) + blacklist.append(name.split('.')[-1]) print(f"Blacklisted layers: {blacklist}") # Make sure there all LoRA layers are fused first, otherwise raise an error @@ -424,7 +424,7 @@ def input_zp_stats_type(): # with brevitas_proxy_inference_mode(pipe.unet): if args.use_mlperf_inference: print(f"Computing accuracy with MLPerf pipeline") - compute_mlperf_fid(args.model, args.path_to_coco, pipe, args.prompt) + compute_mlperf_fid(args.model, args.path_to_coco, pipe, args.prompt, output_dir) else: print(f"Computing accuracy on default prompt") prompts = list() diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py index 784ea16bd..fcb6a04df 100644 --- a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py @@ -180,6 +180,7 @@ END OF TERMS AND CONDITIONS """ +import json import logging import os import pathlib @@ -483,7 +484,12 @@ def finalize(self, result_dict, ds=None, output_dir=None): return result_dict -def compute_mlperf_fid(path_to_sdxl, path_to_coco, model_to_replace=None, samples_to_evaluate=500): +def compute_mlperf_fid( + path_to_sdxl, + path_to_coco, + model_to_replace=None, + samples_to_evaluate=500, + output_dir=None): assert os.path.isfile(path_to_coco + '/tools/val2014.npz'), "Val2014.npz file required. Check the MLPerf directory for instructions" @@ -521,3 +527,7 @@ def compute_mlperf_fid(path_to_sdxl, path_to_coco, model_to_replace=None, sample runner.run_one_item(Item(idx, idx, data, label)) post_proc.finalize(res_dict, ds=ds) log.info(res_dict) + if output_dir is not None: + # Dump args to json + with open(os.path.join(output_dir, 'results_mlperf.json'), 'w') as fp: + json.dump(res_dict, fp) From 2c5c0d5475b757f9395e5e1ddaf26c589b937696 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 31 May 2024 13:25:53 +0100 Subject: [PATCH 15/22] remove gpxq changes --- src/brevitas/graph/gpfq.py | 5 ----- src/brevitas/graph/gptq.py | 5 ----- src/brevitas/graph/gpxq.py | 13 ------------- 3 files changed, 23 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 9a8adc8e9..fd7df9223 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -298,9 +298,6 @@ def single_layer_update(self): # No permutation, permutation tensor is a ordered index perm = torch.tensor(range(weight.shape[-1]), device=dev) permutation_list.append(perm) - - self.reactivate_quantization() - for t in range(weight.shape[-1]): for group_index in range(self.groups): U[group_index] += torch.matmul( @@ -400,8 +397,6 @@ def single_layer_update(self): perm = torch.tensor(range(weight.shape[-1]), device=dev) permutation_list.append(perm) - self.reactivate_quantization() - for t in range(weight.shape[-1]): for group_index in range(self.groups): U[group_index] += torch.matmul( diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 0861fd15c..31d31433b 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -86,11 +86,9 @@ def catch_stopfwd(self, *args, **kwargs): # If we want to return the output of the network, we need to disable all hooks for name, gpxq_class in self.gpxq_layers.items(): gpxq_class.disable_pre_forward_hook = True - out = self.orig_forward(*args, **kwargs) for name, gpxq_class in self.gpxq_layers.items(): gpxq_class.disable_pre_forward_hook = False - return out def initialize_module_optimizer( @@ -136,7 +134,6 @@ def __init__( device='cpu', dtype=torch.float32) self.nsamples = 0 - self.done = False assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher" @@ -260,8 +257,6 @@ def single_layer_update(self, percdamp=.01): finally: del self.H - self.reactivate_quantization() - for i1 in range(0, self.columns, self.blocksize): i2 = min(i1 + self.blocksize, self.columns) count = i2 - i1 diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 2b46ac4f4..fdbaee52f 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -164,10 +164,6 @@ def __enter__(self): return self def __exit__(self, type, value, traceback): - for name, layer in self.gpxq_layers.items(): - if not layer.done: - layer.reactivate_quantization() - if isinstance(self.model, (GraphModule, TorchGraphModule)): self.model.__class__.forward = self.orig_forward else: @@ -223,10 +219,6 @@ def __init__( self.disable_pre_forward_hook = False # Some layers require knowledge from quant inputs to compute quant weights self.quant_metadata = None - self.disable_quant_inference = DisableEnableQuantization() - self.return_quant_tensor_state = disable_return_quant_tensor(self.layer) - self.disable_quant_inference.disable_param_quantization(self.layer, False) - self.done = False def process_input(self, inp): # Input is a tuple, so we take first element @@ -263,11 +255,6 @@ def update_batch(self): def single_layer_update(self): pass - def reactivate_quantization(self): - self.done = True - self.disable_quant_inference.enable_param_quantization(self.layer, False) - restore_return_quant_tensor(self.layer, self.return_quant_tensor_state) - def get_quant_weights(self, i, i1, permutation_list): # We need to recompute quant weights at runtime since our float weights are being updated # Add offset in case of blockwise computation From 55fc1d4aedf0f0a181fb448c3d75253e37d82cc5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 31 May 2024 14:19:25 +0100 Subject: [PATCH 16/22] update --- src/brevitas_examples/stable_diffusion/README.md | 3 +++ src/brevitas_examples/stable_diffusion/main.py | 13 +++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/README.md b/src/brevitas_examples/stable_diffusion/README.md index 01c05b2e8..ad2dc57a8 100644 --- a/src/brevitas_examples/stable_diffusion/README.md +++ b/src/brevitas_examples/stable_diffusion/README.md @@ -75,6 +75,7 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--conv-weight-bit-width CONV_WEIGHT_BIT_WIDTH] [--linear-weight-bit-width LINEAR_WEIGHT_BIT_WIDTH] [--conv-input-bit-width CONV_INPUT_BIT_WIDTH] + [--act-eq-alpha ACT_EQ_ALPHA] [--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH] [--weight-param-method {stats,mse}] [--input-param-method {stats,mse}] @@ -173,6 +174,8 @@ options: Weight bit width. Default: 8. --conv-input-bit-width CONV_INPUT_BIT_WIDTH Input bit width. Default: None (not quantized) + --act-eq-alpha ACT_EQ_ALPHA + Alpha for activation equalization. Default: 0.9 --linear-input-bit-width LINEAR_INPUT_BIT_WIDTH Input bit width. Default: None (not quantized). --weight-param-method {stats,mse} diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index cf6889dfd..8aa960d0f 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -136,9 +136,6 @@ def run_val_inference( def main(args): - if args.export_target: - assert args.weight_quant_format == 'int', "Currently only integer quantization supported for export." - dtype = getattr(torch, args.dtype) calibration_prompts = CALIBRATION_PROMPTS @@ -219,7 +216,10 @@ def main(args): if args.activation_equalization: pipe.set_progress_bar_config(disable=True) - with activation_equalization_mode(pipe.unet, alpha=0.9, layerwise=True, add_mul_node=True): + with activation_equalization_mode(pipe.unet, + alpha=args.act_eq_alpha, + layerwise=True, + add_mul_node=True): # Workaround to expose `in_features` attribute from the Hook Wrapper for m in pipe.unet.modules(): if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'): @@ -604,6 +604,11 @@ def input_zp_stats_type(): type=int, default=None, help='Input bit width. Default: None (not quantized)') + parser.add_argument( + '--act-eq-alpha', + type=float, + default=0.9, + help='Alpha for activation equalization. Default: 0.9') parser.add_argument( '--linear-input-bit-width', type=int, From 8323acd0fbc11a862164527b4eefe39ed035c672 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 14 Jun 2024 15:43:32 +0100 Subject: [PATCH 17/22] Update --- .../common/generative/quantize.py | 104 ++++++++++++-- .../stable_diffusion/README.md | 37 +---- .../stable_diffusion/main.py | 132 ++++++------------ .../stable_diffusion/sd_quant/export.py | 19 --- .../stable_diffusion/sd_quant/utils.py | 98 ------------- 5 files changed, 140 insertions(+), 250 deletions(-) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 2e0fe12d0..44cc7262e 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -163,8 +163,7 @@ 'sym': Fp8e4m3DynamicOCPActPerTensorFloat}}}}}}} -def quantize_model( - model, +def generate_quantizers( dtype, weight_bit_width, weight_param_method, @@ -174,7 +173,6 @@ def quantize_model( weight_group_size, quantize_weight_zero_point, weight_quant_format='int', - name_blacklist=None, input_bit_width=None, input_quant_format='', input_scale_precision=None, @@ -184,7 +182,6 @@ def quantize_model( input_quant_granularity=None, input_group_size=None, quantize_input_zero_point=False, - quantize_embedding=False, use_ocp=False, device=None, weight_kwargs=None, @@ -200,20 +197,20 @@ def quantize_model( weight_float_format = { 'exponent_bit_width': int(weight_quant_format[1]), 'mantissa_bit_width': int(weight_quant_format[3])} + ocp_weight_format = weight_quant_format + weight_quant_format = 'float' if use_ocp: weight_quant_format += '_ocp' - ocp_weight_format = weight_quant_format - weight_quant_format = 'float' else: weight_float_format = {} if re.compile(r'e[1-8]m[1-8]').match(input_quant_format): input_float_format = { 'exponent_bit_width': int(input_quant_format[1]), 'mantissa_bit_width': int(input_quant_format[3])} + ocp_input_format = input_quant_format + input_quant_format = 'float' if use_ocp: input_quant_format += '_ocp' - ocp_input_format = input_quant_format - input_quant_format = 'float' else: input_float_format = {} @@ -230,15 +227,15 @@ def quantize_model( input_scale_type][input_quant_type] elif input_bit_width is not None: if ocp_input_format: - input_quant = INPUT_QUANT_MAP[input_quant_format][ocp_input_format][input_scale_type][ + input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ocp_input_format][ input_scale_precision][input_param_method][input_quant_granularity][ input_quant_type] # Some activations in MHA should always be symmetric - sym_input_quant = INPUT_QUANT_MAP[input_quant_format][ocp_input_format][ - input_scale_type][input_scale_precision][input_param_method][ + sym_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ + ocp_input_format][input_scale_precision][input_param_method][ input_quant_granularity]['sym'] - linear_input_quant = INPUT_QUANT_MAP[input_quant_format][ocp_input_format][ - input_scale_type][input_scale_precision][input_param_method][ + linear_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ + ocp_input_format][input_scale_precision][input_param_method][ input_quant_granularity][input_quant_type] else: input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ @@ -365,6 +362,21 @@ def quantize_model( linear_input_quant = linear_input_quant.let( **{ 'group_dim': -1, 'group_size': input_group_size}) + return linear_input_quant, weight_quant, input_quant, q_scaled_quant, k_transposed_quant, v_quant, attn_output_weights_quant + + +def generate_quant_maps( + linear_input_quant, + weight_quant, + input_quant, + q_scaled_quant, + k_transposed_quant, + v_quant, + attn_output_weights_quant, + dtype, + device, + input_quant_format, + quantize_embedding): quant_linear_kwargs = { 'input_quant': linear_input_quant, @@ -380,7 +392,7 @@ def quantize_model( 'in_proj_bias_quant': None, 'softmax_input_quant': None, 'attn_output_weights_quant': attn_output_weights_quant, - 'attn_output_weights_signed': input_quant_format == 'float', + 'attn_output_weights_signed': 'float' in input_quant_format, 'q_scaled_quant': q_scaled_quant, 'k_transposed_quant': k_transposed_quant, 'v_quant': v_quant, @@ -406,7 +418,71 @@ def quantize_model( if quantize_embedding: quant_embedding_kwargs = {'weight_quant': weight_quant, 'dtype': dtype, 'device': device} layer_map[nn.Embedding] = (qnn.QuantEmbedding, quant_embedding_kwargs) + return layer_map + + +def quantize_model( + model, + dtype, + weight_bit_width, + weight_param_method, + weight_scale_precision, + weight_quant_type, + weight_quant_granularity, + weight_group_size, + quantize_weight_zero_point, + weight_quant_format='int', + name_blacklist=None, + input_bit_width=None, + input_quant_format='', + input_scale_precision=None, + input_scale_type=None, + input_param_method=None, + input_quant_type=None, + input_quant_granularity=None, + input_group_size=None, + quantize_input_zero_point=False, + quantize_embedding=False, + use_ocp=False, + device=None, + weight_kwargs=None, + input_kwargs=None): + linear_input_quant, weight_quant, input_quant, q_scaled_quant, k_transposed_quant, v_quant, attn_output_weights_quant = generate_quantizers( + dtype, + weight_bit_width, + weight_param_method, + weight_scale_precision, + weight_quant_type, + weight_quant_granularity, + weight_group_size, + quantize_weight_zero_point, + weight_quant_format, + input_bit_width, + input_quant_format, + input_scale_precision, + input_scale_type, + input_param_method, + input_quant_type, + input_quant_granularity, + input_group_size, + quantize_input_zero_point, + use_ocp, + device, + weight_kwargs, + input_kwargs) + layer_map = generate_quant_maps( + linear_input_quant, + weight_quant, + input_quant, + q_scaled_quant, + k_transposed_quant, + v_quant, + attn_output_weights_quant, + dtype, + device, + input_quant_format, + quantize_embedding) model = layerwise_quantize( model=model, compute_layer_map=layer_map, name_blacklist=name_blacklist) return model diff --git a/src/brevitas_examples/stable_diffusion/README.md b/src/brevitas_examples/stable_diffusion/README.md index ad2dc57a8..e1f98b7da 100644 --- a/src/brevitas_examples/stable_diffusion/README.md +++ b/src/brevitas_examples/stable_diffusion/README.md @@ -30,7 +30,7 @@ Activation quantization is optional, and disabled by default. To enable, set bot We support ONNX integer export, and we are planning to release soon export for floating point quantization (e.g., FP8). -To export the model with fp16 scale factors, enable `export-cuda-float16`. This will performing the tracing necessary for export on GPU, leaving the model in fp16. +To export the model with fp16 scale factors, disable `export-cpu-float32`. This will performing the tracing necessary for export on GPU, leaving the model in fp16. If the flag is not enabled, the model will be moved to CPU and cast to float32 before export because of missing CPU kernels in fp16. To use MLPerf inference setup, check and install the correct requirements specified in the `requirements.txt` file under mlperf_evaluation. @@ -70,7 +70,7 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--gptq | --no-gptq] [--bias-correction | --no-bias-correction] [--dtype {float32,float16,bfloat16}] [--attention-slicing | --no-attention-slicing] - [--export-target {,torch,onnx}] + [--export-target {,onnx}] [--export-weight-q-node | --no-export-weight-q-node] [--conv-weight-bit-width CONV_WEIGHT_BIT_WIDTH] [--linear-weight-bit-width LINEAR_WEIGHT_BIT_WIDTH] @@ -93,15 +93,11 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--weight-group-size WEIGHT_GROUP_SIZE] [--quantize-weight-zero-point | --no-quantize-weight-zero-point] [--quantize-input-zero-point | --no-quantize-input-zero-point] - [--export-cuda-float16 | --no-export-cuda-float16] + [--export-cpu-float32 | --no-export-cpu-float32] [--use-mlperf-inference | --no-use-mlperf-inference] [--use-ocp | --no-use-ocp] [--use-negative-prompts | --no-use-negative-prompts] [--dry-run | --no-dry-run] - [--quantize-time-emb | --no-quantize-time-emb] - [--quantize-conv-in | --no-quantize-conv-in] - [--quantize-input-time-emb | --no-quantize-input-time-emb] - [--quantize-input-conv-in | --no-quantize-input-conv-in] Stable Diffusion quantization @@ -160,7 +156,7 @@ options: --attention-slicing Enable Enable attention slicing. Default: Disabled --no-attention-slicing Disable Enable attention slicing. Default: Disabled - --export-target {,torch,onnx} + --export-target {,onnx} Target export flow. --export-weight-q-node Enable Enable export of floating point weights + QDQ @@ -224,10 +220,9 @@ options: Enable Quantize input zero-point. Default: Enabled --no-quantize-input-zero-point Disable Quantize input zero-point. Default: Enabled - --export-cuda-float16 - Enable Export FP16 on CUDA. Default: Disabled - --no-export-cuda-float16 - Disable Export FP16 on CUDA. Default: Disabled + --export-cpu-float32 Enable Export FP32 on CPU. Default: Disabled + --no-export-cpu-float32 + Disable Export FP32 on CPU. Default: Disabled --use-mlperf-inference Enable Evaluate FID score with MLPerf pipeline. Default: False @@ -248,23 +243,5 @@ options: calibration. Default: Disabled --no-dry-run Disable Generate a quantized model without any calibration. Default: Disabled - --quantize-time-emb Enable Quantize time embedding layers. Default: True - --no-quantize-time-emb - Disable Quantize time embedding layers. Default: True - --quantize-conv-in Enable Quantize first conv layer. Default: True - --no-quantize-conv-in - Disable Quantize first conv layer. Default: True - --quantize-input-time-emb - Enable Quantize input to time embedding layers. - Default: Disabled - --no-quantize-input-time-emb - Disable Quantize input to time embedding layers. - Default: Disabled - --quantize-input-conv-in - Enable Quantize input to first conv layer. Default: - Enabled - --no-quantize-input-conv-in - Disable Quantize input to first conv layer. Default: - Enabled ``` diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 8aa960d0f..af5b9203d 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -22,17 +22,18 @@ from brevitas.core.stats.stats_op import NegativeMinOrZero from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager -from brevitas.export.torch.qcdq.manager import TorchQCDQManager from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode from brevitas.graph.calibrate import load_quant_model_mode from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.gptq import gptq_mode -from brevitas.inject.enum import QuantType +from brevitas.graph.quantize import layerwise_quantize from brevitas.inject.enum import StatsOp from brevitas.nn.equalized_layer import EqualizedModule -from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer +from brevitas.nn.quant_activation import QuantIdentity from brevitas.utils.torch_utils import KwargsForwardHook +from brevitas_examples.common.generative.quantize import generate_quant_maps +from brevitas_examples.common.generative.quantize import generate_quantizers from brevitas_examples.common.generative.quantize import quantize_model from brevitas_examples.common.parse_utils import add_bool_arg from brevitas_examples.common.parse_utils import quant_format_validator @@ -41,8 +42,6 @@ from brevitas_examples.stable_diffusion.sd_quant.constants import SD_2_1_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.export import export_onnx -from brevitas_examples.stable_diffusion.sd_quant.export import export_torch_export -from brevitas_examples.stable_diffusion.sd_quant.utils import brevitas_proxy_inference_mode from brevitas_examples.stable_diffusion.sd_quant.utils import generate_latents from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_21_rand_inputs from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_xl_rand_inputs @@ -141,12 +140,9 @@ def main(args): calibration_prompts = CALIBRATION_PROMPTS if args.calibration_prompt_path is not None: calibration_prompts = load_calib_prompts(args.calibration_prompt_path) - prompts = list() - for i, v in enumerate(calibration_prompts): - if i == args.calibration_prompt: - break - prompts.append(v) - calibration_prompts = prompts + print(args.calibration_prompt, len(calibration_prompts)) + assert args.calibration_prompt <= len(calibration_prompts) , f"Only {len(calibration_prompts)} prompts are available" + calibration_prompts = calibration_prompts[:args.calibration_prompt] latents = None if args.path_to_latents is not None: @@ -176,15 +172,11 @@ def main(args): if args.prompt > 0 and not args.use_mlperf_inference: print(f"Running inference with prompt ...") - prompts = [] - for i, v in enumerate(TESTING_PROMPTS): - if i == args.prompt: - break - prompts.append(v) + testing_prompts = TESTING_PROMPTS[:args.prompt] float_images = run_test_inference( pipe, args.resolution, - prompts, + testing_prompts, test_seeds, output_dir, args.device, @@ -203,9 +195,7 @@ def main(args): # Extract list of layers to avoid blacklist = [] for name, _ in pipe.unet.named_modules(): - if 'time_emb' in name and not args.quantize_time_emb: - blacklist.append(name.split('.')[-1]) - if 'conv_in' in name and not args.quantize_conv_in: + if 'time_emb' in name: blacklist.append(name.split('.')[-1]) print(f"Blacklisted layers: {blacklist}") @@ -263,23 +253,12 @@ def input_bit_width(module): return args.linear_input_bit_width elif isinstance(module, nn.Conv2d): return args.conv_input_bit_width + elif isinstance(module, QuantIdentity): + return args.quant_identity_bit_width else: raise RuntimeError(f"Module {module} not supported.") input_kwargs = dict() - if args.linear_input_bit_width is None or args.conv_input_bit_width is None: - - @value - def input_quant_enabled(module): - if args.linear_input_bit_width is None and isinstance(module, nn.Linear): - return QuantType.FP - elif args.conv_input_bit_width is None and isinstance(module, nn.Conv2d): - return QuantType.FP - else: - return QuantType.INT - - input_kwargs['quant_type'] = input_quant_enabled - if args.input_scale_stats_op == 'minmax': @value @@ -303,11 +282,9 @@ def input_zp_stats_type(): input_kwargs['zero_point_stats_impl'] = input_zp_stats_type print("Applying model quantization...") - quantize_model( - pipe.unet, + quantizers = generate_quantizers( dtype=dtype, device=args.device, - name_blacklist=blacklist, weight_bit_width=weight_bit_width, weight_quant_format=args.weight_quant_format, weight_quant_type=args.weight_quant_type, @@ -326,23 +303,30 @@ def input_zp_stats_type(): input_quant_granularity=args.input_quant_granularity, use_ocp=args.use_ocp, input_kwargs=input_kwargs) + + layer_map = generate_quant_maps( + *quantizers, dtype, args.device, args.input_quant_format, False) + + linear_qkwargs = layer_map[torch.nn.Linear][1] + linear_qkwargs[ + 'input_quant'] = None if args.linear_input_bit_width is None else linear_qkwargs[ + 'input_quant'] + linear_qkwargs[ + 'weight_quant'] = None if args.linear_weight_bit_width == 0 else linear_qkwargs[ + 'weight_quant'] + layer_map[torch.nn.Linear] = (layer_map[torch.nn.Linear][0], linear_qkwargs) + + conv_qkwargs = layer_map[torch.nn.Conv2d][1] + conv_qkwargs['input_quant'] = None if args.conv_input_bit_width is None else conv_qkwargs[ + 'input_quant'] + conv_qkwargs['weight_quant'] = None if args.conv_weight_bit_width == 0 else conv_qkwargs[ + 'weight_quant'] + layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs) + + pipe.unet = layerwise_quantize( + model=pipe.unet, compute_layer_map=layer_map, name_blacklist=blacklist) print("Model quantization applied.") - skipped_layers = [] - for name, module in pipe.unet.named_modules(): - if 'time_emb' in name and not args.quantize_input_time_emb: - if hasattr(module, 'input_quant'): - module.input_quant.quant_injector = module.input_quant.quant_injector.let( - **{'quant_type': QuantType.FP}) - module.input_quant.init_tensor_quant() - skipped_layers.append(name) - if 'conv_in' in name and not args.quantize_input_conv_in: - if hasattr(module, 'input_quant'): - module.input_quant.quant_injector = module.input_quant.quant_injector.let( - **{'quant_type': QuantType.FP}) - module.input_quant.init_tensor_quant() - skipped_layers.append(name) - print(f"Skipped input quantization for layers: {skipped_layers}") pipe.set_progress_bar_config(disable=True) if args.dry_run: @@ -427,15 +411,13 @@ def input_zp_stats_type(): compute_mlperf_fid(args.model, args.path_to_coco, pipe, args.prompt, output_dir) else: print(f"Computing accuracy on default prompt") - prompts = list() - for i, v in enumerate(TESTING_PROMPTS): - if i == args.prompt: - break - prompts.append(v) + testing_prompts = TESTING_PROMPTS[:args.prompt] + assert args.prompt <= len(TESTING_PROMPTS), f"Only {len(TESTING_PROMPTS)} prompts are available" + quant_images = run_test_inference( pipe, args.resolution, - prompts, + testing_prompts, test_seeds, output_dir, args.device, @@ -461,8 +443,8 @@ def input_zp_stats_type(): if args.export_target: # Move to cpu and to float32 to enable CPU export - if not (dtype == torch.float16 and args.export_cuda_float16): - pipe.unet.to('cpu').to(dtype) + if args.export_cpu_float32: + pipe.unet.to('cpu').to(torch.float32) pipe.unet.eval() device = next(iter(pipe.unet.parameters())).device dtype = next(iter(pipe.unet.parameters())).dtype @@ -487,13 +469,6 @@ def input_zp_stats_type(): export_manager = StdQCDQONNXManager export_manager.change_weight_export(export_weight_q_node=args.export_weight_q_node) export_onnx(pipe, trace_inputs, output_dir, export_manager) - if args.export_target == 'torch': - if args.weight_quant_granularity == 'per_group': - export_manager = BlockQuantProxyLevelManager - else: - export_manager = TorchQCDQManager - export_manager.change_weight_export(export_weight_q_node=args.export_weight_q_node) - export_torch_export(pipe, trace_inputs, output_dir, export_manager) if __name__ == "__main__": @@ -583,11 +558,7 @@ def input_zp_stats_type(): default=False, help='Enable attention slicing. Default: Disabled') parser.add_argument( - '--export-target', - type=str, - default='', - choices=['', 'torch', 'onnx'], - help='Target export flow.') + '--export-target', type=str, default='', choices=['', 'onnx'], help='Target export flow.') add_bool_arg( parser, 'export-weight-q-node', @@ -708,7 +679,7 @@ def input_zp_stats_type(): default=False, help='Quantize input zero-point. Default: Enabled') add_bool_arg( - parser, 'export-cuda-float16', default=False, help='Export FP16 on CUDA. Default: Disabled') + parser, 'export-cpu-float32', default=False, help='Export FP32 on CPU. Default: Disabled') add_bool_arg( parser, 'use-mlperf-inference', @@ -729,23 +700,6 @@ def input_zp_stats_type(): 'dry-run', default=False, help='Generate a quantized model without any calibration. Default: Disabled') - add_bool_arg( - parser, - 'quantize-time-emb', - default=True, - help='Quantize time embedding layers. Default: True') - add_bool_arg( - parser, 'quantize-conv-in', default=True, help='Quantize first conv layer. Default: True') - add_bool_arg( - parser, - 'quantize-input-time-emb', - default=False, - help='Quantize input to time embedding layers. Default: Disabled') - add_bool_arg( - parser, - 'quantize-input-conv-in', - default=True, - help='Quantize input to first conv layer. Default: Enabled') args = parser.parse_args() print("Args: " + str(vars(args))) main(args) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py index 7ce70e783..70c9dda75 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/export.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -11,27 +11,8 @@ from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode -class UnetExportWrapper(nn.Module): - - def __init__(self, unet): - super().__init__() - self.unet = unet - - def forward(self, *args, **kwargs): - return self.unet(*args, **kwargs, return_dict=False) - - def export_onnx(pipe, trace_inputs, output_dir, export_manager): output_path = os.path.join(output_dir, 'unet.onnx') print(f"Saving unet to {output_path} ...") with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager): torch.onnx.export(pipe.unet, args=trace_inputs, f=output_path) - - -def export_torch_export(pipe, trace_inputs, output_dir, export_manager): - output_path = os.path.join(output_dir, 'unet.onnx') - print(trace_inputs[1]) - print(f"Saving unet to {output_path} ...") - with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager): - torch.export.export( - UnetExportWrapper(pipe.unet), args=(trace_inputs[0],), kwargs=trace_inputs[1]) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/utils.py b/src/brevitas_examples/stable_diffusion/sd_quant/utils.py index 2700dd032..b2c30176f 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/utils.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/utils.py @@ -3,106 +3,8 @@ SPDX-License-Identifier: MIT """ -from contextlib import contextmanager - import torch -from brevitas.export.common.handler.base import BaseHandler -from brevitas.export.manager import _set_proxy_export_handler -from brevitas.export.manager import _set_proxy_export_mode -from brevitas.export.manager import BaseManager -from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer -from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector - - -class InferenceWeightProxyHandler(BaseHandler): - handled_layer = WeightQuantProxyFromInjector - - def __init__(self): - super(InferenceWeightProxyHandler, self).__init__() - self.scale = None - self.zero_point = None - self.bit_width = None - self.dtype = None - self.float_weight = None - - def scaling_impl(self, proxy_module): - return proxy_module.tensor_quant.scaling_impl - - def zero_point_impl(self, proxy_module): - return proxy_module.tensor_quant.zero_point_impl - - def bit_width_impl(self, proxy_module): - return proxy_module.tensor_quant.msb_clamp_bit_width_impl - - def export_scale(self, proxy_module, bit_width): - scaling_impl = self.scaling_impl(proxy_module) - int_scaling_impl = proxy_module.tensor_quant.int_scaling_impl - int_threshold = int_scaling_impl(bit_width) - threshold = scaling_impl.stats_scaling_impl(scaling_impl.parameter_list_stats()) - return threshold / int_threshold - - def export_zero_point(self, proxy_module, weight, scale, bit_width): - zero_point_impl = self.zero_point_impl(proxy_module) - return zero_point_impl(weight, scale, bit_width) - - def prepare_for_export(self, module): - assert len(module.tracked_module_list) == 1, "Shared quantizers not supported." - self.bit_width = self.bit_width_impl(module)() - assert self.bit_width <= 8., "Only 8b or lower is supported." - quant_layer = module.tracked_module_list[0] - self.float_weight = quant_layer.quant_weight() - self.dtype = self.float_weight.value.dtype - # if (self.float_weight.zero_point != 0.).any(): - # self.zero_point = self.export_zero_point(module, quant_layer.weight, self.scale, self.bit_width).detach().cpu() - # self.scale = self.export_scale(module, self.bit_width).detach().cpu() - # quant_layer.weight.data = quant_layer.weight.data.cpu() - - def forward(self, x): - - return self.float_weight.value, self.float_weight.scale, self.float_weight.zero_point, self.bit_width - - -class InferenceWeightProxyManager(BaseManager): - handlers = [InferenceWeightProxyHandler] - - @classmethod - def set_export_handler(cls, module): - if hasattr(module, - 'requires_export_handler') and module.requires_export_handler and not isinstance( - module, (WeightQuantProxyFromInjector)): - return - _set_proxy_export_handler(cls, module) - - -def store_mapping_tensor_state_dict(model): - mapping = dict() - for module in model.modules(): - if isinstance(module, QuantWeightBiasInputOutputLayer): - mapping[module.weight.data_ptr()] = module.weight.device - return mapping - - -def restore_mapping(model, mapping): - for module in model.modules(): - if isinstance(module, QuantWeightBiasInputOutputLayer): - module.weight.data = module.weight.data.to(mapping[module.weight.data_ptr()]) - - -@contextmanager -def brevitas_proxy_inference_mode(model): - mapping = store_mapping_tensor_state_dict(model) - is_training = model.training - model.eval() - model.apply(InferenceWeightProxyManager.set_export_handler) - _set_proxy_export_mode(model, enabled=True, proxy_class=WeightQuantProxyFromInjector) - try: - yield model - finally: - restore_mapping(model, mapping) - _set_proxy_export_mode(model, enabled=False) - model.train(is_training) - def unet_input_shape(resolution): return (4, resolution // 8, resolution // 8) From 88184f22e500fbab85c03e196c1df3e7f28d823c Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 20 Jun 2024 19:52:49 +0100 Subject: [PATCH 18/22] update --- src/brevitas/graph/equalize.py | 22 +++-- .../common/generative/quantize.py | 33 ++++++- .../stable_diffusion/main.py | 93 ++++++++++++++++-- .../stable_diffusion/sd_quant/export.py | 97 +++++++++++++++++++ 4 files changed, 227 insertions(+), 18 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index fa63bf80d..c15f46a8f 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -201,6 +201,7 @@ def __init__( add_mul_node=True, layerwise=True, enabled=True, + blacklist_layers=None, co_optimize_act_weights=False) -> None: self.model = model self.alpha = alpha @@ -210,7 +211,8 @@ def __init__( if layerwise: if not self.add_mul_node: raise ValueError("Layerwise activation equalization requires add_mul_node") - self.graph_act_eq = LayerwiseActivationEqualization(self.model) + self.graph_act_eq = LayerwiseActivationEqualization( + self.model, blacklist_layers=blacklist_layers) else: if not isinstance(self.model, (TorchGraphModule, GraphModule)): raise TypeError( @@ -996,15 +998,21 @@ def remove_hooks(self): class LayerwiseActivationEqualization(ActivationEqualization): - def __init__(self, model, scale_computation_type: str = 'maxabs'): + def __init__( + self, + model, + scale_computation_type: str = 'maxabs', + blacklist_layers: Optional[List[str]] = None): super(LayerwiseActivationEqualization, self).__init__(model, scale_computation_type) self.float_act_map = {} self.batch_dim_act_map = {} self.hooks = [] self.add_mul_node = True + self.blacklist_layers = blacklist_layers regions: List[Region] = [] - self.find_module(model, regions) + name = '' + self.find_module(model, name, regions) self.regions = regions if self.scale_computation_type == 'maxabs': @@ -1012,20 +1020,22 @@ def __init__(self, model, scale_computation_type: str = 'maxabs'): elif self.scale_computation_type == 'range': self.scale_fn = _channel_range - def find_module(self, model, regions: List): + def find_module(self, model, name, regions: List): """ Iterate through the model looking at immediate children of every module to look for supported modules. This allows us to stop the search when we meet a top-level module that is supported. """ if isinstance(model, _supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)): + if self.blacklist_layers is not None and name in self.blacklist_layers: + return weight = get_weight_sink(model) eq_indexes = EqualizationIndexes(0, weight.shape[0], 0) region = Region(sinks={'sinks0': eq_indexes}, name_to_module={'sinks0': model}) regions.append(region) else: - for module in model.children(): - self.find_module(module, regions) + for name, module in model.named_children(): + self.find_module(module, name, regions) def setup(self): for region in self.regions: diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 44cc7262e..6d84cf1ad 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -16,6 +16,10 @@ from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat +from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZActPerTensorFloat +from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZWeightPerChannelFloat +from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZWeightPerTensorFloat +from brevitas.quant.experimental.float_quant_fnuz import Fp8e5m2FNUZActPerTensorFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat @@ -104,7 +108,15 @@ 'per_tensor': { 'sym': Fp8e5m2OCPWeightPerTensorFloat}, 'per_channel': { - 'sym': Fp8e5m2OCPWeightPerChannelFloat}}}}}} + 'sym': Fp8e5m2OCPWeightPerChannelFloat}}}}}, + 'float_fnuz': { + 'e4m3': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3FNUZWeightPerTensorFloat}, + 'per_channel': { + 'sym': Fp8e4m3FNUZWeightPerChannelFloat}}}}}} INPUT_QUANT_MAP = { 'int': { @@ -154,7 +166,19 @@ 'float_scale': { 'stats': { 'per_tensor': { - 'sym': Fp8e5m2OCPActPerTensorFloat}}}}}, + 'sym': Fp8e5m2OCPActPerTensorFloat}}}}}}, + 'float_fnuz': { + 'static': { + 'e4m3': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3FNUZActPerTensorFloat}}}}, + 'e5m2': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e5m2FNUZActPerTensorFloat}}}}}, 'dynamic': { 'e4m3': { 'float_scale': { @@ -183,6 +207,7 @@ def generate_quantizers( input_group_size=None, quantize_input_zero_point=False, use_ocp=False, + use_fnuz=False, device=None, weight_kwargs=None, input_kwargs=None): @@ -201,6 +226,8 @@ def generate_quantizers( weight_quant_format = 'float' if use_ocp: weight_quant_format += '_ocp' + elif use_fnuz: + weight_quant_format += '_fnuz' else: weight_float_format = {} if re.compile(r'e[1-8]m[1-8]').match(input_quant_format): @@ -211,6 +238,8 @@ def generate_quantizers( input_quant_format = 'float' if use_ocp: input_quant_format += '_ocp' + elif use_fnuz: + input_quant_format += '_fnuz' else: input_float_format = {} diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index af5b9203d..4527b65b6 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -13,6 +13,8 @@ from dependencies import value from diffusers import DiffusionPipeline from diffusers import StableDiffusionXLPipeline +from diffusers.models.attention_processor import Attention +from diffusers.models.attention_processor import AttnProcessor import numpy as np import pandas as pd import torch @@ -22,6 +24,7 @@ from brevitas.core.stats.stats_op import NegativeMinOrZero from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager +from brevitas.graph.base import ModuleToModuleByClass from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode from brevitas.graph.calibrate import load_quant_model_mode @@ -42,12 +45,15 @@ from brevitas_examples.stable_diffusion.sd_quant.constants import SD_2_1_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.export import export_onnx +from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params +from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention from brevitas_examples.stable_diffusion.sd_quant.utils import generate_latents from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_21_rand_inputs from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_xl_rand_inputs from brevitas_examples.stable_diffusion.sd_quant.utils import unet_input_shape TEST_SEED = 123456 +# TODO: add deterministc flags NEGATIVE_PROMPTS = ["normal quality, low quality, worst quality, low res, blurry, nsfw, nude"] @@ -206,10 +212,12 @@ def main(args): if args.activation_equalization: pipe.set_progress_bar_config(disable=True) - with activation_equalization_mode(pipe.unet, - alpha=args.act_eq_alpha, - layerwise=True, - add_mul_node=True): + with activation_equalization_mode( + pipe.unet, + alpha=args.act_eq_alpha, + layerwise=True, + blacklist_layers=blacklist if args.exclude_blacklist_act_eq else None, + add_mul_node=True): # Workaround to expose `in_features` attribute from the Hook Wrapper for m in pipe.unet.modules(): if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'): @@ -302,6 +310,7 @@ def input_zp_stats_type(): input_quant_type=args.input_quant_type, input_quant_granularity=args.input_quant_granularity, use_ocp=args.use_ocp, + use_fnuz=args.use_fnuz, input_kwargs=input_kwargs) layer_map = generate_quant_maps( @@ -323,6 +332,55 @@ def input_zp_stats_type(): 'weight_quant'] layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs) + if args.quantize_sdp_1 or args.quantize_sdp_2: + float_sdpa_quantizers = generate_quantizers( + dtype=dtype, + device=args.device, + weight_bit_width=weight_bit_width, + weight_quant_format='e4m3', + weight_quant_type='sym', + weight_param_method=args.weight_param_method, + weight_scale_precision=args.weight_scale_precision, + weight_quant_granularity=args.weight_quant_granularity, + weight_group_size=args.weight_group_size, + quantize_weight_zero_point=args.quantize_weight_zero_point, + quantize_input_zero_point=args.quantize_input_zero_point, + input_bit_width=input_bit_width, + input_quant_format='e4m3', + input_scale_type=args.input_scale_type, + input_scale_precision=args.input_scale_precision, + input_param_method=args.input_param_method, + input_quant_type='sym', + input_quant_granularity=args.input_quant_granularity, + use_ocp=args.use_ocp, + use_fnuz=args.use_fnuz, + input_kwargs=input_kwargs) + input_quant = float_sdpa_quantizers[0] + input_quant = input_quant.let(**{'bit_width': args.linear_output_bit_width}) + if args.quantize_sdp_2: + rewriter = ModuleToModuleByClass( + Attention, + QuantAttention, + softmax_output_quant=input_quant, + query_dim=lambda module: module.to_q.in_features, + dim_head=lambda module: int(1 / (module.scale ** 2)), + processor=AttnProcessor(), + is_equalized=args.activation_equalization) + import brevitas.config as config + config.IGNORE_MISSING_KEYS = True + pipe.unet = rewriter.apply(pipe.unet) + config.IGNORE_MISSING_KEYS = False + pipe.unet = pipe.unet.to(args.device) + pipe.unet = pipe.unet.to(dtype) + quant_kwargs = layer_map[torch.nn.Linear][1] + what_to_quantize = [] + if args.quantize_sdp_1: + what_to_quantize.extend(['to_q', 'to_k']) + if args.quantize_sdp_2: + what_to_quantize.extend(['to_v']) + quant_kwargs['output_quant'] = lambda module, name: input_quant if any(ending in name for ending in what_to_quantize) else None + layer_map[torch.nn.Linear] = (layer_map[torch.nn.Linear][0], quant_kwargs) + pipe.unet = layerwise_quantize( model=pipe.unet, compute_layer_map=layer_map, name_blacklist=blacklist) print("Model quantization applied.") @@ -469,6 +527,8 @@ def input_zp_stats_type(): export_manager = StdQCDQONNXManager export_manager.change_weight_export(export_weight_q_node=args.export_weight_q_node) export_onnx(pipe, trace_inputs, output_dir, export_manager) + if args.export_target == 'params_only': + export_quant_params(pipe, output_dir) if __name__ == "__main__": @@ -488,10 +548,7 @@ def input_zp_stats_type(): default=2, help='How many seeds to use for each image during validation. Default: 2') parser.add_argument( - '--prompt', - type=int, - default=4, - help='Number of prompt to use for testing. Default: 4. Max: 4') + '--prompt', type=int, default=4, help='Number of prompt to use for testing. Default: 4') parser.add_argument( '--calibration-prompt', type=int, @@ -558,7 +615,11 @@ def input_zp_stats_type(): default=False, help='Enable attention slicing. Default: Disabled') parser.add_argument( - '--export-target', type=str, default='', choices=['', 'onnx'], help='Target export flow.') + '--export-target', + type=str, + default='', + choices=['', 'onnx', 'params_only'], + help='Target export flow.') add_bool_arg( parser, 'export-weight-q-node', @@ -673,6 +734,11 @@ def input_zp_stats_type(): 'quantize-weight-zero-point', default=True, help='Quantize weight zero-point. Default: Enabled') + add_bool_arg( + parser, + 'exclude-blacklist-act-eq', + default=False, + help='Exclude unquantized layers from activation equalization. Default: Disabled') add_bool_arg( parser, 'quantize-input-zero-point', @@ -688,8 +754,13 @@ def input_zp_stats_type(): add_bool_arg( parser, 'use-ocp', - default=True, + default=False, help='Use OCP format for float quantization. Default: True') + add_bool_arg( + parser, + 'use-nfuz', + default=True, + help='Use NFUZ format for float quantization. Default: True') add_bool_arg( parser, 'use-negative-prompts', @@ -700,6 +771,8 @@ def input_zp_stats_type(): 'dry-run', default=False, help='Generate a quantized model without any calibration. Default: Disabled') + add_bool_arg(parser, 'quantize-sdp-1', default=False, help='Quantize SDP. Default: Disabled') + add_bool_arg(parser, 'quantize-sdp-2', default=False, help='Quantize SDP. Default: Disabled') args = parser.parse_args() print("Args: " + str(vars(args))) main(args) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py index 70c9dda75..541541a5f 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/export.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -3,11 +3,16 @@ SPDX-License-Identifier: MIT """ +import json import os +from safetensors.torch import save_file import torch from torch import nn +from brevitas.nn.equalized_layer import EqualizedModule +from brevitas.nn.quant_layer import QuantNonLinearActLayer +from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode @@ -16,3 +21,95 @@ def export_onnx(pipe, trace_inputs, output_dir, export_manager): print(f"Saving unet to {output_path} ...") with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager): torch.onnx.export(pipe.unet, args=trace_inputs, f=output_path) + + +def handle_quant_param(layer, layer_dict): + input_scale = layer.input_quant.export_handler.symbolic_kwargs['dequantize_symbolic_kwargs'][ + 'scale'].data + input_zp = layer.input_quant.export_handler.symbolic_kwargs['dequantize_symbolic_kwargs'][ + 'zero_point'].data - 128. + weight_scale = layer.weight_quant.export_handler.symbolic_kwargs['dequantize_symbolic_kwargs'][ + 'scale'].data + weight_zp = layer.weight_quant.export_handler.symbolic_kwargs['dequantize_symbolic_kwargs'][ + 'zero_point'].data - 128. # apply offset to have signed zp + if layer.output_quant.export_handler.symbolic_kwargs is not None: + output_scale = layer.output_quant.export_handler.symbolic_kwargs[ + 'dequantize_symbolic_kwargs']['scale'].data + + layer_dict['output_scale'] = output_scale.numpy().tolist() + layer_dict['output_scale_shape'] = output_scale.shape + layer_dict['input_scale'] = input_scale.numpy().tolist() + layer_dict['input_scale_shape'] = input_scale.shape + layer_dict['input_zp'] = input_zp.numpy().tolist() + layer_dict['input_zp_shape'] = input_zp.shape + layer_dict['input_zp_dtype'] = str(torch.int8) + layer_dict['weight_scale'] = weight_scale.numpy().tolist() + nelems = layer.weight.shape[0] + weight_scale_shape = [nelems] + [1] * (layer.weight.data.ndim - 1) + layer_dict['weight_scale_shape'] = weight_scale_shape + layer_dict['weight_zp'] = weight_zp.numpy().tolist() + layer_dict['weight_zp_shape'] = weight_scale_shape + layer_dict['weight_zp_dtype'] = str(torch.int8) + return layer_dict + + +def export_quant_params(pipe, output_dir): + quant_output_path = os.path.join(output_dir, 'quant_params.json') + output_path = os.path.join(output_dir, 'params.safetensors') + print(f"Saving unet to {output_path} ...") + from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager + quant_params = dict() + state_dict = pipe.unet.state_dict() + state_dict = {k: v for (k, v) in state_dict.items() if 'tensor_quant' not in k} + state_dict = {k: v for (k, v) in state_dict.items() if not k.endswith('.scale.weight')} + state_dict = {k.replace('.layer.', '.'): v for (k, v) in state_dict.items()} + + handled_quant_layers = set() + with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, StdQCDQONNXManager): + for name, module in pipe.unet.named_modules(): + if isinstance(module, EqualizedModule): + if id(module.layer) in handled_quant_layers: + raise RuntimeError("This should not happen") + if isinstance(module.layer, QuantWeightBiasInputOutputLayer): + layer_dict = dict() + full_name = name + smoothquant_param = module.scale.weight + + layer_dict['smoothquant_mul'] = smoothquant_param.data.numpy().tolist() + layer_dict['smoothquant_mul_shape'] = module.scale.runtime_shape + layer_dict = handle_quant_param(module.layer, layer_dict) + + quant_params[full_name] = layer_dict + handled_quant_layers.add(id(module.layer)) + else: + layer_dict = dict() + full_name = name + smoothquant_param = module.scale.weight + + layer_dict['smoothquant_mul'] = smoothquant_param.data.numpy().tolist() + layer_dict['smoothquant_mul_shape'] = module.scale.runtime_shape + quant_params[full_name] = layer_dict + handled_quant_layers.add(id(module.layer)) + elif isinstance( + module, + QuantWeightBiasInputOutputLayer) and id(module) not in handled_quant_layers: + layer_dict = dict() + layer_dict = handle_quant_param(module, layer_dict) + quant_params[full_name] = layer_dict + handled_quant_layers.add(id(module)) + elif isinstance(module, QuantNonLinearActLayer): + layer_dict = dict() + act_scale = module.act_quant.export_handler.symbolic_kwargs[ + 'dequantize_symbolic_kwargs']['scale'].data + act_zp = module.act_quant.export_handler.symbolic_kwargs[ + 'dequantize_symbolic_kwargs']['zero_point'].data + layer_dict['act_scale'] = act_scale.numpy().tolist() + layer_dict['act_scale_shape'] = act_scale.shape + layer_dict['act_zp'] = act_zp.to(torch.float32).numpy().tolist() + layer_dict['act_zp_shape'] = act_zp.shape + layer_dict['act_zp_dtype'] = str(act_zp.dtype) + quant_params[full_name] = layer_dict + handled_quant_layers.add(id(module)) + with open(quant_output_path, 'w') as file: + json.dump(quant_params, file, indent=" ") + save_file(state_dict, output_path) From b761430f3cd486bb4480e3889b6dbaffed621a2a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 20 Jun 2024 20:03:41 +0100 Subject: [PATCH 19/22] Missing file --- .../stable_diffusion/sd_quant/nn.py | 127 ++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 src/brevitas_examples/stable_diffusion/sd_quant/nn.py diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/nn.py b/src/brevitas_examples/stable_diffusion/sd_quant/nn.py new file mode 100644 index 000000000..e240c3a36 --- /dev/null +++ b/src/brevitas_examples/stable_diffusion/sd_quant/nn.py @@ -0,0 +1,127 @@ +from typing import Optional + +from diffusers.models.attention_processor import Attention +import torch + +from brevitas.graph.base import ModuleInstanceToModuleInstance +from brevitas.nn.equalized_layer import EqualizedModule +from brevitas.nn.quant_activation import QuantIdentity +from brevitas.nn.quant_scale_bias import ScaleBias +from brevitas.quant_tensor import _unpack_quant_tensor + + +class QuantAttention(Attention): + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block=False, + processor: Optional["AttnProcessor"] = None, + softmax_output_quant=None, + is_equalized=False): + super().__init__( + query_dim, + cross_attention_dim, + heads, + dim_head, + dropout, + bias, + upcast_attention, + upcast_softmax, + cross_attention_norm, + cross_attention_norm_num_groups, + added_kv_proj_dim, + norm_num_groups, + spatial_norm_dim, + out_bias, + scale_qk, + only_cross_attention, + eps, + rescale_output_factor, + residual_connection, + _from_deprecated_attn_block, + processor, + ) + + self.output_softmax_quant = QuantIdentity(softmax_output_quant) + if is_equalized: + replacements = [] + for n, m in self.named_modules(): + if isinstance(m, torch.nn.Linear): + in_channels = m.in_features + eq_m = EqualizedModule(ScaleBias(in_channels, False, (1, 1, -1)), m) + r = ModuleInstanceToModuleInstance(m, eq_m) + replacements.append(r) + for r in replacements: + r.apply(self) + + def get_attention_scores( + self, + query: torch.Tensor, + key: torch.Tensor, + attention_mask: torch.Tensor = None) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.device) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + attention_probs = _unpack_quant_tensor(self.output_softmax_quant(attention_probs)) + return attention_probs From 057bb0f2c94c5add551116c0117da8de5d174152 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 21 Jun 2024 13:07:35 +0100 Subject: [PATCH 20/22] Add clip score --- .../mlperf_evaluation/accuracy.py | 114 +++++++++++++++++- 1 file changed, 113 insertions(+), 1 deletion(-) diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py index fcb6a04df..8e10f107e 100644 --- a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py @@ -185,11 +185,14 @@ import os import pathlib import random +from typing import List, Optional, Tuple, Union import numpy as np +import open_clip from PIL import Image from scipy import linalg import torch +import torch.nn as nn from torch.nn.functional import adaptive_avg_pool2d import torchvision.transforms as TF from tqdm import tqdm @@ -207,6 +210,109 @@ log = logging.getLogger() +class CLIPEncoder(nn.Module): + """ + A class for encoding images and texts using a specified CLIP model and computing the similarity between them. + + Attributes: + ----------- + clip_version: str + The version of the CLIP model to be used. + pretrained: str + The pre-trained weights to load. + model: nn.Module + The CLIP model. + preprocess: Callable + The preprocessing transform to apply to the input image. + device: str + The device to which the model is moved. + """ + + def __init__( + self, + clip_version: str = 'ViT-B/32', + pretrained: Optional[str] = '', + cache_dir: Optional[str] = None, + device: str = 'cpu'): + """ + Initializes the CLIPEncoder with the specified CLIP model version and pre-trained weights. + + Parameters: + ----------- + clip_version: str, optional + The version of the CLIP model to be used. Defaults to 'ViT-B/32'. + pretrained: str, optional + The pre-trained weights to load. If not provided, it defaults based on clip_version. + cache_dir: str, optional + The directory to cache the model. Defaults to None. + device: str, optional + The device to which the model is moved. Defaults to 'cuda'. + """ + super().__init__() + + self.clip_version = clip_version + self.pretrained = pretrained if pretrained else self._get_default_pretrained() + + self.model, _, self.preprocess = open_clip.create_model_and_transforms(self.clip_version, + pretrained=self.pretrained, + cache_dir=cache_dir) + + self.model.eval() + self.model.to(device) + self.device = device + + def _get_default_pretrained(self) -> str: + """Returns the default pretrained weights based on the clip_version.""" + if self.clip_version == 'ViT-H-14': + return 'laion2b_s32b_b79k' + elif self.clip_version == 'ViT-g-14': + return 'laion2b_s12b_b42k' + else: + return 'openai' + + @torch.no_grad() + def get_clip_score( + self, text: Union[str, List[str]], image: Union[Image.Image, + torch.Tensor]) -> torch.Tensor: + """ + Computes the similarity score between the given text(s) and image using the CLIP model. + + Parameters: + ----------- + text: Union[str, List[str]] + The text or list of texts to compare with the image. + image: Image.Image + The input image. + + Returns: + -------- + torch.Tensor + The similarity score between the text(s) and image. + """ + # Preprocess the image and move it to the specified device + image = self.preprocess(image).unsqueeze(0).to(self.device) + + # Normalize the image features + image_features = self.model.encode_image(image).float() + image_features /= image_features.norm(dim=-1, keepdim=True) + + # If a single text string is provided, convert it to a list + if not isinstance(text, (list, tuple)): + text = [text] + + # Tokenize the text and move it to the specified device + text = open_clip.tokenize(text).to(self.device) + + # Normalize the text features + text_features = self.model.encode_text(text).float() + text_features /= text_features.norm(dim=-1, keepdim=True) + + # Compute the similarity between the image and text features + similarity = image_features @ text_features.T + + return similarity + + def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): """Numpy implementation of the Frechet Distance. The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) @@ -476,11 +582,17 @@ def start(self): self.results = [] def finalize(self, result_dict, ds=None, output_dir=None): + clip = CLIPEncoder(device=self.device) + dataset_size = len(self.results) log.info("Accumulating results") + for i in range(0, dataset_size): + caption = ds.get_caption(self.content_ids[i]) + generated = Image.fromarray(self.results[i]) + self.clip_scores.append(100 * clip.get_clip_score(caption, generated).item()) fid_score = compute_fid(self.results, self.statistics_path, self.device, ds=ds) result_dict["FID_SCORE"] = fid_score - + result_dict["CLIP_SCORE"] = np.mean(self.clip_scores) return result_dict From f4310442df01eafdc0408cfc341707d49262c214 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 21 Jun 2024 13:29:36 +0100 Subject: [PATCH 21/22] Review --- .../common/generative/quantize.py | 24 +++++++------- .../stable_diffusion/README.md | 32 ++++++++++++++----- .../stable_diffusion/main.py | 18 ++++++----- .../stable_diffusion/sd_quant/export.py | 4 +-- 4 files changed, 48 insertions(+), 30 deletions(-) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 6d84cf1ad..36bac29d5 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -215,14 +215,14 @@ def generate_quantizers( Replace float layers with quant layers in the target model """ # Retrive base input and weight quantizers - ocp_weight_format = None - ocp_input_format = None + std_float_weight_quant_format = None + std_float_input_format = None # match against custom float format if re.compile(r'e[1-8]m[1-8]').match(weight_quant_format): weight_float_format = { 'exponent_bit_width': int(weight_quant_format[1]), 'mantissa_bit_width': int(weight_quant_format[3])} - ocp_weight_format = weight_quant_format + std_float_weight_quant_format = weight_quant_format weight_quant_format = 'float' if use_ocp: weight_quant_format += '_ocp' @@ -234,7 +234,7 @@ def generate_quantizers( input_float_format = { 'exponent_bit_width': int(input_quant_format[1]), 'mantissa_bit_width': int(input_quant_format[3])} - ocp_input_format = input_quant_format + std_float_input_format = input_quant_format input_quant_format = 'float' if use_ocp: input_quant_format += '_ocp' @@ -243,8 +243,8 @@ def generate_quantizers( else: input_float_format = {} - if ocp_weight_format is not None: - weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][ocp_weight_format][ + if 'ocp' in weight_quant_format or 'fnuz' in weight_quant_format: + weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][std_float_weight_quant_format][ weight_scale_precision][weight_param_method][weight_quant_granularity][ weight_quant_type] else: @@ -255,16 +255,16 @@ def generate_quantizers( input_quant = sym_input_quant = linear_input_quant = INPUT_QUANT_MAP[input_quant_format][ input_scale_type][input_quant_type] elif input_bit_width is not None: - if ocp_input_format: - input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ocp_input_format][ - input_scale_precision][input_param_method][input_quant_granularity][ - input_quant_type] + if 'ocp' in input_quant_format or 'fnuz' in input_quant_format: + input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ + std_float_input_format][input_scale_precision][input_param_method][ + input_quant_granularity][input_quant_type] # Some activations in MHA should always be symmetric sym_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ - ocp_input_format][input_scale_precision][input_param_method][ + std_float_input_format][input_scale_precision][input_param_method][ input_quant_granularity]['sym'] linear_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ - ocp_input_format][input_scale_precision][input_param_method][ + std_float_input_format][input_scale_precision][input_param_method][ input_quant_granularity][input_quant_type] else: input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ diff --git a/src/brevitas_examples/stable_diffusion/README.md b/src/brevitas_examples/stable_diffusion/README.md index e1f98b7da..1685bd4a9 100644 --- a/src/brevitas_examples/stable_diffusion/README.md +++ b/src/brevitas_examples/stable_diffusion/README.md @@ -46,7 +46,7 @@ To add activation quantization: To choose between `static` or `dynamic` activation quantization, set the flag: `--input-scale-type` to either option To include export: -`--export-target torch` or `--export-target onnx` +`--export-target onnx` To perform a dry-run quantization, where only the structure of the quantized model is preserved but no calibration of the quantized parameter is performed, add the `--dry-run` flag. @@ -70,7 +70,7 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--gptq | --no-gptq] [--bias-correction | --no-bias-correction] [--dtype {float32,float16,bfloat16}] [--attention-slicing | --no-attention-slicing] - [--export-target {,onnx}] + [--export-target {,onnx,params_only}] [--export-weight-q-node | --no-export-weight-q-node] [--conv-weight-bit-width CONV_WEIGHT_BIT_WIDTH] [--linear-weight-bit-width LINEAR_WEIGHT_BIT_WIDTH] @@ -92,12 +92,15 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--input-scale-type {static,dynamic}] [--weight-group-size WEIGHT_GROUP_SIZE] [--quantize-weight-zero-point | --no-quantize-weight-zero-point] + [--exclude-blacklist-act-eq | --no-exclude-blacklist-act-eq] [--quantize-input-zero-point | --no-quantize-input-zero-point] [--export-cpu-float32 | --no-export-cpu-float32] [--use-mlperf-inference | --no-use-mlperf-inference] - [--use-ocp | --no-use-ocp] + [--use-ocp | --no-use-ocp] [--use-nfuz | --no-use-nfuz] [--use-negative-prompts | --no-use-negative-prompts] [--dry-run | --no-dry-run] + [--quantize-sdp-1 | --no-quantize-sdp-1] + [--quantize-sdp-2 | --no-quantize-sdp-2] Stable Diffusion quantization @@ -110,8 +113,7 @@ options: -b BATCH_SIZE, --batch-size BATCH_SIZE How many seeds to use for each image during validation. Default: 2 - --prompt PROMPT Number of prompt to use for testing. Default: 4. Max: - 4 + --prompt PROMPT Number of prompt to use for testing. Default: 4 --calibration-prompt CALIBRATION_PROMPT Number of prompt to use for calibration. Default: 2 --calibration-prompt-path CALIBRATION_PROMPT_PATH @@ -156,7 +158,7 @@ options: --attention-slicing Enable Enable attention slicing. Default: Disabled --no-attention-slicing Disable Enable attention slicing. Default: Disabled - --export-target {,onnx} + --export-target {,onnx,params_only} Target export flow. --export-weight-q-node Enable Enable export of floating point weights + QDQ @@ -169,11 +171,11 @@ options: --linear-weight-bit-width LINEAR_WEIGHT_BIT_WIDTH Weight bit width. Default: 8. --conv-input-bit-width CONV_INPUT_BIT_WIDTH - Input bit width. Default: None (not quantized) + Input bit width. Default: 0 (not quantized) --act-eq-alpha ACT_EQ_ALPHA Alpha for activation equalization. Default: 0.9 --linear-input-bit-width LINEAR_INPUT_BIT_WIDTH - Input bit width. Default: None (not quantized). + Input bit width. Default: 0 (not quantized). --weight-param-method {stats,mse} How scales/zero-point are determined. Default: stats. --input-param-method {stats,mse} @@ -216,6 +218,12 @@ options: Enable Quantize weight zero-point. Default: Enabled --no-quantize-weight-zero-point Disable Quantize weight zero-point. Default: Enabled + --exclude-blacklist-act-eq + Enable Exclude unquantized layers from activation + equalization. Default: Disabled + --no-exclude-blacklist-act-eq + Disable Exclude unquantized layers from activation + equalization. Default: Disabled --quantize-input-zero-point Enable Quantize input zero-point. Default: Enabled --no-quantize-input-zero-point @@ -233,6 +241,10 @@ options: True --no-use-ocp Disable Use OCP format for float quantization. Default: True + --use-nfuz Enable Use NFUZ format for float quantization. + Default: True + --no-use-nfuz Disable Use NFUZ format for float quantization. + Default: True --use-negative-prompts Enable Use negative prompts during generation/calibration. Default: Enabled @@ -243,5 +255,9 @@ options: calibration. Default: Disabled --no-dry-run Disable Generate a quantized model without any calibration. Default: Disabled + --quantize-sdp-1 Enable Quantize SDP. Default: Disabled + --no-quantize-sdp-1 Disable Quantize SDP. Default: Disabled + --quantize-sdp-2 Enable Quantize SDP. Default: Disabled + --no-quantize-sdp-2 Disable Quantize SDP. Default: Disabled ``` diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 4527b65b6..0718563e2 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -53,7 +53,7 @@ from brevitas_examples.stable_diffusion.sd_quant.utils import unet_input_shape TEST_SEED = 123456 -# TODO: add deterministc flags +torch.manual_seed(TEST_SEED) NEGATIVE_PROMPTS = ["normal quality, low quality, worst quality, low res, blurry, nsfw, nude"] @@ -318,7 +318,7 @@ def input_zp_stats_type(): linear_qkwargs = layer_map[torch.nn.Linear][1] linear_qkwargs[ - 'input_quant'] = None if args.linear_input_bit_width is None else linear_qkwargs[ + 'input_quant'] = None if args.linear_input_bit_width == 0 else linear_qkwargs[ 'input_quant'] linear_qkwargs[ 'weight_quant'] = None if args.linear_weight_bit_width == 0 else linear_qkwargs[ @@ -326,8 +326,8 @@ def input_zp_stats_type(): layer_map[torch.nn.Linear] = (layer_map[torch.nn.Linear][0], linear_qkwargs) conv_qkwargs = layer_map[torch.nn.Conv2d][1] - conv_qkwargs['input_quant'] = None if args.conv_input_bit_width is None else conv_qkwargs[ - 'input_quant'] + conv_qkwargs[ + 'input_quant'] = None if args.conv_input_bit_width == 0 else conv_qkwargs['input_quant'] conv_qkwargs['weight_quant'] = None if args.conv_weight_bit_width == 0 else conv_qkwargs[ 'weight_quant'] layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs) @@ -355,6 +355,8 @@ def input_zp_stats_type(): use_ocp=args.use_ocp, use_fnuz=args.use_fnuz, input_kwargs=input_kwargs) + # We generate all quantizers, but we are only interested in activation quantization for + # the output of softmax and the output of QKV input_quant = float_sdpa_quantizers[0] input_quant = input_quant.let(**{'bit_width': args.linear_output_bit_width}) if args.quantize_sdp_2: @@ -634,8 +636,8 @@ def input_zp_stats_type(): parser.add_argument( '--conv-input-bit-width', type=int, - default=None, - help='Input bit width. Default: None (not quantized)') + default=0, + help='Input bit width. Default: 0 (not quantized)') parser.add_argument( '--act-eq-alpha', type=float, @@ -644,8 +646,8 @@ def input_zp_stats_type(): parser.add_argument( '--linear-input-bit-width', type=int, - default=None, - help='Input bit width. Default: None (not quantized).') + default=0, + help='Input bit width. Default: 0 (not quantized).') parser.add_argument( '--weight-param-method', type=str, diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py index 541541a5f..ac737f276 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/export.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -27,7 +27,7 @@ def handle_quant_param(layer, layer_dict): input_scale = layer.input_quant.export_handler.symbolic_kwargs['dequantize_symbolic_kwargs'][ 'scale'].data input_zp = layer.input_quant.export_handler.symbolic_kwargs['dequantize_symbolic_kwargs'][ - 'zero_point'].data - 128. + 'zero_point'].data weight_scale = layer.weight_quant.export_handler.symbolic_kwargs['dequantize_symbolic_kwargs'][ 'scale'].data weight_zp = layer.weight_quant.export_handler.symbolic_kwargs['dequantize_symbolic_kwargs'][ @@ -49,7 +49,7 @@ def handle_quant_param(layer, layer_dict): layer_dict['weight_scale_shape'] = weight_scale_shape layer_dict['weight_zp'] = weight_zp.numpy().tolist() layer_dict['weight_zp_shape'] = weight_scale_shape - layer_dict['weight_zp_dtype'] = str(torch.int8) + layer_dict['weight_zp_dtype'] = str(torch.uint8) return layer_dict From 6542813f337857467507f48b33eb7b9a5549b4b2 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 21 Jun 2024 15:41:52 +0200 Subject: [PATCH 22/22] Update export.py --- src/brevitas_examples/stable_diffusion/sd_quant/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py index ac737f276..64bcac34f 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/export.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -49,7 +49,7 @@ def handle_quant_param(layer, layer_dict): layer_dict['weight_scale_shape'] = weight_scale_shape layer_dict['weight_zp'] = weight_zp.numpy().tolist() layer_dict['weight_zp_shape'] = weight_scale_shape - layer_dict['weight_zp_dtype'] = str(torch.uint8) + layer_dict['weight_zp_dtype'] = str(torch.int8) return layer_dict