Skip to content

Commit

Permalink
Evaluate and plot symbolic expressions based on simulation results (#…
Browse files Browse the repository at this point in the history
…2152)

Adds functions `amici.numpy.evaluate` and `amici.plotting.plot_expressions` to evaluate or directly plot symbolic expressions of model quantities, respectively.

Demo: see end of this section https://amici--2152.org.readthedocs.build/en/2152/ExampleSteadystate.html#Plotting-trajectories
  • Loading branch information
dweindl authored Nov 1, 2023
1 parent ea95896 commit 71295a5
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 6 deletions.
30 changes: 26 additions & 4 deletions python/examples/example_steadystate/ExampleSteadystate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@
"source": [
"### Importing the module and loading the model\n",
"\n",
"If everything went well, we need to add the previously selected model output directory to our PYTHON_PATH and are then ready to load newly generated model:"
"If everything went well, we can now import the newly generated Python module containing our model:"
]
},
{
Expand Down Expand Up @@ -392,7 +392,7 @@
"source": [
"model = model_module.getModel()\n",
"\n",
"print(\"Model name:\", model.getName())\n",
"print(\"Model name: \", model.getName())\n",
"print(\"Model parameters:\", model.getParameterIds())\n",
"print(\"Model outputs: \", model.getObservableIds())\n",
"print(\"Model states: \", model.getStateIds())"
Expand Down Expand Up @@ -917,10 +917,32 @@
"source": [
"import amici.plotting\n",
"\n",
"amici.plotting.plotStateTrajectories(rdata, model=None)\n",
"amici.plotting.plotObservableTrajectories(rdata, model=None)"
"amici.plotting.plot_state_trajectories(rdata, model=None)\n",
"amici.plotting.plot_observable_trajectories(rdata, model=None)"
]
},
{
"cell_type": "markdown",
"source": [
"We can also evaluate symbolic expressions of model quantities using `amici.numpy.evaluate`, or directly plot the results using `amici.plotting.plot_expressions`:"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"amici.plotting.plot_expressions(\n",
" \"observable_x1 + observable_x2 + observable_x3\", rdata=rdata\n",
")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
28 changes: 28 additions & 0 deletions python/sdist/amici/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@

import amici
import numpy as np
import sympy as sp

from . import ExpData, ExpDataPtr, Model, ReturnData, ReturnDataPtr

StrOrExpr = Union[str, sp.Expr]


class SwigPtrView(collections.abc.Mapping):
"""
Expand Down Expand Up @@ -429,3 +432,28 @@ def _entity_type_from_id(
return symbol

raise KeyError(f"Unknown symbol {entity_id}.")


def evaluate(expr: StrOrExpr, rdata: ReturnDataView) -> np.array:
"""Evaluate a symbolic expression based on the given simulation outputs.
:param expr:
A symbolic expression, e.g. a sympy expression or a string that can be sympified.
Can include state variable, expression, and observable IDs, depending on whether
the respective data is available in the simulation results.
Parameters are not yet supported.
:param rdata:
The simulation results.
:return:
The evaluated expression for the simulation output timepoints.
"""
from sympy.utilities.lambdify import lambdify

if isinstance(expr, str):
expr = sp.sympify(expr)

arg_names = list(sorted(expr.free_symbols, key=lambda x: x.name))
func = lambdify(arg_names, expr, "numpy")
args = [rdata.by_id(arg.name) for arg in arg_names]
return func(*args)
26 changes: 25 additions & 1 deletion python/sdist/amici/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
--------
Plotting related functions
"""
from typing import Iterable, Optional
from typing import Iterable, Optional, Sequence, Union

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes

from . import Model, ReturnDataView
from .numpy import StrOrExpr, evaluate


def plot_state_trajectories(
Expand Down Expand Up @@ -115,3 +116,26 @@ def plot_jacobian(rdata: ReturnDataView):
# backwards compatibility
plotStateTrajectories = plot_state_trajectories
plotObservableTrajectories = plot_observable_trajectories


def plot_expressions(
exprs: Union[Sequence[StrOrExpr], StrOrExpr], rdata: ReturnDataView
) -> None:
"""Plot the given expressions evaluated on the given simulation outputs.
:param exprs:
A symbolic expression, e.g. a sympy expression or a string that can be sympified.
Can include state variable, expression, and observable IDs, depending on whether
the respective data is available in the simulation results.
Parameters are not yet supported.
:param rdata:
The simulation results.
"""
if not isinstance(exprs, Sequence) or isinstance(exprs, str):
exprs = [exprs]

for expr in exprs:
plt.plot(rdata.t, evaluate(expr, rdata), label=str(expr))

plt.legend()
plt.gca().set_xlabel("$t$")
27 changes: 26 additions & 1 deletion python/tests/test_rdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import amici
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from amici.numpy import evaluate
from numpy.testing import assert_almost_equal, assert_array_equal


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -39,3 +40,27 @@ def test_rdata_by_id(rdata_by_id_fixture):
assert_array_equal(
rdata.by_id(model.getStateIds()[1], "sx", model), rdata.sx[:, :, 1]
)


def test_evaluate(rdata_by_id_fixture):
# get IDs of model components
model, rdata = rdata_by_id_fixture
expr0_id = model.getExpressionIds()[0]
state1_id = model.getStateIds()[1]
observable0_id = model.getObservableIds()[0]

# ensure `evaluate` works for atoms
expr0 = rdata.by_id(expr0_id)
assert_array_equal(expr0, evaluate(expr0_id, rdata=rdata))

state1 = rdata.by_id(state1_id)
assert_array_equal(state1, evaluate(state1_id, rdata=rdata))

observable0 = rdata.by_id(observable0_id)
assert_array_equal(observable0, evaluate(observable0_id, rdata=rdata))

# ensure `evaluate` works for expressions
assert_almost_equal(
expr0 + state1 * observable0,
evaluate(f"{expr0_id} + {state1_id} * {observable0_id}", rdata=rdata),
)

0 comments on commit 71295a5

Please sign in to comment.