Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ORT Training] Some important updates of ONNX Runtime training APIs #1335

Merged
merged 18 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ CMD nvidia-smi
ENV DEBIAN_FRONTEND noninteractive

# Versions
# available options 3.8, 3.9, 3.10, 3.11
ARG PYTHON_VERSION=3.9
ARG TORCH_CUDA_VERSION=cu118
ARG TORCH_VERSION=2.0.0
Expand All @@ -34,7 +35,7 @@ SHELL ["/bin/bash", "-c"]
# Install and update tools to minimize security vulnerabilities
RUN apt-get update
RUN apt-get install -y software-properties-common wget apt-utils patchelf git libprotobuf-dev protobuf-compiler cmake \
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev && \
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev ffmpeg && \
apt-get clean
RUN unattended-upgrade
RUN apt-get autoremove -y
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ ARG TORCHVISION_VERSION=0.14.1
# Install and update tools to minimize security vulnerabilities
RUN apt-get update
RUN apt-get install -y software-properties-common wget apt-utils patchelf git libprotobuf-dev protobuf-compiler cmake \
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev && \
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev ffmpeg && \
apt-get clean
RUN unattended-upgrade
RUN apt-get autoremove -y
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ ARG TORCHVISION_VERSION=0.15.1
# Install and update tools to minimize security vulnerabilities
RUN apt-get update
RUN apt-get install -y software-properties-common wget apt-utils patchelf git libprotobuf-dev protobuf-compiler cmake \
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev && \
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev ffmpeg && \
apt-get clean
RUN unattended-upgrade
RUN apt-get autoremove -y
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ SHELL ["/bin/bash", "-c"]
# Install and update tools to minimize security vulnerabilities
RUN apt-get update
RUN apt-get install -y software-properties-common wget apt-utils patchelf git libprotobuf-dev protobuf-compiler cmake \
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev && \
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev ffmpeg && \
apt-get clean
RUN unattended-upgrade
RUN apt-get autoremove -y
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Optional

Expand Down Expand Up @@ -54,7 +55,7 @@
logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.26.0")
check_min_version("4.34.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

Expand Down Expand Up @@ -141,12 +142,28 @@ class ModelArguments:
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
token: str = field(
default=None,
metadata={
"help": (
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
)
},
)
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
"execute code present on the Hub on your local machine."
)
},
)
Expand All @@ -162,32 +179,24 @@ def collate_fn(examples):
return {"pixel_values": pixel_values, "labels": labels}


@dataclass
class InferenceArguments:
"""
Arguments for inference(evaluate, predict).
"""

inference_with_ort: bool = field(
default=False,
metadata={"help": "Whether use ONNX Runtime as backend for inference. Default set to false."},
)


def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ORTTrainingArguments, InferenceArguments))
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ORTTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args, inference_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1])
)
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args, inference_args = parser.parse_args_into_dataclasses()
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if model_args.use_auth_token is not None:
warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
Expand All @@ -200,6 +209,10 @@ def main():
handlers=[logging.StreamHandler(sys.stdout)],
)

if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()

log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
Expand All @@ -209,7 +222,7 @@ def main():
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")

Expand Down Expand Up @@ -238,7 +251,7 @@ def main():
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
task="image-classification",
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)
else:
data_files = {}
Expand Down Expand Up @@ -285,22 +298,25 @@ def compute_metrics(p):
finetuning_task="image-classification",
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
model = AutoModelForImageClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
image_processor = AutoImageProcessor.from_pretrained(
model_args.image_processor_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)

# Define torchvision transforms to be applied to each image.
Expand Down Expand Up @@ -367,7 +383,6 @@ def val_transforms(example_batch):
compute_metrics=compute_metrics,
tokenizer=image_processor,
data_collator=collate_fn,
feature="image-classification",
)

# Training
Expand All @@ -385,7 +400,7 @@ def val_transforms(example_batch):

# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(inference_with_ort=inference_args.inference_with_ort)
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

Expand Down
Loading
Loading