diff --git a/skpro/distributions/base/_base.py b/skpro/distributions/base/_base.py index 3c528e24d..2c58c7d1a 100644 --- a/skpro/distributions/base/_base.py +++ b/skpro/distributions/base/_base.py @@ -11,7 +11,10 @@ import pandas as pd from skpro.base import BaseObject -from skpro.utils.validation._dependencies import _check_estimator_deps +from skpro.utils.validation._dependencies import ( + _check_estimator_deps, + _check_soft_dependencies, +) class BaseDistribution(BaseObject): @@ -1113,6 +1116,130 @@ def gen_unif(): raise NotImplementedError(self._method_error_msg("sample", "error")) + def plot(self, fun="pdf", ax=None, **kwargs): + """Plot the distribution. + + Different distribution defining functions can be selected for plotting + via the ``fun`` parameter. + The functions available are the same as the methods of the distribution class, + e.g., ``"pdf"``, ``"cdf"``, ``"ppf"``. + + For array distribution, the marginal distribution at each entry is plotted, + as a separate subplot. + + Parameters + ---------- + fun : str, optional, default="pdf" + the function to plot, one of "pdf", "cdf", "ppf" + ax : matplotlib Axes object, optional + matplotlib Axes to plot in + if not provided, defaults to current axes (``plot.gca``) + kwargs : keyword arguments + passed to the plotting function + + Returns + ------- + fig : matplotlib.Figure, only returned if self is array distribution + matplotlig Figure object for subplots + ax : matplotlib.Axes + the axis or axes on which the plot is drawn + """ + _check_soft_dependencies("matplotlib", obj="distribution plot") + + from matplotlib.pyplot import subplots + + if self.ndim > 0: + if "x_bounds" not in kwargs: + upper = self.ppf(0.999).values.flatten().max() + lower = self.ppf(0.001).values.flatten().min() + x_bounds = (lower, upper) + else: + x_bounds = kwargs.pop("x_bounds") + if "sharex" not in kwargs: + sharex = True + else: + sharex = kwargs.pop("sharex") + if "sharey" not in kwargs: + sharey = True + else: + sharey = kwargs.pop("sharey") + + x_argname = _get_first_argname(getattr(self, fun)) + + def get_ax(ax, i, j, shape): + """Get axes at iloc i, j - API unifier for 2D and 1D subplot figures. + + Covers inconsistency in matplotlib where creation of (m, 1) matrix + of subplots creates a 1D object and not a 2D object. + """ + if shape[1] == 1: + return ax[i] + else: + return ax[i, j] + + shape = self.shape + fig, ax = subplots(shape[0], shape[1], sharex=sharex, sharey=sharey) + for i, j in np.ndindex(shape): + d_ij = self.iloc[i, j] + ax_ij = get_ax(ax, i, j, shape) + d_ij.plot( + fun=fun, + ax=ax_ij, + x_bounds=x_bounds, + print_labels="off", + x_argname=x_argname, + **kwargs, + ) + for i in range(shape[0]): + ax_i0 = get_ax(ax, i, 0, shape) + ax_i0.set_ylabel(f"{self.index[i]}") + for j in range(shape[1]): + ax_0j = get_ax(ax, 0, j, shape) + ax_0j.set_title(f"{self.columns[j]}") + fig.supylabel(f"{fun}({x_argname})") + fig.supxlabel(f"{x_argname}") + return fig, ax + + # for now, all plots default ot this function + # but this could be changed to a dispatch mechanism + # e.g., using this line instead + # plot_fun_name = f"_plot_{fun}" + plot_fun_name = "_plot_single" + + ax = getattr(self, plot_fun_name)(ax=ax, fun=fun, **kwargs) + return ax + + def _plot_single(self, ax=None, **kwargs): + """Plot the pdf of the distribution.""" + import matplotlib.pyplot as plt + + fun = kwargs.pop("fun") + print_labels = kwargs.pop("print_labels", "on") + x_argname = kwargs.pop("x_argname", "x") + + # obtain x axis bounds for plotting + if "x_bounds" in kwargs: + lower, upper = kwargs.pop("x_bounds") + elif fun != "ppf": + lower, upper = self.ppf(0.001), self.ppf(0.999) + + if fun == "ppf": + lower, upper = 0.001, 0.999 + + x_arr = np.linspace(lower, upper, 100) + y_arr = [getattr(self, fun)(x) for x in x_arr] + y_arr = np.array(y_arr) + + if ax is None: + ax = plt.gca() + + ax.plot(x_arr, y_arr, **kwargs) + + if print_labels == "on": + ax.set_xlabel(f"{x_argname}") + ax.set_ylabel(f"{fun}({x_argname})") + return ax + class _Indexer: """Indexer for BaseDistribution, for pandas-like index in loc and iloc property.""" @@ -1314,6 +1441,13 @@ def is_scalar_notnone(obj): return obj is not None and np.isscalar(obj) +def _get_first_argname(fun): + """Get the name of the first argument of a function as str.""" + from inspect import signature + + return list(signature(fun).parameters.keys())[0] + + def _coerce_to_pd_index_or_none(x): """Coerce to pd.Index, if not None, else return None.""" if x is None: diff --git a/skpro/distributions/tests/test_proba_basic.py b/skpro/distributions/tests/test_proba_basic.py index f6cda5462..6df048018 100644 --- a/skpro/distributions/tests/test_proba_basic.py +++ b/skpro/distributions/tests/test_proba_basic.py @@ -4,9 +4,12 @@ __author__ = ["fkiraly"] +import numpy as np import pandas as pd import pytest +from skpro.utils.validation._dependencies import _check_soft_dependencies + def test_proba_example(): """Test one subsetting case for BaseDistribution.""" @@ -97,3 +100,39 @@ def test_proba_index_coercion(): assert isinstance(n.columns, pd.Index) assert n.index.equals(pd.RangeIndex(1)) assert n.columns.equals(pd.Index([1, 2, 3])) + + +@pytest.mark.skipif( + not _check_soft_dependencies("matplotlib", severity="none"), + reason="skip if matplotlib is not available", +) +@pytest.mark.parametrize("fun", ["pdf", "ppf", "cdf"]) +def test_proba_plotting(fun): + """Test that plotting functions do not crash and return ax as expected.""" + from matplotlib.axes import Axes + from matplotlib.figure import Figure + + from skpro.distributions.normal import Normal + + # default case, 2D distribution with n_columns>1 + n = Normal(mu=[[0, 1], [2, 3], [4, 5]], sigma=1) + fig, ax = n.plot(fun=fun) + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + assert ax.shape == n.shape + assert all([isinstance(a, Axes) for a in ax.flatten()]) + assert all([a.get_figure() == fig for a in ax.flatten()]) + + # 1D case requires special treatment of axes + n = Normal(mu=[[1], [2], [3]], sigma=1) + fig, ax = n.plot(fun=fun) + assert isinstance(fig, Figure) + assert isinstance(ax, type(ax)) + assert ax.shape == (n.shape[0],) + assert all([isinstance(a, Axes) for a in ax.flatten()]) + assert all([a.get_figure() == fig for a in ax.flatten()]) + + # scalar case + n = Normal(mu=1, sigma=1) + ax = n.plot(fun=fun) + assert isinstance(ax, Axes)