Skip to content

Commit

Permalink
Refactor and extend Roulette (#546)
Browse files Browse the repository at this point in the history
* Refactor roulette

* extend test
  • Loading branch information
aloctavodia authored Oct 1, 2024
1 parent 4ee6344 commit cf3cc2b
Show file tree
Hide file tree
Showing 6 changed files with 458 additions and 474 deletions.
12 changes: 7 additions & 5 deletions preliz/internal/distribution_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,13 @@ def process_extra(input_string):
name = match[0]
args = match[1].split(",")
arg_dict = {}
for arg in args:
key, value = arg.split("=")
arg_dict[key.strip()] = float(value)
result_dict[name] = arg_dict

try:
for arg in args:
key, value = arg.split("=")
arg_dict[key.strip()] = float(value)
result_dict[name] = arg_dict
except ValueError:
pass
return result_dict


Expand Down
5 changes: 4 additions & 1 deletion preliz/internal/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,10 @@ def fit_to_ecdf(selected_distributions, x_vals, ecdf, mean, std, x_min, x_max, e
fitted = Loss(len(selected_distributions))
for dist in selected_distributions:
if dist.__class__.__name__ in extra_pros:
dist._parametrization(**extra_pros[dist.__class__.__name__])
try:
dist._parametrization(**extra_pros[dist.__class__.__name__])
except TypeError:
pass
if dist.__class__.__name__ == "BetaScaled":
update_bounds_beta_scaled(dist, x_min, x_max)

Expand Down
64 changes: 56 additions & 8 deletions preliz/tests/roulette.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"import ipytest\n",
"ipytest.autoconfig()\n",
"\n",
"from preliz import roulette"
"from preliz import Roulette"
]
},
{
Expand All @@ -22,16 +22,64 @@
"metadata": {},
"outputs": [],
"source": [
"%%ipytest\n",
"\n",
"@pytest.mark.parametrize(\"x_min, x_max, nrows, ncols, figsize\", [\n",
" (0, 10, 10, 10, None), # Test default behavior\n",
" (-5, 5, 10, 10, None), # Test different domain\n",
" (0, 10, 5, 5, None), # Test different grid dimensions\n",
"@pytest.mark.parametrize(\"x_min, x_max, nrows, ncols, figsize, dist_names, params\", [\n",
" (0, 10, 10, 10, None, None, None), # Test default behavior\n",
" (-5, 5, 10, 10, None, None, None), # Test different domain\n",
" (0, 10, 5, 5, None, None, None), # Test different grid dimensions\n",
" (0, 10, 10, 10, (10, 8)), # Test custom figsize\n",
" (0, 10, 10, 10, None, [\"Normal\", \"StudentT\"], \"Normal(mu=0), StudentT(nu=0.001)\"), # Test custom dist and params\n",
"])\n",
"def test_roulette(x_min, x_max, nrows, ncols, figsize):\n",
" roulette(x_min, x_max, nrows, ncols, figsize)"
" Roulette(x_min, x_max, nrows, ncols, figsize)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "70ae102b",
"metadata": {},
"outputs": [],
"source": [
"def test_roulette_initialization():\n",
" roulette = Roulette(x_min=0, x_max=10, nrows=10, ncols=11)\n",
" assert roulette._x_min == 0\n",
" assert roulette._x_max == 10\n",
" assert roulette._nrows == 10\n",
" assert roulette._ncols == 11\n",
" assert roulette._figsize == (8, 6)\n",
"\n",
"\n",
"def test_roulette_update_grid():\n",
" roulette = Roulette(x_min=0, x_max=10, nrows=10, ncols=11)\n",
" roulette._widgets['w_x_min'].value = 1\n",
" roulette._widgets['w_x_max'].value = 9\n",
" roulette._widgets['w_nrows'].value = 8\n",
" roulette._widgets['w_ncols'].value = 9\n",
" roulette._update_grid()\n",
" assert roulette._x_min == 1\n",
" assert roulette._x_max == 9\n",
" assert roulette._nrows == 8\n",
" assert roulette._ncols == 9\n",
"\n",
"\n",
"def test_roulette_weights_to_ecdf():\n",
" roulette = Roulette(x_min=0, x_max=10, nrows=10, ncols=11)\n",
" roulette._grid._weights = {0: 2, 1: 6, 2: 10, 3: 10, 4: 7, 5: 3, 6: 1, 7: 1, 8: 1, 9: 1}\n",
" x_vals, cum_sum, probabilities, mean, std, filled_columns = roulette._weights_to_ecdf()\n",
" assert len(x_vals) == 10\n",
" assert len(cum_sum) == 10\n",
" assert len(probabilities) == 10\n",
" assert filled_columns == 10\n",
"\n",
"\n",
"def test_roulette_on_leave_fig():\n",
" roulette = Roulette(x_min=0, x_max=10, nrows=10, ncols=11)\n",
" roulette._grid._weights = {0: 2, 1: 6, 2: 10, 3: 10, 4: 7, 5: 3, 6: 1, 7: 1, 8: 1, 9: 1}\n",
" roulette._widgets['w_distributions'].value = [\"Gamma\", \"LogNormal\", \"StudentT\", \"BetaScaled\", \"Normal\"]\n",
" roulette._widgets['w_repr'].value = \"pdf\"\n",
" roulette._on_leave_fig()\n",
" assert roulette.dist is not None\n",
" assert roulette.hist is not None"
]
}
],
Expand Down
23 changes: 0 additions & 23 deletions preliz/tests/test_roulette.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,5 @@
from test_helper import run_notebook
from preliz.unidimensional.roulette import create_figure, create_grid, Rectangles, on_leave_fig


def test_roulette():
run_notebook("roulette.ipynb")


def test_roulette_mock():
x_min = 0
x_max = 10
ncols = 10
nrows = 10

fig, ax_grid, ax_fit = create_figure((10, 9))
coll = create_grid(x_min, x_max, nrows, ncols, ax=ax_grid)
grid = Rectangles(fig, coll, nrows, ncols, ax_grid)
grid.weights = {0: 2, 1: 6, 2: 10, 3: 10, 4: 7, 5: 3, 6: 1, 7: 1, 8: 1, 9: 1}
w_repr = "kde"
distributions = ["Gamma", "LogNormal", "StudentT", "BetaScaled", "Normal"]

for idx, dist in enumerate(distributions):
w_distributions = distributions[idx:]

fitted_dist = on_leave_fig(
fig.canvas, grid, w_distributions, w_repr, x_min, x_max, ncols, "", ax_fit
)
assert fitted_dist.__class__.__name__ == dist
4 changes: 2 additions & 2 deletions preliz/unidimensional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
from .mle import mle
from .quartile import quartile
from .quartile_int import quartile_int
from .roulette import roulette
from .roulette import Roulette

__all__ = ["beta_mode", "maxent", "mle", "roulette", "quartile", "quartile_int"]
__all__ = ["beta_mode", "maxent", "mle", "Roulette", "quartile", "quartile_int"]
Loading

0 comments on commit cf3cc2b

Please sign in to comment.