Skip to content

Commit

Permalink
Merge pull request #23 from arviz-devs/ecdf
Browse files Browse the repository at this point in the history
add ecdf to distplot
  • Loading branch information
aloctavodia authored Dec 14, 2023
2 parents 1fd1267 + 2f81a7b commit e1a36c5
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
37 changes: 30 additions & 7 deletions src/arviz_plots/plots/distplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import filter_aes
from arviz_plots.visuals import (
ecdf_line,
labelled_title,
line_x,
line_xy,
Expand Down Expand Up @@ -89,6 +90,9 @@ def plot_dist(
-------
PlotCollection
"""
if ci_kind not in ["hdi", "eti", None]:
raise ValueError("ci_kind must be either 'hdi' or 'eti'")

if sample_dims is None:
sample_dims = rcParams["data.sample_dims"]
if ci_prob is None:
Expand Down Expand Up @@ -129,17 +133,30 @@ def plot_dist(
)

if aes_map is None:
aes_map = {"kde": plot_collection.aes_set}
aes_map = {kind: plot_collection.aes_set}
if labeller is None:
labeller = BaseLabeller()

# density
density_dims, _, density_ignore = filter_aes(plot_collection, aes_map, "kde", sample_dims)
density_dims, _, density_ignore = filter_aes(plot_collection, aes_map, kind, sample_dims)

if kind == "kde":
density = distribution.azstats.kde(dims=density_dims, **stats_kwargs.get("density", {}))
plot_collection.map(
line_xy, "kde", data=density, ignore_aes=density_ignore, **plot_kwargs.get("kde", {})
)

elif kind == "ecdf":
density = distribution.azstats.ecdf(dims=density_dims, **stats_kwargs.get("density", {}))
print(density)
plot_collection.map(
ecdf_line,
"ecdf",
data=density,
ignore_aes=density_ignore,
**plot_kwargs.get("ecdf", {}),
)

else:
raise NotImplementedError("coming soon")

Expand All @@ -155,8 +172,7 @@ def plot_dist(
ci = distribution.azstats.hdi(
prob=ci_prob, dims=ci_dims, **stats_kwargs.get("credible_interval", {})
)
else:
raise NotImplementedError("coming soon")

ci_kwargs = plot_kwargs.get("credible_interval", {}).copy()
if "color" not in ci_aes:
ci_kwargs.setdefault("color", "gray")
Expand All @@ -170,13 +186,20 @@ def plot_dist(
point = distribution.mean(dim=pe_dims, **stats_kwargs.get("point_estimate", {}))
else:
raise NotImplementedError("coming soon")

point_density_diff = [dim for dim in density.sel(plot_axis="y").dims if dim not in point.dims]
point_y = 0.03 * density.sel(plot_axis="y", drop=True).max(dim=["kde_dim"] + point_density_diff)
if kind == "kde":
point_y = 0.03 * density.sel(plot_axis="y", drop=True).max(
dim=["kde_dim"] + point_density_diff
)
elif kind == "ecdf":
point_y = 0.03 * density.sel(plot_axis="y", drop=True).max(dim=point_density_diff)

point = xr.concat((point, point_y), dim="plot_axis").assign_coords(plot_axis=["x", "y"])

pe_kwargs = plot_kwargs.get("point_estimate", {}).copy()
if "color" not in pe_aes:
pe_kwargs.setdefault("color", "darkcyan")
pe_kwargs.setdefault("color", "gray")
plot_collection.map(
scatter_x,
"point_estimate",
Expand All @@ -186,7 +209,7 @@ def plot_dist(
)
pet_kwargs = plot_kwargs.get("point_estimate_text", {}).copy()
if "color" not in pe_aes:
pet_kwargs.setdefault("color", "darkcyan")
pet_kwargs.setdefault("color", "gray")
pet_kwargs.setdefault("horizontal_align", "center")
pet_kwargs.setdefault("point_label", "x")
plot_collection.map(
Expand Down
6 changes: 6 additions & 0 deletions src/arviz_plots/visuals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def scatter_x(da, target, backend, y=None, **kwargs):
return plot_backend.scatter(da, y, target, **kwargs)


def ecdf_line(values, target, backend, **kwargs):
"""Plot an ecdf line."""
plot_backend = import_module(f"arviz_plots.backend.{backend}")
return plot_backend.line(values.sel(plot_axis="x"), values.sel(plot_axis="y"), target, **kwargs)


def point_estimate_text(
da, target, backend, *, point_estimate, x=None, y=None, point_label="x", **kwargs
):
Expand Down

0 comments on commit e1a36c5

Please sign in to comment.