Skip to content

Commit

Permalink
rerun tutorials with diffrax solver
Browse files Browse the repository at this point in the history
  • Loading branch information
etch4966 committed Dec 6, 2024
1 parent f83e67b commit 9eaec4d
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 116 deletions.
147 changes: 109 additions & 38 deletions docs/source/tutorials/benchmark.ipynb

Large diffs are not rendered by default.

54 changes: 27 additions & 27 deletions docs/source/tutorials/biomodels_curation.ipynb

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions docs/source/tutorials/gradient_descent.ipynb

Large diffs are not rendered by default.

107 changes: 72 additions & 35 deletions docs/source/tutorials/jax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,28 @@
from jax import jit, lax, vmap
from jax.experimental.ode import odeint
import jax.numpy as jnp
from diffrax import ODETerm, Tsit5, Dopri5, Dopri8, Euler, Midpoint, Heun, Bosh3, Ralston
from typing import Any

from sbmltoodejax import jaxfuncs

t0 = 0.0

y0 = jnp.array([0.0, 0.999999999999999, 0.0, 3.0, 7.0])
y_indexes = {'Timeract': 0, 'CellCact': 1, 'Effectoract': 2, 'HR': 3, 'NHEJ': 4}
y0 = jnp.array([90.0, 10.0, 280.0, 10.0, 10.0, 280.0, 10.0, 10.0])
y_indexes = {'MKKK': 0, 'MKKK_P': 1, 'MKK': 2, 'MKK_P': 3, 'MKK_PP': 4, 'MAPK': 5, 'MAPK_P': 6, 'MAPK_PP': 7}

w0 = jnp.array([10.0, 10.0, 10.0, 9.000000000000002])
w_indexes = {'Effectorina': 0, 'Damage': 1, 'Timerinact': 2, 'CellCina': 3}
w0 = jnp.array([])
w_indexes = {}

c = jnp.array([10.0, 10.0, 10.0, 2.0, 10.0, 1.0, 10.0, 1.0, 1.0, 1.0, 1.0, 10.0, 1.0])
c_indexes = {'CellCycletot': 0, 'Effectortot': 1, 'Timertot': 2, 'Kd2t': 3, 'Kti2t': 4, 'Kcc2ch': 5, 'Kt2cc': 6, 'Kcc2a': 7, 'Kd2ch': 8, 'Kch2cc': 9, 'Km1': 10, 'Km10': 11, 'nucleus': 12}
c = jnp.array([1.0, 2.5, 9.0, 1.0, 10.0, 0.25, 8.0, 0.025, 15.0, 0.025, 15.0, 0.75, 15.0, 0.75, 15.0, 0.025, 15.0, 0.025, 15.0, 0.5, 15.0, 0.5, 15.0])
c_indexes = {'uVol': 0, 'J0_V1': 1, 'J0_Ki': 2, 'J0_n': 3, 'J0_K1': 4, 'J1_V2': 5, 'J1_KK2': 6, 'J2_k3': 7, 'J2_KK3': 8, 'J3_k4': 9, 'J3_KK4': 10, 'J4_V5': 11, 'J4_KK5': 12, 'J5_V6': 13, 'J5_KK6': 14, 'J6_k7': 15, 'J6_KK7': 16, 'J7_k8': 17, 'J7_KK8': 18, 'J8_V9': 19, 'J8_KK9': 20, 'J9_V10': 21, 'J9_KK10': 22}

class RateofSpeciesChange(eqx.Module):
stoichiometricMatrix = jnp.array([[0.0], [0.0], [0.0], [0.0], [0.0]], dtype=jnp.float32)
stoichiometricMatrix = jnp.array([[-1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, -1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, -1.0, 1.0, -1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, 1.0, -1.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, 0.0]], dtype=jnp.float32)

@jit
def __call__(self, y, t, w, c):
rateRuleVector = jnp.array([self.RateTimeract(y, w, c, t), self.RateCellCact(y, w, c, t), self.RateEffectoract(y, w, c, t), self.RateHR(y, w, c, t), self.RateNHEJ(y, w, c, t)], dtype=jnp.float32)
rateRuleVector = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)

reactionVelocities = self.calc_reaction_velocities(y, w, c, t)

Expand All @@ -32,36 +34,53 @@ def __call__(self, y, t, w, c):


def calc_reaction_velocities(self, y, w, c, t):
reactionVelocities = jnp.array([0], dtype=jnp.float32)
reactionVelocities = jnp.array([self.J0(y, w, c, t), self.J1(y, w, c, t), self.J2(y, w, c, t), self.J3(y, w, c, t), self.J4(y, w, c, t), self.J5(y, w, c, t), self.J6(y, w, c, t), self.J7(y, w, c, t), self.J8(y, w, c, t), self.J9(y, w, c, t)], dtype=jnp.float32)

