Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adding PPC plot to Arviz-Plots #55

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

imperorrp
Copy link
Collaborator

@imperorrp imperorrp commented Jun 21, 2024

Adding plot_ppc (issue #11).

This is a PPC plot draft that currently plots the actual observed values and prior/posterior predictive values, flattening across all dimensions except sample dims (chain, draw) by default

WIP:

  • The prior/posterior predictive mean
  • Cumulative and scatter visualization

📚 Documentation preview 📚: https://arviz-plots--55.org.readthedocs.build/en/55/

@imperorrp
Copy link
Collaborator Author

Plots generated by passing in the centered_eight datatree as data:

It seems like multiple colors are cycled through even though no aesthetic mapping was defined in kwargs for the kde curves for the ppc samples so the observed values curve isn't standing out. Not sure how to fix this-

azp.plot_ppc(data, data_pairs={"y":"y"}, num_pp_samples=50)
image

azp.plot_ppc(data, data_pairs={"y":"y"}, num_pp_samples=500)
image

When only 1 pp sample is selected though, the observed curve becomes clear (the darker one)-
azp.plot_ppc(data, data_pairs={"y":"y"}, num_pp_samples=1)
image

As a comparison for checking the visualization accuracy, this is the plot generated for the same centered_eight data by the legacy Arviz plot_ppc-
image

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have left several comments. Please go over all of them first before trying to address any of them and ask questions about anything that might not be clear.

I think the main issue is a conceptual one about what is the PlotCollection class and how to use it.

src/arviz_plots/plots/ppcplot.py Show resolved Hide resolved

Parameters
----------
dt : DataTree or dict of {str : DataTree}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should take only DataTree input, not dictionary of DataTree, otherwise the code below won't work.

raise TypeError("`group` argument must be either `posterior` or `prior`")

for groups in (f"{group}_predictive", "observed_data"):
if not hasattr(dt, groups):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be better checked against .children instead of class attributes

Comment on lines 173 to 180
if group == "posterior":
predictive_data_group = "posterior_predictive"
if observed is None:
observed = True
elif group == "prior":
predictive_data_group = "prior_predictive"
if observed is None:
observed = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if group == "posterior":
predictive_data_group = "posterior_predictive"
if observed is None:
observed = True
elif group == "prior":
predictive_data_group = "prior_predictive"
if observed is None:
observed = False
predictive_data_group = f"{group}_predictive"
if observed is None:
observed = group == "posterior"

observed = False

if observed:
observed_data_group = "observed_data"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't really necessary to define this, as it will always be observed_data but it doesn't really hurt. What is important however is that everything that has to do with observed data happens inside an if observed: like here

Comment on lines 264 to 265
if random_seed is not None:
np.random.seed(random_seed)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if random_seed is not None:
np.random.seed(random_seed)
if random_seed is not None:
rng = np.random.default_rng(random_seed)

Then use rng.method, ref https://numpy.org/doc/stable/reference/random/generator.html

print(f"\nplot_collection.viz = {plot_collection.viz}")

# picking random pp dataset sample indexes
total_pp_samples = plot_collection.data.sizes["chain"] * plot_collection.data.sizes["draw"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will need to be general to allow any value in sample_dims but I'd worry about this once everything is working with chain and draw.

Comment on lines 299 to 305
observed_distribution = distribution.sel(model="observed_data")
print(
f"\nobserved distri = {observed_distribution.obs!r}"
) # the obs data was auto broadcasted all over the chain and draw dims
# but we can ignore the other dims and just subset it to 1 chain and 1 draw and then use the
# resulting subsetted variable data
observed_distribution = observed_distribution.sel(chain=0, draw=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both these operations can be skipped if predictive and observed data aren't concatenated

Comment on lines 308 to 327
# density calculation for observed variables
density_kwargs = copy(plot_kwargs.get(kind, {}))

if density_kwargs is not False:
density_dims, _, density_ignore = filter_aes(plot_collection, aes_map, "kde", sample_dims)
print(f"\ndensity_dims = {density_dims}\ndensity_ignore= {density_ignore}")

for dim in flatten: # flatten is the list of user defined dims to flatten or all dims
obs_density_dims = {dim}.union(
density_dims
) # dims to be reduced now includes the flatten ones and not just sample dims

obs_density = observed_distribution.azstats.kde(
dims=obs_density_dims, **stats_kwargs.get("density", {})
)
print(f"\nobserved data density = {obs_density}")

plot_collection.map(
line_xy, "kde", data=obs_density, ignore_aes=density_ignore, **density_kwargs
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this code is taking care of https://arviz-plots.readthedocs.io/en/latest/contributing/new_plot.html#adding-artists-to-the-plot. The first and main comment is the observed is an artist, and the pp lines (all of them) are another artist. As each artist must have a different id, kind can't be used for either. We could use for example observed and predictive or something like this.

Regarding step 1, only the first line density_kwargs should be outside the if, if it is false, nothing else should happen, moreover, as this is plotting the observed data it should be inside an if observed:.

Given we already have the plot_kwargs={"observed": False} way of not plotting the observed data, I am not sure the observed argument is needed, we could use setdefault on plot_kwargs` so it is not shown when group=prior.

Regarding step 2, you have to be careful to use the same id when calling filter_aes, unless there is a good and documented reason not to do so. Currently it is hardcoded to "kde".

Regarding step 3, this is quite a minor comment, but I would also allow different stats_kwargs for each type of data, their ranges and distributions can be quite different, so users might want to increase the number of points of the kde for one but not the other for example.

Step 4 is skipped completely, which is the reason why you get lines following the default matplotlib color cycle as they are plotted, each line having a different color. You should set the color of the line provided it doesn't have an aesthetic mapping for color.

Step 5 should again use the same id as the key in plot_kwargs and filter_aes/aes_map. Even more important than above, you can not have multiple calls to .map using the same name.

pp_densities = []
pp_xs = []

for i in range(predictive_distribution.sizes["ppc_dim"]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should not be looping manually, instead, you should define the facetting and aesthetics that generate the plot you want. Here you want multiple lines all with the same properties, thus, you need the overlay aesthetic.

See for example this simplified version of plot_ppc using plot_dist:

pc = PlotCollection.wrap(
    centered.posterior_predictive.ds,
    cols=["__variable__"],
    aes={"overlay": ["chain", "draw"]},
)
plot_dist_artists = ("credible_interval", "point_estimate", "point_estimate_text", "title")
plot_dist(
    centered,
    group="posterior_predictive",
    plot_kwargs={
        **{key: False for key in plot_dist_artists},
        "kde": {"color": "C0", "alpha": .2}
    },
    sample_dims=["school"],
    plot_collection=pc,
)
plot_dist(
    centered,
    group="observed_data",
    plot_kwargs={
        **{key: False for key in plot_dist_artists},
        "kde": {"color": "black"}
    },
    sample_dims=["school"],
    plot_collection=pc,
    # we could add aes_map={"kde": []} so the overlay mapping is not used here,
    # but as it is defined over dimensions that don't exist on this variable, PlotCollection
    # already skips it as it is not applicable to this case
);

imatge

It only works with centered_eight, plots all pp samples... but I hope it can help you understand how to use PlotCollection.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is sort the approach that should be taken when calling map on plot_ppc as well, right - as in sample_dims should be all of the non chain and draw dimensions (by default, though a user could select less) since these are the dims being flattened and used to generate KDEs, while the actual pp sample dims chain and draw are set as "overlay" in the aesthetic mapping?

Copy link
Member

@OriolAbril OriolAbril Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the example sample_dims has nothing to do with .map, the need for sample_dims="school" is because sample_dims are the dimensions being reduced by plot_dist when computing the kde. But that happens in the kde computation step, before calling .map

In plot_ppc, we split the dimensions into 3 non-overlapping groups:

  • overlay dimensions -> here sample_dims
  • facet dimensions -> now modifiable also via pc_kwargs, dimensions for which multiple plots will be generated
  • rest/leftover/reduce -> what used to be flatten argument (now defined by elimination) , these are the ones that should be reduced when computing the kde

In this particular example, they are respectively [chain, draw], [] and [school]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, this way of splitting makes sense. I've modified the code in my latest commit to take into account these three kinds of dimensions (for overlay/facetting/reducing)

@imperorrp
Copy link
Collaborator Author

imperorrp commented Jun 24, 2024

Latest commit takes into account your suggestions and comments @OriolAbril. The mean plotting might need some more work as the current implementation just takes into account all of the data to generate the mean. If the plot is facetted, generating the mean curves on the relevant subselections of the data would then be required I suppose.

Here is the current output when plot_ppc is called with the centered-eight data:

azp.plot_ppc(data, num_pp_samples=500)

image

@codecov-commenter
Copy link

codecov-commenter commented Jun 25, 2024

Codecov Report

Attention: Patch coverage is 88.88889% with 20 lines in your changes missing coverage. Please review.

Project coverage is 85.13%. Comparing base (f4a39af) to head (f358509).

Files with missing lines Patch % Lines
src/arviz_plots/plots/ppcplot.py 88.37% 20 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #55      +/-   ##
==========================================
+ Coverage   84.80%   85.13%   +0.33%     
==========================================
  Files          21       22       +1     
  Lines        2336     2509     +173     
==========================================
+ Hits         1981     2136     +155     
- Misses        355      373      +18     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

src/arviz_plots/plots/ppcplot.py Show resolved Hide resolved
src/arviz_plots/plots/ppcplot.py Show resolved Hide resolved
Comment on lines 154 to 156
for groups in (f"{group}_predictive", "observed_data"):
if groups not in dt.children:
raise TypeError(f'`data` argument must have the groups "{groups}" for ppcplot')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it a bit confusing to use groups when it is a string with a single group name

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, when observed=False we dont need to have the observed_data group

src/arviz_plots/plots/ppcplot.py Outdated Show resolved Hide resolved
Comment on lines 196 to 202
reduce_dims = []
reduce_dims.append(
str(
dim for dim in pp_distribution.dims if dim not in facet_dims.union(sample_dims)
) # set by elimination
)
reduce_dims = ["school"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
reduce_dims = []
reduce_dims.append(
str(
dim for dim in pp_distribution.dims if dim not in facet_dims.union(sample_dims)
) # set by elimination
)
reduce_dims = ["school"]
reduce_dims = [dim for dim in pp_distribution.dims if dim not in facet_dims.union(sample_dims)]

kind=None,
facet_dims=None,
data_pairs=None,
aggregate=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we agreed on #11 to have aggregate=False as default

Comment on lines 150 to 169
if facet_dims is None:
facet_dims = []

# check for duplication of facet_dims in top level arg input and pc_kwargs
if "cols" in pc_kwargs and len(facet_dims) > 0:
duplicated_dims = set(facet_dims).intersection(pc_kwargs["cols"])
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=f"""Facet dimensions have been defined twice.
Both in the the top level function arguments and in `pc_kwargs`.
The `cols` key in `pc_kwargs` will take precedence.

facet_dims = {facet_dims}
pc_kwargs["cols"] = {pc_kwargs["cols"]}
Duplicated dimensions: {duplicated_dims}.""",
)
# setting facet_dims to pc_kwargs defined values since it is used
# later to calculate dims to reduce
facet_dims = list(set(pc_kwargs["col"]).difference({"__variable__"}))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think personally I'd prefer to error out if both are provided. If the error message is clear it should have a very clear fix, use one of the other but not both.

We could also issue a warning and continue ahead, but then I think facet_dims should take preference as it is a "higher" level argument, I think kwargs overriding top level arguments is quite unintuitive. In that case I would also keep things as simple as possible, along the lines of "both "cols" in pc_kwargs and facet_dims are provided, "cols" in pc_kwargs will be ignored, to have it taken into account use facet_dims=None"

Side note: Note the code doesn't issue any warning, it activates a filter to catch and ignore warnings whose message matches the message= field in the filterwarnings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra note: reviewing the plot_ridge PR, I realized that a similar thing already happens in plot_forest and we raise errors, so for consistency I think we should raise errors here too. ref https://github.com/arviz-devs/arviz-plots/blob/main/src/arviz_plots/plots/forestplot.py#L242-L249

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah my bad, I just used the error bit from plot_dist as a template without checking it properly. I'll replace with the raise ValueError(...) you pointed out from plot_forest.

Also, I think the first option sounds good- raise an error if both are provided and ask the user to only provide one of the args.

Comment on lines 178 to 189
# checking if plot_kwargs["observed"] or plot_kwargs["aggregate"] is not inconsistent with
# top level bool arguments `observed`` and `aggregate`
if "observed" in plot_kwargs:
if plot_kwargs["observed"] != observed:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="""`plot_kwargs['observed']` inconsistency detected.
It is not the same as with the top level `observed` argument.
`plot_kwargs['observed']` will take precedence.""",
)
observed = plot_kwargs["observed"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here. Here I am a bit less sure about error vs warning though.

I might also modify the logic, especially if going the warning route to something like:

if plot_kwargs.get("observed", False) is not False and observed:
    # raise warning or error

Otherwise, people who do not read the documentation very carefully or used to the API of the other plots might expect using the "observed" key in plot_kwargs to make the line appear, but those will be ignored if observed=False.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused by this logic (Wouldn't this imply if both are true, the error is raised?) Did you mean
if plot_kwargs.get("observed", False) is not False and not observed:? (implying plot_kwargs["observed"] is set to true but observed is false, so an error should be raised as they contradict)

Also what about the condition when observed is true but plot_kwargs["observed"] is false- should we display the error for this condition too due to contradiction, or just let the top-level argument observed take precedence?

Because it might be the case that observed is not set at all by the user but defaults to true (which happens in the code before this check) because of the posterior predictive group being plotted, but the user has set plot_kwargs["observed"] to false to not have it be displayed- in this case should the top level argument still take preference even though it took a default value and wasn't explicitly defined in the function call?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, the code makes no sense, sorry. I started writing with an idea on mind on how to do it then changed my mind but not the text.

The two options should be considered:

if (
    observed and plot_kwargs.get("observed", True) is False
    or not observed and plot_kwargs.get("observed", False) is not False
):
    # raise warning or error

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whenever they are incompatible we should raise an error, thus, none will ever take precedence over the other because they either mean the same thing or an error is thrown

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, depending on #66 we cpuld simplify it to " plot_kwargs["observed"] can't be False, use observed top level kwarg for that"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've used the first incompatibility logic for now- could we keep a note of this and make modifications after the proposals in #66 are finalized?

# making sure both posterior/prior predictive group and observed_data group exists in
# datatree provided
if observed:
for group_name in (f"{predictive_data_group}", "observed_data"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for group_name in (f"{predictive_data_group}", "observed_data"):
for group_name in (predictive_data_group, "observed_data"):

Comment on lines 261 to 262
if backend is None:
backend = rcParams["plot.backend"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is duplicated with the check above now

Comment on lines 333 to 337
aes_map={
kind: value
for key, value in aes_map.items()
if key == "predictive" and key not in pp_density_ignore
},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what is happening here. The keys of the aes_map dict need to be ids (aka artist names) for the visual elements in plot_dist, the values are sets of aesthetic names. It indicates which aesthetics should be applied to each artist.

No key can be in pp_density_ignore as pp_density_ignore contains aesthetic names. That means only the "predictive" key will be used, which makes sense as it is what we want to plot, and it is "renamed" to kind, but also kind is hardcoded even if inside the dict comprehension, so if the loop were to actually do anything then it would break as it would attempt to generate a dictionary with repeated keys.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I meant to put if key == "predictive" and value not in pp_density_ignore so it checks the aesthetic names against those in pp_density_ignore. About the dist comprehension yeah it doesn't seem to useful in retrospect. I'll refactor this to how it's done in plot_trace_dist when plot_dist is called intrinsically

**{key: False for key in plot_dist_artists},
kind: dict(pp_kwargs.items()),
},
stats_kwargs={key: value for key, value in stats_kwargs.items() if key == "predictive"},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhat similar comment here. Make sure the code makes sense and that it is consistent with the documentation.

# ---------STEP 2 (PPC AGGREGATE)-------------

aggregate_kwargs = copy(plot_kwargs.get("aggregate", {}))
if aggregate and pp_kwargs is not False and aggregate_kwargs is not False:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not depend on pp_kwargs, only on aggregate related kwargs. Also, this should be consistent with whatever happens with observed and plot_kwargs["observed"] if one raises a warning when there is duplication, the other should too, if it is an error instead, both should raise an error (ideally if an error at the top of the function before any computation or plotting happens).

Comment on lines 371 to 372
# setting aggregate aes_map to `[]` so `overlay` isn't applied
aes_map.setdefault("aggregate", [])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should happen in the aes_map defaults (it doesn't really matter if we set that and then aggregate=False)

)
obs_kwargs = copy(plot_kwargs.get("observed", {}))

if obs_kwargs is not False:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still has the double level logic with observed and obs_kwargs

@imperorrp
Copy link
Collaborator Author

imperorrp commented Jul 7, 2024

Made some modifications/fixes. Some changes related to input kwargs (the incompatibility checks between top level args and plot_kwargs keys) might have to be further modified later though (in relation to issue #66) in the future.

Updated the docstring as well so users are aware that the values they pass into plot_ppc related keys ("predictive", "aggregate", and "observed") in plot_kwargs and aes_map are transferred to plot_dist's 'kind' type artists internally. In the code for this, I've used plot_trace_dist conventions as far as possible but with some changes since plot_dist is only called once in that plot and not for multiple artists.

stats_kwargs is passed to all internally called plot_dists as is- so there is an assumption being made that users probably will want the same kind of statistical computation run for each plot_ppc artist they want to plot for consistency, which I think would make sense. But this can be modified to something like what's done for plot_kwargs and aes_map if wanted though.

@OriolAbril
Copy link
Member

I'll try to review the code tomorrow, in the meantime some quick amswers to the questions above.

I don't think it matters to users that plot_dist is called internally, they only care about the valid keys and where are their values passed to eventually (so they know what values are valid). This will also match how things are documented in plot_trace_dist.

stats_kwargs should allow for different kwargs being passed to the different datasets, see #55 (comment)

@imperorrp
Copy link
Collaborator Author

imperorrp commented Jul 7, 2024

I'll try to review the code tomorrow, in the meantime some quick amswers to the questions above.

I don't think it matters to users that plot_dist is called internally, they only care about the valid keys and where are their values passed to eventually (so they know what values are valid). This will also match how things are documented in plot_trace_dist.

stats_kwargs should allow for different kwargs being passed to the different datasets, see #55 (comment)

Okay, I'll adjust the docstring to reflect that so its simpler- I could make mention of it as a comment in the code though so if looking at the source code one could make out where it's being passed internally.

Also sorry, missed that. I'll adjust the stats_kwargs usage logic to allow different kwargs to be passed for "predictive", "aggregate" and "observed".

@imperorrp
Copy link
Collaborator Author

Added observed_rug to plot the observed values' distribution. Since no masking is required as divergences aren't being plotted here, the trace_rug visual element was also modified a bit to allow for mask=None.

image

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs tests so we can check the code does what we want to happen in all the different possible scenarios

coords=None,
sample_dims=None,
kind=None,
facet_dims=None,
data_pairs=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_pairs argument should be kept


# pp distribution group plotting logic
pp_distribution = process_group_variables_coords(
dt, group=predictive_data_group, var_names=var_names, filter_vars=filter_vars, coords=coords
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_pairs should be taken into account here:

Suggested change
dt, group=predictive_data_group, var_names=var_names, filter_vars=filter_vars, coords=coords
dt, group=predictive_data_group, var_names=None if var_names is None else [data_pairs.get(var_name, var_name) for var_name in var_names], filter_vars=filter_vars, coords=coords

@imperorrp
Copy link
Collaborator Author

Just added some tests.

I also modified the trace_rug function a bit again- with a 'flatten' keyword to automatically trigger flattening if a multidimensional dataarray is passed to it. When no dimensions of the ppcplot are being facetted, then the observed values subset passed to this function by pc.map are also multidimensional (2D in the case of the 4D datatree fixture), but trace_rug currently does not accept this.

Also restored the data_pairs arg

@imperorrp
Copy link
Collaborator Author

About the tests- the 'kind' and 'group' arguments I've parametrized both via pytest and via hypothesis.strategies. Should only one be picked or this kept?

The hypothesis tests are also failing a few times. At least one seems to be because of trying to toggle the 'observed' artist off with plot_kwargs but not in the top level arg, so the ValueError set for this gets raised. The others I am still unsure of why they're failing.

Dimensions to loop over in plotting posterior/prior predictive.
Note: Dims not in sample_dims or facet_dims (below) will be reduced by default.
Defaults to ``rcParams["data.sample_dims"]``
kind : {"kde", "cumulative", "scatter"}, optional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
kind : {"kde", "cumulative", "scatter"}, optional
kind : {"kde", "cumulative", "scatter"}, optional

I see you did some changes like renaming "cumulative" to "ecdf" internally, but that will be confusing for users. There should be 4 possible values accepted by kind: kde, ecdf, hist and scatter

Comment on lines 79 to 85
Dictionary containing relations between observed data and posterior/prior predictive data.
Dictionary structure:
* key = observed data var_name
* value = posterior/prior predictive var_name
For example, data_pairs = {'y' : 'y_hat'}.
If None, it will assume that the observed data and the posterior/prior predictive data
have the same variable name
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Dictionary containing relations between observed data and posterior/prior predictive data.
Dictionary structure:
* key = observed data var_name
* value = posterior/prior predictive var_name
For example, data_pairs = {'y' : 'y_hat'}.
If None, it will assume that the observed data and the posterior/prior predictive data
have the same variable name
Dictionary containing relations between observed data and posterior/prior predictive data.
Dictionary keys are variable names corresponding to observed data and dictionary values
are variable names corresponding to posterior/prior predictive. For example, data_pairs = {'y' : 'y_hat'}.
By default, it will assume that the observed data and the posterior/prior predictive data
have the same variable names

Comment on lines 86 to 88
aggregate: bool, optional
If True, predictive data will be aggregated over both sample_dims and reduce_dims.
Defaults to False.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
aggregate: bool, optional
If True, predictive data will be aggregated over both sample_dims and reduce_dims.
Defaults to False.
aggregate: bool, default False
If True, predictive data will be aggregated over both sample_dims and reduce_dims.

When the default is the value that is used it should be documented here directly. If it is a placeholder or if it depends on other values then use optional and explain the behaviour in the description

Comment on lines 76 to 77
Defaults to empty list. A warning is raised if `pc_kwargs` is also used to define
dims to facet over with the `cols` key and `pc_kwargs` takes precedence.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Defaults to empty list. A warning is raised if `pc_kwargs` is also used to define
dims to facet over with the `cols` key and `pc_kwargs` takes precedence.
Defaults to empty list.

Comment on lines 289 to 300
pp_sample_ix = rng.choice(total_pp_samples, size=num_pp_samples, replace=False)

# print(f"\npp_sample_ix: {pp_sample_ix!r}")

# stacking sample dims into a new 'ppc_dim' dimension
pp_distribution = pp_distribution.stack(ppc_dim=sample_dims)

# Select the desired samples
pp_distribution = pp_distribution.isel(ppc_dim=pp_sample_ix)

# renaming sample_dims so that rest of plot will consider this as sample_dims
sample_dims = ["ppc_dim"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should skip all these steps if num_pp_samples == total_pp_samples. I think nothing within the function itself should change (test probably will) and it would allow using aesthetics for chain/draw when plotting all samples

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might also want to distinguish between length 1 sample_dims (in which case it is not necessary to stack nor to rename and overwrite its value) and >1 length sample_dims (in which case stacking is necessary)

Comment on lines 63 to 66
if flatten is not False:
xvalues = xvalues.values.flatten() # flatten xvalues
if len(xvalues.shape) != 1:
raise ValueError(f"Expected unidimensional data but got {xvalues.sizes}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you try what happens when xvalues is always flattened and no error is risen?

Comment on lines 183 to 203
# checking if plot_kwargs["observed"] or plot_kwargs["aggregate"] is not inconsistent with
# top level bool arguments `observed`` and `aggregate`

# observed args logic check:
# observed will be True/False depending on user-input/group- true for posterior, false for prior
# if observed and no plot_kwargs['observed'], no prob
# if observed = True and plot_kwargs['observed'] = True, no prob (observed is plotted)
# if observed = True and plot_kwargs['observed'] = False, prob (error/warning raised)
# if observed = False and plot_kwargs['observed'] = True, prob (error/warning raised)
# if observed = False and plot_kwargs['observed'] = False, no prob (observed is not plotted)
if (
observed
and plot_kwargs.get("observed", True) is False
or not observed
and plot_kwargs.get("observed", False) is not False
):
# raise warning or error
raise ValueError(
f"""
`observed` and `plot_kwargs["observed"]` inconsistency detected.
`observed` = {observed}{default_observed}
`plot_kwargs["observed"]` = {plot_kwargs["observed"]}
Please make sure `observed` and `plot_kwargs["observed"]` have the same value."""
)

# same check for aggregate:
if (
aggregate
and plot_kwargs.get("aggregate", True) is False
or not aggregate
and plot_kwargs.get("aggregate", False) is not False
):
# raise warning or error
raise ValueError(
f"""
`aggregate` and `plot_kwargs["aggregate"]` inconsistency detected.
`aggregate` = {aggregate}
`plot_kwargs["aggregate"]` = {plot_kwargs["aggregate"]}
Please make sure `aggregate` and `plot_kwargs["aggregate"]` have the same value."""
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing the tests and having reviewed plot_forest, I realized this is different behaviour from plot_forest. It might make sense to mimic the behaviour completely, which would be raising an error if plot_kwargs["aggregate"] (or observed) is False, continue on ahead otherwise.

This would mean we could have a dict in that key that gets ignored and the element is not plotted, but that is also what happens in plot_forest and I think is good enough behaviour.

},
),
kind=ppc_kind_value,
group=ppc_group,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All relevant plot specific arguments should be here, especially if they affect what gets plotted or how the data is processed within the function. Thus, observed and aggregate (toggling observed_rug on too if observed is true), facet_dims and num_pp_samples should be here.

@@ -246,3 +272,52 @@ def test_plot_ridge_aes_labels_shading(self, backend, datatree_4d, pseudo_dim):
if pseudo_dim != "__variable__":
assert all(0 in child["alpha"] for child in pc.aes.children.values())
assert any(pseudo_dim in child["shade"].dims for child in pc.viz.children.values())

@pytest.mark.parametrize("group", ("prior", "posterior"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to parametrize by group. It plays an extremely minor role in the function and it is also being checked in the hypothesis tests.

@pytest.mark.parametrize("group", ("prior", "posterior"))
@pytest.mark.parametrize("kind", ("kde", "cumulative"))
@pytest.mark.parametrize("facet_dims", (["group"], ["hierarchy"], None))
def test_plot_ppc_4d(self, datatree_4d, facet_dims, kind, group, backend):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there should be assertions to check facet_dims is working as expected. For example, making sure that dimension is present in pc.viz["obs"]["plot"]

@imperorrp
Copy link
Collaborator Author

  • Updated plot_ppc with changes: docstring updates, a check to only flatten and stack if num_pp_samples isn't total_pp_samples or len(sample_dims)>1, plot_forest-like checks for plot_kwargs['observed'] and plot_kwargs['aggregate'] to ensure they aren't False.

  • The trace_rug visual element now flattens by default

  • Parametrizations for group are removed from test_plots, parametrizations for facet_dims are added and assert checks for facet-dims in pc.viz["obs"]["plot"].dims were added. Since the flattening/stacking and updating of sample_dims to ppc_dims was made conditional on some logic now, the
    assert "ppc_dim" in pc.viz["obs"].dims
    statement doesn't work anymore- should we keep it with some logic added?

  • Also test_hypothesis_plots now has observed, observed_rug, aggregate, facet_dims, sample_dims, and a composite strategy for num_pp_samples (dependent on draws from group and sample_dims) added and plot_kwargs["observed"] and plot_kwargs["aggregate"] are always set to True.

Modified the datatree fixture in test_hypothesis_plots too to name the first two dims in each variable so they can be referenced by sample_dims later

There are still 2 failing conditions in the hypothesis tests though -I'm not sure why yet- where the datatree seems to be interpreted as a function for some reason

@OriolAbril
Copy link
Member

Since the flattening/stacking and updating of sample_dims to ppc_dims was made conditional on some logic now, the
assert "ppc_dim" in pc.viz["obs"].dims
statement doesn't work anymore- should we keep it with some logic added?

Some checks on this should still be present

tests/test_hypothesis_plots.py Outdated Show resolved Hide resolved
tests/test_hypothesis_plots.py Outdated Show resolved Hide resolved
src/arviz_plots/plots/ppcplot.py Outdated Show resolved Hide resolved
@imperorrp
Copy link
Collaborator Author

Resolved latest reviews with these last commits- the predictive values stacking/subselection logic was updated and the the tests modified. The Hypothesis tests still seem to be failing for 2 conditions though

@imperorrp
Copy link
Collaborator Author

Just rebased this PR too

sample_dims=ppc_sample_dims,
drawed_samples=draw_num_pp_samples(ppc_group, ppc_sample_dims),
)
def test_plot_ppc(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_plot_ppc(
def test_plot_ppc(
datatree,

@imperorrp
Copy link
Collaborator Author

Added support for kind="scatter" type plots in plot_ppc.

  • Some work on the 'y' aesthetic that is mapped to all of the predictive rugs needs to be done
  • When observed_rug=True and observed=True for kind='scatter', maybe turning one off would be good because they essentially have the same function
azp.plot_ppc(
        data,
        kind="scatter",
        num_pp_samples=10,
        aggregate=True,
        observed_rug=True,
    )

image

azp.plot_ppc(
        data,
        kind="scatter",
        num_pp_samples=5,
        aggregate=True,
    )

image

)
pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
pc_kwargs["aes"].setdefault("overlay", sample_dims) # setting overlay dim
pc_kwargs.setdefault("y", np.linspace(0.01, 0.1, 11))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't set default values

data=obs_distribution,
ignore_aes=obs_density_ignore,
xname=False,
y=0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to -1 so it doesn't overlap with the pp ones

Comment on lines 244 to 245
observed=ppc_observed,
# observed_rug=ppc_observed,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use st.composite for these

@st.composite
def observed_and_rug(draw):
    observed = draw(st.booleans())
    if not observed:
        return (False, False)
    rug = draw(st.booleans())
    return (observed, rug)

@imperorrp
Copy link
Collaborator Author

Fixed some more test related issues (new errors still persisting though). The 'alpha' keyword was missing from the 'none' backend for the 'hist' function, unlike the matplotlib and bokeh backends, which is why it hadn't been detected yet.

Another fixed test fail was due to the fact that different artists have different dimensions (for example, the 'predictive' artists are overlayed along sample dims but the 'observed' artist does not have these dimensions and the 'aggregate' artist has these reduced instead of overlayed) so the dimensions were varying when the 'predictive' key in plot_kwargs was being turned off.

Suggestion by @OriolAbril to deal with the above is to check dimensions of the artists in the tests. Current issue with this though is since a lot of the artists plot_ppc creates are done so via internal calls to plot_dist, they are named based on their 'kind' and thus can have the same name despite different plot-related visual concerns.

Current tests bypass this if 'predictive' is False in plot_kwargs

Todo:

  • Renaming of the artists to match plot_ppc plot_kwargs keys
  • Modifying the hypothesis tests to check dimensions of each of these artists
  • Fix new test errors

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants