Skip to content

Commit

Permalink
Merge branch 'develop' into doc_edata
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl authored Jan 5, 2024
2 parents 14864f4 + dd07008 commit 02faf2e
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 43 deletions.
114 changes: 72 additions & 42 deletions python/sdist/amici/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Iterable, Optional, Sequence, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
Expand All @@ -16,42 +17,55 @@

def plot_state_trajectories(
rdata: ReturnDataView,
state_indices: Optional[Iterable[int]] = None,
state_indices: Optional[Sequence[int]] = None,
ax: Optional[Axes] = None,
model: Model = None,
prefer_names: bool = True,
marker=None,
) -> None:
"""
Plot state trajectories
Plot state trajectories.
:param rdata:
AMICI simulation results as returned by
:func:`amici.amici.runAmiciSimulation`
:func:`amici.amici.runAmiciSimulation`.
:param state_indices:
Indices of states for which trajectories are to be plotted
Indices of state variables for which trajectories are to be plotted.
:param ax:
matplotlib Axes instance to plot into
:class:`matplotlib.pyplot.Axes` instance to plot into.
:param model:
amici model instance
The model *rdata* was generated from.
:param prefer_names:
Whether state names should be preferred over IDs, if available.
:param marker:
Point marker for plotting (see
`matplotlib documentation <https://matplotlib.org/stable/api/markers_api.html>`_).
"""
if not ax:
fig, ax = plt.subplots()
if not state_indices:
state_indices = range(rdata["x"].shape[1])
for ix in state_indices:
if model is None:
label = f"$x_{{{ix}}}$"
elif prefer_names and model.getStateNames()[ix]:
label = model.getStateNames()[ix]
else:
label = model.getStateIds()[ix]
ax.plot(rdata["t"], rdata["x"][:, ix], label=label)

if marker is None:
# Show marker if only one time point is available,
# otherwise nothing will be shown
marker = "o" if len(rdata.t) == 1 else None

if model is None and rdata.ptr.state_ids is None:
labels = [f"$x_{{{ix}}}$" for ix in state_indices]
elif model is not None and prefer_names:
labels = np.asarray(model.getStateNames())[list(state_indices)]
labels = [
l if l else model.getStateIds()[ix]
for ix, l in enumerate(labels)
]
elif model is not None:
labels = np.asarray(model.getStateIds())[list(state_indices)]
else:
labels = np.asarray(rdata.ptr.state_ids)[list(state_indices)]

for ix, label in zip(state_indices, labels):
ax.plot(rdata["t"], rdata["x"][:, ix], marker=marker, label=label)
ax.set_xlabel("$t$")
ax.set_ylabel("$x(t)$")
ax.legend()
Expand All @@ -64,38 +78,54 @@ def plot_observable_trajectories(
ax: Optional[Axes] = None,
model: Model = None,
prefer_names: bool = True,
marker=None,
) -> None:
"""
Plot observable trajectories
Plot observable trajectories.
:param rdata:
AMICI simulation results as returned by
:func:`amici.amici.runAmiciSimulation`
:func:`amici.amici.runAmiciSimulation`.
:param observable_indices:
Indices of observables for which trajectories are to be plotted
Indices of observables for which trajectories are to be plotted.
:param ax:
matplotlib Axes instance to plot into
:class:`matplotlib.pyplot.Axes` instance to plot into.
:param model:
amici model instance
The model *rdata* was generated from.
:param prefer_names:
Whether observables names should be preferred over IDs, if available.
Whether observable names should be preferred over IDs, if available.
:param marker:
Point marker for plotting (see
`matplotlib documentation <https://matplotlib.org/stable/api/markers_api.html>`_).
"""
if not ax:
fig, ax = plt.subplots()
if not observable_indices:
observable_indices = range(rdata["y"].shape[1])
for iy in observable_indices:
if model is None:
label = f"$y_{{{iy}}}$"
elif prefer_names and model.getObservableNames()[iy]:
label = model.getObservableNames()[iy]
else:
label = model.getObservableIds()[iy]
ax.plot(rdata["t"], rdata["y"][:, iy], label=label)

if marker is None:
# Show marker if only one time point is available,
# otherwise nothing will be shown
marker = "o" if len(rdata.t) == 1 else None

if model is None and rdata.ptr.observable_ids is None:
labels = [f"$y_{{{iy}}}$" for iy in observable_indices]
elif model is not None and prefer_names:
labels = np.asarray(model.getObservableNames())[
list(observable_indices)
]
labels = [
l if l else model.getObservableIds()[ix]
for ix, l in enumerate(labels)
]
elif model is not None:
labels = np.asarray(model.getObservableIds())[list(observable_indices)]
else:
labels = np.asarray(rdata.ptr.observable_ids)[list(observable_indices)]

for iy, label in zip(observable_indices, labels):
ax.plot(rdata["t"], rdata["y"][:, iy], marker=marker, label=label)
ax.set_xlabel("$t$")
ax.set_ylabel("$y(t)$")
ax.legend()
Expand All @@ -106,8 +136,8 @@ def plot_jacobian(rdata: ReturnDataView):
"""Plot Jacobian as heatmap."""
df = pd.DataFrame(
data=rdata.J,
index=rdata._swigptr.state_ids_solver,
columns=rdata._swigptr.state_ids_solver,
index=rdata.ptr.state_ids_solver,
columns=rdata.ptr.state_ids_solver,
)
sns.heatmap(df, center=0.0)
plt.title("Jacobian")
Expand All @@ -124,10 +154,10 @@ def plot_expressions(
"""Plot the given expressions evaluated on the given simulation outputs.
:param exprs:
A symbolic expression, e.g. a sympy expression or a string that can be sympified.
Can include state variable, expression, and observable IDs, depending on whether
the respective data is available in the simulation results.
Parameters are not yet supported.
A symbolic expression, e.g., a sympy expression or a string that can be
sympified. It Can include state variable, expression, and
observable IDs, depending on whether the respective data is available
in the simulation results. Parameters are not yet supported.
:param rdata:
The simulation results.
"""
Expand Down
5 changes: 4 additions & 1 deletion python/sdist/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ test =
pytest-rerunfailures
coverage
shyaml
antimony
antimony>=2.13
# see https://github.com/sys-bio/antimony/issues/92
# unsupported x86_64 / x86_64h
antimony!=2.14; platform_system=='Darwin' and platform_machine in 'x86_64h'
vis =
matplotlib
seaborn
Expand Down

0 comments on commit 02faf2e

Please sign in to comment.