Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix loading nested Functional models from config #19509

Closed
wants to merge 1 commit into from

Conversation

torzdf
Copy link

@torzdf torzdf commented Apr 14, 2024

This PR addresses issue #19326

During investigation I found that this bug does not just occur with a shared model structure, but with any model that contains nested Functional models.

The below code (and linked gist) demonstrate a simpler reproducible example:

import os

from keras import Input, Model
from keras.saving import load_model
from keras.layers import Dense

# Model 1 definition
input1 = Input((5, ), name="input1")
dense1 = Dense(3, name="dense1")(input1)
model1 = Model(input1, dense1, name="model1")

# Model 2 definition
input2 = Input((3, ), name="input2")
dense2 = Dense(3, name="dense2")(input2)
model2 = Model(input2, dense2, name="model2")

# Define outputs through the 2 models
outputs = model2(model1(input1))

# Nest the models
model = Model(input1, outputs)
print(model.summary())

model.save("test.keras")
test = load_model("test.keras")  # Fails with index out of range error

The implemented fix does the reverse of get_config() where the kept_nodes are incremented by 1 if the operation is a Functional Model.

When loading, the node_index is decremented by 1 if the layer is a Functional Model.

This fixes the issue and loads the model successfully. In addition, when comparing the get_config() of the saved and loaded models, they match.

Whilst this fix works, I am not 100% sure that it is the correct approach. Initially I was looking to amend the saving function, however, I since discovered that Keras 2 + Keras 3 (beyond some syntactical changes) stored the output_layers in the same way (with a node index of 1 when the number of outputs had a len() of 0), so for backwards compatibility reasons, I focused the fix on the loading function.

If this fix looks good, please let me know and I will implement a unit test. Advice on the best approach for the test would be appreciated. Initially I was looking to compare the original config with the loaded config, but due to conversions of tuples to lists, the config would require iterating. It may be enough just to make sure models structured this way load without a failure. Please let me know.

Copy link

google-cla bot commented Apr 14, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@codecov-commenter
Copy link

codecov-commenter commented Apr 14, 2024

Codecov Report

Attention: Patch coverage is 0% with 2 lines in your changes are missing coverage. Please review.

Project coverage is 76.27%. Comparing base (b267f93) to head (13c3c91).

Files Patch % Lines
keras/models/functional.py 0.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19509      +/-   ##
==========================================
- Coverage   76.28%   76.27%   -0.01%     
==========================================
  Files         367      367              
  Lines       41233    41235       +2     
  Branches     8076     8077       +1     
==========================================
  Hits        31453    31453              
- Misses       8059     8060       +1     
- Partials     1721     1722       +1     
Flag Coverage Δ
keras 76.13% <0.00%> (-0.01%) ⬇️
keras-jax 60.30% <0.00%> (-0.01%) ⬇️
keras-numpy 54.27% <0.00%> (-0.01%) ⬇️
keras-tensorflow 61.55% <0.00%> (-0.01%) ⬇️
keras-torch 60.41% <0.00%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I tried adding the following unit test in saving_lib_test.py:

    def test_nested_functional_model_saving(self):
        def shared():
            inputs = keras.layers.Input(shape=(4, ))
            outputs = keras.layers.Dense(2)((inputs))
            return keras.Model(inputs, outputs=outputs)

        def split():
            inputs  = keras.layers.Input(shape=(2, ))
            outputs = keras.layers.Dense(2)(inputs)
            return keras.Model(inputs, outputs=outputs)

        inputs = [keras.Input((4,)), keras.Input((4,))]
        shared_model = shared()
        shared_a = shared_model(inputs[0])
        shared_b = shared_model(inputs[1])
        out_a = split()(shared_a)
        out_b = split()(shared_b)
        model = keras.Model(inputs, outputs=[out_a, out_b])

        temp_filepath = os.path.join(self.get_temp_dir(), "nested_func.keras")
        model.save(temp_filepath)
        new_model = keras.saving.load_model(temp_filepath)
        x = [np.random.random((2, 4))], np.random.random((2, 4))
        ref_out = model(x)
        out = new_model(x)
        self.assertAllClose(ref_out[0], out[0])
        self.assertAllClose(ref_out[1], out[1])

While the model can be reloaded after the change in this PR, the reloaded version is incorrect (the second output has the wrong values). It sounds like we need a different fix. CC @SamanehSaadat

@torzdf
Copy link
Author

torzdf commented Apr 14, 2024

While the model can be reloaded after the change in this PR, the reloaded version is incorrect (the second output has the wrong values). It sounds like we need a different fix. CC @SamanehSaadat

Well... that's frustrating (and a little odd!). If I get a chance tomorrow, I'll have more of a poke around.

Actually, I see the difference in the model summaries. My fix was tested with the 'simpler' structure shown earlier.

@fchollet
Copy link
Collaborator

fchollet commented Apr 14, 2024

I think your fix works for functional model nesting. However there's an issue with nested functional model sharing. It's a bit challenging to reproduce, but here's the simplest model I found that reproduces it:

    def test_nested_shared_functional_model_saving(self):
        def func(in_size=4, out_size=2, name=None):
            inputs = keras.layers.Input(shape=(in_size,))
            outputs = keras.layers.Dense(out_size)((inputs))
            return keras.Model(inputs, outputs=outputs, name=name)

        inputs = [keras.Input((4,)), keras.Input((4,))]
        func_shared = func(out_size=4, name="func_shared")
        shared_a = func_shared(inputs[0])
        shared_b = func_shared(inputs[1])
        out_a = keras.layers.Dense(2)(shared_a)
        out_b = keras.layers.Dense(2)(shared_b)
        model = keras.Model(inputs, outputs=[out_a, out_b])

        temp_filepath = os.path.join(self.get_temp_dir(), "nested_shared_func.keras")
        model.save(temp_filepath)
        model.summary()
        new_model = keras.saving.load_model(temp_filepath)
        new_model.summary()
        x = [np.random.random((2, 4))], np.random.random((2, 4))
        ref_out = model(x)
        out = new_model(x)
        self.assertAllClose(ref_out[0], out[0])
        self.assertAllClose(ref_out[1], out[1])

Indeed the summaries differ. Basically func_shared doesn't look like it's connected to the second input. However during the deserialization process I can see that func_shared is called on both inputs. This will require careful investigation.

@fchollet
Copy link
Collaborator

I have fixed the issue. It was a silly bug and an easy fix, but tracking down the problem was challenging. I had to bring out my finest-quality print statements. 7dae3e9

Thanks for reporting, and for looking into it initially!

@fchollet fchollet closed this Apr 14, 2024
@torzdf
Copy link
Author

torzdf commented Apr 14, 2024

I have fixed the issue. It was a silly bug and an easy fix, but tracking down the problem was challenging. I had to bring out my finest-quality print statements. 7dae3e9

Thanks for reporting, and for looking into it initially!

Sounds like you were approaching it in much the same way as me! I'm glad you managed to find it. Thanks for the fix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Assigned Reviewer
Development

Successfully merging this pull request may close these issues.

4 participants