Skip to content

Commit

Permalink
add test and example
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Oct 23, 2024
1 parent 4881140 commit bde60ad
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 12 deletions.
23 changes: 15 additions & 8 deletions src/arviz_plots/plots/psensedistplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,17 @@ def plot_psense_dist(
Examples
--------
TBD: exaples specific to usage of arguments that are unique to plot_psense_dist or behave
differently than in the rest of the plots. i.e. we might want to have var_names,
prior_var_names and likelihood_var_names or var_names as a dict
with posterior, prior and likelihood keys allowed.
Select a single variable and generate a point-interval plot
.. plot::
:context: close-figs
>>> from arviz_plots import plot_dist, style
>>> style.use("arviz-clean")
>>> from arviz_base import load_arviz_data
>>> rugby = load_arviz_data('rugby')
>>> plot_psense_dist(rugby, var_names=["sd_att"], plot_kwargs={"kde":False})
.. minigallery:: plot_psense_dist
Expand Down Expand Up @@ -131,8 +138,8 @@ def plot_psense_dist(
# Instead we could have weighted KDEs/ecdfs/etc
ds_prior = new_ds(dt, "prior", alphas, sample_dims=sample_dims)
ds_likelihood = new_ds(dt, "likelihood", alphas, sample_dims=sample_dims)
distribution = concat([ds_prior, ds_likelihood], dim="group").assign_coords(
{"group": ["prior", "likelihood"]}
distribution = concat([ds_prior, ds_likelihood], dim="__group__").assign_coords(
{"__group__": ["prior", "likelihood"]}
)
distribution = process_group_variables_coords(
distribution, group=None, var_names=var_names, filter_vars=filter_vars, coords=coords
Expand Down Expand Up @@ -166,11 +173,11 @@ def plot_psense_dist(
pc_kwargs.setdefault("y", [-0.4, -0.225, -0.05])
pc_kwargs["aes"].setdefault("color", ["alpha"])
pc_kwargs["aes"].setdefault("y", ["alpha"])
pc_kwargs.setdefault("cols", ["group"])
pc_kwargs.setdefault("cols", ["__group__"])
pc_kwargs.setdefault(
"rows",
["__variable__"]
+ [dim for dim in distribution.dims if dim not in sample_dims + ["group", "alpha"]],
+ [dim for dim in distribution.dims if dim not in sample_dims + ["__group__", "alpha"]],
)

figsize = pc_kwargs["plot_grid_kws"].get("figsize", None)
Expand Down
65 changes: 63 additions & 2 deletions tests/test_hypothesis_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,16 @@
from arviz_base import from_dict
from datatree import DataTree
from hypothesis import given
from scipy.stats import halfnorm, norm

from arviz_plots import plot_dist, plot_ess, plot_ess_evolution, plot_forest, plot_ridge
from arviz_plots import (
plot_dist,
plot_ess,
plot_ess_evolution,
plot_forest,
plot_psense_dist,
plot_ridge,
)

pytestmark = pytest.mark.usefixtures("no_artist_kwargs")

Expand All @@ -17,13 +25,28 @@
def datatree(seed=31):
rng = np.random.default_rng(seed)
mu = rng.normal(size=(3, 50))
tau = rng.normal(size=(3, 50, 2))
tau = np.exp(rng.normal(size=(3, 50, 2)))
theta = rng.normal(size=(3, 50, 2, 3))
mu_prior = norm(0, 3).logpdf(mu)
tau_prior = halfnorm(scale=5).logpdf(tau)
theta_prior = norm(0, 1).logpdf(theta)
theta_orig = rng.uniform(size=3)
idxs0 = rng.choice(np.arange(2), size=29)
idxs1 = rng.choice(np.arange(3), size=29)
x = np.linspace(0, 1, 29)
obs = rng.normal(loc=x + theta_orig[idxs1], scale=3)
log_lik = norm(
mu[:, :, None] * x[None, None, :] + theta[:, :, idxs0, idxs1], tau[:, :, idxs0]
).logpdf(obs[None, None, :])
log_lik = log_lik / log_lik.var()
diverging = rng.choice([True, False], size=(3, 50), p=[0.1, 0.9])

dt = from_dict(
{
"posterior": {"mu": mu, "theta": theta, "tau": tau},
"log_prior": {"mu": mu_prior, "theta": theta_prior, "tau": tau_prior},
"log_likelihood": {"y": log_lik},
"observed_data": {"y": obs},
"sample_stats": {"diverging": diverging},
},
dims={"theta": ["hierarchy", "group"], "tau": ["hierarchy"]},
Expand Down Expand Up @@ -294,3 +317,41 @@ def test_plot_ess_evolution(datatree, relative, n_points, extra_methods, min_ess
assert all(key in child for child in pc.viz.children.values())
elif key != "remove_axis":
assert all(key in child for child in pc.viz.children.values())


@given(
plot_kwargs=st.fixed_dictionaries(
{},
optional={
"kind": plot_kwargs_value,
"credible_interval": plot_kwargs_value,
"point_estimate": plot_kwargs_value,
"point_estimate_text": plot_kwargs_value,
"title": plot_kwargs_value,
"remove_axis": st.just(False),
},
),
alphas=st.sampled_from(((0.9, 1.1), None)),
kind=kind_value,
point_estimate=point_estimate_value,
ci_kind=ci_kind_value,
)
def test_plot_psense(datatree, alphas, kind, point_estimate, ci_kind, plot_kwargs):
kind_kwargs = plot_kwargs.pop("kind", None)
if kind_kwargs is not None:
plot_kwargs[kind] = kind_kwargs
pc = plot_psense_dist(
datatree,
alphas=alphas,
backend="none",
kind=kind,
ci_kind=ci_kind,
point_estimate=point_estimate,
plot_kwargs=plot_kwargs,
)
assert all("plot" in child for child in pc.viz.children.values())
for key, value in plot_kwargs.items():
if value is False:
assert all(key not in child for child in pc.viz.children.values())
elif key != "remove_axis":
assert all(key in child for child in pc.viz.children.values())
2 changes: 0 additions & 2 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,6 @@ def test_plot_ess_evolution_sample(self, datatree_sample, backend):
assert "hierarchy" in pc.viz["theta"].dims

def test_plot_psense_dist(self, datatree, backend):
print(datatree["log_prior"])
print(datatree["log_likelihood"])
pc = plot_psense_dist(datatree, backend=backend)
assert "chart" in pc.viz.data_vars
assert "plot" not in pc.viz.data_vars
Expand Down

0 comments on commit bde60ad

Please sign in to comment.