Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

customize roulette #291

Merged
merged 3 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions preliz/internal/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def func(params, dist, x_vals):
init_vals = np.array(dist.params)[none_idx]
bounds = np.array(dist.params_support)[none_idx]
bounds = list(zip(*bounds))

opt = least_squares(func, x0=init_vals, args=(dist, x_vals), bounds=bounds)
params = get_params(dist, opt["x"], none_idx, fixed)
dist._parametrization(**params)
Expand All @@ -88,7 +89,8 @@ def func(params, dist, x_vals, ecdf):
bounds = list(zip(*bounds))

opt = least_squares(func, x0=init_vals, args=(dist, x_vals, ecdf), bounds=bounds)
dist._update(*opt["x"])
params = get_params(dist, opt["x"], none_idx, fixed)
dist._parametrization(**params)
loss = opt["cost"]
return loss

Expand Down Expand Up @@ -203,7 +205,7 @@ def get_distributions(dist_names):
return dists


def fit_to_ecdf(selected_distributions, x_vals, ecdf, mean, std, x_min, x_max):
def fit_to_ecdf(selected_distributions, x_vals, ecdf, mean, std, x_min, x_max, extra_pros):
"""
Minimize the difference between the cdf and the ecdf over a grid of values
defined by x_min and x_max
Expand All @@ -212,6 +214,8 @@ def fit_to_ecdf(selected_distributions, x_vals, ecdf, mean, std, x_min, x_max):
"""
fitted = Loss(len(selected_distributions))
for dist in selected_distributions:
if dist.__class__.__name__ in extra_pros:
dist._parametrization(**extra_pros[dist.__class__.__name__])
if dist.__class__.__name__ == "BetaScaled":
update_bounds_beta_scaled(dist, x_min, x_max)

Expand Down
5 changes: 5 additions & 0 deletions preliz/tests/test_quartile_int.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from test_helper import run_notebook


def test_roulette():
run_notebook("quartile_int.ipynb")
2 changes: 1 addition & 1 deletion preliz/tests/test_roulette.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ def test_roulette_mock():
w_distributions = distributions[idx:]

fitted_dist = on_leave_fig(
fig.canvas, grid, w_distributions, w_repr, x_min, x_max, ncols, ax_fit
fig.canvas, grid, w_distributions, w_repr, x_min, x_max, ncols, "", ax_fit
)
assert fitted_dist.__class__.__name__ == dist
2 changes: 1 addition & 1 deletion preliz/unidimensional/quartile_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def get_widgets(q1, q2, q3, dist_names=None):

if dist_names is None:

default_dist = ["Normal", "BetaScaled", "Gamma", "LogNormal"]
default_dist = ["Normal", "BetaScaled", "Gamma", "LogNormal", "StudentT"]

