Skip to content

Commit

Permalink
bambi parser (#300)
Browse files Browse the repository at this point in the history
* draft bambi parser

* draft bambi parser

* fix linter
  • Loading branch information
aloctavodia authored Nov 29, 2023
1 parent 5823441 commit 60480f8
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 52 deletions.
129 changes: 95 additions & 34 deletions preliz/internal/parser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import inspect
import re
from sys import modules
Expand All @@ -22,7 +23,8 @@ def parse_function_for_pred_textboxes(source, signature):
slidify = list(signature.parameters.keys())
regex = r"\b" + r"\b|\b".join(slidify) + r"\b"

matches = match_preliz_dist(source)
all_dist_str = dist_as_str()
matches = match_preliz_dist(all_dist_str, source, "preliz")

for match in matches:
dist_name_str = match.group(2)
Expand Down Expand Up @@ -54,53 +56,112 @@ def parse_arguments(lst, regex):
return result


def get_prior_pp_samples(fmodel, draws, values=None):
def get_prior_pp_samples(fmodel, variables, draws, engine=None, values=None):
if values is None:
values = []

match = match_return_variables(fmodel)
if match:
variables = [var.strip() for var in match.group(1).split(",")]

obs_rv = variables[-1] # only one observed for the moment
pp_samples_ = []
prior_samples_ = {name: [] for name in variables[:-1]}
for _ in range(draws):
for name, value in zip(variables, fmodel(*values)):
if name == obs_rv:
pp_samples_.append(value)
else:
prior_samples_[name].append(value)
if engine == "preliz":
obs_rv = variables[-1] # only one observed for the moment
pp_samples_ = []
prior_samples_ = {name: [] for name in variables[:-1]}
for _ in range(draws):
for name, value in zip(variables, fmodel(*values)):
if name == obs_rv:
pp_samples_.append(value)
else:
prior_samples_[name].append(value)

pp_samples = np.stack(pp_samples_)
prior_samples = {key: np.array(val) for key, val in prior_samples_.items()}
pp_samples = np.stack(pp_samples_)
prior_samples = {key: np.array(val) for key, val in prior_samples_.items()}
elif engine == "bambi":
*prior_samples_, pp_samples = fmodel(*values)
prior_samples = {name: np.array(val) for name, val in zip(variables[:-1], prior_samples_)}

return pp_samples, prior_samples, obs_rv
return pp_samples, prior_samples


def parse_function_for_ppa(source, obs_rv):
model = {}
def from_preliz(fmodel):
source = inspect.getsource(fmodel)
variables = match_return_variables(source)
# Find the priors we want to change
all_dist_str = dist_as_str()
matches = match_preliz_dist(all_dist_str, source, "preliz")
# Create a dictionary with the priors
model = dict_model(matches, variables)

matches = match_preliz_dist(source)
for match in matches:
var_name = match.group(0).split("=")[0].strip()
if var_name != obs_rv:
dist = getattr(modules["preliz.distributions"], match.group(2))
model[var_name] = dist()
return variables, model

return model

def from_bambi(fmodel, draws):
module_name = fmodel.__module__
module = importlib.import_module(module_name)

def match_preliz_dist(source):
all_distributions = modules["preliz.distributions"].__all__
all_dist_str = "|".join(all_distributions)
# Get the source code of the original function
original_source = inspect.getsource(fmodel)

# Define a pattern to find the line where the model is built
pattern = re.compile(r"(\s+)([a-zA-Z_]\w*)\s*=\s*.*?Model(.*)")

# Find the match in the source code
match = pattern.search(original_source)

# Extract the indentation and variable name
indentation = match.group(1)
variable_name = match.group(2)

# Find the variables after the return statement
return_variables = match_return_variables(original_source)

if return_variables:
# Build the new source code
new_source = original_source.replace(
match.group(0),
f"{match.group(0)}"
f"{indentation}{variable_name}.build()\n"
f"{indentation}variables = [{variable_name}.backend.model.named_vars[v] "
f"for v in {return_variables}]\n"
f'{indentation}{", ".join(return_variables)} = pm.draw(variables, draws={draws})',
)

# Find the priors we want to change
all_dist_str = dist_as_str()
matches = match_preliz_dist(all_dist_str, new_source, "bambi")
# Create a dictionary with the priors
model = dict_model(matches, return_variables)

regex = rf"(.*?({all_dist_str}).*?)\(([^()]*(?:\([^()]*\)[^()]*)*)\)"
# Execute the new source code to redefine the function
exec(new_source, module.__dict__) # pylint: disable=exec-used
modified_fmodel = getattr(module, fmodel.__name__)

return modified_fmodel, return_variables, model


def match_preliz_dist(all_dist_str, source, engine):
if engine == "preliz":
regex = rf"(.*?({all_dist_str}).*?)\(([^()]*(?:\([^()]*\)[^()]*)*)\)"
if engine == "bambi":
regex = rf'(\w+)\s*=\s*(?:\w+\.)?Prior\("({all_dist_str})",\s*((?:\w+=\w+(?:,?\s*)?)*)\s*\)'
matches = re.finditer(regex, source)
return matches


def match_return_variables(fmodel):
source = inspect.getsource(fmodel)
def match_return_variables(source):
match = re.search(r"return (\w+(\s*,\s*\w+)*)", source)
return match
return [var.strip() for var in match.group(1).split(",")]


def dist_as_str():
all_distributions = modules["preliz.distributions"].__all__
return "|".join(all_distributions)


def dict_model(matches, return_variables):
model = {}
obs_rv = return_variables[-1]
for match in matches:
var_name = match.group(0).split("=")[0].strip()
if var_name != obs_rv:
dist = getattr(modules["preliz.distributions"], match.group(2))
model[var_name] = dist()

return model
25 changes: 16 additions & 9 deletions preliz/predictive/ppa.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
plot_pp_samples,
plot_pp_mean,
)
from ..internal.parser import inspect_source, parse_function_for_ppa, get_prior_pp_samples
from ..internal.parser import get_prior_pp_samples, from_preliz, from_bambi
from ..internal.predictive_helper import back_fitting, select_prior_samples
from ..distributions.continuous import Normal
from ..distributions.distributions import Distribution

