diff --git a/python/sdist/amici/jaxcodeprinter.py b/python/sdist/amici/jaxcodeprinter.py index f2d5b29248..ed9181cc09 100644 --- a/python/sdist/amici/jaxcodeprinter.py +++ b/python/sdist/amici/jaxcodeprinter.py @@ -2,6 +2,7 @@ import re from collections.abc import Iterable +from logging import warning import sympy as sp from sympy.printing.numpy import NumPyPrinter @@ -22,6 +23,7 @@ def doprint(self, expr: sp.Expr, assign_to: str | None = None) -> str: ) from e def _print_AmiciSpline(self, expr: sp.Expr) -> str: + warning("Spline interpolation is support in JAX is untested") # FIXME: untested, where are spline nodes coming from anyways? return f'interp1d(time, {self.doprint(expr.args[2:])}, kind="cubic")' diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index d66f258e24..d124a6e1be 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -11,6 +11,7 @@ from beartype import beartype from amici.pysb_import import pysb2amici +from amici.testing import TemporaryDirectoryWinSafe from numpy.testing import assert_allclose pysb = pytest.importorskip("pysb") @@ -28,17 +29,17 @@ def test_conversion(): pysb.Rule("conv", a(s="a") >> a(s="b"), pysb.Parameter("kcat", 0.05)) pysb.Observable("ab", a(s="b")) - outdir = model.name - pysb2amici(model, outdir, verbose=True, observables=["ab"]) + with TemporaryDirectoryWinSafe(prefix=model.name) as outdir: + pysb2amici(model, outdir, verbose=True, observables=["ab"]) - model_module = amici.import_model_module( - module_name=model.name, module_path=outdir - ) + model_module = amici.import_model_module( + module_name=model.name, module_path=outdir + ) - ts = tuple(np.linspace(0, 1, 10)) - p = jnp.stack((1.0, 0.1), axis=-1) - k = tuple() - _test_model(model_module, ts, p, k) + ts = tuple(np.linspace(0, 1, 10)) + p = jnp.stack((1.0, 0.1), axis=-1) + k = tuple() + _test_model(model_module, ts, p, k) @pytest.mark.filterwarnings( @@ -74,23 +75,23 @@ def test_dimerization(): pysb.Observable("a_obs", a()) pysb.Observable("b_obs", b()) - outdir = model.name - pysb2amici( - model, - outdir, - verbose=True, - observables=["a_obs", "b_obs"], - constant_parameters=["ksyn_a", "ksyn_b"], - ) - - model_module = amici.import_model_module( - module_name=model.name, module_path=outdir - ) - - ts = tuple(np.linspace(0, 1, 10)) - p = jnp.stack((5, 0.5, 0.5, 0.5), axis=-1) - k = (0.5, 5) - _test_model(model_module, ts, p, k) + with TemporaryDirectoryWinSafe(prefix=model.name) as outdir: + pysb2amici( + model, + outdir, + verbose=True, + observables=["a_obs", "b_obs"], + constant_parameters=["ksyn_a", "ksyn_b"], + ) + + model_module = amici.import_model_module( + module_name=model.name, module_path=outdir + ) + + ts = tuple(np.linspace(0, 1, 10)) + p = jnp.stack((5, 0.5, 0.5, 0.5), axis=-1) + k = (0.5, 5) + _test_model(model_module, ts, p, k) def _test_model(model_module, ts, p, k):