Skip to content

Commit

Permalink
add option for using diffrax diffeqsolve or iterative solve
Browse files Browse the repository at this point in the history
  • Loading branch information
etch4966 committed Dec 12, 2024
1 parent 9eaec4d commit f644d5b
Show file tree
Hide file tree
Showing 9 changed files with 417 additions and 413 deletions.
164 changes: 75 additions & 89 deletions docs/source/tutorials/biomodels_curation.ipynb

Large diffs are not rendered by default.

114 changes: 56 additions & 58 deletions docs/source/tutorials/gradient_descent.ipynb

Large diffs are not rendered by default.

175 changes: 81 additions & 94 deletions docs/source/tutorials/jax_model.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
import diffrax
import equinox as eqx
from functools import partial
from jax import jit, lax, vmap
from jax.experimental.ode import odeint
import jax.numpy as jnp
from diffrax import ODETerm, Tsit5, Dopri5, Dopri8, Euler, Midpoint, Heun, Bosh3, Ralston
from typing import Any

from sbmltoodejax import jaxfuncs

t0 = 0.0

y0 = jnp.array([90.0, 10.0, 280.0, 10.0, 10.0, 280.0, 10.0, 10.0])
y_indexes = {'MKKK': 0, 'MKKK_P': 1, 'MKK': 2, 'MKK_P': 3, 'MKK_PP': 4, 'MAPK': 5, 'MAPK_P': 6, 'MAPK_PP': 7}
y0 = jnp.array([0.0, 0.999999999999999, 0.0, 3.0, 7.0])
y_indexes = {'Timeract': 0, 'CellCact': 1, 'Effectoract': 2, 'HR': 3, 'NHEJ': 4}

w0 = jnp.array([])
w_indexes = {}
w0 = jnp.array([10.0, 10.0, 10.0, 9.000000000000002])
w_indexes = {'Effectorina': 0, 'Damage': 1, 'Timerinact': 2, 'CellCina': 3}

c = jnp.array([1.0, 2.5, 9.0, 1.0, 10.0, 0.25, 8.0, 0.025, 15.0, 0.025, 15.0, 0.75, 15.0, 0.75, 15.0, 0.025, 15.0, 0.025, 15.0, 0.5, 15.0, 0.5, 15.0])
c_indexes = {'uVol': 0, 'J0_V1': 1, 'J0_Ki': 2, 'J0_n': 3, 'J0_K1': 4, 'J1_V2': 5, 'J1_KK2': 6, 'J2_k3': 7, 'J2_KK3': 8, 'J3_k4': 9, 'J3_KK4': 10, 'J4_V5': 11, 'J4_KK5': 12, 'J5_V6': 13, 'J5_KK6': 14, 'J6_k7': 15, 'J6_KK7': 16, 'J7_k8': 17, 'J7_KK8': 18, 'J8_V9': 19, 'J8_KK9': 20, 'J9_V10': 21, 'J9_KK10': 22}
c = jnp.array([10.0, 10.0, 10.0, 2.0, 10.0, 1.0, 10.0, 1.0, 1.0, 1.0, 1.0, 10.0, 1.0])
c_indexes = {'CellCycletot': 0, 'Effectortot': 1, 'Timertot': 2, 'Kd2t': 3, 'Kti2t': 4, 'Kcc2ch': 5, 'Kt2cc': 6, 'Kcc2a': 7, 'Kd2ch': 8, 'Kch2cc': 9, 'Km1': 10, 'Km10': 11, 'nucleus': 12}

class RateofSpeciesChange(eqx.Module):
stoichiometricMatrix = jnp.array([[-1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, -1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, -1.0, 1.0, -1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, 1.0, -1.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, 0.0]], dtype=jnp.float32)
stoichiometricMatrix = jnp.array([[0.0], [0.0], [0.0], [0.0], [0.0]], dtype=jnp.float32)

@jit
def __call__(self, y, t, w, c):
rateRuleVector = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
rateRuleVector = jnp.array([self.RateTimeract(y, w, c, t), self.RateCellCact(y, w, c, t), self.RateEffectoract(y, w, c, t), self.RateHR(y, w, c, t), self.RateNHEJ(y, w, c, t)], dtype=jnp.float32)

reactionVelocities = self.calc_reaction_velocities(y, w, c, t)

Expand All @@ -34,120 +30,111 @@ def __call__(self, y, t, w, c):


