-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] Adapter for Scipy Distributions (#287)
Fixes #227 #### What does this implement/fix? Explain your changes. <!-- A clear and concise description of what you have implemented. --> - Adapter for Scipy distributions - Fisk Distribution using Scipy Adapter - Poisson Distribution using Scipy Adapter
- Loading branch information
1 parent
4441e55
commit 79dccf2
Showing
5 changed files
with
214 additions
and
172 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
"""Adapters for probability distribution objects, scipy facing.""" | ||
# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) | ||
|
||
from skpro.distributions.adapters.scipy._distribution import _ScipyAdapter | ||
from skpro.distributions.adapters.scipy._empirical import empirical_from_discrete | ||
|
||
__all__ = ["empirical_from_discrete"] | ||
__all__ = ["empirical_from_discrete", "_ScipyAdapter"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) | ||
"""Adapter for Scipy Distributions.""" | ||
|
||
__author__ = ["malikrafsan"] | ||
|
||
from typing import Union | ||
|
||
import pandas as pd | ||
from scipy.stats import rv_continuous, rv_discrete | ||
|
||
from skpro.distributions.base import BaseDistribution | ||
|
||
__all__ = ["_ScipyAdapter"] | ||
|
||
|
||
class _ScipyAdapter(BaseDistribution): | ||
"""Adapter for scipy distributions. | ||
This class is an adapter for scipy distributions. It provides a common | ||
interface for all scipy distributions. The class is abstract | ||
and should not be instantiated directly. | ||
""" | ||
|
||
_distribution_attr = "_dist" | ||
_tags = { | ||
"object_type": ["distribution", "scipy_distribution_adapter"], | ||
} | ||
|
||
def __init__(self, index=None, columns=None): | ||
obj = self._get_scipy_object() | ||
setattr(self, self._distribution_attr, obj) | ||
super().__init__(index, columns) | ||
|
||
def _get_scipy_object(self) -> Union[rv_continuous, rv_discrete]: | ||
"""Abstract method to get the scipy distribution object. | ||
Should import the scipy distribution object and return it. | ||
""" | ||
raise NotImplementedError("abstract method") | ||
|
||
def _get_scipy_param(self): | ||
"""Abstract method to get the scipy distribution parameters. | ||
Should return a tuple with two elements: a list of positional arguments (args) | ||
and a dictionary of keyword arguments (kwds). | ||
""" | ||
raise NotImplementedError("abstract method") | ||
|
||
def _mean(self): | ||
obj: Union[rv_continuous, rv_discrete] = getattr(self, self._distribution_attr) | ||
args, kwds = self._get_scipy_param() | ||
return obj.mean(*args, **kwds) | ||
|
||
def _var(self): | ||
obj: Union[rv_continuous, rv_discrete] = getattr(self, self._distribution_attr) | ||
args, kwds = self._get_scipy_param() | ||
return obj.var(*args, **kwds) | ||
|
||
def _pdf(self, x: pd.DataFrame): | ||
obj: Union[rv_continuous, rv_discrete] = getattr(self, self._distribution_attr) | ||
if isinstance(obj, rv_discrete): | ||
return 0 | ||
|
||
args, kwds = self._get_scipy_param() | ||
return obj.pdf(x, *args, **kwds) | ||
|
||
def _log_pdf(self, x: pd.DataFrame): | ||
obj: Union[rv_continuous, rv_discrete] = getattr(self, self._distribution_attr) | ||
if isinstance(obj, rv_discrete): | ||
return 0 | ||
|
||
args, kwds = self._get_scipy_param() | ||
return obj.logpdf(x, *args, **kwds) | ||
|
||
def _cdf(self, x: pd.DataFrame): | ||
obj: Union[rv_continuous, rv_discrete] = getattr(self, self._distribution_attr) | ||
args, kwds = self._get_scipy_param() | ||
return obj.cdf(x, *args, **kwds) | ||
|
||
def _ppf(self, p: pd.DataFrame): | ||
obj: Union[rv_continuous, rv_discrete] = getattr(self, self._distribution_attr) | ||
args, kwds = self._get_scipy_param() | ||
return obj.ppf(p, *args, **kwds) | ||
|
||
def _pmf(self, x: pd.DataFrame): | ||
"""Return the probability mass function evaluated at x.""" | ||
obj: Union[rv_continuous, rv_discrete] = getattr(self, self._distribution_attr) | ||
if isinstance(obj, rv_continuous): | ||
return 0 | ||
|
||
args, kwds = self._get_scipy_param() | ||
return obj.pmf(x, *args, **kwds) | ||
|
||
def pmf(self, x: pd.DataFrame): | ||
"""Return the probability mass function evaluated at x.""" | ||
return self._boilerplate("_pmf", x=x) | ||
|
||
def _log_pmf(self, x: pd.DataFrame): | ||
"""Return the log of the probability mass function evaluated at x.""" | ||
obj: Union[rv_continuous, rv_discrete] = getattr(self, self._distribution_attr) | ||
if isinstance(obj, rv_continuous): | ||
return 0 | ||
|
||
args, kwds = self._get_scipy_param() | ||
return obj.logpmf(x, *args, **kwds) | ||
|
||
def log_pmf(self, x: pd.DataFrame): | ||
"""Return the log of the probability mass function evaluated at x.""" | ||
return self._boilerplate("_log_pmf", x=x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.