diff --git a/optimum/onnxruntime/trainer.py b/optimum/onnxruntime/trainer.py index 9bc2bb5134d..86c333adb3f 100644 --- a/optimum/onnxruntime/trainer.py +++ b/optimum/onnxruntime/trainer.py @@ -55,7 +55,6 @@ from torch.utils.data import Dataset, RandomSampler from transformers.data.data_collator import DataCollator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow -from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled from transformers.modeling_utils import PreTrainedModel, unwrap_model from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer import Trainer @@ -81,10 +80,10 @@ is_apex_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, - is_torch_tpu_available, ) from ..utils import logging +from ..utils.import_utils import check_if_transformers_greater from .training_args import ORTOptimizerNames, ORTTrainingArguments from .utils import ( is_onnxruntime_training_available, @@ -94,8 +93,25 @@ if is_apex_available(): from apex import amp -if is_torch_tpu_available(check_device=False): - import torch_xla.core.xla_model as xm +if check_if_transformers_greater("4.33"): + from transformers.integrations.deepspeed import ( + deepspeed_init, + deepspeed_load_checkpoint, + is_deepspeed_zero3_enabled, + ) +else: + from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled + +if check_if_transformers_greater("4.39"): + from transformers.utils import is_torch_xla_available + + if is_torch_xla_available(): + import torch_xla.core.xla_model as xm +else: + from transformers.utils import is_torch_tpu_available + + if is_torch_tpu_available(check_device=False): + import torch_xla.core.xla_model as xm if TYPE_CHECKING: import optuna diff --git a/optimum/onnxruntime/trainer_seq2seq.py b/optimum/onnxruntime/trainer_seq2seq.py index 2e43ee89e00..1565ffa6acb 100644 --- a/optimum/onnxruntime/trainer_seq2seq.py +++ b/optimum/onnxruntime/trainer_seq2seq.py @@ -19,10 +19,10 @@ import torch from torch import nn from torch.utils.data import Dataset -from transformers.deepspeed import is_deepspeed_zero3_enabled from transformers.trainer_utils import PredictionOutput from transformers.utils import is_accelerate_available, logging +from ..utils.import_utils import check_if_transformers_greater from .trainer import ORTTrainer @@ -33,6 +33,11 @@ "The package `accelerate` is required to use the ORTTrainer. Please install it following https://huggingface.co/docs/accelerate/basic_tutorials/install." ) +if check_if_transformers_greater("4.33"): + from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +else: + from transformers.deepspeed import is_deepspeed_zero3_enabled + logger = logging.get_logger(__name__)