Skip to content

Commit

Permalink
Add plot_interactive to MvNormal (#337)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohanbabbar04 authored Mar 4, 2024
1 parent 6afa0e5 commit 6441958
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 5 deletions.
103 changes: 103 additions & 0 deletions preliz/distributions/continuous_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,3 +527,106 @@ def plot_ppf(
return plot_mvnormal(
self, "ppf", "marginals", pointinterval, interval, levels, None, figsize, ax
)

def plot_interactive(
self,
kind="pdf",
xy_lim="both",
pointinterval=True,
interval="hdi",
levels=None,
figsize=None,
):
"""
Interactive exploration of parameters
Parameters
----------
kind : str:
Type of plot. Available options are `pdf`, `cdf` and `ppf`.
xy_lim : str or tuple
Set the limits of the x-axis and/or y-axis.
Defaults to `"both"`, the limits of both axes are fixed for all subplots.
Use `"auto"` for automatic rescaling of x-axis and y-axis.
Or set them manually by passing a tuple of 4 elements,
the first two for x-axis, the last two for y-axis. The tuple can have `None`.
pointinterval : bool
Whether to include a plot of the quantiles. Defaults to False.
If `True` the default is to plot the median and two inter-quantiles ranges.
interval : str
Type of interval. Available options are the highest density interval `"hdi"` (default),
equal tailed interval `"eti"` or intervals defined by arbitrary `"quantiles"`.
levels : list
Mass of the intervals. For hdi or eti the number of elements should be 2 or 1.
For quantiles the number of elements should be 5, 3, 1 or 0
(in this last case nothing will be plotted).
figsize : tuple
Size of the figure
"""

check_inside_notebook()

args = dict(zip(self.param_names, self.params))
cov, tau = args.get("cov", None), args.get("tau", None)
self.__init__(**args) # pylint: disable=unnecessary-dunder-call
if kind == "pdf":
w_checkbox_marginals = widgets.Checkbox(
value=True,
description="marginals",
disabled=False,
indent=False,
)
plot_widgets = {"marginals": w_checkbox_marginals}
else:
plot_widgets = {}
for index, mu in enumerate(self.params[0]):
plot_widgets[f"mu-{index + 1}"] = get_slider(
f"mu-{index + 1}", mu, *self.params_support[0]
)

def plot(**args):
if kind == "pdf":
marginals = args.pop("marginals")
params = {"mu": np.asarray(list(args.values()), dtype=float), "cov": cov, "tau": tau}
self.__init__(**params) # pylint: disable=unnecessary-dunder-call
if kind == "pdf":
plot_mvnormal(
self,
"pdf",
marginals,
pointinterval,
interval,
levels,
"full",
figsize,
None,
xy_lim,
)
elif kind == "cdf":
plot_mvnormal(
self,
"cdf",
"marginals",
pointinterval,
interval,
levels,
"full",
figsize,
None,
xy_lim,
)
elif kind == "ppf":
plot_mvnormal(
self,
"cdf",
"marginals",
pointinterval,
interval,
levels,
None,
figsize,
None,
xy_lim,
)

return interactive(plot, **plot_widgets)
36 changes: 31 additions & 5 deletions preliz/internal/plot_helper_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,16 @@ def find_hdi_contours(density, hdi_probs):


def plot_mvnormal(
dist, representation, marginals, pointinterval, interval, levels, support, figsize, axes
dist,
representation,
marginals,
pointinterval,
interval,
levels,
support,
figsize,
axes,
xy_lim="auto",
):
"""Plot pdf, cdf or ppf of Multivariate Normal distribution."""

Expand All @@ -233,6 +242,10 @@ def plot_mvnormal(
if figsize is None:
figsize = (12, 4)

if isinstance(xy_lim, tuple):
xlim = xy_lim[:2]
ylim = xy_lim[2:]

if marginals:
cols, rows = get_cols_rows(dim)

Expand All @@ -244,8 +257,18 @@ def plot_mvnormal(
ax.remove()

for mu_i, sigma_i, ax in zip(mu, sigma, axes):
marginal_dist = dist.marginal(mu_i, sigma_i)
if xy_lim == "both":
xlim = marginal_dist._finite_endpoints("full")
xvals = marginal_dist.xvals("restricted")
if representation == "pdf":
max_pdf = np.max(marginal_dist.pdf(xvals))
ylim = (-max_pdf * 0.075, max_pdf * 1.5)
elif representation == "ppf":
max_ppf = marginal_dist.ppf(0.999)
ylim = (-max_ppf * 0.075, max_ppf * 1.5)
if representation == "pdf":
dist.marginal(mu_i, sigma_i).plot_pdf(
marginal_dist.plot_pdf(
pointinterval=pointinterval,
interval=interval,
levels=levels,
Expand All @@ -254,7 +277,7 @@ def plot_mvnormal(
ax=ax,
)
elif representation == "cdf":
dist.marginal(mu_i, sigma_i).plot_cdf(
marginal_dist.plot_cdf(
pointinterval=pointinterval,
interval=interval,
levels=levels,
Expand All @@ -263,14 +286,17 @@ def plot_mvnormal(
ax=ax,
)
elif representation == "ppf":
dist.marginal(mu_i, sigma_i).plot_ppf(
marginal_dist.plot_ppf(
pointinterval=pointinterval,
interval=interval,
levels=levels,
legend=False,
ax=ax,
)

if xy_lim != "auto" and representation != "ppf":
ax.set_xlim(*xlim)
if xy_lim != "auto" and representation != "cdf":
ax.set_ylim(*ylim)
fig.text(0.5, 1, repr_to_matplotlib(dist), ha="center", va="center")

else:
Expand Down
24 changes: 24 additions & 0 deletions preliz/tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,30 @@ def test_mvnormal_plot(kwargs):
a_dist.plot_ppf(**kwargs)


@pytest.mark.parametrize(
"kwargs",
[
{},
{"xy_lim": "auto"},
{"pointinterval": True, "xy_lim": "auto"},
{"pointinterval": True, "levels": [0.1, 0.9], "xy_lim": "both"},
{"pointinterval": True, "interval": "eti", "levels": [0.9], "xy_lim": (0.3, 0.9, None, 1)},
{"pointinterval": True, "interval": "quantiles", "xy_lim": "both"},
{"pointinterval": True, "interval": "quantiles", "levels": [0.1, 0.5, 0.9]},
{"pointinterval": False, "figsize": (4, 4)},
],
)
def test_plot_interactive_mvnormal(kwargs):
mvnormal_tau = pz.MvNormal(mu=[-1, 2.4], tau=[[1, 0], [1, 1]])
mvnormal_cov = pz.MvNormal(mu=[3, -2], cov=[[1, 0], [0, 1]])
mvnormal_tau.plot_interactive(kind="pdf", **kwargs)
mvnormal_cov.plot_interactive(kind="pdf", **kwargs)
mvnormal_tau.plot_interactive(kind="cdf", **kwargs)
mvnormal_cov.plot_interactive(kind="cdf", **kwargs)
mvnormal_tau.plot_interactive(kind="ppf", **kwargs)
mvnormal_cov.plot_interactive(kind="ppf", **kwargs)


@pytest.fixture
def sample_ax():
return plt.subplot()
Expand Down

0 comments on commit 6441958

Please sign in to comment.