Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax export #1861

Merged
merged 94 commits into from
Nov 19, 2024
Merged
Changes from 1 commit
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
328d462
basic prototype
FFroehlich Aug 25, 2022
ffa5afb
Merge branch 'develop' into jax_export
FFroehlich Aug 26, 2022
d4f8552
add dimerization example, add second order code, refactor jit
FFroehlich Aug 26, 2022
d37a850
remove equinox dependency, list dependencies
FFroehlich Aug 26, 2022
ff37c7e
make jax optional
FFroehlich Aug 26, 2022
c3a77f7
Merge branch 'develop' into jax_export
FFroehlich Aug 26, 2022
7cd8553
support conservation laws
FFroehlich Aug 26, 2022
5177ad7
fixup
FFroehlich Aug 26, 2022
5612cfc
fix jit nesting
FFroehlich Aug 26, 2022
2dd0377
use vmap for vectorization
FFroehlich Aug 26, 2022
e9bd14f
fixups
FFroehlich Aug 26, 2022
bbb5246
add multithreaded simulation runner
FFroehlich Aug 26, 2022
9bd1004
fix my
FFroehlich Aug 26, 2022
1b06c24
Merge branch 'develop' into jax_export
FFroehlich Sep 9, 2022
599aa71
fixes
FFroehlich Sep 13, 2022
51812d6
merge
FFroehlich Apr 10, 2024
3fbd17a
fixup merge
FFroehlich Apr 10, 2024
5974d47
fix install
FFroehlich Apr 10, 2024
37cdc81
actually generate code
FFroehlich Apr 10, 2024
9e6a0ff
fix
FFroehlich Apr 10, 2024
22b2b38
fix
FFroehlich Apr 10, 2024
48a2e49
add better default coefficients, fix jax
FFroehlich Apr 10, 2024
481216d
ignore fujita in jax
FFroehlich Apr 10, 2024
85b8173
ignore smith
FFroehlich Apr 10, 2024
b213adb
optimize & fix bachmann
FFroehlich Apr 11, 2024
a1f37b7
fix import/wokflow
FFroehlich Apr 11, 2024
e09bb2f
Update __init__.template.py
FFroehlich Apr 12, 2024
d8d1900
fix jax imports
FFroehlich Apr 12, 2024
c24fe6b
Update setup.cfg
FFroehlich Apr 12, 2024
1ec591c
add preequilibration support
FFroehlich Apr 12, 2024
aebe07c
fix jax tests
FFroehlich Apr 13, 2024
4125c51
add filterwarning
FFroehlich Apr 14, 2024
8143cc2
fix parameter transformation
FFroehlich Apr 14, 2024
781bb3b
Merge branch 'develop' into jax_export
FFroehlich Oct 19, 2024
81e2aeb
reenable ruff format
FFroehlich Oct 19, 2024
c01f707
post merge cleanup
FFroehlich Oct 19, 2024
a5d356a
"fix" splines
FFroehlich Oct 19, 2024
9a021cf
Update .pre-commit-config.yaml
FFroehlich Oct 19, 2024
a02d215
Merge branch 'develop' into jax_export
FFroehlich Oct 19, 2024
50193d8
force optimistix 0.0.9
FFroehlich Oct 21, 2024
d6c5bcd
Merge branch 'jax_export' of https://github.com/AMICI-dev/AMICI into …
FFroehlich Oct 21, 2024
7faae32
add support for heavyside functions
FFroehlich Oct 21, 2024
907acb7
cleanup & actually run tests
FFroehlich Oct 21, 2024
82a01ba
simply tests + add support for non-dynamic simulation in jax
FFroehlich Oct 22, 2024
7c3aef9
Merge branch 'develop' into jax_export
FFroehlich Oct 23, 2024
c548c93
fix for NONCONST_CLS
FFroehlich Oct 24, 2024
7c27a21
fix petab path
FFroehlich Oct 24, 2024
b84dbdb
Merge branch 'develop' into jax_export
FFroehlich Oct 24, 2024
37b9329
Merge branch 'develop' into jax_export
FFroehlich Oct 24, 2024
956b0a6
fixup merge
FFroehlich Oct 24, 2024
2f3834d
support postequilibration
FFroehlich Oct 25, 2024
5366632
fixup
FFroehlich Oct 25, 2024
5a86f4c
fix
FFroehlich Oct 25, 2024
480b75a
fix gradients
FFroehlich Oct 25, 2024
8b9c10a
fix hessian
FFroehlich Oct 25, 2024
7dc81ac
Update test_petab_benchmark.py
FFroehlich Oct 25, 2024
866c811
Merge branch 'develop' into jax_export
FFroehlich Oct 27, 2024
02a1272
skip smith in jax
FFroehlich Oct 27, 2024
51bd18c
exclude more models
FFroehlich Oct 27, 2024
c7c5d4b
refactor: remove use of edatas
FFroehlich Nov 9, 2024
a514deb
update template
FFroehlich Nov 9, 2024
498681a
Update .pre-commit-config.yaml
FFroehlich Nov 9, 2024
4a5e7d2
Merge branch 'develop' into jax_export
FFroehlich Nov 11, 2024
f745be0
fix python jax tests
FFroehlich Nov 12, 2024
a64f89b
simplify petab interface
FFroehlich Nov 12, 2024
7292451
add parameter values to model class
FFroehlich Nov 12, 2024
da02106
refactor parameter mapping
FFroehlich Nov 12, 2024
a46e65d
refactor & simplify
FFroehlich Nov 12, 2024
404d82e
refsctor
FFroehlich Nov 16, 2024
e399f4c
update template
FFroehlich Nov 16, 2024
eaae778
Update .pre-commit-config.yaml
FFroehlich Nov 16, 2024
d79cfc1
refactor fix test
FFroehlich Nov 16, 2024
94aa679
Update petab.py
FFroehlich Nov 16, 2024
b129c86
fixups
FFroehlich Nov 17, 2024
9b6a62b
fixup
FFroehlich Nov 17, 2024
74cd498
add documentation and typing
FFroehlich Nov 17, 2024
d94714b
add runtime typechecks to jax tests
FFroehlich Nov 17, 2024
0a9fcdf
add coverage from benchmark tests
FFroehlich Nov 17, 2024
186805c
add api versioning and reenable jit compilation
FFroehlich Nov 17, 2024
250f9dd
review comments
FFroehlich Nov 18, 2024
dc4992e
use temporary directories
FFroehlich Nov 18, 2024
d547509
fix doc
FFroehlich Nov 18, 2024
82bfe31
Update test_jax.py
FFroehlich Nov 18, 2024
a010803
don't generate code if jax/diffrax not available
FFroehlich Nov 18, 2024
d9ae05e
Merge branch 'develop' into jax_export
FFroehlich Nov 18, 2024
f7c2c10
add example
FFroehlich Nov 19, 2024
5dc8735
fix doc
FFroehlich Nov 19, 2024
784ab2c
fix notebook symlink
FFroehlich Nov 19, 2024
d528168
update notebook
FFroehlich Nov 19, 2024
24d8c09
Update ExampleJaxPEtab.ipynb
FFroehlich Nov 19, 2024
5393e6c
Update ExampleJaxPEtab.ipynb
FFroehlich Nov 19, 2024
a22f099
fix compilation issue
FFroehlich Nov 19, 2024
a585414
Merge branch 'develop' into jax_export
FFroehlich Nov 19, 2024
c242b15
fix
FFroehlich Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
optimize & fix bachmann
FFroehlich committed Apr 11, 2024
commit b213adb92ec2f4d5d844732df12d4901cb9f5038
288 changes: 159 additions & 129 deletions python/sdist/amici/jax.py
Original file line number Diff line number Diff line change
@@ -3,232 +3,262 @@
from concurrent.futures import ThreadPoolExecutor

