Skip to content

Commit

Permalink
Adds JAX version checking for Export (#923)
Browse files Browse the repository at this point in the history
* Add JAX version checking

* Fix formatting

* Add test for invalid version
  • Loading branch information
nkovela1 authored Sep 19, 2023
1 parent f2c3766 commit 1f9dd43
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
10 changes: 10 additions & 0 deletions keras_core/export/export_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@ def __init__(self):
"The export API is only compatible with JAX and TF backends."
)

# TODO(nkovela): Make JAX version checking programatic.
if backend.backend() == "jax":
from jax import __version__ as jax_v

if jax_v > "0.4.15":
raise ValueError(
"The export API is only compatible with JAX version 0.4.15 "
f"and prior. Your JAX version: {jax_v}"
)

@property
def variables(self):
return self._tf_trackable.variables
Expand Down
17 changes: 17 additions & 0 deletions keras_core/export/export_lib_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for inference-only model/layer exporting utilities."""
import os
import sys

import numpy as np
import pytest
Expand Down Expand Up @@ -28,6 +29,10 @@ def get_model():
backend.backend() not in ("tensorflow", "jax"),
reason="Export only currently supports the TF and JAX backends.",
)
@pytest.mark.skipif(
backend.backend() == "jax" and sys.modules["jax"].__version__ > "0.4.15",
reason="The export API is only compatible with JAX version <= 0.4.15.",
)
class ExportArchiveTest(testing.TestCase):
def test_standard_model_export(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
Expand Down Expand Up @@ -537,6 +542,18 @@ def test_model_export_method(self):
)


@pytest.mark.skipif(
backend.backend() != "jax" or sys.modules["jax"].__version__ <= "0.4.15",
reason="This test is for invalid JAX versions, i.e. versions > 0.4.15.",
)
class VersionTest(testing.TestCase):
def test_invalid_jax_version(self):
with self.assertRaisesRegex(
ValueError, "only compatible with JAX version"
):
_ = export_lib.ExportArchive()


@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="TFSM Layer reloading is only for the TF backend.",
Expand Down

0 comments on commit 1f9dd43

Please sign in to comment.