dist_names = [
"AsymmetricLaplace",
Expand Down
107 changes: 67 additions & 40 deletions preliz/unidimensional/roulette.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
pass
from ..internal.optimization import fit_to_ecdf, get_distributions
from ..internal.plot_helper import check_inside_notebook, representations
from ..internal.distribution_helper import process_extra


def roulette(x_min=0, x_max=10, nrows=10, ncols=11, figsize=None):
def roulette(x_min=0, x_max=10, nrows=10, ncols=11, dist_names=None, figsize=None):
"""
Prior elicitation for 1D distribution using the roulette method.

Expand All @@ -29,6 +30,10 @@ def roulette(x_min=0, x_max=10, nrows=10, ncols=11, figsize=None):
Number of rows for the grid. Defaults to 10.
ncols: Optional[int]
Number of columns for the grid. Defaults to 11.
dist_names: list
List of distributions names to be used in the elicitation. If None, almost all 1D
distributions available in PreliZ will be used. Some distributions like Uniform or
Cauchy are omitted by default.
figsize: Optional[Tuple[int, int]]
Figure size. If None it will be defined automatically.

Expand All @@ -44,8 +49,12 @@ def roulette(x_min=0, x_max=10, nrows=10, ncols=11, figsize=None):

check_inside_notebook(need_widget=True)

w_x_min, w_x_max, w_ncols, w_nrows, w_repr, w_distributions = get_widgets(
x_min, x_max, nrows, ncols
w_x_min, w_x_max, w_ncols, w_nrows, w_extra, w_repr, w_distributions = get_widgets(
x_min,
x_max,
nrows,
ncols,
dist_names,
)

output = widgets.Output()
Expand Down Expand Up @@ -90,6 +99,7 @@ def on_leave_fig_(_):
w_x_min.value,
w_x_max.value,
w_ncols.value,
w_extra.value,
ax_fit,
)

Expand All @@ -113,11 +123,12 @@ def on_value_change(change):
w_x_min.value,
w_x_max.value,
w_ncols.value,
w_extra.value,
ax_fit,
),
)

controls = widgets.VBox([w_x_min, w_x_max, w_nrows, w_ncols])
controls = widgets.VBox([w_x_min, w_x_max, w_nrows, w_ncols, w_extra])

display(widgets.HBox([controls, w_repr, w_distributions])) # pylint:disable=undefined-variable

Expand Down Expand Up @@ -200,11 +211,12 @@ def __call__(self, event):
self.fig.canvas.draw()


def on_leave_fig(canvas, grid, dist_names, kind_plot, x_min, x_max, ncols, ax):
def on_leave_fig(canvas, grid, dist_names, kind_plot, x_min, x_max, ncols, extra, ax):
x_min = float(x_min)
x_max = float(x_max)
ncols = float(ncols)
x_range = x_max - x_min
extra_pros = process_extra(extra)

x_vals, ecdf, mean, std, filled_columns = weights_to_ecdf(grid.weights, x_min, x_range, ncols)

Expand All @@ -222,6 +234,7 @@ def on_leave_fig(canvas, grid, dist_names, kind_plot, x_min, x_max, ncols, ax):
std,
x_min,
x_max,
extra_pros,
)

if fitted_dist is None:
Expand Down Expand Up @@ -280,9 +293,10 @@ def reset_dist_panel(x_min, x_max, ax, yticks):
ax.autoscale_view()


def get_widgets(x_min, x_max, nrows, ncols):
def get_widgets(x_min, x_max, nrows, ncols, dist_names):

width_entry_text = widgets.Layout(width="150px")
width_repr_text = widgets.Layout(width="250px")
width_distribution_text = widgets.Layout(width="150px", height="125px")

w_x_min = widgets.FloatText(
Expand Down Expand Up @@ -319,6 +333,14 @@ def get_widgets(x_min, x_max, nrows, ncols):
layout=width_entry_text,
)

w_extra = widgets.Textarea(
value="",
placeholder="Pass extra parameters",
description="params:",
disabled=False,
layout=width_repr_text,
)

w_repr = widgets.RadioButtons(
options=["pdf", "cdf", "ppf"],
value="pdf",
Expand All @@ -327,39 +349,44 @@ def get_widgets(x_min, x_max, nrows, ncols):
layout=width_entry_text,
)

default_dist = ["Normal", "BetaScaled", "Gamma", "LogNormal", "StudentT"]

dist_names = [
"AsymmetricLaplace",
"BetaScaled",
"ChiSquared",
"ExGaussian",
"Exponential",
"Gamma",
"Gumbel",
"HalfNormal",
"HalfStudentT",
"InverseGamma",
"Laplace",
"LogNormal",
"Logistic",
# "LogitNormal", # fails if we add chips at x_value= 1
"Moyal",
"Normal",
"Pareto",
"Rice",
"SkewNormal",
"StudentT",
"Triangular",
"VonMises",
"Wald",
"Weibull",
"BetaBinomial",
"DiscreteWeibull",
"Geometric",
"NegativeBinomial",
"Poisson",
]
if dist_names is None:

default_dist = ["Normal", "BetaScaled", "Gamma", "LogNormal", "StudentT"]

dist_names = [
"AsymmetricLaplace",
"BetaScaled",
"ChiSquared",
"ExGaussian",
"Exponential",
"Gamma",
"Gumbel",
"HalfNormal",
"HalfStudentT",
"InverseGamma",
"Laplace",
"LogNormal",
"Logistic",
# "LogitNormal", # fails if we add chips at x_value= 1
"Moyal",
"Normal",
"Pareto",
"Rice",
"SkewNormal",
"StudentT",
"Triangular",
"VonMises",
"Wald",
"Weibull",
"BetaBinomial",
"DiscreteWeibull",
"Geometric",
"NegativeBinomial",
"Poisson",
]

else:
default_dist = dist_names

w_distributions = widgets.SelectMultiple(
options=dist_names,
Expand All @@ -369,4 +396,4 @@ def get_widgets(x_min, x_max, nrows, ncols):
layout=width_distribution_text,
)

return w_x_min, w_x_max, w_ncols, w_nrows, w_repr, w_distributions
return w_x_min, w_x_max, w_ncols, w_nrows, w_extra, w_repr, w_distributions
Loading