Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ORT Training] Some important updates of ONNX Runtime training APIs #1335

Merged
merged 18 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 118 additions & 39 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,18 @@

# isort: on

import huggingface_hub.utils as hf_hub_utils
import numpy as np
import torch
import torch.distributed as dist
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader, Dataset, RandomSampler
from transformers.data.data_collator import DataCollator
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
from transformers.dependency_versions_check import dep_version_check
from transformers.file_utils import (
is_apex_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tpu_available,
)
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback, TrainerState
Expand Down Expand Up @@ -99,6 +95,13 @@
speed_metrics,
)
from transformers.training_args import ParallelMode
from transformers.utils import (
is_apex_available,
is_peft_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tpu_available,
)

from ..exporters import TasksManager
from ..exporters.onnx import OnnxConfigWithPast, export, export_models, get_decoder_models_for_export
Expand Down Expand Up @@ -138,6 +141,9 @@
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS

if is_peft_available():
from peft import PeftModel

if TYPE_CHECKING:
import optuna

Expand Down Expand Up @@ -297,7 +303,6 @@ def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
feature: str = "feature-extraction",
args: ORTTrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
Expand All @@ -308,6 +313,7 @@ def __init__(
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
onnx_model_path: Union[str, os.PathLike] = None,
feature: Optional[str] = None,
):
super().__init__(
model=model,
Expand All @@ -333,7 +339,14 @@ def __init__(

self.model = model

self.feature = feature
if feature is None:
try:
self.feature = TasksManager.infer_task_from_model(self.model)
except KeyError:
pass
else:
self.feature = feature

self.onnx_model_path = onnx_model_path
self.exported_with_loss = False
if self.args.local_rank:
Expand Down Expand Up @@ -447,7 +460,12 @@ def train(
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")

if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and args.deepspeed is None:
if (
resume_from_checkpoint is not None
and not is_sagemaker_mp_enabled()
and not self.is_deepspeed_enabled
and not self.is_fsdp_enabled
):
self._load_from_checkpoint(resume_from_checkpoint)

# If model was re-initialized, put it on the right device and update self.model_wrapped
Expand All @@ -459,12 +477,25 @@ def train(
inner_training_loop = find_executable_batch_size(
self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
)
return inner_training_loop(
args=args,
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
)
if args.push_to_hub:
try:
# Disable progress bars when uploading models during checkpoints to avoid polluting stdout
hf_hub_utils.disable_progress_bars()
return inner_training_loop(
args=args,
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
)
finally:
hf_hub_utils.enable_progress_bars()
else:
return inner_training_loop(
args=args,
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
)

def _inner_training_loop(
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
Expand Down Expand Up @@ -514,14 +545,6 @@ def _inner_training_loop(
f" {args.max_steps}"
)

# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps and args.logging_steps < 1:
args.logging_steps = math.ceil(max_steps * args.logging_steps)
if args.eval_steps and args.eval_steps < 1:
args.eval_steps = math.ceil(max_steps * args.eval_steps)
if args.save_steps and args.save_steps < 1:
args.save_steps = math.ceil(max_steps * args.save_steps)

if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
if self.args.n_gpu > 1:
# nn.DataParallel(model) replicates the model, creating new variables and module
Expand Down Expand Up @@ -571,13 +594,30 @@ def _inner_training_loop(
self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None

# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
if args.logging_steps < 1:
self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
else:
self.state.logging_steps = args.logging_steps
if args.eval_steps is not None:
if args.eval_steps < 1:
self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
else:
self.state.eval_steps = args.eval_steps
if args.save_steps is not None:
if args.save_steps < 1:
self.state.save_steps = math.ceil(max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps

# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
self.model.gradient_checkpointing_enable()

model = self._wrap_model(self.model_wrapped) # Wrap unless the ORTModule is already wrapped, eg. wrap DDP

if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
if (is_sagemaker_mp_enabled() or self.is_fsdp_enabled) and resume_from_checkpoint is not None:
self._load_from_checkpoint(resume_from_checkpoint, model)

# as the model is wrapped, don't use `accelerator.prepare`
Expand Down Expand Up @@ -703,11 +743,27 @@ def _inner_training_loop(

self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

# Temp: remove after transformers 4.34 release
def get_dataloader_sampler(dataloader):
if hasattr(dataloader, "batch_sampler") and dataloader.batch_sampler is not None:
return get_dataloader_sampler(dataloader.batch_sampler)
elif hasattr(dataloader, "sampler"):
return dataloader.sampler

# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not args.ignore_data_skip:
for epoch in range(epochs_trained):
for _ in train_dataloader:
break
sampler = get_dataloader_sampler(train_dataloader)
is_random_sampler = isinstance(sampler, RandomSampler)
if not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler.
for _ in train_dataloader:
break
else:
# Otherwise we need to call the whooooole sampler cause there is some random operation added
# AT THE VERY END!
sampler = sampler if sampler is not None else []
_ = list(sampler)

total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
Expand All @@ -718,7 +774,7 @@ def _inner_training_loop(
self._past = None

steps_in_epoch = (
len(train_dataloader)
len(epoch_iterator)
if len_dataloader is not None
else args.max_steps * args.gradient_accumulation_steps
)
Expand All @@ -730,13 +786,13 @@ def _inner_training_loop(
rng_to_sync = False
steps_skipped = 0
if steps_trained_in_current_epoch > 0:
skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0
rng_to_sync = True

step = -1
for step, inputs in enumerate(train_dataloader):
for step, inputs in enumerate(epoch_iterator):
total_batched_samples += 1
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
Expand Down Expand Up @@ -893,12 +949,15 @@ def _inner_training_loop(
# Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
for checkpoint in checkpoints_sorted:
if checkpoint != self.state.best_model_checkpoint:
if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint)

self.control = self.callback_handler.on_train_end(args, self.state, self.control)

# Wait for the checkpoint to be uploaded.
self._finish_current_push()

return TrainOutput(self.state.global_step, train_loss, metrics)

def evaluate(
Expand Down Expand Up @@ -1177,7 +1236,8 @@ def evaluation_loop_ort(
loss, logits, labels = self.prediction_step_ort(
ort_model, inputs, prediction_loss_only, ignore_keys=ignore_keys
)
inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None
main_input_name = getattr(self.model, "main_input_name", "input_ids")
inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None

# Update containers on host
if loss is not None:
Expand Down Expand Up @@ -1207,7 +1267,11 @@ def evaluation_loop_ort(
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)

# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if args.eval_accumulation_steps is not None and self.accelerator.sync_gradients:
if (
args.eval_accumulation_steps is not None
and (step + 1) % args.eval_accumulation_steps == 0
and (self.accelerator.sync_gradients or version.parse(accelerate_version) > version.parse("0.20.3"))
):
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
Expand Down Expand Up @@ -1291,6 +1355,10 @@ def evaluation_loop_ort(

return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)

#
# Deprecated code
#

def prediction_loop_ort(
self,
dataloader: DataLoader,
Expand Down Expand Up @@ -1389,7 +1457,8 @@ def prediction_loop_ort(
loss, logits, labels = self.prediction_step_ort(
ort_model, inputs, prediction_loss_only, ignore_keys=ignore_keys
)
inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None
main_input_name = getattr(self.model, "main_input_name", "input_ids")
inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None

if loss is not None:
losses = loss.repeat(batch_size)
Expand Down Expand Up @@ -1561,7 +1630,11 @@ def compute_loss_ort(self, model, inputs, return_outputs=False):
self._past = outputs[self.args.past_index]

if labels is not None:
if "text-generation" in self.feature:
if is_peft_available() and isinstance(model, PeftModel):
model_name = unwrap_model(model.base_model)._get_name()
else:
model_name = unwrap_model(model)._get_name()
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
Expand Down Expand Up @@ -1714,18 +1787,24 @@ def _wrap_model(self, model, training=True, dataloader=None):

auto_wrap_policy = None
auto_wrapper_callable = None
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get(
"transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
)

if self.args.fsdp_config["min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["min_num_params"]
)
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
elif fsdp_transformer_layer_cls_to_wrap is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
for layer_class in fsdp_transformer_layer_cls_to_wrap:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)

auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
Expand Down
Loading