Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove SharedDDP as it was deprecated from Transformers. #1443

Merged
merged 2 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ ARG PYTHON_EXE=$MINICONDA_PREFIX/bin/python
# (Optional) Intall test dependencies
RUN $PYTHON_EXE -m pip install git+https://github.com/huggingface/transformers
RUN $PYTHON_EXE -m pip install datasets accelerate evaluate coloredlogs absl-py rouge_score seqeval scipy sacrebleu nltk scikit-learn parameterized sentencepiece
RUN $PYTHON_EXE -m pip install fairscale deepspeed mpi4py
RUN $PYTHON_EXE -m pip install deepspeed mpi4py
# RUN $PYTHON_EXE -m pip install optuna ray sigopt wandb

# PyTorch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ RUN pip install pygit2 pgzip
# (Optional) Intall test dependencies
RUN pip install git+https://github.com/huggingface/transformers
RUN pip install datasets accelerate evaluate coloredlogs absl-py rouge_score seqeval scipy sacrebleu nltk scikit-learn parameterized sentencepiece
RUN pip install fairscale deepspeed mpi4py
RUN pip install deepspeed mpi4py
# RUN pip install optuna ray sigopt wandb

# Install onnxruntime-training dependencies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ RUN pip install pygit2 pgzip
# (Optional) Intall test dependencies
RUN pip install git+https://github.com/huggingface/transformers
RUN pip install datasets accelerate evaluate coloredlogs absl-py rouge_score seqeval scipy sacrebleu nltk scikit-learn parameterized sentencepiece
RUN pip install fairscale deepspeed mpi4py
RUN pip install deepspeed mpi4py
# RUN pip install optuna ray sigopt wandb

# Install onnxruntime-training dependencies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ ARG PYTHON_EXE=$MINICONDA_PREFIX/bin/python
# (Optional) Intall test dependencies
RUN $PYTHON_EXE -m pip install git+https://github.com/huggingface/transformers
RUN $PYTHON_EXE -m pip install datasets accelerate evaluate coloredlogs absl-py rouge_score seqeval scipy sacrebleu nltk scikit-learn parameterized sentencepiece
RUN $PYTHON_EXE -m pip install fairscale deepspeed mpi4py
RUN $PYTHON_EXE -m pip install deepspeed mpi4py
# RUN $PYTHON_EXE -m pip install optuna ray sigopt wandb

# PyTorch
Expand Down
87 changes: 20 additions & 67 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@

# Integrations must be imported before ML frameworks:
# isort: off
from transformers.integrations import (
hp_params,
is_fairscale_available,
)
from transformers.integrations import hp_params

