Skip to content

Commit

Permalink
update diffrax & jaxlib (#2632)
Browse files Browse the repository at this point in the history
* update diffrax, remove ts_init

* fix petab simulate

* Update pyproject.toml

* Update pytest.ini
  • Loading branch information
FFroehlich authored Dec 18, 2024
1 parent 4522522 commit 3d73624
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 31 deletions.
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,7 @@ filterwarnings =
ignore:.*PyDevIPCompleter6.*:DeprecationWarning
# ignore numpy log(0) warnings (np.log(0) = -inf)
ignore:divide by zero encountered in log:RuntimeWarning
# ignore jax deprecation warnings
ignore:jax.* is deprecated:DeprecationWarning

norecursedirs = .git amici_models build doc documentation matlab models ThirdParty amici sdist examples
3 changes: 1 addition & 2 deletions python/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@
"simulation_condition = (\"model1_data1\",)\n",
"\n",
"# Load condition-specific data\n",
"ts_init, ts_dyn, ts_posteq, my, iys, iy_trafos = jax_problem._measurements[\n",
"ts_dyn, ts_posteq, my, iys, iy_trafos = jax_problem._measurements[\n",
" simulation_condition\n",
"]\n",
"\n",
Expand All @@ -371,7 +371,6 @@
"def grad_ts_dyn(tt):\n",
" return jax_problem.model.simulate_condition(\n",
" p=p,\n",
" ts_init=ts_init,\n",
" ts_dyn=tt,\n",
" ts_posteq=ts_posteq,\n",
" my=jnp.array(my),\n",
Expand Down
15 changes: 4 additions & 11 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,6 @@ def _sigmays(
def simulate_condition(
self,
p: jt.Float[jt.Array, "np"],
ts_init: jt.Float[jt.Array, "nt_preeq"],
ts_dyn: jt.Float[jt.Array, "nt_dyn"],
ts_posteq: jt.Float[jt.Array, "nt_posteq"],
my: jt.Float[jt.Array, "nt"],
Expand All @@ -462,13 +461,9 @@ def simulate_condition(
:param p:
parameters for simulation ordered according to ids in :ivar parameter_ids:
:param ts_init:
time points that do not require simulation. Usually valued 0.0, but needs to be shaped according to
the number of observables that are evaluated before dynamic simulation.
:param ts_dyn:
time points for dynamic simulation. Usually valued > 0.0 and sorted in monotonically increasing order.
Duplicate time points are allowed to facilitate the evaluation of multiple observables at specific time
points.
time points for dynamic simulation. Sorted in monotonically increasing order but duplicate time points are
allowed to facilitate the evaluation of multiple observables at specific time points.
:param ts_posteq:
time points for post-equilibration. Usually valued \Infty, but needs to be shaped according to
the number of observables that are evaluated after post-equilibration.
Expand Down Expand Up @@ -508,8 +503,6 @@ def simulate_condition(
x_solver = self._x_solver(x)
tcl = self._tcl(x, p)

x_preq = jnp.repeat(x_solver.reshape(1, -1), ts_init.shape[0], axis=0)

# Dynamic simulation
if ts_dyn.shape[0]:
x_dyn, stats_dyn = self._solve(
Expand Down Expand Up @@ -541,8 +534,8 @@ def simulate_condition(
x_solver.reshape(1, -1), ts_posteq.shape[0], axis=0
)

ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0)
x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0)
ts = jnp.concatenate((ts_dyn, ts_posteq), axis=0)
x = jnp.concatenate((x_dyn, x_posteq), axis=0)

nllhs = self._nllhs(ts, x, p, tcl, my, iys)
llh = -jnp.sum(nllhs)
Expand Down
16 changes: 5 additions & 11 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class JAXProblem(eqx.Module):
:ivar _parameter_mappings:
:class:`ParameterMappingForCondition` instances for each simulation condition.
:ivar _measurements:
Subset measurement dataframes for each simulation condition.
Preprocessed arrays for each simulation condition.
:ivar _petab_problem:
PEtab problem to simulate.
"""
Expand All @@ -87,7 +87,6 @@ class JAXProblem(eqx.Module):
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
],
]
_petab_measurement_indices: dict[tuple[str, ...], tuple[int, ...]]
Expand Down Expand Up @@ -187,7 +186,6 @@ def _get_measurements(
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
],
],
dict[tuple[str, ...], tuple[int, ...]],
Expand All @@ -213,11 +211,9 @@ def _get_measurements(
)

ts = m[petab.TIME]
ts_preeq = ts[np.isfinite(ts) & (ts == 0)]
ts_dyn = ts[np.isfinite(ts) & (ts > 0)]
ts_dyn = ts[np.isfinite(ts)]
ts_posteq = ts[np.logical_not(np.isfinite(ts))]
index = pd.concat([ts_preeq, ts_dyn, ts_posteq]).index
ts_preeq = ts_preeq.values
index = pd.concat([ts_dyn, ts_posteq]).index
ts_dyn = ts_dyn.values
ts_posteq = ts_posteq.values
my = m[petab.MEASUREMENT].values
Expand Down Expand Up @@ -245,7 +241,6 @@ def _get_measurements(
iy_trafos = np.zeros_like(iys)

measurements[tuple(simulation_condition)] = (
ts_preeq,
ts_dyn,
ts_posteq,
my,
Expand Down Expand Up @@ -492,7 +487,7 @@ def run_simulation(
:return:
Tuple of output value and simulation statistics
"""
ts_preeq, ts_dyn, ts_posteq, my, iys, iy_trafos = self._measurements[
ts_dyn, ts_posteq, my, iys, iy_trafos = self._measurements[
simulation_condition
]
p = self.load_parameters(simulation_condition[0])
Expand All @@ -501,7 +496,6 @@ def run_simulation(
)
return self.model.simulate_condition(
p=p,
ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)),
ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)),
ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)),
my=jax.lax.stop_gradient(jnp.array(my)),
Expand Down Expand Up @@ -650,7 +644,7 @@ def petab_simulate(
for sc, ys in y.items():
obs = [
problem.model.observable_ids[io]
for io in problem._measurements[sc][4]
for io in problem._measurements[sc][3]
]
t = jnp.concat(problem._measurements[sc][:2])
df_sc = pd.DataFrame(
Expand Down
8 changes: 4 additions & 4 deletions python/sdist/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ examples = [
"scipy",
]
jax = [
"jax>=0.4.34,<0.4.36",
"jaxlib>=0.4.34",
"diffrax>=0.6.0",
"jax>=0.4.36",
"jaxlib>=0.4.36",
"diffrax>=0.6.1",
"jaxtyping>=0.2.34",
"equinox>=0.11.8",
"equinox>=0.11.10",
"optimistix>=0.0.9",
"interpax>=0.3.3",
]
Expand Down
4 changes: 1 addition & 3 deletions python/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ def check_fields_jax(
iys = iys.flatten()
iy_trafos = np.zeros_like(iys)

ts_init = ts[ts == 0]
ts_dyn = ts[ts > 0]
ts_dyn = ts
ts_posteq = np.array([])

par_dict = {
Expand All @@ -190,7 +189,6 @@ def check_fields_jax(

p = jnp.array([par_dict[par_id] for par_id in jax_model.parameter_ids])
kwargs = {
"ts_init": jnp.array(ts_init),
"ts_dyn": jnp.array(ts_dyn),
"ts_posteq": jnp.array(ts_posteq),
"my": jnp.array(my),
Expand Down

0 comments on commit 3d73624

Please sign in to comment.