diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index 2c7d839e4346..9be670505aad 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -103,7 +103,7 @@ def parse_args(): ) parser.add_argument( "--vae_precision", - type="choice", + type=str, choices=["fp32", "fp16", "bf16"], default="fp32", help=(