from transformers.utils import is_accelerate_available
from packaging import version
Expand Down Expand Up @@ -60,7 +57,6 @@
from transformers.data.data_collator import DataCollator
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
from transformers.dependency_versions_check import dep_version_check
from transformers.file_utils import (
is_apex_available,
is_sagemaker_dp_enabled,
Expand Down Expand Up @@ -88,7 +84,6 @@
EvalPrediction,
HPSearchBackend,
PredictionOutput,
ShardedDDPOption,
TrainOutput,
denumpify_detensorize,
enable_full_determinism,
Expand Down Expand Up @@ -133,11 +128,6 @@
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm

if is_fairscale_available():
dep_version_check("fairscale")
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS

if TYPE_CHECKING:
import optuna

Expand Down Expand Up @@ -533,12 +523,7 @@ def _inner_training_loop(
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa

delay_optimizer_creation = (
self.sharded_ddp is not None
and self.sharded_ddp != ShardedDDPOption.SIMPLE
or self.fsdp is not None
or self.is_fsdp_enabled
)
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.fsdp is not None or self.is_fsdp_enabled

# Wrap the model with `ORTModule`
logger.info("Wrap ORTModule for ONNX Runtime training.")
Expand Down Expand Up @@ -582,7 +567,7 @@ def _inner_training_loop(

# as the model is wrapped, don't use `accelerator.prepare`
# this is for unhandled cases such as
# Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False

if delay_optimizer_creation:
Expand Down Expand Up @@ -793,10 +778,6 @@ def _inner_training_loop(
if args.max_grad_norm is not None and args.max_grad_norm > 0:
# deepspeed does its own clipping

if self.do_grad_scaling:
# AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer)

if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm)
elif hasattr(self.optimizer, "clip_grad_norm"):
Expand All @@ -812,18 +793,8 @@ def _inner_training_loop(
)

# Optimizer step
optimizer_was_run = True
if is_torch_tpu_available():
raise NotImplementedError("`ORTTrainer` is not supported by TPU!")
elif self.do_grad_scaling:
scale_before = self.scaler.get_scale()
self.scaler.step(self.optimizer)
self.scaler.update()
scale_after = self.scaler.get_scale()
optimizer_was_run = scale_before <= scale_after
else:
self.optimizer.step()
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
self.optimizer.step()
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped

if optimizer_was_run:
# Delay optimizer scheduling until metrics are generated
Expand Down Expand Up @@ -1689,19 +1660,8 @@ def _wrap_model(self, model, training=True, dataloader=None):
if not training:
return model

# Distributed training (should be after apex fp16 initialization)
if self.sharded_ddp is not None:
# Sharded DDP!
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
model = ShardedDDP(model, self.optimizer)
else:
raise NotImplementedError(
"Fairscale's zero_dp_2 and zero_dp_3 are not compatible with `torch_ort.ORTModule`"
" used in `ORTTrainer`. Use `--sharded_ddp simpe` or deepspeed stage 2 if you want"
"the gradient to be sharded."
)
# Distributed training using PyTorch FSDP
elif self.fsdp is not None:
if self.fsdp is not None:
try:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from torch_xla.distributed.fsdp import checkpoint_module
Expand Down Expand Up @@ -1806,27 +1766,20 @@ def create_optimizer(self):
else:
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")

if is_sagemaker_mp_enabled():
raise NotImplementedError(
Expand Down
2 changes: 0 additions & 2 deletions optimum/onnxruntime/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,6 @@ def prediction_step_ort(
has_labels = "labels" in inputs
inputs = self._prepare_inputs(inputs)

# XXX: adapt synced_gpus for fairscale as well
# Priority (handled in generate):
# gen_kwargs > model.generation_config > default GenerationConfig()

Expand Down Expand Up @@ -658,7 +657,6 @@ def prediction_step(
has_labels = "labels" in inputs
inputs = self._prepare_inputs(inputs)

# XXX: adapt synced_gpus for fairscale as well
# Priority (handled in generate):
# gen_kwargs > model.generation_config > default GenerationConfig()

Expand Down
25 changes: 1 addition & 24 deletions optimum/onnxruntime/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
HubStrategy,
IntervalStrategy,
SchedulerType,
ShardedDDPOption,
)
from transformers.training_args import OptimizerNames, default_logdir, logger
from transformers.utils import (
Expand Down Expand Up @@ -211,8 +210,6 @@ def __post_init__(self):
" `--half_precision_backend apex`: GPU bf16 is not supported by apex. Use"
" `--half_precision_backend cuda_amp` instead"
)
if not (self.sharded_ddp == "" or not self.sharded_ddp):
raise ValueError("sharded_ddp is not supported with bf16")

if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
if self.evaluation_strategy == IntervalStrategy.NO:
Expand Down Expand Up @@ -329,26 +326,6 @@ def __post_init__(self):
" during training"
)

if not (self.sharded_ddp == "" or not self.sharded_ddp):
warnings.warn(
"using `sharded_ddp` is deprecated and will be removed in version 4.33"
" of 🤗 Transformers. Use `fsdp` instead",
FutureWarning,
)
if isinstance(self.sharded_ddp, bool):
self.sharded_ddp = "simple" if self.sharded_ddp else ""
if isinstance(self.sharded_ddp, str):
self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()]
if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]:
raise ValueError(
"`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or "
'`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.'
)
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.SIMPLE in self.sharded_ddp:
raise ValueError("`--sharded_ddp simple` is not compatible with any other option.")
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")

if isinstance(self.fsdp, bool):
self.fsdp = "full_shard" if self.fsdp else ""
if isinstance(self.fsdp, str):
Expand Down Expand Up @@ -516,7 +493,7 @@ def __post_init__(self):
)

# if training args is specified, it will override the one specified in the accelerate config
if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0:
if self.half_precision_backend != "apex":
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
if self.fp16:
mixed_precision_dtype = "fp16"
Expand Down
2 changes: 1 addition & 1 deletion tests/onnxruntime/docker/Dockerfile_onnxruntime_trainer
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ ARG PYTHON_EXE=$MINICONDA_PREFIX/bin/python
# (Optional) Intall test dependencies
RUN $PYTHON_EXE -m pip install git+https://github.com/huggingface/transformers
RUN $PYTHON_EXE -m pip install datasets accelerate evaluate coloredlogs absl-py rouge_score seqeval scipy sacrebleu nltk scikit-learn parameterized sentencepiece
RUN $PYTHON_EXE -m pip install fairscale deepspeed mpi4py
RUN $PYTHON_EXE -m pip install deepspeed mpi4py
# RUN $PYTHON_EXE -m pip install optuna ray sigopt wandb

# PyTorch
Expand Down
Loading