Skip to content

Commit

Permalink
Make test resilient to spurious warnings. (#20516)
Browse files Browse the repository at this point in the history
Test was counting warnings, but some other components can throw unrelated warnings.

This makes sure we only count the warnings we're looking for.
  • Loading branch information
hertschuh authored Nov 19, 2024
1 parent 3b8aba1 commit 660da94
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions keras/src/models/functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,24 +498,25 @@ def compute_output_shape(self, x_shape):
# Note: it's not intended to work in symbolic mode (yet).

def test_warning_for_mismatched_inputs_structure(self):
def is_input_warning(w):
return str(w.message).startswith(
"The structure of `inputs` doesn't match the expected structure"
)

i1 = Input((2,))
i2 = Input((2,))
outputs = layers.Add()([i1, i2])
model = Model({"i1": i1, "i2": i2}, outputs)

with pytest.warns() as record:
model = Model({"i1": i1, "i2": i2}, outputs)
with pytest.warns() as warning_logs:
model.predict([np.ones((2, 2)), np.zeros((2, 2))], verbose=0)
self.assertLen(record, 1)
self.assertStartsWith(
str(record[0].message),
r"The structure of `inputs` doesn't match the expected structure:",
)
self.assertLen(list(filter(is_input_warning, warning_logs)), 1)

# No warning for mismatched tuples and lists.
model = Model([i1, i2], outputs)
with warnings.catch_warnings(record=True) as warning_logs:
model.predict((np.ones((2, 2)), np.zeros((2, 2))), verbose=0)
self.assertLen(warning_logs, 0)
self.assertLen(list(filter(is_input_warning, warning_logs)), 0)

def test_for_functional_in_sequential(self):
# Test for a v3.4.1 regression.
Expand Down

0 comments on commit 660da94

Please sign in to comment.