Skip to content

Commit

Permalink
refactor and small fixes to distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Oct 31, 2024
1 parent 6cce4ae commit f442953
Showing 1 changed file with 66 additions and 189 deletions.
255 changes: 66 additions & 189 deletions preliz/distributions/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# pylint: disable=import-outside-toplevel
from collections import namedtuple
from copy import copy
import warnings

try:
from ipywidgets import interactive
Expand Down Expand Up @@ -44,7 +45,11 @@ def __repr__(self):
if name in ["Truncated", "Censored", "Hurdle"]:
name += self.dist.__class__.__name__
if name == "Mixture":
name = "Mixture" + "".join([dist.__class__.__name__ for dist in self.dist]) + "\n"
name = (
"Mixture"
+ "".join(dict.fromkeys(dist.__class__.__name__ for dist in self.dist))
+ "\n"
)

if self.is_frozen:
if "Mixture" in name:
Expand Down Expand Up @@ -98,6 +103,11 @@ def summary(self, mass=0.94, interval="hdi", fmt=".2f"):
name = "Truncated" + self.dist.__class__.__name__
elif name == "Censored":
name = "Censored" + self.dist.__class__.__name__
elif name == "Mixture":
name = "Mixture" + "".join(
dict.fromkeys(dist.__class__.__name__ for dist in self.dist)
)

attr = namedtuple(name, ["mean", "median", "std", "lower", "upper"])
mean = float(f"{self.mean():{fmt}}")
median = float(f"{self.median():{fmt}}")
Expand All @@ -120,62 +130,6 @@ def summary(self, mass=0.94, interval="hdi", fmt=".2f"):
else:
return None

def rvs(self, *args, **kwds):
"""Random sample
Parameters
----------
size : int or tuple of ints, optional
Defining number of random variates. Defaults to 1.
random_state : {None, int, numpy.random.Generator, numpy.random.RandomState}
Defaults to None
"""
return self.rvs(*args, **kwds)

def cdf(self, x, *args, **kwds):
"""Cumulative distribution function.
Parameters
----------
x : array_like
Values on which to evaluate the cdf
"""
return self.cdf(x, *args, **kwds)

def ppf(self, q, *args, **kwds):
"""Percent point function (inverse of cdf).
Parameters
----------
x : array_like
Values on which to evaluate the inverse of the cdf
"""
return self.ppf(q, *args, **kwds)

def mean(self):
"""Mean of the distribution."""
return self.mean()

def median(self):
"""Median of the distribution."""
return self.median()

def std(self):
"""Standard deviation of the distribution."""
return self.std()

def var(self):
"""Variance of the distribution."""
return self.var()

def skewness(self):
"""Skewness of the distribution."""
return self.stats(moment="s")

def kurtosis(self):
"""Kurtosis of the distribution"""
return self.stats(moments="k")

