diff --git a/python/sdist/amici/plotting.py b/python/sdist/amici/plotting.py index bd1f3a8ba1..edf4a33156 100644 --- a/python/sdist/amici/plotting.py +++ b/python/sdist/amici/plotting.py @@ -6,6 +6,7 @@ from typing import Iterable, Optional, Sequence, Union import matplotlib.pyplot as plt +import numpy as np import pandas as pd import seaborn as sns from matplotlib.axes import Axes @@ -16,42 +17,55 @@ def plot_state_trajectories( rdata: ReturnDataView, - state_indices: Optional[Iterable[int]] = None, + state_indices: Optional[Sequence[int]] = None, ax: Optional[Axes] = None, model: Model = None, prefer_names: bool = True, + marker=None, ) -> None: """ - Plot state trajectories + Plot state trajectories. :param rdata: AMICI simulation results as returned by - :func:`amici.amici.runAmiciSimulation` - + :func:`amici.amici.runAmiciSimulation`. :param state_indices: - Indices of states for which trajectories are to be plotted - + Indices of state variables for which trajectories are to be plotted. :param ax: - matplotlib Axes instance to plot into - + :class:`matplotlib.pyplot.Axes` instance to plot into. :param model: - amici model instance - + The model *rdata* was generated from. :param prefer_names: Whether state names should be preferred over IDs, if available. + :param marker: + Point marker for plotting (see + `matplotlib documentation `_). """ if not ax: fig, ax = plt.subplots() if not state_indices: state_indices = range(rdata["x"].shape[1]) - for ix in state_indices: - if model is None: - label = f"$x_{{{ix}}}$" - elif prefer_names and model.getStateNames()[ix]: - label = model.getStateNames()[ix] - else: - label = model.getStateIds()[ix] - ax.plot(rdata["t"], rdata["x"][:, ix], label=label) + + if marker is None: + # Show marker if only one time point is available, + # otherwise nothing will be shown + marker = "o" if len(rdata.t) == 1 else None + + if model is None and rdata.ptr.state_ids is None: + labels = [f"$x_{{{ix}}}$" for ix in state_indices] + elif model is not None and prefer_names: + labels = np.asarray(model.getStateNames())[list(state_indices)] + labels = [ + l if l else model.getStateIds()[ix] + for ix, l in enumerate(labels) + ] + elif model is not None: + labels = np.asarray(model.getStateIds())[list(state_indices)] + else: + labels = np.asarray(rdata.ptr.state_ids)[list(state_indices)] + + for ix, label in zip(state_indices, labels): + ax.plot(rdata["t"], rdata["x"][:, ix], marker=marker, label=label) ax.set_xlabel("$t$") ax.set_ylabel("$x(t)$") ax.legend() @@ -64,38 +78,54 @@ def plot_observable_trajectories( ax: Optional[Axes] = None, model: Model = None, prefer_names: bool = True, + marker=None, ) -> None: """ - Plot observable trajectories + Plot observable trajectories. :param rdata: AMICI simulation results as returned by - :func:`amici.amici.runAmiciSimulation` - + :func:`amici.amici.runAmiciSimulation`. :param observable_indices: - Indices of observables for which trajectories are to be plotted - + Indices of observables for which trajectories are to be plotted. :param ax: - matplotlib Axes instance to plot into - + :class:`matplotlib.pyplot.Axes` instance to plot into. :param model: - amici model instance - + The model *rdata* was generated from. :param prefer_names: - Whether observables names should be preferred over IDs, if available. + Whether observable names should be preferred over IDs, if available. + :param marker: + Point marker for plotting (see + `matplotlib documentation `_). + """ if not ax: fig, ax = plt.subplots() if not observable_indices: observable_indices = range(rdata["y"].shape[1]) - for iy in observable_indices: - if model is None: - label = f"$y_{{{iy}}}$" - elif prefer_names and model.getObservableNames()[iy]: - label = model.getObservableNames()[iy] - else: - label = model.getObservableIds()[iy] - ax.plot(rdata["t"], rdata["y"][:, iy], label=label) + + if marker is None: + # Show marker if only one time point is available, + # otherwise nothing will be shown + marker = "o" if len(rdata.t) == 1 else None + + if model is None and rdata.ptr.observable_ids is None: + labels = [f"$y_{{{iy}}}$" for iy in observable_indices] + elif model is not None and prefer_names: + labels = np.asarray(model.getObservableNames())[ + list(observable_indices) + ] + labels = [ + l if l else model.getObservableIds()[ix] + for ix, l in enumerate(labels) + ] + elif model is not None: + labels = np.asarray(model.getObservableIds())[list(observable_indices)] + else: + labels = np.asarray(rdata.ptr.observable_ids)[list(observable_indices)] + + for iy, label in zip(observable_indices, labels): + ax.plot(rdata["t"], rdata["y"][:, iy], marker=marker, label=label) ax.set_xlabel("$t$") ax.set_ylabel("$y(t)$") ax.legend() @@ -106,8 +136,8 @@ def plot_jacobian(rdata: ReturnDataView): """Plot Jacobian as heatmap.""" df = pd.DataFrame( data=rdata.J, - index=rdata._swigptr.state_ids_solver, - columns=rdata._swigptr.state_ids_solver, + index=rdata.ptr.state_ids_solver, + columns=rdata.ptr.state_ids_solver, ) sns.heatmap(df, center=0.0) plt.title("Jacobian") @@ -124,10 +154,10 @@ def plot_expressions( """Plot the given expressions evaluated on the given simulation outputs. :param exprs: - A symbolic expression, e.g. a sympy expression or a string that can be sympified. - Can include state variable, expression, and observable IDs, depending on whether - the respective data is available in the simulation results. - Parameters are not yet supported. + A symbolic expression, e.g., a sympy expression or a string that can be + sympified. It Can include state variable, expression, and + observable IDs, depending on whether the respective data is available + in the simulation results. Parameters are not yet supported. :param rdata: The simulation results. """ diff --git a/python/sdist/setup.cfg b/python/sdist/setup.cfg index b0945eca6e..8e8797fda7 100644 --- a/python/sdist/setup.cfg +++ b/python/sdist/setup.cfg @@ -52,7 +52,10 @@ test = pytest-rerunfailures coverage shyaml - antimony + antimony>=2.13 + # see https://github.com/sys-bio/antimony/issues/92 + # unsupported x86_64 / x86_64h + antimony!=2.14; platform_system=='Darwin' and platform_machine in 'x86_64h' vis = matplotlib seaborn