Skip to content

Commit

Permalink
update from keras2onnx to tf2onnx (huggingface#15162)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored Jan 14, 2022
1 parent 1b730c3 commit ebc4edf
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 24 deletions.
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@
"jax>=0.2.8",
"jaxlib>=0.1.65",
"jieba",
"keras2onnx",
"nltk",
"numpy>=1.17",
"onnxconverter-common",
Expand Down Expand Up @@ -147,6 +146,7 @@
"starlette",
"tensorflow-cpu>=2.3",
"tensorflow>=2.3",
"tf2onnx",
"timeout-decorator",
"timm",
"tokenizers>=0.10.1",
Expand Down Expand Up @@ -229,8 +229,8 @@ def run(self):
extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic")
extras["sklearn"] = deps_list("scikit-learn")

extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "keras2onnx")
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "keras2onnx")
extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx")
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx")

extras["torch"] = deps_list("torch")

Expand All @@ -243,7 +243,7 @@ def run(self):

extras["tokenizers"] = deps_list("tokenizers")
extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
extras["onnx"] = deps_list("onnxconverter-common", "keras2onnx") + extras["onnxruntime"]
extras["onnx"] = deps_list("onnxconverter-common", "tf2onnx") + extras["onnxruntime"]
extras["modelcreation"] = deps_list("cookiecutter")

extras["sagemaker"] = deps_list("sagemaker")
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/convert_graph_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format

def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
"""
Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR
Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR)
Args:
nlp: The pipeline to be exported
Expand All @@ -312,10 +312,10 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
try:
import tensorflow as tf

from keras2onnx import __version__ as k2ov
from keras2onnx import convert_keras, save_model
from tf2onnx import __version__ as t2ov
from tf2onnx import convert_keras, save_model

print(f"Using framework TensorFlow: {tf.version.VERSION}, keras2onnx: {k2ov}")
print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}")

# Build
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"jax": "jax>=0.2.8",
"jaxlib": "jaxlib>=0.1.65",
"jieba": "jieba",
"keras2onnx": "keras2onnx",
"nltk": "nltk",
"numpy": "numpy>=1.17",
"onnxconverter-common": "onnxconverter-common",
Expand Down Expand Up @@ -57,6 +56,7 @@
"starlette": "starlette",
"tensorflow-cpu": "tensorflow-cpu>=2.3",
"tensorflow": "tensorflow>=2.3",
"tf2onnx": "tf2onnx",
"timeout-decorator": "timeout-decorator",
"timm": "timm",
"tokenizers": "tokenizers>=0.10.1",
Expand Down
12 changes: 6 additions & 6 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,12 @@
_sympy_available = False


_keras2onnx_available = importlib.util.find_spec("keras2onnx") is not None
_tf2onnx_available = importlib.util.find_spec("tf2onnx") is not None
try:
_keras2onnx_version = importlib_metadata.version("keras2onnx")
logger.debug(f"Successfully imported keras2onnx version {_keras2onnx_version}")
_tf2onnx_version = importlib_metadata.version("tf2onnx")
logger.debug(f"Successfully imported tf2onnx version {_tf2onnx_version}")
except importlib_metadata.PackageNotFoundError:
_keras2onnx_available = False
_tf2onnx_available = False

_onnx_available = importlib.util.find_spec("onnxruntime") is not None
try:
Expand Down Expand Up @@ -429,8 +429,8 @@ def is_coloredlogs_available():
return _coloredlogs_available


def is_keras2onnx_available():
return _keras2onnx_available
def is_tf2onnx_available():
return _tf2onnx_available


def is_onnx_available():
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
is_faiss_available,
is_flax_available,
is_ftfy_available,
is_keras2onnx_available,
is_librosa_available,
is_onnx_available,
is_pandas_available,
Expand All @@ -49,6 +48,7 @@
is_soundfile_availble,
is_spacy_available,
is_tensorflow_probability_available,
is_tf2onnx_available,
is_tf_available,
is_timm_available,
is_tokenizers_available,
Expand Down Expand Up @@ -246,9 +246,9 @@ def require_rjieba(test_case):
return test_case


def require_keras2onnx(test_case):
if not is_keras2onnx_available():
return unittest.skip("test requires keras2onnx")(test_case)
def require_tf2onnx(test_case):
if not is_tf2onnx_available():
return unittest.skip("test requires tf2onnx")(test_case)
else:
return test_case

Expand Down
10 changes: 5 additions & 5 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
_tf_gpu_memory_limit,
is_pt_tf_cross_test,
is_staging_test,
require_keras2onnx,
require_tf,
require_tf2onnx,
slow,
)
from transformers.utils import logging
Expand Down Expand Up @@ -254,24 +254,24 @@ def test_onnx_compliancy(self):

self.assertEqual(len(incompatible_ops), 0, incompatible_ops)

@require_keras2onnx
@require_tf2onnx
@slow
def test_onnx_runtime_optimize(self):
if not self.test_onnx:
return

import keras2onnx
import onnxruntime
import tf2onnx

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
model = model_class(config)
model(model.dummy_inputs)

onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset)
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)

onnxruntime.InferenceSession(onnx_model.SerializeToString())
onnxruntime.InferenceSession(onnx_model_proto.SerializeToString())

def test_keras_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down

0 comments on commit ebc4edf

Please sign in to comment.