Skip to content

Commit

Permalink
remove sharedddp
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam Louly committed Oct 11, 2023
1 parent c8cf353 commit d61f046
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ ARG PYTHON_EXE=$MINICONDA_PREFIX/bin/python
# (Optional) Intall test dependencies
RUN $PYTHON_EXE -m pip install git+https://github.com/huggingface/transformers
RUN $PYTHON_EXE -m pip install datasets accelerate evaluate coloredlogs absl-py rouge_score seqeval scipy sacrebleu nltk scikit-learn parameterized sentencepiece
RUN $PYTHON_EXE -m pip install fairscale deepspeed mpi4py
RUN $PYTHON_EXE -m pip install deepspeed mpi4py
# RUN $PYTHON_EXE -m pip install optuna ray sigopt wandb

# PyTorch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ RUN pip install pygit2 pgzip
# (Optional) Intall test dependencies
RUN pip install git+https://github.com/huggingface/transformers
RUN pip install datasets accelerate evaluate coloredlogs absl-py rouge_score seqeval scipy sacrebleu nltk scikit-learn parameterized sentencepiece
RUN pip install fairscale deepspeed mpi4py
RUN pip install deepspeed mpi4py
# RUN pip install optuna ray sigopt wandb

# Install onnxruntime-training dependencies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ RUN pip install pygit2 pgzip
# (Optional) Intall test dependencies
RUN pip install git+https://github.com/huggingface/transformers
RUN pip install datasets accelerate evaluate coloredlogs absl-py rouge_score seqeval scipy sacrebleu nltk scikit-learn parameterized sentencepiece
RUN pip install fairscale deepspeed mpi4py
RUN pip install deepspeed mpi4py
# RUN pip install optuna ray sigopt wandb

# Install onnxruntime-training dependencies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ ARG PYTHON_EXE=$MINICONDA_PREFIX/bin/python
# (Optional) Intall test dependencies
RUN $PYTHON_EXE -m pip install git+https://github.com/huggingface/transformers
RUN $PYTHON_EXE -m pip install datasets accelerate evaluate coloredlogs absl-py rouge_score seqeval scipy sacrebleu nltk scikit-learn parameterized sentencepiece
RUN $PYTHON_EXE -m pip install fairscale deepspeed mpi4py
RUN $PYTHON_EXE -m pip install deepspeed mpi4py
# RUN $PYTHON_EXE -m pip install optuna ray sigopt wandb

# PyTorch
Expand Down
90 changes: 21 additions & 69 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@

# Integrations must be imported before ML frameworks:
# isort: off
from transformers.integrations import (
hp_params,
is_fairscale_available,
)
from transformers.integrations import hp_params

