diff --git a/preliz/internal/parser.py b/preliz/internal/parser.py index 32e0b0fd..73aed135 100644 --- a/preliz/internal/parser.py +++ b/preliz/internal/parser.py @@ -1,3 +1,4 @@ +import importlib import inspect import re from sys import modules @@ -22,7 +23,8 @@ def parse_function_for_pred_textboxes(source, signature): slidify = list(signature.parameters.keys()) regex = r"\b" + r"\b|\b".join(slidify) + r"\b" - matches = match_preliz_dist(source) + all_dist_str = dist_as_str() + matches = match_preliz_dist(all_dist_str, source, "preliz") for match in matches: dist_name_str = match.group(2) @@ -54,53 +56,112 @@ def parse_arguments(lst, regex): return result -def get_prior_pp_samples(fmodel, draws, values=None): +def get_prior_pp_samples(fmodel, variables, draws, engine=None, values=None): if values is None: values = [] - match = match_return_variables(fmodel) - if match: - variables = [var.strip() for var in match.group(1).split(",")] - - obs_rv = variables[-1] # only one observed for the moment - pp_samples_ = [] - prior_samples_ = {name: [] for name in variables[:-1]} - for _ in range(draws): - for name, value in zip(variables, fmodel(*values)): - if name == obs_rv: - pp_samples_.append(value) - else: - prior_samples_[name].append(value) + if engine == "preliz": + obs_rv = variables[-1] # only one observed for the moment + pp_samples_ = [] + prior_samples_ = {name: [] for name in variables[:-1]} + for _ in range(draws): + for name, value in zip(variables, fmodel(*values)): + if name == obs_rv: + pp_samples_.append(value) + else: + prior_samples_[name].append(value) - pp_samples = np.stack(pp_samples_) - prior_samples = {key: np.array(val) for key, val in prior_samples_.items()} + pp_samples = np.stack(pp_samples_) + prior_samples = {key: np.array(val) for key, val in prior_samples_.items()} + elif engine == "bambi": + *prior_samples_, pp_samples = fmodel(*values) + prior_samples = {name: np.array(val) for name, val in zip(variables[:-1], prior_samples_)} - return pp_samples, prior_samples, obs_rv + return pp_samples, prior_samples -def parse_function_for_ppa(source, obs_rv): - model = {} +def from_preliz(fmodel): + source = inspect.getsource(fmodel) + variables = match_return_variables(source) + # Find the priors we want to change + all_dist_str = dist_as_str() + matches = match_preliz_dist(all_dist_str, source, "preliz") + # Create a dictionary with the priors + model = dict_model(matches, variables) - matches = match_preliz_dist(source) - for match in matches: - var_name = match.group(0).split("=")[0].strip() - if var_name != obs_rv: - dist = getattr(modules["preliz.distributions"], match.group(2)) - model[var_name] = dist() + return variables, model - return model +def from_bambi(fmodel, draws): + module_name = fmodel.__module__ + module = importlib.import_module(module_name) -def match_preliz_dist(source): - all_distributions = modules["preliz.distributions"].__all__ - all_dist_str = "|".join(all_distributions) + # Get the source code of the original function + original_source = inspect.getsource(fmodel) + + # Define a pattern to find the line where the model is built + pattern = re.compile(r"(\s+)([a-zA-Z_]\w*)\s*=\s*.*?Model(.*)") + + # Find the match in the source code + match = pattern.search(original_source) + + # Extract the indentation and variable name + indentation = match.group(1) + variable_name = match.group(2) + + # Find the variables after the return statement + return_variables = match_return_variables(original_source) + + if return_variables: + # Build the new source code + new_source = original_source.replace( + match.group(0), + f"{match.group(0)}" + f"{indentation}{variable_name}.build()\n" + f"{indentation}variables = [{variable_name}.backend.model.named_vars[v] " + f"for v in {return_variables}]\n" + f'{indentation}{", ".join(return_variables)} = pm.draw(variables, draws={draws})', + ) + + # Find the priors we want to change + all_dist_str = dist_as_str() + matches = match_preliz_dist(all_dist_str, new_source, "bambi") + # Create a dictionary with the priors + model = dict_model(matches, return_variables) - regex = rf"(.*?({all_dist_str}).*?)\(([^()]*(?:\([^()]*\)[^()]*)*)\)" + # Execute the new source code to redefine the function + exec(new_source, module.__dict__) # pylint: disable=exec-used + modified_fmodel = getattr(module, fmodel.__name__) + + return modified_fmodel, return_variables, model + + +def match_preliz_dist(all_dist_str, source, engine): + if engine == "preliz": + regex = rf"(.*?({all_dist_str}).*?)\(([^()]*(?:\([^()]*\)[^()]*)*)\)" + if engine == "bambi": + regex = rf'(\w+)\s*=\s*(?:\w+\.)?Prior\("({all_dist_str})",\s*((?:\w+=\w+(?:,?\s*)?)*)\s*\)' matches = re.finditer(regex, source) return matches -def match_return_variables(fmodel): - source = inspect.getsource(fmodel) +def match_return_variables(source): match = re.search(r"return (\w+(\s*,\s*\w+)*)", source) - return match + return [var.strip() for var in match.group(1).split(",")] + + +def dist_as_str(): + all_distributions = modules["preliz.distributions"].__all__ + return "|".join(all_distributions) + + +def dict_model(matches, return_variables): + model = {} + obs_rv = return_variables[-1] + for match in matches: + var_name = match.group(0).split("=")[0].strip() + if var_name != obs_rv: + dist = getattr(modules["preliz.distributions"], match.group(2)) + model[var_name] = dist() + + return model diff --git a/preliz/predictive/ppa.py b/preliz/predictive/ppa.py index e087c271..be1245b6 100644 --- a/preliz/predictive/ppa.py +++ b/preliz/predictive/ppa.py @@ -17,7 +17,7 @@ plot_pp_samples, plot_pp_mean, ) -from ..internal.parser import inspect_source, parse_function_for_ppa, get_prior_pp_samples +from ..internal.parser import get_prior_pp_samples, from_preliz, from_bambi from ..internal.predictive_helper import back_fitting, select_prior_samples from ..distributions.continuous import Normal from ..distributions.distributions import Distribution @@ -25,7 +25,9 @@ _log = logging.getLogger("preliz") -def ppa(fmodel, draws=2000, references=0, boundaries=(-np.inf, np.inf), target=None): +def ppa( + fmodel, draws=2000, references=0, boundaries=(-np.inf, np.inf), target=None, engine="preliz" +): """ Prior predictive check assistant. @@ -46,6 +48,8 @@ def ppa(fmodel, draws=2000, references=0, boundaries=(-np.inf, np.inf), target=N Target distribution. The first shown distributions will be selected to be as close as possible to `target`. Available options are, a PreliZ distribution or a 2-tuple with the first element representing the mean and the second the standard deviation. + engine : str + Library used to define the model. Either `preliz` or `bambi`. Defaults to `preliz` """ check_inside_notebook(need_widget=True) @@ -54,7 +58,7 @@ def ppa(fmodel, draws=2000, references=0, boundaries=(-np.inf, np.inf), target=N if isinstance(references, (float, int)): references = [references] - filter_dists = FilterDistribution(fmodel, draws, references, boundaries, target) + filter_dists = FilterDistribution(fmodel, draws, references, boundaries, target, engine) filter_dists() output = widgets.Output() @@ -117,19 +121,19 @@ def click(event): class FilterDistribution: # pylint:disable=too-many-instance-attributes - def __init__(self, fmodel, draws, references, boundaries, target): + def __init__(self, fmodel, draws, references, boundaries, target, engine): self.fmodel = fmodel self.source = "" # string representation of the model self.draws = draws self.references = references self.boundaries = boundaries self.target = target + self.engine = engine self.pp_samples = None # prior predictive samples self.prior_samples = None # prior samples used for backfitting self.display_pp_idxs = None # indices of the pp_samples to be displayed self.pp_octiles = None # octiles computed from pp_samples self.kdt = None # KDTree used to find similar distributions - self.obs_rv = None # name of the observed random variable self.model = None # parsed model used for backfitting self.clicked = [] # axes clicked by the user self.choices = [] # indices of the pp_samples selected by the user and not yet used to @@ -144,11 +148,14 @@ def __init__(self, fmodel, draws, references, boundaries, target): def __call__(self): - self.pp_samples, self.prior_samples, self.obs_rv = get_prior_pp_samples( - self.fmodel, self.draws + if self.engine == "preliz": + variables, self.model = from_preliz(self.fmodel) + elif self.engine == "bambi": + self.fmodel, variables, self.model = from_bambi(self.fmodel, self.draws) + + self.pp_samples, self.prior_samples = get_prior_pp_samples( + self.fmodel, variables, self.draws, self.engine ) - self.source, _ = inspect_source(self.fmodel) - self.model = parse_function_for_ppa(self.source, self.obs_rv) if self.target is not None: self.add_target_dist() diff --git a/preliz/predictive/predictive_finder.py b/preliz/predictive/predictive_finder.py index b6486b08..0008d9d7 100644 --- a/preliz/predictive/predictive_finder.py +++ b/preliz/predictive/predictive_finder.py @@ -10,13 +10,13 @@ pass from ..internal.plot_helper import create_figure, check_inside_notebook, plot_repr, reset_dist_panel -from ..internal.parser import inspect_source, parse_function_for_ppa, get_prior_pp_samples +from ..internal.parser import get_prior_pp_samples, from_bambi, from_preliz from ..internal.predictive_helper import back_fitting, select_prior_samples _log = logging.getLogger("preliz") -def predictive_finder(fmodel, target, draws=100, steps=5, figsize=None): +def predictive_finder(fmodel, target, draws=100, steps=5, engine="preliz", figsize=None): """ Prior predictive finder. @@ -39,6 +39,8 @@ def predictive_finder(fmodel, target, draws=100, steps=5, figsize=None): initial guess. If your initial prior predictive distribution is far from the target distribution you may need to increase the number of steps. Alternatively, you can click on the figure or press the `carry on` button many times. + engine : str + Library used to define the model. Either `preliz` or `bambi`. Defaults to `preliz`. figsize : tuple Figure size. If None, the default is (8, 6). """ @@ -54,7 +56,7 @@ def predictive_finder(fmodel, target, draws=100, steps=5, figsize=None): button_carry_on, button_return_prior, w_repr = get_widgets() - match_distribution = MatchDistribution(fig, fmodel, target, draws, steps, ax_fit) + match_distribution = MatchDistribution(fig, fmodel, target, draws, steps, engine, ax_fit) plot_pp_samples(match_distribution.pp_samples, draws, target, w_repr.value, fig, ax_fit) fig.suptitle( @@ -102,32 +104,43 @@ def on_return_prior(fig, ax_fit): class MatchDistribution: # pylint:disable=too-many-instance-attributes - def __init__(self, fig, fmodel, target, draws, steps, ax): + def __init__(self, fig, fmodel, target, draws, steps, engine, ax): self.fig = fig self.fmodel = fmodel self.target = target self.target_octiles = self.target.ppf([0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875]) self.draws = draws self.steps = steps + self.engine = engine self.ax = ax - self.pp_samples, _, obs_rv = get_prior_pp_samples(self.fmodel, self.draws) - self.model = parse_function_for_ppa(inspect_source(self.fmodel)[0], obs_rv) self.values = None self.string = None + if self.engine == "preliz": + self.variables, self.model = from_preliz(self.fmodel) + elif self.engine == "bambi": + self.fmodel, self.variables, self.model = from_bambi(self.fmodel, self.draws) + + self.pp_samples, _ = get_prior_pp_samples( + self.fmodel, self.variables, self.draws, self.engine + ) + def __call__(self, kind_plot): self.fig.texts = [] for _ in range(self.steps): - pp_samples, prior_samples, _ = get_prior_pp_samples( - self.fmodel, self.draws, self.values + pp_samples, prior_samples = get_prior_pp_samples( + self.fmodel, self.variables, self.draws, self.engine, self.values ) values_to_fit = select( prior_samples, pp_samples, self.draws, self.target_octiles, self.model ) self.string, self.values = back_fitting(self.model, values_to_fit, new_families=False) - self.pp_samples = [self.fmodel(*self.values)[-1] for _ in range(self.draws)] + if self.engine == "preliz": + self.pp_samples = [self.fmodel(*self.values)[-1] for _ in range(self.draws)] + elif self.engine == "bambi": + self.pp_samples = self.fmodel(*self.values)[-1] reset_dist_panel(self.ax, True) plot_repr(self.pp_samples, kind_plot, self.draws, self.ax)