diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index eb83cad4235..6ed7d3c6543 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -1219,75 +1219,6 @@ def test_functional_deeply_nested_outputs_struct_losses(self): ) self.assertListEqual(hist_keys, ref_keys) - @parameterized.named_parameters( - ("tf_saved_model", "tf_saved_model"), - ("onnx", "onnx"), - ) - @pytest.mark.skipif( - backend.backend() not in ("tensorflow", "jax", "torch"), - reason=( - "Currently, `Model.export` only supports the tensorflow, jax and " - "torch backends." - ), - ) - @pytest.mark.skipif( - testing.jax_uses_gpu(), reason="Leads to core dumps on CI" - ) - def test_export(self, export_format): - if export_format == "tf_saved_model" and testing.torch_uses_gpu(): - self.skipTest("Leads to core dumps on CI") - - temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") - model = _get_model() - x1 = np.random.rand(1, 3).astype("float32") - x2 = np.random.rand(1, 3).astype("float32") - ref_output = model([x1, x2]) - - model.export(temp_filepath, format=export_format) - - if export_format == "tf_saved_model": - import tensorflow as tf - - revived_model = tf.saved_model.load(temp_filepath) - self.assertAllClose(ref_output, revived_model.serve([x1, x2])) - - # Test with a different batch size - if backend.backend() == "torch": - # TODO: Dynamic shape is not supported yet in the torch backend - return - revived_model.serve( - [ - np.concatenate([x1, x1], axis=0), - np.concatenate([x2, x2], axis=0), - ] - ) - elif export_format == "onnx": - import onnxruntime - - ort_session = onnxruntime.InferenceSession(temp_filepath) - ort_inputs = { - k.name: v for k, v in zip(ort_session.get_inputs(), [x1, x2]) - } - self.assertAllClose( - ref_output, ort_session.run(None, ort_inputs)[0] - ) - - # Test with a different batch size - if backend.backend() == "torch": - # TODO: Dynamic shape is not supported yet in the torch backend - return - ort_inputs = { - k.name: v - for k, v in zip( - ort_session.get_inputs(), - [ - np.concatenate([x1, x1], axis=0), - np.concatenate([x2, x2], axis=0), - ], - ) - } - ort_session.run(None, ort_inputs) - def test_export_error(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") model = _get_model()