-
-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Adding PPC plot to Arviz-Plots #55
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have left several comments. Please go over all of them first before trying to address any of them and ask questions about anything that might not be clear.
I think the main issue is a conceptual one about what is the PlotCollection class and how to use it.
src/arviz_plots/plots/ppcplot.py
Outdated
|
||
Parameters | ||
---------- | ||
dt : DataTree or dict of {str : DataTree} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should take only DataTree input, not dictionary of DataTree, otherwise the code below won't work.
src/arviz_plots/plots/ppcplot.py
Outdated
raise TypeError("`group` argument must be either `posterior` or `prior`") | ||
|
||
for groups in (f"{group}_predictive", "observed_data"): | ||
if not hasattr(dt, groups): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be better checked against .children
instead of class attributes
src/arviz_plots/plots/ppcplot.py
Outdated
if group == "posterior": | ||
predictive_data_group = "posterior_predictive" | ||
if observed is None: | ||
observed = True | ||
elif group == "prior": | ||
predictive_data_group = "prior_predictive" | ||
if observed is None: | ||
observed = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if group == "posterior": | |
predictive_data_group = "posterior_predictive" | |
if observed is None: | |
observed = True | |
elif group == "prior": | |
predictive_data_group = "prior_predictive" | |
if observed is None: | |
observed = False | |
predictive_data_group = f"{group}_predictive" | |
if observed is None: | |
observed = group == "posterior" |
src/arviz_plots/plots/ppcplot.py
Outdated
observed = False | ||
|
||
if observed: | ||
observed_data_group = "observed_data" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It isn't really necessary to define this, as it will always be observed_data
but it doesn't really hurt. What is important however is that everything that has to do with observed data happens inside an if observed:
like here
src/arviz_plots/plots/ppcplot.py
Outdated
if random_seed is not None: | ||
np.random.seed(random_seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if random_seed is not None: | |
np.random.seed(random_seed) | |
if random_seed is not None: | |
rng = np.random.default_rng(random_seed) |
Then use rng.method, ref https://numpy.org/doc/stable/reference/random/generator.html
src/arviz_plots/plots/ppcplot.py
Outdated
print(f"\nplot_collection.viz = {plot_collection.viz}") | ||
|
||
# picking random pp dataset sample indexes | ||
total_pp_samples = plot_collection.data.sizes["chain"] * plot_collection.data.sizes["draw"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will need to be general to allow any value in sample_dims
but I'd worry about this once everything is working with chain and draw.
src/arviz_plots/plots/ppcplot.py
Outdated
observed_distribution = distribution.sel(model="observed_data") | ||
print( | ||
f"\nobserved distri = {observed_distribution.obs!r}" | ||
) # the obs data was auto broadcasted all over the chain and draw dims | ||
# but we can ignore the other dims and just subset it to 1 chain and 1 draw and then use the | ||
# resulting subsetted variable data | ||
observed_distribution = observed_distribution.sel(chain=0, draw=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
both these operations can be skipped if predictive and observed data aren't concatenated
src/arviz_plots/plots/ppcplot.py
Outdated
# density calculation for observed variables | ||
density_kwargs = copy(plot_kwargs.get(kind, {})) | ||
|
||
if density_kwargs is not False: | ||
density_dims, _, density_ignore = filter_aes(plot_collection, aes_map, "kde", sample_dims) | ||
print(f"\ndensity_dims = {density_dims}\ndensity_ignore= {density_ignore}") | ||
|
||
for dim in flatten: # flatten is the list of user defined dims to flatten or all dims | ||
obs_density_dims = {dim}.union( | ||
density_dims | ||
) # dims to be reduced now includes the flatten ones and not just sample dims | ||
|
||
obs_density = observed_distribution.azstats.kde( | ||
dims=obs_density_dims, **stats_kwargs.get("density", {}) | ||
) | ||
print(f"\nobserved data density = {obs_density}") | ||
|
||
plot_collection.map( | ||
line_xy, "kde", data=obs_density, ignore_aes=density_ignore, **density_kwargs | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All this code is taking care of https://arviz-plots.readthedocs.io/en/latest/contributing/new_plot.html#adding-artists-to-the-plot. The first and main comment is the observed is an artist, and the pp lines (all of them) are another artist. As each artist must have a different id, kind
can't be used for either. We could use for example observed
and predictive
or something like this.
Regarding step 1, only the first line density_kwargs
should be outside the if, if it is false, nothing else should happen, moreover, as this is plotting the observed data it should be inside an if observed:
.
Given we already have the plot_kwargs={"observed": False}
way of not plotting the observed data, I am not sure the observed
argument is needed, we could use setdefault on
plot_kwargs` so it is not shown when group=prior.
Regarding step 2, you have to be careful to use the same id when calling filter_aes, unless there is a good and documented reason not to do so. Currently it is hardcoded to "kde".
Regarding step 3, this is quite a minor comment, but I would also allow different stats_kwargs for each type of data, their ranges and distributions can be quite different, so users might want to increase the number of points of the kde for one but not the other for example.
Step 4 is skipped completely, which is the reason why you get lines following the default matplotlib color cycle as they are plotted, each line having a different color. You should set the color of the line provided it doesn't have an aesthetic mapping for color.
Step 5 should again use the same id as the key in plot_kwargs and filter_aes/aes_map. Even more important than above, you can not have multiple calls to .map
using the same name.
src/arviz_plots/plots/ppcplot.py
Outdated
pp_densities = [] | ||
pp_xs = [] | ||
|
||
for i in range(predictive_distribution.sizes["ppc_dim"]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should not be looping manually, instead, you should define the facetting and aesthetics that generate the plot you want. Here you want multiple lines all with the same properties, thus, you need the overlay
aesthetic.
See for example this simplified version of plot_ppc using plot_dist:
pc = PlotCollection.wrap(
centered.posterior_predictive.ds,
cols=["__variable__"],
aes={"overlay": ["chain", "draw"]},
)
plot_dist_artists = ("credible_interval", "point_estimate", "point_estimate_text", "title")
plot_dist(
centered,
group="posterior_predictive",
plot_kwargs={
**{key: False for key in plot_dist_artists},
"kde": {"color": "C0", "alpha": .2}
},
sample_dims=["school"],
plot_collection=pc,
)
plot_dist(
centered,
group="observed_data",
plot_kwargs={
**{key: False for key in plot_dist_artists},
"kde": {"color": "black"}
},
sample_dims=["school"],
plot_collection=pc,
# we could add aes_map={"kde": []} so the overlay mapping is not used here,
# but as it is defined over dimensions that don't exist on this variable, PlotCollection
# already skips it as it is not applicable to this case
);
It only works with centered_eight, plots all pp samples... but I hope it can help you understand how to use PlotCollection.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is sort the approach that should be taken when calling map on plot_ppc as well, right - as in sample_dims should be all of the non chain and draw dimensions (by default, though a user could select less) since these are the dims being flattened and used to generate KDEs, while the actual pp sample dims chain and draw are set as "overlay" in the aesthetic mapping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the example sample_dims
has nothing to do with .map
, the need for sample_dims="school"
is because sample_dims
are the dimensions being reduced by plot_dist
when computing the kde. But that happens in the kde computation step, before calling .map
In plot_ppc, we split the dimensions into 3 non-overlapping groups:
- overlay dimensions -> here
sample_dims
- facet dimensions -> now modifiable also via pc_kwargs, dimensions for which multiple plots will be generated
- rest/leftover/reduce -> what used to be
flatten
argument (now defined by elimination) , these are the ones that should be reduced when computing the kde
In this particular example, they are respectively [chain, draw]
, []
and [school]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, this way of splitting makes sense. I've modified the code in my latest commit to take into account these three kinds of dimensions (for overlay/facetting/reducing)
Latest commit takes into account your suggestions and comments @OriolAbril. The Here is the current output when plot_ppc is called with the
|
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #55 +/- ##
==========================================
+ Coverage 84.80% 85.13% +0.33%
==========================================
Files 21 22 +1
Lines 2336 2509 +173
==========================================
+ Hits 1981 2136 +155
- Misses 355 373 +18 ☔ View full report in Codecov by Sentry. |
src/arviz_plots/plots/ppcplot.py
Outdated
for groups in (f"{group}_predictive", "observed_data"): | ||
if groups not in dt.children: | ||
raise TypeError(f'`data` argument must have the groups "{groups}" for ppcplot') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find it a bit confusing to use groups
when it is a string with a single group name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, when observed=False we dont need to have the observed_data
group
src/arviz_plots/plots/ppcplot.py
Outdated
reduce_dims = [] | ||
reduce_dims.append( | ||
str( | ||
dim for dim in pp_distribution.dims if dim not in facet_dims.union(sample_dims) | ||
) # set by elimination | ||
) | ||
reduce_dims = ["school"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reduce_dims = [] | |
reduce_dims.append( | |
str( | |
dim for dim in pp_distribution.dims if dim not in facet_dims.union(sample_dims) | |
) # set by elimination | |
) | |
reduce_dims = ["school"] | |
reduce_dims = [dim for dim in pp_distribution.dims if dim not in facet_dims.union(sample_dims)] |
src/arviz_plots/plots/ppcplot.py
Outdated
kind=None, | ||
facet_dims=None, | ||
data_pairs=None, | ||
aggregate=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we agreed on #11 to have aggregate=False as default
src/arviz_plots/plots/ppcplot.py
Outdated
if facet_dims is None: | ||
facet_dims = [] | ||
|
||
# check for duplication of facet_dims in top level arg input and pc_kwargs | ||
if "cols" in pc_kwargs and len(facet_dims) > 0: | ||
duplicated_dims = set(facet_dims).intersection(pc_kwargs["cols"]) | ||
with warnings.catch_warnings(): | ||
warnings.filterwarnings( | ||
"ignore", | ||
message=f"""Facet dimensions have been defined twice. | ||
Both in the the top level function arguments and in `pc_kwargs`. | ||
The `cols` key in `pc_kwargs` will take precedence. | ||
|
||
facet_dims = {facet_dims} | ||
pc_kwargs["cols"] = {pc_kwargs["cols"]} | ||
Duplicated dimensions: {duplicated_dims}.""", | ||
) | ||
# setting facet_dims to pc_kwargs defined values since it is used | ||
# later to calculate dims to reduce | ||
facet_dims = list(set(pc_kwargs["col"]).difference({"__variable__"})) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think personally I'd prefer to error out if both are provided. If the error message is clear it should have a very clear fix, use one of the other but not both.
We could also issue a warning and continue ahead, but then I think facet_dims
should take preference as it is a "higher" level argument, I think kwargs overriding top level arguments is quite unintuitive. In that case I would also keep things as simple as possible, along the lines of "both "cols" in pc_kwargs and facet_dims are provided, "cols" in pc_kwargs will be ignored, to have it taken into account use facet_dims=None"
Side note: Note the code doesn't issue any warning, it activates a filter to catch and ignore warnings whose message matches the message=
field in the filterwarnings
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extra note: reviewing the plot_ridge PR, I realized that a similar thing already happens in plot_forest
and we raise errors, so for consistency I think we should raise errors here too. ref https://github.com/arviz-devs/arviz-plots/blob/main/src/arviz_plots/plots/forestplot.py#L242-L249
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah my bad, I just used the error bit from plot_dist
as a template without checking it properly. I'll replace with the raise ValueError(...)
you pointed out from plot_forest
.
Also, I think the first option sounds good- raise an error if both are provided and ask the user to only provide one of the args.
src/arviz_plots/plots/ppcplot.py
Outdated
# checking if plot_kwargs["observed"] or plot_kwargs["aggregate"] is not inconsistent with | ||
# top level bool arguments `observed`` and `aggregate` | ||
if "observed" in plot_kwargs: | ||
if plot_kwargs["observed"] != observed: | ||
with warnings.catch_warnings(): | ||
warnings.filterwarnings( | ||
"ignore", | ||
message="""`plot_kwargs['observed']` inconsistency detected. | ||
It is not the same as with the top level `observed` argument. | ||
`plot_kwargs['observed']` will take precedence.""", | ||
) | ||
observed = plot_kwargs["observed"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment here. Here I am a bit less sure about error vs warning though.
I might also modify the logic, especially if going the warning route to something like:
if plot_kwargs.get("observed", False) is not False and observed:
# raise warning or error
Otherwise, people who do not read the documentation very carefully or used to the API of the other plots might expect using the "observed" key in plot_kwargs to make the line appear, but those will be ignored if observed=False
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused by this logic (Wouldn't this imply if both are true, the error is raised?) Did you mean
if plot_kwargs.get("observed", False) is not False and not observed:
? (implying plot_kwargs["observed"]
is set to true but observed
is false, so an error should be raised as they contradict)
Also what about the condition when observed
is true but plot_kwargs["observed"]
is false- should we display the error for this condition too due to contradiction, or just let the top-level argument observed
take precedence?
Because it might be the case that observed
is not set at all by the user but defaults to true (which happens in the code before this check) because of the posterior predictive group being plotted, but the user has set plot_kwargs["observed"]
to false to not have it be displayed- in this case should the top level argument still take preference even though it took a default value and wasn't explicitly defined in the function call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, the code makes no sense, sorry. I started writing with an idea on mind on how to do it then changed my mind but not the text.
The two options should be considered:
if (
observed and plot_kwargs.get("observed", True) is False
or not observed and plot_kwargs.get("observed", False) is not False
):
# raise warning or error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whenever they are incompatible we should raise an error, thus, none will ever take precedence over the other because they either mean the same thing or an error is thrown
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, depending on #66 we cpuld simplify it to " plot_kwargs["observed"]
can't be False, use observed top level kwarg for that"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've used the first incompatibility logic for now- could we keep a note of this and make modifications after the proposals in #66 are finalized?
src/arviz_plots/plots/ppcplot.py
Outdated
# making sure both posterior/prior predictive group and observed_data group exists in | ||
# datatree provided | ||
if observed: | ||
for group_name in (f"{predictive_data_group}", "observed_data"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for group_name in (f"{predictive_data_group}", "observed_data"): | |
for group_name in (predictive_data_group, "observed_data"): |
src/arviz_plots/plots/ppcplot.py
Outdated
if backend is None: | ||
backend = rcParams["plot.backend"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is duplicated with the check above now
src/arviz_plots/plots/ppcplot.py
Outdated
aes_map={ | ||
kind: value | ||
for key, value in aes_map.items() | ||
if key == "predictive" and key not in pp_density_ignore | ||
}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand what is happening here. The keys of the aes_map
dict need to be ids (aka artist names) for the visual elements in plot_dist, the values are sets of aesthetic names. It indicates which aesthetics should be applied to each artist.
No key can be in pp_density_ignore
as pp_density_ignore
contains aesthetic names. That means only the "predictive" key will be used, which makes sense as it is what we want to plot, and it is "renamed" to kind
, but also kind
is hardcoded even if inside the dict comprehension, so if the loop were to actually do anything then it would break as it would attempt to generate a dictionary with repeated keys.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I meant to put if key == "predictive" and value not in pp_density_ignore
so it checks the aesthetic names against those in pp_density_ignore
. About the dist comprehension yeah it doesn't seem to useful in retrospect. I'll refactor this to how it's done in plot_trace_dist
when plot_dist
is called intrinsically
src/arviz_plots/plots/ppcplot.py
Outdated
**{key: False for key in plot_dist_artists}, | ||
kind: dict(pp_kwargs.items()), | ||
}, | ||
stats_kwargs={key: value for key, value in stats_kwargs.items() if key == "predictive"}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somewhat similar comment here. Make sure the code makes sense and that it is consistent with the documentation.
src/arviz_plots/plots/ppcplot.py
Outdated
# ---------STEP 2 (PPC AGGREGATE)------------- | ||
|
||
aggregate_kwargs = copy(plot_kwargs.get("aggregate", {})) | ||
if aggregate and pp_kwargs is not False and aggregate_kwargs is not False: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not depend on pp_kwargs
, only on aggregate
related kwargs. Also, this should be consistent with whatever happens with observed
and plot_kwargs["observed"]
if one raises a warning when there is duplication, the other should too, if it is an error instead, both should raise an error (ideally if an error at the top of the function before any computation or plotting happens).
src/arviz_plots/plots/ppcplot.py
Outdated
# setting aggregate aes_map to `[]` so `overlay` isn't applied | ||
aes_map.setdefault("aggregate", []) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should happen in the aes_map defaults (it doesn't really matter if we set that and then aggregate=False)
src/arviz_plots/plots/ppcplot.py
Outdated
) | ||
obs_kwargs = copy(plot_kwargs.get("observed", {})) | ||
|
||
if obs_kwargs is not False: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still has the double level logic with observed
and obs_kwargs
Made some modifications/fixes. Some changes related to input kwargs (the incompatibility checks between top level args and Updated the docstring as well so users are aware that the values they pass into
|
I'll try to review the code tomorrow, in the meantime some quick amswers to the questions above. I don't think it matters to users that plot_dist is called internally, they only care about the valid keys and where are their values passed to eventually (so they know what values are valid). This will also match how things are documented in plot_trace_dist. stats_kwargs should allow for different kwargs being passed to the different datasets, see #55 (comment) |
Okay, I'll adjust the docstring to reflect that so its simpler- I could make mention of it as a comment in the code though so if looking at the source code one could make out where it's being passed internally. Also sorry, missed that. I'll adjust the stats_kwargs usage logic to allow different kwargs to be passed for "predictive", "aggregate" and "observed". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs tests so we can check the code does what we want to happen in all the different possible scenarios
src/arviz_plots/plots/ppcplot.py
Outdated
coords=None, | ||
sample_dims=None, | ||
kind=None, | ||
facet_dims=None, | ||
data_pairs=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data_pairs argument should be kept
src/arviz_plots/plots/ppcplot.py
Outdated
|
||
# pp distribution group plotting logic | ||
pp_distribution = process_group_variables_coords( | ||
dt, group=predictive_data_group, var_names=var_names, filter_vars=filter_vars, coords=coords |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data_pairs
should be taken into account here:
dt, group=predictive_data_group, var_names=var_names, filter_vars=filter_vars, coords=coords | |
dt, group=predictive_data_group, var_names=None if var_names is None else [data_pairs.get(var_name, var_name) for var_name in var_names], filter_vars=filter_vars, coords=coords |
Just added some tests. I also modified the Also restored the data_pairs arg |
About the tests- the 'kind' and 'group' arguments I've parametrized both via pytest and via hypothesis.strategies. Should only one be picked or this kept? The hypothesis tests are also failing a few times. At least one seems to be because of trying to toggle the 'observed' artist off with plot_kwargs but not in the top level arg, so the ValueError set for this gets raised. The others I am still unsure of why they're failing. |
src/arviz_plots/plots/ppcplot.py
Outdated
Dimensions to loop over in plotting posterior/prior predictive. | ||
Note: Dims not in sample_dims or facet_dims (below) will be reduced by default. | ||
Defaults to ``rcParams["data.sample_dims"]`` | ||
kind : {"kde", "cumulative", "scatter"}, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kind : {"kde", "cumulative", "scatter"}, optional | |
kind : {"kde", "cumulative", "scatter"}, optional |
I see you did some changes like renaming "cumulative" to "ecdf" internally, but that will be confusing for users. There should be 4 possible values accepted by kind
: kde, ecdf, hist and scatter
src/arviz_plots/plots/ppcplot.py
Outdated
Dictionary containing relations between observed data and posterior/prior predictive data. | ||
Dictionary structure: | ||
* key = observed data var_name | ||
* value = posterior/prior predictive var_name | ||
For example, data_pairs = {'y' : 'y_hat'}. | ||
If None, it will assume that the observed data and the posterior/prior predictive data | ||
have the same variable name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dictionary containing relations between observed data and posterior/prior predictive data. | |
Dictionary structure: | |
* key = observed data var_name | |
* value = posterior/prior predictive var_name | |
For example, data_pairs = {'y' : 'y_hat'}. | |
If None, it will assume that the observed data and the posterior/prior predictive data | |
have the same variable name | |
Dictionary containing relations between observed data and posterior/prior predictive data. | |
Dictionary keys are variable names corresponding to observed data and dictionary values | |
are variable names corresponding to posterior/prior predictive. For example, data_pairs = {'y' : 'y_hat'}. | |
By default, it will assume that the observed data and the posterior/prior predictive data | |
have the same variable names |
src/arviz_plots/plots/ppcplot.py
Outdated
aggregate: bool, optional | ||
If True, predictive data will be aggregated over both sample_dims and reduce_dims. | ||
Defaults to False. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aggregate: bool, optional | |
If True, predictive data will be aggregated over both sample_dims and reduce_dims. | |
Defaults to False. | |
aggregate: bool, default False | |
If True, predictive data will be aggregated over both sample_dims and reduce_dims. |
When the default is the value that is used it should be documented here directly. If it is a placeholder or if it depends on other values then use optional
and explain the behaviour in the description
src/arviz_plots/plots/ppcplot.py
Outdated
Defaults to empty list. A warning is raised if `pc_kwargs` is also used to define | ||
dims to facet over with the `cols` key and `pc_kwargs` takes precedence. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defaults to empty list. A warning is raised if `pc_kwargs` is also used to define | |
dims to facet over with the `cols` key and `pc_kwargs` takes precedence. | |
Defaults to empty list. |
src/arviz_plots/plots/ppcplot.py
Outdated
pp_sample_ix = rng.choice(total_pp_samples, size=num_pp_samples, replace=False) | ||
|
||
# print(f"\npp_sample_ix: {pp_sample_ix!r}") | ||
|
||
# stacking sample dims into a new 'ppc_dim' dimension | ||
pp_distribution = pp_distribution.stack(ppc_dim=sample_dims) | ||
|
||
# Select the desired samples | ||
pp_distribution = pp_distribution.isel(ppc_dim=pp_sample_ix) | ||
|
||
# renaming sample_dims so that rest of plot will consider this as sample_dims | ||
sample_dims = ["ppc_dim"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should skip all these steps if num_pp_samples == total_pp_samples. I think nothing within the function itself should change (test probably will) and it would allow using aesthetics for chain/draw when plotting all samples
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we might also want to distinguish between length 1 sample_dims (in which case it is not necessary to stack nor to rename and overwrite its value) and >1 length sample_dims (in which case stacking is necessary)
src/arviz_plots/visuals/__init__.py
Outdated
if flatten is not False: | ||
xvalues = xvalues.values.flatten() # flatten xvalues | ||
if len(xvalues.shape) != 1: | ||
raise ValueError(f"Expected unidimensional data but got {xvalues.sizes}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you try what happens when xvalues is always flattened and no error is risen?
src/arviz_plots/plots/ppcplot.py
Outdated
# checking if plot_kwargs["observed"] or plot_kwargs["aggregate"] is not inconsistent with | ||
# top level bool arguments `observed`` and `aggregate` | ||
|
||
# observed args logic check: | ||
# observed will be True/False depending on user-input/group- true for posterior, false for prior | ||
# if observed and no plot_kwargs['observed'], no prob | ||
# if observed = True and plot_kwargs['observed'] = True, no prob (observed is plotted) | ||
# if observed = True and plot_kwargs['observed'] = False, prob (error/warning raised) | ||
# if observed = False and plot_kwargs['observed'] = True, prob (error/warning raised) | ||
# if observed = False and plot_kwargs['observed'] = False, no prob (observed is not plotted) | ||
if ( | ||
observed | ||
and plot_kwargs.get("observed", True) is False | ||
or not observed | ||
and plot_kwargs.get("observed", False) is not False | ||
): | ||
# raise warning or error | ||
raise ValueError( | ||
f""" | ||
`observed` and `plot_kwargs["observed"]` inconsistency detected. | ||
`observed` = {observed}{default_observed} | ||
`plot_kwargs["observed"]` = {plot_kwargs["observed"]} | ||
Please make sure `observed` and `plot_kwargs["observed"]` have the same value.""" | ||
) | ||
|
||
# same check for aggregate: | ||
if ( | ||
aggregate | ||
and plot_kwargs.get("aggregate", True) is False | ||
or not aggregate | ||
and plot_kwargs.get("aggregate", False) is not False | ||
): | ||
# raise warning or error | ||
raise ValueError( | ||
f""" | ||
`aggregate` and `plot_kwargs["aggregate"]` inconsistency detected. | ||
`aggregate` = {aggregate} | ||
`plot_kwargs["aggregate"]` = {plot_kwargs["aggregate"]} | ||
Please make sure `aggregate` and `plot_kwargs["aggregate"]` have the same value.""" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seeing the tests and having reviewed plot_forest, I realized this is different behaviour from plot_forest
. It might make sense to mimic the behaviour completely, which would be raising an error if plot_kwargs["aggregate"]
(or observed) is False
, continue on ahead otherwise.
This would mean we could have a dict in that key that gets ignored and the element is not plotted, but that is also what happens in plot_forest and I think is good enough behaviour.
tests/test_hypothesis_plots.py
Outdated
}, | ||
), | ||
kind=ppc_kind_value, | ||
group=ppc_group, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All relevant plot specific arguments should be here, especially if they affect what gets plotted or how the data is processed within the function. Thus, observed
and aggregate
(toggling observed_rug
on too if observed
is true), facet_dims
and num_pp_samples
should be here.
tests/test_plots.py
Outdated
@@ -246,3 +272,52 @@ def test_plot_ridge_aes_labels_shading(self, backend, datatree_4d, pseudo_dim): | |||
if pseudo_dim != "__variable__": | |||
assert all(0 in child["alpha"] for child in pc.aes.children.values()) | |||
assert any(pseudo_dim in child["shade"].dims for child in pc.viz.children.values()) | |||
|
|||
@pytest.mark.parametrize("group", ("prior", "posterior")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to parametrize by group. It plays an extremely minor role in the function and it is also being checked in the hypothesis tests.
tests/test_plots.py
Outdated
@pytest.mark.parametrize("group", ("prior", "posterior")) | ||
@pytest.mark.parametrize("kind", ("kde", "cumulative")) | ||
@pytest.mark.parametrize("facet_dims", (["group"], ["hierarchy"], None)) | ||
def test_plot_ppc_4d(self, datatree_4d, facet_dims, kind, group, backend): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there should be assertions to check facet_dims
is working as expected. For example, making sure that dimension is present in pc.viz["obs"]["plot"]
Modified the There are still 2 failing conditions in the hypothesis tests though -I'm not sure why yet- where the datatree seems to be interpreted as a function for some reason |
Some checks on this should still be present |
Resolved latest reviews with these last commits- the predictive values stacking/subselection logic was updated and the the tests modified. The Hypothesis tests still seem to be failing for 2 conditions though |
Just rebased this PR too |
sample_dims=ppc_sample_dims, | ||
drawed_samples=draw_num_pp_samples(ppc_group, ppc_sample_dims), | ||
) | ||
def test_plot_ppc( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def test_plot_ppc( | |
def test_plot_ppc( | |
datatree, |
Added support for kind="scatter" type plots in
|
src/arviz_plots/plots/ppcplot.py
Outdated
) | ||
pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() | ||
pc_kwargs["aes"].setdefault("overlay", sample_dims) # setting overlay dim | ||
pc_kwargs.setdefault("y", np.linspace(0.01, 0.1, 11)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't set default values
src/arviz_plots/plots/ppcplot.py
Outdated
data=obs_distribution, | ||
ignore_aes=obs_density_ignore, | ||
xname=False, | ||
y=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move to -1 so it doesn't overlap with the pp ones
tests/test_hypothesis_plots.py
Outdated
observed=ppc_observed, | ||
# observed_rug=ppc_observed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use st.composite for these
@st.composite
def observed_and_rug(draw):
observed = draw(st.booleans())
if not observed:
return (False, False)
rug = draw(st.booleans())
return (observed, rug)
…t keys in stats_kwargs for each of the 3 major artists, passed as stats_kwargs['density'] through internally to plot_dist
Fixed some more test related issues (new errors still persisting though). The 'alpha' keyword was missing from the 'none' backend for the 'hist' function, unlike the matplotlib and bokeh backends, which is why it hadn't been detected yet. Another fixed test fail was due to the fact that different artists have different dimensions (for example, the 'predictive' artists are overlayed along sample dims but the 'observed' artist does not have these dimensions and the 'aggregate' artist has these reduced instead of overlayed) so the dimensions were varying when the 'predictive' key in Suggestion by @OriolAbril to deal with the above is to check dimensions of the artists in the tests. Current issue with this though is since a lot of the artists Current tests bypass this if 'predictive' is False in Todo:
|
…each artists' dimensions
…te instead of '--' manually
Adding plot_ppc (issue #11).
This is a PPC plot draft that currently plots the actual observed values and prior/posterior predictive values, flattening across all dimensions except sample dims (chain, draw) by default
WIP:
📚 Documentation preview 📚: https://arviz-plots--55.org.readthedocs.build/en/55/