Skip to content

Commit

Permalink
Updated ppa widgets to include references (#328)
Browse files Browse the repository at this point in the history
* Updated ppa widgets to include references

* Run black

* Add margins between widgets

* Update the gif

* Add docs for dictionary option to references
  • Loading branch information
rohanbabbar04 authored Feb 22, 2024
1 parent 4404ec3 commit 950e2a7
Show file tree
Hide file tree
Showing 3 changed files with 709 additions and 653 deletions.
1,339 changes: 694 additions & 645 deletions docs/examples/observed_space_examples.ipynb

Large diffs are not rendered by default.

Binary file modified docs/examples/ppa.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
23 changes: 15 additions & 8 deletions preliz/predictive/ppa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Prior predictive check assistant."""

import logging
import ast
from random import shuffle

try:
Expand Down Expand Up @@ -38,9 +39,9 @@ def ppa(
model : PreliZ model
draws : int
Number of draws from the prior and prior predictive distribution
references : int, float, list or tuple
references : int, float, list, tuple or dictionary
Value(s) used as reference points representing prior knowledge. For example expected
values or values that are considered extreme.
values or values that are considered extreme. Use a dictionary for labeled references.
boundaries : tuple
Hard boundaries (lower, upper). Posterior predictive samples with values outside these
boundaries will be excluded from the analysis.
Expand All @@ -55,15 +56,19 @@ def ppa(

_log.info(""""This is an experimental method under development, use with caution.""")

if isinstance(references, (float, int)):
references = [references]

filter_dists = FilterDistribution(fmodel, draws, references, boundaries, target, engine)
filter_dists()

output = widgets.Output()

with output:
references_widget = widgets.Text(
value=str(references),
placeholder="Int, Float or tuple",
description="references: ",
disabled=False,
layout=widgets.Layout(width="230px", margin="0 20px 0 0"),
)
button_carry_on = widgets.Button(description="carry on")
button_return_prior = widgets.Button(description="return prior")
radio_buttons_kind = widgets.RadioButtons(
Expand All @@ -85,16 +90,17 @@ def ppa(

def kind_(_):
kind = radio_buttons_kind.value

plot_pp_samples(
filter_dists.pp_samples,
filter_dists.display_pp_idxs,
references,
ast.literal_eval(references_widget.value),
kind,
check_button_sharex.value,
filter_dists.fig,
)

references_widget.observe(kind_, names=["value"])

radio_buttons_kind.observe(kind_, names=["value"])

check_button_sharex.observe(kind_, names=["value"])
Expand All @@ -114,9 +120,10 @@ def click(event):
filter_dists.fig.canvas.mpl_connect("button_press_event", click)

controls = widgets.VBox([button_carry_on, button_return_prior])
plot_combine = widgets.VBox([radio_buttons_kind, check_button_sharex])

display( # pylint:disable=undefined-variable
widgets.HBox([controls, radio_buttons_kind, check_button_sharex, output])
widgets.HBox([references_widget, plot_combine, controls, output])
)


Expand Down

0 comments on commit 950e2a7

Please sign in to comment.