diff --git a/python/sdist/amici/plotting.py b/python/sdist/amici/plotting.py index da718c1ec7..aa93bd7011 100644 --- a/python/sdist/amici/plotting.py +++ b/python/sdist/amici/plotting.py @@ -3,15 +3,18 @@ -------- Plotting related functions """ -from typing import Iterable, Optional +from typing import Iterable, Optional, Sequence, Union import matplotlib.pyplot as plt import pandas as pd import seaborn as sns +import sympy as sp from matplotlib.axes import Axes from . import Model, ReturnDataView +StrOrExpr = Union[str, sp.Expr] + def plot_state_trajectories( rdata: ReturnDataView, @@ -115,3 +118,49 @@ def plot_jacobian(rdata: ReturnDataView): # backwards compatibility plotStateTrajectories = plot_state_trajectories plotObservableTrajectories = plot_observable_trajectories + + +def evaluate(expr: StrOrExpr, rdata: ReturnDataView) -> "numpy.array": + """Evaluate a symbolic expression based on the given simulation outputs. + + :param expr: + A symbolic expression, e.g. a sympy expression or a string that can be sympified. + Can include state variable, expression, and observable IDs, depending on whether + the respective data is available in the simulation results. + Parameters are not yet supported. + :param rdata: + The simulation results. + + :return: + The evaluated expression for the simulation output timepoints. + """ + from sympy.utilities.lambdify import lambdify + + if isinstance(expr, str): + expr = sp.sympify(expr) + + arg_names = list(sorted(expr.free_symbols, key=lambda x: x.name)) + func = lambdify(arg_names, expr, "numpy") + args = [rdata.by_id(arg.name) for arg in arg_names] + return func(*args) + + +def plot_expressions( + exprs: Union[Sequence[StrOrExpr], StrOrExpr], rdata: ReturnDataView +) -> None: + """Plot the given expressions evaluated on the given simulation outputs. + + :param exprs: + A symbolic expression, e.g. a sympy expression or a string that can be sympified. + Can include state variable, expression, and observable IDs, depending on whether + the respective data is available in the simulation results. + Parameters are not yet supported. + :param rdata: + The simulation results. + """ + if not isinstance(exprs, Sequence): + exprs = [exprs] + + for expr in exprs: + plt.plot(rdata.t, evaluate(expr, rdata), label=str(expr)) + plt.legend()