diff --git a/tensorflow_hub/BUILD b/tensorflow_hub/BUILD index 68df7469..e834177a 100644 --- a/tensorflow_hub/BUILD +++ b/tensorflow_hub/BUILD @@ -45,14 +45,8 @@ py_library( visibility = ["//visibility:public"], deps = [ # Dependencies of the tensorflow_hub library. - ":image_util", - ":keras_layer", - ":module", ":module_v2", - ":native_module", - ":saved_model_module", - ":feature_column", - ":feature_column_v2", + ":keras_layer", ":config", # Internal dependency., ], @@ -63,64 +57,14 @@ py_library( srcs = ["config.py"], srcs_version = "PY3", deps = [ + # Deps of config. ":compressed_module_resolver", - ":native_module", ":registry", ":resolver", ":uncompressed_module_resolver", ], ) -py_library( - name = "feature_column", - srcs = ["feature_column.py"], - srcs_version = "PY3", - deps = [ - ":module", - "//tensorflow_hub:expect_tensorflow_installed", - # "tensorflow/python/feature_column", - ], -) - -py_test( - name = "feature_column_test", - srcs = ["feature_column_test.py"], - python_version = "PY3", - srcs_version = "PY3", - deps = [ - ":tensorflow_hub", - ":test_utils", - "//tensorflow_hub:expect_numpy_installed", - "//tensorflow_hub:expect_tensorflow_installed", - ":expect_tensorflow_hub_includes_feature_column_apis", - ], -) - -py_library( - name = "feature_column_v2", - srcs = ["feature_column_v2.py"], - srcs_version = "PY3", - deps = [ - ":module", - "//tensorflow_hub:expect_tensorflow_installed", - # "tensorflow/python/feature_column", - ], -) - -py_test( - name = "feature_column_v2_test", - srcs = ["feature_column_v2_test.py"], - python_version = "PY3", - srcs_version = "PY3", - deps = [ - ":tensorflow_hub", - ":test_utils", - "//tensorflow_hub:expect_numpy_installed", - "//tensorflow_hub:expect_tensorflow_installed", - ":expect_tensorflow_hub_includes_feature_column_v2_apis", - ], -) - py_library( name = "compressed_module_resolver", srcs = ["compressed_module_resolver.py"], @@ -146,108 +90,6 @@ py_test( ], ) -py_library( - name = "image_util", - srcs = ["image_util.py"], - srcs_version = "PY3", - deps = [ - ":all_protos_py_pb2", - ":native_module", - ], -) - -py_test( - name = "image_util_test", - srcs = ["image_util_test.py"], - python_version = "PY3", - srcs_version = "PY3", - deps = [ - ":image_util", - ":module", - ":native_module", - "//tensorflow_hub:expect_tensorflow_installed", - ], -) - -py_library( - name = "module", - srcs = [ - "module.py", - "module_impl.py", - "module_spec.py", - ], - srcs_version = "PY3", - deps = [ - ":registry", - ":tensor_info", - "//tensorflow_hub:expect_tensorflow_installed", - ], -) - -py_test( - name = "module_test", - srcs = ["module_test.py"], - python_version = "PY3", - srcs_version = "PY3", - deps = [ - ":module", - ":tensor_info", - "//tensorflow_hub:expect_tensorflow_installed", - ], -) - -py_library( - name = "saved_model_module", - srcs = ["saved_model_module.py"], - srcs_version = "PY3", - deps = [ - ":module", - ":native_module", - ":saved_model_lib", - ":tf_utils", - "//tensorflow_hub:expect_tensorflow_installed", - ], -) - -py_test( - name = "saved_model_module_test", - srcs = ["saved_model_module_test.py"], - python_version = "PY3", - srcs_version = "PY3", - deps = [ - ":tensorflow_hub", - "//tensorflow_hub:expect_tensorflow_installed", - ], -) - -py_library( - name = "native_module", - srcs = ["native_module.py"], - srcs_version = "PY3", - deps = [ - ":all_protos_py_pb2", - ":meta_graph_lib", - ":module", - ":saved_model_lib", - ":tf_utils", - "//tensorflow_hub:expect_tensorflow_installed", - ], -) - -py_test( - name = "native_module_test", - srcs = ["native_module_test.py"], - python_version = "PY3", - srcs_version = "PY3", - deps = [ - ":all_protos_py_pb2", - ":native_module", - ":tensorflow_hub", - "//tensorflow_hub:expect_numpy_installed", - "//tensorflow_hub:expect_tensorflow_installed", - ], -) - py_library( name = "resolver", srcs = ["resolver.py"], @@ -266,10 +108,13 @@ py_test( python_version = "PY3", srcs_version = "PY3", deps = [ + ":compressed_module_resolver", + ":registry", ":resolver", ":tensorflow_hub", ":test_utils", ":tf_utils", + ":uncompressed_module_resolver", "//tensorflow_hub:expect_tensorflow_installed", ], ) @@ -420,7 +265,7 @@ tf_hub_proto_library( srcs = [ "image_module_info.proto", "module_attachment.proto", - "module_def.proto", + # End proto files. ], visibility = ["//:__subpackages__"], ) @@ -431,18 +276,6 @@ py_library( name = "expect_tensorflow_installed", ) -# We expect TensorFlow Hub to support feature_column_v2 related APIs. -# This is for internal bookkeeping only and will always be true. -py_library( - name = "expect_tensorflow_hub_includes_feature_column_v2_apis", -) - -# We expect TensorFlow Hub to support feature_column related APIs. -# This is for internal bookkeeping only and will always be true. -py_library( - name = "expect_tensorflow_hub_includes_feature_column_apis", -) - # We expect numpy to already be installed on the system, e.g. via # `pip install numpy` py_library( @@ -461,8 +294,7 @@ py_library( srcs = ["module_v2.py"], srcs_version = "PY3", deps = [ - ":module", - ":native_module", + ":registry", "//tensorflow_hub:expect_tensorflow_installed", ], ) diff --git a/tensorflow_hub/__init__.py b/tensorflow_hub/__init__.py index 69c0cfa3..13b02806 100644 --- a/tensorflow_hub/__init__.py +++ b/tensorflow_hub/__init__.py @@ -114,25 +114,9 @@ def _ensure_keras_2_importable(): # pylint: disable=g-import-not-at-top # pylint: disable=g-bad-import-order # Symbols exposed via tensorflow_hub. -from tensorflow_hub.feature_column_v2 import text_embedding_column_v2 -from tensorflow_hub.feature_column import image_embedding_column -from tensorflow_hub.feature_column import sparse_text_embedding_column -from tensorflow_hub.feature_column import text_embedding_column -from tensorflow_hub.image_util import attach_image_module_info -from tensorflow_hub.image_util import get_expected_image_size -from tensorflow_hub.image_util import get_num_image_channels -from tensorflow_hub.image_util import ImageModuleInfo from tensorflow_hub.keras_layer import KerasLayer -from tensorflow_hub.module import eval_function_for_module -from tensorflow_hub.module import load_module_spec -from tensorflow_hub.module import Module -from tensorflow_hub.module_spec import ModuleSpec from tensorflow_hub.module_v2 import load from tensorflow_hub.module_v2 import resolve -from tensorflow_hub.native_module import add_signature -from tensorflow_hub.native_module import attach_message -from tensorflow_hub.native_module import create_module_spec -from tensorflow_hub.saved_model_module import create_module_spec_from_saved_model from tensorflow_hub.version import __version__ from tensorflow_hub.config import _run, _get_extra_deps # pylint: disable=g-multiple-import diff --git a/tensorflow_hub/config.py b/tensorflow_hub/config.py index 314e82f9..d72591f1 100644 --- a/tensorflow_hub/config.py +++ b/tensorflow_hub/config.py @@ -14,8 +14,8 @@ # ============================================================================== """Configuration to bind implementations on the API.""" +# Begin imports for config.py. from tensorflow_hub import compressed_module_resolver -from tensorflow_hub import native_module from tensorflow_hub import registry from tensorflow_hub import resolver from tensorflow_hub import uncompressed_module_resolver @@ -34,7 +34,7 @@ def _install_default_resolvers(): def _run(): _install_default_resolvers() - registry.loader.add_implementation(native_module.Loader()) + # End add resolvers. def _get_extra_deps(): diff --git a/tensorflow_hub/e2e_test.py b/tensorflow_hub/e2e_test.py index 48581e3c..4637ab29 100644 --- a/tensorflow_hub/e2e_test.py +++ b/tensorflow_hub/e2e_test.py @@ -16,18 +16,12 @@ import os import tarfile -import tempfile from absl import logging import tensorflow as tf import tensorflow_hub as hub from tensorflow_hub import test_utils -from tensorflow_hub import tf_utils - -# pylint: disable=g-direct-tensorflow-import -from tensorflow.python.ops.lookup_ops import index_to_string_table_from_file -# pylint: enable=g-direct-tensorflow-import class End2EndTest(tf.test.TestCase): @@ -64,86 +58,36 @@ def _generate_module(self): self._create_tgz(module_export_path) def test_http_locations(self): - with tf.Graph().as_default(): - self._generate_module() - - m = hub.Module("http://localhost:%d/test_module.tgz" % self.server_port) - out = m(11) - with tf.compat.v1.Session() as sess: - self.assertAllClose(sess.run(out), 121) - - # Test caching using custom filesystem (file://) to make sure that the - # TF Hub library can operate on such paths. - try: - root_dir = "file://%s" % self.get_temp_dir() - cache_dir = "%s_%s" % (root_dir, "cache") - tf.compat.v1.gfile.MakeDirs(cache_dir) - os.environ["TFHUB_CACHE_DIR"] = cache_dir - m = hub.Module("http://localhost:%d/test_module.tgz" % self.server_port) - out = m(11) - with tf.compat.v1.train.MonitoredSession() as sess: - self.assertAllClose(sess.run(out), 121) - - cache_content = sorted(tf.compat.v1.gfile.ListDirectory(cache_dir)) - logging.info("Cache context: %s", str(cache_content)) - self.assertEqual(2, len(cache_content)) - self.assertTrue(cache_content[1].endswith(".descriptor.txt")) - module_files = sorted(tf.compat.v1.gfile.ListDirectory( - os.path.join(cache_dir, cache_content[0]))) - self.assertListEqual( - ["assets", "saved_model.pb", "tfhub_module.pb", "variables"], - module_files) - finally: - os.unsetenv("TFHUB_CACHE_DIR") - - def test_module_export_vocab_on_custom_fs(self): - root_dir = "file://%s" % self.get_temp_dir() - export_dir = "%s_%s" % (root_dir, "export") - tf.compat.v1.gfile.MakeDirs(export_dir) - # Create a module with a vocab file located on a custom filesystem. - vocab_dir = os.path.join(root_dir, "vocab_location") - tf.compat.v1.gfile.MakeDirs(vocab_dir) - vocab_filename = os.path.join(vocab_dir, "tokens.txt") - tf_utils.atomic_write_string_to_file(vocab_filename, "one", False) - - def create_assets_module_fn(): - - def assets_module_fn(): - indices = tf.compat.v1.placeholder(dtype=tf.int64, name="indices") - table = index_to_string_table_from_file( - vocabulary_file=vocab_filename, default_value="UNKNOWN") - outputs = table.lookup(indices) - hub.add_signature(inputs=indices, outputs=outputs) - - return assets_module_fn - - with tf.Graph().as_default(): - assets_module_fn = create_assets_module_fn() - spec = hub.create_module_spec(assets_module_fn) - embedding_module = hub.Module(spec) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.tables_initializer()) - embedding_module.export(export_dir, sess) - - module_files = tf.compat.v1.gfile.ListDirectory(export_dir) - self.assertListEqual( - ["assets", "saved_model.pb", "tfhub_module.pb", "variables"], - sorted(module_files)) - module_files = tf.compat.v1.gfile.ListDirectory(os.path.join(export_dir, - "assets")) - self.assertListEqual(["tokens.txt"], module_files) - - # def test_resolve(self): - # with tf.Graph().as_default(): - # self._generate_module() - - # module_dir = hub.resolve( - # "http://localhost:%d/test_module.tgz" % self.server_port) - # self.assertIn(tempfile.gettempdir(), module_dir) - # module_files = sorted(tf.compat.v1.gfile.ListDirectory(module_dir)) - # self.assertEqual( - # ["assets", "saved_model.pb", "tfhub_module.pb", "variables"], - # module_files) + self._generate_module() + + m = hub.load("http://localhost:%d/test_module.tgz" % self.server_port) + self.assertAllClose(m(11), 121) + + # Test caching using custom filesystem (file://) to make sure that the + # TF Hub library can operate on such paths. + try: + root_dir = "file://%s" % self.get_temp_dir() + cache_dir = "%s_%s" % (root_dir, "cache") + tf.compat.v1.gfile.MakeDirs(cache_dir) + os.environ["TFHUB_CACHE_DIR"] = cache_dir + m = hub.load("http://localhost:%d/test_module.tgz" % self.server_port) + self.assertAllClose(m(11), 121) + + cache_content = sorted(tf.compat.v1.gfile.ListDirectory(cache_dir)) + logging.info("Cache context: %s", str(cache_content)) + self.assertEqual(2, len(cache_content)) + self.assertTrue(cache_content[1].endswith(".descriptor.txt")) + module_files = sorted( + tf.compat.v1.gfile.ListDirectory( + os.path.join(cache_dir, cache_content[0]) + ) + ) + self.assertListEqual( + ["assets", "fingerprint.pb", "saved_model.pb", "variables"], + module_files, + ) + finally: + os.unsetenv("TFHUB_CACHE_DIR") def test_load(self): if not hasattr(tf.compat.v1.saved_model, "load_v2"): diff --git a/tensorflow_hub/estimator.py b/tensorflow_hub/estimator.py deleted file mode 100644 index 93b27469..00000000 --- a/tensorflow_hub/estimator.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities to use Modules with Estimators.""" - -import os - -from absl import logging -import tensorflow as tf -from tensorflow.compat.v1 import estimator as tf_estimator -from tensorflow_hub import tf_utils - -# A collection of pairs (key: string, module: Module) used internally to -# propagate modules from where they are defined to the export hook. -# The collection key is a tuple (not a string) in order to make it invisible -# from user apis such as `get_all_collection_keys()` and manual exporting to -# meta_graphs. -_EXPORT_MODULES_COLLECTION = ("__tfhub_export_modules",) - - -def register_module_for_export(module, export_name): - """Register a Module to be exported under `export_name`. - - Warning: Deprecated. This belongs to the hub.Module API and TF1 Hub format. - - This function registers `module` to be exported by `LatestModuleExporter` - under a subdirectory named `export_name`. - - Note that `export_name` must be unique for each module exported from the - current graph. It only controls the export subdirectory name and it has - no scope effects such as the `name` parameter during Module instantiation. - - THIS FUNCTION IS DEPRECATED. - - Args: - module: Module instance to be exported. - export_name: subdirectory name to use when performing the export. - - Raises: - ValueError: if `export_name` is already taken in the current graph. - """ - for used_name, _ in tf.compat.v1.get_collection(_EXPORT_MODULES_COLLECTION): - if used_name == export_name: - raise ValueError( - "There is already a module registered to be exported as %r" % - export_name) - tf.compat.v1.add_to_collection(_EXPORT_MODULES_COLLECTION, - (export_name, module)) - - -class LatestModuleExporter(tf_estimator.Exporter): - """Regularly exports registered modules into timestamped directories. - - Warning: Deprecated. This belongs to the hub.Module API and TF1 Hub format. - - Modules can be registered to be exported by this class by calling - `register_module_for_export` when constructing the graph. The - `export_name` provided determines the subdirectory name used when - exporting. - - In addition to exporting, this class also garbage collects older exports. - - Example use with EvalSpec: - - ```python - train_spec = tf.estimator.TrainSpec(...) - eval_spec = tf.estimator.EvalSpec( - input_fn=eval_input_fn, - exporters=[ - hub.LatestModuleExporter("tf_hub", serving_input_fn), - ]) - tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) - ``` - - See `LatestModuleExporter.export()` for a direct use example. - - THIS FUNCTION IS DEPRECATED. - """ - - def __init__(self, name, serving_input_fn, exports_to_keep=5): - """Creates an `Exporter` to use with `tf.estimator.EvalSpec`. - - Args: - name: unique name of this `Exporter`, which will be used in the export - path. - serving_input_fn: A function with no arguments that returns a - ServingInputReceiver. This is used with the `estimator` passed to - `export()` to build the graph (in PREDICT mode) that registers the - modules for export. The model in that graph is never run, so the actual - data provided by this input fn does not matter. - exports_to_keep: Number of exports to keep. Older exports will be garbage - collected. Defaults to 5. Set to None to disable garbage collection. - - Raises: - ValueError: if any argument is invalid. - """ - self._name = name - self._serving_input_fn = serving_input_fn - - self._exports_to_keep = exports_to_keep - if exports_to_keep is not None and exports_to_keep <= 0: - raise ValueError( - "`exports_to_keep`, if provided, must be a positive number") - - @property - def name(self): - return self._name - - def export(self, - estimator, - export_path, - checkpoint_path=None, - eval_result=None, - is_the_final_export=None): - """Actually performs the export of registered Modules. - - This method creates a timestamped directory under `export_path` - with one sub-directory (named `export_name`) per module registered - via `register_module_for_export`. - - Example use: - - ```python - estimator = ... (Create estimator with modules registered for export)... - exporter = hub.LatestModuleExporter("tf_hub", serving_input_fn) - exporter.export(estimator, export_path, estimator.latest_checkpoint()) - ``` - - Args: - estimator: the `Estimator` from which to export modules. - export_path: A string containing a directory where to write the export - timestamped directories. - checkpoint_path: The checkpoint path to export. If `None`, - `estimator.latest_checkpoint()` is used. - eval_result: Unused. - is_the_final_export: Unused. - - Returns: - The path to the created timestamped directory containing the exported - modules. - """ - if checkpoint_path is None: - checkpoint_path = estimator.latest_checkpoint() - - export_dir = tf_utils.get_timestamped_export_dir(export_path) - temp_export_dir = tf_utils.get_temp_export_dir(export_dir) - - session = _make_estimator_serving_session(estimator, self._serving_input_fn, - checkpoint_path) - with session: - export_modules = tf.compat.v1.get_collection(_EXPORT_MODULES_COLLECTION) - if export_modules: - for export_name, module in export_modules: - module_export_path = os.path.join(temp_export_dir, - tf.compat.as_bytes(export_name)) - module.export(module_export_path, session) - tf.compat.v1.gfile.Rename(temp_export_dir, export_dir) - tf_utils.garbage_collect_exports(export_path, self._exports_to_keep) - return export_dir - else: - logging.warn("LatestModuleExporter found zero modules to export. " - "Use hub.register_module_for_export() if needed.") - # No export_dir has been created. - return None - - -def _make_estimator_serving_session(estimator, serving_input_fn, - checkpoint_path): - """Returns a session constructed using `estimator` and `serving_input_fn`. - - The Estimator API does not provide an API to construct a graph and session, - making it necessary for this function to replicate how an estimator builds - a graph. - - This code is based on `Estimator.export_savedmodel` (another function that - has to replicate how an estimator builds a graph). - - Args: - estimator: tf.Estimator to use when constructing the session. - serving_input_fn: A function that takes no arguments and returns a - `ServingInputReceiver`. It is used to construct the session. - checkpoint_path: The checkpoint path to restore in the session. Must not be - None. - """ - with tf.Graph().as_default() as g: - mode = tf_estimator.ModeKeys.PREDICT - tf.compat.v1.train.create_global_step(g) - tf.compat.v1.set_random_seed(estimator.config.tf_random_seed) - serving_input_receiver = serving_input_fn() - - estimator_spec = estimator.model_fn( - features=serving_input_receiver.features, - labels=None, - mode=mode, - config=estimator.config) - - # pylint: disable=protected-access - # Note that MonitoredSession(), despite the name is not a Session, and - # can't be used to export Modules as one can't use them with Savers. - # As so this must use a raw tf.Session(). - session = tf.compat.v1.Session(config=estimator._session_config) - # pylint: enable=protected-access - - with session.as_default(): - # TODO(b/71839662): Consider if this needs to support TPUEstimatorSpec - # which does not have a scaffold member. - # pylint: disable=line-too-long - saver_for_restore = estimator_spec.scaffold.saver or tf.compat.v1.train.Saver( - sharded=True) - saver_for_restore.restore(session, checkpoint_path) - return session diff --git a/tensorflow_hub/estimator_test.py b/tensorflow_hub/estimator_test.py deleted file mode 100644 index 8e429a00..00000000 --- a/tensorflow_hub/estimator_test.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tensorflow_hub.estimator.""" - -import os -import tempfile - -import tensorflow as tf -from tensorflow.compat.v1 import estimator as tf_estimator -import tensorflow_hub as hub - -_TEXT_FEATURE_NAME = "text" -_EXPORT_MODULE_NAME = "embedding-text" - - -def _input_fn(): - """An input fn.""" - features = { - _TEXT_FEATURE_NAME: tf.constant([ - "Example 1 feature", "Example 2"]), - } - labels = tf.constant([False, True]) - return features, labels - - -def _serving_input_fn(): - """A serving input fn.""" - text_features = tf.compat.v1.placeholder(dtype=tf.string, shape=[None]) - return tf_estimator.export.ServingInputReceiver( - features={_TEXT_FEATURE_NAME: text_features}, - receiver_tensors=text_features) - - -def text_module_fn(): - weights = tf.compat.v1.get_variable( - "weights", dtype=tf.float32, shape=[100, 10]) - # initializer=tf.random_uniform_initializer()) - text = tf.compat.v1.placeholder(tf.string, shape=[None]) - hash_buckets = tf.compat.v1.string_to_hash_bucket_fast( - text, weights.get_shape()[0]) - embeddings = tf.compat.v1.gather(weights, hash_buckets) - hub.add_signature(inputs=text, outputs=embeddings) - - -def _get_model_fn(register_module=False): - def _model_fn(features, labels, mode): - """A model_fn that uses a mock TF-Hub module.""" - del labels - - spec = hub.create_module_spec(text_module_fn) - embedding = hub.Module(spec) - if register_module: - hub.register_module_for_export(embedding, _EXPORT_MODULE_NAME) - predictions = embedding(features[_TEXT_FEATURE_NAME]) - loss = tf.constant(0.0) - - global_step = tf.compat.v1.train.get_global_step() - train_op = tf.compat.v1.assign_add(global_step, 1) - - return tf_estimator.EstimatorSpec( - mode=mode, - predictions=predictions, - loss=loss, - train_op=train_op) - - return _model_fn - - -class EstimatorTest(tf.test.TestCase): - - def testLatestModuleExporterDirectly(self): - model_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) - export_base_dir = os.path.join( - tempfile.mkdtemp(dir=self.get_temp_dir()), "export") - - estimator = tf_estimator.Estimator( - _get_model_fn(register_module=True), model_dir=model_dir) - estimator.train(input_fn=_input_fn, steps=1) - - exporter = hub.LatestModuleExporter("exporter_name", _serving_input_fn) - export_dir = exporter.export(estimator=estimator, - export_path=export_base_dir, - eval_result=None, - is_the_final_export=None) - - # Check that a timestamped directory is created in the expected location. - timestamp_dirs = tf.compat.v1.gfile.ListDirectory(export_base_dir) - self.assertEquals(1, len(timestamp_dirs)) - self.assertEquals( - tf.compat.as_bytes(os.path.join(export_base_dir, timestamp_dirs[0])), - tf.compat.as_bytes(export_dir)) - - # Check the timestamped directory containts the exported modules inside. - expected_module_dir = os.path.join( - tf.compat.as_bytes(export_dir), - tf.compat.as_bytes(_EXPORT_MODULE_NAME)) - self.assertTrue(tf.compat.v1.gfile.IsDirectory(expected_module_dir)) - - def test_latest_module_exporter_with_no_modules(self): - model_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) - export_base_dir = os.path.join(tempfile.mkdtemp(dir=self.get_temp_dir()), - "export") - self.assertFalse(tf.compat.v1.gfile.Exists(export_base_dir)) - - estimator = tf_estimator.Estimator( - _get_model_fn(register_module=False), model_dir=model_dir) - estimator.train(input_fn=_input_fn, steps=1) - - exporter = hub.LatestModuleExporter("exporter_name", _serving_input_fn) - export_dir = exporter.export(estimator=estimator, - export_path=export_base_dir, - eval_result=None, - is_the_final_export=None) - - # Check the result. - self.assertIsNone(export_dir) - - # Check that a no directory has been created in the expected location. - self.assertFalse(tf.compat.v1.gfile.Exists(export_base_dir)) - - def test_latest_module_exporter_with_eval_spec(self): - model_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) - estimator = tf_estimator.Estimator( - _get_model_fn(register_module=True), model_dir=model_dir) - exporter = hub.LatestModuleExporter( - "tf_hub", _serving_input_fn, exports_to_keep=2) - estimator.train(_input_fn, max_steps=1) - export_base_dir = os.path.join(model_dir, "export", "tf_hub") - - exporter.export(estimator, export_base_dir) - timestamp_dirs = tf.compat.v1.gfile.ListDirectory(export_base_dir) - self.assertEquals(1, len(timestamp_dirs)) - oldest_timestamp = timestamp_dirs[0] - - expected_module_dir = os.path.join(export_base_dir, - timestamp_dirs[0], - _EXPORT_MODULE_NAME) - self.assertTrue(tf.compat.v1.gfile.IsDirectory(expected_module_dir)) - - exporter.export(estimator, export_base_dir) - timestamp_dirs = tf.compat.v1.gfile.ListDirectory(export_base_dir) - self.assertEquals(2, len(timestamp_dirs)) - - # Triggering yet another export should clean the oldest export. - exporter.export(estimator, export_base_dir) - timestamp_dirs = tf.compat.v1.gfile.ListDirectory(export_base_dir) - self.assertEquals(2, len(timestamp_dirs)) - self.assertFalse(oldest_timestamp in timestamp_dirs) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow_hub/feature_column.py b/tensorflow_hub/feature_column.py deleted file mode 100644 index b58e1249..00000000 --- a/tensorflow_hub/feature_column.py +++ /dev/null @@ -1,573 +0,0 @@ -# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities to use Modules as feature columns.""" - -import collections - -import tensorflow as tf -from tensorflow_hub import image_util -from tensorflow_hub import module - -# TODO(b/73987364): It is not possible to extend feature columns without -# depending on TensorFlow internal implementation details. -# pylint: disable=g-direct-tensorflow-import -from tensorflow.python.feature_column import feature_column -from tensorflow.python.feature_column import feature_column_v2 -# pylint: enable=g-direct-tensorflow-import - - -class DenseFeatureColumn( - feature_column._DenseColumn, # pylint: disable=protected-access - feature_column_v2.DenseColumn): - - @property - def dtype(self): - return tf.float32 - - -_MODULE_RESOURCE_STRING = "module" - - -def text_embedding_column(key, module_spec, trainable=False): - """Uses a Module to construct a dense representation from a text feature. - - TODO(b/131678043): This does not work yet with TF2. - - This feature column can be used on an input feature whose values are strings - of arbitrary size. - - The result of this feature column is the result of passing its `input` - through the module `m` instantiated from `module_spec`, as per - `result = m(input)`. The `result` must have dtype float32 and shape - `[batch_size, num_features]` with a known value of num_features. - - Example: - - ```python - comment = hub.text_embedding_column("comment", "/tmp/text-module") - feature_columns = [comment, ...] - ... - features = { - "comment": np.array(["wow, much amazing", "so easy", ...]), - ... - } - labels = np.array([[1], [0], ...]) - # If running TF 2.x, use `tf.compat.v1.estimator.inputs.numpy_input_fn` - input_fn = tf.estimator.inputs.numpy_input_fn(features, labels, - shuffle=True) - estimator = tf.estimator.DNNClassifier(hidden_units, feature_columns) - estimator.train(input_fn, max_steps=100) - ``` - - Args: - key: A string or `_FeatureColumn` identifying the text feature. - module_spec: A ModuleSpec defining the Module to instantiate or a path where - to load a ModuleSpec via `load_module_spec` - trainable: Whether or not the Module is trainable. False by default, meaning - the pre-trained weights are frozen. This is different from the ordinary - tf.feature_column.embedding_column(), but that one is intended for - training from scratch. - - Returns: - `_DenseColumn` that converts from text input. - - Raises: - ValueError: if module_spec is not suitable for use in this feature column. - """ - return _TextEmbeddingColumn( - key=key, module_spec_path=module_spec, trainable=trainable) - - -def _check_module_is_text_embedding(module_spec): - """Raises ValueError if `module_spec` is not a text-embedding module. - - Args: - module_spec: A `ModuleSpec` to test. - - Raises: - ValueError: if `module_spec` default signature is not compatible with - Tensor(string, shape=(?,)) -> Tensor(float32, shape=(?,K)). - """ - issues = [] - - # Find issues with signature inputs. - input_info_dict = module_spec.get_input_info_dict() - if len(input_info_dict) != 1: - issues.append("Module default signature must require only one input") - else: - input_info, = input_info_dict.values() - input_shape = input_info.get_shape() - if not (input_info.dtype == tf.string and input_shape.ndims == 1 and - input_shape.as_list() == [None]): - issues.append("Module default signature must have only one input " - "tf.Tensor(shape=(?,), dtype=string)") - - # Find issues with signature outputs. - output_info_dict = module_spec.get_output_info_dict() - if "default" not in output_info_dict: - issues.append("Module default signature must have a 'default' output.") - else: - output_info = output_info_dict["default"] - output_shape = output_info.get_shape() - if not (output_info.dtype == tf.float32 and output_shape.ndims == 2 and - not output_shape.as_list()[0] and output_shape.as_list()[1]): - issues.append("Module default signature must have a 'default' output of " - "tf.Tensor(shape=(?,K), dtype=float32).") - - if issues: - raise ValueError("Module is not a text-embedding: %r" % issues) - - -class _TextEmbeddingColumn( - DenseFeatureColumn, - collections.namedtuple("_ModuleEmbeddingColumn", - ("key", "module_spec_path", "trainable"))): - """Returned by text_embedding_column(). Do not use directly.""" - - def __init__(self, key, module_spec_path, trainable): - self.module_spec = module.as_module_spec(self.module_spec_path) - _check_module_is_text_embedding(self.module_spec) - super().__init__() - - @property - def _is_v2_column(self): - return True - - @property - def parents(self): - """See 'FeatureColumn` base class.""" - return [self.key] - - @property - def name(self): - """Returns string. Used for variable_scope and naming.""" - if not hasattr(self, "_name"): - key_name = self.key if isinstance(self.key, str) else self.key.name - self._name = "{}_hub_module_embedding".format(key_name) - return self._name - - def create_state(self, state_manager): - """Imports the module along with all variables.""" - # Note: state_manager._trainable is not public but is the pattern used - # to propagate the "trainable" state that used to be received via - # self._get_dense_tensor. - trainable = self.trainable and state_manager._trainable # pylint: disable=protected-access - m = module.Module(self.module_spec, trainable=trainable) - state_manager.add_resource(self, _MODULE_RESOURCE_STRING, m) - - def _transform_feature(self, inputs): - """Returns intermediate representation (usually a `Tensor`).""" - return inputs.get(self.key) - - def transform_feature(self, transformation_cache, state_manager): - return transformation_cache.get(self.key, state_manager) - - @property - def _parse_example_spec(self): - """Returns a `tf.Example` parsing spec as dict.""" - return self.parse_example_spec - - @property - def parse_example_spec(self): - """Returns a `tf.Example` parsing spec as dict.""" - return {self.key: tf.compat.v1.FixedLenFeature([1], tf.string)} - - @property - def _variable_shape(self): - """`TensorShape` of `_get_dense_tensor`, without batch dimension.""" - return self.variable_shape - - @property - def variable_shape(self): - """`TensorShape` of `_get_dense_tensor`, without batch dimension.""" - return self.module_spec.get_output_info_dict()["default"].get_shape()[1:] - - def _get_dense_tensor_for_input_tensor(self, input_tensor, text_module): - text_batch = tf.reshape(input_tensor, shape=[-1]) - return text_module(text_batch) - - def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): - """Returns a `Tensor`.""" - del weight_collections - input_tensor = inputs.get(self) - text_module = module.Module( - self.module_spec, trainable=self.trainable and trainable) - return self._get_dense_tensor_for_input_tensor(input_tensor, text_module) - - def get_dense_tensor(self, transformation_cache, state_manager): - """Returns a `Tensor`.""" - input_tensor = transformation_cache.get(self, state_manager) - text_module = state_manager.get_resource(self, _MODULE_RESOURCE_STRING) - return self._get_dense_tensor_for_input_tensor(input_tensor, text_module) - - def get_config(self): - if not isinstance(self.module_spec_path, str): - raise NotImplementedError( - "Can only generate a valid config for `hub.text_embedding_column`" - "that uses a string `module_spec`.\n\n" - "Got `type(module_spec)`: {}".format(type(self.module_spec_path))) - config = dict(zip(self._fields, self)) - return config - - @classmethod - def from_config(cls, config, custom_objects=None, columns_by_name=None): - copied_config = config.copy() - return cls(**copied_config) - - -def image_embedding_column(key, module_spec, image_size=None): - """Uses a Module to get a dense 1-D representation from the pixels of images. - - TODO(b/131678043): This does not work yet with TF2. - - This feature column can be used on images, represented as float32 tensors of - RGB pixel data in the range [0,1]. This can be read from a numeric_column() - if the tf.Example input data happens to have decoded images, all with the - same shape [height, width, 3]. More commonly, the input_fn will have code to - explicitly decode images, resize them (possibly after performing data - augmentation such as random crops etc.), and provide a batch of shape - [batch_size, height, width, 3]. - - The result of this feature column is the result of passing its `input` - through the module `m` instantiated from `module_spec`, as per - `result = m({"images": input})`. The `result` must have dtype float32 and - shape `[batch_size, num_features]` with a known value of num_features. - - Example: - - ```python - image_column = hub.image_embedding_column("embeddings", "/tmp/image-module") - feature_columns = [image_column, ...] - estimator = tf.estimator.LinearClassifier(feature_columns, ...) - height, width = hub.get_expected_image_size(image_column.module_spec) - input_fn = ... # Provides "embeddings" with shape [None, height, width, 3]. - estimator.train(input_fn, ...) - ``` - - Args: - key: A string or `_FeatureColumn` identifying the input image data. - module_spec: A string handle or a `ModuleSpec` identifying the module. - image_size: Optional. If specified it should be a tuple of image height and - width to use with the module. Note that it depends on the module on - whether the default size can be overridden and what the permissible - values are. - - Returns: - `_DenseColumn` that converts from pixel data. - - Raises: - ValueError: if module_spec is not suitable for use in this feature column. - """ - # Configuration stored in a feature column should be hashable or user can - # get a TypeError when using it with DenseFeatures. If a user passes a list - # cast it to a tuple to avoid wasted debugging time. - if isinstance(image_size, list): - image_size = tuple(image_size) - return _ImageEmbeddingColumn(key=key, module_spec_path=module_spec, - image_size=image_size) - - -def _check_module_is_image_embedding(module_spec, check_image_size): - """Raises ValueError if `module_spec` is not usable as image embedding. - - Args: - module_spec: A `_ModuleSpec` to test. - check_image_size: Whether to check for compatibility with - get_expected_image_size. - - Raises: - ValueError: if `module_spec` default signature is not compatible with - mappingan "images" input to a Tensor(float32, shape=(_,K)). - """ - issues = [] - - # Find issues with "default" signature inputs. The common signatures for - # image models prescribe a specific name; we trust it if we find it - # and if we can do the necessary inference of input shapes from it. - input_info_dict = module_spec.get_input_info_dict() - if (list(input_info_dict.keys()) != ["images"] or - input_info_dict["images"].dtype != tf.float32): - issues.append("Module 'default' signature must require a single input, " - "which must have type float32 and name 'images'.") - else: - try: - if check_image_size: - image_util.get_expected_image_size(module_spec) - except ValueError as e: - issues.append("Module does not support hub.get_expected_image_size(); " - "original error was:\n" + str(e)) # Raised again below. - - # Find issues with "default" signature outputs. We test that the dtype and - # shape is appropriate for use in input_layer(). - output_info_dict = module_spec.get_output_info_dict() - if "default" not in output_info_dict: - issues.append("Module 'default' signature must have a 'default' output.") - else: - output_type = output_info_dict["default"].dtype - output_shape = output_info_dict["default"].get_shape() - if not (output_type == tf.float32 and output_shape.ndims == 2 and - output_shape.dims[1].value): - issues.append("Module 'default' signature must have a 'default' output " - "of tf.Tensor(shape=(_,K), dtype=float32).") - - if issues: - raise ValueError("Module is not usable as image embedding: %r" % issues) - - -class _ImageEmbeddingColumn(DenseFeatureColumn, - collections.namedtuple("_ImageEmbeddingColumn", - ("key", "module_spec_path", - "image_size")) - ): - """Returned by image_embedding_column(). Do not use directly.""" - - def __init__(self, key, module_spec_path, image_size): - self.module_spec = module.as_module_spec(self.module_spec_path) - _check_module_is_image_embedding(self.module_spec, - check_image_size=self.image_size is None) - super().__init__() - - @property - def _is_v2_column(self): - return True - - @property - def parents(self): - """See 'FeatureColumn` base class.""" - return [self.key] - - @property - def name(self): - """Returns string. Used for variable_scope and naming.""" - if not hasattr(self, "_name"): - key_name = self.key if isinstance(self.key, str) else self.key.name - self._name = "{}_hub_module_embedding".format(key_name) - return self._name - - def create_state(self, state_manager): - """Imports the module along with all variables.""" - # Module is not trainable by default. - m = module.Module(self.module_spec) - state_manager.add_resource(self, _MODULE_RESOURCE_STRING, m) - - def _transform_feature(self, inputs): - """Returns intermediate representation (usually a `Tensor`).""" - return inputs.get(self.key) - - def transform_feature(self, transformation_cache, state_manager): - return transformation_cache.get(self.key, state_manager) - - @property - def _parse_example_spec(self): - """Returns a `tf.Example` parsing spec as dict.""" - return self.parse_example_spec - - @property - def parse_example_spec(self): - """Returns a `tf.Example` parsing spec as dict.""" - if self.image_size: - height, width = self.image_size - else: - height, width = image_util.get_expected_image_size(self.module_spec) - input_shape = [height, width, 3] - return {self.key: tf.compat.v1.FixedLenFeature(input_shape, tf.float32)} - - @property - def _variable_shape(self): - """`TensorShape` of `_get_dense_tensor`, without batch dimension.""" - return self.variable_shape - - @property - def variable_shape(self): - """`TensorShape` of `_get_dense_tensor`, without batch dimension.""" - return self.module_spec.get_output_info_dict()["default"].get_shape()[1:] - - def _get_dense_tensor_for_images(self, images, image_module): - return image_module({"images": images}) - - def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): - del weight_collections, trainable # Unused. - images = inputs.get(self) - image_module = module.Module(self.module_spec) - return self._get_dense_tensor_for_images(images, image_module) - - def get_dense_tensor(self, transformation_cache, state_manager): - images = transformation_cache.get(self, state_manager) - image_module = state_manager.get_resource(self, _MODULE_RESOURCE_STRING) - return self._get_dense_tensor_for_images(images, image_module) - - def get_config(self): - if not isinstance(self.module_spec_path, str): - raise NotImplementedError( - "Can only generate a valid config for `hub.image_embedding_column`" - "that uses a string `module_spec`.\n\n" - "Got `type(module_spec)`: {}".format(type(self.module_spec_path))) - config = dict(zip(self._fields, self)) - return config - - @classmethod - def from_config(cls, config, custom_objects=None, columns_by_name=None): - copied_config = config.copy() - return cls(**copied_config) - - -def sparse_text_embedding_column(key, - module_spec, - combiner, - default_value, - trainable=False): - """Uses a Module to construct dense representations from sparse text features. - - TODO(b/131678043): This does not work yet with TF2. - - The input to this feature column is a batch of multiple strings with - arbitrary size, assuming the input is a SparseTensor. - - This type of feature column is typically suited for modules that operate on - pre-tokenized text to produce token level embeddings which are combined with - the combiner into a text embedding. The combiner always treats the tokens as a - bag of words rather than a sequence. - - The output (i.e., transformed input layer) is a DenseTensor, with shape - [batch_size, num_embedding_dim]. - - For Example: - - ```python - comment = hub.sparse_text_embedding_column("comment", "/tmp/text_module") - feature_columns = [comment, ...] - ... - features = { - "comment": tf.SparseTensor(indices=[[0, 0], [1, 2]], - values=['sparse', 'embedding'], - dense_shape=[3, 4]), - ... - } - estimator = tf.estimator.DNNClassifier(hidden_units, feature_columns) - ``` - - Args: - key: A string or `_FeatureColumn` identifying the text feature. - module_spec: A string handle or a `_ModuleSpec` identifying the module. - combiner: a string specifying reducing op for embeddings in the same - Example. Currently, 'mean', 'sqrtn', 'sum' are supported. Using - combiner=None is undefined. - default_value: default value for Examples where the text feature is empty. - Note, it's recommended to have default_value consistent OOV tokens, in - case there was special handling of OOV in the text module. If None, the - text feature is assumed be non-empty for each Example. - trainable: Whether or not the Module is trainable. False by default, meaning - the pre-trained weights are frozen. This is different from the ordinary - tf.feature_column.embedding_column(), but that one is intended for - training from scratch. - - Returns: - `_DenseColumn` that converts from text input. - - Raises: - ValueError: if module_spec is not suitable for use in this feature column. - ValueError: if combiner not in ('mean', 'sqrtn', 'sum'). - """ - module_spec = module.as_module_spec(module_spec) - _check_module_is_text_embedding(module_spec) - if combiner not in ("mean", "sqrtn", "sum"): - raise ValueError("combiner must be 'mean', 'sqrtn' or 'sum': %r" % combiner) - return _SparseTextEmbeddingColumn( - key=key, - module_spec=module_spec, - trainable=trainable, - default_value=default_value, - combiner=combiner) - - -class _SparseTextEmbeddingColumn( - DenseFeatureColumn, # pylint: disable=protected-access - collections.namedtuple( - "_ModuleEmbeddingColumn", - ("key", "combiner", "module_spec", "default_value", "trainable"))): - """Returned by sparse_text_embedding_column(). Do not use directly.""" - - @property - def _is_v2_column(self): - return True - - @property - def parents(self): - """See 'FeatureColumn` base class.""" - return [self.key] - - @property - def name(self): - """Returns string. Used for variable_scope and naming.""" - if not hasattr(self, "_name"): - key_name = self.key if isinstance(self.key, str) else self.key.name - self._name = "{}_hub_module_embedding".format(key_name) - return self._name - - def _transform_feature(self, inputs): - """Returns intermediate representation (usually a `Tensor`).""" - return inputs.get(self.key) - - def transform_feature(self, transformation_cache, state_manager): - return transformation_cache.get(self.key, state_manager) - - @property - def _parse_example_spec(self): - """Returns a `tf.Example` parsing spec as dict.""" - return self.parse_example_spec - - @property - def parse_example_spec(self): - """Returns a `tf.Example` parsing spec as dict.""" - return {self.key: tf.compat.v1.VarLenFeature(tf.string)} - - @property - def _variable_shape(self): - """`TensorShape` of `_get_dense_tensor`, without batch dimension.""" - return self.variable_shape - - @property - def variable_shape(self): - """`TensorShape` of `_get_dense_tensor`, without batch dimension.""" - return self.module_spec.get_output_info_dict()["default"].get_shape()[1:] - - def _get_dense_tensor_for_inputs(self, text_batch, trainable): - m = module.Module(self.module_spec, trainable=self.trainable and trainable) - - if self.default_value is not None: - text_batch = tf.sparse.fill_empty_rows(text_batch, self.default_value)[0] - embedded_tokens = m(text_batch.values) - embedding_ids = tf.SparseTensor( - indices=text_batch.indices, - values=tf.range(tf.shape(text_batch.indices)[0], dtype=tf.int32), - dense_shape=text_batch.dense_shape) - - return tf.nn.embedding_lookup_sparse( - params=embedded_tokens, - sp_ids=embedding_ids, - sp_weights=None, - combiner=self.combiner) - - def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): - """Returns a `Tensor`.""" - del weight_collections - text_batch = inputs.get(self) - return self._get_dense_tensor_for_inputs(text_batch, self.trainable and - trainable) - - def get_dense_tensor(self, transformation_cache, state_manager): - """Returns a `Tensor`.""" - input_tensor = transformation_cache.get(self, state_manager) - return self._get_dense_tensor_for_inputs(input_tensor, self.trainable) diff --git a/tensorflow_hub/feature_column_test.py b/tensorflow_hub/feature_column_test.py deleted file mode 100644 index 19a58120..00000000 --- a/tensorflow_hub/feature_column_test.py +++ /dev/null @@ -1,516 +0,0 @@ -# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tensorflow_hub.feature_column.""" - -import os -import unittest - -import tensorflow as tf -import tensorflow_hub as hub - -# pylint: disable=g-import-not-at-top -# Use Keras 2. -version_fn = getattr(tf.keras, "version", None) -if version_fn and version_fn().startswith("3."): - import tf_keras as keras -else: - keras = tf.keras - -# pylint: disable=g-direct-tensorflow-import -from tensorflow.python.feature_column import feature_column_v2 -from tensorflow.python.ops.lookup_ops import HashTable -from tensorflow.python.ops.lookup_ops import KeyValueTensorInitializer -# pylint: enable=g-direct-tensorflow-import - - -def text_module_fn(): - embeddings = [ - ("", [0, 0, 0, 0]), # OOV items are mapped to this embedding. - ("hello world", [1, 2, 3, 4]), - ("pair-programming", [5, 5, 5, 5]), - ] - keys = tf.constant([item[0] for item in embeddings], dtype=tf.string) - indices = tf.constant(list(range(len(embeddings))), dtype=tf.int64) - tbl_init = KeyValueTensorInitializer(keys, indices) - table = HashTable(tbl_init, 0) - - weights_initializer = tf.cast( - tf.constant(list([item[1] for item in embeddings])), tf.float32) - - weights = tf.compat.v1.get_variable( - "weights", dtype=tf.float32, initializer=weights_initializer) - - text_tensor = tf.compat.v1.placeholder(dtype=tf.string, name="text", - shape=[None]) - indices_tensor = table.lookup(text_tensor) - embedding_tensor = tf.gather(weights, indices_tensor) - hub.add_signature(inputs=text_tensor, outputs=embedding_tensor) - - -def invalid_text_module_fn(): - text = tf.compat.v1.placeholder(tf.string, shape=[10]) - hub.add_signature(inputs=text, outputs=tf.zeros([10, 3])) - - -def export_module_spec(spec, export_path): - """Export module with random initialization.""" - with tf.compat.v1.Graph().as_default(): - m = hub.Module(spec) - with tf.compat.v1.Session() as session: - session.run(tf.compat.v1.initializers.global_variables()) - m.export(export_path, session) - - -class CommonColumnTest(tf.test.TestCase): - - def setUp(self): - self.spec = hub.create_module_spec(text_module_fn) - - @unittest.mock.patch.object(feature_column_v2._StateManagerImpl, - "add_resource") - def testFeatureColumnsWithResources(self, mock_add_resource): - feature_column = hub.text_embedding_column("text_a", self.spec) - self.assertTrue(feature_column_v2.is_feature_column_v2([feature_column])) - - -class TextEmbeddingColumnTest(tf.test.TestCase): - - def setUp(self): - self.spec = hub.create_module_spec(text_module_fn) - - def testVariableShape(self): - text_column = hub.text_embedding_column("text", self.spec, trainable=False) - self.assertEqual(text_column._variable_shape, [4]) - - def testParents(self): - text_column = hub.text_embedding_column("text", self.spec, trainable=False) - self.assertEqual(["text"], text_column.parents) - - def testMakeParseExampleSpec(self): - text_column = hub.text_embedding_column("text", self.spec, trainable=False) - parsing_spec = tf.compat.v1.feature_column.make_parse_example_spec( - [text_column]) - self.assertEqual( - parsing_spec, - {"text": tf.compat.v1.FixedLenFeature([1], dtype=tf.string)}) - - def testInputLayer(self): - features = { - "text_a": ["hello world", "pair-programming"], - "text_b": ["hello world", "oov token"], - } - feature_columns = [ - hub.text_embedding_column("text_a", self.spec, trainable=False), - hub.text_embedding_column("text_b", self.spec, trainable=False), - ] - with tf.Graph().as_default(): - input_layer = tf.compat.v1.feature_column.input_layer(features, - feature_columns) - with tf.compat.v1.train.MonitoredSession() as sess: - output = sess.run(input_layer) - self.assertAllEqual( - output, [[1, 2, 3, 4, 1, 2, 3, 4], [5, 5, 5, 5, 0, 0, 0, 0]]) - - def testDenseFeatures(self): - features = { - "text_a": ["hello world", "pair-programming"], - "text_b": ["hello world", "oov token"], - } - feature_columns = [ - hub.text_embedding_column("text_a", self.spec, trainable=False), - hub.text_embedding_column("text_b", self.spec, trainable=False), - ] - if not feature_column_v2.is_feature_column_v2(feature_columns): - self.skipTest("Resources not implemented in the state manager of feature " - "column v2.") - with tf.Graph().as_default(): - # We want to test with dense_features_v2.DenseFeatures. This symbol was - # added in https://github.com/tensorflow/tensorflow/commit/64586f18724f737393071125a91b19adf013cf8a. - feature_layer = keras.layers.DenseFeatures(feature_columns) - feature_layer_out = feature_layer(features) - with tf.compat.v1.train.MonitoredSession() as sess: - output = sess.run(feature_layer_out) - self.assertAllEqual( - output, [[1, 2, 3, 4, 1, 2, 3, 4], [5, 5, 5, 5, 0, 0, 0, 0]]) - - def testDenseFeatures_shareAcrossApplication(self): - features = { - "text": ["hello world", "pair-programming"], - } - feature_columns = [ - hub.text_embedding_column("text", self.spec, trainable=True), - ] - if not feature_column_v2.is_feature_column_v2(feature_columns): - self.skipTest("Resources not implemented in the state manager of feature " - "column v2.") - with tf.Graph().as_default(): - # We want to test with dense_features_v2.DenseFeatures. This symbol was - # added in https://github.com/tensorflow/tensorflow/commit/64586f18724f737393071125a91b19adf013cf8a. - feature_layer = keras.layers.DenseFeatures(feature_columns) - feature_layer_out_1 = feature_layer(features) - feature_layer_out_2 = feature_layer(features) - - # We define loss only on the first layer. Since layers should have shared - # weights, we expect the second layer will change too. - loss = feature_layer_out_1 - tf.constant(0.005) - optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.7) - train_op = optimizer.minimize(loss) - - with tf.compat.v1.train.MonitoredSession() as sess: - before_update_1 = sess.run(feature_layer_out_1) - sess.run(train_op) - after_update_1 = sess.run(feature_layer_out_1) - after_update_2 = sess.run(feature_layer_out_2) - - self.assertAllEqual(before_update_1, [[1, 2, 3, 4], - [5, 5, 5, 5]]) - self.assertAllEqual(after_update_1, after_update_2) - - def testTrainableEmbeddingColumn(self): - feature_columns = [ - hub.text_embedding_column("text", self.spec, trainable=True), - ] - - with tf.Graph().as_default(): - features = { - "text": ["hello world", "pair-programming"], - } - target = [[1, 1, 1, 1], [4, 3, 2, 1]] - input_layer = tf.compat.v1.feature_column.input_layer(features, - feature_columns) - - loss = tf.cast( - tf.compat.v1.losses.mean_squared_error(input_layer, target), - tf.float64) - optimizer = tf.compat.v1.train.GradientDescentOptimizer( - learning_rate=0.97) - train_op = optimizer.minimize(loss) - - with tf.compat.v1.train.MonitoredSession() as sess: - self.assertAllEqual(sess.run(input_layer), [[1, 2, 3, 4], [5, 5, 5, 5]]) - for _ in range(10): - sess.run(train_op) - self.assertAllClose(sess.run(input_layer), target, atol=0.5) - - def testInvalidTextModule(self): - spec = hub.create_module_spec(invalid_text_module_fn) - with self.assertRaisesRegexp(ValueError, "only one input"): - hub.text_embedding_column("coment", spec, trainable=False) - - def testConfig(self): - module_path = os.path.join(self.get_temp_dir(), "module") - export_module_spec(self.spec, module_path) - text_column = hub.text_embedding_column("text", module_path) - config = text_column.get_config() - cloned_text_column = hub.feature_column._TextEmbeddingColumn.from_config( - config) - self.assertEqual(cloned_text_column.module_spec_path, - text_column.module_spec_path) - - with self.assertRaisesRegexp(NotImplementedError, "Can only generate"): - text_column = hub.text_embedding_column("text", self.spec) - config = text_column.get_config() - - -def create_image_module_fn(image_size, randomly_initialized=False): - def image_module_fn(): - """Maps 1x2 images to sums of each color channel.""" - height, width = image_size - images = tf.compat.v1.placeholder(dtype=tf.float32, - shape=[None, height, width, 3]) - if randomly_initialized: - initializer = tf.compat.v1.random_uniform_initializer( - minval=-1, maxval=1, dtype=tf.float32) - else: - initializer = tf.compat.v1.constant_initializer(1.0, dtype=tf.float32) - weight = tf.compat.v1.get_variable( - name="weight", shape=[1], initializer=initializer) - sum_channels = tf.reduce_sum(images, axis=[1, 2]) * weight - hub.add_signature(inputs={"images": images}, outputs=sum_channels) - return image_module_fn - - -class ImageEmbeddingColumnTest(tf.test.TestCase): - - def setUp(self): - self.spec = hub.create_module_spec(create_image_module_fn([1, 2])) - self.randomly_initialized_spec = hub.create_module_spec( - create_image_module_fn([1, 2], randomly_initialized=True)) - - def testExpectedImageSize(self): - image_column = hub.image_embedding_column("image", self.spec) - # The usage comment recommends this code pattern, so we test it here. - self.assertSequenceEqual( - hub.get_expected_image_size(image_column.module_spec), [1, 2]) - - def testVariableShape(self): - image_column = hub.image_embedding_column("image", self.spec) - self.assertEqual(image_column.variable_shape, [3]) - - def testParents(self): - image_column = hub.image_embedding_column("image", self.spec) - self.assertEqual(["image"], image_column.parents) - - def testMakeParseExampleSpec(self): - image_column = hub.image_embedding_column("image", self.spec) - parsing_spec = tf.compat.v1.feature_column.make_parse_example_spec( - [image_column]) - self.assertEqual( - parsing_spec, - {"image": tf.compat.v1.FixedLenFeature([1, 2, 3], dtype=tf.float32)}) - - def testImageSizeManuallySpecified(self): - spec = hub.create_module_spec(create_image_module_fn([None, None])) - image_column = hub.image_embedding_column("image", spec, - image_size=[229, 229]) - parsing_spec = tf.compat.v1.feature_column.make_parse_example_spec( - [image_column]) - self.assertEqual( - parsing_spec, - {"image": tf.compat.v1.FixedLenFeature([229, 229, 3], - dtype=tf.float32)}) - - def testInputLayer(self): - features = { - "image_a": [[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]], - [[[0.7, 0.7, 0.7], [0.1, 0.2, 0.3]]]], - "image_b": [[[[0.1, 0.2, 0.1], [0.2, 0.1, 0.2]]], - [[[0.1, 0.2, 0.3], [0.3, 0.2, 0.1]]]], - } - feature_columns = [ - hub.image_embedding_column("image_a", self.spec), - hub.image_embedding_column("image_b", self.spec), - ] - with tf.Graph().as_default(): - input_layer = tf.compat.v1.feature_column.input_layer(features, - feature_columns) - with tf.compat.v1.train.MonitoredSession() as sess: - output = sess.run(input_layer) - self.assertAllClose( - output, - [[0.5, 0.7, 0.9, 0.3, 0.3, 0.3], [0.8, 0.9, 1.0, 0.4, 0.4, 0.4]]) - - def testDenseFeatures(self): - features = { - "image_a": [[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]], - [[[0.7, 0.7, 0.7], [0.1, 0.2, 0.3]]]], - "image_b": [[[[0.1, 0.2, 0.1], [0.2, 0.1, 0.2]]], - [[[0.1, 0.2, 0.3], [0.3, 0.2, 0.1]]]], - } - feature_columns = [ - hub.image_embedding_column("image_a", self.spec), - hub.image_embedding_column("image_b", self.spec), - ] - if not feature_column_v2.is_feature_column_v2(feature_columns): - self.skipTest("Resources not implemented in the state manager of feature " - "column v2.") - with tf.Graph().as_default(): - # We want to test with dense_features_v2.DenseFeatures. This symbol was - # added in https://github.com/tensorflow/tensorflow/commit/64586f18724f737393071125a91b19adf013cf8a. - feature_layer = keras.layers.DenseFeatures(feature_columns) - feature_layer_out = feature_layer(features) - with tf.compat.v1.train.MonitoredSession() as sess: - output = sess.run(feature_layer_out) - self.assertAllClose( - output, - [[0.5, 0.7, 0.9, 0.3, 0.3, 0.3], [0.8, 0.9, 1.0, 0.4, 0.4, 0.4]]) - - def testDenseFeatures_shareAcrossApplication(self): - features = { - "image": [[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]], - [[[0.7, 0.7, 0.7], [0.1, 0.2, 0.3]]]], - } - feature_columns = [ - hub.image_embedding_column("image", self.randomly_initialized_spec), - ] - if not feature_column_v2.is_feature_column_v2(feature_columns): - self.skipTest("Resources not implemented in the state manager of feature " - "column v2.") - with tf.Graph().as_default(): - # We want to test with dense_features_v2.DenseFeatures. This symbol was - # added in https://github.com/tensorflow/tensorflow/commit/64586f18724f737393071125a91b19adf013cf8a. - feature_layer = keras.layers.DenseFeatures(feature_columns) - feature_layer_out_1 = feature_layer(features) - feature_layer_out_2 = feature_layer(features) - - with tf.compat.v1.train.MonitoredSession() as sess: - output_1 = sess.run(feature_layer_out_1) - output_2 = sess.run(feature_layer_out_2) - - self.assertAllClose(output_1, output_2) - - def testConfig(self): - module_path = os.path.join(self.get_temp_dir(), "module") - export_module_spec(self.spec, module_path) - image_column = hub.image_embedding_column("image", module_path) - config = image_column.get_config() - cloned_image_column = hub.feature_column._ImageEmbeddingColumn.from_config( - config) - self.assertEqual(cloned_image_column.module_spec_path, - image_column.module_spec_path) - - with self.assertRaisesRegexp(NotImplementedError, "Can only generate"): - image_column = hub.image_embedding_column("image", self.spec) - config = image_column.get_config() - - def testName(self): - image_column = hub.image_embedding_column( - tf.feature_column.numeric_column("image"), self.spec) - self.assertEqual("image_hub_module_embedding", image_column.name) - - -class SparseTextEmbeddingColumnTest(tf.test.TestCase): - - def setUp(self): - self.spec = hub.create_module_spec(text_module_fn) - - def testVariableShape(self): - text_column = hub.sparse_text_embedding_column( - "text", self.spec, combiner="mean", default_value=None, trainable=False) - self.assertEqual(text_column._variable_shape, [4]) - - def testMakeParseExampleSpec(self): - text_column = hub.sparse_text_embedding_column( - "text", self.spec, combiner="mean", default_value=None, trainable=False) - parsing_spec = tf.compat.v1.feature_column.make_parse_example_spec( - [text_column]) - self.assertEqual( - parsing_spec, {"text": tf.compat.v1.VarLenFeature(tf.string)}) - - def testParents(self): - text_column = hub.sparse_text_embedding_column( - "text", self.spec, "sum", "", trainable=False) - self.assertEqual(["text"], text_column.parents) - - def testInputLayer(self): - with tf.Graph().as_default(): - text_a = tf.SparseTensor( - values=["hello world", "pair-programming", "hello world"], - indices=[[0, 0], [0, 1], [1, 0]], - dense_shape=[2, 2]) - text_b = tf.SparseTensor( - values=["hello world", "oov token"], - indices=[[0, 0], [0, 1]], - dense_shape=[2, 3]) - - features = { - "text_a": text_a, - "text_b": text_b, - } - feature_columns = [ - hub.sparse_text_embedding_column( - "text_a", - self.spec, - combiner="mean", - default_value="__UNKNOWN__", - trainable=False), - hub.sparse_text_embedding_column( - "text_b", - self.spec, - combiner="mean", - default_value="__UNKNOWN__", - trainable=False), - ] - input_layer = tf.compat.v1.feature_column.input_layer(features, - feature_columns) - with tf.compat.v1.train.MonitoredSession() as sess: - output = sess.run(input_layer) - self.assertAllEqual( - output, - [[3, 3.5, 4, 4.5, 0.5, 1, 1.5, 2], [1, 2, 3, 4, 0, 0, 0, 0]]) - # ([1, 2, 3, 4] + [5, 5, 5, 5])/2 extend ([1, 2, 3, 4] + [0, 0, 0, 0])/2 - # [1, 2, 3, 4] extend [0, 0, 0, 0] - - def testTrainableEmbeddingColumn(self): - feature_columns = [ - hub.sparse_text_embedding_column( - "text", - self.spec, - combiner="mean", - default_value=None, - trainable=True), - ] - - with tf.Graph().as_default(): - text = tf.SparseTensor( - values=["hello world", "pair-programming"], - indices=[[0, 0], [1, 0]], - dense_shape=[2, 2]) - - target = [[1, 1, 1, 1], [4, 3, 2, 1]] - input_layer = tf.compat.v1.feature_column.input_layer({"text": text}, - feature_columns) - - loss = tf.compat.v1.losses.mean_squared_error(input_layer, target) - optimizer = tf.compat.v1.train.GradientDescentOptimizer( - learning_rate=0.97) - train_op = optimizer.minimize(loss) - - with tf.compat.v1.train.MonitoredSession() as sess: - self.assertAllEqual(sess.run(input_layer), [[1, 2, 3, 4], [5, 5, 5, 5]]) - for _ in range(10): - sess.run(train_op) - self.assertAllClose(sess.run(input_layer), target, atol=0.5) - - def testEmptySparseTensorBatch(self): - feature_columns = [ - hub.sparse_text_embedding_column( - "text", - self.spec, - combiner="mean", - default_value="default", - trainable=True), - ] - - with tf.Graph().as_default(): - text = tf.SparseTensor( - values=tf.constant([], dtype=tf.string, shape=[0]), - indices=tf.constant([], dtype=tf.int64, shape=[0, 2]), - dense_shape=[3, 0]) - - input_layer = tf.compat.v1.feature_column.input_layer({"text": text}, - feature_columns) - - with tf.compat.v1.train.MonitoredSession() as sess: - embeddings = sess.run(input_layer) - self.assertAllEqual(embeddings, - [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) - - def testEmptySparseTensorRow(self): - feature_columns = [ - hub.sparse_text_embedding_column( - "text", - self.spec, - combiner="mean", - default_value="default", - trainable=True), - ] - - with tf.Graph().as_default(): - text = tf.SparseTensor( - values=tf.constant(["hello world"], dtype=tf.string, shape=[1]), - indices=tf.constant([[0, 0]], dtype=tf.int64, shape=[1, 2]), - dense_shape=[2, 1]) - - input_layer = tf.compat.v1.feature_column.input_layer({"text": text}, - feature_columns) - - with tf.compat.v1.train.MonitoredSession() as sess: - embeddings = sess.run(input_layer) - self.assertAllEqual(embeddings, [[1, 2, 3, 4], [0, 0, 0, 0]]) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow_hub/image_util.py b/tensorflow_hub/image_util.py deleted file mode 100644 index 92650dcb..00000000 --- a/tensorflow_hub/image_util.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Helper functions for TF-Hub modules that handle images.""" - -from tensorflow_hub import image_module_info_pb2 -from tensorflow_hub import native_module - - -# Models in TF1 Hub format for images can provide further information for the -# utilities in this file by attaching an ImageModuleInfo message under this key. -IMAGE_MODULE_INFO_KEY = "image_module_info" - - -# The externally visible name of the message is hub.ImageModuleInfo -ImageModuleInfo = image_module_info_pb2.ImageModuleInfo # pylint: disable=invalid-name - - -def attach_image_module_info(image_module_info): - """Attaches an ImageModuleInfo message from within a module_fn. - - Warning: Deprecated. This belongs to the hub.Module API and TF1 Hub format. - - THIS FUNCTION IS DEPRECATED. - - Args: - image_module_info: an ImageModuleInfo message. - """ - native_module.attach_message(IMAGE_MODULE_INFO_KEY, image_module_info) - - -def get_image_module_info(module_or_spec, required=False): - """Returns the module's attached ImageModuleInfo message, or None if missing. - - Warning: Deprecated. This belongs to the hub.Module API and TF1 Hub format - - THIS FUNCTION IS DEPRECATED. - - Args: - module_or_spec: a hub.Module or module_spec object. - required: if true, raises KeyError instead of returning None. - """ - return module_or_spec.get_attached_message( - IMAGE_MODULE_INFO_KEY, ImageModuleInfo, required=required) - - -def get_expected_image_size(module_or_spec, signature=None, input_name=None): - """Returns expected [height, width] dimensions of an image input. - - TODO(b/139530454): This does not work yet with TF2. - - Args: - module_or_spec: a Module or ModuleSpec that accepts image inputs. - signature: a string with the key of the signature in question. - If None, the default signature is used. - input_name: a string with the input name for images. If None, the - conventional input name `images` for the default signature is used. - - Returns: - A list if integers `[height, width]`. - - Raises: - ValueError: If the size information is missing or malformed. - """ - # First see if an attached ImageModuleInfo provides this information. - image_module_info = get_image_module_info(module_or_spec) - if image_module_info: - size = image_module_info.default_image_size - if size.height and size.width: - return [size.height, size.width] - - # Else inspect the input shape in the module signature. - if input_name is None: - input_name = "images" - input_info_dict = module_or_spec.get_input_info_dict(signature) - try: - shape = input_info_dict[input_name].get_shape() - except KeyError: - raise ValueError("Module is missing input '%s' in signature '%s'." % - (input_name, signature or "default")) - try: - _, height, width, _ = shape.as_list() - if not height or not width: - raise ValueError - except ValueError: - raise ValueError( - "Shape of module input is %s, " - "expected [batch_size, height, width, num_channels] " - "with known height and width." % shape) - return [height, width] - - -def get_num_image_channels(module_or_spec, signature=None, input_name=None): - """Returns expected num_channels dimensions of an image input. - - This is for advanced users only who expect to handle modules with - image inputs that might not have the 3 usual RGB channels. - - TODO(b/139530454): This does not work yet with TF2. - - Args: - module_or_spec: a Module or ModuleSpec that accepts image inputs. - signature: a string with the key of the signature in question. - If None, the default signature is used. - input_name: a string with the input name for images. If None, the - conventional input name `images` for the default signature is used. - - Returns: - An integer with the number of input channels to the module. - - Raises: - ValueError: If the channel information is missing or malformed. - """ - if input_name is None: - input_name = "images" - input_info_dict = module_or_spec.get_input_info_dict(signature) - try: - shape = input_info_dict[input_name].get_shape() - except KeyError: - raise ValueError("Module is missing input '%s' in signature '%s'." % - (input_name, signature or "default")) - try: - _, _, _, num_channels = shape.as_list() - if num_channels is None: - raise ValueError - except ValueError: - raise ValueError( - "Shape of module input is %s, " - "expected [batch_size, height, width, num_channels] " - "with known num_channels" % shape) - return num_channels diff --git a/tensorflow_hub/image_util_test.py b/tensorflow_hub/image_util_test.py deleted file mode 100644 index ca97ef75..00000000 --- a/tensorflow_hub/image_util_test.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for image_util.py.""" - -import tensorflow as tf - -from tensorflow_hub import image_util -from tensorflow_hub import module -from tensorflow_hub import native_module - - -def image_module_fn(): - images = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 2, 4, 3]) - sum_by_channels = tf.reduce_sum(images, [1, 2]) - sum_all = tf.reduce_sum(images, [1, 2, 3]) - native_module.add_signature(inputs=dict(images=images), - outputs=dict(default=sum_all, - sum_by_channels=sum_by_channels)) - - -def image_module_fn_with_info(): - images = tf.compat.v1.placeholder(dtype=tf.float32, - shape=[None, None, None, 3]) - sum_all = tf.reduce_sum(images, [1, 2, 3]) - native_module.add_signature(inputs=dict(images=images), - outputs=dict(default=sum_all)) - image_module_info = image_util.ImageModuleInfo() - size = image_module_info.default_image_size - size.height, size.width = 2, 4 - image_util.attach_image_module_info(image_module_info) - - -class ImageModuleTest(tf.test.TestCase): - - def testGetExpectedImageSizeFromShape(self): - with tf.Graph().as_default(): - spec = native_module.create_module_spec(image_module_fn) - self.assertAllEqual(image_util.get_expected_image_size(spec), [2, 4]) - m = module.Module(spec) - self.assertAllEqual(image_util.get_expected_image_size(m), [2, 4]) - - def testGetExpectedImageSizeFromImageModuleInfo(self): - with tf.Graph().as_default(): - spec = native_module.create_module_spec(image_module_fn_with_info) - self.assertAllEqual(image_util.get_expected_image_size(spec), [2, 4]) - m = module.Module(spec) - self.assertAllEqual(image_util.get_expected_image_size(m), [2, 4]) - - def testGetNumImageChannels(self): - with tf.Graph().as_default(): - spec = native_module.create_module_spec(image_module_fn) - self.assertEqual(image_util.get_num_image_channels(spec), 3) - m = module.Module(spec) - self.assertEqual(image_util.get_num_image_channels(m), 3) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow_hub/keras_layer_test.py b/tensorflow_hub/keras_layer_test.py index f44e9566..615d678e 100644 --- a/tensorflow_hub/keras_layer_test.py +++ b/tensorflow_hub/keras_layer_test.py @@ -91,32 +91,6 @@ def call_fn(inputs): tf.saved_model.save(obj, export_dir) -def _save_half_plus_one_hub_module_v1(path): - """Writes a model in TF1 Hub format to compute y = wx + 1, with w trainable.""" - - def half_plus_one(): - x = tf.compat.v1.placeholder(shape=(None, 1), dtype=tf.float32) - # Use TF1 native tf_keras_v2.layers instead of tf_keras_v2.layers as they - # correctly update TF collections, such as REGULARIZATION_LOSS. - times_w = tf_keras_v2.layers.Dense( - units=1, - kernel_initializer=tf_keras_v2.initializers.Constant([[0.5]]), - kernel_regularizer=tf_keras_v2.regularizers.l2(0.01), - use_bias=False, - ) - plus_1 = tf_keras_v2.layers.Dense( - units=1, - kernel_initializer=tf_keras_v2.initializers.Constant([[1.0]]), - bias_initializer=tf_keras_v2.initializers.Constant([1.0]), - trainable=False, - ) - y = plus_1(times_w(x)) - hub.add_signature(inputs=x, outputs=y) - - spec = hub.create_module_spec(half_plus_one) - _export_module_spec_with_init_weights(spec, path) - - def _save_2d_text_embedding(export_dir, save_from_keras=False): """Writes SavedModel to compute y = length(text)*w, with w trainable.""" @@ -352,35 +326,12 @@ def keras_default(x, training=False): tf.saved_model.save(obj, path, signatures={"plus_one": obj.plus_one}) -def _save_plus_one_hub_module_v1(path): - """Writes a model in TF1 Hub format that increments the input by one.""" - - def plus_one(): - x = tf.compat.v1.placeholder(dtype=tf.float32, name="x") - y = x + 1 - hub.add_signature(inputs=x, outputs=y) - - spec = hub.create_module_spec(plus_one) - _export_module_spec_with_init_weights(spec, path) - - -def _export_module_spec_with_init_weights(spec, path): - """Initializes initial weights of a TF1.x HubModule and saves it.""" - with tf.compat.v1.Graph().as_default(): - module = hub.Module(spec, trainable=True) - with tf.compat.v1.Session() as session: - session.run(tf.compat.v1.global_variables_initializer()) - module.export(path, session) - - -def _dispatch_model_format(model_format, saved_model_fn, hub_module_fn, *args): +def _dispatch_model_format(model_format, saved_model_fn, *args): """Dispatches the correct save function based on the model format.""" if model_format == "TF2SavedModel_SavedRaw": saved_model_fn(*args, save_from_keras=False) elif model_format == "TF2SavedModel_SavedFromKeras": saved_model_fn(*args, save_from_keras=True) - elif model_format == "TF1HubModule": - hub_module_fn(*args) else: raise ValueError("Unrecognized format: " + format) @@ -392,8 +343,7 @@ class KerasTest(tf.test.TestCase, parameterized.TestCase): ("TF2SavedModel_SavedFromKeras")) def testHalfPlusOneRetraining(self, model_format): export_dir = os.path.join(self.get_temp_dir(), "half-plus-one") - _dispatch_model_format(model_format, _save_half_plus_one_model, - _save_half_plus_one_hub_module_v1, export_dir) + _dispatch_model_format(model_format, _save_half_plus_one_model, export_dir) # Import the half-plus-one model into a consumer model. inp = tf_keras_v2.layers.Input(shape=(1,), dtype=tf.float32) imported = hub.KerasLayer(export_dir, trainable=True) @@ -440,8 +390,7 @@ def testHalfPlusOneRetraining(self, model_format): ("TF2SavedModel_SavedFromKeras")) def testRegularizationLoss(self, model_format): export_dir = os.path.join(self.get_temp_dir(), "half-plus-one") - _dispatch_model_format(model_format, _save_half_plus_one_model, - _save_half_plus_one_hub_module_v1, export_dir) + _dispatch_model_format(model_format, _save_half_plus_one_model, export_dir) # Import the half-plus-one model into a consumer model. inp = tf_keras_v2.layers.Input(shape=(1,), dtype=tf.float32) imported = hub.KerasLayer(export_dir, trainable=False) @@ -675,18 +624,6 @@ def testResaveWithMixedPrecision(self, save_from_keras): finally: tf_keras_v2.mixed_precision.set_global_policy("float32") - def testComputeOutputShapeNonEager(self): - export_dir = os.path.join(self.get_temp_dir(), "half-plus-one") - _save_half_plus_one_hub_module_v1(export_dir) - - with tf.compat.v1.Graph().as_default(): - # Output shape is required when computing output shape outside of eager - # mode. - layer = hub.KerasLayer(export_dir, output_shape=(1,)) - self.assertEqual([None, 1], - layer.compute_output_shape((None, 1)).as_list()) - self.assertEqual([3, 1], layer.compute_output_shape((3, 1)).as_list()) - @parameterized.named_parameters(("SavedRaw", False), ("SavedFromKeras", True)) def testGetConfigFromConfig(self, save_from_keras): export_dir = os.path.join(self.get_temp_dir(), "half-plus-one") @@ -736,41 +673,17 @@ def testSaveModelConfig(self, save_from_keras): class KerasLayerTest(tf.test.TestCase, parameterized.TestCase): """Unit tests for KerasLayer.""" - @parameterized.parameters(("TF1HubModule"), ("TF2SavedModel_SavedRaw")) - def test_load_with_defaults(self, model_format): + def test_load_with_defaults(self): + model_format = "TF2SavedModel_SavedRaw" export_dir = os.path.join(self.get_temp_dir(), "plus_one_" + model_format) - _dispatch_model_format(model_format, _save_plus_one_saved_model_v2, - _save_plus_one_hub_module_v1, export_dir) + _dispatch_model_format( + model_format, _save_plus_one_saved_model_v2, export_dir + ) inputs, expected_outputs = 10., 11. # Test modules perform increment op. layer = hub.KerasLayer(export_dir) output = layer(inputs) self.assertEqual(output, expected_outputs) - @parameterized.parameters( - ("TF1HubModule", None, None, True), - ("TF1HubModule", None, None, False), - ("TF1HubModule", "default", None, True), - ("TF1HubModule", None, "default", False), - ("TF1HubModule", "default", "default", False), - ) - def test_load_legacy_hub_module_v1_with_signature(self, model_format, - signature, output_key, - as_dict): - export_dir = os.path.join(self.get_temp_dir(), "plus_one_" + model_format) - _dispatch_model_format(model_format, _save_plus_one_saved_model_v2, - _save_plus_one_hub_module_v1, export_dir) - inputs, expected_outputs = 10., 11. # Test modules perform increment op. - layer = hub.KerasLayer( - export_dir, - signature=signature, - output_key=output_key, - signature_outputs_as_dict=as_dict) - output = layer(inputs) - if as_dict: - self.assertEqual(output, {"default": expected_outputs}) - else: - self.assertEqual(output, expected_outputs) - @parameterized.parameters( ("TF2SavedModel_SavedRaw", None, None, False), ("TF2SavedModel_SavedRaw", "serving_default", None, True), @@ -780,8 +693,9 @@ def test_load_callable_saved_model_v2_with_signature(self, model_format, signature, output_key, as_dict): export_dir = os.path.join(self.get_temp_dir(), "plus_one_" + model_format) - _dispatch_model_format(model_format, _save_plus_one_saved_model_v2, - _save_plus_one_hub_module_v1, export_dir) + _dispatch_model_format( + model_format, _save_plus_one_saved_model_v2, export_dir + ) inputs, expected_outputs = 10., 11. # Test modules perform increment op. layer = hub.KerasLayer( export_dir, @@ -807,20 +721,16 @@ def test_load_callable_keras_default_saved_model_v2_with_signature(self): self.assertEqual(output["output_0"], expected_outputs) @parameterized.parameters( - ("TF1HubModule", None, None, True), - ("TF1HubModule", None, None, False), - ("TF1HubModule", "default", None, True), - ("TF1HubModule", None, "default", False), - ("TF1HubModule", "default", "default", False), - ("TF2SavedModel_SavedRaw", None, None, False), - ("TF2SavedModel_SavedRaw", "serving_default", None, True), - ("TF2SavedModel_SavedRaw", "serving_default", "output_0", False), + (None, None, False), + ("serving_default", None, True), + ("serving_default", "output_0", False), ) - def test_keras_layer_get_config(self, model_format, signature, output_key, - as_dict): + def test_keras_layer_get_config(self, signature, output_key, as_dict): + model_format = "TF2SavedModel_SavedRaw" export_dir = os.path.join(self.get_temp_dir(), "plus_one_" + model_format) - _dispatch_model_format(model_format, _save_plus_one_saved_model_v2, - _save_plus_one_hub_module_v1, export_dir) + _dispatch_model_format( + model_format, _save_plus_one_saved_model_v2, export_dir + ) inputs = 10. # Test modules perform increment op. layer = hub.KerasLayer( export_dir, @@ -875,21 +785,6 @@ def test_keras_layer_fails_if_output_is_not_dict(self): ValueError, "Specifying `output_key` is forbidden if output type *"): layer(10.) - def test_keras_layer_fails_if_output_key_not_in_layer_outputs(self): - export_dir = os.path.join(self.get_temp_dir(), "hub_module_v1_mini") - _save_plus_one_hub_module_v1(export_dir) - layer = hub.KerasLayer(export_dir, output_key="unknown") - with self.assertRaisesRegex( - ValueError, "KerasLayer output does not contain the output key*"): - layer(10.) - - def test_keras_layer_fails_if_hub_module_trainable(self): - export_dir = os.path.join(self.get_temp_dir(), "hub_module_v1_mini") - _save_plus_one_hub_module_v1(export_dir) - layer = hub.KerasLayer(export_dir, trainable=True) - with self.assertRaisesRegex(ValueError, "trainable.*=.*True.*unsupported"): - layer(10.) - def test_keras_layer_fails_if_signature_trainable(self): export_dir = os.path.join(self.get_temp_dir(), "saved_model_v2_mini") _save_plus_one_saved_model_v2(export_dir, save_from_keras=False) diff --git a/tensorflow_hub/module.py b/tensorflow_hub/module.py deleted file mode 100644 index 6c57e93c..00000000 --- a/tensorflow_hub/module.py +++ /dev/null @@ -1,628 +0,0 @@ -# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""The deprecated hub.Module class of TensorFlow Hub.""" - -import contextlib - -import tensorflow as tf - -from tensorflow_hub import module_spec -from tensorflow_hub import registry -from tensorflow_hub import tensor_info - - -def as_module_spec(spec): - if isinstance(spec, module_spec.ModuleSpec): - return spec - elif isinstance(spec, str): - return load_module_spec(spec) - else: - raise ValueError("Unknown module spec type: %r" % type(spec)) - - -def load_module_spec(path): - """Loads a ModuleSpec from a TF Hub service or the filesystem. - - Warning: Deprecated. This belongs to the hub.Module API and TF1 Hub format. - For TF2, switch to plain SavedModels and hub.load(); see also hub.resolve(). - - THIS FUNCTION IS DEPRECATED. - - Args: - path: string describing the location of a module. There are several - supported path encoding schemes: - a) A URL like "https://tfhub.dev/the/module/1" referring to tfhub.dev or - another service implementing https://www.tensorflow.org/hub/hosting. - b) A URL like "https://example.com/module.tar.gz" that points to a - compressed tarball directly, as long as that web server ignores - the query parameters added by https://www.tensorflow.org/hub/hosting. - c) Any filesystem location of a module directory (e.g. /module_dir - for a local filesystem). All filesystems implementations provided - by Tensorflow are supported. - d) Private name resolution schemes added by the maintainer of your - local installation of the tensorflow_hub library (usually none). - - Returns: - A ModuleSpec. - - Raises: - ValueError: on unexpected values in the module spec. - tf.errors.OpError: on file handling exceptions. - """ - path = registry.resolver(path) - return registry.loader(path) - - -def export_module_spec(spec, path, checkpoint_path, name_transform_fn): - """Helper function to ModuleSpec.export().""" - with tf.Graph().as_default(): - m = Module(spec) - assign_map = { - name_transform_fn(name): value for name, value in m.variable_map.items() - } - tf.compat.v1.train.init_from_checkpoint(checkpoint_path, assign_map) - init_op = tf.compat.v1.initializers.global_variables() - with tf.compat.v1.Session() as session: - session.run(init_op) - m.export(path, session) - - -# Module class provides a unified access to all ModuleSpecs implementations and -# should not contain specific implementation code in it (e.g. SavedModel code). -class Module(object): - """Part of a TensorFlow 1 model that can be transferred between models. - - Warning: Deprecated. The hub.Module API works for TF1 only. - For TF2, switch to plain SavedModels and hub.load(). - - A Module represents a part of a TensorFlow graph that can be exported to disk - (based on the SavedModel format) and later re-loaded. A Module has a defined - interface that allows it to be used in a replaceable way, with little or no - knowledge of its internals and its serialization format. Example: - - ```python - m = hub.Module("/tmp/text-embedding") - embeddings = m(sentences) - ``` - - The module to instantiate is defined by its spec (a `ModuleSpec` or a - path where to load it from) which contains the module weights, assets and - signatures. - - During instantiation the Module adds the state (e.g. variables and tables ops) - to the current graph. Afterwards, the method `__call__()` allows to apply the - module `signatures` multiple times, which adds ops for the computation. - - A Module may provide different variants of its graph for different purposes - (say, training or serving, which may behave differently, e.g., for batch - normalization). Graph variants are identified by sets of string-valued tags. - The graph variant used to create a module that is exported must define all the - variables needed by any other graph variant that is subsequently used. - - To make it possible to easily replace a module with another, they all assume - that they will be used with common TensorFlow conventions such as session - initialization and restore, use of collections for variables, regularization - losses and updates, etc. - - THIS FUNCTION IS DEPRECATED. - """ - - def __init__(self, spec, trainable=False, name="module", tags=None): - """Constructs a Module to be used in the current graph. - - This creates the module `state-graph` under an unused variable_scope based - on `name`. During this call a Module will: - - - Add GLOBAL_VARIABLES under its scope. Those variables may be added to - to the TRAINABLE_VARIABLES collection (depending on `trainable` parameter) - and to the MODEL_VARIABLES. The variables must be initialized before use, - and can be checkpointed as usual. - - - Add ops to the INIT_TABLE_OPS collection, which must be run during session - initialization and add constant tensors to ASSET_FILEPATHS that are - needed during the execution of such ops. - - - Add tensors to the REGULARIZATION_LOSSES collection (depending on - `trainable` parameter). - - Args: - spec: A ModuleSpec defining the Module to instantiate or a path where - to load a ModuleSpec from via `load_module_spec`. - trainable: whether the Module is trainable. If False, no variables are - added to TRAINABLE_VARIABLES collection, and no tensors are added to - REGULARIZATION_LOSSES collection. - name: A string, the variable scope name under which to create the Module. - It will be uniquified and the equivalent name scope must be unused. - tags: A set of strings specifying the graph variant to use. - - Raises: - RuntimeError: explaning the reason why it failed to instantiate the - Module. - ValueError: if the requested graph variant does not exists. - tf.errors.NotFoundError: if the requested graph contains unknown ops. - """ - self._graph = tf.compat.v1.get_default_graph() - self._spec = as_module_spec(spec) - self._trainable = trainable - - self._tags = set(tags or []) - if self._tags not in self._spec.get_tags(): - tags = sorted(list(tags)) if tags else tags - raise ValueError("No such graph variant: tags=%r" % tags) - - abs_state_scope = _try_get_state_scope(name, mark_name_scope_used=False) - self._name = abs_state_scope.split("/")[-2] - - abs_parent_scope = abs_state_scope.split("/")[:-2] - if abs_parent_scope: - abs_parent_scope = "/".join(abs_parent_scope) + "/" - else: - abs_parent_scope = "" - - with tf.name_scope(abs_parent_scope): - # pylint: disable=protected-access - self._impl = self._spec._create_impl( - name=self._name, - trainable=self._trainable, - tags=self._tags) - # pylint: enable=protected-access - - def __call__(self, inputs=None, # pylint: disable=invalid-name - _sentinel=None, signature=None, as_dict=None): - """Instantiates a module signature in the graph. - - Example calls: - - ```python - # Use default signature with one input and default output. - embeddings = m(["hello world", "good morning"]) - - # Use "encode" signature with one input and default output. - encodings = m(["hello world"], signature="encode") - - # Use default signature with input dict and output dict. - dict_outputs = m({"text": [...], "lang": [...]}, as_dict=True) - ``` - - The method __call__() allows to create the graph ops that compute a - signature outputs given the inputs and using this module instance state. - Each signature can be applied multiple times with different inputs and they - all share the same module state. - - A Module may define multiple signatures. Use `signature=` to identify - the specific signature to instantiate. If omitted or None, the default - signature is used. - - A signature may define various outputs. Use `as_dict=True` to return a dict - of all outputs. If omitted or False, the output named 'default' is - returned. - - During this call a Module will: - - - Add ops in the current name scope to convert the inputs in tensors to feed - to the signature. - - - Add ops to the UPDATE_OPS collection which depend on at least one of the - provided inputs if the Module was constructed with `trainable=True`. - - - Add constant tensors to ASSET_FILEPATHS, even if those are not needed - directly needed for the signature. - - Note: `hub.Module` implementation depends on graph pruning that happens - usually during `session.run` as so it can lead to errors when used inside - function graphs that execute all its ops (e.g. `tf.data.Dataset.map`). - - Args: - inputs: Inputs to the signature. A dict from input names to input tensors - (incl. composite tensors, such as `SparseTensor` or `RaggedTensor`). - If the signature only expects one input, one may pass - a single value. If the signature has no inputs, it may be omitted. - _sentinel: Used to prevent positional parameters besides `inputs`. - signature: A string with the signature name to apply. If none, the - default signature is used. - as_dict: A boolean indicating whether to the return all the outputs - of the signature as a dict or return only the default output. - - Returns: - A tensor (incl. composite tensors, such as `SparseTensor` or - `RaggedTensor`) if the signature defines a default output; or a dict from - strings (output names) to tensors (incl. composite tensors) if - `as_dict=True` is used. - - Raises: - TypeError: If there is a mismatch on arguments, inputs or outputs of - the module signature. - RuntimeError: If there are errors during creation of the signature graph. - """ - if self._graph is not tf.compat.v1.get_default_graph(): - raise RuntimeError( - "Module must be applied in the graph it was instantiated for.") - - signature = self._impl.get_signature_name(signature) - # SavedModel non-default signatures automatically includes ':' in them, - # but that is an invalid character for a name that is used as part - # of variable scopes. - safe_signature = signature.replace(":", "_") - name = "%s_apply_%s" % (self._name, safe_signature) - - input_tensor_infos = self._spec.get_input_info_dict(signature, self._tags) - output_tensor_infos = self._spec.get_output_info_dict(signature, self._tags) - _check_supported_types(input_tensor_infos, signature, "input") - _check_supported_types(output_tensor_infos, signature, "output") - - dict_inputs = _convert_dict_inputs(inputs, input_tensor_infos) - dict_outputs = self._impl.create_apply_graph( - signature=signature, - input_tensors=dict_inputs, - name=name) - return _prepare_outputs(dict_outputs, as_dict=as_dict) - - def get_signature_names(self): - """Returns the module's signature names as an iterable of strings.""" - return self._spec.get_signature_names(tags=self._tags) - - def get_input_info_dict(self, signature=None): - """Describes the inputs required by a signature. - - Args: - signature: A string with the signature to get inputs information for. - If None, the default signature is used if defined. - - Returns: - The result of ModuleSpec.get_input_info_dict() for the given signature, - and the graph variant selected by `tags` when this Module was initialized. - - Raises: - KeyError: if there is no such signature. - """ - return self._spec.get_input_info_dict(signature=signature, tags=self._tags) - - def get_output_info_dict(self, signature=None): - """Describes the outputs provided by a signature. - - Args: - signature: A string with the signature to get ouputs information for. - If None, the default signature is used if defined. - - Returns: - The result of ModuleSpec.get_output_info_dict() for the given signature, - and the graph variant selected by `tags` when this Module was initialized. - - Raises: - KeyError: if there is no such signature. - """ - return self._spec.get_output_info_dict(signature=signature, tags=self._tags) - - def get_attached_message(self, key, message_type, required=False): - """Calls ModuleSpec.get_attached_message(); see there for more.""" - return self._spec.get_attached_message(key, message_type, - tags=self._tags, required=required) - - def export(self, path, session): - """Exports the module with the variables from the session in `path`. - - Note that it is the module definition in the ModuleSpec used to create this - module that gets exported. The session is only used to provide the value - of variables. - - Args: - path: path where to export the module to. - session: session where to export the variables from. - - Raises: - RuntimeError: if there is an issue during the export. - """ - if self._graph is not tf.compat.v1.get_default_graph(): - raise RuntimeError("default graph differs from the graph where the " - "module was instantiated.") - if self._graph is not session.graph: - raise RuntimeError("session graph differs from the graph where the " - "module was instantiated.") - self._impl.export(path, session) - - @property - def variable_map(self): - """Map from original variable names into tf.Variables (or lists of them). - - This map translates between variable names relative to the module and the - corresponding Variable objects that have been created by instantiating it - in the current graph (with the applicable scoping added). Each key in the - map is a variable name as created by running the module's defining - `module_fn` in the root scope of an empty graph. Each value in the map is - a Variable object, or in case of partitioned variables a list of Variable - objects. - - This property can be used with `tf.init_from_checkpoint` as `assignment_map` - in order to restore a pre-trained checkpoint into a Module before calling - `Module.export()`. - - Returns: - A dict from the variable names in the Module to the instantiated - tf.Variables or list of tf.Variables (if partitioned). The keys of this - map are the same regardless of the scope of where the Module was - instantiated. - """ - return self._impl.variable_map - - @property - def variables(self): - """Returns the list of all tf.Variables created by module instantiation.""" - result = [] - for _, value in sorted(self.variable_map.items()): - if isinstance(value, list): - result.extend(value) - else: - result.append(value) - return result - - -def _try_get_state_scope(name, mark_name_scope_used=True): - """Returns a fresh variable/name scope for a module's state. - - In order to import a module into a given scope without major complications - we require the scope to be empty. This function deals with deciding an unused - scope where to define the module state. This is non trivial in cases where - name_scope and variable_scopes are out of sync, e.g. tpus or re-entering - scopes. - - Args: - name: A string with the name of the module as supplied by the client. - mark_name_scope_used: a boolean, indicating whether to mark the name - scope of the returned value as used. - - Raises: - RuntimeError: if the name scope of the freshly created variable scope is - already used. - """ - tmp_scope_name = tf.compat.v1.get_variable_scope().name - if tmp_scope_name: - tmp_scope_name += "/" - with tf.name_scope(tmp_scope_name): - # Pick an unused variable scope. - with tf.compat.v1.variable_scope( - None, default_name=name, auxiliary_name_scope=False) as vs: - abs_state_scope = vs.name + "/" - # Verify that the name scope is available and mark it used if requested. - graph = tf.compat.v1.get_default_graph() - unique_name_scope = graph.unique_name(name, mark_name_scope_used) + "/" - if unique_name_scope != abs_state_scope: - raise RuntimeError( - "variable_scope %s was unused but the corresponding " - "name_scope was already taken." % abs_state_scope) - return abs_state_scope - - -def _check_supported_types(tensor_infos, signature, arg_type): - """Raises a ValueError if any infos have unsupported types. - - Args: - tensor_infos: dictionary from string name to TensorInfo. - signature: string signature name. - arg_type: `'input'` or `'output'` (for the error message) - """ - for name, info in sorted(tensor_infos.items()): - if not info.is_supported_type: - raise ValueError( - "Signature %r expects a %s for %s %r, which is not supported" - " by this version of tensorflow_hub." % - (signature, info.type_spec.value_type.__name__, arg_type, name)) - - -def _prepare_dict_inputs(inputs, tensor_info_map): - """Converts inputs to a dict of inputs and checks extra/missing args. - - Args: - inputs: inputs fed to Module.__call__(). - tensor_info_map: A map from string to `tensor_info.ParsedTensorInfo` - describing the signature inputs. - - Returns: - A dict of values with the same keys as tensor_info_map. - - Raises: - TypeError: If it fails to convert the input values into a dict of tensors - to feed to the signature instantiation. - """ - if inputs is None: - dict_inputs = {} - elif isinstance(inputs, dict): - dict_inputs = inputs - elif len(tensor_info_map) == 1: - dict_inputs = {list(tensor_info_map.keys())[0]: inputs} - elif not tensor_info_map: - raise TypeError("Signature expects no inputs.") - else: - raise TypeError("Signature expects multiple inputs. Use a dict.") - - dict_inputs_keys = set(dict_inputs.keys()) - tensor_info_map_keys = set(tensor_info_map.keys()) - if dict_inputs_keys != tensor_info_map_keys: - raise TypeError("Cannot convert dict_inputs: missing %r, extra given %r" % - (sorted(list(tensor_info_map_keys - dict_inputs_keys)), - sorted(list(dict_inputs_keys - tensor_info_map_keys)))) - - return dict_inputs - - -def _convert_dict_inputs(inputs, tensor_info_map): - """Converts from inputs into dict of input tensors. - - This handles: - - putting inputs into a dict, per _prepare_dict_inputs(), - - converting all input values into tensors compatible with the - expected input tensor (dtype, shape). - - check composite tensor types. - - Args: - inputs: inputs fed to Module.__call__(). - tensor_info_map: A map from string to `tensor_info.ParsedTensorInfo` - describing the signature inputs. - - Returns: - A dict of tensors to feed to the signature instantiation. - - Raises: - TypeError: If it fails to convert the input values into a dict of tensors - to feed to the signature instantiation. - """ - dict_inputs = _prepare_dict_inputs(inputs, tensor_info_map) - return tensor_info.convert_dict_to_compatible_tensor(dict_inputs, - tensor_info_map) - - -def _prepare_outputs(dict_outputs, as_dict): - """Converts from dict outputs into the return value of Module.__call__(). - - Args: - dict_outputs: A dict output from applying a signature. - as_dict: A boolean indicating whether to return the outputs of the Module - as a dict or return the output named 'default. - - Returns: - A tensor with the output named 'default' or a dict of output tensors if - `as_dict=True`. - - Raises: - TypeError: If as_dict is False and there is no output named 'default'. - """ - if as_dict: - return dict_outputs - if "default" in dict_outputs: - return dict_outputs["default"] - else: - raise TypeError("There is no output named 'default'. Use as_dict=True.") - - -def _spec_to_placeholder(type_spec, name): - """Returns a tensor or composite tensor placeholder for the given TypeSpec. - - Args: - type_spec: A `TypeSpec`. - name: The name prefix for the placeholder tensors. - - Returns: - A placeholder tensor (if `type_spec` is a `TensorSpec`); or a value with - type `type_spec.value_type`, whose component tensors are placeholders. - """ - flat_specs = tf.nest.flatten(type_spec, expand_composites=True) - if len(flat_specs) == 1: - flat_names = [name] - else: - flat_names = [ - "{}.component_{}".format(name, i) for i in range(len(flat_specs)) - ] - placeholders = [ - tf.compat.v1.placeholder(s.dtype, s.shape, name) - for (s, name) in zip(flat_specs, flat_names) - ] - return tf.nest.pack_sequence_as( - type_spec, placeholders, expand_composites=True) - - -@contextlib.contextmanager -def eval_function_for_module(spec, tags=None): - """Context manager that yields a function to directly evaluate a hub.Module. - - Warning: Deprecated. This belongs to the hub.Module API and TF1 Hub format. - For TF2, switch to plain SavedModels and hub.load(). Eager execution in - TF2 obviates the need for this helper. - - This creates a separate graph, in which all of the signatures of the module - are instantiated. Then, it creates a session and initializes the module - variables. Finally, it returns a function which can be used to evaluate the - module signatures. - - The function returned by eval_function_for_module has the same syntax as - Module.__call__ , except that inputs and outputs are not tensors but actual - values as used with Session.run(). - - ```python - with hub.eval_function_for_module("/tmp/text-embedding") as f: - # The module can be directly evaluated using f without constructing a graph. - embeddings = f(["Hello world!",], signature="mysignature") - ``` - - THIS FUNCTION IS DEPRECATED. - - Args: - spec: A ModuleSpec defining the Module to instantiate or a path where to - load a ModuleSpec from via `load_module_spec`. - tags: A set of strings specifying the graph variant to use. - - Yields: - A function whose keyword arguments are fed into the tfhub module and which - returns a dictionary with the value of the output tensors. - - Raises: - RuntimeError: explaning the reason why it failed to instantiate the - Module. - ValueError: if the requested graph variant does not exists. - """ - # We create a separate graph and add all the signatures of the module to it. - original_graph = tf.compat.v1.get_default_graph() - with tf.Graph().as_default(): - module = Module(spec, tags=tags) - input_tensors_per_signature = {} - output_tensors_per_signature = {} - for signature in module.get_signature_names(): - # We scope with the signature name as different signatures will likely - # contain tensors with the same name (e.g. the input and output tensors). - with tf.compat.v1.variable_scope(signature): - input_tensors = {} - for name, tensorinfo in module.get_input_info_dict(signature).items(): - if tensorinfo.is_sparse: - # There's a bug in sparse_placeholder that causes it to break if - # we pass in `TensorShape(None)` -- work around it by passing in - # `None` instead. - shape = tensorinfo.get_shape() - effective_shape = None if shape.dims is None else shape.as_list() - if tensorinfo.is_sparse: - input_tensors[name] = tf.compat.v1.sparse_placeholder( - tensorinfo.dtype, shape=effective_shape, name=name) - else: - input_tensors[name] = _spec_to_placeholder(tensorinfo.type_spec, - name) - input_tensors_per_signature[signature] = input_tensors - output_tensors_per_signature[signature] = module( - input_tensors_per_signature[signature], - signature=signature, - as_dict=True) - - # Evaluating the tfhub module requires an active tensorflow session. - with tf.compat.v1.train.SingularMonitoredSession() as sess: - - def func( - inputs=None, - _sentinel=None, # pylint: disable=invalid-name - signature=None, - as_dict=None): - """Function that directly evaluates a signature in the module.""" - signature = signature or "default" - input_tensors = input_tensors_per_signature[signature] - - dict_inputs = _prepare_dict_inputs(inputs, input_tensors) - - # The input arguments are directly fed into the session. - feed_dict = { - input_tensors[key]: value for key, value in dict_inputs.items() - } - output = output_tensors_per_signature[signature] - output = _prepare_outputs(output, as_dict) - return sess.run(output, feed_dict=feed_dict) - - with original_graph.as_default(): - # Yield the function since that will keep the session alive until the - # user exits the context. - yield func diff --git a/tensorflow_hub/module_def.proto b/tensorflow_hub/module_def.proto deleted file mode 100644 index 18525bc4..00000000 --- a/tensorflow_hub/module_def.proto +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -syntax = "proto3"; - -package tensorflow_hub; - -// A Hub Module is stored in a directory with a file 'tfhub_module.pb' -// containing a serialized protocol message of this type. The further contents -// of the directory depend on the storage format described by the message. -message ModuleDef { - enum Format { - // This value is never set in a ModuleDef message. Proto parsing will return - // it in lieu of a missing value. - FORMAT_UNSPECIFIED = 0; - - // Hub SavedModel format v3: - // - The remaining files in the Module directory are a SavedModel - // with variables and assets. - // - The reader must adhere to the required_features protocol (see below). - FORMAT_V3 = 3; - } - // The storage format of this module. Unknown values (likely from future - // formats) or an unspecified value should be treated as unsupported. - Format format = 1; - - // List of feature names that must be supported by the reader to successfully - // interpret this module. - // - // Features for FORMAT_V3: - // - None yet. - repeated string required_features = 2; -} diff --git a/tensorflow_hub/module_test.py b/tensorflow_hub/module_test.py deleted file mode 100644 index 26e8fe48..00000000 --- a/tensorflow_hub/module_test.py +++ /dev/null @@ -1,344 +0,0 @@ -# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Unit tests for tensorflow_hub.module.""" - -import tensorflow as tf -from tensorflow_hub import module -from tensorflow_hub import module_impl -from tensorflow_hub import module_spec -from tensorflow_hub import tensor_info - - -class TestConvertInputsOutputs(tf.test.TestCase): - - def testSingleInput(self): - inputs_info = { - "x": tensor_info.ParsedTensorInfo( - tf.float32, - tf.TensorShape([None]), - is_sparse=False), - } - def _check(dict_inputs): - self.assertEqual(len(dict_inputs), 1) - self.assertEqual(dict_inputs["x"].dtype, tf.float32) - self.assertTrue(dict_inputs["x"].shape.is_compatible_with([None])) - - _check(module._convert_dict_inputs([1, 2], inputs_info)) - _check(module._convert_dict_inputs({"x": [1, 2]}, inputs_info)) - - with self.assertRaisesRegexp(TypeError, r"missing \['x'\]"): - module._convert_dict_inputs(None, inputs_info) - - with self.assertRaisesRegexp(TypeError, r"extra given \['y'\]"): - module._convert_dict_inputs({"x": [1, 2], "y": [1, 2]}, inputs_info) - - def testNoInputs(self): - self.assertEqual(module._convert_dict_inputs(None, {}), {}) - self.assertEqual(module._convert_dict_inputs({}, {}), {}) - - with self.assertRaisesRegexp(TypeError, "expects no inputs"): - module._convert_dict_inputs([None], {}) - - with self.assertRaisesRegexp(TypeError, "expects no inputs"): - module._convert_dict_inputs(1, {}) - - with self.assertRaisesRegexp(TypeError, r"extra given \['x'\]"): - module._convert_dict_inputs({"x": 1}, {}) - - def testMultipleInputs(self): - inputs_info = { - "x": tensor_info.ParsedTensorInfo( - tf.float32, - tf.TensorShape([None]), - is_sparse=False), - "y": tensor_info.ParsedTensorInfo( - tf.float32, - tf.TensorShape([None]), - is_sparse=False), - } - def _check(dict_inputs): - self.assertEqual(len(dict_inputs), 2) - for key in ("x", "y"): - self.assertEqual(dict_inputs[key].dtype, tf.float32) - self.assertTrue(dict_inputs[key].shape.is_compatible_with([None])) - - _check(module._convert_dict_inputs({"x": [1, 2], "y": [1, 2]}, - inputs_info)) - - with self.assertRaisesRegexp(TypeError, r"missing \['x', 'y'\]"): - module._convert_dict_inputs(None, inputs_info) - with self.assertRaisesRegexp(TypeError, r"missing \['x', 'y'\]"): - module._convert_dict_inputs({}, inputs_info) - with self.assertRaisesRegexp(TypeError, r"missing \['x', 'y'\]"): - module._convert_dict_inputs({"z": 1}, inputs_info) - - with self.assertRaisesRegexp( - TypeError, "Signature expects multiple inputs. Use a dict."): - module._convert_dict_inputs(1, inputs_info) - - def testOutputWithDefault(self): - outputs = {"default": "result", "extra": "dbg info"} - self.assertEquals(module._prepare_outputs(outputs, as_dict=False), "result") - self.assertEquals(module._prepare_outputs(outputs, as_dict=True), outputs) - - def testDictOutput(self): - outputs = {"x": 1, "y": 2} - self.assertEquals(module._prepare_outputs(outputs, as_dict=True), outputs) - with self.assertRaisesRegexp(TypeError, r"Use as_dict=True."): - self.assertEquals(module._prepare_outputs(outputs, as_dict=False), - outputs) - - -class GetStateScopeTest(tf.test.TestCase): - - def testGetStateScope(self): - with tf.Graph().as_default(): - self.assertEqual(module._try_get_state_scope("a"), "a/") - self.assertEqual(module._try_get_state_scope("a"), "a_1/") - - def testGetStateScope_UsesVariableScope(self): - with tf.Graph().as_default(): - self.assertEqual(module._try_get_state_scope("a"), "a/") - with tf.compat.v1.variable_scope(None, default_name="a") as vs: - self.assertEqual(vs.name, "a_1") - - def testGetStateScope_UsesNameScope(self): - with tf.Graph().as_default(): - self.assertEqual(module._try_get_state_scope("a"), "a/") - with tf.compat.v1.name_scope("a") as ns: - self.assertEqual(ns, "a_1/") - - def testGetStateScope_UnusedNameScope(self): - with tf.Graph().as_default(): - self.assertEqual(module._try_get_state_scope("a", False), "a/") - with tf.compat.v1.name_scope("a") as ns: - self.assertEqual(ns, "a/") - - self.assertEqual(module._try_get_state_scope("a", False), "a_1/") - with tf.compat.v1.name_scope("a") as ns: - self.assertEqual(ns, "a_1/") - - def testGetStateScope_AlreadyUsedNameScope(self): - with tf.Graph().as_default(): - with tf.compat.v1.name_scope("a"): - pass - with self.assertRaisesRegexp(RuntimeError, - "name_scope was already taken"): - module._try_get_state_scope("a", False) - - def testGetStateScopeWithActiveScopes(self): - with tf.Graph().as_default(): - with tf.compat.v1.name_scope("foo"): - abs_scope = module._try_get_state_scope("a", False) - self.assertEqual(abs_scope, "a/") - with tf.compat.v1.name_scope(abs_scope) as ns: - self.assertEqual(ns, "a/") - - with tf.Graph().as_default(): - with tf.compat.v1.variable_scope("vs"): - self.assertEqual(module._try_get_state_scope("a", False), "vs/a/") - with tf.compat.v1.name_scope(name="a") as ns: - self.assertEqual(ns, "vs/a/") - - with tf.Graph().as_default(): - with tf.compat.v1.name_scope("foo"): - with tf.compat.v1.variable_scope("vs"): - self.assertEquals(module._try_get_state_scope("a", False), "vs/a/") - - -class _ModuleSpec(module_spec.ModuleSpec): - - def get_tags(self): - return [set(), set(["special"])] - - def get_signature_names(self, tags=None): - if tags == set(["special"]): - return iter(["default", "extra", "sparse", "ragged"]) - else: - return iter(["default"]) - - def get_input_info_dict(self, signature=None, tags=None): - if signature == "ragged" and tags == set(["special"]): - result = { - "x": - tensor_info.ParsedTensorInfo.from_type_spec( - type_spec=tf.RaggedTensorSpec( - shape=[None, None, None, 3], dtype=tf.float32, - ragged_rank=2)), - } - else: - result = { - "x": - tensor_info.ParsedTensorInfo( - tf.float32, - tf.TensorShape([None]), - is_sparse=(signature == "sparse" and - tags == set(["special"]))), - } - if tags == set(["special"]) and signature == "extra": - result["y"] = result["x"] - return result - - def get_output_info_dict(self, signature=None, tags=None): - result = { - "default": tensor_info.ParsedTensorInfo( - tf.float32, - tf.TensorShape([None]), - is_sparse=False), - } - if tags == set(["special"]) and signature == "extra": - result["z"] = result["default"] - return result - - def _create_impl(self, name, trainable, tags): - return _ModuleImpl(name, trainable) - - # native_module_test.py covers setting and getting attached messages. - def _get_attached_bytes(self, key, tags): - del key, tags # Unused. - return None - - -class _ModuleImpl(module_impl.ModuleImpl): - - def __init__(self, name, trainable): - super().__init__() - with tf.compat.v1.variable_scope(name): - pass - - def create_apply_graph(self, signature, input_tensors, name): - with tf.compat.v1.name_scope(name): - if signature == "sparse": - input_tensors = { - key: tf.compat.v1.sparse_tensor_to_dense(value) - for key, value in input_tensors.items() - } - result = {"default": 2 * input_tensors["x"]} - if signature == "extra": - result["z"] = 2 * input_tensors["x"] + 3 * input_tensors["y"] - return result - - def export(self, path, session): - raise NotImplementedError() - - @property - def variable_map(self): - raise NotImplementedError() - - -class ModuleTest(tf.test.TestCase): - - def testModuleSingleInput(self): - with tf.Graph().as_default(): - m = module.Module(_ModuleSpec()) - result = m([1, 2]) - with tf.compat.v1.Session() as session: - self.assertAllEqual(session.run(result), [2, 4]) - - def testModuleDictInput(self): - with tf.Graph().as_default(): - m = module.Module(_ModuleSpec()) - result = m({"x": [1, 2]}) - with tf.compat.v1.Session() as session: - self.assertAllEqual(session.run(result), [2, 4]) - - def testModuleDictOutput(self): - with tf.Graph().as_default(): - m = module.Module(_ModuleSpec()) - result = m([1, 2], as_dict=True) - self.assertIsInstance(result, dict) - self.assertAllEqual(list(result.keys()), ["default"]) - - def testModuleInNestedScope(self): - with tf.Graph().as_default(): - with tf.compat.v1.variable_scope("foo"): - m = module.Module(_ModuleSpec()) - result = m([1, 2]) - with tf.compat.v1.Session() as session: - self.assertAllEqual(session.run(result), [2, 4]) - - def testModuleInterfaceGettersDefaultSignatureAndTags(self): - with tf.Graph().as_default(): - m = module.Module(_ModuleSpec()) - self.assertItemsEqual(m.get_signature_names(), ["default"]) - self.assertItemsEqual(m.get_input_info_dict().keys(), ["x"]) - self.assertItemsEqual(m.get_output_info_dict().keys(), ["default"]) - - def testModuleInterfaceGettersExplicitSignatureAndTags(self): - """Tests that tags from Module(...) apply to module.get_*().""" - with tf.Graph().as_default(): - m = module.Module(_ModuleSpec(), tags={"special"}) - self.assertItemsEqual(m.get_signature_names(), - ["default", "extra", "sparse", "ragged"]) - self.assertItemsEqual(m.get_input_info_dict(signature="extra").keys(), - ["x", "y"]) - self.assertItemsEqual(m.get_output_info_dict(signature="extra").keys(), - ["z", "default"]) - - -class EvalFunctionForModuleTest(tf.test.TestCase): - """Tests for hub.eval_function_for_module(...). - - This tests that hub.eval_function_for_module parses input variables, - signatures and tags correctly and that it returns the correct output. - End-to-end tests with the native module are done in native_module_test.py. - """ - - def testSingleInput(self): - with module.eval_function_for_module(_ModuleSpec()) as f: - self.assertAllEqual(f([1, 2]), [2, 4]) - - def testSparseInput(self): - with module.eval_function_for_module(_ModuleSpec(), tags={"special"}) as f: - self.assertAllEqual( - f(tf.compat.v1.SparseTensorValue([[0]], [1], [2]), # Value is [1, 0]. - signature="sparse"), - [2, 0]) - - # TODO(b/273203177): Enable once dtype can be propagated to numpy. - # def testRaggedInput(self): - # with module.eval_function_for_module(_ModuleSpec(), tags={"special"}) as f: - # rt = tf.compat.v1.ragged.constant_value( - # [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], - # [[[20, 21, 22], [30, 35, 38], [0, 2, 0]]]], - # ragged_rank=2) - - # self.assertAllEqual(f(rt, signature="ragged").to_list(), - # [[[[2, 4, 6], [8, 10, 12]], [[14, 16, 18]]], - # [[[40, 42, 44], [60, 70, 76], [0, 4, 0]]]]) - - def testDictInput(self): - with module.eval_function_for_module(_ModuleSpec()) as f: - self.assertAllEqual(f({"x": [1, 2]}), [2, 4]) - - def testDictOutput(self): - with module.eval_function_for_module(_ModuleSpec()) as f: - result = f({"x": [1, 2]}, as_dict=True) - self.assertTrue(isinstance(result, dict)) - self.assertAllEqual(list(result.keys()), ["default"]) - - def testSignature(self): - with module.eval_function_for_module(_ModuleSpec()) as f: - self.assertAllEqual(f([1, 2]), [2, 4]) - - def testExplicitSignatureAndTags(self): - with module.eval_function_for_module(_ModuleSpec(), tags={"special"}) as f: - result = f(dict(x=[1], y=[2]), signature="extra", as_dict=True) - self.assertAllEqual(result["default"], [2]) - self.assertAllEqual(result["z"], [8]) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow_hub/module_v2.py b/tensorflow_hub/module_v2.py index 09c32994..81d82e5a 100644 --- a/tensorflow_hub/module_v2.py +++ b/tensorflow_hub/module_v2.py @@ -17,9 +17,16 @@ import os import tensorflow as tf -from tensorflow_hub import native_module from tensorflow_hub import registry +_MODULE_PROTO_FILENAME_PB = "tfhub_module.pb" + + +def _get_module_proto_path(module_dir): + return os.path.join( + tf.compat.as_bytes(module_dir), + tf.compat.as_bytes(_MODULE_PROTO_FILENAME_PB)) + def resolve(handle): """Resolves a module handle into a path. @@ -91,8 +98,7 @@ def load(handle, tags=None, options=None): if not isinstance(handle, str): raise ValueError("Expected a string, got %s" % handle) module_path = resolve(handle) - is_hub_module_v1 = tf.io.gfile.exists( - native_module.get_module_proto_path(module_path)) + is_hub_module_v1 = tf.io.gfile.exists(_get_module_proto_path(module_path)) if tags is None and is_hub_module_v1: tags = [] diff --git a/tensorflow_hub/module_v2_test.py b/tensorflow_hub/module_v2_test.py index 937b05cc..b281d323 100644 --- a/tensorflow_hub/module_v2_test.py +++ b/tensorflow_hub/module_v2_test.py @@ -17,7 +17,6 @@ import os from absl.testing import parameterized import tensorflow as tf -import tensorflow_hub as hub from tensorflow_hub import module_v2 @@ -32,69 +31,16 @@ def plus_one(x): tf.saved_model.save(obj, path) -def _save_plus_one_hub_module_v1(path): - - def plus_one(): - x = tf.compat.v1.placeholder(dtype=tf.float32, name='x') - y = x + 1 - hub.add_signature(inputs=x, outputs=y) - - spec = hub.create_module_spec(plus_one) - - with tf.compat.v1.Graph().as_default(): - module = hub.Module(spec, trainable=True) - with tf.compat.v1.Session() as session: - session.run(tf.compat.v1.global_variables_initializer()) - module.export(path, session) - - -def _save_sparse_plus_one_hub_module_v1(path): - - def plus_one(): - x = tf.compat.v1.sparse.placeholder(dtype=tf.float32, name='x') - y = tf.identity(tf.SparseTensor(x.indices, x.values + 1, x.dense_shape)) - hub.add_signature(inputs=x, outputs=y) - - spec = hub.create_module_spec(plus_one) - - with tf.compat.v1.Graph().as_default(): - module = hub.Module(spec, trainable=True) - with tf.compat.v1.Session() as session: - session.run(tf.compat.v1.global_variables_initializer()) - module.export(path, session) - - -def _save_ragged_plus_one_hub_module_v1(path): - - def plus_one(): - x = tf.compat.v1.ragged.placeholder( - dtype=tf.float32, ragged_rank=1, value_shape=[], name='x') - y = tf.identity(x + 1) - hub.add_signature(inputs=x, outputs=y) - - spec = hub.create_module_spec(plus_one) - - with tf.compat.v1.Graph().as_default(): - module = hub.Module(spec, trainable=True) - with tf.compat.v1.Session() as session: - session.run(tf.compat.v1.global_variables_initializer()) - module.export(path, session) - - class ModuleV2Test(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( - ('v1_implicit_tags', 'hub_module_v1_mini', None, True), - ('v1_explicit_tags', 'hub_module_v1_mini', [], True), - ('v2_implicit_tags', 'saved_model_v2_mini', None, False), - ('v2_explicit_tags', 'saved_model_v2_mini', ['serve'], False), + ('v2_implicit_tags', None, False), + ('v2_explicit_tags', ['serve'], False), ) - def test_load(self, module_name, tags, is_hub_module_v1): + def test_load(self, tags, is_hub_module_v1): + module_name = 'saved_model_v2_mini' export_dir = os.path.join(self.get_temp_dir(), module_name) - if module_name == 'hub_module_v1_mini': - _save_plus_one_hub_module_v1(export_dir) - else: - _save_plus_one_saved_model_v2(export_dir) + _save_plus_one_saved_model_v2(export_dir) m = module_v2.load(export_dir, tags) self.assertEqual(m._is_hub_module_v1, is_hub_module_v1) @@ -105,39 +51,6 @@ def test_load_incomplete_model_fails(self): with self.assertRaisesRegex(ValueError, 'contains neither'): module_v2.load(temp_dir) - def test_load_sparse(self): - if any(tf.__version__.startswith(bad) for bad in ['1.', '2.0.']): - self.skipTest('load_v1_in_v2 did not handle sparse tensors correctly' - 'in TensorFlow version %r.' % (tf.__version__,)) - export_dir = os.path.join(self.get_temp_dir(), 'sparse') - _save_sparse_plus_one_hub_module_v1(export_dir) - m = module_v2.load(export_dir) - self.assertTrue(m._is_hub_module_v1) - plus_one = m.signatures['default'] - st = tf.sparse.from_dense([[1.0, 2.0, 0.0], [0.0, 3.0, 0.0]]) - actual = plus_one( - default_indices=st.indices, - default_values=st.values, - default_dense_shape=st.dense_shape)['default'] - expected = [2.0, 3.0, 4.0] - self.assertAllEqual(actual.values, expected) - - def test_load_ragged(self): - if any(tf.__version__.startswith(bad) for bad in - ['1.', '2.0.', '2.1.', '2.2.', '2.3.']): - self.skipTest('load_v1_in_v2 did not handle composite tensors correctly' - 'in TensorFlow version %r.' % (tf.__version__,)) - export_dir = os.path.join(self.get_temp_dir(), 'ragged') - _save_ragged_plus_one_hub_module_v1(export_dir) - m = module_v2.load(export_dir) - self.assertTrue(m._is_hub_module_v1) - plus_one = m.signatures['default'] - rt = tf.ragged.constant([[1.0, 8.0], [3.0]]) - actual = plus_one(default_component_0=rt.values, - default_component_1=rt.row_splits)['default'] - expected = [2.0, 9.0, 4.0] - self.assertAllEqual(actual.values, expected) - def test_load_without_string(self): with self.assertRaisesRegex(ValueError, 'Expected a string, got.*'): module_v2.load(0) diff --git a/tensorflow_hub/native_module.py b/tensorflow_hub/native_module.py deleted file mode 100644 index 20952ac8..00000000 --- a/tensorflow_hub/native_module.py +++ /dev/null @@ -1,1180 +0,0 @@ -# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""The implementation of deprecated hub.Module backed by TF1 Hub format.""" - -import collections -import os -import re - -from absl import logging -import tensorflow as tf - -from tensorflow_hub import meta_graph_lib -from tensorflow_hub import module_def_pb2 -from tensorflow_hub import module_impl -from tensorflow_hub import module_spec -from tensorflow_hub import saved_model_lib -from tensorflow_hub import tensor_info -from tensorflow_hub import tf_utils - -# pylint: disable=g-direct-tensorflow-import -from tensorflow.core.protobuf import meta_graph_pb2 - -# TODO(b/72732111): Get this APIs or similar functionality to be public. -# They are needed to identify the "state-ops" in a graph and to load C -# registered ops into the python register for import_meta_graph to succeed -# without having to do "import library_that_register_missing_op". -# pylint: disable=g-bad-import-order -from tensorflow.core.framework import op_def_pb2 -from tensorflow.core.framework import types_pb2 -from tensorflow.python.framework import op_def_registry -from tensorflow.python import pywrap_tensorflow as c_api -# pylint: enable=g-bad-import-order -# pylint: enable=g-direct-tensorflow-import - -# Align `op_def_registry` API between TensorFlow 1.X and 2.X. -if not hasattr(op_def_registry, "get"): - - def get(name): - registered_ops = op_def_registry.get_registered_ops() - return registered_ops.get(name) - - op_def_registry.get = get - -if not hasattr(op_def_registry, "sync"): - - def _remove_non_deprecated_descriptions(op_def): - for input_arg in op_def.input_arg: - input_arg.description = "" - for output_arg in op_def.output_arg: - output_arg.description = "" - for attr in op_def.attr: - attr.description = "" - - op_def.summary = "" - op_def.description = "" - - def sync(): - p_buffer = c_api.TF_GetAllOpList() - cpp_op_list = op_def_pb2.OpList() - cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer)) - - registered_ops = op_def_registry.get_registered_ops() - for op_def in cpp_op_list.op: - # If an OpList is registered from a gen_*_ops.py, it does not any - # descriptions. Strip them here as well to satisfy validation in - # register_op_list. - _remove_non_deprecated_descriptions(op_def) - registered_ops[op_def.name] = op_def - - op_def_registry.sync = sync - -_MODULE_PROTO_FILENAME_PB = "tfhub_module.pb" - -_MODULE_V3_SUPPORTED_FEATURES = frozenset([]) # None yet. - -_SUPPORTED_COLLECTIONS = set([ - # GLOBAL_VARIABLES, TRAINABLE_VARIABLES and MODEL_VARIABLES hold - # tf.Variable objects saved in CollectionDef.bytes_list as serialized - # VariableDef proto. - tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, - tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, - tf.compat.v1.GraphKeys.MODEL_VARIABLES, - # This holds tf.Operation objects, saved in CollectionDef.node_list. - tf.compat.v1.GraphKeys.TABLE_INITIALIZERS, - # This holds tf.Tensor objects, saved in CollectionDef.node_list. - tf.compat.v1.GraphKeys.UPDATE_OPS, - # This holds tf.Tensor objects, saved in CollectionDef.node_list. - # These are imported to help fine-tuning (unlike LOSSES, which the - # importing model redefines from scratch). - tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES, - # This holds constant tensors of type string. - tf.compat.v1.GraphKeys.ASSET_FILEPATHS, - # This holds serialized CondContextDef protos in CollectionDef.bytes_list. - tf.compat.v1.GraphKeys.COND_CONTEXT, - # This holds serialized WhileContextDef protos in CollectionDef.bytes_list. - tf.compat.v1.GraphKeys.WHILE_CONTEXT, - # saved_model_lib uses this collection internally for ModuleAttachments. - saved_model_lib.ATTACHMENT_COLLECTION_SAVED, -]) - - -def get_module_proto_path(module_dir): - return os.path.join( - tf.compat.as_bytes(module_dir), - tf.compat.as_bytes(_MODULE_PROTO_FILENAME_PB)) - - -class Loader(object): - """Loader for Hub modules in the native format.""" - - def is_supported(self, path): - return True - - def _get_module_def_proto(self, path): - module_def_path = get_module_proto_path(path) - module_def_proto = module_def_pb2.ModuleDef() - with tf.compat.v1.gfile.Open(module_def_path, "rb") as f: - module_def_proto.ParseFromString(f.read()) - return module_def_proto - - def _module_def_proto_to_module_spec(self, path): - saved_model_handler = saved_model_lib.load(path) - checkpoint_filename = saved_model_lib.get_variables_path(path) - return _ModuleSpec(saved_model_handler, checkpoint_filename) - - def __call__(self, path): - module_def_proto = self._get_module_def_proto(path) - - if module_def_proto.format != module_def_pb2.ModuleDef.FORMAT_V3: - raise ValueError("Unsupported module def format: %r" % - module_def_proto.format) - - required_features = set(module_def_proto.required_features) - unsupported_features = (required_features - _MODULE_V3_SUPPORTED_FEATURES) - - if unsupported_features: - raise ValueError("Unsupported features: %r" % list(unsupported_features)) - - return self._module_def_proto_to_module_spec(path) - - -def create_module_spec(module_fn, tags_and_args=None, drop_collections=None): - """Creates a ModuleSpec from a function that builds the module's graph. - - Warning: Deprecated. This belongs to the hub.Module API and TF1 Hub format. - For TF2, switch to plain SavedModels. - - The `module_fn` is called on a new graph (not the current one) to build the - graph of the module and define its signatures via `hub.add_signature()`. - Example: - - ```python - # Define a text embedding module. - def my_text_module_fn(): - text_input = tf.placeholder(dtype=tf.string, shape=[None]) - embeddings = compute_embedding(text_input) - hub.add_signature(inputs=text_input, outputs=embeddings) - ``` - - See `add_signature()` for documentation on adding multiple input/output - signatures. - - NOTE: The `module_fn` is called on a graph that uses resource variables - by default. If you want old-style variables ("ref variables"), then - you can use `with tf.variable_scope("", use_resource=False)` in `module_fn`. - - Multiple graph variants can be defined by using the `tags_and_args` argument. - For example, the code: - - ```python - hub.create_module_spec( - module_fn, - tags_and_args=[({"train"}, {"is_training":True}), - (set(), {"is_training":False})]) - ``` - - calls `module_fn` twice, once as `module_fn(is_training=True)` and once as - `module_fn(is_training=False)` to define the respective graph variants: - for training with tags {"train"} and for inference with the empty set of tags. - Using the empty set aligns the inference case with the default in - Module.__init__(). - - THIS FUNCTION IS DEPRECATED. - - Args: - module_fn: a function to build a graph for the Module. - tags_and_args: Optional list of tuples (tags, kwargs) of tags and keyword - args used to define graph variants. If omitted, it is interpreted as - [(set(), {})], meaning `module_fn` is called once with no args. - drop_collections: list of collection to drop. - - Returns: - A ModuleSpec. - - Raises: - ValueError: if it fails to construct the ModuleSpec due to bad or - unsupported values in the arguments or in the graphs constructed by - `module_fn`. - """ - if not drop_collections: - drop_collections = [] - - report_tags = True - if not tags_and_args: - tags_and_args = [(set(), {})] - report_tags = False - - saved_model_handler = saved_model_lib.SavedModelHandler() - for tags, args in tags_and_args: - with tf.Graph().as_default() as graph: - with tf.compat.v1.variable_scope("", use_resource=True): - module_fn(**args) - - for collection_key in drop_collections: - del tf.compat.v1.get_collection_ref(collection_key)[:] - - err = find_state_op_colocation_error(graph, tags if report_tags else None) - if err: raise ValueError(err) - saved_model_handler.add_graph_copy(graph, tags=tags) - - return _ModuleSpec(saved_model_handler, checkpoint_variables_path=None) - - -def add_signature(name=None, inputs=None, outputs=None): - """Adds a signature to the module definition. - - Warning: Deprecated. This belongs to the hub.Module API and TF1 Hub format. - For TF2, switch to plain SavedModels. - - NOTE: This must be called within a `module_fn` that is defining a hub.Module. - - THIS FUNCTION IS DEPRECATED. - - Args: - name: Signature name as a string. If omitted, it is interpreted as 'default' - and is the signature used when `Module.__call__` `signature` is not - specified. - inputs: A dict from input name to Tensor or composite tensor (such as - SparseTensor or RaggedTensor) to feed when applying the signature. If a - single tensor is passed, it is interpreted as a dict with a single - 'default' entry. - outputs: A dict from output name to Tensor or composite tensor (such as - SparseTensor or RaggedTensor) to return from applying the signature. If a - single tensor is passed, it is interpreted as a dict with a single - 'default' entry. - - Raises: - ValueError: if the arguments are invalid. - """ - if not name: - name = "default" - if inputs is None: - inputs = {} - if outputs is None: - outputs = {} - if not isinstance(inputs, dict): - inputs = {"default": inputs} - if not isinstance(outputs, dict): - outputs = {"default": outputs} - message = find_signature_inputs_from_multivalued_ops(inputs) - if message: logging.error(message) - message = find_signature_input_colocation_error(name, inputs) - if message: raise ValueError(message) - message = find_signature_type_errors(name, inputs, outputs) - if message: raise ValueError(message) - saved_model_lib.add_signature(name, inputs, outputs) - - -def attach_message(key, message): - """Adds an attached message to the module definition. - - Warning: Deprecated. This belongs to the hub.Module API and TF1 Hub format. - For TF2, switch to plain SavedModels. - - NOTE: This must be called within a `module_fn` that is defining a hub.Module. - - See ModuleSpec.get_attached_message() for an introduction to attached messages - and the API for module consumers. - - To define a new type of attached message: - - * Select a reasonably descriptive name as a unique key. For now, keys must - be valid Python identifiers that start with a letter. Punctuation besides - underscores ('_') is reserved for future use in hierarchical names. - - * Define a Protocol Buffer message type to store the value for the key. - (Use generic containers like google.protobuf.Value only if running - the protocol compiler is infeasible for your build process.) - - * For module consumers, consider providing a small library that encapsulates - the specific call to get_attached_message() behind a higher-level - interface and supplies the right message type for parsing. - - Attached messages work best for few messages of moderate size. - Avoid a large number of messages; use repetition within messages instead. - Avoid large messages (megabytes); consider module assets instead. - - For modules with multiple graph versions, each graph version stores separately - what was attached from within the call to `module_fn` that defines its graph. - - THIS FUNCTION IS DEPRECATED. - - Args: - key: A string with the unique key to retrieve this message. Must start - with a letter and contain only letters, digits and underscores. If used - repeatedly within one invocation of `module_fn`, then only the message - from the final call will be returned by `get_attached_message()`. - message: A protocol message object, to be stored in serialized form. - - Raises: - ValueError: if `key` is not a string of the form of a Python identifier. - """ - if not re.match(r"[a-zA-Z][a-zA-Z0-9_]*$", key): - raise ValueError( - "hub.attach_message() called with malformed key '%s'" % key) - saved_model_lib.attach_bytes(key, message.SerializeToString()) - - -class _ModuleSpec(module_spec.ModuleSpec): - """ModuleSpec for Hub's native Module format (backed by SavedModel).""" - - def __init__(self, saved_model_handler, checkpoint_variables_path, - check_collections=True): - """Private constructor. - - Args: - saved_model_handler: SavedModelHandler backing up this Module definition. - checkpoint_variables_path: An optional string to the checkpoint where this - Module variables are checkpointed. If given the variables initializers - are overridden to load from it. - check_collections: Whether to check collections are supported. - - Raises: - ValueError: if SavedModel contains any unexpected value. - """ - check_unique_tags(saved_model_handler.get_tags()) - if check_collections: - check_collections_are_supported( - saved_model_handler, _SUPPORTED_COLLECTIONS) - self._saved_model_handler = saved_model_handler - self._checkpoint_variables_path = checkpoint_variables_path - self._module_attachments = { - tags: saved_model_handler.get_attached_bytes_map(tags) - for tags in saved_model_handler.get_tags()} - - def get_tags(self): - return self._saved_model_handler.get_tags() - - def get_signature_names(self, tags=None): - meta_graph = self._saved_model_handler.get_meta_graph(tags=tags) - return list(meta_graph.signature_def.keys()) - - def get_input_info_dict(self, signature=None, tags=None): - signature_def = self._get_signature_def(signature, tags) - return tensor_info.parse_tensor_info_map(signature_def.inputs) - - def get_output_info_dict(self, signature=None, tags=None): - signature_def = self._get_signature_def(signature, tags) - return tensor_info.parse_tensor_info_map(signature_def.outputs) - - def _get_signature_def(self, signature, tags): - meta_graph = self._saved_model_handler.get_meta_graph(tags=tags) - if signature is None: - signature = "default" - signature_def = meta_graph.signature_def.get(signature) - if signature_def is None: - raise ValueError("Signature %r is missing from meta graph." % signature) - return signature_def - - def _get_attached_bytes(self, key, tags): - return self._module_attachments[frozenset(tags or [])].get(key) - - def _create_impl(self, name, trainable, tags): - meta_graph = self._saved_model_handler.get_meta_graph(tags=tags) - return _ModuleImpl( - spec=self, - meta_graph=meta_graph, - trainable=trainable, - checkpoint_path=self._checkpoint_variables_path, - name=name) - - def _export(self, path, variables_saver): - """Internal. - - Args: - path: string where to export the module to. - variables_saver: an unary-function that writes the module variables - checkpoint on the given path. - """ - self._saved_model_handler.export(path, variables_saver=variables_saver) - - module_def_proto = module_def_pb2.ModuleDef() - module_def_proto.format = module_def_pb2.ModuleDef.FORMAT_V3 - module_def_filename = get_module_proto_path(path) - tf_utils.atomic_write_string_to_file( - module_def_filename, - module_def_proto.SerializeToString(), - overwrite=False) - logging.info("Exported TF-Hub module to: %s", path) - - -class _ModuleImpl(module_impl.ModuleImpl): - """A Module instantiation backed by a MetaGraphDef.""" - - def __init__(self, spec, meta_graph, trainable, checkpoint_path, name): - """Private constructor. - - Args: - spec: _ModuleSpec instance. - meta_graph: MetaGraphDef to use - trainable: whether module is trainable. - checkpoint_path: None or a string to the variables checkpoints. - name: variable and scope name where to instantiate the Module. Must be an - unused name scope. - """ - self._spec = spec - self._meta_graph = meta_graph - self._trainable = trainable - self._checkpoint_path = checkpoint_path - - register_ops_if_needed({ - op.name for op in self._meta_graph.meta_info_def.stripped_op_list.op}) - - if _is_tpu_graph_function(): - # TODO(b/129142908): Hub should not use `tf.init_scope` since that makes - # it incompatible with tf.compat.v1.wrap_function. For now the only use - # case where hub used it was for tpu compatibility. This should be cleaned - # up at an early convinience. - scope_func = tf.init_scope - else: - scope_func = lambda: tf.control_dependencies(None) - - # Clear dependencies so modules can be constructed from deep inside - # functions that have dependencies active. Note that the dependencies - # would be active when applying the Module signature, just not active - # when creating the Module state. This use case has showed up in some - # TPU training code. - with scope_func(): - self._init_state(name) - - def _init_state(self, name): - variable_tensor_map, self._state_map = self._create_state_graph(name) - self._variable_map = recover_partitioned_variable_map( - get_node_map_from_tensor_map(variable_tensor_map)) - if self._variable_map and self._checkpoint_path: - tf.compat.v1.train.init_from_checkpoint(self._checkpoint_path, - self._variable_map) - - # Build Saver so it can be used later on to export the variables. - if self._variable_map: - self._saver = tf.compat.v1.train.Saver( - self._variable_map, - sharded=True, - write_version=tf.compat.v1.train.SaverDef.V2) - else: - self._saver = None - - def _create_state_graph(self, name): - """Creates the graph nodes that hold the state of the Module. - - Args: - name: name scope to create the state graph in. - - Returns: - A tuple consisting of: - variables_tensor_map: a map from tensor names in the original graph def - to the created Variables objects. - state_map: a map from tensors names in the original graph def to the - instantiated tensors to be used as a state_map. - """ - import_collections = [ - tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, - tf.compat.v1.GraphKeys.MODEL_VARIABLES, - tf.compat.v1.GraphKeys.TABLE_INITIALIZERS, - # A typical use of assets is a vocab file to initialize a table. - tf.compat.v1.GraphKeys.ASSET_FILEPATHS, - tf.compat.v1.GraphKeys.COND_CONTEXT, - tf.compat.v1.GraphKeys.WHILE_CONTEXT, - ] - if self._trainable: - # TODO(b/64049014): Import UPDATE_OPS which do not depend on inputs. - import_collections.extend([tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, - tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES]) - - absolute_scope_name = tf.compat.v1.get_default_graph().unique_name( - name, mark_as_used=False) - relative_scope_name = absolute_scope_name.split("/")[-1] - assert relative_scope_name == name # verify name scope was indeed unused. - - meta_graph = meta_graph_pb2.MetaGraphDef() - meta_graph.CopyFrom(self._meta_graph) - - meta_graph_lib.filter_collections(meta_graph, import_collections) - meta_graph_lib.prefix_shared_name_attributes(meta_graph, - absolute_scope_name) - - tf.compat.v1.train.import_meta_graph( - meta_graph, - input_map={}, - import_scope=relative_scope_name) - - # Build a list from the variable name in the module definition to the actual - # instantiated variables. - variables_tensor_map = {} - for var in tf.compat.v1.global_variables(): - if var.op.name.startswith(absolute_scope_name + "/"): - variables_tensor_map[var.name[len(absolute_scope_name)+1:]] = var - - # Build a map of tensors to feed from the state-graph into subsequent - # apply-graphs. - def _get_tensor(tensor_name): - return tf.compat.v1.get_default_graph().get_tensor_by_name( - meta_graph_lib.prepend_name_scope( - tensor_name, import_scope=absolute_scope_name)) - - state_op_names = list_registered_stateful_ops_without_inputs( - meta_graph.graph_def) - state_map = get_state_map(meta_graph, state_op_names, set(), _get_tensor) - - return variables_tensor_map, state_map - - def create_apply_graph(self, signature, input_tensors, name): - """See `ModuleImpl.create_apply_graph`.""" - signature_def = self._meta_graph.signature_def.get(signature) - meta_graph = meta_graph_pb2.MetaGraphDef() - meta_graph.CopyFrom(self._meta_graph) - apply_graph = tf.compat.v1.get_default_graph() - infeed_map = tensor_info.build_input_map(signature_def.inputs, - input_tensors) - - # Build a input map to feed when importing the apply-graph by augmenting the - # state_map with the input args. This allows an input to override a tensor - # from the state-graph. - feed_map = dict(self._state_map) - # If we are applying the module in a function with a TPUReplicateContext, we - # must capture the state tensors in generating our feedmap and prune out - # assign ops. Function graph semantics are different in that all ops are - # executed regardless of dependency. - # TODO(b/112575006): The following adds functionality of function call - # within a TPU context. Work to generalize this for all function calls is - # ongoing. - if _is_tpu_graph_function(): - for k, v in self._state_map.items(): - feed_map[k] = apply_graph.capture(v) - meta_graph_lib.prune_unused_nodes(meta_graph, signature_def) - # After we prune the metagraph def, we might need to prune away - # infeeds which no longer exist. - meta_graph_lib.prune_feed_map(meta_graph, infeed_map) - elif apply_graph.building_function: - # Log a warning if a user is using a hub module in function graph. - # This is only expected to work if the function graph is pruned and - # not all nodes are executed. - # - # E.g. it could work with "tf.compat.v1.wrap_function", but it will not - # work with defun, Dataset.map_fn, etc... - logging.warning( - "Using TF1 Hub format while building a function: %s. " - "This can lead to errors if the function is not pruned.", - apply_graph.name) - - # As state ops in the apply graph are unused, replace them with Placeholders - # so that in a heirarchical instantiation, apply_graph state ops are - # ignored. - replace_apply_state( - meta_graph, - list_registered_stateful_ops_without_inputs(meta_graph.graph_def), - feed_map) - feed_map.update(infeed_map) - - # Make state tensors enter the current context. This way the Module can be - # applied inside a control flow structure such as a while_loop. - control_flow = apply_graph._get_control_flow_context() # pylint: disable=protected-access - if control_flow: - for key, value in sorted(feed_map.items()): - feed_map[key] = control_flow.AddValue(value) - - # Don't mark the name as used at this point - import_scoped_meta_graph will - # start using it. - absolute_scope_name = apply_graph.unique_name(name, mark_as_used=False) - relative_scope_name = absolute_scope_name.split("/")[-1] - - import_collections = [ - # In most cases ASSET_FILEPATHS are only used for the TABLE_INITIALIZERS - # ops, however one could create a graph that uses an asset at any other - # time. As so everytime we bring the tensor with that has the asset - # filename we must annotate it as so, so later re-exports have that - # semantic information and can handle it. - tf.compat.v1.GraphKeys.ASSET_FILEPATHS, - tf.compat.v1.GraphKeys.COND_CONTEXT, - tf.compat.v1.GraphKeys.WHILE_CONTEXT, - ] - if self._trainable: - import_collections.extend([tf.compat.v1.GraphKeys.UPDATE_OPS]) - - meta_graph_lib.filter_collections(meta_graph, import_collections) - meta_graph_lib.prefix_shared_name_attributes(meta_graph, - absolute_scope_name) - if len(meta_graph.collection_def) and _is_tpu_graph_function(): - raise NotImplementedError( - "Applying modules with collections inside TPU functions is not " - "supported. Collections found: %s" % str(meta_graph.collection_def)) - - tf.compat.v1.train.import_meta_graph( - meta_graph, - input_map=feed_map, - import_scope=relative_scope_name) - fix_colocation_after_import(input_map=feed_map, - absolute_import_scope=absolute_scope_name) - - def get_tensor(name): - # When trying to output an input tensor there are no nodes created within - # the apply scope. So one must look into the input map. - try: - return feed_map[name] - except KeyError: - return apply_graph.get_tensor_by_name( - meta_graph_lib.prepend_name_scope( - name, import_scope=absolute_scope_name)) - - return tensor_info.build_output_map(signature_def.outputs, get_tensor) - - def export(self, path, session): - """See `Module.export`.""" - def variables_saver(variables_path): - if self._saver: - self._saver.save( - session, variables_path, - write_meta_graph=False, - write_state=False) - - self._spec._export(path, variables_saver) # pylint: disable=protected-access - - @property - def variable_map(self): - """See `Module.variable_map`.""" - return self._variable_map - - -def is_registered_stateful_op_without_inputs(name): - """Checks if an op is registered, stateful and does not expect inputs.""" - op_def = op_def_registry.get(name) - return op_def is not None and (op_def.is_stateful and not op_def.input_arg) - - -def list_registered_stateful_ops_without_inputs(graph_def): - """Returns set of registered stateful ops that do not expect inputs. - - This list is used to identify the ops to be included in the state-graph and - that are subsequently fed into the apply-graphs. - - Args: - graph_def: GraphDef to list ops from. - - Returns: - A set of strings. - """ - used_ops = (node.op for node in graph_def.node) - return {op for op in used_ops if is_registered_stateful_op_without_inputs(op)} - - -def get_state_map(meta_graph, state_ops, unsupported_state_ops, - get_tensor_by_name): - """Returns a map from tensor names to tensors that hold the state.""" - state_map = {} - for node in meta_graph.graph_def.node: - if node.op in state_ops: - tensor_name = node.name + ":0" - tensor = get_tensor_by_name(tensor_name) - num_outputs = len(tensor.op.outputs) - if num_outputs != 1: - raise ValueError("Stateful op %s has %d outputs, expected 1" % - (node.op, num_outputs)) - state_map[tensor_name] = tensor - if node.op in unsupported_state_ops: - raise ValueError("Unsupported stateful op: %s" % node.op) - return state_map - - -def replace_apply_state(meta_graph, state_ops, feed_map): - """Replaces state ops with non state Placeholder ops for the apply graph.""" - for node in meta_graph.graph_def.node: - keys_to_purge = [] - tensor_name = node.name + ":0" - # Verify that the node is a state op and that its due to be rewired - # in the feedmap. - if node.op in state_ops and tensor_name in feed_map: - node.op = "Placeholder" - for key in node.attr: - # Only shape and dtype are required for Placeholder. Remove other - # attributes. - if key != "shape": - keys_to_purge.append(key) - for key in keys_to_purge: - del node.attr[key] - node.attr["dtype"].type = types_pb2.DT_RESOURCE - - -def get_node_map_from_tensor_map(tensor_map): - """Converts the keys from tensor name to node name. - - Args: - tensor_map: Map where keys are full tensor names and values are tensors. - - Returns: - Map same as tensor_map, except keys have the output_number stripped. - """ - return { - _split_tensor_name(key)[0]: value - for key, value in tensor_map.items() - } - - -def _split_tensor_name(tensor_name): - """Given a tensor name as node_name:output_number, returns both parts.""" - result = re.match(r"(.*):(\d+)$", tensor_name) - if not result: - raise ValueError( - "Unexpected format for tensor name. Expected node_name:output_number. " - "Got %r" % tensor_name) - return result.group(1), int(result.group(2)) - - -def _extract_variable_parts(variable_key, variable): - """Matches a variable to individual parts. - - Args: - variable_key: String identifier of the variable in the module scope. - variable: Variable tensor. - - Returns: - partitioned: Whether the variable is partitioned. - name: Name of the variable up to the partitioning. - offset: Offset of the variable into the full variable. - - Raises: - RuntimeError: In case of unexpected variable format. - """ - name, offset, partitioned = None, None, False - # pylint: disable=protected-access - if variable._save_slice_info: - name = variable_key[:variable_key.rfind("/")] - if not variable._save_slice_info.full_name.endswith(name): - raise RuntimeError("Unexpected handling of partitioned variable.") - offset = variable._save_slice_info.var_offset[0] - partitioned = True - # pylint: enable=protected-access - return partitioned, name, offset - - -def recover_partitioned_variable_map(var_node_map): - """Builds a proper variable map if it contains PartitionedVariables. - - Args: - var_node_map: A map to tf.Variables. PartitionedVariables show up in this - map as N entries with keys "/part_n". - - Returns: - A map to tf.Variables or to list of tf.Variables for each - PartitionedVariables in `var_node_map`. - - Raises: - RuntimeError: if there are issues recovering the PartitionedVariables. - """ - offset_variables_map = {} - for var_key, var_tensor in var_node_map.items(): - match, var_name, offset = _extract_variable_parts(var_key, var_tensor) - - if not match: - # This is a standard variable, so we can safely add it to the output. - if var_key in offset_variables_map: - raise RuntimeError( - "Variable %s exists both as a single and partitioned variable.") - offset_variables_map[var_key] = var_tensor - continue - - if var_name not in offset_variables_map: - offset_variables_map[var_name] = {} - elif not isinstance(offset_variables_map[var_name], dict): - raise RuntimeError( - "Variable %s exists both as a single and partitioned variable.") - - # Duplicated variable offsets should not exist. - if offset in offset_variables_map[var_name]: - raise RuntimeError( - "Variable map contains duplicate offset %d for variable [%s]" % - (offset, var_name)) - offset_variables_map[var_name][offset] = var_tensor - - variables_map = {} - # Use offsets for sorting, then strip them from the dictionary and keep only - # a list of variables per each variable name. - for var_name, var_value in offset_variables_map.items(): - if not isinstance(var_value, dict): - variables_map[var_name] = var_value - continue - shapes = [var_tensor.shape[1:] for var_tensor in var_value.values()] - if not all(shape == shapes[0] for shape in shapes): - raise RuntimeError("Shapes not compatible: %s" % (shapes)) - for _, tensor in sorted(var_value.items()): - variables_map[var_name] = [ - tensor for _, tensor in sorted(var_value.items()) - ] - - return variables_map - - -def check_unique_tags(tag_list): - """Checks that tag list contains each set of tags only once.""" - frozen_tags_seen = set() - for tags in tag_list: - frozen_tags = frozenset(tags) - if frozen_tags in frozen_tags_seen: - raise ValueError("Tags %r used repeatedly" % tags) - frozen_tags_seen.add(frozen_tags) - - -def check_collections_are_supported(saved_model_handler, supported): - """Checks that SavedModelHandler only uses supported collections.""" - for meta_graph in saved_model_handler.meta_graphs: - used_collection_keys = set(meta_graph.collection_def.keys()) - unsupported = used_collection_keys - supported - if unsupported: - raise ValueError("Unsupported collections in graph: %s\n" - "Use hub.create_module_spec(..., drop_collections=[...])" - " as appropriate." % list(unsupported)) - - -def get_unsupported_collections(used_collection_keys): - return list(set(used_collection_keys) - _SUPPORTED_COLLECTIONS) - - -def register_ops_if_needed(graph_ops): - """Register graph ops absent in op_def_registry, if present in c++ registry. - - Args: - graph_ops: set with graph op names to register. - - Raises: - tf.errors.NotFoundError: if `graph_ops` contains ops that are not in either - python or c++ registry. - """ - if all(op_def_registry.get(op) is not None for op in graph_ops): - return - - # Note: Only raise missing op ValueError after trying to load ops. - # This allows the test to exercise all the calls into TensorFlow - # without having to write a C + python test. - op_def_registry.sync() - missing_ops = {op for op in graph_ops if op_def_registry.get(op) is None} - if missing_ops: - raise tf.errors.NotFoundError( - None, None, - "Graph ops missing from the python registry (%s) are also absent from " - "the c++ registry." % missing_ops) - - -def fix_colocation_after_import(input_map, absolute_import_scope): - """Fixes colocation attributes after import according to input_map. - - This function is meant to be called after importing a GraphDef, in order - to rewrite colocate_with constrains analogous to how inputs to ops - are rewritten by input_map during import. It also updates devices accordingly. - - The nodes in the given import scope of the current default graph have their - colocation attributes (that is, the "loc:@..." values in the "_class" attr) - rewritten as follows: If, before the call, op x has attribute loc:@y, and - `input_map` replaces an output of y with an output of z, then loc:@y gets - replaced by the colocation attributes of z (that is, loc:@z, if no other - constraints are in play). - - This style of rewriting imposes the following requirements: - - * If an output of node y is an input tensor in a signature of the module, - y must not have any colocation attributes on it, such that colocations - with y are expressed by loc:@y and can be adjusted with a rewriting rule - for it. Function `find_signature_input_colocation_error()` checks this - during module creation. - - * If y1 is a state node, its colocation constraints must only reference - other state nodes, say, y2. Since all outputs of state nodes are mapped - the same way, all their rewriting rules together will do the same thing. - Function `find_state_op_colocation_error()` checks this during module - creation. - - * Other nodes may have arbitrary colocation attributes. - - Mapping of inputs works with tensors, while colocation constraints work with - ops. Issues may arise when mapping tensors from ops with multiple outputs. - If the outputs of y are replaced by outputs of distinct ops z1, z2, ..., - rewriting of loc:@y becomes ambiguous unless z1, z2, ... have equal - colocation_groups) If some but not all outputs of y are replaced, it - becomes ambiguous whether to rewrite loc:@y at all. For now, this is - handled conservatively by raising an error (instead of rewriting to the - union of all applicable constraints). This should be very rare: all state - ops so far have single outputs (and even if not, the rewriting would be - consistent); input ops usually are placeholders, which have single outputs. - - Args: - input_map: a dict mapping from tensor names in the imported graph to - existing Tensors, typically the same as passed to tf.import_graph_def(). - absolute_import_scope: a string with the full name of the import scope, - comprising the current scope when import_graph_def() as called plus - the import_scope passed to it. - - Raises: - ValueError: if one imported op has its multiple outputs and they are - remapped in a way that causes conflicting colocation rewrites. - """ - attr_map = _build_colocation_attr_map(input_map, absolute_import_scope) - _apply_colocation_attr_map(attr_map, absolute_import_scope) - - -class _ConsistentValue(object): - """Helper for deferred consistency checking for values from multiple sources. - - Suppose you compute some value from multiple sources that should all be - consistent. This class helps you store the value (with context on sources) - and provides a getter method to either get the consistent value or raise - a exception with a meaningful custom error message. - - Usage example: - - remainder = _ConsistentValue() - for x in (105, 205, 305, 406): - remainder.Set(x % 100, {"x": x}) - print(remainder.GetConsistentValueOrRaise( - "Got {old_value} at {old_x} but became {new_value} at {new_x}.")) - - will print "5" three times and then raise - ValueError("Got 5 at 105 but became 6 at 406."). - """ - def __init__(self): - self.has_error = False - self.value = None - self._context = {} - - def Set(self, value, context=None): - """Receives a value for the object and some context on its source.""" - if self.has_error: return - if self.value is None: - self.value = value - self._context["old_value"] = value - self._context.update({"old_" + k: v for k, v in context.items()}) - elif self.value != value: - self.has_error = True - self._context["new_value"] = value - self._context.update({"new_" + k: v for k, v in context.items()}) - - def GetConsistentValueOrRaise(self, error_format, context=None): - """Gets consistent value or raises ValueError with formatted contexts.""" - if self.has_error: - full_context = dict(self._context) - if context: full_context.update(context) - raise ValueError(error_format.format(**full_context)) - return self.value - - -def _build_colocation_attr_map(input_map, absolute_import_scope): - """Returns a dict mapping from pre-import to post-import colocation attrs. - - Args: - input_map: as for fix_colocation_after_import. - absolute_import_scope: as for fix_colocation_after_import. - - Returns: - A dict that maps bytes `"loc:@" + absolute_import_scope + "/foo"` - to _ConsistentValues set to the lists of bytes `["loc:@...", ...]` - according to the rewriting scheme of fix_colocation_after_import. - In case of an inconsistent rewriting, _ConsistentValue.has_error is true. - """ - colocation_attr_map = collections.defaultdict(_ConsistentValue) - used_outputs_of_imported_ops = collections.defaultdict(set) - # Collect mappings from the input_map. - for imported_tensor_name, mapped_tensor in input_map.items(): - imported_tensor_name = absolute_import_scope + "/" + imported_tensor_name - imported_op_name, imported_index = _split_tensor_name(imported_tensor_name) - key = tf.compat.as_bytes("loc:@" + imported_op_name) - colocation_attr_map[key].Set( - mapped_tensor.op.colocation_groups(), - {"reason": "input '%s' is substituted by '%s'" % ( - imported_tensor_name, mapped_tensor.name)}) - used_outputs_of_imported_ops[imported_op_name].add(imported_index) - # Add unchanged mappings for additional, non-remapped outputs of ops touched - # by the input_map. For now, these just signal inconsistency when used. - for imported_op_name, used_outputs in used_outputs_of_imported_ops.items(): - imported_op = tf.compat.v1.get_default_graph().get_operation_by_name( - imported_op_name) - unused_outputs = set(range(len(imported_op.outputs))) - used_outputs - if not unused_outputs: continue - key = tf.compat.as_bytes("loc:@" + imported_op_name) - if imported_op.colocation_groups() != [key]: - # This should never happen: state nodes are remapped fully, input nodes - # are prevented from having colocation attributes. - raise ValueError( - "Internal error: tensors from op '%s' are partially remapped in " - "import but op.colocation_groups=%s cannot be captured in a " - "simple rewrite rule." % - (imported_op_name, imported_op.colocation_groups())) - colocation_attr_map[key].Set( - [key], - {"reason": "tensor '%s:%s' is not substituted by inputs" % ( - imported_op_name, - ",".join(str(i) for i in sorted(unused_outputs)))}) - - return colocation_attr_map - - -def _apply_colocation_attr_map(colocation_attr_map, absolute_import_scope): - """Rewrites colocation constraints in the current default graph. - - Nodes in `absolute_import_scope` get their "_class" attr lists rewritten - according to `colocation_attr_map`: each entry that matches a key gets - replaced by the associated values (with deduplication). The node's device - is updated accordingly. - - Args: - colocation_attr_map: as returned by _build_colocation_attr_map. - absolute_import_scope: as for fix_colocation_after_import. - - Raises: - ValueError: if rewriting runs into an inconsistent value in - `colocation_attr_map`. - """ - graph = tf.compat.v1.get_default_graph() - for op in graph.get_operations(): - # Rewrite the values of the "_class" attr that store colocation constraints. - # NOTE: The colocation_group loc:@X of a node with itself is not stored - # explicitly as an attr, so rewrite errors for loc:@X are not triggered - # by the mere existence of X. - if not op.name.startswith(absolute_import_scope + "/"): continue - try: - class_values = op.get_attr("_class") - except ValueError: - continue # No _class attr found; nothing to do. - new_attr_value = tf.compat.v1.AttrValue() - new_coloc_groups = [] - for class_value in class_values: - if class_value.startswith(tf.compat.as_bytes("loc:@")): - if class_value not in colocation_attr_map: - rewritten_class_value = [class_value] - else: - rewritten_class_value = (colocation_attr_map[ - class_value].GetConsistentValueOrRaise( - "Failed to rewrite colocation constraints while applying " - "hub.Module:\n" - "The module graph contains a node {op!r} " - "that has a colocation constraint {class_value!r} " - "with ambiguous rewriting {old_value!r} vs {new_value!r} " - "because {old_reason} and {new_reason}, respectively.\n" - "To fix, avoid publishing a module with inputs comprising " - "multiple outputs of one op that is referenced in " - "tf.colocate_with(...) constraints on other ops.", - {"op": op.name, "class_value": class_value})) - new_coloc_groups.extend(rewritten_class_value) - else: - new_attr_value.list.s.append(class_value) - new_coloc_groups = sorted(set(new_coloc_groups)) - new_attr_value.list.s.extend(new_coloc_groups) - op._set_attr("_class", new_attr_value) # pylint: disable=protected-access - - # Mimic the code of tf.import_graph_def(): If there are colocation - # constraints, use any of them to set the device (overriding what the - # device function stack would do), without attempting to merge or check for - # equality. If they were inconsistent, TensorFlow's C++ runtime would fail - # anyways due to conflicting colocation constraints. - # Note that Hub imports GraphDefs with devices cleared, so this code deals - # with the result of import_graph_def, not a setting saved in the module. - if new_coloc_groups: - new_coloc_device = "" - for new_coloc_group in new_coloc_groups: - assert new_coloc_group.startswith(tf.compat.as_bytes("loc:@")) - new_coloc_target_op = graph.get_operation_by_name( - tf.compat.as_str_any(new_coloc_group[5:])) - new_coloc_device = new_coloc_target_op.device - if new_coloc_device: break - # Set this, even if empty, to avoid retaining an outdated value. - op._set_device(new_coloc_device) # pylint: disable=protected-access - - -def find_state_op_colocation_error(graph, reported_tags=None): - """Returns error message for colocation of state ops, or None if ok.""" - state_op_types = list_registered_stateful_ops_without_inputs( - graph.as_graph_def()) - state_op_map = {op.name: op for op in graph.get_operations() - if op.type in state_op_types} - for op in state_op_map.values(): - for colocation_group in op.colocation_groups(): - if not (colocation_group.startswith(tf.compat.as_bytes("loc:@")) and - tf.compat.as_str_any(colocation_group[5:]) in state_op_map): - tags_prefix = ("" if reported_tags is None else - "in the graph for tags %s, " % reported_tags) - return ( - "A state-holding node x of a module's graph (e.g., a Variable op) " - "must not be subject to a tf.colocate_with(y) constraint " - "unless y is also a state-holding node.\n" - "Details: %snode '%s' has op '%s', which counts as state-holding, " - "but Operation.colocation_groups() == %s. " % - (tags_prefix, op.name, op.type, op.colocation_groups())) - return None - - -def find_signature_input_colocation_error(signature_name, inputs): - """Returns error message for colocation of signature inputs, or None if ok.""" - for input_name, tensor in inputs.items(): - ops = [t.op for t in tf.nest.flatten(tensor, expand_composites=True)] - for op in ops: - expected_colocation_groups = [tf.compat.as_bytes("loc:@" + op.name)] - if op.colocation_groups() != expected_colocation_groups: - return ( - "A tensor x used as input in a signature must not be subject to a " - "tf.colocate_with(y) constraint. (The reverse would be allowed.)\n" - "Details: tensor '%s' appears %s input '%s' of signature '%s' " - "but has Tensor.op.colocation_groups() == %s" % - (tensor, ("as" if len(ops) == 1 else "in"), input_name, - signature_name, op.colocation_groups())) - return None - - -def find_signature_inputs_from_multivalued_ops(inputs): - """Returns error message for module inputs from ops with multiple outputs.""" - dense_inputs = [] # List of (str, Tensor), with CompositeTensors decomposed. - for name, tensor in sorted(inputs.items()): - if isinstance(tensor, tf.SparseTensor): - dense_inputs.extend(("%s.%s" % (name, attr), getattr(tensor, attr)) - for attr in ("indices", "values", "dense_shape")) - elif tf_utils.is_composite_tensor(tensor): - components = tf.nest.flatten(tensor, expand_composites=True) - dense_inputs.extend(("%s.component_%d" % (name, i), c) - for (i, c) in enumerate(components)) - else: - dense_inputs.append((name, tensor)) - warnings = [(name, tensor.name) for name, tensor in dense_inputs - if len(tensor.op.outputs) != 1] - if warnings: - return ( - "WARNING: The inputs declared in hub.add_signature() should be tensors " - "from ops with a single output, or else uses of tf.colocate_with() on " - "that op can trigger fatal errors when the module is applied and " - "colocation constraints have to be rewritten.\nAffected inputs: %s" % - ", ".join("%s='%s'" % pair for pair in warnings)) - return None - - -def find_signature_type_errors(signature_name, inputs, outputs): - """Return error message for inputs or outputs with incorrect types.""" - errors = ([("input", name, tensor) - for name, tensor in sorted(inputs.items()) - if not isinstance(tensor, tf_utils.SUPPORTED_ARGUMENT_TYPES)] + - [("output", name, tensor) - for name, tensor in sorted(outputs.items()) - if not isinstance(tensor, tf_utils.SUPPORTED_ARGUMENT_TYPES)]) - if errors: - ok_types = ", ".join(t.__name__ for t in tf_utils.SUPPORTED_ARGUMENT_TYPES) - bad_types = "\n".join(" * %s '%s' has type %s" % - (source, name, type(value).__name__) - for (source, name, value) in errors) - return ( - "The inputs and outputs declared in hub.add_signature() for signature " - "'%s' should have one of the types that are supported by this version " - "of tensorflow_hub: %s.\n%s" % (signature_name, ok_types, bad_types)) - return None - - -def _is_tpu_graph_function(): - graph = tf.compat.v1.get_default_graph() - return (graph.building_function and - type(graph._get_control_flow_context()).__name__.endswith( # pylint: disable=protected-access - "TPUReplicateContext")) diff --git a/tensorflow_hub/native_module_test.py b/tensorflow_hub/native_module_test.py deleted file mode 100644 index d4cdfc25..00000000 --- a/tensorflow_hub/native_module_test.py +++ /dev/null @@ -1,1978 +0,0 @@ -# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tensorflow_hub.native_module.""" - -import os - -import numpy as np -import tensorflow as tf -import tensorflow_hub as hub - -from tensorflow_hub import module_def_pb2 -from tensorflow_hub import native_module -from tensorflow_hub import tf_utils - -# pylint: disable=g-import-not-at-top -# Use Keras 2. -version_fn = getattr(tf.keras, "version", None) -if version_fn and version_fn().startswith("3."): - import tf_keras # pylint: disable=unused-import - from tf_keras.api._v1 import keras as tf_keras_v1 # pylint: disable=unused-import -else: - tf_keras = tf.keras # Keras 2 - tf_keras_v1 = tf.compat.v1.keras - -# pylint: disable=g-direct-tensorflow-import -from tensorflow.python.framework import function -from tensorflow.python.framework import test_util -from tensorflow.python.ops.control_flow_ops import ControlFlowContext -from tensorflow.python.ops.lookup_ops import HashTable -from tensorflow.python.ops.lookup_ops import index_to_string_table_from_file -from tensorflow.python.ops.lookup_ops import KeyValueTensorInitializer -# pylint: enable=g-direct-tensorflow-import - - -def load_module_spec(spec): - """Force use of native_module implementation.""" - return native_module.Loader()(spec) - - -def multi_signature_module(): - x = tf.compat.v1.placeholder(tf.float32, shape=[None]) - native_module.add_signature("double", {"x": x}, {"y": 2*x}) - - z = tf.compat.v1.placeholder(tf.float32, shape=[None]) - native_module.add_signature("square", {"z": z}, {"z_out": z*z}) - - -def batch_norm_module(training): - x = tf.compat.v1.placeholder(tf.float32, shape=[None, 3]) - y = tf_keras_v1.__internal__.legacy.layers.batch_normalization( - x, training=training - ) - native_module.add_signature(inputs=x, outputs=y) - - -def module_with_variables(): - tf.compat.v1.get_variable( - name="weights", - shape=[3], - initializer=tf.compat.v1.zeros_initializer()) - tf.compat.v1.get_variable( - name="partition", - shape=[4], - initializer=tf.compat.v1.zeros_initializer(), - partitioner=tf.compat.v1.fixed_size_partitioner(3)) - hub.add_signature(outputs=tf.constant(1.0)) - - -class NativeModuleTest(tf.test.TestCase): - - def testModuleWithMissingRequiredFeature(self): - path = os.path.join(self.get_temp_dir(), "required-feature") - tf.compat.v1.gfile.MakeDirs(path) - proto_path = native_module.get_module_proto_path(path) - with tf.compat.v1.gfile.Open(proto_path, mode="wb") as f: - module_def_proto = module_def_pb2.ModuleDef() - module_def_proto.format = module_def_pb2.ModuleDef.FORMAT_V3 - module_def_proto.required_features.extend(["foo-test-missing"]) - f.write(module_def_proto.SerializeToString()) - with self.assertRaisesRegexp(ValueError, "foo-test-missing"): - load_module_spec(path) - - def testMultiSignatureSpec(self): - spec = native_module.create_module_spec(multi_signature_module) - self.assertAllEqual(sorted(spec.get_signature_names()), - ["double", "square"]) - self.assertAllEqual(list(spec.get_input_info_dict("double").keys()), ["x"]) - self.assertAllEqual(list(spec.get_output_info_dict("double").keys()), ["y"]) - self.assertAllEqual(list(spec.get_input_info_dict("square").keys()), ["z"]) - self.assertAllEqual(list(spec.get_output_info_dict("square").keys()), - ["z_out"]) - - def testDefaultTagSpec(self): - spec = native_module.create_module_spec(multi_signature_module) - self.assertAllEqual(sorted(spec.get_tags()), [set()]) - - def testMultiTagSpec(self): - spec = native_module.create_module_spec( - batch_norm_module, - [({"training"}, {"training": True}), - ({"inference"}, {"training": False})]) - self.assertAllEqual(sorted(spec.get_tags()), - [set(["training"]), set(["inference"])]) - - def testModuleWithVariablesAndNoCheckpoint(self): - with tf.Graph().as_default(): - spec = native_module.create_module_spec(module_with_variables) - spec._create_impl(name="module", trainable=False, tags=None) - self.assertAllEqual( - [x.op.name for x in tf.compat.v1.global_variables()], - [ - "module/weights", - "module/partition/part_0", - "module/partition/part_1", - "module/partition/part_2", - ]) - - with tf.compat.v1.Session() as session: - session.run(tf.compat.v1.initializers.global_variables()) - expected_values = [ - [0.0, 0.0, 0.0], - [0.0, 0.0], - [0.0], - [0.0], - ] - for a, b in zip(session.run(tf.compat.v1.global_variables()), - expected_values): - self.assertAllEqual(a, b) - - def testNoSignaturesPresent(self): - - def wrong_module_fn(): - x = tf.compat.v1.placeholder(tf.float32, shape=[None, 3]) - return tf.identity(x) - - with self.assertRaises(ValueError) as cm: - spec = native_module.create_module_spec(wrong_module_fn) - self.assertIn("No signatures present", str(cm.exception)) - - def testUnsupportedCollections(self): - - def module_fn(): - scale = tf.compat.v1.get_variable("x", (), collections=["my_scope"]) - x = tf.compat.v1.placeholder(tf.float32, shape=[None, 3]) - native_module.add_signature("my_func", {"x": x}, {"y": x*scale}) - - with self.assertRaises(ValueError) as cm: - _ = native_module.create_module_spec(module_fn) - self.assertIn("Unsupported collections in graph", cm) - - with tf.Graph().as_default() as tmp_graph: - module_fn() - unsupported_collections = native_module.get_unsupported_collections( - tmp_graph.get_all_collection_keys()) - self.assertEqual(["my_scope"], unsupported_collections) - - _ = native_module.create_module_spec( - module_fn, drop_collections=unsupported_collections) - - -class RecoverPartitionedVariableMapTest(tf.test.TestCase): - - def testRecoverPartitionedVariableMap(self): - with tf.Graph().as_default(): - with tf.compat.v1.variable_scope("test"): - partitioner = tf.compat.v1.fixed_size_partitioner(3) - tf.compat.v1.get_variable( - initializer=tf.ones([11, 5]), - name="partitioned_variable", - partitioner=partitioner) - tf.compat.v1.get_variable( - initializer=tf.ones([11, 5]), - name="normal_variable") - - all_vars = tf.compat.v1.global_variables() - all_vars_dict = {var.op.name[5:]: var for var in all_vars} - self.assertEqual(set(all_vars_dict.keys()), set([ - "partitioned_variable/part_0", - "partitioned_variable/part_1", - "partitioned_variable/part_2", - "normal_variable"])) - - self.assertEqual(len(all_vars_dict), 4) - var_map = native_module.recover_partitioned_variable_map(all_vars_dict) - self.assertEqual(set(var_map.keys()), set([ - "partitioned_variable", "normal_variable"])) - - # Verify order of the partitioned variable list - self.assertAllEqual( - [v.op.name for v in var_map["partitioned_variable"]], - [ - "test/partitioned_variable/part_0", - "test/partitioned_variable/part_1", - "test/partitioned_variable/part_2", - ]) - - -def stateless_module_fn(): - x = tf.compat.v1.placeholder(tf.int64) - y = x*x - hub.add_signature(inputs=x, outputs=y) - - -def unused_input_module_fn(): - x = tf.compat.v1.placeholder(tf.int64) - y = tf.compat.v1.placeholder(tf.int64) - result = x*x - hub.add_signature( - inputs={"x": x, "unused": y}, - outputs=result) - - -def double_module_fn(): - w = tf.Variable(2.0) - x = tf.compat.v1.placeholder(dtype=tf.float32) - hub.add_signature(inputs=x, outputs=x*w) - - -def create_partitioned_variable_module_fn(partitions, shape): - """Returns a module summing one normal and one partitioned variable.""" - def module_fn(): - """A module summing one normal and one partitioned variable.""" - partitioner = tf.compat.v1.fixed_size_partitioner(partitions) - var_1 = tf.compat.v1.get_variable( - initializer=tf.ones(shape), - name="partitioned_variable", - partitioner=partitioner) - var_2 = tf.compat.v1.get_variable( - initializer=tf.ones(shape), name="normal_variable") - hub.add_signature(outputs=var_1 + var_2) - - return module_fn - - -class TFHubStatelessModuleTest(tf.test.TestCase): - - def testLoadModuleFromFuncDef(self): - with tf.compat.v1.Session() as sess: - v = tf.compat.v1.placeholder(tf.int64) - spec = hub.create_module_spec(stateless_module_fn) - m = hub.Module(spec) - y = m(v) - self.assertEqual(sess.run(y, feed_dict={v: 10}), 100) - - def testUnusedInputModule(self): - with tf.compat.v1.Session() as sess: - v1 = tf.compat.v1.placeholder(tf.int64) - v2 = tf.compat.v1.placeholder(tf.int64) - spec = hub.create_module_spec(unused_input_module_fn) - m = hub.Module(spec) - out = m({"x": v1, "unused": v2}) - self.assertEqual(sess.run(out, feed_dict={v1: 10, v2: 4}), 100) - - def testConvertToTensor(self): - spec = hub.create_module_spec(stateless_module_fn) - with tf.compat.v1.Session() as sess: - m = hub.Module(spec) - y = m([10, 2]) - self.assertAllEqual(sess.run(y), [100, 4]) - with tf.compat.v1.Session() as sess: - m = hub.Module(spec) - with self.assertRaises(TypeError): - m("hello") - - def testArgErrors(self): - spec = hub.create_module_spec(stateless_module_fn) - with tf.compat.v1.Session(): - m = hub.Module(spec) - with self.assertRaisesRegexp(TypeError, "missing"): - m() - - @test_util.run_v1_only("b/138681007") - def testUseWithinWhileLoop(self): - with tf.Graph().as_default(): - spec = hub.create_module_spec(double_module_fn) - m = hub.Module(spec) - i = tf.constant(0) - x = tf.constant(10.0) - p = tf.compat.v1.placeholder(dtype=tf.int32) - c = lambda i, x: tf.less(i, p) - b = lambda i, x: (tf.add(i, 1), m(x)) - oi, ox = tf.while_loop(c, b, [i, x]) # ox = v**p * x - v = m.variables[0] - dodv = tf.gradients(ox, v)[0] # d ox / dv = p*v**(p-1) * x - dodx = tf.gradients(ox, x)[0] # d ox / dx = v**p - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - self.assertAllEqual(sess.run([oi, ox], feed_dict={p: 1}), [1, 20]) - self.assertAllEqual(sess.run([oi, ox], feed_dict={p: 2}), [2, 40]) - self.assertAllEqual(sess.run([oi, ox], feed_dict={p: 4}), [4, 160]) - # Gradients also use the control flow structures setup earlier. - # Also check they are working properly. - self.assertAllEqual(sess.run([dodv, dodx], feed_dict={p: 1}), [10, 2]) - self.assertAllEqual(sess.run([dodv, dodx], feed_dict={p: 2}), [40, 4]) - self.assertAllEqual(sess.run([dodv, dodx], feed_dict={p: 4}), [320, 16]) - - # tf.map_fn() is merely a wrapper around tf.while(), but just to be sure... - @test_util.run_v1_only("b/138681007") - def testUseWithinMap(self): - with tf.Graph().as_default(): - spec = hub.create_module_spec(double_module_fn) - m = hub.Module(spec) - x = tf.constant([1.0, 11.0, 101.0]) - y = tf.map_fn(m, x) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - self.assertAllEqual(sess.run(y), [2, 22, 202]) - - def testClearControlDependenciesForModuleStateButNotApplyGraphs(self): - module_spec = hub.create_module_spec(stateless_module_fn) - - with tf.Graph().as_default() as g1: - v = tf.compat.v1.placeholder(dtype=tf.int64, name="v") - m = hub.Module(module_spec) - m(v) - - with tf.Graph().as_default() as g2: - v = tf.compat.v1.placeholder(dtype=tf.int64, name="v") - with tf.control_dependencies([v]): - m = hub.Module(module_spec) - m(v) - - # The MLIR Bridge may arbitrarily apply internal-only attributes to graphs - # that are not intended to be consumed by external users. In this particular - # instance, the bridge will add an internal-only attribute as the result of - # using manual control dependencies above. Filter out the internal-only - # attributes to allow equality checking on the meaningfully public structure - # of the graph. - g2_graph_def = g2.as_graph_def() - for node in g2_graph_def.node: - for attr in list(node.attr.keys()): - if attr.startswith("_"): - del node.attr[attr] - - self.assertEqual(g1.as_graph_def(), g2_graph_def) - - with tf.Graph().as_default() as g3: - v = tf.compat.v1.placeholder(dtype=tf.int64, name="v") - m = hub.Module(module_spec) - m(v) - - with tf.Graph().as_default() as g4: - v = tf.compat.v1.placeholder(dtype=tf.int64, name="v") - m = hub.Module(module_spec) - with tf.control_dependencies([v]): - m(v) - - self.assertNotEqual(g3.as_graph_def(), g4.as_graph_def()) - - -def sparse_square_module_fn(): - x = tf.compat.v1.sparse_placeholder(dtype=tf.int64, name="x") - out = tf.SparseTensor(x.indices, x.values * x.values, x.dense_shape) - hub.add_signature(inputs=x, outputs=out) - - -class TFHubSparseTensorModuleTest(tf.test.TestCase): - - def testSparseTensors(self): - square_spec = hub.create_module_spec(sparse_square_module_fn) - - with tf.Graph().as_default(): - square = hub.Module(square_spec) - v = tf.compat.v1.sparse_placeholder(dtype=tf.int64, name="v") - y = square(v) - - with tf.compat.v1.Session().as_default(): - indices = [[0, 0], [0, 1], [1, 1]] - values = [10, 2, 1] - shape = [2, 2] - v1 = tf.compat.v1.SparseTensorValue(indices, values, shape) - v2 = y.eval(feed_dict={v: v1}) - v4 = y.eval(feed_dict={v: v2}) - - self.assertAllEqual(v4.indices, indices) # Unchanged. - self.assertAllEqual(v4.values, [t**4 for t in values]) # Squared twice. - self.assertAllEqual(v4.dense_shape, shape) # Unchanged. - - -def ragged_square_module_fn(): - x = tf.compat.v1.ragged.placeholder( - tf.int64, ragged_rank=1, value_shape=[], name="x") - out = x.with_values(x.values * x.values) - hub.add_signature(inputs=x, outputs=out) - - -def ragged_combine_fn(x, y, z): - return x + tf.reduce_sum(y) * z - - -def ragged_combine_module_fn(): - x = tf.compat.v1.ragged.placeholder( - tf.int64, ragged_rank=1, value_shape=[], name="x") - y = tf.compat.v1.ragged.placeholder( - tf.int64, ragged_rank=3, value_shape=[2], name="y") - z = tf.compat.v1.placeholder(tf.int64, shape=[], name="z") - - combined = ragged_combine_fn(x, y, z) - hub.add_signature(inputs={"x": x, "y": y, "z": z}, outputs=combined) - - -class TFHubRaggedTensorModuleTest(tf.test.TestCase): - - def testSquare(self): - if tf.__version__.startswith("2.3."): - self.skipTest("tf.compat.v1.ragged.placeholder erroneously adds " - "validation, upsetting hub.Module") - square_spec = hub.create_module_spec(ragged_square_module_fn) - - with tf.Graph().as_default(): - square = hub.Module(square_spec) - v = tf.compat.v1.ragged.placeholder( - tf.int64, ragged_rank=1, value_shape=[], name="v") - y = square(v) - - with tf.compat.v1.Session().as_default() as sess: - v1 = tf.compat.v1.ragged.constant_value([[10, 2], [1]]) - v2 = sess.run(y, feed_dict={v: v1}) - v4 = sess.run(y, feed_dict={v: v2}) - - self.assertAllEqual(v4.row_splits, v1.row_splits) # Unchanged. - self.assertAllEqual(v4.values, - [t**4 for t in v1.values]) # Squared twice. - - def testCombine(self): - if tf.__version__.startswith("2.3."): - self.skipTest("tf.compat.v1.ragged.placeholder erroneously adds " - "validation, upsetting hub.Module") - module_spec = hub.create_module_spec(ragged_combine_module_fn) - - with tf.Graph().as_default(): - combine = hub.Module(module_spec) - a = tf.compat.v1.ragged.placeholder( - tf.int64, ragged_rank=1, value_shape=[], name="a") - b = tf.compat.v1.ragged.placeholder( - tf.int64, ragged_rank=3, value_shape=[2], name="b") - c = tf.compat.v1.placeholder(tf.int64, shape=[], name="c") - out = combine({"x": a, "y": b, "z": c}) - - with tf.compat.v1.Session().as_default() as sess: - a_value = tf.compat.v1.ragged.constant_value([[10, 20], [30, 40, 50]]) - b_value = tf.compat.v1.ragged.constant_value( - [[[[[1, 2], [3, 4]], []], [[[5, 6]]]], - [[[[7, 8], [9, 10], [11, 12]]], [[[13, 14]]]]], - ragged_rank=3) - c_value = np.array(100) - - feed_dict = feed_dict = {a: a_value, b: b_value, c: c_value} - result = sess.run(out, feed_dict=feed_dict) - expected = sess.run(ragged_combine_fn(a, b, c), feed_dict=feed_dict) - self.assertAllEqual(result, expected) - - -class AddSignatureWithUnsupportedTypesTest(tf.test.TestCase): - """Tests that adding signatures w/ unsupported types raises an exception. - - We use IndexedSlice as a convenient example of an unsupported composite - tensor type. If IndexedSlices gets added to the list of supported types, - then these tests would need to be updated. - """ - - def testUnsupportedInputInSignature(self): - - def module_fn(): - x = tf.IndexedSlices( - tf.compat.v1.placeholder(tf.int64, None), - tf.compat.v1.placeholder(tf.int64, None)) - y = x.values - hub.add_signature(inputs={"x": x}, outputs=y) - - with self.assertRaisesRegex(ValueError, "should have one of the types"): - hub.create_module_spec(module_fn) - - def testUnsupportedOutputInSignature(self): - - def module_fn(): - x = tf.compat.v1.placeholder(tf.int64, None) - y = tf.compat.v1.placeholder(tf.int64, None) - z = tf.IndexedSlices(x, y) - hub.add_signature(inputs={"x": x, "y": y}, outputs=z) - - with self.assertRaisesRegex(ValueError, "should have one of the types"): - hub.create_module_spec(module_fn) - - def testUnsupportedInputInCall(self): - """Ensure that Module.__call__ flags unsupported args. - - This could occur if a module is built using version N, which supports some - type, but then loaded with version N-1, which doesn't support that type. - We should be able to *load* the module (since it may contain other - signatures that are supported on version N-1), but we shouldn't be able to - call the signature that uses the unsupported type. - """ - original_supported_arg_types = tf_utils.SUPPORTED_ARGUMENT_TYPES - try: - def module_fn(): - x = tf.compat.v1.sparse.placeholder(tf.int64, [2, 3]) - y = x.values - hub.add_signature(inputs={"x": x}, outputs=y) - - spec = hub.create_module_spec(module_fn) - with tf.Graph().as_default(): - module = hub.Module(spec) - input_tensor = tf.sparse.from_dense([[1, 0, 2], [3, 0, 0]]) - - # Simulate being in a version that doesn't support SparseTensor. - tf_utils.SUPPORTED_ARGUMENT_TYPES = (tf.Tensor,) - with self.assertRaisesRegex( - ValueError, "Signature 'default' expects a SparseTensor for input " - "'x', which is not supported by this version of tensorflow_hub."): - module(input_tensor) - - finally: - tf_utils.SUPPORTED_ARGUMENT_TYPES = original_supported_arg_types - - def testUnsupportedOutputInCall(self): - """Ensure that Module.__call__ flags unsupported outputs.""" - original_supported_arg_types = tf_utils.SUPPORTED_ARGUMENT_TYPES - try: - def module_fn(): - x = tf.compat.v1.placeholder(tf.int64, [4]) - y = tf.RaggedTensor.from_row_lengths(x, [3, 0, 1]) - hub.add_signature(inputs={"x": x}, outputs=y) - - spec = hub.create_module_spec(module_fn) - with tf.Graph().as_default(): - module = hub.Module(spec) - input_tensor = tf.compat.v1.ragged.constant([1, 2, 3, 4]) - - # Simulate being in a version that doesn't support RaggedTensor. - tf_utils.SUPPORTED_ARGUMENT_TYPES = (tf.Tensor,) - with self.assertRaisesRegex( - ValueError, - "Signature 'default' expects a RaggedTensor for output 'default', " - "which is not supported by this version of tensorflow_hub."): - module(input_tensor) - - finally: - tf_utils.SUPPORTED_ARGUMENT_TYPES = original_supported_arg_types - - -def stateful_module_fn(): - v = tf.compat.v1.get_variable( - "var123", shape=[3], - initializer=tf.compat.v1.constant_initializer([1.0, 2.0, 3.0])) - hub.add_signature(outputs=v.value()) - - -def stateful_rv_module_fn(): - r = tf.compat.v1.get_variable( - "rv_var123", shape=[], - initializer=tf.compat.v1.constant_initializer(10.0), - use_resource=True) - hub.add_signature(outputs=r.value()) - - -class TPUReplicateContext(ControlFlowContext): - - def __init__(self): - super().__init__() - self._name = "TPUReplicateContext" - - def AddOp(self, _): - pass - - def AddValue(self, x): - return x - - def to_control_flow_context_def(self, context_def, export_scope=None): - super().to_control_flow_context_def( - context_def, export_scope) - - -def stateful_random_rv_module_fn(): - r = tf.compat.v1.get_variable( - "rv_var123", - shape=[], - initializer=tf.compat.v1.random_uniform_initializer(), - use_resource=True) - hub.add_signature(outputs=r.value()) - - -def stateful_rv_with_input_module_fn(): - x = tf.compat.v1.placeholder(dtype=tf.float32, name="x") - # Add a placeholder/variable that doesn't go to an output. - y = tf.compat.v1.placeholder(dtype=tf.float32, name="y") - r = tf.compat.v1.get_variable( - "rv_var123", - shape=[], - initializer=tf.compat.v1.constant_initializer(10.0), - use_resource=True) - t = tf.compat.v1.get_variable( - "rv_var456", - shape=[], - initializer=tf.compat.v1.constant_initializer(10.0), - use_resource=True) - t.assign(y) - res = x + r - hub.add_signature(inputs={"x": x}, outputs=res) - - -def control_dependency_module_fn(): - const_op = tf.constant(1.0, name="dependency_op") - with tf.control_dependencies([const_op]): - res = tf.constant(3.0) + tf.constant(2.0) - hub.add_signature(inputs={}, outputs=res) - - -def stateful_non_rv_module_fn(): - v = tf.compat.v1.get_variable( - "var123", shape=[], - initializer=tf.compat.v1.constant_initializer(10.0), - use_resource=False) - hub.add_signature(outputs=v.value()) - - -def stateful_module_fn_with_colocation(): - v = tf.compat.v1.get_variable( - "var123", shape=[], - initializer=tf.compat.v1.constant_initializer(1.0), - use_resource=False) - v_value = v.value() - x = tf.compat.v1.placeholder(dtype=tf.float32, name="x") - with tf.compat.v1.colocate_with(v), tf.compat.v1.colocate_with(x): - y = tf.add(v_value, x, name="y") - hub.add_signature(inputs=x, outputs=y) - - -class TFHubStatefulModuleTest(tf.test.TestCase): - - def testVariables(self): - with tf.Graph().as_default(): - spec = hub.create_module_spec(stateful_module_fn) - m = hub.Module(spec, name="test") - out = m() - self.assertEqual(list(m.variable_map.keys()), ["var123"]) - self.assertEqual(m.variable_map["var123"].name, "test/var123:0") - self.assertEqual([v.name for v in m.variables], ["test/var123:0"]) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - self.assertAllClose(sess.run(out), [1.0, 2.0, 3.0]) - - def testResourceVariables(self): - with tf.Graph().as_default(): - spec = hub.create_module_spec(stateful_rv_module_fn) - m = hub.Module(spec, name="test_rv") - out = m() - self.assertEqual(list(m.variable_map.keys()), ["rv_var123"]) - self.assertEqual(m.variable_map["rv_var123"].name, "test_rv/rv_var123:0") - self.assertEqual([v.name for v in m.variables], ["test_rv/rv_var123:0"]) - - # Check that "shared_name" attributes are adapted correctly: - var_handle_op_name = "test_rv/rv_var123" - var_handle_op = tf.compat.v1.get_default_graph().get_operation_by_name( - var_handle_op_name) - self.assertEqual( - var_handle_op.get_attr("shared_name"), - tf.compat.as_bytes(var_handle_op_name)) - - export_path = os.path.join(self.get_temp_dir(), "resource-variables") - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - self.assertAllClose(sess.run(out), 10.0) - m.export(export_path, sess) - - with tf.Graph().as_default(): - f = hub.Module(export_path) - out = f() - - # Test colocation constraints on the read op in the apply graph. - # It has two legal values: - # - Colocation with the VarHandleOp in the state graph. - # - No constraint, in which case it reports its own colocation_group. - # This appears to happen at the time of this writing (March 2018) - # because the Python code relies on the TensorFlow core to handle - # VariableReadOps as a special case and colocate them with their - # VarHandleOp input, which is mapped to the state graph. - # In any case, the point is to *not* colocate with the stillborn copy - # of the VarHandleOp in the apply graph scope. - if out.op.colocation_groups() != [ - tf.compat.as_bytes("loc:@" + out.op.name)]: - self.assertItemsEqual(out.op.colocation_groups(), - [tf.compat.as_bytes("loc:@module/rv_var123")]) - - # Check that "shared_name" attributes are adapted correctly: - var_handle_op_name = "module/rv_var123" - var_handle_op = tf.compat.v1.get_default_graph().get_operation_by_name( - var_handle_op_name) - self.assertEqual( - var_handle_op.get_attr("shared_name"), - tf.compat.as_bytes(var_handle_op_name)) - - # Create a saver for the whole graph. - saver = tf.compat.v1.train.Saver() - - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - self.assertAllClose(sess.run(out), 10.0) - - # Make sure that the variable names stored in a checkpoint of the graph - # are as expected. - variables_path = os.path.join(self.get_temp_dir(), "variables") - saver.save( - sess, variables_path, write_meta_graph=False, write_state=False) - variable_names_and_shapes = tf.compat.v1.train.list_variables( - ckpt_dir_or_file=variables_path) - variable_names = set(name for name, _ in variable_names_and_shapes) - self.assertEqual(variable_names, {"module/rv_var123"}) - - def testNonResourceVariables(self): - with tf.Graph().as_default(): - spec = hub.create_module_spec(stateful_non_rv_module_fn) - m = hub.Module(spec, name="test_non_rv") - out = m() - self.assertEqual(list(m.variable_map.keys()), ["var123"]) - self.assertEqual(m.variable_map["var123"].name, "test_non_rv/var123:0") - self.assertEqual([v.name for v in m.variables], ["test_non_rv/var123:0"]) - - export_path = os.path.join(self.get_temp_dir(), "non-resource-variables") - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - self.assertAllClose(sess.run(out), 10.0) - m.export(export_path, sess) - - with tf.Graph().as_default(): - f = hub.Module(export_path) - out = f() - - # Test that the read op in the apply graph gets colocated with the - # variable in the state graph scope "module/" (and not the stillborn - # copy in the apply graph scope). - self.assertItemsEqual(out.op.colocation_groups(), - [tf.compat.as_bytes("loc:@module/var123")]) - - # Create a saver for the whole graph. - saver = tf.compat.v1.train.Saver() - - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - self.assertAllClose(sess.run(out), 10.0) - - # Make sure that the variable names stored in a checkpoint of the - # graph are as expected. - variables_path = os.path.join(self.get_temp_dir(), "variables") - saver.save( - sess, variables_path, write_meta_graph=False, write_state=False) - variable_names_and_shapes = tf.compat.v1.train.list_variables( - ckpt_dir_or_file=variables_path) - variable_names = set(name for name, _ in variable_names_and_shapes) - self.assertEqual(variable_names, {"module/var123"}) - - @test_util.run_v1_only("b/138681007") - def testNonResourceVariableInWhileLoop(self): - with tf.Graph().as_default(): - # This test uses non-Resource variables to see an actual colocation - # constraint propagated to the context Enter op. The long comment on - # colocation in testResourceVariables explains why they may not offer - # that. - spec = hub.create_module_spec(stateful_non_rv_module_fn) - m = hub.Module(spec) - cond = lambda i, x: tf.less(i, 4) - def body(i, x): - v = m() - self.assertItemsEqual(v.op.colocation_groups(), - [tf.compat.as_bytes("loc:@module/var123")]) - return (tf.add(i, 1), 2*x) - oi, ox = tf.while_loop(cond, body, [0, 10.0]) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - self.assertAllEqual(sess.run([oi, ox]), [4, 160.0]) - - @test_util.run_v1_only("b/138681007") - def testNonResourceVariableInCond(self): - with tf.Graph().as_default(): - spec = hub.create_module_spec(stateful_non_rv_module_fn) - m = hub.Module(spec) - pred = tf.compat.v1.placeholder(tf.bool) - def true_fn(): - v = m() - self.assertItemsEqual(v.op.colocation_groups(), - [tf.compat.as_bytes("loc:@module/var123")]) - return v - def false_fn(): - return tf.constant(9.0) - out = tf.cond(pred, true_fn, false_fn) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - self.assertEqual(sess.run(out, feed_dict={pred: True}), 10.0) - self.assertEqual(sess.run(out, feed_dict={pred: False}), 9.0) - - def testVariableColocationPropagation(self): - with tf.Graph().as_default(): - spec = hub.create_module_spec(stateful_module_fn_with_colocation) - m = hub.Module(spec) - u1 = tf.constant(1, name="u1") - u2 = tf.constant(2, name="u2") - with tf.compat.v1.colocate_with(u1), tf.compat.v1.colocate_with(u2): - x = tf.constant(100.0, name="x") - y = m(x) - self.assertItemsEqual(y.op.colocation_groups(), - [tf.compat.as_bytes("loc:@module/var123"), - tf.compat.as_bytes("loc:@u1"), - tf.compat.as_bytes("loc:@u2")]) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - self.assertEqual(sess.run(y), 101.0) - - def testPartitionedVariables(self): - with tf.Graph().as_default(): - spec = hub.create_module_spec( - create_partitioned_variable_module_fn(partitions=3, shape=[7, 3])) - m = hub.Module(spec, name="test") - out = m() - self.assertEqual(len(m.variable_map), 2) - self.assertEqual(m.variable_map["normal_variable"].name, - "test/normal_variable:0") - self.assertAllEqual([ - variable.name for variable in m.variable_map["partitioned_variable"] - ], [ - "test/partitioned_variable/part_0:0", - "test/partitioned_variable/part_1:0", - "test/partitioned_variable/part_2:0" - ]) - self.assertAllEqual( # Check deterministric order (by variable_map key). - [variable.name for variable in m.variables], - ["test/normal_variable:0", - "test/partitioned_variable/part_0:0", - "test/partitioned_variable/part_1:0", - "test/partitioned_variable/part_2:0"]) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - self.assertAllClose(sess.run(out), 2 * np.ones([7, 3])) - - def testLargePartitionedVariables(self): - with tf.Graph().as_default(): - spec = hub.create_module_spec( - create_partitioned_variable_module_fn(partitions=25, shape=[600, 3])) - m = hub.Module(spec, name="test") - out = m() - self.assertEqual(len(m.variable_map), 2) - self.assertEqual(len(m.variable_map["partitioned_variable"]), 25) - self.assertEqual(len(m.variables), 26) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - self.assertAllClose(sess.run(out), 2 * np.ones([600, 3])) - - def testLoadTrainableModuleFromFuncDef(self): - with tf.compat.v1.Session() as sess: - spec = hub.create_module_spec(stateful_module_fn) - m = hub.Module(spec, trainable=True) - x = m() - step = tf.Variable(0, trainable=False, name="global_step") - train_op = tf.compat.v1.train.GradientDescentOptimizer(0.40).minimize( - loss=tf.compat.v1.losses.mean_squared_error(x, [3.1, 3.2, 3.3]), - global_step=step) - sess.run(tf.compat.v1.global_variables_initializer()) - for _ in range(50): - sess.run(train_op) - got = sess.run(x) - self.assertAllClose(got, [3.1, 3.2, 3.3]) - - # TODO(b/112575006): The following tests verify functionality of function call - # within a TPU context. Work to generalize this for all function calls is - # ongoing. - def testTPUModuleInitializeOnceWithDefun(self): - spec = hub.create_module_spec(stateful_random_rv_module_fn) - - @function.Defun() - def import_computation(): - context = TPUReplicateContext() - context.Enter() - m = hub.Module(spec, name="module_", trainable=True) - return [m(), m()] - - with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess: - x = import_computation() - sess.run(tf.compat.v1.global_variables_initializer()) - got = sess.run(x) - # Check the values are equal. If the initializer ran on each call, - # the values would be different. - self.assertEqual(got[0], got[1]) - - def testTPUPruneWithUnusedInput(self): - spec = hub.create_module_spec(unused_input_module_fn) - - @function.Defun() - def import_computation(x): - context = TPUReplicateContext() - context.Enter() - m = hub.Module(spec, name="module_", trainable=True) - return m({ - "x": tf.cast(x, dtype=tf.int64), - "unused": tf.constant(2, dtype=tf.int64) - }) - - with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess: - x = import_computation(5) - got = sess.run(x) - self.assertEqual(got, 25) - - def testTPUModuleDoesntPruneControlDependencies(self): - spec = hub.create_module_spec(control_dependency_module_fn) - - @function.Defun() - def import_computation(): - context = TPUReplicateContext() - context.Enter() - m = hub.Module(spec, name="module_", trainable=True) - return m() - - with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess: - x = import_computation() - got = sess.run(x) - self.assertEqual(got, 5.0) - # If the op got pruned, the following get_operation_by_name should fail - # with a dependency error. - tf.compat.v1.get_default_graph().get_operation_by_name( - "module_/dependency_op") - - def testTPUModuleWithDefun(self): - spec = hub.create_module_spec(stateful_rv_with_input_module_fn) - - @function.Defun() - def import_computation(first, second): - context = TPUReplicateContext() - context.Enter() - m = hub.Module(spec, name="module_", trainable=True) - return [m(first), m(second)] - - with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess: - x = import_computation(9.0, 6.0) - sess.run(tf.compat.v1.global_variables_initializer()) - got = sess.run(x) - self.assertEqual(got, (19.0, 16.0)) - - def testTPUModuleWithTFEDefun(self): - with tf.compat.v1.Graph().as_default() as graph: - with tf.compat.v1.Session() as sess: - spec = hub.create_module_spec(stateful_rv_with_input_module_fn) - - @tf.function - def import_computation(first, second): - context = TPUReplicateContext() - context.Enter() - m = hub.Module(spec, trainable=True) - return [m(first), m(second)] - - x = import_computation(9.0, 6.0) - sess.run(tf.compat.v1.global_variables_initializer()) - got = sess.run(x) - self.assertEqual(got, [19.0, 16.0]) - - def testTPUModuleWithWrapFunc(self): - spec = hub.create_module_spec(stateful_rv_with_input_module_fn) - - def import_computation(first, second): - context = TPUReplicateContext() - context.Enter() - m = hub.Module(spec, trainable=True) - return [m(first), m(second)] - - with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess: - x = tf.compat.v1.wrap_function( - import_computation, - [tf.TensorSpec((), tf.float32), - tf.TensorSpec((), tf.float32)]) - sess.run(tf.compat.v1.global_variables_initializer()) - got = sess.run(x(9.0, 6.0)) - self.assertEqual(got, [19.0, 16.0]) - - def _exportModulewithTrainedVariable(self): - export_path = os.path.join(self.get_temp_dir(), "var-module") - with tf.Graph().as_default(): - spec = hub.create_module_spec(stateful_module_fn) - m = hub.Module(spec, trainable=True) - assign_op = tf.compat.v1.assign(m.variable_map["var123"], - tf.constant([9.0, 9.0, 9.0])) - with tf.compat.v1.Session() as sess: - sess.run(assign_op) - m.export(export_path, sess) - return export_path - - def testModuleWithTrainedVariable(self): - with tf.Graph().as_default(): - f = hub.Module(self._exportModulewithTrainedVariable()) - out = f() - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - got = sess.run(out) - self.assertAllClose(got, [9.0, 9.0, 9.0]) - - def testModuleEvalWithTrainedVariable(self): - export_path = self._exportModulewithTrainedVariable() - with hub.eval_function_for_module(export_path) as f: - self.assertAllClose(f(), [9.0, 9.0, 9.0]) - - -def table_lookup_module_fn(): - x = tf.compat.v1.placeholder(dtype=tf.int64, name="x") - keys = tf.constant([0, 1, 2], dtype=tf.int64) - values = tf.constant(["index0", "hello", "world"]) - - tbl_init = KeyValueTensorInitializer(keys, values) - table = HashTable(tbl_init, "UNK") - hub.add_signature(inputs=x, outputs=table.lookup(x)) - - -class TFHubTableLookupModuleTest(tf.test.TestCase): - - def _exportModuleWithTable(self): - export_path = os.path.join(self.get_temp_dir(), "table-module") - with tf.Graph().as_default(): - spec = hub.create_module_spec(table_lookup_module_fn) - m = hub.Module(spec) - # Export requires a session to work regardless of the module having no - # variables to export. - with tf.compat.v1.Session() as sess: - m.export(export_path, sess) - return export_path - - def testModuleWithTable(self): - with tf.Graph().as_default(): - v = tf.compat.v1.placeholder(dtype=tf.int64) - f = hub.Module(self._exportModuleWithTable()) - y = f(v) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.tables_initializer()) - got = sess.run(y, feed_dict={v: [0, 1, 2, 3]}) - self.assertAllEqual(list(got), [b"index0", b"hello", b"world", b"UNK"]) - - def testModuleEvalWithTable(self): - with hub.eval_function_for_module(self._exportModuleWithTable()) as f: - got = f([0, 1, 2, 3]) - self.assertAllEqual(list(got), [b"index0", b"hello", b"world", b"UNK"]) - - -def do_table_lookup(indices, vocabulary_file): - table = index_to_string_table_from_file( - vocabulary_file=vocabulary_file, - default_value="UNKNOWN") - return table.lookup(indices) - - -def layers_module_fn(): - """Module that exercises the use of layers.""" - # This is a plain linear map Mx+b regularized by the sum of the squares - # of the coefficients in M and b. - x = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 2], name="x") - def l2(weights): - """Applies l2 regularization to weights.""" - with tf.control_dependencies([weights]): - return 2.0 * tf.compat.v1.nn.l2_loss(weights) - - h = tf_keras_v1.__internal__.legacy.layers.dense( - x, 2, activation=None, kernel_regularizer=l2, bias_regularizer=l2 - ) - hub.add_signature(inputs=x, outputs=h) - - -class TFHubLayersModuleTest(tf.test.TestCase): - - def testModuleWithLayers(self): - export_path = os.path.join(self.get_temp_dir(), "layers-module") - - sample_input = [[1.0, 2.0], [3.1, 10.0]] - - spec = hub.create_module_spec(layers_module_fn) - with tf.Graph().as_default(): - m = hub.Module(spec, trainable=False) - x = tf.compat.v1.placeholder(dtype=tf.float32) - y = m(x) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - sample_output = sess.run(y, feed_dict={x: sample_input}) - m.export(export_path, sess) - - with tf.Graph().as_default(): - x = tf.compat.v1.placeholder(dtype=tf.float32) - y = hub.Module(export_path)(x) - - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - got = sess.run(y, feed_dict={x: sample_input}) - self.assertAllEqual(got, sample_output) - - def testModuleWithRegularizedLayers(self): - # The linear map y = Mx + b with L2 regularization on M and b - # when trained at x = [1,1] with L2 loss towards the target y' = [4,4] - # learns M = [[1,1],[1,1]], b = [1,1], y = [3,3], with eight balanced - # loss terms: the elements of M, b, and y' - y are all distance 1 from zero. - train_input = [[1.0, 1.0]] - target = [[4.0, 4.0]] - - spec = hub.create_module_spec(layers_module_fn) - with tf.Graph().as_default(): - m = hub.Module(spec, trainable=True) - x = tf.compat.v1.placeholder(dtype=tf.float32) - y = m(x) - squared_loss = tf.compat.v1.losses.mean_squared_error(y, target, - weights=2.0) - # Recover REGULARIZATION_LOSSES from the module. - total_loss = squared_loss + tf.compat.v1.losses.get_regularization_loss() - step = tf.Variable(0, trainable=False, name="global_step") - train = tf.compat.v1.train.GradientDescentOptimizer(0.1).minimize( - loss=total_loss, global_step=step) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - for _ in range(50): - sess.run(train, feed_dict={x: train_input}) - # Verify M = [[1,1],[1,1]], b = [1,1] by evaluating at three points. - # Without regularization, the result would be an underdetermined mess. - out = sess.run(y, feed_dict={x: [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]}) - self.assertAllClose( - out, [[1.0, 1.0], [2.0, 2.0], [2.0, 2.0]], atol=0.001) - - -def valid_colocation_module_fn(): - w = tf.Variable(42 + 69, name="w") - # w.op has the same name on resource and non-resource variables - with tf.compat.v1.colocate_with(w.op): - # A colocation reference among state nodes is ok. - v = tf.Variable(1.0, name="v") - assert v.op.colocation_groups() == [tf.compat.as_bytes("loc:@w")] - # A colocation reference from other nodes to state nodes is ok. - y = tf.add(v, 1, name="y") - assert y.op.colocation_groups() == [tf.compat.as_bytes("loc:@w")] - x = tf.compat.v1.placeholder(dtype=tf.float32, name="x") - with tf.compat.v1.colocate_with(x): - # A colocation reference from other nodes to input nodes is ok. - z = tf.add(x, 1, name="z") - assert z.op.colocation_groups() == [tf.compat.as_bytes("loc:@x")] - hub.add_signature(inputs=dict(x=x), outputs=dict(y=y, z=z)) - - -def bad_input_colocation_module_fn(): - u = tf.add(42, 69, name="u") - with tf.compat.v1.colocate_with(u): - # Inputs must not reference other nodes for colocation. - x = tf.compat.v1.placeholder(tf.float32, name="x") - y = x + 1.0 - hub.add_signature(inputs=x, outputs=y) - - -def bad_state_colocation_module_fn(): - u = tf.add(42, 69, name="u") - with tf.compat.v1.colocate_with(u): - # State-holding nodes must not reference other nodes for colocation. - v = tf.Variable(1.0, name="v") - x = tf.compat.v1.placeholder(dtype=tf.float32) - y = x + v - hub.add_signature(inputs=x, outputs=y) - - -def brittle_multivalued_colocation_module_fn(): - x, y = tf.split([1, 2], 2, name="split") - with tf.compat.v1.colocate_with(x), tf.compat.v1.colocate_with(y): - z = tf.add(x, y, name="add") - assert z.op.colocation_groups() == [tf.compat.as_bytes("loc:@split")] - hub.add_signature(inputs=dict(x=x, y=y), outputs=z, name="both") - hub.add_signature(inputs=dict(x=x), outputs=z, name="partial") - - -class ColocationRewritingTest(tf.test.TestCase): - - def testValidCase(self): - """Tests a complex, valid case end-to-end.""" - spec = hub.create_module_spec(valid_colocation_module_fn) - with tf.Graph().as_default(): - u = tf.constant(7.0, name="u") - m = hub.Module(spec, name="m") - outputs = m(dict(x=u), as_dict=True) - self.assertItemsEqual(outputs["y"].op.colocation_groups(), - [tf.compat.as_bytes("loc:@m/w")]) - self.assertItemsEqual(outputs["z"].op.colocation_groups(), - [tf.compat.as_bytes("loc:@u")]) - - def testBadInputColocation(self): - """Tests catching bad colocation of inputs during create_module_spec.""" - with self.assertRaisesRegexp(ValueError, "(?s)input.*colocate.*loc:@u"): - _ = hub.create_module_spec(bad_input_colocation_module_fn) - - def testBadStateColocation(self): - """Tests catching bad colocation of states during create_module_spec.""" - with self.assertRaisesRegexp(ValueError, "(?s)state.*colocate.*loc:@u"): - _ = hub.create_module_spec(bad_state_colocation_module_fn) - - def testInputsFromMultivaluedOp(self): - """Tests warning for inputs from multivalued ops in create_module_spec.""" - # Ideally, one would be able to write - # with self.assertLogs("blah"): hub.create_module_spec(module_fn) - # but in the absence of assertions on logs, we test the underlying helper - # in the environment seen from within a module_fn. - with tf.Graph().as_default(): - first, _ = tf.split([[1, 2], [3, 4]], 2, name="split1") - _, second = tf.split([[5, 6], [7, 8]], 2, name="split2") - third = tf.constant(105, name="const") - message = native_module.find_signature_inputs_from_multivalued_ops( - dict(first=first, second=second, third=third)) - self.assertRegexpMatches( - message, - ".*single output.*\n" - "Affected inputs: first='split1:0', second='split2:1'$") - # Also test the case of no errors. - with tf.Graph().as_default(): - first = tf.constant(101) - second = tf.constant(102) - third = tf.constant(103) - message = native_module.find_signature_inputs_from_multivalued_ops( - dict(first=first, second=second, third=third)) - self.assertIsNone(message) - - def testSparseInputsFromMultivaluedOp(self): - """Tests warning for SparseTensor inputs from multivalued ops.""" - with tf.Graph().as_default(): - one, _ = tf.compat.v1.sparse_split( - sp_input=tf.SparseTensor(indices=[[0, 1], [1, 2]], values=[1, 2], - dense_shape=[2, 3]), - num_split=2, axis=0, name="op1") - _, two = tf.compat.v1.sparse_split( - sp_input=tf.SparseTensor(indices=[[0, 0], [1, 1]], values=[3, 4], - dense_shape=[2, 3]), - num_split=2, axis=0, name="op2") - three = tf.SparseTensor(indices=[[0]], values=[5], dense_shape=[2]) - message = native_module.find_signature_inputs_from_multivalued_ops( - dict(one=one, two=two, three=three)) - self.assertRegexpMatches( - message, - ".*single output.*\nAffected inputs: " - "one.indices='op1:0', one.values='op1:2', one.dense_shape='op1:4', " - "two.indices='op2:1', two.values='op2:3', two.dense_shape='op2:5'$") - # Also test the case of no errors. - with tf.Graph().as_default(): - one = tf.SparseTensor(indices=[[0]], values=[1], dense_shape=[2]) - two = tf.SparseTensor(indices=[[1]], values=[2], dense_shape=[2]) - message = native_module.find_signature_inputs_from_multivalued_ops( - dict(one=one, two=two, three=three)) - self.assertIsNone(message) - - def testBrittleColocationWithInputsFromMultivaluedOp(self): - """Tests handling of ambiguous rewrites during module.__call__.""" - spec = hub.create_module_spec(brittle_multivalued_colocation_module_fn) - with tf.Graph().as_default(): - u = tf.constant([1], name="u") - with tf.compat.v1.colocate_with(u): - v = tf.constant([2], name="v") - w = tf.constant([3], name="w") - m = hub.Module(spec, name="m") - # It works if both inputs are mapped to ops with equal colocation groups. - assert u.op.colocation_groups() == v.op.colocation_groups() - z = m(dict(x=u, y=v), signature="both") - self.assertItemsEqual(z.op.colocation_groups(), - [tf.compat.as_bytes("loc:@u")]) - # It crashes in the general case. - assert u.op.colocation_groups() != w.op.colocation_groups() - with self.assertRaisesRegexp( - ValueError, - # In Python 3 (but not 2), colocation groups are lists of bytes, - # which are formatted with a leading "b" just before the quotes. - r"(?s)Failed to rewrite .*b?'loc:@m_apply_both_1/split' .*" - "\[b?'loc:@[uw]'\] vs \[b?'loc:@[wu]'\]"): - z = m(dict(x=u, y=w), signature="both") - - def testBadColocationWithPartialInputsFromMultivaluedOp(self): - spec = hub.create_module_spec(brittle_multivalued_colocation_module_fn) - with tf.Graph().as_default(): - u = tf.constant([1], name="u") - m = hub.Module(spec, name="m") - with self.assertRaisesRegexp( - ValueError, - r"(?s)Failed to rewrite .*b?'loc:@m_apply_partial/split' .*" - "\[b?'loc:@u'\] vs \[b?'loc:@m_apply_partial/split'\]"): - z = m(dict(x=u), signature="partial") - - -def update_ops_module_fn(): - counter = tf.Variable(0, trainable=False) - tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, - counter.assign_add(1)) - hub.add_signature(inputs=None, outputs=counter.value()) - - -class TFHubUpdateOpsTest(tf.test.TestCase): - - def testUpdateOps(self): - spec = hub.create_module_spec(update_ops_module_fn) - with tf.compat.v1.Session() as sess: - trainable_module = hub.Module(spec, trainable=True) - fixed_module = hub.Module(spec, trainable=False) - - # TODO(b/62433105): Understand what is the desired behaviour of UPDATE_OPS - # and applying a Module multiple times. For now UPDATE_OPS probably only - # do something reasonable if each Module is applied exactly one time. - trainable_module() - fixed_module() - - variable = tf.Variable(0.0) - step = tf.Variable(0, trainable=False, name="global_step") - update_ops = tf.compat.v1.get_collection( - tf.compat.v1.GraphKeys.UPDATE_OPS) - with tf.control_dependencies(update_ops): - train_op = tf.compat.v1.train.GradientDescentOptimizer(0.1).minimize( - loss=variable, - global_step=step) - - sess.run(tf.compat.v1.global_variables_initializer()) - sess.run(train_op) - trainable_module_vars = list(trainable_module.variable_map.values()) - self.assertEqual(len(trainable_module_vars), 1) - self.assertEqual(sess.run(trainable_module_vars[0]), 1) - fixed_module_vars = list(fixed_module.variable_map.values()) - self.assertEqual(len(fixed_module_vars), 1) - self.assertEqual(sess.run(fixed_module_vars[0]), 0) - - -def batch_norm_module_fn(is_training): - """Module that exercises batch normalization, incl. UPDATE_OPS.""" - x = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 1], name="x") - y = tf_keras_v1.__internal__.legacy.layers.batch_normalization( - momentum=0.4, inputs=x, fused=False, training=is_training - ) - hub.add_signature(inputs=x, outputs=y) - - -class TFHubBatchNormModuleTest(tf.test.TestCase): - - # This test is intended to verify the following: - # 1) A module_fn that uses batch normalization through tf.layers.contrib - # (and its underlying utilities from tf.nn) can be used to create, - # export, load and use the Module. - # 2) Batch normalization learns the scale and offset parameters for its - # output as it should. - # 3) The UPDATE_OPS added internally for the moving_mean and moving_variance - # over the training data are properly executed at training time, and their - # results are used at serving time, without further change. - def testModuleWithBatchNorm(self): - export_path = os.path.join(self.get_temp_dir(), "batch-norm-module") - # This test resorts to lookup by name to retrieve the moving mean, - # because tf.contrib.layers.batch_norm() does not return it, and even if, - # module_fn() has no way to return it next to the result for training. - moving_mean_name = ( - "module/batch_normalization/moving_mean/Read/ReadVariableOp:0") - - batch_norm_train_tags = ["batch_norm_trains"] - batch_norm_fixed_tags = ["batch_norm_fixed"] - spec = hub.create_module_spec( - batch_norm_module_fn, - [(batch_norm_train_tags, {"is_training": True}), - (batch_norm_fixed_tags, {"is_training": False})]) - # Test Module creation and training. - with tf.Graph().as_default() as g: - m = hub.Module(spec, trainable=True, tags=batch_norm_train_tags) - # The module is trained on a fixed batch of inputs, which has a mean - # of 12.0 and some sample variance of a less obvious value. The module - # learns scale and offset parameters that achieve the mapping x --> 2*x - # for the observed mean and variance. - x = tf.constant([[11.0], [12.0], [13.0]]) - training_mean = [12.0] - y_target = tf.constant([[22.0], [24.0], [26.0]]) - y = m(x) - step = tf.Variable(0, trainable=False, name="global_step") - moving_mean = g.get_tensor_by_name(moving_mean_name) - update_ops = tf.compat.v1.get_collection( - tf.compat.v1.GraphKeys.UPDATE_OPS) - with tf.control_dependencies(update_ops): - train = tf.compat.v1.train.GradientDescentOptimizer(0.1).minimize( - loss=tf.compat.v1.losses.mean_squared_error(y, y_target), - global_step=step) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - self.assertAllClose(sess.run(moving_mean), [0.0]) - for _ in range(100): - sess.run([train]) - trained_moving_mean, trained_y = sess.run([moving_mean, y]) - self.assertAllClose(trained_moving_mean, training_mean) - self.assertAllClose(trained_y, [[22.0], [24.0], [26.0]]) - # Test export. - m.export(export_path, sess) - - # Test import and use. - spec = load_module_spec(export_path) - with tf.Graph().as_default() as g: - # The module gets run for inference on inputs with different mean and - # variance. However, both mean and variance as well as offset and scale - # are now frozen to the values from learning, so the same mapping - # x --> 2*x is recovered. - x = tf.constant([[10.0], [20.0], [30.0]]) - y = hub.Module( - spec, tags=batch_norm_fixed_tags)(x) - moving_mean = g.get_tensor_by_name(moving_mean_name) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - for _ in range(100): - served_moving_mean, served_y = sess.run([moving_mean, y]) - # No update occurs to the moving_mean from training time. - self.assertAllClose(served_moving_mean, training_mean) - # Prediction results are correct. - self.assertAllClose(served_y, [[20.0], [40.0], [60.0]]) - - -def multiple_outputs_module_fn(): - x = tf.compat.v1.placeholder(dtype=tf.float32) - v = tf.Variable([3.0]) - hub.add_signature( - inputs={"x": x}, - outputs={"y": v * x, "z": v * v * x}) - - -class TFHubMultipleOutputsTest(tf.test.TestCase): - - def testMultipleOutputs(self): - with tf.compat.v1.Session() as sess: - spec = hub.create_module_spec(multiple_outputs_module_fn) - m = hub.Module(spec) - output = m(tf.constant([2.0]), as_dict=True) - output1 = output["y"] - output2 = output["z"] - sess.run(tf.compat.v1.global_variables_initializer()) - self.assertAllClose(sess.run(output1), [6.0]) - self.assertAllClose(sess.run(output2), [18.0]) - - -def create_assets_module_fn(vocabulary_file): - - def assets_module_fn(): - indices = tf.compat.v1.placeholder(dtype=tf.int64, name="indices") - outputs = do_table_lookup(indices, vocabulary_file) - hub.add_signature(inputs=indices, outputs=outputs) - - return assets_module_fn - - -def create_consumer_module_fn(exported_hub_module): - - def consumer_module_fn(): - indices = tf.compat.v1.placeholder(dtype=tf.int64, name="indices") - inner_module = hub.Module(exported_hub_module) - inner_module_output = inner_module(indices) - output = tf.identity(inner_module_output) - hub.add_signature(inputs=indices, outputs=output) - - return consumer_module_fn - - -class TFHubAssetsTest(tf.test.TestCase): - - def create_vocab_file(self, path, vocab): - vocabulary_file = os.path.join(self.get_temp_dir(), "tokens.txt") - with open(vocabulary_file, "w+") as vocab_file: - for line in vocab: - vocab_file.write(line) - vocab_file.write(os.linesep) - return vocabulary_file - - def testAssets(self): - export_path = os.path.join(self.get_temp_dir(), "assets-module") - vocabulary_file = self.create_vocab_file("tokens.txt", - ["emerson", "lake", "palmer"]) - with tf.Graph().as_default(): - assets_module_fn = create_assets_module_fn(vocabulary_file) - spec = hub.create_module_spec(assets_module_fn) - embedding_module = hub.Module(spec) - output = embedding_module(tf.constant([1, 2], dtype=tf.int64)) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.tables_initializer()) - self.assertAllEqual(list(sess.run(output)), [b"lake", b"palmer"]) - embedding_module.export(export_path, sess) - - asset_file = os.path.join(*[export_path, "assets", "tokens.txt"]) - # Check that asset file got written to the expected place: - self.assertTrue(tf.compat.v1.gfile.Exists(asset_file)) - - # Assets should be hermetic, so we can delete the original vocab file: - tf.compat.v1.gfile.Remove(vocabulary_file) - - with tf.Graph().as_default(): - spec = load_module_spec(export_path) - embedding_module = hub.Module(spec) - output = embedding_module(tf.constant([1, 2], dtype=tf.int64)) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.tables_initializer()) - # Check functionality: - self.assertAllEqual(list(sess.run(output)), [b"lake", b"palmer"]) - # Check that the ASSET_FILEPATHS collection was restored properly: - asset_filepaths_collection = tf.compat.v1.get_collection( - tf.compat.v1.GraphKeys.ASSET_FILEPATHS) - asset_filepaths = [sess.run(tensor) - for tensor in asset_filepaths_collection] - # ASSET_FILEPATHS are added for the state graph and for the apply graph: - self.assertAllEqual(asset_filepaths, - [tf.compat.as_bytes(asset_file)] * 2) - - def testDuplicateAssetCopy(self): - export_path = os.path.join(self.get_temp_dir(), "assets-module") - - def module_with_duplicate_asset(): - vocabulary_file = self.create_vocab_file("tokens2.txt", ["1", "2", "3"]) - indices1 = tf.compat.v1.placeholder(dtype=tf.int64, name="indices1") - indices2 = tf.compat.v1.placeholder(dtype=tf.int64, name="indices2") - hub.add_signature( - inputs={ - "indices_1": indices1, - "indices_2": indices2, - }, - outputs={ - "x": do_table_lookup(indices1, vocabulary_file), - "y": do_table_lookup(indices2, vocabulary_file), - }) - - with tf.Graph().as_default(): - spec = hub.create_module_spec(module_with_duplicate_asset) - module_a = hub.Module(spec) - module_a({"indices_1": tf.constant([1, 2], dtype=tf.int64), - "indices_2": tf.constant([1, 2], dtype=tf.int64)}, as_dict=True) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.tables_initializer()) - module_a.export(export_path, sess) - - def testExportedConsumerModelWorksIfItUsesHubModuleWithAssets(self): - # 1. Create and export a module with assets. - module_export_path = os.path.join(self.get_temp_dir(), "small-module") - vocabulary_file = self.create_vocab_file("tokens.txt", - ["emerson", "lake", "palmer"]) - assets_module_fn = create_assets_module_fn(vocabulary_file) - spec = hub.create_module_spec(assets_module_fn) - with tf.Graph().as_default(): - small_module = hub.Module(spec) - with tf.compat.v1.Session() as sess: - small_module.export(module_export_path, sess) - # 2. Remove the original vocab file and move the module to another location. - tf.compat.v1.gfile.Remove(vocabulary_file) - inner_module_path = os.path.join(self.get_temp_dir(), "inner-module") - tf.compat.v1.gfile.Rename(module_export_path, inner_module_path) - del module_export_path - # 3. Use the module in a consumer model (which is another module here). - module_export_path = os.path.join(self.get_temp_dir(), "consumer-module") - consumer_module_fn = create_consumer_module_fn(inner_module_path) - spec = hub.create_module_spec(consumer_module_fn) - with tf.Graph().as_default(): - consumer_module = hub.Module(spec) - with tf.compat.v1.Session() as sess: - consumer_module.export(module_export_path, sess) - # 4. Delete the inner module on disk and move the consumer model to a final - # location for serving. - tf.compat.v1.gfile.DeleteRecursively(inner_module_path) - module_serving_path = os.path.join(self.get_temp_dir(), "serving-module") - tf.compat.v1.gfile.Rename(module_export_path, module_serving_path) - # 5. Make sure the model can be served successfully. - with tf.Graph().as_default(): - serving_module = hub.Module(module_serving_path) - output = serving_module(tf.constant([1, 2], dtype=tf.int64)) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.tables_initializer()) - self.assertAllEqual(list(sess.run(output)), [b"lake", b"palmer"]) - - -def another_stateful_module_fn(): - """Stateful module with inputs.""" - module_input = tf.compat.v1.placeholder(dtype=tf.float32) - variable = tf.Variable([3.0], name="iamtheoneandonly") - hub.add_signature(inputs=module_input, outputs=module_input*variable) - - -class TFHubApplyStatefulModuleMultipleTimesTest(tf.test.TestCase): - - def testApplyStatefulModuleMultipleTimes(self): - export_path = os.path.join(self.get_temp_dir(), "another-module") - - with tf.compat.v1.Session() as sess: - spec = hub.create_module_spec(another_stateful_module_fn) - stateful_module = hub.Module(spec, trainable=True) - times2 = stateful_module(tf.constant([2.0])) - times3 = stateful_module(tf.constant([3.0])) - step = tf.Variable(0, trainable=False, name="global_step") - # Training will adapt the hidden variable to be approximately 2: - train = tf.compat.v1.train.GradientDescentOptimizer(0.05).minimize( - loss=tf.compat.v1.losses.mean_squared_error(times2, [4.0]), - global_step=step) - - sess.run(tf.compat.v1.global_variables_initializer()) - for _ in range(50): - sess.run(train) - self.assertAllClose(sess.run(times2), [4.0]) - self.assertAllClose(sess.run(times3), [6.0]) - stateful_module.export(export_path, sess) - with tf.compat.v1.Session() as sess: - stateful_module = hub.Module(export_path) - times4 = stateful_module(tf.constant([4.0])) - times5 = stateful_module(tf.constant([5.0])) - sess.run(tf.compat.v1.global_variables_initializer()) - self.assertAllClose(sess.run(times4), [8.0]) - self.assertAllClose(sess.run(times5), [10.0]) - - def testMultipleApplicationsInDifferentScopes(self): - with tf.Graph().as_default(): - export_path = os.path.join(self.get_temp_dir(), "module-applied-in-scope") - - spec = hub.create_module_spec(another_stateful_module_fn) - stateful_module = hub.Module(spec, name="moduleA") - with tf.name_scope("foo"): - with tf.compat.v1.variable_scope("bar"): - times2 = stateful_module(tf.constant([2.0])) - with tf.name_scope("baz"): - times3 = stateful_module(tf.constant([3.0])) - - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - self.assertAllClose(sess.run(times2), [6.0]) - self.assertAllClose(sess.run(times3), [9.0]) - self.assertEqual(len(stateful_module.variable_map), 1) - self.assertEqual( - stateful_module.variable_map["iamtheoneandonly"].name, - "moduleA/iamtheoneandonly:0") - stateful_module.export(export_path, sess) - - # Check minimal functionality of the exported module. - with tf.Graph().as_default(): - stateful_module = hub.Module(export_path, name="moduleB") - times2 = stateful_module(tf.constant([2.0])) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - self.assertAllClose(sess.run(times2), [6.0]) - - -def multiple_signature_module_fn(): - """Stateful module with multiple signatures.""" - weight = tf.Variable([3.0]) - - x_input = tf.compat.v1.placeholder(dtype=tf.float32) - x_output = tf.multiply(x_input, weight) - hub.add_signature("mul", inputs=x_input, outputs=x_output) - - y_input = tf.compat.v1.placeholder(dtype=tf.float32) - y_output = tf.divide(y_input, weight) - hub.add_signature("div", inputs=y_input, outputs=y_output) - - -class TFHubModuleWithMultipleSignatures(tf.test.TestCase): - - def testGetSignatures(self): - spec = hub.create_module_spec(multiple_signature_module_fn) - self.assertEqual(sorted(spec.get_signature_names()), ["div", "mul"]) - - def testModuleWithMultipleSignatures(self): - with tf.Graph().as_default(): - spec = hub.create_module_spec(multiple_signature_module_fn) - module_a = hub.Module(spec, name="moduleA") - in_tensor = tf.compat.v1.placeholder(dtype=tf.float32) - out_tensor_a = module_a(in_tensor, signature="mul") - out_tensor_b = module_a(out_tensor_a, signature="div") - - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - in_values = [6, 3, 1] - self.assertAllClose( - sess.run(out_tensor_b, feed_dict={in_tensor: in_values}), in_values) - - -def cond_module_fn(): - """Computes relu(x) with a conditional.""" - x = tf.compat.v1.placeholder(dtype=tf.float32, name="x", shape=[]) - result = tf.cond(0 < x, lambda: tf.identity(x), lambda: tf.constant(0.0)) - hub.add_signature(inputs=x, outputs=result) - - -def nested_cond_module_fn(): - """Computes relu(x) with nested conditionals.""" - x = tf.compat.v1.placeholder(dtype=tf.float32, name="x", shape=[]) - # pylint: disable=g-long-lambda - result = tf.cond( - 0 < x, - lambda: tf.cond(3 < x, - lambda: tf.identity(x), - lambda: tf.multiply(x, 1.0)), - lambda: tf.cond(x < -3, - lambda: tf.constant(0.0), - lambda: tf.multiply(0.0, 1.0))) - # pylint: enable=g-long-lambda - hub.add_signature(inputs=x, outputs=result) - - -def while_module_fn(): - """Compute x^n with while_loop.""" - x = tf.compat.v1.placeholder(dtype=tf.float32, name="x", shape=[]) - n = tf.compat.v1.placeholder(dtype=tf.int32, name="n") - _, pow_x = tf.while_loop( - lambda i, ix: i < n, lambda i, ix: [tf.add(i, 1), ix * x], - [tf.constant(0), tf.constant(1.0)]) - hub.add_signature(inputs={"x": x, "n": n}, outputs=pow_x) - - -def nested_control_flow_module_fn(): - """Compute the sum of elements greater than 'a' with nested control flow.""" - elems = tf.compat.v1.placeholder( - dtype=tf.float32, name="elems", shape=[None]) - a = tf.compat.v1.placeholder(dtype=tf.float32, name="a") - - def sum_above_a(acc, x): - return acc + tf.cond(x > a, lambda: x, lambda: 0.0) - - hub.add_signature( - inputs={"elems": elems, "a": a}, - outputs=tf.foldl(sum_above_a, elems, initializer=tf.constant(0.0))) - - -class TFHubModulesWithControlFlow(tf.test.TestCase): - - def _testCondModule(self): - self._testReluModule(cond_module_fn) - - def testCondModule(self): - self._testCondModule() - - @test_util.enable_control_flow_v2 - def testCondModuleWithControlFlowV2(self): - self._testCondModule() - - def _testModuleWithNestedConds(self): - self._testReluModule(nested_cond_module_fn) - - def testModuleWithNestedConds(self): - self._testModuleWithNestedConds() - - @test_util.enable_control_flow_v2 - def testModuleWithNestedCondsWithControlFlowV2(self): - self._testModuleWithNestedConds() - - def _testReluModule(self, module_fn): - spec = hub.create_module_spec(module_fn) - with tf.Graph().as_default(): - with tf.compat.v1.Session() as sess: - x = tf.compat.v1.placeholder(dtype=tf.float32, name="x") - relu_module = hub.Module(spec) - y = relu_module(x) - grad = tf.gradients([y], [x]) - self.assertAllClose(sess.run(y, {x: 9.1}), 9.1) - self.assertAllClose(sess.run(y, {x: -2.4}), 0.0) - self.assertAllClose(sess.run(grad, {x: 2}), [1.0]) - self.assertAllClose(sess.run(grad, {x: -2}), [0.0]) - - def _testWhileModule(self): - spec = hub.create_module_spec(while_module_fn) - with tf.Graph().as_default(): - with tf.compat.v1.Session() as sess: - x = tf.compat.v1.placeholder(tf.float32) - n = tf.compat.v1.placeholder(tf.int32) - pow_module = hub.Module(spec) - y = pow_module({"x": x, "n": n}) - grad = tf.gradients([y], [x]) - self.assertAllClose(sess.run(y, {x: 9.1, n: 1}), 9.1) - self.assertAllClose(sess.run(y, {x: 2.4, n: 2}), 5.76) - self.assertAllClose(sess.run(grad, {x: 2, n: 3}), [12.0]) - - def testWhileModule(self): - self._testWhileModule() - - @test_util.enable_control_flow_v2 - def testWhileModuleWithControlFlowV2(self): - self._testWhileModule() - - @test_util.run_v1_only("b/138681007") - def testUseModuleWithWhileLoopInsideCond(self): - spec = hub.create_module_spec(while_module_fn) - with tf.Graph().as_default(): - m = hub.Module(spec) - cond = tf.cond( - tf.equal(tf.constant(0), tf.constant(0)), - lambda: m({"x": tf.constant(3.0), "n": tf.constant(2)}), - lambda: tf.constant(4.0)) - with tf.compat.v1.Session() as sess: - self.assertEqual(sess.run(cond), 9.0) - - def _testNestedControlFlowModule(self): - spec = hub.create_module_spec(nested_control_flow_module_fn) - with tf.Graph().as_default(): - with tf.compat.v1.Session() as sess: - elems = tf.compat.v1.placeholder(tf.float32, shape=[None]) - a = tf.compat.v1.placeholder(tf.float32) - m = hub.Module(spec) - out = m({"elems": elems, "a": a}) - grad = tf.gradients([out], [elems]) - self.assertAllClose( - sess.run(out, { - a: 1.1, - elems: [10, 0, 0.5, 1.2] - }), 11.2) - - self.assertAllClose(sess.run(grad, {a: 1, elems: [10, 0, 0.5, 1.2]}), - [[1.0, 0.0, 0.0, 1.0]]) - - def testNestedControlFlowModule(self): - self._testNestedControlFlowModule() - - @test_util.enable_control_flow_v2 - def testNestedControlFlowModuleWithControlFlowV2(self): - self._testNestedControlFlowModule() - - -def attached_messages_module_fn(tagged=0): - x = tf.compat.v1.placeholder(tf.float32, shape=[None]) - hub.add_signature(inputs={"x": x}, outputs={"y": 2*x}) - # For brevity, this test borrows two well-known, stable message types - # from TensorFlow. They are not likely choices for actual uses. - hub.attach_message("numbers", - tf.compat.v1.train.Int64List(value=[-3])) # Overwritten. - hub.attach_message("numbers", tf.compat.v1.train.Int64List(value=[42, 69])) - hub.attach_message("letters", tf.compat.v1.train.BytesList(value=[ - tf.compat.as_bytes("abc"), tf.compat.as_bytes("xyz")])) - hub.attach_message("tagged", tf.compat.v1.train.Int64List(value=[tagged])) - - -class TFHubModuleWithAttachedMessages(tf.test.TestCase): - - def testModuleSpec(self): - """This is the general test for ModuleSpec and native_module._ModuleSpec.""" - spec = hub.create_module_spec(attached_messages_module_fn) - attached_letters = spec.get_attached_message("letters", - tf.compat.v1.train.BytesList) - self.assertSequenceEqual( - attached_letters.value, - [tf.compat.as_bytes("abc"), - tf.compat.as_bytes("xyz")]) - attached_numbers = spec.get_attached_message("numbers", - tf.compat.v1.train.Int64List) - self.assertSequenceEqual(attached_numbers.value, [42, 69]) - attached_train = spec.get_attached_message("tagged", - tf.compat.v1.train.Int64List) - self.assertSequenceEqual(attached_train.value, [0]) - self.assertIsNone(spec.get_attached_message("bad", - tf.compat.v1.train.BytesList)) - with self.assertRaises(KeyError): - spec.get_attached_message("bad", - tf.compat.v1.train.BytesList, required=True) - - def testModule(self): - """Tests forwarding from Module to ModuleSpec.""" - spec = hub.create_module_spec(attached_messages_module_fn) - with tf.Graph().as_default(): - module = hub.Module(spec) - attached = module.get_attached_message("numbers", - tf.compat.v1.train.Int64List) - self.assertSequenceEqual(attached.value, [42, 69]) - - def testGraphVersions(self): - """Tests native_module._ModuleSpec for explicit tags arguments.""" - tags_and_args = [(set(), {"tagged": 1}), - ({"double", "the", "value"}, {"tagged": 2})] - spec = hub.create_module_spec(attached_messages_module_fn, - tags_and_args=tags_and_args) - for tags, args in tags_and_args: - attached_to_spec = spec.get_attached_message( - "tagged", tf.compat.v1.train.Int64List, tags=tags) - self.assertSequenceEqual(attached_to_spec.value, [args["tagged"]]) - with tf.Graph().as_default(): - module = hub.Module(spec, tags=tags) - attached_to_module = module.get_attached_message( - "tagged", tf.compat.v1.train.Int64List) - self.assertSequenceEqual(attached_to_module.value, [args["tagged"]]) - - def testSeparateCopies(self): - """Mutating returned objects does not affect future returned values.""" - spec = hub.create_module_spec(attached_messages_module_fn) - attached_numbers = spec.get_attached_message("numbers", - tf.compat.v1.train.Int64List) - self.assertSequenceEqual(attached_numbers.value, [42, 69]) - attached_numbers.Clear() - self.assertSequenceEqual(attached_numbers.value, []) - attached_numbers = spec.get_attached_message("numbers", - tf.compat.v1.train.Int64List) - self.assertSequenceEqual(attached_numbers.value, [42, 69]) - - -class TFHubOpsTest(tf.test.TestCase): - - def testRegisterLinkedOpsError(self): - with self.assertRaisesRegexp(tf.errors.NotFoundError, "non-existent-op"): - native_module.register_ops_if_needed({"non-existent-op"}) - - -class TFHubExportSpecTest(tf.test.TestCase): - - def f(self, x, dim=10): - return tf_keras_v1.__internal__.legacy.layers.dense(x, dim) - - def module_fn(self, dim=10): - x = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, dim]) - y = self.f(x, dim=dim) - hub.add_signature(inputs=x, outputs=y) - - def createCheckpoint(self, scope=None): - checkpoint_path = os.path.join(self.get_temp_dir(), "model") - with tf.Graph().as_default(): - x = tf.compat.v1.get_variable( - "x", [32, 10], initializer=tf.compat.v1.initializers.random_normal()) - if scope: - with tf.compat.v1.variable_scope(scope): - y = self.f(x) - else: - y = self.f(x) - tf_keras_v1.__internal__.legacy.layers.dense(y, 20) - - saver = tf.compat.v1.train.Saver() - init_op = tf.compat.v1.initializers.global_variables() - - with tf.compat.v1.Session() as session: - session.run(init_op) - saver.save(session, checkpoint_path) - - return checkpoint_path - - def testExportModuleSpec(self): - checkpoint_path = self.createCheckpoint() - export_path = os.path.join(self.get_temp_dir(), "module1") - - spec = hub.create_module_spec(self.module_fn) - spec.export(export_path, - checkpoint_path=checkpoint_path) - - def testExportModuleSpec_withWrongShape(self): - checkpoint_path = self.createCheckpoint(scope="block") - export_path = os.path.join(self.get_temp_dir(), "module2") - - spec = hub.create_module_spec(lambda: self.module_fn(dim=20)) - with self.assertRaisesRegexp(ValueError, "doesn't match with shape of"): - spec.export(export_path, - checkpoint_path=checkpoint_path, - name_transform_fn=lambda x: "block/" + x) - - def testExportModuleSpec_withWrongScope(self): - checkpoint_path = self.createCheckpoint("block2") - export_path = os.path.join(self.get_temp_dir(), "module3") - - spec = hub.create_module_spec(self.module_fn) - with self.assertRaisesRegexp(ValueError, "bias is not found in"): - spec.export(export_path, - checkpoint_path=checkpoint_path, - name_transform_fn=lambda x: "block/" + x) - - -class TFHubUsageWithEager(tf.test.TestCase): - - def testWrapFunction(self): - if not tf.executing_eagerly(): - self.skipTest("Test requires eager.") - - spec = hub.create_module_spec(stateful_rv_with_input_module_fn) - - initializers = [] - def use_module(x, y): - m = hub.Module(spec, name="module_", trainable=True) - initializers.append(tf.compat.v1.initializers.global_variables()) - return [m(x), m(y)] - - input_signature = [ - tf.TensorSpec((), tf.float32), - tf.TensorSpec((), tf.float32), - ] - - f = tf.compat.v1.wrap_function(use_module, input_signature) - f.prune([], initializers)() - self.assertAllEqual( - [x.numpy() for x in f(9.0, 6.0)], - [19.0, 16.0]) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow_hub/resolver_test.py b/tensorflow_hub/resolver_test.py index e2a14e2b..980c16d2 100644 --- a/tensorflow_hub/resolver_test.py +++ b/tensorflow_hub/resolver_test.py @@ -408,7 +408,7 @@ def testModuleRunningWithUncompressedContext(self): "_request_gcs_location", return_value=module_export_path) as mocked_urlopen: with test_utils.UncompressedLoadFormatContext(): - m = hub.Module("https://tfhub.dev/google/model/1") + m = hub.load("https://tfhub.dev/google/model/1") mocked_urlopen.assert_called_once_with( "https://tfhub.dev/google/model/1?tf-hub-format=uncompressed") out = m(11) @@ -424,7 +424,7 @@ def _assert_resolver_is_called(self, http_resolver): with mock.patch.object( http_resolver, "__call__", side_effect=ValueError) as mocked_call: try: - hub.Module(module_url) + hub.load(module_url) self.fail("Failure expected since mock raises it as side effect.") except ValueError: pass diff --git a/tensorflow_hub/test_utils.py b/tensorflow_hub/test_utils.py index 923ca411..930ba0fe 100644 --- a/tensorflow_hub/test_utils.py +++ b/tensorflow_hub/test_utils.py @@ -21,7 +21,6 @@ from absl import flags import tensorflow as tf -import tensorflow_hub as hub from tensorflow_hub import resolver @@ -159,17 +158,16 @@ def get_test_data_path(file_or_dirname): def export_module(module_export_path): """Create and export a simple module to the specified path.""" - def _stateless_module_fn(): + class SquareModule(tf.Module): """Simple module that squares an input.""" - x = tf.compat.v1.placeholder(tf.int64) - y = x * x - hub.add_signature(inputs=x, outputs=y) - - spec = hub.create_module_spec(_stateless_module_fn) - m = hub.Module(spec, name="test_module") - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - m.export(module_export_path, sess) + + @tf.function + def __call__(self, x): + return x * x + + module = SquareModule() + module(tf.constant(0.0)) + tf.saved_model.save(module, module_export_path) class EnvVariableContextManager(object):