Skip to content

Commit

Permalink
Remove duplicate export tests in model_test. (#20735)
Browse files Browse the repository at this point in the history
  • Loading branch information
hertschuh authored Jan 8, 2025
1 parent f97be63 commit fd2955f
Showing 1 changed file with 0 additions and 69 deletions.
69 changes: 0 additions & 69 deletions keras/src/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,75 +1219,6 @@ def test_functional_deeply_nested_outputs_struct_losses(self):
)
self.assertListEqual(hist_keys, ref_keys)

@parameterized.named_parameters(
("tf_saved_model", "tf_saved_model"),
("onnx", "onnx"),
)
@pytest.mark.skipif(
backend.backend() not in ("tensorflow", "jax", "torch"),
reason=(
"Currently, `Model.export` only supports the tensorflow, jax and "
"torch backends."
),
)
@pytest.mark.skipif(
testing.jax_uses_gpu(), reason="Leads to core dumps on CI"
)
def test_export(self, export_format):
if export_format == "tf_saved_model" and testing.torch_uses_gpu():
self.skipTest("Leads to core dumps on CI")

temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = _get_model()
x1 = np.random.rand(1, 3).astype("float32")
x2 = np.random.rand(1, 3).astype("float32")
ref_output = model([x1, x2])

model.export(temp_filepath, format=export_format)

if export_format == "tf_saved_model":
import tensorflow as tf

revived_model = tf.saved_model.load(temp_filepath)
self.assertAllClose(ref_output, revived_model.serve([x1, x2]))

# Test with a different batch size
if backend.backend() == "torch":
# TODO: Dynamic shape is not supported yet in the torch backend
return
revived_model.serve(
[
np.concatenate([x1, x1], axis=0),
np.concatenate([x2, x2], axis=0),
]
)
elif export_format == "onnx":
import onnxruntime

ort_session = onnxruntime.InferenceSession(temp_filepath)
ort_inputs = {
k.name: v for k, v in zip(ort_session.get_inputs(), [x1, x2])
}
self.assertAllClose(
ref_output, ort_session.run(None, ort_inputs)[0]
)

# Test with a different batch size
if backend.backend() == "torch":
# TODO: Dynamic shape is not supported yet in the torch backend
return
ort_inputs = {
k.name: v
for k, v in zip(
ort_session.get_inputs(),
[
np.concatenate([x1, x1], axis=0),
np.concatenate([x2, x2], axis=0),
],
)
}
ort_session.run(None, ort_inputs)

def test_export_error(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = _get_model()
Expand Down

0 comments on commit fd2955f

Please sign in to comment.