Skip to content

Commit

Permalink
add predictive_finder method (experimental) (#293)
Browse files Browse the repository at this point in the history
* add predictive finder experimental method

* update exampels

* make pylint happy
  • Loading branch information
aloctavodia authored Oct 25, 2023
1 parent d0f0cfb commit a5528dd
Show file tree
Hide file tree
Showing 8 changed files with 1,125 additions and 413 deletions.
1,053 changes: 773 additions & 280 deletions docs/examples/observed_space_examples.ipynb

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions preliz/internal/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def parse_arguments(lst, regex):
return result


def get_prior_pp_samples(fmodel, draws):
def get_prior_pp_samples(fmodel, draws, values=None):
if values is None:
values = []

match = match_return_variables(fmodel)
if match:
variables = [var.strip() for var in match.group(1).split(",")]
Expand All @@ -63,7 +66,7 @@ def get_prior_pp_samples(fmodel, draws):
pp_samples_ = []
prior_samples_ = {name: [] for name in variables[:-1]}
for _ in range(draws):
for name, value in zip(variables, fmodel()):
for name, value in zip(variables, fmodel(*values)):
if name == obs_rv:
pp_samples_.append(value)
else:
Expand Down
102 changes: 65 additions & 37 deletions preliz/internal/plot_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,47 +405,51 @@ def looper(*args, **kwargs):
_, ax = plt.subplots()
ax.set_xlim(x_min, x_max, auto=auto)

alpha = max(0.01, 1 - iterations * 0.009)

if kind_plot == "hist":
if results[0].dtype.kind == "i":
bins = np.arange(np.min(results), np.max(results) + 1.5) - 0.5
if len(bins) < 30:
ax.set_xticks(bins + 0.5)
else:
bins = "auto"
ax.hist(
results,
alpha=alpha,
density=True,
color=["C0"] * iterations,
bins=bins,
histtype="step",
)
ax.hist(
np.concatenate(results),
density=True,
bins=bins,
color="k",
ls="--",
histtype="step",
)
elif kind_plot == "kde":
for result in results:
plt.plot(*_kde_linear(result, grid_len=100), "C0", alpha=alpha)
plt.plot(*_kde_linear(np.concatenate(results), grid_len=100), "k--")
elif kind_plot == "ecdf":
plt.plot(
np.sort(results, axis=1).T,
np.linspace(0, 1, len(results[0]), endpoint=False),
color="C0",
)
a = np.concatenate(results)
plt.plot(np.sort(a), np.linspace(0, 1, len(a), endpoint=False), "k--")
plot_repr(results, kind_plot, iterations, ax)

return looper


def plot_repr(results, kind_plot, iterations, ax):
alpha = max(0.01, 1 - iterations * 0.009)

if kind_plot == "hist":
if results[0].dtype.kind == "i":
bins = np.arange(np.min(results), np.max(results) + 1.5) - 0.5
if len(bins) < 30:
ax.set_xticks(bins + 0.5)
else:
bins = "auto"
ax.hist(
results,
alpha=alpha,
density=True,
color=["0.5"] * iterations,
bins=bins,
histtype="step",
)
ax.hist(
np.concatenate(results),
density=True,
bins=bins,
color="k",
ls="--",
histtype="step",
)
elif kind_plot == "kde":
for result in results:
ax.plot(*_kde_linear(result, grid_len=100), "0.5", alpha=alpha)
ax.plot(*_kde_linear(np.concatenate(results), grid_len=100), "k--")
elif kind_plot == "ecdf":
ax.plot(
np.sort(results, axis=1).T,
np.linspace(0, 1, len(results[0]), endpoint=False),
color="0.5",
)
a = np.concatenate(results)
ax.plot(np.sort(a), np.linspace(0, 1, len(a), endpoint=False), "k--")


def plot_pp_samples(pp_samples, pp_samples_idxs, references, kind="pdf", sharex=True, fig=None):
row_colum = int(np.ceil(len(pp_samples_idxs) ** 0.5))

Expand Down Expand Up @@ -576,3 +580,27 @@ def representations(fitted_dist, kind_plot, ax):
elif kind_plot == "ppf":
fitted_dist.plot_ppf(pointinterval=True, legend="title", ax=ax)
ax.set_xlim(-0.01, 1)


def create_figure(figsize):
"""
Initialize a matplotlib figure with one subplot
"""
fig, axes = plt.subplots(1, 1, figsize=figsize, constrained_layout=True)
axes.set_yticks([])
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_position = "right"

return fig, axes


def reset_dist_panel(ax, yticks):
"""
Clean the distribution subplot
"""
ax.cla()
if yticks:
ax.set_yticks([])
ax.relim()
ax.autoscale_view()
70 changes: 70 additions & 0 deletions preliz/internal/predictive_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from sys import modules

import numpy as np


from .plot_helper import (
repr_to_matplotlib,
)

from ..unidimensional import mle


def back_fitting(model, subset, new_families=True):
"""
Use MLE to fit a subset of the prior samples to the marginal prior distributions
"""
string = "Your selection is consistent with the priors (original families):\n"

for name, dist in model.items():
dist._fit_mle(subset[name])
string += f"{name} = {repr_to_matplotlib(dist)}\n"

if new_families:
string += "\nYour selection is consistent with the priors (new families):\n"

exclude, distributions = get_distributions()
for name, dist in model.items():
if dist.__class__.__name__ in exclude:
dist._fit_mle(subset[name])
else:
idx, _ = mle(distributions, subset[name], plot=False)
dist = distributions[idx[0]]
string += f"{name} = {repr_to_matplotlib(dist)}\n"

return string, np.concatenate([dist.params for dist in model.values()])


def get_distributions():
exclude = [
"Beta",
"BetaScaled",
"Triangular",
"TruncatedNormal",
"Uniform",
"VonMises",
"Categorical",
"DiscreteUniform",
"HyperGeometric",
"zeroInflatedBinomial",
"ZeroInflatedNegativeBinomial",
"ZeroInflatedPoisson",
"MvNormal",
]
all_distributions = modules["preliz.distributions"].__all__
distributions = []
for a_dist in all_distributions:
dist = getattr(modules["preliz.distributions"], a_dist)()
if dist.__class__.__name__ not in exclude:
distributions.append(dist)
return exclude, distributions


def select_prior_samples(selected, prior_samples, model):
"""
Given a selected set of prior predictive samples pick the corresponding
prior samples.
"""
subsample = {rv: prior_samples[rv][selected] for rv in model.keys()}

return subsample
3 changes: 2 additions & 1 deletion preliz/predictive/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .ppa import ppa
from .predictive_explorer import predictive_explorer
from .predictive_finder import predictive_finder


__all__ = ["ppa", "predictive_explorer"]
__all__ = ["ppa", "predictive_explorer", "predictive_finder"]
69 changes: 3 additions & 66 deletions preliz/predictive/ppa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
from random import shuffle
from sys import modules

try:
import ipywidgets as widgets
Expand All @@ -17,12 +16,11 @@
check_inside_notebook,
plot_pp_samples,
plot_pp_mean,
repr_to_matplotlib,
)
from ..internal.parser import inspect_source, parse_function_for_ppa, get_prior_pp_samples
from ..internal.predictive_helper import back_fitting, select_prior_samples
from ..distributions.continuous import Normal
from ..distributions.distributions import Distribution
from ..unidimensional import mle

