Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] register_checkpoint_saver API was not stable under TF version 2.11. #473

Merged
merged 1 commit into from
Oct 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import inspect
import functools
import os.path
from packaging import version
import re

from tensorflow_recommenders_addons import dynamic_embedding as de
Expand Down Expand Up @@ -54,6 +55,7 @@
from tensorflow.python.training.saving import functional_saver
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow import version as tf_version

tf_original_save_func = tf_saved_model_save.save
if keras_saved_model_save is not None:
Expand Down Expand Up @@ -566,7 +568,9 @@ def restore(self, sess, save_path):


def patch_on_tf_save_restore():
try:
if version.parse(tf_version.VERSION) < version.parse("2.11"):
functional_saver._SingleDeviceSaver = _DynamicEmbeddingSingleDeviceSaver
else:
from tensorflow.python.saved_model.registration.registration import register_checkpoint_saver
class_obj = de.Variable
predicate = lambda x: isinstance(x, class_obj)
Expand All @@ -584,8 +588,6 @@ def patch_on_tf_save_restore():
k_name = param.name
kwargs[k_name] = prekwargs[k_name]
register_checkpoint_saver(**kwargs)
except:
functional_saver._SingleDeviceSaver = _DynamicEmbeddingSingleDeviceSaver
saver.Saver = _DynamicEmbeddingSaver
# # Replace origin saving function is too dangerous.
# tf_saved_model_save.save = functools.partial(de.keras.models._de_keras_save_func,
Expand Down
Loading