Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Adapter for Scipy Distributions #287

Merged
merged 11 commits into from
May 3, 2024
Prev Previous commit
Next Next commit
docs: add docstring on scipy adapters
malikrafsan committed Apr 29, 2024
commit 936f0d1243163e71ae257ff8ba2ac7cccbf9b127
4 changes: 2 additions & 2 deletions skpro/distributions/adapters/scipy/__init__.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,6 @@
# copyright: skpro developers, BSD-3-Clause License (see LICENSE file)

from skpro.distributions.adapters.scipy._empirical import empirical_from_discrete
from skpro.distributions.adapters.scipy._distribution import _ScipyAdapter
from skpro.distributions.adapters.scipy._distribution import _ScipyAdapter, _ScipyDiscreteAdapter

__all__ = ["empirical_from_discrete", "_ScipyAdapter"]
__all__ = ["empirical_from_discrete", "_ScipyAdapter", "_ScipyDiscreteAdapter"]
49 changes: 48 additions & 1 deletion skpro/distributions/adapters/scipy/_distribution.py
Original file line number Diff line number Diff line change
@@ -6,9 +6,16 @@
from scipy.stats import rv_continuous, rv_discrete
from scipy.stats import rv_continuous

__all__ = ["_ScipyAdapter"]
__all__ = ["_ScipyAdapter", "_ScipyDiscreteAdapter"]

class _ScipyAdapter(BaseDistribution):
"""Adapter for scipy distributions.

This class is an adapter for scipy distributions. It provides a common
interface for all scipy distributions. The class is abstract
and should not be instantiated directly.
"""

_distribution_attr = "_dist"

def __init__(self, index=None, columns=None):
@@ -17,9 +24,18 @@ def __init__(self, index=None, columns=None):
super().__init__(index, columns)

def _get_scipy_object(self) -> Union[rv_continuous, rv_discrete]:
"""Abstract method to get the scipy distribution object.

Should import the scipy distribution object and return it.
"""
raise NotImplementedError("abstract method")

def _get_scipy_param(self) -> tuple[list, dict]:
"""Abstract method to get the scipy distribution parameters.

Should return a tuple with two elements: a list of positional arguments (args)
and a dictionary of keyword arguments (kwds).
"""
raise NotImplementedError("abstract method")

def _mean(self):
@@ -51,3 +67,34 @@ def _ppf(self, q: pd.DataFrame):
obj: Union[rv_continuous, rv_discrete] = getattr(self, self._distribution_attr)
args, kwds = self._get_scipy_param()
return obj.ppf(q, *args, **kwds)

class _ScipyDiscreteAdapter(_ScipyAdapter):
"""Adapter for scipy discrete distributions.

This class is an adapter for scipy discrete distributions. It provides a common
interface for all scipy discrete distributions. The class is abstract
and should not be instantiated directly.
"""

def _get_scipy_object(self) -> rv_discrete:
raise NotImplementedError("abstract method")

def _pmf(self, x: pd.DataFrame):
"""Return the probability mass function evaluated at x."""
obj: rv_discrete = getattr(self, self._distribution_attr)
args, kwds = self._get_scipy_param()
return obj.pmf(x, *args, **kwds)

def pmf(self, x: pd.DataFrame):
"""Return the probability mass function evaluated at x."""
return self._boilerplate("_pmf", x=x)

def _log_pmf(self, x: pd.DataFrame):
"""Return the log of the probability mass function evaluated at x."""
obj: rv_discrete = getattr(self, self._distribution_attr)
args, kwds = self._get_scipy_param()
return obj.logpmf(x, *args, **kwds)

def log_pmf(self, x: pd.DataFrame):
"""Return the log of the probability mass function evaluated at x."""
return self._boilerplate("_log_pmf", x=x)
2 changes: 1 addition & 1 deletion skpro/distributions/fisk.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ class Fisk(BaseDistribution):

Example
-------
>>> from skpro.distributions.fisk import Fisk
>>> from skpro.distributions.fisk import FiskScipy as Fisk

>>> d = Fisk(beta=[[1, 1], [2, 3], [4, 5]], alpha=2)
"""
24 changes: 24 additions & 0 deletions skpro/distributions/fisk_scipy.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,30 @@
__all__ = ["FiskScipy"]

class FiskScipy(_ScipyAdapter):
r"""Fisk distribution, aka log-logistic distribution.

The Fisk distribution is parametrized by a scale parameter :math:`\alpha`
and a shape parameter :math:`\beta`, such that the cumulative distribution
function (CDF) is given by:

.. math:: F(x) = 1 - \left(1 + \frac{x}{\alpha}\right)^{-\beta}\right)^{-1}

Parameters
----------
alpha : float or array of float (1D or 2D), must be positive
scale parameter of the distribution
beta : float or array of float (1D or 2D), must be positive
shape parameter of the distribution
index : pd.Index, optional, default = RangeIndex
columns : pd.Index, optional, default = RangeIndex

Example
-------
>>> from skpro.distributions.fisk import Fisk

>>> d = Fisk(beta=[[1, 1], [2, 3], [4, 5]], alpha=2)
"""

_tags = {
"capabilities:approx": ["energy", "pdfnorm"],
"capabilities:exact": ["mean", "var", "pdf", "log_pdf", "cdf", "ppf"],
20 changes: 18 additions & 2 deletions skpro/distributions/poisson_scipy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
from scipy.stats import rv_discrete, poisson
from skpro.distributions.adapters.scipy import _ScipyAdapter
from skpro.distributions.adapters.scipy import _ScipyDiscreteAdapter

__all__ = ["PoissonScipy"]

class PoissonScipy(_ScipyAdapter):
class PoissonScipy(_ScipyDiscreteAdapter):
"""Poisson distribution.

Parameters
----------
mu : float or array of float (1D or 2D)
mean of the distribution
index : pd.Index, optional, default = RangeIndex
columns : pd.Index, optional, default = RangeIndex

Example
-------
>>> from skpro.distributions import PoissonScipy as Poisson

>>> distr = Poisson(mu=[[1, 1], [2, 3], [4, 5]])
"""

_tags = {
"capabilities:approx": ["ppf", "energy"],
"capabilities:exact": ["mean", "var", "pmf", "log_pmf", "cdf"],