Skip to content

Commit

Permalink
Add diffrax solver support to modulegeneration and utils
Browse files Browse the repository at this point in the history
- Added diffrax solver as an alternative to odeint
- Updated modulegeneration to support diffrax solver configuration
  • Loading branch information
Dylan Esguerra committed Dec 3, 2024
1 parent ca8aaa2 commit baabb33
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 1,583 deletions.
61 changes: 46 additions & 15 deletions sbmltoodejax/modulegeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ def GenerateModel(modelData, outputFilePath,
deltaT: float =0.1,
atol: float=1e-6,
rtol: float = 1e-12,
mxstep: int = 5000000
mxstep: int = 5000000,
solver_type: str = 'odeint',
diffrax_solver: str = 'Tsit5'
):
"""
This function takes model data created by :func:`~sbmltoodejax.parse.ParseSBMLFile` and generates a python file containing
Expand Down Expand Up @@ -125,7 +127,9 @@ def GenerateModel(modelData, outputFilePath,
outputFile.write("from functools import partial\n")
outputFile.write("from jax import jit, lax, vmap\n")
outputFile.write("from jax.experimental.ode import odeint\n")
outputFile.write("import jax.numpy as jnp\n\n")
outputFile.write("import jax.numpy as jnp\n")
outputFile.write("from diffrax import ODETerm, Tsit5, Dopri5, Dopri8, Euler, Midpoint, Heun, Bosh3, Ralston\n")
outputFile.write("from typing import Any\n\n")
outputFile.write("from sbmltoodejax import jaxfuncs\n\n")


Expand Down Expand Up @@ -512,42 +516,59 @@ def ParseRHS(rawRateLaw, extended_param_names=[], reaction_name=None, yvar="y",
outputFile.write("\tatol: float = eqx.static_field()\n")
outputFile.write("\trtol: float = eqx.static_field()\n")
outputFile.write("\tmxstep: int = eqx.static_field()\n")
outputFile.write(f"\tassignmentfunc: {AssignmentRuleName}\n\n")
outputFile.write(f"\tassignmentfunc: {AssignmentRuleName}\n")
outputFile.write("\tsolver_type: str = eqx.static_field()\n")
outputFile.write("\tsolver: Any = eqx.static_field()\n\n")

outputFile.write(f"\tdef __init__(self, "
f"y_indexes={y_indexes}, "
f"w_indexes={w_indexes}, "
f"c_indexes={c_indexes}, "
f"atol={atol}, rtol={rtol}, mxstep={mxstep}):\n\n")
f"atol={atol}, rtol={rtol}, mxstep={mxstep}, "
f"solver_type='{solver_type}', diffrax_solver='{diffrax_solver}'):\n\n")

outputFile.write("\t\tself.y_indexes = y_indexes\n")
outputFile.write("\t\tself.w_indexes = w_indexes\n")
outputFile.write("\t\tself.c_indexes = c_indexes\n\n")

outputFile.write("\t\tself.c_indexes = c_indexes\n")
outputFile.write(f"\t\tself.ratefunc = {RateofSpeciesChangeName}()\n")
outputFile.write("\t\tself.rtol = rtol\n")
outputFile.write("\t\tself.atol = atol\n")
outputFile.write("\t\tself.mxstep = mxstep\n")

outputFile.write(f"\t\tself.assignmentfunc = {AssignmentRuleName}()\n\n")

outputFile.write(f"\t\tself.assignmentfunc = {AssignmentRuleName}()\n")
outputFile.write("\t\tself.solver_type = solver_type\n")
outputFile.write("\t\tif solver_type == 'odeint':\n")
outputFile.write("\t\t\tself.solver = odeint\n")
outputFile.write("\t\telif solver_type == 'diffrax':\n")
outputFile.write("\t\t\tfrom diffrax import ODETerm, Tsit5, Dopri5, Dopri8, Euler, Midpoint, Heun, Bosh3, Ralston\n")
outputFile.write("\t\t\tvalid_solvers = {'Tsit5', 'Dopri5', 'Dopri8', 'Euler', 'Midpoint', 'Heun', 'Bosh3', 'Ralston'}\n")
outputFile.write("\t\t\tif diffrax_solver not in valid_solvers:\n")
outputFile.write("\t\t\t\traise ValueError(f'Unknown diffrax solver: {diffrax_solver}')\n")
outputFile.write(f"\t\t\tself.solver = {diffrax_solver}()\n")
outputFile.write("\t\telse:\n")
outputFile.write("\t\t\traise ValueError(f'Unknown solver type: {solver_type}')\n\n")

outputFile.write("\t@jit\n")
outputFile.write("\tdef __call__(self, y, w, c, t, deltaT):\n")
outputFile.write("\t\ty_new = odeint(self.ratefunc, y, jnp.array([t, t + deltaT]), w, c, atol=self.atol, rtol=self.rtol, mxstep=self.mxstep)[-1]\t\n")
outputFile.write("\t\tt_new = t + deltaT\t\n")
outputFile.write("\t\tw_new = self.assignmentfunc(y_new, w, c, t_new)\t\n")
outputFile.write("\t\treturn y_new, w_new, c, t_new\t\n\n")
outputFile.write("\t\tif self.solver_type == 'odeint':\n")
outputFile.write("\t\t\ty_new = odeint(self.ratefunc, y, jnp.array([t, t + deltaT]), w, c, atol=self.atol, rtol=self.rtol, mxstep=self.mxstep)[-1]\n")
outputFile.write("\t\telse: # diffrax\n")
outputFile.write("\t\t\tterm = ODETerm(lambda t, y, args: self.ratefunc(y, t, *args))\n")
outputFile.write("\t\t\ttprev, tnext = t, t + deltaT\n")
outputFile.write("\t\t\tstate = self.solver.init(term, tprev, tnext, y, (w, c))\n")
outputFile.write("\t\t\ty_new, _, _, _, _ = self.solver.step(term, tprev, tnext, y, (w, c), state, made_jump=False)\n")
outputFile.write("\t\tt_new = t + deltaT\n")
outputFile.write("\t\tw_new = self.assignmentfunc(y_new, w, c, t_new)\n")
outputFile.write("\t\treturn y_new, w_new, c, t_new\n\n")

# ================================================================================================================================

outputFile.write("class " + ModelRolloutName + "(eqx.Module):\n")
outputFile.write("\tdeltaT: float = eqx.static_field()\n")
outputFile.write(f"\tmodelstepfunc: {ModelStepName}\n\n")

outputFile.write(f"\tdef __init__(self, deltaT={deltaT}, atol={atol}, rtol={rtol}, mxstep={mxstep}):\n\n")
outputFile.write(f"\tdef __init__(self, deltaT={deltaT}, atol={atol}, rtol={rtol}, mxstep={mxstep}, solver_type='{solver_type}', diffrax_solver='{diffrax_solver}'):\n\n")
outputFile.write("\t\tself.deltaT = deltaT\n")
outputFile.write(f"\t\tself.modelstepfunc = {ModelStepName}(atol=atol, rtol=rtol, mxstep=mxstep)\n\n")
outputFile.write(f"\t\tself.modelstepfunc = {ModelStepName}(atol=atol, rtol=rtol, mxstep=mxstep, solver_type=solver_type, diffrax_solver=diffrax_solver)\n\n")

outputFile.write("\t@partial(jit, static_argnames=(\"n_steps\",))\n")
outputFile.write("\tdef __call__(self, n_steps, "
Expand All @@ -571,3 +592,13 @@ def ParseRHS(rawRateLaw, extended_param_names=[], reaction_name=None, yvar="y",

# ================================================================================================================================
outputFile.close()










Loading

0 comments on commit baabb33

Please sign in to comment.