Skip to content

Commit

Permalink
ref
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Nov 2, 2024
1 parent ae4056e commit 684563b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 21 deletions.
27 changes: 13 additions & 14 deletions preliz/ppls/agnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@


from preliz import distributions
from preliz.internal.distribution_helper import init_vals, get_distributions
from preliz.internal.distribution_helper import init_vals
from preliz.internal.plot_helper import plot_repr


from preliz.distributions import Gamma, Normal, HalfNormal
from preliz.unidimensional.mle import mle
from preliz.ppls.pymc_io import get_model_information, write_pymc_string
from preliz.ppls.bambi_io import (
Expand All @@ -33,7 +32,7 @@
pass


def posterior_to_prior(model, idata, alternative=None, engine="auto"):
def posterior_to_prior(model, idata, new_families=None, engine="auto"):
"""
Fit a posterior from a model to its prior
Expand All @@ -46,7 +45,7 @@ def posterior_to_prior(model, idata, alternative=None, engine="auto"):
model : A PyMC or a Bambi Model
idata : InferenceData
InferenceData with a posterior group.
alternative : "auto", list or dict
new_families : "auto", list or dict
Defaults to None, the samples are fit to the original prior distribution.
If "auto", the method evaluates the fit to the original prior plus a set of
predefined distributions.
Expand All @@ -68,7 +67,7 @@ def posterior_to_prior(model, idata, alternative=None, engine="auto"):

_, _, preliz_model, _, untransformed_var_info, *_ = get_model_information(model)

new_priors = back_fitting_idata(idata, preliz_model, alternative)
new_priors = back_fitting_idata(idata, preliz_model, new_families)

if engine == "bambi":
new_model = write_bambi_string(new_priors, untransformed_var_info)
Expand All @@ -78,25 +77,25 @@ def posterior_to_prior(model, idata, alternative=None, engine="auto"):
return new_model


def back_fitting_idata(idata, model_info, alternative):
def back_fitting_idata(idata, model_info, new_families):
new_priors = {}
posterior = idata.posterior.stack(sample=("chain", "draw"))

if alternative is None:
if new_families is None:
for var, dist in model_info.items():
idx, _ = mle([dist], posterior[var].values, plot=False)
new_priors[var] = dist
else:
for var, dist in model_info.items():
dists = [dist]

if alternative == "auto":
alt = get_distributions(["Normal", "HalfNormal", "Gamma"])
if new_families == "auto":
alt = [Normal(), HalfNormal(), Gamma()]
dists += [a for a in alt if dist.__class__.__name__ != a.__class__.__name__]
elif isinstance(alternative, list):
dists += alternative
elif isinstance(alternative, dict):
dists += alternative.get(var, [])
elif isinstance(new_families, list):
dists += new_families
elif isinstance(new_families, dict):
dists += new_families.get(var, [])

idx, _ = mle(dists, posterior[var].values, plot=False)
new_priors[var] = dists[idx[0]]
Expand Down
2 changes: 1 addition & 1 deletion preliz/predictive/ppe.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def ppe(model, target, method="projective", engine="auto", random_state=0):
with model:
idata = fit(method="pathfinder", num_samples=1000)

new_priors = back_fitting_idata(idata, preliz_model, alternative=False)
new_priors = back_fitting_idata(idata, preliz_model, new_families=False)
if engine == "bambi":
new_model = write_bambi_string(new_priors, untransformed_var_info)
elif engine == "pymc":
Expand Down
12 changes: 6 additions & 6 deletions preliz/tests/test_posterior_to_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

def test_p2p_pymc():
pz.posterior_to_prior(model, idata)
assert 'Gamma\x1b[0m("b", alpha=' in pz.posterior_to_prior(model, idata, alternative="auto")
pz.posterior_to_prior(model, idata, alternative=[pz.LogNormal()])
assert 'Gamma\x1b[0m("b", alpha=' in pz.posterior_to_prior(model, idata, new_families="auto")
pz.posterior_to_prior(model, idata, new_families=[pz.LogNormal()])
assert 'Gamma\x1b[0m("b", mu=' in pz.posterior_to_prior(
model, idata, alternative={"b": [pz.Gamma(mu=0)]}
model, idata, new_families={"b": [pz.Gamma(mu=0)]}
)


Expand All @@ -39,9 +39,9 @@ def test_p2p_pymc():
def test_p2p_bambi():
pz.posterior_to_prior(bmb_model, bmb_idata)
assert 'Gamma\x1b[0m", alpha=' in pz.posterior_to_prior(
bmb_model, bmb_idata, alternative="auto"
bmb_model, bmb_idata, new_families="auto"
)
pz.posterior_to_prior(bmb_model, bmb_idata, alternative=[pz.LogNormal()])
pz.posterior_to_prior(bmb_model, bmb_idata, new_families=[pz.LogNormal()])
assert 'Normal\x1b[0m", mu=' in pz.posterior_to_prior(
bmb_model, bmb_idata, alternative={"Intercept": [pz.Normal(mu=1, sigma=1)]}
bmb_model, bmb_idata, new_families={"Intercept": [pz.Normal(mu=1, sigma=1)]}
)

0 comments on commit 684563b

Please sign in to comment.