diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index e0a4082ddd8b..4d0b9bef55f1 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -103,7 +103,7 @@ def parse_args(): ) parser.add_argument( "--vae_precision", - type="choice", + type=str, choices=["fp32", "fp16", "bf16"], default="fp32", help=(