return reactionVelocities

def RateTimeract(self, y, w, c, t):
return c[3] * (w[1]/1.0) * (w[2]/1.0) / (c[10] + (w[2]/1.0)) - c[4] * (y[0]/1.0) / (c[10] + (y[0]/1.0))

def RateCellCact(self, y, w, c, t):
return (c[7] + (y[1]/1.0)) * (w[3]/1.0) / (c[11] + (w[3]/1.0)) - c[6] * (y[0]/1.0) * (y[1]/1.0) / (c[11] + (y[1]/1.0)) - c[9] * (y[1]/1.0) * (y[2]/1.0) / (c[11] + (y[1]/1.0))
def J0(self, y, w, c, t):
return c[0] * c[1] * (y[0]/1.0) / ((1 + ((y[7]/1.0) / c[2])**c[3]) * (c[4] + (y[0]/1.0)))

def RateEffectoract(self, y, w, c, t):
return c[8] * (w[1]/1.0) * (w[0]/1.0) / (c[11] + (w[0]/1.0)) - c[5] * (y[1]/1.0) * (y[2]/1.0) / (c[11] + (y[2]/1.0))

def RateHR(self, y, w, c, t):
return -(y[3]/1.0) * 0.2
def J1(self, y, w, c, t):
return c[0] * c[5] * (y[1]/1.0) / (c[6] + (y[1]/1.0))

def RateNHEJ(self, y, w, c, t):
return -(y[4]/1.0) * 0.5

class AssignmentRule(eqx.Module):
@jit
def __call__(self, y, w, c, t):
w = w.at[0].set(1.0 * ((c[1]/1.0) - (y[2]/1.0)))
def J2(self, y, w, c, t):
return c[0] * c[7] * (y[1]/1.0) * (y[2]/1.0) / (c[8] + (y[2]/1.0))


def J3(self, y, w, c, t):
return c[0] * c[9] * (y[1]/1.0) * (y[3]/1.0) / (c[10] + (y[3]/1.0))


def J4(self, y, w, c, t):
return c[0] * c[11] * (y[4]/1.0) / (c[12] + (y[4]/1.0))


def J5(self, y, w, c, t):
return c[0] * c[13] * (y[3]/1.0) / (c[14] + (y[3]/1.0))

w = w.at[1].set(1.0 * ((y[3]/1.0) + (y[4]/1.0)))

w = w.at[2].set(1.0 * ((c[2]/1.0) - (y[0]/1.0)))
def J6(self, y, w, c, t):
return c[0] * c[15] * (y[4]/1.0) * (y[5]/1.0) / (c[16] + (y[5]/1.0))

w = w.at[3].set(1.0 * ((c[0]/1.0) - (y[1]/1.0)))

def J7(self, y, w, c, t):
return c[0] * c[17] * (y[4]/1.0) * (y[6]/1.0) / (c[18] + (y[6]/1.0))


def J8(self, y, w, c, t):
return c[0] * c[19] * (y[7]/1.0) / (c[20] + (y[7]/1.0))


def J9(self, y, w, c, t):
return c[0] * c[21] * (y[6]/1.0) / (c[22] + (y[6]/1.0))

class AssignmentRule(eqx.Module):
@jit
def __call__(self, y, w, c, t):
return w

class ModelStep(eqx.Module):
Expand All @@ -73,37 +92,55 @@ class ModelStep(eqx.Module):
rtol: float = eqx.static_field()
mxstep: int = eqx.static_field()
assignmentfunc: AssignmentRule
solver_type: str = eqx.static_field()
solver: Any = eqx.static_field()

def __init__(self, y_indexes={'Timeract': 0, 'CellCact': 1, 'Effectoract': 2, 'HR': 3, 'NHEJ': 4}, w_indexes={'Effectorina': 0, 'Damage': 1, 'Timerinact': 2, 'CellCina': 3}, c_indexes={'CellCycletot': 0, 'Effectortot': 1, 'Timertot': 2, 'Kd2t': 3, 'Kti2t': 4, 'Kcc2ch': 5, 'Kt2cc': 6, 'Kcc2a': 7, 'Kd2ch': 8, 'Kch2cc': 9, 'Km1': 10, 'Km10': 11, 'nucleus': 12}, atol=1e-06, rtol=1e-12, mxstep=5000000):
def __init__(self, y_indexes={'MKKK': 0, 'MKKK_P': 1, 'MKK': 2, 'MKK_P': 3, 'MKK_PP': 4, 'MAPK': 5, 'MAPK_P': 6, 'MAPK_PP': 7}, w_indexes={}, c_indexes={'uVol': 0, 'J0_V1': 1, 'J0_Ki': 2, 'J0_n': 3, 'J0_K1': 4, 'J1_V2': 5, 'J1_KK2': 6, 'J2_k3': 7, 'J2_KK3': 8, 'J3_k4': 9, 'J3_KK4': 10, 'J4_V5': 11, 'J4_KK5': 12, 'J5_V6': 13, 'J5_KK6': 14, 'J6_k7': 15, 'J6_KK7': 16, 'J7_k8': 17, 'J7_KK8': 18, 'J8_V9': 19, 'J8_KK9': 20, 'J9_V10': 21, 'J9_KK10': 22}, atol=1e-06, rtol=1e-12, mxstep=5000000, solver_type='diffrax', diffrax_solver='Tsit5'):

