Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update transformers imports for deepspeed and is_torch_xla_available #2012

Merged
merged 3 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from torch.utils.data import Dataset, RandomSampler
from transformers.data.data_collator import DataCollator
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
Rohan138 marked this conversation as resolved.
Show resolved Hide resolved
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import Trainer
Expand Down
2 changes: 1 addition & 1 deletion optimum/onnxruntime/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from torch import nn
from torch.utils.data import Dataset
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
Rohan138 marked this conversation as resolved.
Show resolved Hide resolved
from transformers.trainer_utils import PredictionOutput
from transformers.utils import is_accelerate_available, logging

Expand Down
Loading