from transformers.utils import is_accelerate_available
from packaging import version
Expand Down Expand Up @@ -60,7 +57,6 @@
from transformers.data.data_collator import DataCollator
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
from transformers.dependency_versions_check import dep_version_check
from transformers.file_utils import (
is_apex_available,
is_sagemaker_dp_enabled,
Expand Down Expand Up @@ -88,7 +84,6 @@
EvalPrediction,
HPSearchBackend,
PredictionOutput,
ShardedDDPOption,
TrainOutput,
denumpify_detensorize,
enable_full_determinism,
Expand Down Expand Up @@ -133,11 +128,6 @@
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm

if is_fairscale_available():
dep_version_check("fairscale")
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS

if TYPE_CHECKING:
import optuna

Expand Down Expand Up @@ -533,12 +523,7 @@ def _inner_training_loop(
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa

delay_optimizer_creation = (
self.sharded_ddp is not None
and self.sharded_ddp != ShardedDDPOption.SIMPLE
or self.fsdp is not None
or self.is_fsdp_enabled
)
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.fsdp is not None or self.is_fsdp_enabled

# Wrap the model with `ORTModule`
logger.info("Wrap ORTModule for ONNX Runtime training.")
Expand Down Expand Up @@ -582,7 +567,7 @@ def _inner_training_loop(

# as the model is wrapped, don't use `accelerator.prepare`
# this is for unhandled cases such as
# Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False

if delay_optimizer_creation:
Expand Down Expand Up @@ -793,10 +778,6 @@ def _inner_training_loop(
if args.max_grad_norm is not None and args.max_grad_norm > 0:
# deepspeed does its own clipping

if self.do_grad_scaling:
# AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer)

if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm)
elif hasattr(self.optimizer, "clip_grad_norm"):
Expand All @@ -807,23 +788,12 @@ def _inner_training_loop(
model.clip_grad_norm_(args.max_grad_norm)
else:
self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
model.parameters(), args.max_grad_norm,
)

# Optimizer step
optimizer_was_run = True
if is_torch_tpu_available():
raise NotImplementedError("`ORTTrainer` is not supported by TPU!")
elif self.do_grad_scaling:
scale_before = self.scaler.get_scale()
self.scaler.step(self.optimizer)
self.scaler.update()
scale_after = self.scaler.get_scale()
optimizer_was_run = scale_before <= scale_after
else:
self.optimizer.step()
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
self.optimizer.step()
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped

if optimizer_was_run:
# Delay optimizer scheduling until metrics are generated
Expand Down Expand Up @@ -1689,19 +1659,8 @@ def _wrap_model(self, model, training=True, dataloader=None):
if not training:
return model

# Distributed training (should be after apex fp16 initialization)
if self.sharded_ddp is not None:
# Sharded DDP!
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
model = ShardedDDP(model, self.optimizer)
else:
raise NotImplementedError(
"Fairscale's zero_dp_2 and zero_dp_3 are not compatible with `torch_ort.ORTModule`"
" used in `ORTTrainer`. Use `--sharded_ddp simpe` or deepspeed stage 2 if you want"
"the gradient to be sharded."
)
# Distributed training using PyTorch FSDP
elif self.fsdp is not None:
if self.fsdp is not None:
try:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from torch_xla.distributed.fsdp import checkpoint_module
Expand Down Expand Up @@ -1806,27 +1765,20 @@ def create_optimizer(self):
else:
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")

if is_sagemaker_mp_enabled():
raise NotImplementedError(
Expand Down
2 changes: 0 additions & 2 deletions optimum/onnxruntime/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,6 @@ def prediction_step_ort(
has_labels = "labels" in inputs
inputs = self._prepare_inputs(inputs)

# XXX: adapt synced_gpus for fairscale as well
# Priority (handled in generate):
# gen_kwargs > model.generation_config > default GenerationConfig()

Expand Down Expand Up @@ -658,7 +657,6 @@ def prediction_step(
has_labels = "labels" in inputs
inputs = self._prepare_inputs(inputs)

# XXX: adapt synced_gpus for fairscale as well
# Priority (handled in generate):
# gen_kwargs > model.generation_config > default GenerationConfig()

Expand Down
25 changes: 1 addition & 24 deletions optimum/onnxruntime/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
HubStrategy,
IntervalStrategy,
SchedulerType,
ShardedDDPOption,
)
from transformers.training_args import OptimizerNames, default_logdir, logger
from transformers.utils import (
Expand Down Expand Up @@ -211,8 +210,6 @@ def __post_init__(self):
" `--half_precision_backend apex`: GPU bf16 is not supported by apex. Use"
" `--half_precision_backend cuda_amp` instead"
)
if not (self.sharded_ddp == "" or not self.sharded_ddp):
raise ValueError("sharded_ddp is not supported with bf16")

if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
if self.evaluation_strategy == IntervalStrategy.NO:
Expand Down Expand Up @@ -329,26 +326,6 @@ def __post_init__(self):
" during training"
)

if not (self.sharded_ddp == "" or not self.sharded_ddp):
warnings.warn(
"using `sharded_ddp` is deprecated and will be removed in version 4.33"
" of 🤗 Transformers. Use `fsdp` instead",
FutureWarning,
)
if isinstance(self.sharded_ddp, bool):
self.sharded_ddp = "simple" if self.sharded_ddp else ""
if isinstance(self.sharded_ddp, str):
self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()]
if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]:
raise ValueError(
"`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or "
'`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.'
)
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.SIMPLE in self.sharded_ddp:
raise ValueError("`--sharded_ddp simple` is not compatible with any other option.")
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")

if isinstance(self.fsdp, bool):
self.fsdp = "full_shard" if self.fsdp else ""
if isinstance(self.fsdp, str):
Expand Down Expand Up @@ -516,7 +493,7 @@ def __post_init__(self):
)

# if training args is specified, it will override the one specified in the accelerate config
if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0:
if self.half_precision_backend != "apex":
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
if self.fp16:
mixed_precision_dtype = "fp16"
Expand Down
2 changes: 1 addition & 1 deletion tests/onnxruntime/docker/Dockerfile_onnxruntime_trainer
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ ARG PYTHON_EXE=$MINICONDA_PREFIX/bin/python
# (Optional) Intall test dependencies
RUN $PYTHON_EXE -m pip install git+https://github.com/huggingface/transformers
RUN $PYTHON_EXE -m pip install datasets accelerate evaluate coloredlogs absl-py rouge_score seqeval scipy sacrebleu nltk scikit-learn parameterized sentencepiece
RUN $PYTHON_EXE -m pip install fairscale deepspeed mpi4py
RUN $PYTHON_EXE -m pip install deepspeed mpi4py
# RUN $PYTHON_EXE -m pip install optuna ray sigopt wandb

# PyTorch
Expand Down

0 comments on commit d61f046

Please sign in to comment.