Skip to content

flowersteam/sbmltoodejax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

68 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SBMLtoODEjax's logo

PyPI version Downloads arXiv

About

SBMLtoODEjax is a lightweight library that allows to automatically parse and convert SBML models into python models written end-to-end in JAX, a high-performance numerical computing library with automatic differentiation capabilities. SBMLtoODEjax is targeted at researchers that aim to incorporate SBML-specified ordinary differential equation (ODE) models into their python projects and machine learning pipelines, in order to perform efficient numerical simulation and optimization with only a few lines of code (by taking advantage of JAX’s core transformation features).

SBMLtoODEjax extends SBMLtoODEpy, a python library developed in 2019 for converting SBML files into python files written in Numpy/Scipy. The chosen conventions for the generated variables and modules are slightly different from the standard SBML conventions (used in the SBMLtoODEpy library) with the aim here to accommodate for more flexible manipulations while preserving JAX-like functional programming style.

👉 In short, SBMLtoODEjax facilitates the re-use of biological network models and their manipulation in python projects while tailoring them to take advantage of JAX main features for efficient and parallel computations.

📖 The documentation, notebook tutorials and public APU are available at https://developmentalsystems.org/sbmltoodejax/.

Installation

The latest stable release of SBMLtoODEjax can be installed via pip:

pip install sbmltoodejax

Requires SBMLtoODEpy, JAX (cpu) and Equinox.

Why use SBMLtoODEjax?

Simplicity and extensibility

SBMLtoODEjax retains the simplicity of the original SBMLtoODEPy library to facilitate incorporation and refactoring of the ODE models into one’s own python projects. As shown below, with only a few lines of python code one can load and simulate existing SBML files.

Figure 1 Example code (left) and output snapshot (right) reproducing original simulation results of Kholodenko 2000's paper hosted on BioModels website.

👉 Check our Numerical Simulation tutorial to reproduce results yourself and see more examples.

JAX-friendly

The generated python models are tailored to take advantage of JAX main features.

class ModelRollout(eqx.Module):
    
    def __call__(self, n_steps, y0, w0, c, t0=0.0):

        @jit # use of jit transformation decorator
        def f(carry, x):
            y, w, c, t = carry
            return self.modelstepfunc(y, w, c, t, self.deltaT), (y, w, t)
        
        # use of scan primitive to replace for loop (and reduce compilation time)
        (y, w, c, t), (ys, ws, ts) = lax.scan(f, (y0, w0, c, t0), jnp.arange(n_steps)) 
        ys = jnp.moveaxis(ys, 0, -1)
        ws = jnp.moveaxis(ws, 0, -1)
        
        return ys, ws, ts

As shown above, model rollouts use jit transformation and scan primitive to reduce compilation and execution time of the recursive ODE integration steps, which is particularly useful when running large numbers of steps (long reaction times). Models also inherit from the Equinox module abstraction and are registered as PyTree containers, which facilitates the application of JAX core transformations to any SBMLtoODEjax object.

Efficiency simulation and optimization

The application of JAX core transformations, such as just-in-time compilation (jit), automatic vectorization (vmap) and automatic differentiation (grad), to the generated models make it very easy (and seamless) to efficiently run simulations in parallel.

For instance, as shown below, with only a few lines of python code one can vectorize calls to model rollout and perform batched computations efficiently, which is particularly useful when considering large batch sizes. Figure 2 (left) Example code to vectorize calls to model rollout (right) Results of a (rudimentary) benchmark comparing the average simulation time of models implemented with SBMLtoODEpy versus SBMLtoODEjax (for different number of rollouts i.e. batch size).

👉 Check our Benchmarking notebook for additional details on the benchmark results.

Finally, as shown below, SBMLtoODEjax models can also be integrated within Optax pipelines, a gradient processing and optimization library for JAX, allowing to optimize model parameters and/or external interventions with stochastic gradient descent.

Figure 3 (left) Default simulation results of biomodel #145 which models ATP-induced intracellular calcium oscillations, and target sine-wave pattern for Ca_Cyt concentration. (middle) Training loss obtained when running the Optax optimization loop, with Adam optimizer, over the model kinematic parameters c. (right) Simulation results obtained after optimization.

👉 Check our Gradient Descent tutorial to reproduce the result yourself and try more-advanced optimization usages.

All contributions are welcome!

SBMLtoODEjax is in its early stage and any sort of contribution will be highly appreciated.

Suggested contributions

They are several use cases that are not handled by the current codebase including:

  1. Events: SBML files with events (discrete occurrences that can trigger discontinuous changes in the model) are not handled
  2. Math Functions: we handle a large portion, but not all, of functions possibly-used in SBML files (see mathFuncs in sbmltoodejax.modulegeneration.GenerateModel)
  3. Custom solvers: To integrate the model's equation, we use jax experimental odeint solver but do not yet allow for other solvers.
  4. NaN/Negative values: numerical simulation sometimes leads to NaN values (or negative values for the species amounts) which could either be due to wrong parsing or solver issues

This means that a large portion of the possible SBML files cannot yet be simulated, for instance as we detail on the below image, out of 1048 curated models that one can load from the BioModels website, only 232 can successfully be simulated (given the default initial conditions) in SBMLtoODEjax:

👉 Please consider contributing and check our Contribution Guidelines to learn how to do so.

License

The SBMLtoODEjax project is licensed under the MIT license.

Acknowledgements

SBMLtoODEjax builds on:

  • SBMLtoODEpy's parsing and conversion of SBML files, by Steve M. Ruggiero and Ashlee N. Ford
  • JAX's composable transformations, by the Google team
  • Equinox's module abstraction, by Patrick Kidger
  • BasiCO's access the BioModels REST api, by the COPASI team

Our documentation was also inspired by the GPJax documentation, by Thomas Pinder and team.

Citing SBMLtoODEjax

If you use SBMLtoODEjax in your research, please cite the paper:

@inproceedings{etcheverry2023sbmltoodejax,
title={SBMLtoODEjax: Efficient Simulation and Optimization of Biological Network Models in JAX},
author={Mayalen Etcheverry and Michael Levin and Clement Moulin-Frier and Pierre-Yves Oudeyer},
booktitle={NeurIPS 2023 AI for Science Workshop},
year={2023},
url={https://openreview.net/forum?id=exP6UntwqJ}
}