_log = logging.getLogger("preliz")

Expand All @@ -33,12 +31,11 @@ def ppa(
"""
Prior predictive check assistant.
This is experimental
This is an experimental method under development, use with caution.
Parameters
----------
model : PreliZ model
Model associated to ``idata``.
draws : int
Number of draws from the prior and prior predictive distribution
summary : str
Expand Down Expand Up @@ -176,7 +173,7 @@ def on_return_prior(fig, selected, model):
if len(selected) > 4:
subsample = select_prior_samples(selected, prior_samples, model)

string = back_fitting(model, subsample)
string, _ = back_fitting(model, subsample)

fig.clf()
plt.text(0.05, 0.5, string, fontsize=14)
Expand Down Expand Up @@ -412,63 +409,3 @@ def collect_more_samples(
return selected, shown
else:
return selected, shown


def select_prior_samples(selected, prior_samples, model):
"""
Given a selected set of prior predictive samples pick the corresponding
prior samples.
"""
subsample = {rv: prior_samples[rv][selected] for rv in model.keys()}

return subsample


def back_fitting(model, subset, new_families=True):
"""
Use MLE to fit a subset of the prior samples to the marginal prior distributions
"""
string = "Your selection is consistent with the priors (original families):\n"

for name, dist in model.items():
dist._fit_mle(subset[name])
string += f"{name} = {repr_to_matplotlib(dist)}\n"

if new_families:
string += "\nYour selection is consistent with the priors (new families):\n"

exclude, distributions = get_distributions()
for name, dist in model.items():
if dist.__class__.__name__ in exclude:
dist._fit_mle(subset[name])
else:
idx, _ = mle(distributions, subset[name], plot=False)
dist = distributions[idx[0]]
string += f"{name} = {repr_to_matplotlib(dist)}\n"

return string


def get_distributions():
exclude = [
"Beta",
"BetaScaled",
"Triangular",
"TruncatedNormal",
"Uniform",
"VonMises",
"Categorical",
"DiscreteUniform",
"HyperGeometric",
"zeroInflatedBinomial",
"ZeroInflatedNegativeBinomial",
"ZeroInflatedPoisson",
"MvNormal",
]
all_distributions = modules["preliz.distributions"].__all__
distributions = []
for a_dist in all_distributions:
dist = getattr(modules["preliz.distributions"], a_dist)()
if dist.__class__.__name__ not in exclude:
distributions.append(dist)
return exclude, distributions
Loading

0 comments on commit a5528dd

Please sign in to comment.