def calc_reaction_velocities(self, y, w, c, t):
reactionVelocities = jnp.array([self.J0(y, w, c, t), self.J1(y, w, c, t), self.J2(y, w, c, t), self.J3(y, w, c, t), self.J4(y, w, c, t), self.J5(y, w, c, t), self.J6(y, w, c, t), self.J7(y, w, c, t), self.J8(y, w, c, t), self.J9(y, w, c, t)], dtype=jnp.float32)
reactionVelocities = jnp.array([0], dtype=jnp.float32)

return reactionVelocities

def RateTimeract(self, y, w, c, t):
return c[3] * (w[1]/1.0) * (w[2]/1.0) / (c[10] + (w[2]/1.0)) - c[4] * (y[0]/1.0) / (c[10] + (y[0]/1.0))

def J0(self, y, w, c, t):
return c[0] * c[1] * (y[0]/1.0) / ((1 + ((y[7]/1.0) / c[2])**c[3]) * (c[4] + (y[0]/1.0)))


def J1(self, y, w, c, t):
return c[0] * c[5] * (y[1]/1.0) / (c[6] + (y[1]/1.0))


def J2(self, y, w, c, t):
return c[0] * c[7] * (y[1]/1.0) * (y[2]/1.0) / (c[8] + (y[2]/1.0))


def J3(self, y, w, c, t):
return c[0] * c[9] * (y[1]/1.0) * (y[3]/1.0) / (c[10] + (y[3]/1.0))

def RateCellCact(self, y, w, c, t):
return (c[7] + (y[1]/1.0)) * (w[3]/1.0) / (c[11] + (w[3]/1.0)) - c[6] * (y[0]/1.0) * (y[1]/1.0) / (c[11] + (y[1]/1.0)) - c[9] * (y[1]/1.0) * (y[2]/1.0) / (c[11] + (y[1]/1.0))

def J4(self, y, w, c, t):
return c[0] * c[11] * (y[4]/1.0) / (c[12] + (y[4]/1.0))
def RateEffectoract(self, y, w, c, t):
return c[8] * (w[1]/1.0) * (w[0]/1.0) / (c[11] + (w[0]/1.0)) - c[5] * (y[1]/1.0) * (y[2]/1.0) / (c[11] + (y[2]/1.0))

def RateHR(self, y, w, c, t):
return -(y[3]/1.0) * 0.2

def J5(self, y, w, c, t):
return c[0] * c[13] * (y[3]/1.0) / (c[14] + (y[3]/1.0))


def J6(self, y, w, c, t):
return c[0] * c[15] * (y[4]/1.0) * (y[5]/1.0) / (c[16] + (y[5]/1.0))


def J7(self, y, w, c, t):
return c[0] * c[17] * (y[4]/1.0) * (y[6]/1.0) / (c[18] + (y[6]/1.0))
def RateNHEJ(self, y, w, c, t):
return -(y[4]/1.0) * 0.5

class AssignmentRule(eqx.Module):
@jit
def __call__(self, y, w, c, t):
w = w.at[0].set(1.0 * ((c[1]/1.0) - (y[2]/1.0)))

def J8(self, y, w, c, t):
return c[0] * c[19] * (y[7]/1.0) / (c[20] + (y[7]/1.0))
w = w.at[1].set(1.0 * ((y[3]/1.0) + (y[4]/1.0)))

w = w.at[2].set(1.0 * ((c[2]/1.0) - (y[0]/1.0)))

def J9(self, y, w, c, t):
return c[0] * c[21] * (y[6]/1.0) / (c[22] + (y[6]/1.0))
w = w.at[3].set(1.0 * ((c[0]/1.0) - (y[1]/1.0)))

class AssignmentRule(eqx.Module):
@jit
def __call__(self, y, w, c, t):
return w

class ModelStep(eqx.Module):
class ModelRollout(eqx.Module):
y_indexes: dict = eqx.static_field()
w_indexes: dict = eqx.static_field()
c_indexes: dict = eqx.static_field()
ratefunc: RateofSpeciesChange
atol: float = eqx.static_field()
rtol: float = eqx.static_field()
mxstep: int = eqx.static_field()
assignmentfunc: AssignmentRule
solver_type: str = eqx.static_field()
solver: Any = eqx.static_field()
ode_term: diffrax.ODETerm
solver: diffrax.AbstractERK = eqx.static_field()
iterative_solve: bool = eqx.static_field()

