Skip to content

Commit

Permalink
refactor fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Nov 16, 2024
1 parent eaae778 commit d79cfc1
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 157 deletions.
2 changes: 1 addition & 1 deletion python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def jnp_array_str(array) -> str:
),
indent,
)
)
)[indent:]
for eq_name in eq_names
},
**{
Expand Down
32 changes: 13 additions & 19 deletions python/sdist/amici/jax.template.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@ class JAXModel_TPL_MODEL_NAME(JAXModel):
def __init__(self):
super().__init__()

@staticmethod
def xdot(t, x, args):
def xdot(self, t, x, args):

pk, tcl = args

TPL_X_SYMS = x
TPL_PK_SYMS = pk
TPL_TCL_SYMS = tcl
TPL_W_SYMS = JAXModel_TPL_MODEL_NAME._w(t, x, pk, tcl)
TPL_W_SYMS = self._w(t, x, pk, tcl)

TPL_XDOT_EQ
TPL_XDOT_EQ

return TPL_XDOT_RET

Expand All @@ -29,7 +28,7 @@ def _w(t, x, pk, tcl):
TPL_PK_SYMS = pk
TPL_TCL_SYMS = tcl

TPL_W_EQ
TPL_W_EQ

return TPL_W_RET

Expand All @@ -38,7 +37,7 @@ def x0(pk):

TPL_PK_SYMS = pk

TPL_X0_EQ
TPL_X0_EQ

return TPL_X0_RET

Expand All @@ -47,7 +46,7 @@ def x_solver(x):

TPL_X_RDATA_SYMS = x

TPL_X_SOLVER_EQ
TPL_X_SOLVER_EQ

return TPL_X_SOLVER_RET

Expand All @@ -57,7 +56,7 @@ def x_rdata(x, tcl):
TPL_X_SYMS = x
TPL_TCL_SYMS = tcl

TPL_X_RDATA_EQ
TPL_X_RDATA_EQ

return TPL_X_RDATA_RET

Expand All @@ -67,7 +66,7 @@ def tcl(x, pk):
TPL_X_RDATA_SYMS = x
TPL_PK_SYMS = pk

TPL_TOTAL_CL_EQ
TPL_TOTAL_CL_EQ

return TPL_TOTAL_CL_RET

Expand All @@ -77,7 +76,7 @@ def y(self, t, x, pk, tcl):
TPL_PK_SYMS = pk
TPL_W_SYMS = self._w(t, x, pk, tcl)

TPL_Y_EQ
TPL_Y_EQ

return TPL_Y_RET

Expand All @@ -86,24 +85,19 @@ def sigmay(self, y, pk):

TPL_Y_SYMS = y

TPL_SIGMAY_EQ
TPL_SIGMAY_EQ

return TPL_SIGMAY_RET


def llh(self, t, x, pk, tcl, my, iy):
y = self.y(t, x, pk, tcl)
TPL_Y_SYMS = y
sigmay = self.sigmay(y, pk)
TPL_SIGMAY_SYMS = sigmay
TPL_SIGMAY_SYMS = self.sigmay(y, pk)

TPL_JY_EQ
TPL_JY_EQ

return jnp.array([
TPL_JY_RET.at[iy].get(),
y.at[iy].get(),
sigmay.at[iy].get()
])
return TPL_JY_RET.at[iy].get()

@property
def observable_ids(self):
Expand Down
153 changes: 91 additions & 62 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import diffrax
import equinox as eqx
import jax.numpy as jnp
import numpy as np
import jax

# always use 64-bit precision. No-brainer on CPUs and GPUs don't make sense for stiff systems.
Expand All @@ -16,10 +15,12 @@ class JAXModel(eqx.Module):
JAXModel must provide model specific implementations of abstract methods.
"""

@staticmethod
@abstractmethod
def xdot(
t: jnp.float_, x: jnp.ndarray, args: tuple[jnp.ndarray, jnp.ndarray]
self,
t: jnp.float_,
x: jnp.ndarray,
args: tuple[jnp.ndarray, jnp.ndarray],
) -> jnp.ndarray:
"""
Right-hand side of the ODE system.
Expand Down Expand Up @@ -190,21 +191,6 @@ def parameter_ids(self) -> list[str]:
"""
...

def _preeq(self, p, solver, controller, max_steps):
"""
Pre-equilibration of the model.
:param p:
parameters
:return:
Initial state vector
"""
x0 = self.x_solver(self.x0(p))
tcl = self.tcl(x0, p)
return self._eq(p, tcl, x0, solver, controller, max_steps)

def _posteq(self, p, x, tcl, solver, controller, max_steps):
return self._eq(p, tcl, x, solver, controller, max_steps)

def _eq(self, p, tcl, x0, solver, controller, max_steps):
sol = diffrax.diffeqsolve(
diffrax.ODETerm(self.xdot),
Expand All @@ -216,27 +202,27 @@ def _eq(self, p, tcl, x0, solver, controller, max_steps):
y0=x0,
stepsize_controller=controller,
max_steps=max_steps,
adjoint=diffrax.DirectAdjoint(),
event=diffrax.Event(cond_fn=diffrax.steady_state_event()),
)
return sol.ys[-1, :]
return sol.ys[-1, :], sol.stats

def _solve(self, ts, p, x0, solver, controller, max_steps):
tcl = self.tcl(x0, p)
def _solve(self, p, ts, tcl, x0, solver, controller, max_steps, adjoint):
sol = diffrax.diffeqsolve(
diffrax.ODETerm(self.xdot),
solver,
args=(p, tcl),
t0=0.0,
t1=ts[-1],
dt0=None,
y0=self.x_solver(x0),
y0=x0,
stepsize_controller=controller,
max_steps=max_steps,
adjoint=diffrax.RecursiveCheckpointAdjoint(),
adjoint=adjoint,
saveat=diffrax.SaveAt(ts=ts),
throw=False,
)
return sol.ys, tcl, sol.stats
return sol.ys, sol.stats

def _x_rdata(self, x, tcl):
return jax.vmap(self.x_rdata, in_axes=(0, None))(x, tcl)
Expand All @@ -246,62 +232,105 @@ def _outputs(self, ts, x, p, tcl, my, iys) -> jnp.float_:
ts, x, p, tcl, my, iys
)

def _y(self, ts, xs, p, tcl, iys):
return jax.vmap(
lambda t, x, p, tcl, iy: self.y(t, x, p, tcl).at[iy].get(),
in_axes=(0, 0, None, None, 0),
)(ts, xs, p, tcl, iys)

def _sigmay(self, ts, xs, p, tcl, iys):
return jax.vmap(
lambda t, x, p, tcl, iy: self.sigmay(self.y(t, x, p, tcl), p)
.at[iy]
.get(),
in_axes=(0, 0, None, None, 0),
)(ts, xs, p, tcl, iys)

# @eqx.filter_jit
def simulate_condition(
self,
ts: np.ndarray,
ts_dyn: np.ndarray,
my: np.ndarray,
iys: np.ndarray,
p: jnp.ndarray,
p_preeq: jnp.ndarray,
dynamic: bool,
ts_preeq: jnp.ndarray,
ts_dyn: jnp.ndarray,
ts_posteq: jnp.ndarray,
my: jnp.ndarray,
iys: jnp.ndarray,
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
max_steps: int,
ret: str = "llh",
):
# Pre-equilibration
if p_preeq.shape[0] > 0:
x0 = self._preeq(p_preeq, solver, controller, max_steps)
x0 = self.x0(p_preeq)
tcl = self.tcl(x0, p_preeq)
current_x = self.x_solver(x0)
current_x, stats_preeq = self._eq(
p_preeq, tcl, current_x, solver, controller, max_steps
)
# update tcl with new parameters
tcl = self.tcl(self.x_rdata(current_x, tcl), p)
else:
x0 = self.x0(p)
current_x = self.x_solver(x0)
stats_preeq = None

tcl = self.tcl(x0, p)
x_preq = jnp.repeat(
current_x.reshape(1, -1), ts_preeq.shape[0], axis=0
)

# Dynamic simulation
if dynamic:
x, tcl, stats = self._solve(
ts_dyn, p, x0, solver, controller, max_steps
if ts_dyn.shape[0] > 0:
x_dyn, stats_dyn = self._solve(
p,
ts_dyn,
tcl,
current_x,
solver,
controller,
max_steps,
adjoint,
)
current_x = x_dyn[-1, :]
else:
x = jnp.repeat(
self.x_solver(x0).reshape(1, -1),
len(ts_dyn),
axis=0,
x_dyn = jnp.repeat(
current_x.reshape(1, -1), ts_dyn.shape[0], axis=0
)
tcl = self.tcl(x0, p)
stats = None
stats_dyn = None

# Post-equilibration
if len(ts) > len(ts_dyn):
if len(ts_dyn) > 0:
x_final = x[-1, :]
else:
x_final = self.x_solver(x0)
x_posteq = self._posteq(
p, x_final, tcl, solver, controller, max_steps
)
x_posteq = jnp.repeat(
x_posteq.reshape(1, -1),
len(ts) - len(ts_dyn),
axis=0,
if ts_posteq.shape[0] > 0:
current_x, stats_posteq = self._eq(
p, tcl, current_x, solver, controller, max_steps
)
if len(ts_dyn) > 0:
x = jnp.concatenate((x, x_posteq), axis=0)
else:
x = x_posteq

outputs = self._outputs(ts, x, p, tcl, my, iys)
llh = -jnp.sum(outputs[:, 0])
obs = outputs[:, 1]
sigmay = outputs[:, 2]
x_rdata = jnp.stack(self._x_rdata(x, tcl), axis=1)
return llh, dict(llh=llh, x=x_rdata, y=obs, sigmay=sigmay, stats=stats)
else:
stats_posteq = None

x_posteq = jnp.repeat(
current_x.reshape(1, -1), ts_posteq.shape[0], axis=0
)

ts = jnp.concatenate((ts_preeq, ts_dyn, ts_posteq), axis=0)
x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0)

llhs = self._outputs(ts, x, p, tcl, my, iys)
llh = -jnp.sum(llhs)
return {
"llh": llh,
"llhs": llhs,
"x": self._x_rdata(x, tcl),
"x_solver": x,
"y": self._y(ts, x, p, tcl, iys),
"sigmay": self._sigmay(ts, x, p, tcl, iys),
"x0": self.x_rdata(x_preq[-1, :], tcl),
"x0_solver": x_preq[-1, :],
"tcl": tcl,
"res": self._y(ts, x, p, tcl, iys) - my,
}[ret], dict(
stats_preeq=stats_preeq,
stats_dyn=stats_dyn,
stats_posteq=stats_posteq,
)
Loading

0 comments on commit d79cfc1

Please sign in to comment.