Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor #580

Merged
merged 3 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions preliz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

Tools to help you pick a prior
"""
import logging
from os import path as os_path

from matplotlib import rcParams
Expand All @@ -18,13 +17,6 @@

__version__ = "0.11.0"

_log = logging.getLogger("preliz")

if not logging.root.handlers:
_log.setLevel(logging.INFO)
if len(_log.handlers) == 0:
handler = logging.StreamHandler()
_log.addHandler(handler)

# Allow legend outside plot in maxent to be included when saving a figure
# We may want to make this more explicit by having preliz.rcParams
Expand All @@ -37,4 +29,4 @@
style.core.reload_library()

# clean namespace
del logging, os_path, rcParams, _preliz_style_path, _log
del os_path, rcParams, _preliz_style_path
29 changes: 3 additions & 26 deletions preliz/internal/distribution_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def from_precision(precision):


def to_precision(sigma):
precision = 1 / sigma**2
precision = 1 / (eps + sigma**2)
return precision


Expand Down Expand Up @@ -148,38 +148,15 @@ def num_kurtosis(dist):
}


def get_distributions(dist_names=None, exclude=None):
def get_distributions(dist_names=None):

if dist_names is None:
all_distributions = modules["preliz.distributions"].__all__
else:
all_distributions = dist_names

if exclude is None:
exclude = []
if exclude == "auto":
exclude = [
"Beta",
"BetaScaled",
"Triangular",
"TruncatedNormal",
"Uniform",
"VonMises",
"Categorical",
"DiscreteUniform",
"HyperGeometric",
"zeroInflatedBinomial",
"ZeroInflatedNegativeBinomial",
"ZeroInflatedPoisson",
"MvNormal",
"Mixture",
]

distributions = []
for a_dist in all_distributions:
dist = getattr(modules["preliz.distributions"], a_dist)()
if dist.__class__.__name__ not in exclude:
distributions.append(dist)
if exclude:
return exclude, distributions
distributions.append(dist)
return distributions
13 changes: 0 additions & 13 deletions preliz/internal/logging.py

This file was deleted.

208 changes: 0 additions & 208 deletions preliz/internal/parser.py

This file was deleted.

61 changes: 1 addition & 60 deletions preliz/internal/plot_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
try:
from IPython import get_ipython
from ipywidgets import FloatSlider, IntSlider, FloatText, IntText, Checkbox, ToggleButton
from pymc import sample_prior_predictive
except ImportError:
pass

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import _pylab_helpers, get_backend
from matplotlib.ticker import MaxNLocator
from .logging import disable_pymc_sampling_logs
from .narviz import hdi, kde
from preliz.internal.narviz import hdi, kde


def plot_pointinterval(distribution, interval="hdi", levels=None, rotated=False, ax=None):
Expand Down Expand Up @@ -425,63 +423,6 @@ def looper(*args, **kwargs):
return looper


def bambi_plot_decorator(func, iterations, kind_plot, references, plot_func):
def looper(*args, **kwargs):
kwargs.pop("__resample__")
x_min = kwargs.pop("__x_min__")
x_max = kwargs.pop("__x_max__")
if not kwargs.pop("__set_xlim__"):
x_min = None
x_max = None
auto = True
else:
auto = False

model = func(*args, **kwargs)
model.build()
with disable_pymc_sampling_logs():
idata = model.prior_predictive(iterations)
results = (
idata["prior_predictive"].stack(sample=("chain", "draw"))[model.response_name].values.T
)

_, ax = plt.subplots()
ax.set_xlim(x_min, x_max, auto=auto)
if plot_func is None:
plot_repr(results, kind_plot, references, iterations, ax)
else:
plot_func(results, ax)

return looper


def pymc_plot_decorator(func, iterations, kind_plot, references, plot_func):
def looper(*args, **kwargs):
kwargs.pop("__resample__")
x_min = kwargs.pop("__x_min__")
x_max = kwargs.pop("__x_max__")
if not kwargs.pop("__set_xlim__"):
x_min = None
x_max = None
auto = True
else:
auto = False
with func(*args, **kwargs) as model:
obs_name = model.observed_RVs[0].name
with disable_pymc_sampling_logs():
idata = sample_prior_predictive(samples=iterations)
results = idata["prior_predictive"].stack(sample=("chain", "draw"))[obs_name].values.T

_, ax = plt.subplots()
ax.set_xlim(x_min, x_max, auto=auto)
if plot_func is None:
plot_repr(results, kind_plot, references, iterations, ax)
else:
plot_func(results, ax)

return looper


def plot_repr(results, kind_plot, references, iterations, ax):
alpha = max(0.01, 1 - iterations * 0.009)

Expand Down
Loading
Loading