Skip to content

Commit

Permalink
Add optional Axes argument to plot functions and return final Axes in…
Browse files Browse the repository at this point in the history
…stances
  • Loading branch information
andreArtelt committed May 30, 2024
1 parent b57c00a commit e1f2fcb
Showing 1 changed file with 59 additions and 18 deletions.
77 changes: 59 additions & 18 deletions epyt_flow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import requests
from tqdm import tqdm
import numpy as np
import matplotlib
import matplotlib.pyplot as plt


Expand Down Expand Up @@ -67,7 +68,8 @@ def volume_to_level(tank_volume: float, tank_diameter: float) -> float:


def plot_timeseries_data(data: np.ndarray, labels: list[str] = None, x_axis_label: str = None,
y_axis_label: str = None, show: bool = True) -> None:
y_axis_label: str = None, show: bool = True,
ax: matplotlib.axes.Axes = None) -> matplotlib.axes.Axes:
"""
Plots a single or multiple time series.
Expand All @@ -91,7 +93,18 @@ def plot_timeseries_data(data: np.ndarray, labels: list[str] = None, x_axis_labe
show : `bool`, optional
If True, the plot/figure is shown in a window.
Only considered when 'ax' is None.
The default is True.
ax : `matplotlib.axes.Axes`, optional
If not None, 'ax' is used for plotting.
The default is None.
Returns
-------
`matplotlib.axes.Axes`
Plot.
"""
if not isinstance(data, np.ndarray):
raise TypeError(f"'data' must be an instance of 'numpy.ndarray' but not of '{type(data)}'")
Expand All @@ -111,28 +124,37 @@ def plot_timeseries_data(data: np.ndarray, labels: list[str] = None, x_axis_labe
f"but not of '{type(y_axis_label)}'")
if not isinstance(show, bool):
raise TypeError(f"'show' must be an instance of 'bool' but not of '{type(show)}'")
if ax is not None:
if not isinstance(ax, matplotlib.axes.Axes):
raise TypeError("ax' must be an instance of 'matplotlib.axes.Axes'" +
f"but not of '{type(ax)}'")

plt.figure()
fig = None
if ax is None:
fig, ax = plt.subplots()

labels = labels if labels is not None else [None] * data.shape[0]

for i in range(data.shape[0]):
plt.plot(data[i, :], ".-", label=labels[i])
ax.plot(data[i, :], ".-", label=labels[i])

if not any(label is None for label in labels):
plt.legend()
ax.legend()

if x_axis_label is not None:
plt.xlabel(x_axis_label)
ax.set_xlabel(x_axis_label)
if y_axis_label is not None:
plt.ylabel(y_axis_label)
ax.set_ylabel(y_axis_label)

if show is True:
if show is True and fig is not None:
plt.show()

return ax


def plot_timeseries_prediction(y: np.ndarray, y_pred: np.ndarray,
confidence_interval: np.ndarray = None, show: bool = True) -> None:
confidence_interval: np.ndarray = None, show: bool = True,
ax: matplotlib.axes.Axes = None) -> matplotlib.axes.Axes:
"""
Plots the prediction (e.g. forecast) of *single* time series together with the
ground truth time series. In addition, confidence intervals can be plotted as well.
Expand All @@ -151,7 +173,18 @@ def plot_timeseries_prediction(y: np.ndarray, y_pred: np.ndarray,
show : `bool`, optional
If True, the plot/figure is shown in a window.
Only considered when 'ax' is None.
The default is True.
ax : `matplotlib.axes.Axes`, optional
If not None, 'axes' is used for plotting.
The default is None.
Returns
-------
`matplotlib.axes.Axes`
Plot.
"""
if not isinstance(y_pred, np.ndarray):
raise TypeError("'y_pred' must be an instance of 'numpy.ndarray' " +
Expand All @@ -167,21 +200,29 @@ def plot_timeseries_prediction(y: np.ndarray, y_pred: np.ndarray,
raise ValueError("'y' must be a 1d array")
if not isinstance(show, bool):
raise TypeError(f"'show' must be an instance of 'bool' but not of '{type(show)}'")
if ax is not None:
if not isinstance(ax, matplotlib.axes.Axes):
raise TypeError("ax' must be an instance of 'matplotlib.axes.Axes'" +
f"but not of '{type(ax)}'")

plt.figure()
fig = None
if ax is None:
fig, ax = plt.subplots()

if confidence_interval is not None:
plt.fill_between(range(len(y_pred)),
y_pred - confidence_interval[0],
y_pred + confidence_interval[1],
alpha=0.5)
plt.plot(y_pred, ".-", label="Prediction")
plt.plot(y, ".-", label="Ground truth")
plt.legend()

if show is True:
ax.fill_between(range(len(y_pred)),
y_pred - confidence_interval[0],
y_pred + confidence_interval[1],
alpha=0.5)
ax.plot(y_pred, ".-", label="Prediction")
ax.plot(y, ".-", label="Ground truth")
ax.legend()

if show is True and fig is not None:
plt.show()

return ax


def download_if_necessary(download_path: str, url: str, verbose: bool = True) -> None:
"""
Expand Down

0 comments on commit e1f2fcb

Please sign in to comment.