def moments(self, types="mvsk"):
"""
Compute moments of the distribution.
Expand Down Expand Up @@ -227,7 +181,7 @@ def eti(self, mass=0.94, fmt=".2f"):

if valid_scalar_params(self):
lower_tail, upper_tail = self.ppf([(1 - mass) / 2, 1 - (1 - mass) / 2])
if self.kind == "continuos" and fmt != "none":
if self.kind == "continuous" and fmt != "none":
lower_tail = float(f"{lower_tail:{fmt}}")
upper_tail = float(f"{upper_tail:{fmt}}")
elif self.kind == "discrete":
Expand All @@ -254,9 +208,12 @@ def hdi(self, mass=0.94, fmt=".2f"):
if not isinstance(fmt, str):
raise ValueError("Invalid format string.")

if self.__class__.__name__ == "Mixture":
warnings.warn("HDI may not be correct for multimodal distributions")

if valid_scalar_params(self):
lower_tail, upper_tail = optimize_hdi(self, mass)
if self.kind == "continuos" and fmt != "none":
if self.kind == "continuous" and fmt != "none":
lower_tail = float(f"{lower_tail:{fmt}}")
upper_tail = float(f"{upper_tail:{fmt}}")
return (lower_tail, upper_tail)
Expand Down Expand Up @@ -298,6 +255,12 @@ def to_pymc(self, name=None, **kwargs):
upper=self.params_dict["upper"],
**kwargs,
)
elif self.__class__.__name__ == "Mixture":
pymc_dist = pymc_class.dist(
self.weights,
[dist.to_pymc() for dist in self.dist],
**kwargs,
)
else:
pymc_dist = pymc_class.dist(**self.params_dict, **kwargs)
else:
Expand All @@ -311,6 +274,16 @@ def to_pymc(self, name=None, **kwargs):
upper=self.params_dict["upper"],
**kwargs,
)
elif self.__class__.__name__ == "Mixture":
pymc_dist = pymc_class(
name,
self.weights,
[
getattr(pm_dists, dist.__class__.__name__).dist(**dist.params_dict)
for dist in self.dist
],
**kwargs,
)
else:
pymc_dist = pymc_class(name, **self.params_dict, **kwargs)

Expand Down Expand Up @@ -371,6 +344,9 @@ def _finite_endpoints(self, support):
if isinstance(support, tuple):
lower_ep, upper_ep = support
else:
if support not in ["restricted", "full"]:
raise ValueError("Allowed values for the support are 'restricted' or 'full' ")

lower_ep, upper_ep = self.support

if not np.isfinite(lower_ep) or support == "restricted":
Expand All @@ -380,6 +356,35 @@ def _finite_endpoints(self, support):

return lower_ep, upper_ep

def xvals(self, support, n_points=None):
"""Provide x values in the support of the distribution. This is useful for example when
plotting.
Parameters
----------
support : str
Available options are `"full"` or `"restricted"`.
If `"full"` the values will cover the entire support of the distribution if the boundary
is finite, or the quantiles 0.0001 or 0.9999, if infinite.
If `"restricted"` the values will cover the quantile 0.0001 to 0.9999.
n_points : int
Number of values to return. Defaults to 1000 for continuous distributions
and 200 for discrete ones.
For discrete distributions the returned values may be fewer
than `n_points` if the actual number of discrete values in the support of the
distribution is smaller than `n_points`.
"""
lower_ep, upper_ep = self._finite_endpoints(support)

if self.kind == "continuous":
if n_points is None:
n_points = 1000
return _continuous_xvals(lower_ep, upper_ep, n_points)
else:
if n_points is None:
n_points = 200
return _discrete_xvals(lower_ep, upper_ep, n_points)

def plot_pdf(
self,
moments=None,
Expand Down Expand Up @@ -683,70 +688,6 @@ def __init__(self):
super().__init__()
self.kind = "continuous"

def xvals(self, support, n_points=1000):
"""Provide x values in the support of the distribution. This is useful for example when
plotting.
Parameters
----------
support : str
Available options are `full` or `restricted`. If `full` the values will cover the entire
support of the distribution, if finite, or the quantiles 0.0001 or 0.9999, if infinite.
If `restricted` the values will cover the quantile 0.0001 to 0.9999.
n_points : int
Number of values to return.
"""
half_n_points = int(n_points / 2)

if isinstance(support, tuple):
even = np.linspace(*support, n_points)
uneven = self.ppf(np.linspace(*self.cdf(support), n_points))
else:
lower_ep, upper_ep = self.support

if not np.isfinite(lower_ep) or support == "restricted":
lower_ep = 0.0001
if not np.isfinite(upper_ep) or support == "restricted":
upper_ep = 0.9999

even = np.linspace(*self.ppf([lower_ep, upper_ep]), half_n_points)
uneven = self.ppf(np.linspace(lower_ep, upper_ep, half_n_points))

return np.sort(np.concatenate([even, uneven]))

def _fit_mle(self, sample, **kwargs):
"""
Estimate the parameters of the distribution from a sample by maximizing the likelihood.
Parameters
----------
sample : array-like
a sample
kwargs : dict
keywords arguments passed to scipy.stats.rv_continuous.fit
"""
raise NotImplementedError

def pdf(self, x, *args, **kwds):
"""Probability density function at x.
Parameters
----------
x : array_like
Values on which to evaluate the pdf
"""
return self.pdf(x, *args, **kwds)

def logpdf(self, x, *args, **kwds):
"""Probability mass function at x.
Parameters
----------
x : array_like
Values on which to evaluate the pdf
"""
return self.logpdf(x, *args, **kwds)


class Discrete(Distribution):
"""Base class for discrete distributions."""
Expand All @@ -755,44 +696,6 @@ def __init__(self):
super().__init__()
self.kind = "discrete"

def xvals(self, support, n_points=200):
"""Provide x values in the support of the distribution. This is useful for example when
plotting.
Parameters
----------
support : str
Available options are `full` or `restricted`. If `full` the values will cover the entire
support of the distribution, if finite, or the quantiles 0.0001 or 0.9999, if infinite.
If `restricted` the values will cover the quantile 0.0001 to 0.9999.
n_points : int
Number of values to return. The returned values may be fewer than `n_points` if
the actual number of discrete values in the support of the distribution is smaller than
`n_points`.
"""
lower_ep, upper_ep = self._finite_endpoints(support)
return discrete_xvals(lower_ep, upper_ep, n_points)

def pdf(self, x, *args, **kwds):
"""Probability mass function at x.
Parameters
----------
x : array_like
Values on which to evaluate the pdf
"""
return self.pdf(x, *args, **kwds)

def logpdf(self, x, *args, **kwds):
"""Probability mass function at x.
Parameters
----------
x : array_like
Values on which to evaluate the pdf
"""
return self.logpdf(x, *args, **kwds)


class DistributionTransformer(Distribution):
"""Base class for distributions that transform other distributions"""
Expand All @@ -802,38 +705,12 @@ def __init__(self):
if not isinstance(self.dist, list):
self.kind = self.dist.kind

def xvals(self, support, n_points=None):
"""Provide x values in the support of the distribution. This is useful for example when
plotting.
Parameters
----------
support : str
Available options are `full` or `restricted`. If `full` the values will cover the entire
support of the distribution, if finite, or the quantiles 0.0001 or 0.9999, if infinite.
If `restricted` the values will cover the quantile 0.0001 to 0.9999.
n_points : int
Number of values to return. For discrete distributions the returned values may be fewer
than `n_points` if the actual number of discrete values in the support of the
distribution is smaller than `n_points`.
"""
lower_ep, upper_ep = self._finite_endpoints(support)

if self.kind == "continuous":
if n_points is None:
n_points = 1000
return continuous_xvals(lower_ep, upper_ep, n_points)
else:
if n_points is None:
n_points = 200
return discrete_xvals(lower_ep, upper_ep, n_points)


def continuous_xvals(lower_ep, upper_ep, n_points):
def _continuous_xvals(lower_ep, upper_ep, n_points):
return np.linspace(lower_ep, upper_ep, n_points)


def discrete_xvals(lower_ep, upper_ep, n_points):
def _discrete_xvals(lower_ep, upper_ep, n_points):
upper_ep = int(upper_ep)
lower_ep = int(lower_ep)
range_x = upper_ep - lower_ep
Expand Down

0 comments on commit f442953

Please sign in to comment.