From f3a97c20a638a553a7440bbde299bfb2a83cf10d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Mon, 2 Dec 2024 15:37:21 +0000 Subject: [PATCH] add nan safe log÷ (#2611) --- python/sdist/amici/jax.template.py | 2 +- python/sdist/amici/jax/model.py | 36 +++++++++++++++++++ python/sdist/amici/jaxcodeprinter.py | 9 +++++ .../benchmark-models/test_petab_benchmark.py | 12 ++----- 4 files changed, 49 insertions(+), 10 deletions(-) diff --git a/python/sdist/amici/jax.template.py b/python/sdist/amici/jax.template.py index 9b566281ca..ddddb8a64b 100644 --- a/python/sdist/amici/jax.template.py +++ b/python/sdist/amici/jax.template.py @@ -3,7 +3,7 @@ from interpax import interp1d from pathlib import Path -from amici.jax.model import JAXModel +from amici.jax.model import JAXModel, safe_log, safe_div class JAXModel_TPL_MODEL_NAME(JAXModel): diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index e037c44a2f..ac86b547a6 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -559,3 +559,39 @@ def simulate_condition( stats_dyn=stats_dyn, stats_posteq=stats_posteq, ) + + +def safe_log(x: jnp.float_) -> jnp.float_: + """ + Safe logarithm that returns `jnp.log(jnp.finfo(jnp.float_).eps)` for x <= 0. + + :param x: + input + :return: + logarithm of x + """ + # see https://docs.kidger.site/equinox/api/debug/, need double jnp.where to guard + # against nans in forward & backward passes + safe_x = jnp.where( + x > jnp.finfo(jnp.float_).eps, x, jnp.finfo(jnp.float_).eps + ) + return jnp.where( + x > 0, jnp.log(safe_x), jnp.log(jnp.finfo(jnp.float_).eps) + ) + + +def safe_div(x: jnp.float_, y: jnp.float_) -> jnp.float_: + """ + Safe division that returns `x/jnp.finfo(jnp.float_).eps` for `y == 0`. + + :param x: + numerator + :param y: + denominator + :return: + x / y + """ + # see https://docs.kidger.site/equinox/api/debug/, need double jnp.where to guard + # against nans in forward & backward passes + safe_y = jnp.where(y != 0, y, jnp.finfo(jnp.float_).eps) + return jnp.where(y != 0, x / safe_y, x / jnp.finfo(jnp.float_).eps) diff --git a/python/sdist/amici/jaxcodeprinter.py b/python/sdist/amici/jaxcodeprinter.py index ed9181cc09..6cfce97b35 100644 --- a/python/sdist/amici/jaxcodeprinter.py +++ b/python/sdist/amici/jaxcodeprinter.py @@ -27,6 +27,15 @@ def _print_AmiciSpline(self, expr: sp.Expr) -> str: # FIXME: untested, where are spline nodes coming from anyways? return f'interp1d(time, {self.doprint(expr.args[2:])}, kind="cubic")' + def _print_log(self, expr: sp.Expr) -> str: + return f"safe_log({self.doprint(expr.args[0])})" + + def _print_Mul(self, expr: sp.Expr) -> str: + numer, denom = expr.as_numer_denom() + if denom == 1: + return super()._print_Mul(expr) + return f"safe_div({self.doprint(numer)}, {self.doprint(denom)})" + def _get_sym_lines( self, symbols: sp.Matrix | Iterable[str], diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index 2c56089409..74c84d37a9 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -299,14 +299,8 @@ def test_jax_llh(benchmark_problem): np.random.seed(cur_settings.rng_seed) - problems_for_gradient_check_jax = list( - set(problems_for_gradient_check) - {"Laske_PLOSComputBiol2019"} - # Laske has nan values in gradient due to nan values in observables that are not used in the likelihood - # but are problematic during backpropagation - ) - problem_parameters = None - if problem_id in problems_for_gradient_check_jax: + if problem_id in problems_for_gradient_check: point = petab_problem.x_nominal_free_scaled for _ in range(20): amici_solver.setSensitivityMethod(amici.SensitivityMethod.adjoint) @@ -361,14 +355,14 @@ def test_jax_llh(benchmark_problem): err_msg=f"LLH mismatch for {problem_id}", ) - if problem_id in problems_for_gradient_check_jax: + if problem_id in problems_for_gradient_check: sllh_amici = r_amici[SLLH] np.testing.assert_allclose( sllh_jax.parameters, np.array([sllh_amici[pid] for pid in jax_problem.parameter_ids]), rtol=1e-2, atol=1e-2, - err_msg=f"SLLH mismatch for {problem_id}", + err_msg=f"SLLH mismatch for {problem_id}, {dict(zip(jax_problem.parameter_ids, sllh_jax.parameters))}", )