import diffrax
import equinox as eqx
import jax.numpy as jnp
import numpy as np
import jax
from functools import partial
from collections.abc import Iterable

import amici

jax.config.update("jax_enable_x64", True)

Check warning on line 14 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L14

Added line #L14 was not covered by tests


class JAXModel:
class JAXModel(eqx.Module):
_unscale_funs = {

Check warning on line 18 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L17-L18

Added lines #L17 - L18 were not covered by tests
amici.ParameterScaling.none: lambda x: x,
amici.ParameterScaling.ln: lambda x: jnp.exp(x),
amici.ParameterScaling.log10: lambda x: jnp.power(10, x),
}
solver: diffrax.AbstractSolver
controller: diffrax.AbstractStepSizeController
atol: float
rtol: float
pcoeff: float
icoeff: float
dcoeff: float
maxsteps: int
term: diffrax.ODETerm
sensi_order: amici.SensitivityOrder

Check warning on line 32 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L23-L32

Added lines #L23 - L32 were not covered by tests

def __init__(self):
self.solver = diffrax.Kvaerno5()
self.atol: float = 1e-8
self.rtol: float = 1e-8
self.pcoeff: float = 0.4
self.icoeff: float = 0.3
self.dcoeff: float = 0.0
self.maxsteps: int = 2**10
self.controller = diffrax.PIDController(

Check warning on line 42 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L34-L42

Added lines #L34 - L42 were not covered by tests
rtol=self.rtol,
atol=self.atol,
pcoeff=self.pcoeff,
icoeff=self.icoeff,
dcoeff=self.dcoeff,
)
self.term = diffrax.ODETerm(self.xdot)
self.sensi_order = amici.SensitivityOrder.none

Check warning on line 50 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L49-L50

Added lines #L49 - L50 were not covered by tests

@staticmethod
@abstractmethod
def xdot(self, t, x, args):
def xdot(t, x, args):
...

Check warning on line 55 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L52-L55

Added lines #L52 - L55 were not covered by tests

@staticmethod
@abstractmethod
def _w(self, t, x, p, k, tcl):
def _w(t, x, p, k, tcl):
...

Check warning on line 60 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L57-L60

Added lines #L57 - L60 were not covered by tests

@staticmethod
@abstractmethod
def x0(self, p, k):
def x0(p, k):
...

Check warning on line 65 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L62-L65

Added lines #L62 - L65 were not covered by tests

@staticmethod
@abstractmethod
def x_solver(self, x):
def x_solver(x):
...

Check warning on line 70 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L67-L70

Added lines #L67 - L70 were not covered by tests

@staticmethod
@abstractmethod
def x_rdata(self, x, tcl):
def x_rdata(x, tcl):
...

Check warning on line 75 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L72-L75

Added lines #L72 - L75 were not covered by tests

@staticmethod
@abstractmethod
def tcl(self, x, p, k):
def tcl(x, p, k):
...

Check warning on line 80 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L77-L80

Added lines #L77 - L80 were not covered by tests

@staticmethod
@abstractmethod
def y(self, t, x, p, k, tcl):
def y(t, x, p, k, tcl):
...

Check warning on line 85 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L82-L85

Added lines #L82 - L85 were not covered by tests

@staticmethod
@abstractmethod
def sigmay(self, y, p, k):
def sigmay(y, p, k):
...

Check warning on line 90 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L87-L90

Added lines #L87 - L90 were not covered by tests

@staticmethod
@abstractmethod
def Jy(self, y, my, sigmay):
def Jy(y, my, sigmay):
...

Check warning on line 95 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L92-L95

Added lines #L92 - L95 were not covered by tests

def unscale_p(self, p, pscale):
return jnp.stack(
[
self._unscale_funs[pscale_i](p_i)
for p_i, pscale_i in zip(p, pscale)
]
)

def get_solver(self):
return JAXSolver(model=self)


class JAXSolver:
def __init__(self, model: JAXModel):
self.model: JAXModel = model
self.solver: diffrax.AbstractSolver = diffrax.Kvaerno5()
self.atol: float = 1e-8
self.rtol: float = 1e-8
self.pcoeff: float = 0.4
self.icoeff: float = 0.3
self.dcoeff: float = 0.0
self.maxsteps: int = int(1e6)
self.sensi_mode: amici.SensitivityMethod = (
amici.SensitivityMethod.adjoint
)
self.sensi_order: amici.SensitivityOrder = amici.SensitivityOrder.none
return jax.vmap(

Check warning on line 98 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L97-L98

Added lines #L97 - L98 were not covered by tests
lambda p_i, pscale_i: jnp.stack(
(p_i, jnp.exp(p_i), jnp.power(10, p_i))
)
.at[pscale_i]
.get()
)(p, pscale)

def _solve(self, ts, p, k):
x0 = self.model.x0(p, k)
tcl = self.model.tcl(x0, p, k)
x0 = self.x0(p, k)
tcl = self.tcl(x0, p, k)
sol = diffrax.diffeqsolve(

Check warning on line 109 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L106-L109

Added lines #L106 - L109 were not covered by tests
diffrax.ODETerm(self.model.xdot),
self.term,
self.solver,
args=(p, k, tcl),
t0=0.0,
t1=ts[-1],
dt0=None,
y0=self.model.x_solver(x0),
stepsize_controller=diffrax.PIDController(
rtol=self.rtol,
atol=self.atol,
pcoeff=self.pcoeff,
icoeff=self.icoeff,
dcoeff=self.dcoeff,
),
y0=self.x_solver(x0),
stepsize_controller=self.controller,
max_steps=self.maxsteps,
saveat=diffrax.SaveAt(ts=ts),
)
return sol.ys, tcl
return sol.ys, tcl, sol.stats

Check warning on line 121 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L121

Added line #L121 was not covered by tests

def _obs(self, ts, x, p, k, tcl):
return jax.vmap(self.model.y, in_axes=(0, 0, None, None, None))(
np.asarray(ts), x, p, k, tcl
return jax.vmap(self.y, in_axes=(0, 0, None, None, None))(

Check warning on line 124 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L123-L124

Added lines #L123 - L124 were not covered by tests
ts, x, p, k, tcl
)

def _sigmay(self, obs, p, k):
return jax.vmap(self.model.sigmay, in_axes=(0, None, None))(obs, p, k)
return jax.vmap(self.sigmay, in_axes=(0, None, None))(obs, p, k)

Check warning on line 129 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L128-L129

Added lines #L128 - L129 were not covered by tests

def _x_rdata(self, x, tcl):
return jax.vmap(self.model.x_rdata, in_axes=(0, None))(x, tcl)
return jax.vmap(self.x_rdata, in_axes=(0, None))(x, tcl)

Check warning on line 132 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L131-L132

Added lines #L131 - L132 were not covered by tests

def _loss(self, obs: jnp.ndarray, sigmay: jnp.ndarray, my: np.ndarray):
loss_fun = jax.vmap(self.model.Jy, in_axes=(0, 0, 0))
loss_fun = jax.vmap(self.Jy, in_axes=(0, 0, 0))
return -jnp.sum(loss_fun(obs, my, sigmay))

Check warning on line 136 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L134-L136

Added lines #L134 - L136 were not covered by tests

def _run(

Check warning on line 138 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L138

Added line #L138 was not covered by tests
self, ts: tuple, p: jnp.ndarray, k: tuple, my: tuple, pscale: tuple
self,
ts: np.ndarray,
p: np.ndarray,
k: jnp.ndarray,
my: jnp.ndarray,
pscale: np.ndarray,
):
ps = self.model.unscale_p(p, pscale)
x, tcl = self._solve(ts, ps, k)
ps = self.unscale_p(p, pscale)
x, tcl, stats = self._solve(ts, ps, k)
obs = self._obs(ts, x, ps, k, tcl)
my_r = np.asarray(my).reshape((len(ts), -1))
my_r = my.reshape((len(ts), -1))
sigmay = self._sigmay(obs, ps, k)
llh = self._loss(obs, sigmay, my_r)
x_rdata = self._x_rdata(x, tcl)
return llh, (x_rdata, obs)
return llh, (x_rdata, obs, stats)

Check warning on line 153 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L146-L153

Added lines #L146 - L153 were not covered by tests

@partial(jax.jit, static_argnames=("self", "ts", "k", "my", "pscale"))
@eqx.filter_jit
def run(

Check warning on line 156 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L155-L156

Added lines #L155 - L156 were not covered by tests
self, ts: tuple, p: jnp.ndarray, k: tuple, my: tuple, pscale: tuple
self,
ts: np.ndarray,
p: jnp.ndarray,
k: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
):
return self._run(ts, p, k, my, pscale)

Check warning on line 164 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L164

Added line #L164 was not covered by tests

@partial(jax.jit, static_argnames=("self", "ts", "k", "my", "pscale"))
@eqx.filter_jit
def srun(

Check warning on line 167 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L166-L167

Added lines #L166 - L167 were not covered by tests
self, ts: tuple, p: jnp.ndarray, k: tuple, my: tuple, pscale: tuple
self,
ts: np.ndarray,
p: jnp.ndarray,
k: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
):
(llh, (x, obs)), sllh = (jax.value_and_grad(self._run, 1, True))(
ts, p, k, my, pscale
)
return llh, sllh, (x, obs)
(llh, (x, obs, stats)), sllh = (

Check warning on line 175 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L175

Added line #L175 was not covered by tests
jax.value_and_grad(self._run, 1, True)
)(ts, p, k, my, pscale)
return llh, sllh, (x, obs, stats)

Check warning on line 178 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L178

Added line #L178 was not covered by tests

@partial(jax.jit, static_argnames=("self", "ts", "k", "my", "pscale"))
@eqx.filter_jit
def s2run(

Check warning on line 181 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L180-L181

Added lines #L180 - L181 were not covered by tests
self, ts: tuple, p: jnp.ndarray, k: tuple, my: tuple, pscale: tuple
self,
ts: np.ndarray,
p: jnp.ndarray,
k: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
):
(llh, (x, obs)), sllh = (jax.value_and_grad(self._run, 1, True))(
(llh, (_, _, _)), sllh = (jax.value_and_grad(self._run, 1, True))(

Check warning on line 189 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L189

Added line #L189 was not covered by tests
ts, p, k, my, pscale
)
s2llh, (x, obs) = jax.jacfwd(jax.grad(self._run, 1, True), 1, True)(
ts, p, k, my, pscale
)
return llh, sllh, s2llh, (x, obs)


def run_simulations(
model: JAXModel,
solver: JAXSolver,
edatas: Iterable[amici.ExpData],
num_threads: int = 1,
):
def run(edata):
return run_simulation(model, solver, edata)

if num_threads > 1:
with ThreadPoolExecutor(max_workers=num_threads) as pool:
results = pool.map(run, edatas)
else:
results = map(run, edatas)
return list(results)


def run_simulation(model: JAXModel, solver: JAXSolver, edata: amici.ExpData):
ts = tuple(edata.getTimepoints())
p = jnp.asarray(edata.parameters)
k = tuple(edata.fixedParameters)
my = tuple(edata.getObservedData())
pscale = tuple(edata.pscale)

rdata_kwargs = dict()

if solver.sensi_order == amici.SensitivityOrder.none:
(
rdata_kwargs["llh"],
(rdata_kwargs["x"], rdata_kwargs["y"]),
) = solver.run(ts, p, k, my, pscale)
elif solver.sensi_order == amici.SensitivityOrder.first:
(
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
(rdata_kwargs["x"], rdata_kwargs["y"]),
) = solver.srun(ts, p, k, my, pscale)
elif solver.sensi_order == amici.SensitivityOrder.second:
(
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
rdata_kwargs["s2llh"],
(rdata_kwargs["x"], rdata_kwargs["y"]),
) = solver.s2run(ts, p, k, my, pscale)

for field in rdata_kwargs.keys():
if field == "llh":
rdata_kwargs[field] = np.float64(rdata_kwargs[field])
elif field not in ["sllh", "s2llh"]:
rdata_kwargs[field] = np.asarray(rdata_kwargs[field]).T
if rdata_kwargs[field].ndim == 1:
rdata_kwargs[field] = np.expand_dims(rdata_kwargs[field], 1)

return ReturnDataJAX(**rdata_kwargs)
s2llh, (x, obs, stats) = jax.jacfwd(

Check warning on line 192 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L192

Added line #L192 was not covered by tests
jax.grad(self._run, 1, True), 1, True
)(ts, p, k, my, pscale)
return llh, sllh, s2llh, (x, obs, stats)

Check warning on line 195 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L195

Added line #L195 was not covered by tests

def run_simulation(self, edata: amici.ExpData):
ts = np.asarray(edata.getTimepoints())
p = jnp.asarray(edata.parameters)
k = np.asarray(edata.fixedParameters)
my = np.asarray(edata.getObservedData())
pscale = np.asarray(edata.pscale)

Check warning on line 202 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L197-L202

Added lines #L197 - L202 were not covered by tests

rdata_kwargs = dict()

Check warning on line 204 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L204

Added line #L204 was not covered by tests

if self.sensi_order == amici.SensitivityOrder.none:
(

Check warning on line 207 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L206-L207

Added lines #L206 - L207 were not covered by tests
rdata_kwargs["llh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.run(ts, p, k, my, pscale)
elif self.sensi_order == amici.SensitivityOrder.first:
(

Check warning on line 212 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L211-L212

Added lines #L211 - L212 were not covered by tests
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.srun(ts, p, k, my, pscale)
elif self.sensi_order == amici.SensitivityOrder.second:
(

Check warning on line 218 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L217-L218

Added lines #L217 - L218 were not covered by tests
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
rdata_kwargs["s2llh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.s2run(ts, p, k, my, pscale)

for field in rdata_kwargs.keys():
if field == "llh":
rdata_kwargs[field] = np.float64(rdata_kwargs[field])
elif field not in ["sllh", "s2llh"]:
rdata_kwargs[field] = np.asarray(rdata_kwargs[field]).T
if rdata_kwargs[field].ndim == 1:
rdata_kwargs[field] = np.expand_dims(

Check warning on line 231 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L225-L231

Added lines #L225 - L231 were not covered by tests
rdata_kwargs[field], 1
)

return ReturnDataJAX(**rdata_kwargs)

Check warning on line 235 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L235

Added line #L235 was not covered by tests

def run_simulations(

Check warning on line 237 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L237

Added line #L237 was not covered by tests
self,
edatas: Iterable[amici.ExpData],
num_threads: int = 1,
):
if num_threads > 1:
with ThreadPoolExecutor(max_workers=num_threads) as pool:
results = pool.map(self.run_simulation, edatas)

Check warning on line 244 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L242-L244

Added lines #L242 - L244 were not covered by tests
else:
results = map(self.run_simulation, edatas)
return list(results)

Check warning on line 247 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L246-L247

Added lines #L246 - L247 were not covered by tests


@dataclass
class ReturnDataJAX(dict):
x: np.array = None
sx: np.array = None
y: np.array = None
sy: np.array = None
sigmay: np.array = None
ssigmay: np.array = None
llh: np.array = None
sllh: np.array = None
stats: dict = None

Check warning on line 260 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L250-L260

Added lines #L250 - L260 were not covered by tests

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self

Check warning on line 264 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L262-L264

Added lines #L262 - L264 were not covered by tests
31 changes: 20 additions & 11 deletions python/sdist/amici/jax.template.py
Original file line number Diff line number Diff line change
@@ -7,21 +7,23 @@ class JAXModel_TPL_MODEL_NAME(JAXModel):
def __init__(self):
super().__init__()

def xdot(self, t, x, args):
@staticmethod
def xdot(t, x, args):

p, k, tcl = args

TPL_X_SYMS = x
TPL_P_SYMS = p
TPL_K_SYMS = k
TPL_TCL_SYMS = tcl
TPL_W_SYMS = self._w(t, x, p, k, tcl)
TPL_W_SYMS = JAXModel_TPL_MODEL_NAME._w(t, x, p, k, tcl)

TPL_XDOT_EQ

return TPL_XDOT_RET

def _w(self, t, x, p, k, tcl):
@staticmethod
def _w(t, x, p, k, tcl):

TPL_X_SYMS = x
TPL_P_SYMS = p
@@ -32,7 +34,8 @@ def _w(self, t, x, p, k, tcl):

return TPL_W_RET

def x0(self, p, k):
@staticmethod
def x0(p, k):

TPL_P_SYMS = p
TPL_K_SYMS = k
@@ -41,15 +44,17 @@ def x0(self, p, k):

return TPL_X0_RET

def x_solver(self, x):
@staticmethod
def x_solver(x):

TPL_X_RDATA_SYMS = x

TPL_X_SOLVER_EQ

return TPL_X_SOLVER_RET

def x_rdata(self, x, tcl):
@staticmethod
def x_rdata(x, tcl):

TPL_X_SYMS = x
TPL_TCL_SYMS = tcl
@@ -58,7 +63,8 @@ def x_rdata(self, x, tcl):

return TPL_X_RDATA_RET

def tcl(self, x, p, k):
@staticmethod
def tcl(x, p, k):

TPL_X_RDATA_SYMS = x
TPL_P_SYMS = p
@@ -68,18 +74,20 @@ def tcl(self, x, p, k):

return TPL_TOTAL_CL_RET

def y(self, t, x, p, k, tcl):
@staticmethod
def y(t, x, p, k, tcl):

TPL_X_SYMS = x
TPL_P_SYMS = p
TPL_K_SYMS = k
TPL_W_SYMS = self._w(t, x, p, k, tcl)
TPL_W_SYMS = JAXModel_TPL_MODEL_NAME._w(t, x, p, k, tcl)

TPL_Y_EQ

return TPL_Y_RET

def sigmay(self, y, p, k):
@staticmethod
def sigmay(y, p, k):
TPL_Y_SYMS = y
TPL_P_SYMS = p
TPL_K_SYMS = k
@@ -88,7 +96,8 @@ def sigmay(self, y, p, k):

return TPL_SIGMAY_RET

def Jy(self, y, my, sigmay):
@staticmethod
def Jy(y, my, sigmay):
TPL_Y_SYMS = y
TPL_MY_SYMS = my
TPL_SIGMAY_SYMS = sigmay
2 changes: 2 additions & 0 deletions python/sdist/setup.cfg
Original file line number Diff line number Diff line change
@@ -52,6 +52,8 @@ pysb = pysb>=1.13.1
jax =
jax
diffrax
equinox
optimistix
test =
benchmark_models_petab @ git+https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab.git@master#subdirectory=src/python
h5py
7 changes: 2 additions & 5 deletions tests/benchmark-models/test_petab_model.py
Original file line number Diff line number Diff line change
@@ -145,7 +145,6 @@ def main():
llh = res[LLH]

if args.model_name not in (
"Bachmann_MSB2011",
"Beer_MolBioSystems2014",
"Brannmark_JBC2010",
"Fujita_SciSignal2010",
@@ -154,7 +153,6 @@ def main():
"Weber_BMC2015",
"Zheng_PNAS2012",
):
# Bachmann: integration failure even with 1e6 steps
# Beer: Heaviside
# Brannmark_JBC2010: preeq
# Fujita: Heaviside
@@ -164,7 +162,6 @@ def main():
# Zheng_PNAS2012: preeq

jax_model = model_module.get_jax_model()
jax_solver = jax_model.get_solver()
simulation_conditions = (
problem.get_simulation_conditions_from_measurement_df()
)
@@ -191,9 +188,9 @@ def main():
amici_model=amici_model,
)
# run once to JIT
amici.jax.run_simulations(jax_model, jax_solver, edatas)
jax_model.run_simulations(edatas)
start_jax = timer()
rdatas_jax = amici.jax.run_simulations(jax_model, jax_solver, edatas)
rdatas_jax = jax_model.run_simulations(edatas)
end_jax = timer()

t_jax = end_jax - start_jax

Unchanged files with check annotations Beta

code = re.sub(r"numpy\.", r"jnp.", code)
return code
except TypeError as e:

Check warning on line 19 in python/sdist/amici/jaxcodeprinter.py

Codecov / codecov/patch

python/sdist/amici/jaxcodeprinter.py#L19

Added line #L19 was not covered by tests
raise ValueError(
f'Encountered unsupported function in expression "{expr}"'
) from e