Skip to content

Commit

Permalink
Feat (brevitas_examples/sdxl): inference_mode + compile
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 17, 2024
1 parent f8c6d64 commit fca0882
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from brevitas_examples.common.parse_utils import add_bool_arg
from brevitas_examples.common.parse_utils import quant_format_validator
from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager
from brevitas_examples.llm.main import quant_inference_mode
from brevitas_examples.stable_diffusion.mlperf_evaluation.accuracy import compute_mlperf_fid
from brevitas_examples.stable_diffusion.sd_quant.constants import SD_2_1_EMBEDDINGS_SHAPE
from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE
Expand Down Expand Up @@ -247,7 +248,6 @@ def main(args):
else:
non_blacklist[name_to_add] += 1
print(f"Blacklisted layers: {set(blacklist)}")
print(f"Non blacklisted layers: {set(non_blacklist.keys())}")

# Make sure there all LoRA layers are fused first, otherwise raise an error
for m in pipe.unet.modules():
Expand Down Expand Up @@ -610,14 +610,29 @@ def sdpa_zp_stats_type():
# with brevitas_proxy_inference_mode(pipe.unet):
if args.use_mlperf_inference:
print(f"Computing accuracy with MLPerf pipeline")
compute_mlperf_fid(
args.model,
args.path_to_coco,
pipe,
args.prompt,
output_dir,
args.device,
not args.vae_fp16_fix)
with torch.no_grad(), quant_inference_mode(pipe.unet):
# Perform a single forward pass before evenutally compiling
run_val_inference(
pipe,
args.resolution,
[calibration_prompts[0]], # We need a list
test_seeds,
args.device,
dtype,
total_steps=1,
use_negative_prompts=args.use_negative_prompts,
test_latents=latents,
guidance_scale=args.guidance_scale)
if args.compile:
pipe.unet = torch.compile(pipe.unet)
compute_mlperf_fid(
args.model,
args.path_to_coco,
pipe,
args.prompt,
output_dir,
args.device,
not args.vae_fp16_fix)
else:
print(f"Computing accuracy on default prompt")
testing_prompts = TESTING_PROMPTS[:args.prompt]
Expand Down Expand Up @@ -734,6 +749,8 @@ def sdpa_zp_stats_type():
'attention-slicing',
default=False,
help='Enable attention slicing. Default: Disabled')
add_bool_arg(
parser, 'compile', default=False, help='Compile during inference. Default: Disabled')
parser.add_argument(
'--export-target',
type=str,
Expand Down

0 comments on commit fca0882

Please sign in to comment.