Skip to content

Commit

Permalink
Fix: formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Sep 10, 2024
1 parent 223700a commit bec3da1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
8 changes: 7 additions & 1 deletion src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,13 @@ def input_zp_stats_type():
if args.use_mlperf_inference:
print(f"Computing accuracy with MLPerf pipeline")
compute_mlperf_fid(
args.model, args.path_to_coco, pipe, args.prompt, output_dir, args.device, not args.vae_fp16_fix)
args.model,
args.path_to_coco,
pipe,
args.prompt,
output_dir,
args.device,
not args.vae_fp16_fix)
else:
print(f"Computing accuracy on default prompt")
testing_prompts = TESTING_PROMPTS[:args.prompt]
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas_examples/stable_diffusion/sd_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def export_quant_params(pipe, output_dir, export_vae=False):
print(f"Saving vae to {vae_output_path} ...")
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
export_manager = StdQCDQONNXManager
export_manager.change_weight_export(export_weight_q_node=True) # We're exporting FP weights + quantization parameters
export_manager.change_weight_export(
export_weight_q_node=True) # We're exporting FP weights + quantization parameters
quant_params = dict()
state_dict = pipe.unet.state_dict()
state_dict = {k: v for (k, v) in state_dict.items() if 'tensor_quant' not in k}
Expand Down

0 comments on commit bec3da1

Please sign in to comment.