Skip to content

Commit

Permalink
fix gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Oct 25, 2024
1 parent 5a86f4c commit 480b75a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/sdist/amici/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def srun(
dynamic=True,
):
(llh, (x, obs, stats)), sllh = (

Check warning on line 239 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L239

Added line #L239 was not covered by tests
jax.value_and_grad(self._run, 1, True)
jax.value_and_grad(self._run, 2, True)
)(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic)
return llh, sllh, (x, obs, stats)

Check warning on line 242 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L242

Added line #L242 was not covered by tests

Expand All @@ -254,10 +254,10 @@ def s2run(
dynamic=True,
):
(llh, (x, obs, stats)), sllh = (

Check warning on line 256 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L256

Added line #L256 was not covered by tests
jax.value_and_grad(self._run, 1, True)
jax.value_and_grad(self._run, 2, True)
)(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic)

s2llh = jax.hessian(self._run, 1, True)(
s2llh = jax.hessian(self._run, 2, True)(

Check warning on line 260 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L260

Added line #L260 was not covered by tests
ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic
)

Expand Down

0 comments on commit 480b75a

Please sign in to comment.