Skip to content

Commit

Permalink
fix style (#1715)
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang authored Feb 23, 2024
1 parent 990c203 commit b5f4de3
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import evaluate
import numpy as np
import torch
import transformers
from datasets import load_dataset
from PIL import Image
from torchvision.transforms import (
Expand All @@ -35,8 +36,6 @@
Resize,
ToTensor,
)

import transformers
from transformers import (
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
AutoConfig,
Expand All @@ -51,6 +50,7 @@

from optimum import ORTTrainer, ORTTrainingArguments


""" Fine-tuning a 🤗 Transformers model for image classification"""

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -447,4 +447,4 @@ def val_transforms(example_batch):


if __name__ == "__main__":
main()
main()
6 changes: 3 additions & 3 deletions examples/onnxruntime/training/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,9 +598,9 @@ def compute_metrics(eval_preds):
# Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=default_data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
else None,
preprocess_logits_for_metrics=(
preprocess_logits_for_metrics if training_args.do_eval and not is_torch_tpu_available() else None
),
)

# Training
Expand Down
6 changes: 3 additions & 3 deletions examples/onnxruntime/training/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,9 +617,9 @@ def compute_metrics(eval_preds):
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
else None,
preprocess_logits_for_metrics=(
preprocess_logits_for_metrics if training_args.do_eval and not is_torch_tpu_available() else None
),
)

# Training
Expand Down
2 changes: 2 additions & 0 deletions tests/bettertransformer/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class BetterTransformersBarkTest(BetterTransformersTestMixin, unittest.TestCase)
Since `Bark` is a text-to-speech model, it is preferrable
to define its own testing class.
"""

SUPPORTED_ARCH = ["bark"]

FULL_GRID = {
Expand Down Expand Up @@ -185,6 +186,7 @@ class BetterTransformersAudioTest(BetterTransformersTestMixin, unittest.TestCase
r"""
Testing suite for Audio models - tests all the tests defined in `BetterTransformersTestMixin`
"""

SUPPORTED_ARCH = ["wav2vec2", "hubert"]

FULL_GRID = {
Expand Down
1 change: 1 addition & 0 deletions tests/bettertransformer/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class BetterTransformersEncoderTest(BetterTransformersTestMixin):
- if the converted model produces the same logits as the original model.
- if the converted model is faster than the original model.
"""

SUPPORTED_ARCH = [
"albert",
"bert",
Expand Down
1 change: 1 addition & 0 deletions tests/bettertransformer/test_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class BetterTransformersEncoderDecoderTest(BetterTransformersTestMixin, unittest
- if the converted model produces the same logits as the original model.
- if the converted model is faster than the original model.
"""

SUPPORTED_ARCH = [
"bart",
"blenderbot",
Expand Down
1 change: 1 addition & 0 deletions tests/bettertransformer/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class BetterTransformersVisionTest(BetterTransformersTestMixin, unittest.TestCas
r"""
Testing suite for Vision Models - tests all the tests defined in `BetterTransformersTestMixin`
"""

SUPPORTED_ARCH = ["blip-2", "clip", "clip_text_model", "deit", "vilt", "vit", "vit_mae", "vit_msn", "yolos"]

def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preprocessor_kwargs):
Expand Down
28 changes: 18 additions & 10 deletions tests/exporters/onnx/test_onnx_config_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,12 @@ def test_onnx_config_with_loss(self):
ort_sess = onnxruntime.InferenceSession(
onnx_model_path.as_posix(),
providers=[
"CUDAExecutionProvider"
if torch.cuda.is_available()
and "CUDAExecutionProvider" in onnxruntime.get_available_providers()
else "CPUExecutionProvider"
(
"CUDAExecutionProvider"
if torch.cuda.is_available()
and "CUDAExecutionProvider" in onnxruntime.get_available_providers()
else "CPUExecutionProvider"
)
],
)
framework = "pt" if isinstance(model, PreTrainedModel) else "tf"
Expand Down Expand Up @@ -145,9 +147,12 @@ def test_onnx_decoder_model_with_config_with_loss(self):
ort_sess = onnxruntime.InferenceSession(
onnx_model_path.as_posix(),
providers=[
"CUDAExecutionProvider"
if torch.cuda.is_available() and "CUDAExecutionProvider" in onnxruntime.get_available_providers()
else "CPUExecutionProvider"
(
"CUDAExecutionProvider"
if torch.cuda.is_available()
and "CUDAExecutionProvider" in onnxruntime.get_available_providers()
else "CPUExecutionProvider"
)
],
)

Expand Down Expand Up @@ -206,9 +211,12 @@ def test_onnx_seq2seq_model_with_config_with_loss(self):
ort_sess = onnxruntime.InferenceSession(
onnx_model_path.as_posix(),
providers=[
"CUDAExecutionProvider"
if torch.cuda.is_available() and "CUDAExecutionProvider" in onnxruntime.get_available_providers()
else "CPUExecutionProvider"
(
"CUDAExecutionProvider"
if torch.cuda.is_available()
and "CUDAExecutionProvider" in onnxruntime.get_available_providers()
else "CPUExecutionProvider"
)
],
)
batch = 3
Expand Down

0 comments on commit b5f4de3

Please sign in to comment.