diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index c7b5367a4..b1b5732f6 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -43,6 +43,7 @@ from brevitas_examples.common.parse_utils import add_bool_arg from brevitas_examples.common.parse_utils import quant_format_validator from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager +from brevitas_examples.llm.main import quant_inference_mode from brevitas_examples.stable_diffusion.mlperf_evaluation.accuracy import compute_mlperf_fid from brevitas_examples.stable_diffusion.sd_quant.constants import SD_2_1_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE @@ -247,7 +248,6 @@ def main(args): else: non_blacklist[name_to_add] += 1 print(f"Blacklisted layers: {set(blacklist)}") - print(f"Non blacklisted layers: {set(non_blacklist.keys())}") # Make sure there all LoRA layers are fused first, otherwise raise an error for m in pipe.unet.modules(): @@ -610,14 +610,29 @@ def sdpa_zp_stats_type(): # with brevitas_proxy_inference_mode(pipe.unet): if args.use_mlperf_inference: print(f"Computing accuracy with MLPerf pipeline") - compute_mlperf_fid( - args.model, - args.path_to_coco, - pipe, - args.prompt, - output_dir, - args.device, - not args.vae_fp16_fix) + with torch.no_grad(), quant_inference_mode(pipe.unet): + # Perform a single forward pass before evenutally compiling + run_val_inference( + pipe, + args.resolution, + [calibration_prompts[0]], # We need a list + test_seeds, + args.device, + dtype, + total_steps=1, + use_negative_prompts=args.use_negative_prompts, + test_latents=latents, + guidance_scale=args.guidance_scale) + if args.compile: + pipe.unet = torch.compile(pipe.unet) + compute_mlperf_fid( + args.model, + args.path_to_coco, + pipe, + args.prompt, + output_dir, + args.device, + not args.vae_fp16_fix) else: print(f"Computing accuracy on default prompt") testing_prompts = TESTING_PROMPTS[:args.prompt] @@ -734,6 +749,8 @@ def sdpa_zp_stats_type(): 'attention-slicing', default=False, help='Enable attention slicing. Default: Disabled') + add_bool_arg( + parser, 'compile', default=False, help='Compile during inference. Default: Disabled') parser.add_argument( '--export-target', type=str,