Skip to content

Commit

Permalink
Merge pull request #19 from arviz-devs/distplot
Browse files Browse the repository at this point in the history
rename plot_posterior to plot_dist and small improvements
  • Loading branch information
aloctavodia authored Oct 30, 2023
2 parents 2e9f219 + d8f0c2e commit e22af52
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 219 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ repos:
rev: v4.3.0
hooks:
- id: check-added-large-files
args: ['--maxkb=1500']
- id: check-merge-conflict

- repo: https://github.com/PyCQA/isort
Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
.. autosummary::
:toctree: generated/
arviz_plots.plot_posterior
arviz_plots.plot_dist
arviz_plots.plot_trace
```

Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Plotting backends have to be added manually.
:caption: Example notebooks
tutorials/intro_to_plotmuseum
tutorials/plot_posterior_examples
tutorials/plot_dist_examples
```

```{toctree}
Expand Down
187 changes: 187 additions & 0 deletions docs/source/tutorials/plot_dist_examples.ipynb

Large diffs are not rendered by default.

187 changes: 0 additions & 187 deletions docs/source/tutorials/plot_posterior_examples.ipynb

This file was deleted.

4 changes: 2 additions & 2 deletions src/arviz_plots/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Batteries-included ArviZ plots."""

from .posteriorplot import plot_posterior
from .posteriorplot import plot_dist
from .traceplot import plot_trace

__all__ = ["plot_posterior", "plot_trace"]
__all__ = ["plot_dist", "plot_trace"]
45 changes: 34 additions & 11 deletions src/arviz_plots/plots/posteriorplot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Posterior plot code."""
"""dist plot code."""
import arviz_stats # pylint: disable=unused-import
import xarray as xr
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller
from arviz_base.utils import _var_names

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import filter_aes
Expand All @@ -16,8 +17,11 @@
)


def plot_posterior(
ds,
def plot_dist(
dt,
var_names=None,
filter_vars=None,
group="posterior",
sample_dims=None,
kind=None,
point_estimate=None,
Expand All @@ -35,8 +39,17 @@ def plot_posterior(
Parameters
----------
ds : Dataset
dt : DataTree
Input data
var_names: str or list of str, optional
One or more variables to be plotted.
Prefix the variables by ~ when you want to exclude them from the plot.
filter_vars: {None, “like”, “regex”}, optional, default=None
If None (default), interpret var_names as the real variables names.
If “like”, interpret var_names as substrings of the real variables names.
If “regex”, interpret var_names as regular expressions on the real variables names.
group : str, optional
Group to be plotted. Defaults to ``posterior``
sample_dims : iterable, optional
Dimensions to reduce unless mapped to an aesthetic.
Defaults to ``rcParams["data.sample_dims"]``
Expand Down Expand Up @@ -96,15 +109,21 @@ def plot_posterior(
if stats_kwargs is None:
stats_kwargs = {}

distribution = dt[group].ds
var_names = _var_names(var_names, distribution, filter_vars)

if var_names is not None:
distribution = dt[group].ds[var_names]

if plot_collection is None:
if backend is None:
backend = rcParams["plot.backend"]
pc_kwargs.setdefault("col_wrap", 5)
pc_kwargs.setdefault(
"cols", ["__variable__"] + [dim for dim in ds.dims if dim not in sample_dims]
"cols", ["__variable__"] + [dim for dim in distribution.dims if dim not in sample_dims]
)
plot_collection = PlotCollection.wrap(
ds,
distribution,
backend=backend,
**pc_kwargs,
)
Expand All @@ -117,7 +136,7 @@ def plot_posterior(
# density
density_dims, _, density_ignore = filter_aes(plot_collection, aes_map, "kde", sample_dims)
if kind == "kde":
density = ds.azstats.kde(dims=density_dims, **stats_kwargs.get("density", {}))
density = distribution.azstats.kde(dims=density_dims, **stats_kwargs.get("density", {}))
plot_collection.map(
line_xy, "kde", data=density, ignore_aes=density_ignore, **plot_kwargs.get("kde", {})
)
Expand All @@ -129,9 +148,13 @@ def plot_posterior(
plot_collection, aes_map, "credible_interval", sample_dims
)
if ci_kind == "eti":
ci = ds.azstats.eti(prob=ci_prob, dims=ci_dims, **stats_kwargs.get("credible_interval", {}))
ci = distribution.azstats.eti(
prob=ci_prob, dims=ci_dims, **stats_kwargs.get("credible_interval", {})
)
elif ci_kind == "hdi":
ci = ds.azstats.hdi(prob=ci_prob, dims=ci_dims, **stats_kwargs.get("credible_interval", {}))
ci = distribution.azstats.hdi(
prob=ci_prob, dims=ci_dims, **stats_kwargs.get("credible_interval", {})
)
else:
raise NotImplementedError("coming soon")
ci_kwargs = plot_kwargs.get("credible_interval", {}).copy()
Expand All @@ -142,9 +165,9 @@ def plot_posterior(
# point estimate
pe_dims, pe_aes, pe_ignore = filter_aes(plot_collection, aes_map, "point_estimate", sample_dims)
if point_estimate == "median":
point = ds.median(dim=pe_dims, **stats_kwargs.get("point_estimate", {}))
point = distribution.median(dim=pe_dims, **stats_kwargs.get("point_estimate", {}))
elif point_estimate == "mean":
point = ds.mean(dim=pe_dims, **stats_kwargs.get("point_estimate", {}))
point = distribution.mean(dim=pe_dims, **stats_kwargs.get("point_estimate", {}))
else:
raise NotImplementedError("coming soon")
point_density_diff = [dim for dim in density.sel(plot_axis="y").dims if dim not in point.dims]
Expand Down
22 changes: 5 additions & 17 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,9 @@
"""Test batteries-included plots."""
import numpy as np
import pytest
import xarray as xr
from arviz_base import convert_to_datatree
from arviz_base import from_dict

from arviz_plots import plot_posterior, plot_trace


@pytest.fixture(scope="module")
def data(seed=31):
rng = np.random.default_rng(seed)
mu = rng.normal(size=(4, 100))
theta = rng.normal(size=(4, 100, 7))

return xr.Dataset(
{"mu": (["chain", "draw"], mu), "theta": (["chain", "draw", "hierarchy"], theta)},
)
from arviz_plots import plot_dist, plot_trace


@pytest.fixture(scope="module")
Expand All @@ -25,13 +13,13 @@ def datatree(seed=31):
mu = rng.normal(size=(4, 100))
theta = rng.normal(size=(4, 100, 7))

return convert_to_datatree({"mu": mu, "theta": theta})
return from_dict({"posterior": {"mu": mu, "theta": theta}}, dims={"theta": ["hierarchy"]})


@pytest.mark.parametrize("backend", ["matplotlib", "bokeh"])
class TestPlots:
def test_plot_posterior(self, data, backend):
pc = plot_posterior(data, backend=backend)
def test_plot_dist(self, datatree, backend):
pc = plot_dist(datatree, backend=backend)
assert not pc.aes["mu"]
assert "kde" in pc.viz["mu"]
assert "hierarchy" not in pc.viz["mu"].dims
Expand Down

0 comments on commit e22af52

Please sign in to comment.