def __init__(self, y_indexes={'MKKK': 0, 'MKKK_P': 1, 'MKK': 2, 'MKK_P': 3, 'MKK_PP': 4, 'MAPK': 5, 'MAPK_P': 6, 'MAPK_PP': 7}, w_indexes={}, c_indexes={'uVol': 0, 'J0_V1': 1, 'J0_Ki': 2, 'J0_n': 3, 'J0_K1': 4, 'J1_V2': 5, 'J1_KK2': 6, 'J2_k3': 7, 'J2_KK3': 8, 'J3_k4': 9, 'J3_KK4': 10, 'J4_V5': 11, 'J4_KK5': 12, 'J5_V6': 13, 'J5_KK6': 14, 'J6_k7': 15, 'J6_KK7': 16, 'J7_k8': 17, 'J7_KK8': 18, 'J8_V9': 19, 'J8_KK9': 20, 'J9_V10': 21, 'J9_KK10': 22}, atol=1e-06, rtol=1e-12, mxstep=5000000, solver_type='diffrax', diffrax_solver='Tsit5'):
def __init__(self, y_indexes={'Timeract': 0, 'CellCact': 1, 'Effectoract': 2, 'HR': 3, 'NHEJ': 4}, w_indexes={'Effectorina': 0, 'Damage': 1, 'Timerinact': 2, 'CellCina': 3}, c_indexes={'CellCycletot': 0, 'Effectortot': 1, 'Timertot': 2, 'Kd2t': 3, 'Kti2t': 4, 'Kcc2ch': 5, 'Kt2cc': 6, 'Kcc2a': 7, 'Kd2ch': 8, 'Kch2cc': 9, 'Km1': 10, 'Km10': 11, 'nucleus': 12}, solver=diffrax.Tsit5(), iterative_solve=True):

self.y_indexes = y_indexes
self.w_indexes = w_indexes
self.c_indexes = c_indexes

self.ratefunc = RateofSpeciesChange()
self.rtol = rtol
self.atol = atol
self.mxstep = mxstep
self.assignmentfunc = AssignmentRule()
self.solver_type = solver_type
if solver_type == 'odeint':
self.solver = odeint
elif solver_type == 'diffrax':
from diffrax import ODETerm, Tsit5, Dopri5, Dopri8, Euler, Midpoint, Heun, Bosh3, Ralston
valid_solvers = {'Tsit5', 'Dopri5', 'Dopri8', 'Euler', 'Midpoint', 'Heun', 'Bosh3', 'Ralston'}
if diffrax_solver not in valid_solvers:
raise ValueError(f'Unknown diffrax solver: {diffrax_solver}')
self.solver = eval(diffrax_solver)()
else:
raise ValueError(f'Unknown solver type: {solver_type}')

@jit
def __call__(self, y, w, c, t, deltaT):
if self.solver_type == 'odeint':
y_new = self.solver(self.ratefunc, y, jnp.array([t, t + deltaT]), w, c, atol=self.atol, rtol=self.rtol, mxstep=self.mxstep)[-1]
else: # diffrax
term = ODETerm(lambda t, y, args: self.ratefunc(y, t, *args))
tprev, tnext = t, t + deltaT
state = self.solver.init(term, tprev, tnext, y, (w, c))
y_new, _, _, _, _ = self.solver.step(term, tprev, tnext, y, (w, c), state, made_jump=False)
def ode_func(t, y, args):
w, c = args
# Update w using the assignment rule
w = self.assignmentfunc(y, w, c, t)

# Calculate the rate of change
dy_dt = self.ratefunc(y, t, w, c)

return dy_dt

self.ode_term = diffrax.ODETerm(ode_func)
self.solver = solver
self.iterative_solve = iterative_solve

@eqx.filter_jit
def step(self, y, w, c, t, deltaT=0.1):
t_new = t + deltaT
state = self.solver.init(self.ode_term, t, t_new, y, (w, c))
y_new, _, _, _, _ = self.solver.step(self.ode_term, t, t_new, y, (w, c), state, made_jump=False)
w_new = self.assignmentfunc(y_new, w, c, t_new)
return y_new, w_new, c, t_new

class ModelRollout(eqx.Module):
deltaT: float = eqx.static_field()
modelstepfunc: ModelStep

