Skip to content

Commit

Permalink
use temporary directories
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Nov 18, 2024
1 parent 250f9dd commit dc4992e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 26 deletions.
2 changes: 2 additions & 0 deletions python/sdist/amici/jaxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import re
from collections.abc import Iterable
from logging import warning

import sympy as sp
from sympy.printing.numpy import NumPyPrinter
Expand All @@ -22,6 +23,7 @@ def doprint(self, expr: sp.Expr, assign_to: str | None = None) -> str:
) from e

def _print_AmiciSpline(self, expr: sp.Expr) -> str:
warning("Spline interpolation is support in JAX is untested")
# FIXME: untested, where are spline nodes coming from anyways?
return f'interp1d(time, {self.doprint(expr.args[2:])}, kind="cubic")'

Expand Down
53 changes: 27 additions & 26 deletions python/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from beartype import beartype

from amici.pysb_import import pysb2amici
from amici.testing import TemporaryDirectoryWinSafe
from numpy.testing import assert_allclose

pysb = pytest.importorskip("pysb")
Expand All @@ -28,17 +29,17 @@ def test_conversion():
pysb.Rule("conv", a(s="a") >> a(s="b"), pysb.Parameter("kcat", 0.05))
pysb.Observable("ab", a(s="b"))

outdir = model.name
pysb2amici(model, outdir, verbose=True, observables=["ab"])
with TemporaryDirectoryWinSafe(prefix=model.name) as outdir:
pysb2amici(model, outdir, verbose=True, observables=["ab"])

model_module = amici.import_model_module(
module_name=model.name, module_path=outdir
)
model_module = amici.import_model_module(
module_name=model.name, module_path=outdir
)

ts = tuple(np.linspace(0, 1, 10))
p = jnp.stack((1.0, 0.1), axis=-1)
k = tuple()
_test_model(model_module, ts, p, k)
ts = tuple(np.linspace(0, 1, 10))
p = jnp.stack((1.0, 0.1), axis=-1)
k = tuple()
_test_model(model_module, ts, p, k)


@pytest.mark.filterwarnings(
Expand Down Expand Up @@ -74,23 +75,23 @@ def test_dimerization():
pysb.Observable("a_obs", a())
pysb.Observable("b_obs", b())

outdir = model.name
pysb2amici(
model,
outdir,
verbose=True,
observables=["a_obs", "b_obs"],
constant_parameters=["ksyn_a", "ksyn_b"],
)

model_module = amici.import_model_module(
module_name=model.name, module_path=outdir
)

ts = tuple(np.linspace(0, 1, 10))
p = jnp.stack((5, 0.5, 0.5, 0.5), axis=-1)
k = (0.5, 5)
_test_model(model_module, ts, p, k)
with TemporaryDirectoryWinSafe(prefix=model.name) as outdir:
pysb2amici(
model,
outdir,
verbose=True,
observables=["a_obs", "b_obs"],
constant_parameters=["ksyn_a", "ksyn_b"],
)

model_module = amici.import_model_module(
module_name=model.name, module_path=outdir
)

ts = tuple(np.linspace(0, 1, 10))
p = jnp.stack((5, 0.5, 0.5, 0.5), axis=-1)
k = (0.5, 5)
_test_model(model_module, ts, p, k)


def _test_model(model_module, ts, p, k):
Expand Down

0 comments on commit dc4992e

Please sign in to comment.