From 1f9dd4311e7e20c3a55e97dec9016eac2c3ce703 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi <60985914+nkovela1@users.noreply.github.com> Date: Tue, 19 Sep 2023 15:41:48 -0700 Subject: [PATCH] Adds JAX version checking for Export (#923) * Add JAX version checking * Fix formatting * Add test for invalid version --- keras_core/export/export_lib.py | 10 ++++++++++ keras_core/export/export_lib_test.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/keras_core/export/export_lib.py b/keras_core/export/export_lib.py index 5dae2602c..7af83a403 100644 --- a/keras_core/export/export_lib.py +++ b/keras_core/export/export_lib.py @@ -91,6 +91,16 @@ def __init__(self): "The export API is only compatible with JAX and TF backends." ) + # TODO(nkovela): Make JAX version checking programatic. + if backend.backend() == "jax": + from jax import __version__ as jax_v + + if jax_v > "0.4.15": + raise ValueError( + "The export API is only compatible with JAX version 0.4.15 " + f"and prior. Your JAX version: {jax_v}" + ) + @property def variables(self): return self._tf_trackable.variables diff --git a/keras_core/export/export_lib_test.py b/keras_core/export/export_lib_test.py index 2dfdf950f..91f5c48a8 100644 --- a/keras_core/export/export_lib_test.py +++ b/keras_core/export/export_lib_test.py @@ -1,5 +1,6 @@ """Tests for inference-only model/layer exporting utilities.""" import os +import sys import numpy as np import pytest @@ -28,6 +29,10 @@ def get_model(): backend.backend() not in ("tensorflow", "jax"), reason="Export only currently supports the TF and JAX backends.", ) +@pytest.mark.skipif( + backend.backend() == "jax" and sys.modules["jax"].__version__ > "0.4.15", + reason="The export API is only compatible with JAX version <= 0.4.15.", +) class ExportArchiveTest(testing.TestCase): def test_standard_model_export(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") @@ -537,6 +542,18 @@ def test_model_export_method(self): ) +@pytest.mark.skipif( + backend.backend() != "jax" or sys.modules["jax"].__version__ <= "0.4.15", + reason="This test is for invalid JAX versions, i.e. versions > 0.4.15.", +) +class VersionTest(testing.TestCase): + def test_invalid_jax_version(self): + with self.assertRaisesRegex( + ValueError, "only compatible with JAX version" + ): + _ = export_lib.ExportArchive() + + @pytest.mark.skipif( backend.backend() != "tensorflow", reason="TFSM Layer reloading is only for the TF backend.",