From f16293d0e0780276b6ef3e0861c28e4f9990aa54 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Tue, 19 Sep 2023 21:57:34 +0200 Subject: [PATCH] refactor --- python/sdist/amici/numpy.py | 28 ++++++++++++++++++++++++++++ python/sdist/amici/plotting.py | 30 +----------------------------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/python/sdist/amici/numpy.py b/python/sdist/amici/numpy.py index 23ebfdbbc4..d9b34b6447 100644 --- a/python/sdist/amici/numpy.py +++ b/python/sdist/amici/numpy.py @@ -10,9 +10,12 @@ import amici import numpy as np +import sympy as sp from . import ExpData, ExpDataPtr, Model, ReturnData, ReturnDataPtr +StrOrExpr = Union[str, sp.Expr] + class SwigPtrView(collections.abc.Mapping): """ @@ -429,3 +432,28 @@ def _entity_type_from_id( return symbol raise KeyError(f"Unknown symbol {entity_id}.") + + +def evaluate(expr: StrOrExpr, rdata: ReturnDataView) -> np.array: + """Evaluate a symbolic expression based on the given simulation outputs. + + :param expr: + A symbolic expression, e.g. a sympy expression or a string that can be sympified. + Can include state variable, expression, and observable IDs, depending on whether + the respective data is available in the simulation results. + Parameters are not yet supported. + :param rdata: + The simulation results. + + :return: + The evaluated expression for the simulation output timepoints. + """ + from sympy.utilities.lambdify import lambdify + + if isinstance(expr, str): + expr = sp.sympify(expr) + + arg_names = list(sorted(expr.free_symbols, key=lambda x: x.name)) + func = lambdify(arg_names, expr, "numpy") + args = [rdata.by_id(arg.name) for arg in arg_names] + return func(*args) diff --git a/python/sdist/amici/plotting.py b/python/sdist/amici/plotting.py index e99be7d969..330e74edea 100644 --- a/python/sdist/amici/plotting.py +++ b/python/sdist/amici/plotting.py @@ -6,15 +6,12 @@ from typing import Iterable, Optional, Sequence, Union import matplotlib.pyplot as plt -import numpy as np import pandas as pd import seaborn as sns -import sympy as sp from matplotlib.axes import Axes from . import Model, ReturnDataView - -StrOrExpr = Union[str, sp.Expr] +from .numpy import StrOrExpr, evaluate def plot_state_trajectories( @@ -121,31 +118,6 @@ def plot_jacobian(rdata: ReturnDataView): plotObservableTrajectories = plot_observable_trajectories -def evaluate(expr: StrOrExpr, rdata: ReturnDataView) -> np.array: - """Evaluate a symbolic expression based on the given simulation outputs. - - :param expr: - A symbolic expression, e.g. a sympy expression or a string that can be sympified. - Can include state variable, expression, and observable IDs, depending on whether - the respective data is available in the simulation results. - Parameters are not yet supported. - :param rdata: - The simulation results. - - :return: - The evaluated expression for the simulation output timepoints. - """ - from sympy.utilities.lambdify import lambdify - - if isinstance(expr, str): - expr = sp.sympify(expr) - - arg_names = list(sorted(expr.free_symbols, key=lambda x: x.name)) - func = lambdify(arg_names, expr, "numpy") - args = [rdata.by_id(arg.name) for arg in arg_names] - return func(*args) - - def plot_expressions( exprs: Union[Sequence[StrOrExpr], StrOrExpr], rdata: ReturnDataView ) -> None: