Skip to content

Commit

Permalink
[ENH] probability distributions: illustrative and didactic plotting f…
Browse files Browse the repository at this point in the history
…unctionality (#275)

This PR adds a `plot` method to distributions, which allows users to
create plots of various distribution defining functions.

For array distributions, an array of marginal subplots is created, of
the corresponding plots of the marginal distributions.
The array of plots is axis aligned for comparability.
  • Loading branch information
fkiraly authored Apr 25, 2024
1 parent c3ef686 commit a9bdd92
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 1 deletion.
136 changes: 135 additions & 1 deletion skpro/distributions/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import pandas as pd

from skpro.base import BaseObject
from skpro.utils.validation._dependencies import _check_estimator_deps
from skpro.utils.validation._dependencies import (
_check_estimator_deps,
_check_soft_dependencies,
)


class BaseDistribution(BaseObject):
Expand Down Expand Up @@ -1113,6 +1116,130 @@ def gen_unif():

raise NotImplementedError(self._method_error_msg("sample", "error"))

def plot(self, fun="pdf", ax=None, **kwargs):
"""Plot the distribution.
Different distribution defining functions can be selected for plotting
via the ``fun`` parameter.
The functions available are the same as the methods of the distribution class,
e.g., ``"pdf"``, ``"cdf"``, ``"ppf"``.
For array distribution, the marginal distribution at each entry is plotted,
as a separate subplot.
Parameters
----------
fun : str, optional, default="pdf"
the function to plot, one of "pdf", "cdf", "ppf"
ax : matplotlib Axes object, optional
matplotlib Axes to plot in
if not provided, defaults to current axes (``plot.gca``)
kwargs : keyword arguments
passed to the plotting function
Returns
-------
fig : matplotlib.Figure, only returned if self is array distribution
matplotlig Figure object for subplots
ax : matplotlib.Axes
the axis or axes on which the plot is drawn
"""
_check_soft_dependencies("matplotlib", obj="distribution plot")

from matplotlib.pyplot import subplots

if self.ndim > 0:
if "x_bounds" not in kwargs:
upper = self.ppf(0.999).values.flatten().max()
lower = self.ppf(0.001).values.flatten().min()
x_bounds = (lower, upper)
else:
x_bounds = kwargs.pop("x_bounds")
if "sharex" not in kwargs:
sharex = True
else:
sharex = kwargs.pop("sharex")
if "sharey" not in kwargs:
sharey = True
else:
sharey = kwargs.pop("sharey")

x_argname = _get_first_argname(getattr(self, fun))

def get_ax(ax, i, j, shape):
"""Get axes at iloc i, j - API unifier for 2D and 1D subplot figures.
Covers inconsistency in matplotlib where creation of (m, 1) matrix
of subplots creates a 1D object and not a 2D object.
"""
if shape[1] == 1:
return ax[i]
else:
return ax[i, j]

shape = self.shape
fig, ax = subplots(shape[0], shape[1], sharex=sharex, sharey=sharey)
for i, j in np.ndindex(shape):
d_ij = self.iloc[i, j]
ax_ij = get_ax(ax, i, j, shape)
d_ij.plot(
fun=fun,
ax=ax_ij,
x_bounds=x_bounds,
print_labels="off",
x_argname=x_argname,
**kwargs,
)
for i in range(shape[0]):
ax_i0 = get_ax(ax, i, 0, shape)
ax_i0.set_ylabel(f"{self.index[i]}")
for j in range(shape[1]):
ax_0j = get_ax(ax, 0, j, shape)
ax_0j.set_title(f"{self.columns[j]}")
fig.supylabel(f"{fun}({x_argname})")
fig.supxlabel(f"{x_argname}")
return fig, ax

# for now, all plots default ot this function
# but this could be changed to a dispatch mechanism
# e.g., using this line instead
# plot_fun_name = f"_plot_{fun}"
plot_fun_name = "_plot_single"

ax = getattr(self, plot_fun_name)(ax=ax, fun=fun, **kwargs)
return ax

def _plot_single(self, ax=None, **kwargs):
"""Plot the pdf of the distribution."""
import matplotlib.pyplot as plt

fun = kwargs.pop("fun")
print_labels = kwargs.pop("print_labels", "on")
x_argname = kwargs.pop("x_argname", "x")

# obtain x axis bounds for plotting
if "x_bounds" in kwargs:
lower, upper = kwargs.pop("x_bounds")
elif fun != "ppf":
lower, upper = self.ppf(0.001), self.ppf(0.999)

if fun == "ppf":
lower, upper = 0.001, 0.999

x_arr = np.linspace(lower, upper, 100)
y_arr = [getattr(self, fun)(x) for x in x_arr]
y_arr = np.array(y_arr)

if ax is None:
ax = plt.gca()

ax.plot(x_arr, y_arr, **kwargs)

if print_labels == "on":
ax.set_xlabel(f"{x_argname}")
ax.set_ylabel(f"{fun}({x_argname})")
return ax


class _Indexer:
"""Indexer for BaseDistribution, for pandas-like index in loc and iloc property."""
Expand Down Expand Up @@ -1314,6 +1441,13 @@ def is_scalar_notnone(obj):
return obj is not None and np.isscalar(obj)


def _get_first_argname(fun):
"""Get the name of the first argument of a function as str."""
from inspect import signature

return list(signature(fun).parameters.keys())[0]


def _coerce_to_pd_index_or_none(x):
"""Coerce to pd.Index, if not None, else return None."""
if x is None:
Expand Down
39 changes: 39 additions & 0 deletions skpro/distributions/tests/test_proba_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

__author__ = ["fkiraly"]

import numpy as np
import pandas as pd
import pytest

from skpro.utils.validation._dependencies import _check_soft_dependencies


def test_proba_example():
"""Test one subsetting case for BaseDistribution."""
Expand Down Expand Up @@ -97,3 +100,39 @@ def test_proba_index_coercion():
assert isinstance(n.columns, pd.Index)
assert n.index.equals(pd.RangeIndex(1))
assert n.columns.equals(pd.Index([1, 2, 3]))


@pytest.mark.skipif(
not _check_soft_dependencies("matplotlib", severity="none"),
reason="skip if matplotlib is not available",
)
@pytest.mark.parametrize("fun", ["pdf", "ppf", "cdf"])
def test_proba_plotting(fun):
"""Test that plotting functions do not crash and return ax as expected."""
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from skpro.distributions.normal import Normal

# default case, 2D distribution with n_columns>1
n = Normal(mu=[[0, 1], [2, 3], [4, 5]], sigma=1)
fig, ax = n.plot(fun=fun)
assert isinstance(fig, Figure)
assert isinstance(ax, np.ndarray)
assert ax.shape == n.shape
assert all([isinstance(a, Axes) for a in ax.flatten()])
assert all([a.get_figure() == fig for a in ax.flatten()])

# 1D case requires special treatment of axes
n = Normal(mu=[[1], [2], [3]], sigma=1)
fig, ax = n.plot(fun=fun)
assert isinstance(fig, Figure)
assert isinstance(ax, type(ax))
assert ax.shape == (n.shape[0],)
assert all([isinstance(a, Axes) for a in ax.flatten()])
assert all([a.get_figure() == fig for a in ax.flatten()])

# scalar case
n = Normal(mu=1, sigma=1)
ax = n.plot(fun=fun)
assert isinstance(ax, Axes)

0 comments on commit a9bdd92

Please sign in to comment.