Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evaluate and plot symbolic expressions based on simulation results #2152

Merged
merged 13 commits into from
Nov 1, 2023
30 changes: 26 additions & 4 deletions python/examples/example_steadystate/ExampleSteadystate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@
"source": [
"### Importing the module and loading the model\n",
"\n",
"If everything went well, we need to add the previously selected model output directory to our PYTHON_PATH and are then ready to load newly generated model:"
"If everything went well, we can now import the newly generated Python module containing our model:"
]
},
{
Expand Down Expand Up @@ -392,7 +392,7 @@
"source": [
"model = model_module.getModel()\n",
"\n",
"print(\"Model name:\", model.getName())\n",
"print(\"Model name: \", model.getName())\n",
"print(\"Model parameters:\", model.getParameterIds())\n",
"print(\"Model outputs: \", model.getObservableIds())\n",
"print(\"Model states: \", model.getStateIds())"
Expand Down Expand Up @@ -917,10 +917,32 @@
"source": [
"import amici.plotting\n",
"\n",
"amici.plotting.plotStateTrajectories(rdata, model=None)\n",
"amici.plotting.plotObservableTrajectories(rdata, model=None)"
"amici.plotting.plot_state_trajectories(rdata, model=None)\n",
"amici.plotting.plot_observable_trajectories(rdata, model=None)"
]
},
{
"cell_type": "markdown",
"source": [
"We can also evaluate symbolic expressions of model quantities using `amici.numpy.evaluate`, or directly plot the results using `amici.plotting.plot_expressions`:"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"amici.plotting.plot_expressions(\n",
" \"observable_x1 + observable_x2 + observable_x3\", rdata=rdata\n",
")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
28 changes: 28 additions & 0 deletions python/sdist/amici/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@

import amici
import numpy as np
import sympy as sp

from . import ExpData, ExpDataPtr, Model, ReturnData, ReturnDataPtr

StrOrExpr = Union[str, sp.Expr]


class SwigPtrView(collections.abc.Mapping):
"""
Expand Down Expand Up @@ -429,3 +432,28 @@ def _entity_type_from_id(
return symbol

raise KeyError(f"Unknown symbol {entity_id}.")


def evaluate(expr: StrOrExpr, rdata: ReturnDataView) -> np.array:
"""Evaluate a symbolic expression based on the given simulation outputs.

:param expr:
A symbolic expression, e.g. a sympy expression or a string that can be sympified.
Can include state variable, expression, and observable IDs, depending on whether
the respective data is available in the simulation results.
Parameters are not yet supported.
:param rdata:
The simulation results.

:return:
The evaluated expression for the simulation output timepoints.
"""
from sympy.utilities.lambdify import lambdify

if isinstance(expr, str):
expr = sp.sympify(expr)

arg_names = list(sorted(expr.free_symbols, key=lambda x: x.name))
func = lambdify(arg_names, expr, "numpy")
args = [rdata.by_id(arg.name) for arg in arg_names]
return func(*args)
26 changes: 25 additions & 1 deletion python/sdist/amici/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
--------
Plotting related functions
"""
from typing import Iterable, Optional
from typing import Iterable, Optional, Sequence, Union

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes

from . import Model, ReturnDataView
from .numpy import StrOrExpr, evaluate


def plot_state_trajectories(
Expand Down Expand Up @@ -115,3 +116,26 @@
# backwards compatibility
plotStateTrajectories = plot_state_trajectories
plotObservableTrajectories = plot_observable_trajectories


def plot_expressions(

Check warning on line 121 in python/sdist/amici/plotting.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/plotting.py#L121

Added line #L121 was not covered by tests
exprs: Union[Sequence[StrOrExpr], StrOrExpr], rdata: ReturnDataView
) -> None:
"""Plot the given expressions evaluated on the given simulation outputs.

:param exprs:
A symbolic expression, e.g. a sympy expression or a string that can be sympified.
Can include state variable, expression, and observable IDs, depending on whether
the respective data is available in the simulation results.
Parameters are not yet supported.
:param rdata:
The simulation results.
"""
if not isinstance(exprs, Sequence) or isinstance(exprs, str):
exprs = [exprs]

Check warning on line 135 in python/sdist/amici/plotting.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/plotting.py#L134-L135

Added lines #L134 - L135 were not covered by tests

for expr in exprs:
plt.plot(rdata.t, evaluate(expr, rdata), label=str(expr))

Check warning on line 138 in python/sdist/amici/plotting.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/plotting.py#L137-L138

Added lines #L137 - L138 were not covered by tests

plt.legend()
plt.gca().set_xlabel("$t$")

Check warning on line 141 in python/sdist/amici/plotting.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/plotting.py#L140-L141

Added lines #L140 - L141 were not covered by tests
27 changes: 26 additions & 1 deletion python/tests/test_rdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import amici
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from amici.numpy import evaluate
from numpy.testing import assert_almost_equal, assert_array_equal


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -39,3 +40,27 @@ def test_rdata_by_id(rdata_by_id_fixture):
assert_array_equal(
rdata.by_id(model.getStateIds()[1], "sx", model), rdata.sx[:, :, 1]
)


def test_evaluate(rdata_by_id_fixture):
# get IDs of model components
model, rdata = rdata_by_id_fixture
expr0_id = model.getExpressionIds()[0]
state1_id = model.getStateIds()[1]
observable0_id = model.getObservableIds()[0]

# ensure `evaluate` works for atoms
expr0 = rdata.by_id(expr0_id)
assert_array_equal(expr0, evaluate(expr0_id, rdata=rdata))

state1 = rdata.by_id(state1_id)
assert_array_equal(state1, evaluate(state1_id, rdata=rdata))

observable0 = rdata.by_id(observable0_id)
assert_array_equal(observable0, evaluate(observable0_id, rdata=rdata))

# ensure `evaluate` works for expressions
assert_almost_equal(
expr0 + state1 * observable0,
evaluate(f"{expr0_id} + {state1_id} * {observable0_id}", rdata=rdata),
)