diff --git a/optimum/onnxruntime/trainer.py b/optimum/onnxruntime/trainer.py index 2b7fb654303..9bc2bb5134d 100644 --- a/optimum/onnxruntime/trainer.py +++ b/optimum/onnxruntime/trainer.py @@ -459,7 +459,14 @@ def _inner_training_loop( # Wrap the model with `ORTModule` logger.info("Wrap ORTModule for ONNX Runtime training.") - model = ORTModule(self.model) + if self.args.save_onnx: + from torch_ort import DebugOptions + + model = ORTModule( + self.model, DebugOptions(save_onnx=self.args.save_onnx, onnx_prefix=self.args.onnx_prefix) + ) + else: + model = ORTModule(self.model) self.model_wrapped = model self.model = model diff --git a/optimum/onnxruntime/training_args.py b/optimum/onnxruntime/training_args.py index b05da6a5ede..6aec362c07c 100644 --- a/optimum/onnxruntime/training_args.py +++ b/optimum/onnxruntime/training_args.py @@ -79,6 +79,29 @@ class ORTTrainingArguments(TrainingArguments): }, ) + save_onnx: Optional[bool] = field( + default=False, + metadata={ + "help": "Configure ORTModule to save onnx models. Defaults to False. \ + The output directory of the onnx models by default is set to args.output_dir. \ + To change the output directory, the environment variable ORTMODULE_SAVE_ONNX_PATH can be \ + set to the destination directory path." + }, + ) + + onnx_prefix: Optional[str] = field( + default=None, + metadata={"help": "Prefix for the saved ORTModule file names. Must be provided if save_onnx is True."}, + ) + + onnx_log_level: Optional[str] = field( + default="WARNING", + metadata={ + "help": "Configure ORTModule log level. Defaults to WARNING. \ + onnx_log_level can also be set to one of VERBOSE, INFO, WARNING, ERROR, FATAL." + }, + ) + # This method will not need to be overriden after the deprecation of `--adafactor` in version 5 of 🤗 Transformers. def __post_init__(self): # expand paths, if not os.makedirs("~/bar") will make directory @@ -244,6 +267,13 @@ def __post_init__(self): if version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.0") and self.fp16: raise ValueError("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0") + if self.save_onnx: + if not self.onnx_prefix: + raise ValueError("onnx_prefix must be provided if save_onnx is True") + if not os.getenv("ORTMODULE_SAVE_ONNX_PATH", None): + os.environ["ORTMODULE_SAVE_ONNX_PATH"] = self.output_dir + os.environ["ORTMODULE_LOG_LEVEL"] = self.onnx_log_level + if ( is_torch_available() and (self.device.type != "cuda")