_log = logging.getLogger("preliz")


def ppa(fmodel, draws=2000, references=0, boundaries=(-np.inf, np.inf), target=None):
def ppa(
fmodel, draws=2000, references=0, boundaries=(-np.inf, np.inf), target=None, engine="preliz"
):
"""
Prior predictive check assistant.
Expand All @@ -46,6 +48,8 @@ def ppa(fmodel, draws=2000, references=0, boundaries=(-np.inf, np.inf), target=N
Target distribution. The first shown distributions will be selected to be as close
as possible to `target`. Available options are, a PreliZ distribution or a 2-tuple with
the first element representing the mean and the second the standard deviation.
engine : str
Library used to define the model. Either `preliz` or `bambi`. Defaults to `preliz`
"""
check_inside_notebook(need_widget=True)

Expand All @@ -54,7 +58,7 @@ def ppa(fmodel, draws=2000, references=0, boundaries=(-np.inf, np.inf), target=N
if isinstance(references, (float, int)):
references = [references]

filter_dists = FilterDistribution(fmodel, draws, references, boundaries, target)
filter_dists = FilterDistribution(fmodel, draws, references, boundaries, target, engine)
filter_dists()

output = widgets.Output()
Expand Down Expand Up @@ -117,19 +121,19 @@ def click(event):


class FilterDistribution: # pylint:disable=too-many-instance-attributes
def __init__(self, fmodel, draws, references, boundaries, target):
def __init__(self, fmodel, draws, references, boundaries, target, engine):
self.fmodel = fmodel
self.source = "" # string representation of the model
self.draws = draws
self.references = references
self.boundaries = boundaries
self.target = target
self.engine = engine
self.pp_samples = None # prior predictive samples
self.prior_samples = None # prior samples used for backfitting
self.display_pp_idxs = None # indices of the pp_samples to be displayed
self.pp_octiles = None # octiles computed from pp_samples
self.kdt = None # KDTree used to find similar distributions
self.obs_rv = None # name of the observed random variable
self.model = None # parsed model used for backfitting
self.clicked = [] # axes clicked by the user
self.choices = [] # indices of the pp_samples selected by the user and not yet used to
Expand All @@ -144,11 +148,14 @@ def __init__(self, fmodel, draws, references, boundaries, target):

def __call__(self):

self.pp_samples, self.prior_samples, self.obs_rv = get_prior_pp_samples(
self.fmodel, self.draws
if self.engine == "preliz":
variables, self.model = from_preliz(self.fmodel)
elif self.engine == "bambi":
self.fmodel, variables, self.model = from_bambi(self.fmodel, self.draws)

self.pp_samples, self.prior_samples = get_prior_pp_samples(
self.fmodel, variables, self.draws, self.engine
)
self.source, _ = inspect_source(self.fmodel)
self.model = parse_function_for_ppa(self.source, self.obs_rv)

if self.target is not None:
self.add_target_dist()
Expand Down
31 changes: 22 additions & 9 deletions preliz/predictive/predictive_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
pass

from ..internal.plot_helper import create_figure, check_inside_notebook, plot_repr, reset_dist_panel
from ..internal.parser import inspect_source, parse_function_for_ppa, get_prior_pp_samples
from ..internal.parser import get_prior_pp_samples, from_bambi, from_preliz
from ..internal.predictive_helper import back_fitting, select_prior_samples

_log = logging.getLogger("preliz")


def predictive_finder(fmodel, target, draws=100, steps=5, figsize=None):
def predictive_finder(fmodel, target, draws=100, steps=5, engine="preliz", figsize=None):
"""
Prior predictive finder.
Expand All @@ -39,6 +39,8 @@ def predictive_finder(fmodel, target, draws=100, steps=5, figsize=None):
initial guess. If your initial prior predictive distribution is far from the target
distribution you may need to increase the number of steps. Alternatively, you can
click on the figure or press the `carry on` button many times.
engine : str
Library used to define the model. Either `preliz` or `bambi`. Defaults to `preliz`.
figsize : tuple
Figure size. If None, the default is (8, 6).
"""
Expand All @@ -54,7 +56,7 @@ def predictive_finder(fmodel, target, draws=100, steps=5, figsize=None):

button_carry_on, button_return_prior, w_repr = get_widgets()

match_distribution = MatchDistribution(fig, fmodel, target, draws, steps, ax_fit)
match_distribution = MatchDistribution(fig, fmodel, target, draws, steps, engine, ax_fit)

plot_pp_samples(match_distribution.pp_samples, draws, target, w_repr.value, fig, ax_fit)
fig.suptitle(
Expand Down Expand Up @@ -102,32 +104,43 @@ def on_return_prior(fig, ax_fit):


class MatchDistribution: # pylint:disable=too-many-instance-attributes
def __init__(self, fig, fmodel, target, draws, steps, ax):
def __init__(self, fig, fmodel, target, draws, steps, engine, ax):
self.fig = fig
self.fmodel = fmodel
self.target = target
self.target_octiles = self.target.ppf([0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875])
self.draws = draws
self.steps = steps
self.engine = engine
self.ax = ax
self.pp_samples, _, obs_rv = get_prior_pp_samples(self.fmodel, self.draws)
self.model = parse_function_for_ppa(inspect_source(self.fmodel)[0], obs_rv)
self.values = None
self.string = None

if self.engine == "preliz":
self.variables, self.model = from_preliz(self.fmodel)
elif self.engine == "bambi":
self.fmodel, self.variables, self.model = from_bambi(self.fmodel, self.draws)

self.pp_samples, _ = get_prior_pp_samples(
self.fmodel, self.variables, self.draws, self.engine
)

def __call__(self, kind_plot):
self.fig.texts = []

for _ in range(self.steps):
pp_samples, prior_samples, _ = get_prior_pp_samples(
self.fmodel, self.draws, self.values
pp_samples, prior_samples = get_prior_pp_samples(
self.fmodel, self.variables, self.draws, self.engine, self.values
)
values_to_fit = select(
prior_samples, pp_samples, self.draws, self.target_octiles, self.model
)
self.string, self.values = back_fitting(self.model, values_to_fit, new_families=False)

self.pp_samples = [self.fmodel(*self.values)[-1] for _ in range(self.draws)]
if self.engine == "preliz":
self.pp_samples = [self.fmodel(*self.values)[-1] for _ in range(self.draws)]
elif self.engine == "bambi":
self.pp_samples = self.fmodel(*self.values)[-1]

reset_dist_panel(self.ax, True)
plot_repr(self.pp_samples, kind_plot, self.draws, self.ax)
Expand Down

0 comments on commit 60480f8

Please sign in to comment.