From f85baca54ff2b3bf06e3741df7123452f6f3462e Mon Sep 17 00:00:00 2001 From: Rohan138 Date: Tue, 3 Sep 2024 17:15:35 +0000 Subject: [PATCH 1/3] change deepspeed to integrations.deepspeed --- optimum/onnxruntime/trainer.py | 2 +- optimum/onnxruntime/trainer_seq2seq.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/onnxruntime/trainer.py b/optimum/onnxruntime/trainer.py index 9bc2bb5134d..766d6ca0c21 100644 --- a/optimum/onnxruntime/trainer.py +++ b/optimum/onnxruntime/trainer.py @@ -55,7 +55,7 @@ from torch.utils.data import Dataset, RandomSampler from transformers.data.data_collator import DataCollator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow -from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled +from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled from transformers.modeling_utils import PreTrainedModel, unwrap_model from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer import Trainer diff --git a/optimum/onnxruntime/trainer_seq2seq.py b/optimum/onnxruntime/trainer_seq2seq.py index 2e43ee89e00..77a4e15bb83 100644 --- a/optimum/onnxruntime/trainer_seq2seq.py +++ b/optimum/onnxruntime/trainer_seq2seq.py @@ -19,7 +19,7 @@ import torch from torch import nn from torch.utils.data import Dataset -from transformers.deepspeed import is_deepspeed_zero3_enabled +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.trainer_utils import PredictionOutput from transformers.utils import is_accelerate_available, logging From dba908c81dfbe8d5ce49fa3593b986bbdafe3732 Mon Sep 17 00:00:00 2001 From: Rohan138 Date: Thu, 5 Sep 2024 19:44:54 +0000 Subject: [PATCH 2/3] add version check and change tpu to xla --- optimum/onnxruntime/trainer.py | 24 ++++++++++++++++++++---- optimum/onnxruntime/trainer_seq2seq.py | 10 +++++++++- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/optimum/onnxruntime/trainer.py b/optimum/onnxruntime/trainer.py index 766d6ca0c21..27710e4425f 100644 --- a/optimum/onnxruntime/trainer.py +++ b/optimum/onnxruntime/trainer.py @@ -53,9 +53,9 @@ import torch.distributed as dist from torch import nn from torch.utils.data import Dataset, RandomSampler +from transformers import __version__ as transformers_version from transformers.data.data_collator import DataCollator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow -from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled from transformers.modeling_utils import PreTrainedModel, unwrap_model from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer import Trainer @@ -81,7 +81,6 @@ is_apex_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, - is_torch_tpu_available, ) from ..utils import logging @@ -91,11 +90,28 @@ ) +if version.parse(transformers_version) >= version.parse("4.33"): + from transformers.integrations.deepspeed import ( + deepspeed_init, + deepspeed_load_checkpoint, + is_deepspeed_zero3_enabled, + ) +else: + from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled + if is_apex_available(): from apex import amp -if is_torch_tpu_available(check_device=False): - import torch_xla.core.xla_model as xm +if version.parse(transformers_version) >= version.parse("4.39"): + from transformers.utils import is_torch_xla_available + + if is_torch_xla_available(): + import torch_xla.core.xla_model as xm +else: + from transformers.utils import is_torch_tpu_available + + if is_torch_tpu_available(check_device=False): + import torch_xla.core.xla_model as xm if TYPE_CHECKING: import optuna diff --git a/optimum/onnxruntime/trainer_seq2seq.py b/optimum/onnxruntime/trainer_seq2seq.py index 77a4e15bb83..abe711574a0 100644 --- a/optimum/onnxruntime/trainer_seq2seq.py +++ b/optimum/onnxruntime/trainer_seq2seq.py @@ -17,9 +17,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +from packaging import version from torch import nn from torch.utils.data import Dataset -from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.trainer_utils import PredictionOutput from transformers.utils import is_accelerate_available, logging @@ -33,6 +33,14 @@ "The package `accelerate` is required to use the ORTTrainer. Please install it following https://huggingface.co/docs/accelerate/basic_tutorials/install." ) +from transformers import __version__ as transformers_version + + +if version.parse(transformers_version) >= version.parse("4.33"): + from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +else: + from transformers.deepspeed import is_deepspeed_zero3_enabled + logger = logging.get_logger(__name__) From 4c8199209c83729bdb3f12d2df8eb0ba385d00f8 Mon Sep 17 00:00:00 2001 From: Rohan138 Date: Fri, 6 Sep 2024 19:26:24 +0000 Subject: [PATCH 3/3] add version check --- optimum/onnxruntime/trainer.py | 12 ++++++------ optimum/onnxruntime/trainer_seq2seq.py | 7 ++----- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/optimum/onnxruntime/trainer.py b/optimum/onnxruntime/trainer.py index 27710e4425f..86c333adb3f 100644 --- a/optimum/onnxruntime/trainer.py +++ b/optimum/onnxruntime/trainer.py @@ -53,7 +53,6 @@ import torch.distributed as dist from torch import nn from torch.utils.data import Dataset, RandomSampler -from transformers import __version__ as transformers_version from transformers.data.data_collator import DataCollator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow from transformers.modeling_utils import PreTrainedModel, unwrap_model @@ -84,13 +83,17 @@ ) from ..utils import logging +from ..utils.import_utils import check_if_transformers_greater from .training_args import ORTOptimizerNames, ORTTrainingArguments from .utils import ( is_onnxruntime_training_available, ) -if version.parse(transformers_version) >= version.parse("4.33"): +if is_apex_available(): + from apex import amp + +if check_if_transformers_greater("4.33"): from transformers.integrations.deepspeed import ( deepspeed_init, deepspeed_load_checkpoint, @@ -99,10 +102,7 @@ else: from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled -if is_apex_available(): - from apex import amp - -if version.parse(transformers_version) >= version.parse("4.39"): +if check_if_transformers_greater("4.39"): from transformers.utils import is_torch_xla_available if is_torch_xla_available(): diff --git a/optimum/onnxruntime/trainer_seq2seq.py b/optimum/onnxruntime/trainer_seq2seq.py index abe711574a0..1565ffa6acb 100644 --- a/optimum/onnxruntime/trainer_seq2seq.py +++ b/optimum/onnxruntime/trainer_seq2seq.py @@ -17,12 +17,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch -from packaging import version from torch import nn from torch.utils.data import Dataset from transformers.trainer_utils import PredictionOutput from transformers.utils import is_accelerate_available, logging +from ..utils.import_utils import check_if_transformers_greater from .trainer import ORTTrainer @@ -33,10 +33,7 @@ "The package `accelerate` is required to use the ORTTrainer. Please install it following https://huggingface.co/docs/accelerate/basic_tutorials/install." ) -from transformers import __version__ as transformers_version - - -if version.parse(transformers_version) >= version.parse("4.33"): +if check_if_transformers_greater("4.33"): from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled else: from transformers.deepspeed import is_deepspeed_zero3_enabled