def __init__(self, deltaT=0.1, atol=1e-06, rtol=1e-12, mxstep=5000000, solver_type='diffrax', diffrax_solver='Tsit5'):
@eqx.filter_jit
def __call__(self, t1,y0=jnp.array([0.0, 0.999999999999999, 0.0, 3.0, 7.0]), w0=jnp.array([10.0, 10.0, 10.0, 9.000000000000002]), c=jnp.array([10.0, 10.0, 10.0, 2.0, 10.0, 1.0, 10.0, 1.0, 1.0, 1.0, 1.0, 10.0, 1.0]), t0=0.0, deltaT=0.1, stepsize_controller=diffrax.PIDController(atol=1e-06, rtol=1e-12), max_steps=5000000):

self.deltaT = deltaT
self.modelstepfunc = ModelStep(atol=atol, rtol=rtol, mxstep=mxstep, solver_type=solver_type, diffrax_solver=diffrax_solver)
# Number of steps
n_steps = int(t1 / deltaT)

@partial(jit, static_argnames=("n_steps",))
def __call__(self, n_steps, y0=jnp.array([90.0, 10.0, 280.0, 10.0, 10.0, 280.0, 10.0, 10.0]), w0=jnp.array([]), c=jnp.array([1.0, 2.5, 9.0, 1.0, 10.0, 0.25, 8.0, 0.025, 15.0, 0.025, 15.0, 0.75, 15.0, 0.75, 15.0, 0.025, 15.0, 0.025, 15.0, 0.5, 15.0, 0.5, 15.0]), t0=0.0):
# Solve the ODE system
if self.iterative_solve:
def f(carry, x):
y, w, c, t = carry
return self.step(y, w, c, t, deltaT), (y, w, t)
(y, w, c, t), (ys, ws, ts) = lax.scan(f, (y0, w0, c, t0), jnp.arange(n_steps))

@jit
def f(carry, x):
y, w, c, t = carry
return self.modelstepfunc(y, w, c, t, self.deltaT), (y, w, t)
(y, w, c, t), (ys, ws, ts) = lax.scan(f, (y0, w0, c, t0), jnp.arange(n_steps))
ys = jnp.moveaxis(ys, 0, -1)
ws = jnp.moveaxis(ws, 0, -1)
else:
sol = diffrax.diffeqsolve(
self.ode_term,
self.solver,
t0=t0,
t1=t1,
dt0=deltaT,
y0=y0,
args=(w0, c),
saveat=diffrax.SaveAt(ts=jnp.linspace(t0, t1, n_steps)),
stepsize_controller=stepsize_controller,
max_steps=max_steps
)

# Extract results and recompute ws
ts = sol.ts
ys = sol.ys
ws = vmap(lambda t, y: self.assignmentfunc(y, w0, c, t))(ts, ys)
ys = jnp.moveaxis(ys, 0, -1) #(n_species, n_steps)
ws = jnp.moveaxis(ws, 0, -1) #(n_params, n_steps)
return ys, ws, ts

13 changes: 6 additions & 7 deletions docs/source/tutorials/parallel_execution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,7 @@
"source": [
"# Run simulation\n",
"n_secs = 100\n",
"n_steps = int(n_secs / model.deltaT)\n",
"default_ys, default_ws, ts = model(n_steps)"
"default_ys, default_ws, ts = model(n_secs)"
]
},
{
Expand Down Expand Up @@ -261,8 +260,8 @@
],
"source": [
"# Plot time-course evolution and corresponding trajectories in phase space\n",
"plot_time_trajectory(ts, default_ys, model.modelstepfunc.y_indexes)\n",
"plot_phase_space_trajectories(default_ys, model.modelstepfunc.y_indexes)"
"plot_time_trajectory(ts, default_ys, model.y_indexes)\n",
"plot_phase_space_trajectories(default_ys, model.y_indexes)"
]
},
{
Expand Down Expand Up @@ -359,7 +358,7 @@
"batched_model = vmap(model, in_axes=(None, 0), out_axes=(0, 0, None))\n",
"\n",
"# run simulation in batch mode\n",
"batched_ys, batched_ws, ts = batched_model(n_steps, batched_y0)"
"batched_ys, batched_ws, ts = batched_model(n_secs, batched_y0)"
]
},
{
Expand All @@ -372,7 +371,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "78db96eb-5a14-483e-9212-4500cae401b2",
"metadata": {},
"outputs": [
Expand All @@ -388,7 +387,7 @@
}
],
"source": [
"plot_phase_space_trajectories(batched_ys, model.modelstepfunc.y_indexes)"
"plot_phase_space_trajectories(batched_ys, model.y_indexes)"
]
}
],
Expand Down
Loading

0 comments on commit f644d5b

Please sign in to comment.