Skip to content

Commit

Permalink
Save "layers" first for subclassed Functional models (#18982)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw authored Dec 22, 2023
1 parent fe2f54a commit aaf4289
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
6 changes: 6 additions & 0 deletions keras/saving/saving_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,12 @@ def _walk_trackable(trackable):
raise ValueError(f"Invalid obj_type: {obj_type}")
attr_skiplist = get_attr_skiplist(obj_type)

# Save all layers directly tracked by Sequential and Functional first.
# This helps avoid ordering concerns for subclassed Sequential or Functional
# models with extra attributes--the internal Keras state take precedence.
if obj_type in ("Sequential", "Functional"):
yield "layers", trackable.layers

for child_attr in sorted(dir(trackable), key=lambda x: _name_key(x)):
if child_attr.startswith("__") or child_attr in attr_skiplist:
continue
Expand Down
65 changes: 65 additions & 0 deletions keras/saving/saving_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,32 @@ def compile(self, *args, **kwargs):
super().compile(*args, **kwargs)


@keras.saving.register_keras_serializable(package="my_custom_package")
class SubclassFunctional(keras.Model):
"""Subclassed functional identical to `_get_basic_functional_model`."""

def __init__(self, **kwargs):
inputs = keras.Input(shape=(4,), batch_size=2)
dense = keras.layers.Dense(1, name="first_dense")
x = dense(inputs)
outputs = keras.layers.Dense(1, name="second_dense")(x)
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
# Attrs for layers in the functional graph should not affect saving
self.layer_attr = dense

@property
def layer_property(self):
# Properties for layers in the functional graph should not affect saving
return self.layer_attr

def get_config(self):
return {}

@classmethod
def from_config(cls, config):
return cls(**config)


@keras.saving.register_keras_serializable(package="my_custom_package")
def my_mean_squared_error(y_true, y_pred):
"""Identical to built-in `mean_squared_error`, but as a custom fn."""
Expand Down Expand Up @@ -200,6 +226,17 @@ def _get_basic_functional_model(compile=True):
return functional_model


def _get_subclassed_functional_model(compile=True):
functional_model = SubclassFunctional()
if compile:
functional_model.compile(
optimizer="adam",
loss=my_mean_squared_error,
metrics=[keras.metrics.Hinge(), "mse"],
)
return functional_model


@pytest.mark.requires_trainable_backend
class SavingTest(testing.TestCase):
def _test_inference_after_instantiation(self, model):
Expand Down Expand Up @@ -234,6 +271,10 @@ def test_inference_after_instantiation_custom_functional(self):
model = _get_custom_functional_model(compile=False)
self._test_inference_after_instantiation(model)

def test_inference_after_instantiation_subclassed_functional(self):
model = _get_subclassed_functional_model(compile=False)
self._test_inference_after_instantiation(model)

def _test_compile_preserved(self, model):
x_ref = np.random.random((2, 4))
y_ref = np.random.random((2, 1))
Expand Down Expand Up @@ -286,6 +327,10 @@ def test_compile_preserved_custom_functional(self):
model = _get_custom_functional_model(compile=True)
self._test_compile_preserved(model)

def test_compile_preserved_subclassed_functional(self):
model = _get_subclassed_functional_model(compile=True)
self._test_compile_preserved(model)

def test_saving_preserve_unbuilt_state(self):
temp_filepath = os.path.join(self.get_temp_dir(), "my_model.keras")
subclassed_model = CustomModelX()
Expand Down Expand Up @@ -432,6 +477,26 @@ def test_load_weights_only_with_keras_file(self):
model.load_weights(temp_filepath)
self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6)

def test_save_weights_subclassed_functional(self):
# The subclassed and basic functional model should have the same
# weights structure.
temp_filepath = Path(
os.path.join(self.get_temp_dir(), "mymodel.weights.h5")
)
model = _get_basic_functional_model()
ref_input = np.random.random((2, 4))
ref_output = model.predict(ref_input)
# Test saving basic, loading subclassed.
saving_lib.save_weights_only(model, temp_filepath)
model = _get_subclassed_functional_model()
saving_lib.load_weights_only(model, temp_filepath)
self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6)
# Test saving subclassed, loading basic.
saving_lib.save_weights_only(model, temp_filepath)
model = _get_basic_functional_model()
saving_lib.load_weights_only(model, temp_filepath)
self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6)

def test_compile_arg(self):
temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras")
model = _get_basic_functional_model()
Expand Down

0 comments on commit aaf4289

Please sign in to comment.