Skip to content

Commit

Permalink
add nan safe log&divide (#2611)
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich authored Dec 2, 2024
1 parent 851d389 commit f3a97c2
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 10 deletions.
2 changes: 1 addition & 1 deletion python/sdist/amici/jax.template.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from interpax import interp1d
from pathlib import Path

from amici.jax.model import JAXModel
from amici.jax.model import JAXModel, safe_log, safe_div


class JAXModel_TPL_MODEL_NAME(JAXModel):
Expand Down
36 changes: 36 additions & 0 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,3 +559,39 @@ def simulate_condition(
stats_dyn=stats_dyn,
stats_posteq=stats_posteq,
)


def safe_log(x: jnp.float_) -> jnp.float_:
"""
Safe logarithm that returns `jnp.log(jnp.finfo(jnp.float_).eps)` for x <= 0.
:param x:
input
:return:
logarithm of x
"""
# see https://docs.kidger.site/equinox/api/debug/, need double jnp.where to guard
# against nans in forward & backward passes
safe_x = jnp.where(
x > jnp.finfo(jnp.float_).eps, x, jnp.finfo(jnp.float_).eps
)
return jnp.where(
x > 0, jnp.log(safe_x), jnp.log(jnp.finfo(jnp.float_).eps)
)


def safe_div(x: jnp.float_, y: jnp.float_) -> jnp.float_:
"""
Safe division that returns `x/jnp.finfo(jnp.float_).eps` for `y == 0`.
:param x:
numerator
:param y:
denominator
:return:
x / y
"""
# see https://docs.kidger.site/equinox/api/debug/, need double jnp.where to guard
# against nans in forward & backward passes
safe_y = jnp.where(y != 0, y, jnp.finfo(jnp.float_).eps)
return jnp.where(y != 0, x / safe_y, x / jnp.finfo(jnp.float_).eps)
9 changes: 9 additions & 0 deletions python/sdist/amici/jaxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ def _print_AmiciSpline(self, expr: sp.Expr) -> str:
# FIXME: untested, where are spline nodes coming from anyways?
return f'interp1d(time, {self.doprint(expr.args[2:])}, kind="cubic")'

def _print_log(self, expr: sp.Expr) -> str:
return f"safe_log({self.doprint(expr.args[0])})"

def _print_Mul(self, expr: sp.Expr) -> str:
numer, denom = expr.as_numer_denom()
if denom == 1:
return super()._print_Mul(expr)
return f"safe_div({self.doprint(numer)}, {self.doprint(denom)})"

def _get_sym_lines(
self,
symbols: sp.Matrix | Iterable[str],
Expand Down
12 changes: 3 additions & 9 deletions tests/benchmark-models/test_petab_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,8 @@ def test_jax_llh(benchmark_problem):

np.random.seed(cur_settings.rng_seed)

problems_for_gradient_check_jax = list(
set(problems_for_gradient_check) - {"Laske_PLOSComputBiol2019"}
# Laske has nan values in gradient due to nan values in observables that are not used in the likelihood
# but are problematic during backpropagation
)

problem_parameters = None
if problem_id in problems_for_gradient_check_jax:
if problem_id in problems_for_gradient_check:
point = petab_problem.x_nominal_free_scaled
for _ in range(20):
amici_solver.setSensitivityMethod(amici.SensitivityMethod.adjoint)
Expand Down Expand Up @@ -361,14 +355,14 @@ def test_jax_llh(benchmark_problem):
err_msg=f"LLH mismatch for {problem_id}",
)

if problem_id in problems_for_gradient_check_jax:
if problem_id in problems_for_gradient_check:
sllh_amici = r_amici[SLLH]
np.testing.assert_allclose(
sllh_jax.parameters,
np.array([sllh_amici[pid] for pid in jax_problem.parameter_ids]),
rtol=1e-2,
atol=1e-2,
err_msg=f"SLLH mismatch for {problem_id}",
err_msg=f"SLLH mismatch for {problem_id}, {dict(zip(jax_problem.parameter_ids, sllh_jax.parameters))}",
)


Expand Down

0 comments on commit f3a97c2

Please sign in to comment.