diff --git a/optimum/onnxruntime/training_args.py b/optimum/onnxruntime/training_args.py index a0cb7c8e983..7c3171855c1 100644 --- a/optimum/onnxruntime/training_args.py +++ b/optimum/onnxruntime/training_args.py @@ -397,6 +397,7 @@ def __post_init__(self): ): raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.") self.fsdp_config["xla"] = self.fsdp_config.get("xla", False) + self.fsdp_config["xla_fsdp_v2"] = self.fsdp_config.get("xla_fsdp_v2", False) self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False) if self.fsdp_config["xla"]: if len(self.fsdp) > 0: