Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Oct 25, 2024
1 parent 2f3834d commit 5366632
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions python/sdist/amici/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _run(
x0 = self.x0(ps, k)

Check warning on line 171 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L171

Added line #L171 was not covered by tests

# Dynamic simulation
if dynamic and ts_dyn.shape[0] > 0:
if dynamic == "true":
x, tcl, stats = self._solve(

Check warning on line 175 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L174-L175

Added lines #L174 - L175 were not covered by tests
ts_dyn, ps, k, x0, checkpointed=checkpointed
)
Expand Down Expand Up @@ -220,7 +220,9 @@ def run(
pscale: np.ndarray,
dynamic=True,
):
return self._run(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic)
return self._run(

Check warning on line 223 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L223

Added line #L223 was not covered by tests
ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic
)

@eqx.filter_jit
def srun(
Expand All @@ -236,7 +238,7 @@ def srun(
):
(llh, (x, obs, stats)), sllh = (

Check warning on line 239 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L239

Added line #L239 was not covered by tests
jax.value_and_grad(self._run, 1, True)
)(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic)
)(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic)
return llh, sllh, (x, obs, stats)

Check warning on line 242 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L242

Added line #L242 was not covered by tests

@eqx.filter_jit
Expand All @@ -253,10 +255,10 @@ def s2run(
):
(llh, (x, obs, stats)), sllh = (

Check warning on line 256 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L256

Added line #L256 was not covered by tests
jax.value_and_grad(self._run, 1, True)
)(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic)
)(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic)

s2llh = jax.hessian(self._run, 1, True)(

Check warning on line 260 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L260

Added line #L260 was not covered by tests
ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic
ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic
)

return llh, sllh, s2llh, (x, obs, stats)

Check warning on line 264 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L264

Added line #L264 was not covered by tests
Expand All @@ -271,28 +273,34 @@ def run_simulation(
my = np.asarray(edata.getObservedData())
pscale = np.asarray(edata.pscale)
ts_dyn = ts[np.isfinite(ts)]
dynamic = len(ts_dyn) > 0 and np.max(ts_dyn) > 0
dynamic = "true" if len(ts_dyn) and np.max(ts_dyn) > 0 else "false"

Check warning on line 276 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L269-L276

Added lines #L269 - L276 were not covered by tests

rdata_kwargs = dict()

Check warning on line 278 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L278

Added line #L278 was not covered by tests

if sensitivity_order == amici.SensitivityOrder.none:
(

Check warning on line 281 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L280-L281

Added lines #L280 - L281 were not covered by tests
rdata_kwargs["llh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.run(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic)
) = self.run(
ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic
)
elif sensitivity_order == amici.SensitivityOrder.first:
(

Check warning on line 288 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L287-L288

Added lines #L287 - L288 were not covered by tests
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.srun(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic)
) = self.srun(
ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic
)
elif sensitivity_order == amici.SensitivityOrder.second:
(

Check warning on line 296 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L295-L296

Added lines #L295 - L296 were not covered by tests
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
rdata_kwargs["s2llh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.s2run(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic)
) = self.s2run(
ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic
)

for field in rdata_kwargs.keys():
if field == "llh":
Expand Down

0 comments on commit 5366632

Please sign in to comment.