self.y_indexes = y_indexes
self.w_indexes = w_indexes
self.c_indexes = c_indexes

self.ratefunc = RateofSpeciesChange()
self.rtol = rtol
self.atol = atol
self.mxstep = mxstep
self.assignmentfunc = AssignmentRule()
self.solver_type = solver_type
if solver_type == 'odeint':
self.solver = odeint
elif solver_type == 'diffrax':
from diffrax import ODETerm, Tsit5, Dopri5, Dopri8, Euler, Midpoint, Heun, Bosh3, Ralston
valid_solvers = {'Tsit5', 'Dopri5', 'Dopri8', 'Euler', 'Midpoint', 'Heun', 'Bosh3', 'Ralston'}
if diffrax_solver not in valid_solvers:
raise ValueError(f'Unknown diffrax solver: {diffrax_solver}')
self.solver = eval(diffrax_solver)()
else:
raise ValueError(f'Unknown solver type: {solver_type}')

@jit
def __call__(self, y, w, c, t, deltaT):
y_new = odeint(self.ratefunc, y, jnp.array([t, t + deltaT]), w, c, atol=self.atol, rtol=self.rtol, mxstep=self.mxstep)[-1]
t_new = t + deltaT
w_new = self.assignmentfunc(y_new, w, c, t_new)
return y_new, w_new, c, t_new
if self.solver_type == 'odeint':
y_new = self.solver(self.ratefunc, y, jnp.array([t, t + deltaT]), w, c, atol=self.atol, rtol=self.rtol, mxstep=self.mxstep)[-1]
else: # diffrax
term = ODETerm(lambda t, y, args: self.ratefunc(y, t, *args))
tprev, tnext = t, t + deltaT
state = self.solver.init(term, tprev, tnext, y, (w, c))
y_new, _, _, _, _ = self.solver.step(term, tprev, tnext, y, (w, c), state, made_jump=False)
t_new = t + deltaT
w_new = self.assignmentfunc(y_new, w, c, t_new)
return y_new, w_new, c, t_new

class ModelRollout(eqx.Module):
deltaT: float = eqx.static_field()
modelstepfunc: ModelStep

def __init__(self, deltaT=0.1, atol=1e-06, rtol=1e-12, mxstep=5000000):
def __init__(self, deltaT=0.1, atol=1e-06, rtol=1e-12, mxstep=5000000, solver_type='diffrax', diffrax_solver='Tsit5'):

self.deltaT = deltaT
self.modelstepfunc = ModelStep(atol=atol, rtol=rtol, mxstep=mxstep)
self.modelstepfunc = ModelStep(atol=atol, rtol=rtol, mxstep=mxstep, solver_type=solver_type, diffrax_solver=diffrax_solver)

@partial(jit, static_argnames=("n_steps",))
def __call__(self, n_steps, y0=jnp.array([0.0, 0.999999999999999, 0.0, 3.0, 7.0]), w0=jnp.array([10.0, 10.0, 10.0, 9.000000000000002]), c=jnp.array([10.0, 10.0, 10.0, 2.0, 10.0, 1.0, 10.0, 1.0, 1.0, 1.0, 1.0, 10.0, 1.0]), t0=0.0):
def __call__(self, n_steps, y0=jnp.array([90.0, 10.0, 280.0, 10.0, 10.0, 280.0, 10.0, 10.0]), w0=jnp.array([]), c=jnp.array([1.0, 2.5, 9.0, 1.0, 10.0, 0.25, 8.0, 0.025, 15.0, 0.025, 15.0, 0.75, 15.0, 0.75, 15.0, 0.025, 15.0, 0.025, 15.0, 0.5, 15.0, 0.5, 15.0]), t0=0.0):

@jit
def f(carry, x):
Expand Down
14 changes: 7 additions & 7 deletions docs/source/tutorials/parallel_execution.ipynb

Large diffs are not rendered by default.

0 comments on commit 9eaec4d

Please sign in to comment.