From a9bdd92459bf4f1468e54536f481a28f5e81354f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 25 Apr 2024 21:37:43 +0100 Subject: [PATCH] [ENH] probability distributions: illustrative and didactic plotting functionality (#275) This PR adds a `plot` method to distributions, which allows users to create plots of various distribution defining functions. For array distributions, an array of marginal subplots is created, of the corresponding plots of the marginal distributions. The array of plots is axis aligned for comparability. --- skpro/distributions/base/_base.py | 136 +++++++++++++++++- skpro/distributions/tests/test_proba_basic.py | 39 +++++ 2 files changed, 174 insertions(+), 1 deletion(-) diff --git a/skpro/distributions/base/_base.py b/skpro/distributions/base/_base.py index 3c528e24d..2c58c7d1a 100644 --- a/skpro/distributions/base/_base.py +++ b/skpro/distributions/base/_base.py @@ -11,7 +11,10 @@ import pandas as pd from skpro.base import BaseObject -from skpro.utils.validation._dependencies import _check_estimator_deps +from skpro.utils.validation._dependencies import ( + _check_estimator_deps, + _check_soft_dependencies, +) class BaseDistribution(BaseObject): @@ -1113,6 +1116,130 @@ def gen_unif(): raise NotImplementedError(self._method_error_msg("sample", "error")) + def plot(self, fun="pdf", ax=None, **kwargs): + """Plot the distribution. + + Different distribution defining functions can be selected for plotting + via the ``fun`` parameter. + The functions available are the same as the methods of the distribution class, + e.g., ``"pdf"``, ``"cdf"``, ``"ppf"``. + + For array distribution, the marginal distribution at each entry is plotted, + as a separate subplot. + + Parameters + ---------- + fun : str, optional, default="pdf" + the function to plot, one of "pdf", "cdf", "ppf" + ax : matplotlib Axes object, optional + matplotlib Axes to plot in + if not provided, defaults to current axes (``plot.gca``) + kwargs : keyword arguments + passed to the plotting function + + Returns + ------- + fig : matplotlib.Figure, only returned if self is array distribution + matplotlig Figure object for subplots + ax : matplotlib.Axes + the axis or axes on which the plot is drawn + """ + _check_soft_dependencies("matplotlib", obj="distribution plot") + + from matplotlib.pyplot import subplots + + if self.ndim > 0: + if "x_bounds" not in kwargs: + upper = self.ppf(0.999).values.flatten().max() + lower = self.ppf(0.001).values.flatten().min() + x_bounds = (lower, upper) + else: + x_bounds = kwargs.pop("x_bounds") + if "sharex" not in kwargs: + sharex = True + else: + sharex = kwargs.pop("sharex") + if "sharey" not in kwargs: + sharey = True + else: + sharey = kwargs.pop("sharey") + + x_argname = _get_first_argname(getattr(self, fun)) + + def get_ax(ax, i, j, shape): + """Get axes at iloc i, j - API unifier for 2D and 1D subplot figures. + + Covers inconsistency in matplotlib where creation of (m, 1) matrix + of subplots creates a 1D object and not a 2D object. + """ + if shape[1] == 1: + return ax[i] + else: + return ax[i, j] + + shape = self.shape + fig, ax = subplots(shape[0], shape[1], sharex=sharex, sharey=sharey) + for i, j in np.ndindex(shape): + d_ij = self.iloc[i, j] + ax_ij = get_ax(ax, i, j, shape) + d_ij.plot( + fun=fun, + ax=ax_ij, + x_bounds=x_bounds, + print_labels="off", + x_argname=x_argname, + **kwargs, + ) + for i in range(shape[0]): + ax_i0 = get_ax(ax, i, 0, shape) + ax_i0.set_ylabel(f"{self.index[i]}") + for j in range(shape[1]): + ax_0j = get_ax(ax, 0, j, shape) + ax_0j.set_title(f"{self.columns[j]}") + fig.supylabel(f"{fun}({x_argname})") + fig.supxlabel(f"{x_argname}") + return fig, ax + + # for now, all plots default ot this function + # but this could be changed to a dispatch mechanism + # e.g., using this line instead + # plot_fun_name = f"_plot_{fun}" + plot_fun_name = "_plot_single" + + ax = getattr(self, plot_fun_name)(ax=ax, fun=fun, **kwargs) + return ax + + def _plot_single(self, ax=None, **kwargs): + """Plot the pdf of the distribution.""" + import matplotlib.pyplot as plt + + fun = kwargs.pop("fun") + print_labels = kwargs.pop("print_labels", "on") + x_argname = kwargs.pop("x_argname", "x") + + # obtain x axis bounds for plotting + if "x_bounds" in kwargs: + lower, upper = kwargs.pop("x_bounds") + elif fun != "ppf": + lower, upper = self.ppf(0.001), self.ppf(0.999) + + if fun == "ppf": + lower, upper = 0.001, 0.999 + + x_arr = np.linspace(lower, upper, 100) + y_arr = [getattr(self, fun)(x) for x in x_arr] + y_arr = np.array(y_arr) + + if ax is None: + ax = plt.gca() + + ax.plot(x_arr, y_arr, **kwargs) + + if print_labels == "on": + ax.set_xlabel(f"{x_argname}") + ax.set_ylabel(f"{fun}({x_argname})") + return ax + class _Indexer: """Indexer for BaseDistribution, for pandas-like index in loc and iloc property.""" @@ -1314,6 +1441,13 @@ def is_scalar_notnone(obj): return obj is not None and np.isscalar(obj) +def _get_first_argname(fun): + """Get the name of the first argument of a function as str.""" + from inspect import signature + + return list(signature(fun).parameters.keys())[0] + + def _coerce_to_pd_index_or_none(x): """Coerce to pd.Index, if not None, else return None.""" if x is None: diff --git a/skpro/distributions/tests/test_proba_basic.py b/skpro/distributions/tests/test_proba_basic.py index f6cda5462..6df048018 100644 --- a/skpro/distributions/tests/test_proba_basic.py +++ b/skpro/distributions/tests/test_proba_basic.py @@ -4,9 +4,12 @@ __author__ = ["fkiraly"] +import numpy as np import pandas as pd import pytest +from skpro.utils.validation._dependencies import _check_soft_dependencies + def test_proba_example(): """Test one subsetting case for BaseDistribution.""" @@ -97,3 +100,39 @@ def test_proba_index_coercion(): assert isinstance(n.columns, pd.Index) assert n.index.equals(pd.RangeIndex(1)) assert n.columns.equals(pd.Index([1, 2, 3])) + + +@pytest.mark.skipif( + not _check_soft_dependencies("matplotlib", severity="none"), + reason="skip if matplotlib is not available", +) +@pytest.mark.parametrize("fun", ["pdf", "ppf", "cdf"]) +def test_proba_plotting(fun): + """Test that plotting functions do not crash and return ax as expected.""" + from matplotlib.axes import Axes + from matplotlib.figure import Figure + + from skpro.distributions.normal import Normal + + # default case, 2D distribution with n_columns>1 + n = Normal(mu=[[0, 1], [2, 3], [4, 5]], sigma=1) + fig, ax = n.plot(fun=fun) + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + assert ax.shape == n.shape + assert all([isinstance(a, Axes) for a in ax.flatten()]) + assert all([a.get_figure() == fig for a in ax.flatten()]) + + # 1D case requires special treatment of axes + n = Normal(mu=[[1], [2], [3]], sigma=1) + fig, ax = n.plot(fun=fun) + assert isinstance(fig, Figure) + assert isinstance(ax, type(ax)) + assert ax.shape == (n.shape[0],) + assert all([isinstance(a, Axes) for a in ax.flatten()]) + assert all([a.get_figure() == fig for a in ax.flatten()]) + + # scalar case + n = Normal(mu=1, sigma=1) + ax = n.plot(fun=fun) + assert isinstance(ax, Axes)