From 6db78c20f41afebfdb447c50f8acd8bd23f5083a Mon Sep 17 00:00:00 2001 From: ShreeshaM07 Date: Tue, 11 Jun 2024 22:47:14 +0530 Subject: [PATCH 01/12] [ENH] Histogram distribution --- docs/source/api_reference/distributions.rst | 14 + skpro/distributions/__init__.py | 2 + skpro/distributions/base/__init__.py | 3 +- skpro/distributions/base/_base_array.py | 467 ++++++++++++++ skpro/distributions/histogram.py | 618 +++++++++++++++++++ skpro/distributions/tests/test_all_distrs.py | 2 +- 6 files changed, 1104 insertions(+), 2 deletions(-) create mode 100644 skpro/distributions/base/_base_array.py create mode 100644 skpro/distributions/histogram.py diff --git a/docs/source/api_reference/distributions.rst b/docs/source/api_reference/distributions.rst index e25e5726c..4157488d0 100644 --- a/docs/source/api_reference/distributions.rst +++ b/docs/source/api_reference/distributions.rst @@ -107,3 +107,17 @@ Sampling and multivariate composition :template: class.rst IID + +Array distributions +------------------- + +Continuous support +~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: skpro.distributions + +.. autosummary:: + :toctree: auto_generated/ + :template: class.rst + + Histogram diff --git a/skpro/distributions/__init__.py b/skpro/distributions/__init__.py index e5f53ef13..50f337f42 100644 --- a/skpro/distributions/__init__.py +++ b/skpro/distributions/__init__.py @@ -16,6 +16,7 @@ "HalfLogistic", "HalfNormal", "IID", + "Histogram", "Laplace", "Logistic", "LogLaplace", @@ -45,6 +46,7 @@ from skpro.distributions.halfcauchy import HalfCauchy from skpro.distributions.halflogistic import HalfLogistic from skpro.distributions.halfnormal import HalfNormal +from skpro.distributions.histogram import Histogram from skpro.distributions.laplace import Laplace from skpro.distributions.logistic import Logistic from skpro.distributions.loglaplace import LogLaplace diff --git a/skpro/distributions/base/__init__.py b/skpro/distributions/base/__init__.py index 4c56bc6b9..e2191ba15 100644 --- a/skpro/distributions/base/__init__.py +++ b/skpro/distributions/base/__init__.py @@ -2,7 +2,8 @@ # copyright: skpro developers, BSD-3-Clause License (see LICENSE file) # adapted from sktime -__all__ = ["BaseDistribution", "_DelegatedDistribution"] +__all__ = ["BaseDistribution", "_DelegatedDistribution", "BaseArrayDistribution"] from skpro.distributions.base._base import BaseDistribution +from skpro.distributions.base._base_array import BaseArrayDistribution from skpro.distributions.base._delegate import _DelegatedDistribution diff --git a/skpro/distributions/base/_base_array.py b/skpro/distributions/base/_base_array.py new file mode 100644 index 000000000..3d12eb07d --- /dev/null +++ b/skpro/distributions/base/_base_array.py @@ -0,0 +1,467 @@ +# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) +"""Base classes for probability array distribution objects.""" + +__author__ = ["ShreeshaM07"] + +__all__ = ["BaseArrayDistribution"] + +import numpy as np +import pandas as pd + +from skpro.base import BaseObject +from skpro.distributions.base import BaseDistribution +from skpro.distributions.base._base import ( + _coerce_to_pd_index_or_none, + is_scalar_notnone, +) + + +class BaseArrayDistribution(BaseDistribution, BaseObject): + """Base Array probability distribution.""" + + def __init__(self, index=None, columns=None): + self.index = _coerce_to_pd_index_or_none(index) + self.columns = _coerce_to_pd_index_or_none(columns) + + super().__init__(index=index, columns=columns) + + def _loc(self, rowidx=None, colidx=None): + if is_scalar_notnone(rowidx) and is_scalar_notnone(colidx): + return self._at(rowidx, colidx) + if is_scalar_notnone(rowidx): + rowidx = pd.Index([rowidx]) + if is_scalar_notnone(colidx): + colidx = pd.Index([colidx]) + + if rowidx is not None: + row_iloc = pd.Index(self.index.get_indexer_for(rowidx)) + else: + row_iloc = None + if colidx is not None: + col_iloc = pd.Index(self.columns.get_indexer_for(colidx)) + else: + col_iloc = None + return self._iloc(rowidx=row_iloc, colidx=col_iloc) + + def _subset_params(self, rowidx, colidx, coerce_scalar=False): + """Subset distribution parameters to given rows and columns. + + Parameters + ---------- + rowidx : None, numpy index/slice coercible, or int + Rows to subset to. If None, no subsetting is done. + colidx : None, numpy index/slice coercible, or int + Columns to subset to. If None, no subsetting is done. + coerce_scalar : bool, optional, default=False + If True, and the subsetted parameter is a scalar, coerce it to a scalar. + + Returns + ------- + dict + Dictionary with subsetted distribution parameters. + Keys are parameter names of ``self``, values are the subsetted parameters. + """ + params = self._get_dist_params() + + subset_param_dict = {} + for param, val in params.items(): + if val is None: + subset_param_dict[param] = None + continue + arr = val + arr_shape = 2 + # when rowidx and colidx are integer while plotting + if coerce_scalar: + arr = arr[rowidx][colidx] + subset_param_dict[param] = arr + continue + # subset the 2D distributions + if arr_shape == 2 and rowidx is not None: + _arr_shift = [] + if rowidx.values is not None and colidx is None: + rowidx_list = rowidx.values + for row in rowidx: + _arr_shift.append(arr[row]) + + elif rowidx.values is not None and colidx.values is not None: + rowidx_list = rowidx.values + colidx_list = colidx.values + for row in rowidx_list: + _arr_shift_row = [] + for col in colidx_list: + _arr_shift_row.append(arr[row][col]) + _arr_shift.append(_arr_shift_row) + arr = _arr_shift + + if arr_shape == 2 and rowidx is None: + _arr_shift = [] + if colidx is not None: + colidx_list = colidx.values + for row in range(len(arr)): + _arr_shift_row = [] + for col in colidx_list: + _arr_shift_row.append(arr[row][col]) + _arr_shift.append(_arr_shift_row) + arr = _arr_shift + + subset_param_dict[param] = arr + return subset_param_dict + + def _iloc(self, rowidx=None, colidx=None): + if is_scalar_notnone(rowidx) and is_scalar_notnone(colidx): + return self._iat(rowidx, colidx) + if is_scalar_notnone(rowidx): + rowidx = pd.Index([rowidx]) + if is_scalar_notnone(colidx): + colidx = pd.Index([colidx]) + + if rowidx is not None: + rowidx = pd.Index(rowidx) + if colidx is not None: + colidx = pd.Index(colidx) + + subset_params = self._subset_params(rowidx=rowidx, colidx=colidx) + + def subset_not_none(idx, subs): + if subs is not None: + return idx.take(pd.Index(subs)) + else: + return idx + + index_subset = subset_not_none(self.index, rowidx) + columns_subset = subset_not_none(self.columns, colidx) + + sk_distr_type = type(self) + return sk_distr_type( + index=index_subset, + columns=columns_subset, + **subset_params, + ) + + def _check_single_arr_distr(self, value): + return ( + isinstance(value[0], int) + or isinstance(value[0], np.integer) + or isinstance(value[0], float) + or isinstance(value[0], np.floating) + ) + + def _get_bc_params_dict( + self, dtype=None, oned_as="row", return_shape=False, **kwargs + ): + """Fully broadcast dict of parameters given param shapes and index, columns. + + Parameters + ---------- + kwargs : float, int, array of floats, or array of ints (1D or 2D) + Distribution parameters that are to be made broadcastable. If no positional + arguments are provided, all parameters of `self` are used except for `index` + and `columns`. + dtype : str, optional + broadcasted arrays are cast to all have datatype `dtype`. If None, then no + datatype casting is done. + oned_as : str, optional, "row" (default) or "col" + If 'row', then 1D arrays are treated as row vectors. If 'column', then 1D + arrays are treated as column vectors. + return_shape : bool, optional, default=False + If True, return shape tuple, and a boolean tuple + indicating which parameters are scalar. + + Returns + ------- + dict of float or integer arrays + Each element of the tuple represents a different broadcastable distribution + parameter. + shape : Tuple, only returned if ``return_shape`` is True + Shape of the broadcasted parameters. + Pair of row/column if not scalar, empty tuple if scalar. + is_scalar : Tuple of bools, only returned if ``return_is_scalar`` is True + Each element of the tuple is True if the corresponding parameter is scalar. + """ + number_of_params = len(kwargs) + if number_of_params == 0: + # Handle case where no positional arguments are provided + kwargs = self._get_dist_params() + number_of_params = len(kwargs) + + # def row_to_col(arr): + # """Convert 1D arrays to 2D col arrays, leave 2D arrays unchanged.""" + # if arr.ndim == 1 and oned_as == "col": + # return arr.reshape(-1, 1) + # return arr + + # kwargs_as_np = {k: row_to_col(np.array(v)) for k, v in kwargs.items()} + kwargs_as_np = {k: v for k, v in kwargs.items()} + + if hasattr(self, "index") and self.index is not None: + kwargs_as_np["index"] = self.index.to_numpy().reshape(-1, 1) + if hasattr(self, "columns") and self.columns is not None: + kwargs_as_np["columns"] = self.columns.to_numpy() + + bc_params = self.get_tags()["broadcast_params"] + + if bc_params is None: + bc_params = kwargs_as_np.keys() + + args_as_np = [kwargs_as_np[k] for k in bc_params] + + if all(self._check_single_arr_distr(value) for value in kwargs_as_np.values()): + # Convert all values in kwargs_as_np to np.array + kwargs_as_np = {key: np.array(value) for key, value in kwargs_as_np.items()} + shape = () + + if return_shape: + is_scalar = tuple([True] * (len(args_as_np) - 2)) + # print(kwargs_as_np,shape,is_scalar) + return kwargs_as_np, shape, is_scalar + return kwargs_as_np + + shape = (len(args_as_np[0]), len(args_as_np[0][0])) + # create broadcast_array which will be same shape as the original bins + # without considering the inner np.array containing the values of the bin edges + # and bin masses. This will later get replaced by the values after broadcasting + # index and columns. + broadcast_array = np.arange(len(args_as_np[0]) * len(args_as_np[0][0])).reshape( + shape + ) + + index_column_broadcast = [broadcast_array] * (len(args_as_np) - 2) + index_column_broadcast.append(kwargs_as_np["index"]) + index_column_broadcast.append(kwargs_as_np["columns"]) + + bc = np.broadcast_arrays(*index_column_broadcast) + if dtype is not None: + bc = [array.astype(dtype) for array in bc] + + for i in range(len(bc) - 2): + bc[i] = args_as_np[i] + + for i, k in enumerate(bc_params): + kwargs_as_np[k] = bc[i] + + if return_shape: + is_scalar = tuple([False] * (len(args_as_np) - 2)) + # print(kwargs_as_np,shape,is_scalar) + return kwargs_as_np, shape, is_scalar + return kwargs_as_np + + def pdf(self, x): + r"""Probability density function. + + Let :math:`X` be a random variables with the distribution of ``self``, + taking values in ``(N, n)`` ``DataFrame``-s + Let :math:`x\in \mathbb{R}^{N\times n}`. + By :math:`p_{X_{ij}}`, denote the marginal pdf of :math:`X` at the + :math:`(i,j)`-th entry. + + The output of this method, for input ``x`` representing :math:`x`, + is a ``DataFrame`` with same columns and indices as ``self``, + and entries :math:`p_{X_{ij}}(x_{ij})`. + + If ``self`` has a mixed or discrete distribution, this returns + the weighted continuous part of `self`'s distribution instead of the pdf, + i.e., the marginal pdf integrate to the weight of the continuous part. + + Parameters + ---------- + x : ``pandas.DataFrame`` or 2D ``np.ndarray`` + representing :math:`x`, as above + + Returns + ------- + ``pd.DataFrame`` with same columns and index as ``self`` + containing :math:`p_{X_{ij}}(x_{ij})`, as above + """ + distr_type = self.get_tag("distr:measuretype", "mixed", raise_error=False) + x = np.array(x) + if distr_type == "discrete": + return self._coerce_to_self_index_df(0, flatten=False) + + return self._boilerplate("_pdf", x=x) + + def log_pdf(self, x): + r"""Logarithmic probability density function. + + Numerically more stable than calling pdf and then taking logartihms. + + Let :math:`X` be a random variables with the distribution of ``self``, + taking values in `(N, n)` ``DataFrame``-s + Let :math:`x\in \mathbb{R}^{N\times n}`. + By :math:`p_{X_{ij}}`, denote the marginal pdf of :math:`X` at the + :math:`(i,j)`-th entry. + + The output of this method, for input ``x`` representing :math:`x`, + is a ``DataFrame`` with same columns and indices as ``self``, + and entries :math:`\log p_{X_{ij}}(x_{ij})`. + + If ``self`` has a mixed or discrete distribution, this returns + the weighted continuous part of `self`'s distribution instead of the pdf, + i.e., the marginal pdf integrate to the weight of the continuous part. + + Parameters + ---------- + x : ``pandas.DataFrame`` or 2D ``np.ndarray`` + representing :math:`x`, as above + + Returns + ------- + ``pd.DataFrame`` with same columns and index as ``self`` + containing :math:`\log p_{X_{ij}}(x_{ij})`, as above + """ + distr_type = self.get_tag("distr:measuretype", "mixed", raise_error=False) + x = np.array(x) + if distr_type == "discrete": + return self._coerce_to_self_index_df(-np.inf, flatten=False) + + return self._boilerplate("_log_pdf", x=x) + + def pmf(self, x): + r"""Probability mass function. + + Let :math:`X` be a random variables with the distribution of ``self``, + taking values in ``(N, n)`` ``DataFrame``-s + Let :math:`x\in \mathbb{R}^{N\times n}`. + By :math:`m_{X_{ij}}`, denote the marginal mass of :math:`X` at the + :math:`(i,j)`-th entry, i.e., + :math:`m_{X_{ij}}(x_{ij}) = \mathbb{P}(X_{ij} = x_{ij})`. + + The output of this method, for input ``x`` representing :math:`x`, + is a ``DataFrame`` with same columns and indices as ``self``, + and entries :math:`m_{X_{ij}}(x_{ij})`. + + If ``self`` has a mixed or discrete distribution, this returns + the weighted continuous part of `self`'s distribution instead of the pdf, + i.e., the marginal pdf integrate to the weight of the continuous part. + + Parameters + ---------- + x : ``pandas.DataFrame`` or 2D ``np.ndarray`` + representing :math:`x`, as above + + Returns + ------- + ``pd.DataFrame`` with same columns and index as ``self`` + containing :math:`p_{X_{ij}}(x_{ij})`, as above + """ + distr_type = self.get_tag("distr:measuretype", "mixed", raise_error=False) + if distr_type == "continuous": + return self._coerce_to_self_index_df(0, flatten=False) + + return self._boilerplate("_pmf", x=x) + + def log_pmf(self, x): + r"""Logarithmic probability mass function. + + Numerically more stable than calling pmf and then taking logartihms. + + Let :math:`X` be a random variables with the distribution of ``self``, + taking values in `(N, n)` ``DataFrame``-s + Let :math:`x\in \mathbb{R}^{N\times n}`. + By :math:`m_{X_{ij}}`, denote the marginal pdf of :math:`X` at the + :math:`(i,j)`-th entry, i.e., + :math:`m_{X_{ij}}(x_{ij}) = \mathbb{P}(X_{ij} = x_{ij})`. + + The output of this method, for input ``x`` representing :math:`x`, + is a ``DataFrame`` with same columns and indices as ``self``, + and entries :math:`\log m_{X_{ij}}(x_{ij})`. + + If ``self`` has a mixed or discrete distribution, this returns + the weighted continuous part of `self`'s distribution instead of the pdf, + i.e., the marginal pdf integrate to the weight of the continuous part. + + Parameters + ---------- + x : ``pandas.DataFrame`` or 2D ``np.ndarray`` + representing :math:`x`, as above + + Returns + ------- + ``pd.DataFrame`` with same columns and index as ``self`` + containing :math:`\log m_{X_{ij}}(x_{ij})`, as above + """ + distr_type = self.get_tag("distr:measuretype", "mixed", raise_error=False) + if distr_type == "continuous": + return self._coerce_to_self_index_df(-np.inf, flatten=False) + + return self._boilerplate("_log_pmf", x=x) + + def cdf(self, x): + r"""Cumulative distribution function. + + Let :math:`X` be a random variables with the distribution of ``self``, + taking values in ``(N, n)`` ``DataFrame``-s + Let :math:`x\in \mathbb{R}^{N\times n}`. + By :math:`F_{X_{ij}}`, denote the marginal cdf of :math:`X` at the + :math:`(i,j)`-th entry, + i.e., :math:`F_{X_{ij}}(t) = \mathbb{P}(X_{ij} \leq t)`. + + The output of this method, for input ``x`` representing :math:`x`, + is a ``DataFrame`` with same columns and indices as ``self``, + and entries :math:`F_{X_{ij}}(x_{ij})`. + + Parameters + ---------- + x : ``pandas.DataFrame`` or 2D ``np.ndarray`` + representing :math:`x`, as above + + Returns + ------- + ``pd.DataFrame`` with same columns and index as ``self`` + containing :math:`F_{X_{ij}}(x_{ij})`, as above + """ + x = np.array(x) + return self._boilerplate("_cdf", x=x) + + def ppf(self, p): + r"""Quantile function = percent point function = inverse cdf. + + Let :math:`X` be a random variables with the distribution of ``self``, + taking values in ``(N, n)`` ``DataFrame``-s + Let :math:`x\in \mathbb{R}^{N\times n}`. + By :math:`F_{X_{ij}}`, denote the marginal cdf of :math:`X` at the + :math:`(i,j)`-th entry. + + The output of this method, for input ``p`` representing :math:`p`, + is a ``DataFrame`` with same columns and indices as ``self``, + and entries :math:`F^{-1}_{X_{ij}}(p_{ij})`. + + Parameters + ---------- + p : ``pandas.DataFrame`` or 2D np.ndarray + representing :math:`p`, as above + + Returns + ------- + ``pd.DataFrame`` with same columns and index as ``self`` + containing :math:`F_{X_{ij}}(x_{ij})`, as above + """ + p = np.array(p) + return self._boilerplate("_ppf", p=p) + + def energy(self, x=None): + r"""Energy of self, w.r.t. self or a constant frame x. + + Let :math:`X, Y` be i.i.d. random variables with the distribution of ``self``. + + If ``x`` is ``None``, returns :math:`\mathbb{E}[|X-Y|]` (per row), + "self-energy". + If ``x`` is passed, returns :math:`\mathbb{E}[|X-x|]` (per row), "energy wrt x". + + The CRPS is related to energy: + it holds that + :math:`\mbox{CRPS}(\mbox{self}, y)` = `self.energy(y) - 0.5 * self.energy()`. + + Parameters + ---------- + x : None or pd.DataFrame, optional, default=None + if ``pd.DataFrame``, must have same rows and columns as ``self`` + + Returns + ------- + ``pd.DataFrame`` with same rows as ``self``, single column ``"energy"`` + each row contains one float, self-energy/energy as described above. + """ + if x is None: + return self._boilerplate("_energy_self", columns=["energy"]) + x = np.array(x) + return self._boilerplate("_energy_x", x=x, columns=["energy"]) diff --git a/skpro/distributions/histogram.py b/skpro/distributions/histogram.py new file mode 100644 index 000000000..dc2b854a0 --- /dev/null +++ b/skpro/distributions/histogram.py @@ -0,0 +1,618 @@ +# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) +"""Histogram distribution.""" + +__author__ = ["ShreeshaM07"] + +import numpy as np +import pandas as pd + +from skpro.distributions.base import BaseArrayDistribution + + +class Histogram(BaseArrayDistribution): + """Histogram Probability Distribution. + + The histogram probability distribution is parameterized + by the bins and bin densities. + + Parameters + ---------- + bins : tuple(float,float,int) or numpy.array of float 1D or 2D list of size m x n + 1. tuple(first bin's start point, last bin's end point, number of bins) + Used when bin widths are equal. + example: bins:(0,4,4), + 2. array has the bin boundaries with 1st element the first bin's + starting point and rest are the bin ending points of all bins + example: bins:[0, 1, 2, 3, 4], + 3. 2D list of size m x n containing m*n float numpy.arrays or tuple like case 1. + example : "bins": [ + [[0, 1, 2, 3, 4], [5, 5.5, 5.8, 6.5, 7, 7.5]], + [(2, 12, 5), [0, 1, 2, 3, 4]], + [[1.5, 2.5, 3.1, 4, 5.4], [-4, -2, -1.5, 5, 10]], + ] + bin_mass: array of float 1D or 2D list of size m x n containing + 1. Array has the mass of the bins or area of the bins. + example: bin_mass:[0.1, 0.2, 0.3, 0.4], + Note: len(bin_mass) will be (len(bins)-1). + Note: Sum of all the bin_mass must be 1. + 2. 2D list of size m x n containing m*n float numpy.arrays satisfying case 1. + example : "bin_mass": [ + [[0.1, 0.2, 0.3, 0.4], [0.25, 0.1, 0, 0.4, 0.25]], + [[0.1, 0.2, 0.4, 0.2, 0.1], [0.4, 0.3, 0.2, 0.1]], + [[0.06, 0.15, 0.09, 0.7], [0.4, 0.15, 0.325, 0.125]], + ] + index : pd.Index, optional, default = RangeIndex + columns : pd.Index, optional, default = RangeIndex + """ + + _tags = { + "authors": ["ShreeshaM07"], + "capabilities:approx": ["pdfnorm", "energy"], + "capabilities:exact": ["mean", "var", "pdf", "log_pdf", "cdf", "ppf"], + "distr:measuretype": "continuous", + "distr:paramtype": "parametric", + "broadcast_init": "on", + } + + def _convert_tuple_to_array(self, bins): + bins_to_list = (bins[0], bins[1], bins[2]) + bins = [] + bin_width = (bins_to_list[1] - bins_to_list[0]) / bins_to_list[2] + for b in range(bins_to_list[2]): + bins.append(bins_to_list[0] + b * bin_width) + bins.append(bins_to_list[1]) + return bins + + def _check_single_array_distr(self, bins, bin_mass): + all1Ds = ( + isinstance(bins[0], float) + or isinstance(bins[0], np.floating) + or isinstance(bins[0], int) + or isinstance(bins[0], np.integer) + ) + all1Ds = ( + all1Ds + and isinstance(bin_mass[0], int) + or isinstance(bin_mass[0], np.integer) + or isinstance(bin_mass[0], float) + or isinstance(bin_mass[0], np.floating) + and np.array(bin_mass).ndim == 1 + ) + return all1Ds + + def __init__(self, bins, bin_mass, index=None, columns=None): + if isinstance(bins, tuple): + bins = self._convert_tuple_to_array(bins) + self.bins = np.array(bins) + self.bin_mass = np.array(bin_mass) + elif self._check_single_array_distr(bins, bin_mass): + self.bins = np.array(bins) + self.bin_mass = np.array(bin_mass) + else: + # convert the bins into a list + for i in range(len(bins)): + for j in range(len(bins[i])): + if isinstance(bins[i][j], tuple): + bins[i][j] = self._convert_tuple_to_array(bins[i][j]) + bins[i][j] = np.array(bins[i][j]) + bin_mass[i][j] = np.array(bin_mass[i][j]) + self.bins = bins + self.bin_mass = bin_mass + + super().__init__(index=index, columns=columns) + + def _energy_self(self): + r"""Energy of self, w.r.t. self. + + :math:`\mathbb{E}[|X-Y|]`, where :math:`X, Y` are i.i.d. copies of self. + + Private method, to be implemented by subclasses. + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + energy values w.r.t. the given points + """ + bins = self.bins + bin_mass = self.bin_mass + energy_arr = [] + from numpy.lib.stride_tricks import sliding_window_view + + if self._check_single_array_distr(bins, bin_mass): + bins_hist = bins + bin_mass_hist = bin_mass + win_centre_bins = 0.5 * np.sum( + sliding_window_view(bins_hist, window_shape=2), axis=1 + ) + expected_value = 0 + for i in range(len(bin_mass_hist)): + for j in range(len(bin_mass_hist)): + expected_value += ( + bin_mass_hist[i] + * bin_mass_hist[j] + * abs(win_centre_bins[i] - win_centre_bins[j]) + ) + energy_arr = expected_value + return expected_value + + for row in range(len(bins)): + energy_arr_row = [] + for col in range(len(bins[0])): + bins_hist = bins[row][col] + bin_mass_hist = bin_mass[row][col] + win_centre_bins = 0.5 * np.sum( + sliding_window_view(bins_hist, window_shape=2), axis=1 + ) + expected_value = 0 + for i in range(len(bin_mass_hist)): + for j in range(len(bin_mass_hist)): + expected_value += ( + bin_mass_hist[i] + * bin_mass_hist[j] + * abs(win_centre_bins[i] - win_centre_bins[j]) + ) + energy_arr_row.append(expected_value) + energy_arr.append(energy_arr_row) + energy_arr = np.array(energy_arr) + if energy_arr.ndim > 0: + energy_arr = np.sum(energy_arr, axis=1) + return energy_arr + + def _energy_x(self, x): + r"""Energy of self, w.r.t. a constant frame x. + + :math:`\mathbb{E}[|X-x|]`, where :math:`X` is a copy of self, + and :math:`x` is a constant. + + Private method, to be implemented by subclasses. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to compute energy w.r.t. to + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + energy values w.r.t. the given points + """ + bins = self.bins + bin_mass = self.bin_mass + energy_arr = [] + mean = self._mean() + cdf = self._cdf(x) + pdf = self._pdf(x) + x = np.array(x) + from numpy.lib.stride_tricks import sliding_window_view + + if self._check_single_array_distr(bins, bin_mass): + bins_hist = np.array(bins) + bin_mass_hist = np.array(bin_mass) + X = x + is_outside = X < bins_hist[0] or X > bins_hist[-1] + if is_outside: + energy_arr = abs(mean - X) + return energy_arr + else: + # consider X lies in kth bin + # so kth bin's start index is + k_1_bins = np.where(X >= bins_hist)[0][-1] + win_sum_bins = np.sum( + sliding_window_view(bins_hist, window_shape=2), axis=1 + ) + # upto kth bin excluding kth + X_upto_k = X * cdf - 0.5 * np.dot( + win_sum_bins[:k_1_bins], bin_mass_hist[:k_1_bins] + ) + # if X is in last bin + if k_1_bins >= len(bin_mass_hist) - 1: + energy_arr = X_upto_k + return energy_arr + # in the kth bin + X_in_k = ( + 0.5 + * pdf + * ( + bins_hist[k_1_bins] ** 2 + + bins_hist[k_1_bins + 1] ** 2 + - 2 * X**2 + ) + ) + # after kth bin excluding kth + X_after_k = 0.5 * np.dot( + win_sum_bins[k_1_bins + 1 :], bin_mass_hist[k_1_bins + 1 :] + ) - X * (1 - cdf) + energy_arr = X_upto_k + X_in_k + X_after_k + return energy_arr + + for i in range(len(bins)): + energy_arr_row = [] + for j in range(len(bins[0])): + bins_hist = bins[i][j] + bin_mass_hist = bin_mass[i][j] + X = x[i][j] + is_outside = X < bins_hist[0] or X > bins_hist[-1] + if is_outside: + energy_arr_row.append(abs(mean[i][j] - X)) + else: + # consider X lies in kth bin + # so kth bin's start index is + k_1_bins = np.where(X >= bins_hist)[0][-1] + win_sum_bins = np.sum( + sliding_window_view(bins_hist, window_shape=2), axis=1 + ) + # upto kth bin excluding kth + X_upto_k = X * cdf[i][j] - 0.5 * np.dot( + win_sum_bins[:k_1_bins], bin_mass_hist[:k_1_bins] + ) + # if X is in last bin + if k_1_bins >= len(bin_mass_hist) - 1: + energy_arr_row.append(X_upto_k) + continue + # in the kth bin + X_in_k = ( + 0.5 + * pdf[i][j] + * ( + bins_hist[k_1_bins] ** 2 + + bins_hist[k_1_bins + 1] ** 2 + - 2 * X**2 + ) + ) + # after kth bin excluding kth + X_after_k = 0.5 * np.dot( + win_sum_bins[k_1_bins + 1 :], bin_mass_hist[k_1_bins + 1 :] + ) - X * (1 - cdf[i][j]) + energy_arr_row.append(X_upto_k + X_in_k + X_after_k) + energy_arr.append(energy_arr_row) + energy_arr = np.array(energy_arr) + if energy_arr.ndim > 0: + energy_arr = np.sum(energy_arr, axis=1) + return energy_arr + + def _mean(self): + """Return expected value of the distribution. + + Returns + ------- + float, sum(bin_mass)/range(bins) + expected value of distribution (entry-wise) + """ + bins = self.bins + bin_mass = self.bin_mass + mean = [] + from numpy.lib.stride_tricks import sliding_window_view + + if self._check_single_array_distr(bins, bin_mass): + bins = np.array(bins) + bin_mass = np.array(bin_mass) + win_sum_bins = np.sum(sliding_window_view(bins, window_shape=2), axis=1) + mean = 0.5 * np.dot(win_sum_bins, bin_mass) + return mean + + for i in range(len(bins)): + mean_row = [] + for j in range(len(bins[0])): + bins_hist = bins[i][j] + bin_mass_hist = bin_mass[i][j] + win_sum_bins = np.sum( + sliding_window_view(bins_hist, window_shape=2), axis=1 + ) + mean_row.append(0.5 * np.dot(win_sum_bins, bin_mass_hist)) + mean.append(mean_row) + return np.array(mean) + + def _var(self): + r"""Return element/entry-wise variance of the distribution. + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + variance of the distribution (entry-wise) + """ + bins = self.bins + bin_mass = self.bin_mass + var = [] + mean = self._mean() + from numpy.lib.stride_tricks import sliding_window_view + + if self._check_single_array_distr(bins, bin_mass): + bins = np.array(bins) + bin_mass = np.array(bin_mass) + win_sum_bins = np.sum(sliding_window_view(bins, window_shape=2), axis=1) + win_prod_bins = np.prod(sliding_window_view(bins, window_shape=2), axis=1) + var = np.dot(bin_mass / 3, (win_sum_bins**2 - win_prod_bins)) - mean**2 + return var + + for i in range(len(bins)): + var_row = [] + for j in range(len(bins[0])): + bins_hist = bins[i][j] + bin_mass_hist = bin_mass[i][j] + win_sum_bins = np.sum( + sliding_window_view(bins_hist, window_shape=2), axis=1 + ) + win_prod_bins = np.prod( + sliding_window_view(bins_hist, window_shape=2), axis=1 + ) + var_row.append( + np.dot(bin_mass_hist / 3, (win_sum_bins**2 - win_prod_bins)) + - mean[i][j] ** 2 + ) + var.append(var_row) + var = np.array(var) + return var + + def _pdf(self, x): + """Probability density function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + pdf values at the given points + """ + bin_mass = self.bin_mass + bins = self.bins + pdf = [] + if self._check_single_array_distr(bins, bin_mass): + bins = np.array(bins) + bin_mass = np.array(bin_mass) + bin_width = np.diff(bins) + pdf_arr = bin_mass / bin_width + X = x + if len(np.where(X < bins)[0]) and len(np.where(X >= bins)[0]): + pdf = pdf_arr[min(np.where(X < bins)[0]) - 1] + else: + pdf = 0 + return pdf + # bins_hist contains the bins edges of each histogram + for i in range(len(bins)): + pdf_row = [] + for j in range(len(bins[i])): + bins_hist = bins[i][j] + bin_mass_hist = bin_mass[i][j] + bin_width = np.diff(bins_hist) + pdf_arr = bin_mass_hist / bin_width + X = x[i][j] + if len(np.where(X < bins_hist)[0]) and len(np.where(X >= bins_hist)[0]): + pdf_row.append(pdf_arr[min(np.where(X < bins_hist)[0]) - 1]) + else: + pdf_row.append(0) + pdf.append(pdf_row) + pdf = np.array(pdf) + return pdf + + def _log_pdf(self, x): + """Logarithmic probability density function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the pdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + log pdf values at the given points + """ + bin_mass = self.bin_mass + bins = self.bins + lpdf = [] + + from warnings import warn + + if self._check_single_array_distr(bins, bin_mass): + bins = np.array(bins) + bin_mass = np.array(bin_mass) + bin_width = np.diff(bins) + if 0 in bin_mass: + warn( + "Zero values detected in bin_mass. These values", + "will be replaced with a small positive constant.", + ) + small_value = 1e-100 + bin_mass = np.where(bin_mass == 0, small_value, bin_mass) + lpdf_arr = np.log(bin_mass / bin_width) + X = x + if len(np.where(X < bins)[0]) and len(np.where(X >= bins)[0]): + lpdf = lpdf_arr[min(np.where(X < bins)[0]) - 1] + else: + lpdf = 0 + return lpdf + + x = np.array(x) + for i in range(len(bins)): + lpdf_row = [] + for j in range(len(bins[0])): + X = x[i][j] + bins_hist = bins[i][j] + bin_mass_hist = bin_mass[i][j] + bin_width = np.diff(bins_hist) + if 0 in bin_mass_hist: + warn("0 value detected in bin_mass is replaced by a positive const") + small_value = 1e-100 + bin_mass_hist = np.where( + bin_mass_hist == 0, small_value, bin_mass_hist + ) + lpdf_arr = np.log(bin_mass_hist / bin_width) + if len(np.where(X < bins_hist)[0]) and len(np.where(X >= bins_hist)[0]): + lpdf_row.append(lpdf_arr[min(np.where(X < bins_hist)[0]) - 1]) + else: + lpdf_row.append(0) + lpdf.append(lpdf_row) + lpdf = np.array(lpdf) + return lpdf + + def _cdf(self, x): + """Cumulative distribution function. + + Parameters + ---------- + x : 2D np.ndarray, same shape as ``self`` + values to evaluate the cdf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + cdf values at the given points + """ + bins = self.bins + bin_mass = self.bin_mass + cdf = [] + pdf = self._pdf(x) + + if self._check_single_array_distr(bins, bin_mass): + bins = np.array(bins) + bin_mass = np.array(bin_mass) + # cum_bin_index is an array of all indices + # of the bins or bin edges that are less than X. + X = x + cum_bin_index = np.where(X >= bins)[0] + if len(cum_bin_index) == len(bins): + cdf = 1 + elif len(cum_bin_index) > 1: + cdf = np.cumsum(bin_mass)[-2] + pdf * (X - bins[cum_bin_index[-1]]) + + elif len(cum_bin_index) == 0: + cdf = 0 + elif len(cum_bin_index) == 1: + cdf = pdf * (X - bins[cum_bin_index[-1]]) + + return cdf + + x = np.array(x) + + for i in range(len(bins)): + cdf_row = [] + for j in range(len(bins[0])): + X = x[i][j] + bins_hist = bins[i][j] + bin_mass_hist = bin_mass[i][j] + cum_sum_mass = np.cumsum(bin_mass_hist) + # cum_bin_index is an array of all indices + # of the bins or bin edges that are less than X. + cum_bin_index = np.where(X >= bins_hist)[0] + if len(cum_bin_index) == len(bins_hist): + cdf_row.append(1) + elif len(cum_bin_index) > 1: + cdf_row.append( + cum_sum_mass[cum_bin_index[-2]] + + pdf[i][j] * (X - bins_hist[cum_bin_index[-1]]) + ) + elif len(cum_bin_index) == 0: + cdf_row.append(0) + elif len(cum_bin_index) == 1: + cdf_row.append(pdf[i][j] * (X - bins_hist[cum_bin_index[-1]])) + cdf.append(cdf_row) + cdf = np.array(cdf) + return cdf + + def _ppf(self, p): + """Quantile function = percent point function = inverse cdf. + + Parameters + ---------- + p : 2D np.ndarray, same shape as ``self`` + values to evaluate the ppf at + + Returns + ------- + 2D np.ndarray, same shape as ``self`` + ppf values at the given points + """ + bins = self.bins + bin_mass = self.bin_mass + ppf = [] + p = np.array(p) + + if self._check_single_array_distr(bins, bin_mass): + bins = np.array(bins) + bin_mass = np.array(bin_mass) + cum_sum_mass = np.cumsum(bin_mass) + # print(cum_sum_mass) + pdf_bins = [] + for bin in bins: + pdf_bins.append(self._pdf(bin)) + P = p + cum_bin_index_P = np.where(P >= cum_sum_mass)[0] + if P < 0 or P > 1: + X = np.NaN + elif len(cum_bin_index_P) == 0: + X = bins[0] + P / pdf_bins[len(cum_bin_index_P)] + elif len(cum_bin_index_P) > 0: + if P - cum_sum_mass[cum_bin_index_P[-1]] > 0: + X = ( + bins[cum_bin_index_P[-1] + 1] + + (P - cum_sum_mass[cum_bin_index_P[-1]]) + / pdf_bins[len(cum_bin_index_P)] + ) + else: + X = bins[cum_bin_index_P[-1] + 1] + + return X + + for i in range(len(bins)): + ppf_row = [] + for j in range(len(bins[0])): + P = p[i][j] + bins_hist = bins[i][j] + bin_mass_hist = bin_mass[i][j] + cum_sum_mass = np.cumsum(bin_mass_hist) + # manually finding pdf of 1D array at all bin edges + pdf_bins = [] + bin_width = np.diff(bins_hist) + pdf_arr = bin_mass_hist / bin_width + for bh in bins_hist: + if len(np.where(bh < bins_hist)[0]) and len( + np.where(bh >= bins_hist)[0] + ): + pdf_bins.append(pdf_arr[min(np.where(bh < bins_hist)[0]) - 1]) + else: + pdf_bins.append(0) + pdf_bins = np.array(pdf_bins) + # find a way to calculate pdf for 1D array ... + cum_bin_index_P = np.where(P >= cum_sum_mass)[0] + if P < 0 or P > 1: + ppf_row.append(np.NaN) + elif len(cum_bin_index_P) == 0: + X = bins_hist[0] + P / pdf_bins[len(cum_bin_index_P)] + ppf_row.append(round(X, 4)) + elif len(cum_bin_index_P) > 0: + if P - cum_sum_mass[cum_bin_index_P[-1]] > 0: + X = ( + bins_hist[cum_bin_index_P[-1] + 1] + + (P - cum_sum_mass[cum_bin_index_P[-1]]) + / pdf_bins[len(cum_bin_index_P)] + ) + else: + X = bins_hist[cum_bin_index_P[-1] + 1] + ppf_row.append(round(X, 4)) + ppf.append(ppf_row) + ppf = np.array(ppf) + return ppf + + @classmethod + def get_test_params(cls, parameter_set="default"): + """Return testing parameter settings for the estimator.""" + # array case examples + params1 = { + "bins": [ + [[0, 1, 2, 3, 4], [5, 5.5, 5.8, 6.5, 7, 7.5]], + [(2, 12, 5), [0, 1, 2, 3, 4]], + [[1.5, 2.5, 3.1, 4, 5.4], [-4, -2, -1.5, 5, 10]], + ], + "bin_mass": [ + [[0.1, 0.2, 0.3, 0.4], [0.25, 0.1, 0, 0.4, 0.25]], + [[0.1, 0.2, 0.4, 0.2, 0.1], [0.4, 0.3, 0.2, 0.1]], + [[0.06, 0.15, 0.09, 0.7], [0.4, 0.15, 0.325, 0.125]], + ], + "index": pd.Index(np.arange(3)), + "columns": pd.Index(np.arange(2)), + } + + return [params1] diff --git a/skpro/distributions/tests/test_all_distrs.py b/skpro/distributions/tests/test_all_distrs.py index b9c2fe23b..1fed1c5c9 100644 --- a/skpro/distributions/tests/test_all_distrs.py +++ b/skpro/distributions/tests/test_all_distrs.py @@ -167,7 +167,7 @@ def test_methods_p(self, object_instance, method, shuffled): else: p = np_unif - res = getattr(object_instance, method)(p) + res = getattr(d, method)(p) _check_output_format(res, d, method) From bd5b6b50fef248e17b7aecf14d543614750c6b99 Mon Sep 17 00:00:00 2001 From: ShreeshaM07 Date: Tue, 11 Jun 2024 23:14:39 +0530 Subject: [PATCH 02/12] updated docstring to look better --- skpro/distributions/histogram.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/skpro/distributions/histogram.py b/skpro/distributions/histogram.py index dc2b854a0..8d05276bd 100644 --- a/skpro/distributions/histogram.py +++ b/skpro/distributions/histogram.py @@ -20,27 +20,37 @@ class Histogram(BaseArrayDistribution): bins : tuple(float,float,int) or numpy.array of float 1D or 2D list of size m x n 1. tuple(first bin's start point, last bin's end point, number of bins) Used when bin widths are equal. + ``` example: bins:(0,4,4), + ``` 2. array has the bin boundaries with 1st element the first bin's starting point and rest are the bin ending points of all bins + ``` example: bins:[0, 1, 2, 3, 4], + ``` 3. 2D list of size m x n containing m*n float numpy.arrays or tuple like case 1. + ``` example : "bins": [ [[0, 1, 2, 3, 4], [5, 5.5, 5.8, 6.5, 7, 7.5]], [(2, 12, 5), [0, 1, 2, 3, 4]], [[1.5, 2.5, 3.1, 4, 5.4], [-4, -2, -1.5, 5, 10]], ] + ``` bin_mass: array of float 1D or 2D list of size m x n containing 1. Array has the mass of the bins or area of the bins. + ``` example: bin_mass:[0.1, 0.2, 0.3, 0.4], + ``` Note: len(bin_mass) will be (len(bins)-1). Note: Sum of all the bin_mass must be 1. 2. 2D list of size m x n containing m*n float numpy.arrays satisfying case 1. + ``` example : "bin_mass": [ [[0.1, 0.2, 0.3, 0.4], [0.25, 0.1, 0, 0.4, 0.25]], [[0.1, 0.2, 0.4, 0.2, 0.1], [0.4, 0.3, 0.2, 0.1]], [[0.06, 0.15, 0.09, 0.7], [0.4, 0.15, 0.325, 0.125]], ] + ``` index : pd.Index, optional, default = RangeIndex columns : pd.Index, optional, default = RangeIndex """ From 6dd7474a2ba3c2463a529a4396e6168e945cf317 Mon Sep 17 00:00:00 2001 From: ShreeshaM07 Date: Tue, 11 Jun 2024 23:45:14 +0530 Subject: [PATCH 03/12] update docstring and BaseArrayDistribution in distribution.rst --- docs/source/api_reference/distributions.rst | 1 + skpro/distributions/histogram.py | 55 +++++++++------------ 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/docs/source/api_reference/distributions.rst b/docs/source/api_reference/distributions.rst index 4157488d0..8351ccd92 100644 --- a/docs/source/api_reference/distributions.rst +++ b/docs/source/api_reference/distributions.rst @@ -22,6 +22,7 @@ Base :template: class.rst BaseDistribution + BaseArrayDistribution Parametric distributions ------------------------ diff --git a/skpro/distributions/histogram.py b/skpro/distributions/histogram.py index 8d05276bd..06dfb0e3b 100644 --- a/skpro/distributions/histogram.py +++ b/skpro/distributions/histogram.py @@ -10,47 +10,40 @@ class Histogram(BaseArrayDistribution): - """Histogram Probability Distribution. + r"""Histogram Probability Distribution. The histogram probability distribution is parameterized by the bins and bin densities. Parameters ---------- - bins : tuple(float,float,int) or numpy.array of float 1D or 2D list of size m x n + bins : tuple(float, float, int) or numpy.array of float 1D or 2D list of size m x n 1. tuple(first bin's start point, last bin's end point, number of bins) - Used when bin widths are equal. - ``` - example: bins:(0,4,4), - ``` + Used when bin widths are equal. + example: bins:(0,4,4) 2. array has the bin boundaries with 1st element the first bin's - starting point and rest are the bin ending points of all bins - ``` - example: bins:[0, 1, 2, 3, 4], - ``` - 3. 2D list of size m x n containing m*n float numpy.arrays or tuple like case 1. - ``` - example : "bins": [ - [[0, 1, 2, 3, 4], [5, 5.5, 5.8, 6.5, 7, 7.5]], - [(2, 12, 5), [0, 1, 2, 3, 4]], - [[1.5, 2.5, 3.1, 4, 5.4], [-4, -2, -1.5, 5, 10]], - ] - ``` - bin_mass: array of float 1D or 2D list of size m x n containing + starting point and rest are the bin ending points of all bins + example: bins:[0, 1, 2, 3, 4] + 3. 2D list of size m x n containing m*n float numpy.arrays or tuple + like case 1. + example: + bins: [ + [[0, 1, 2, 3, 4], [5, 5.5, 5.8, 6.5, 7, 7.5]], + [(2, 12, 5), [0, 1, 2, 3, 4]], + [[1.5, 2.5, 3.1, 4, 5.4], [-4, -2, -1.5, 5, 10]] + ] + bin_mass : array of float 1D or 2D list of size m x n 1. Array has the mass of the bins or area of the bins. - ``` - example: bin_mass:[0.1, 0.2, 0.3, 0.4], - ``` - Note: len(bin_mass) will be (len(bins)-1). - Note: Sum of all the bin_mass must be 1. + example: bin_mass:[0.1, 0.2, 0.3, 0.4] + Note: `len(bin_mass)` will be `(len(bins)-1)`. + Note: Sum of all the `bin_mass` must be `1`. 2. 2D list of size m x n containing m*n float numpy.arrays satisfying case 1. - ``` - example : "bin_mass": [ - [[0.1, 0.2, 0.3, 0.4], [0.25, 0.1, 0, 0.4, 0.25]], - [[0.1, 0.2, 0.4, 0.2, 0.1], [0.4, 0.3, 0.2, 0.1]], - [[0.06, 0.15, 0.09, 0.7], [0.4, 0.15, 0.325, 0.125]], - ] - ``` + example: + bin_mass: [ + [[0.1, 0.2, 0.3, 0.4], [0.25, 0.1, 0, 0.4, 0.25]], + [[0.1, 0.2, 0.4, 0.2, 0.1], [0.4, 0.3, 0.2, 0.1]], + [[0.06, 0.15, 0.09, 0.7], [0.4, 0.15, 0.325, 0.125]] + ] index : pd.Index, optional, default = RangeIndex columns : pd.Index, optional, default = RangeIndex """ From 3a79d42223a46b6cb8640205753e3f2e07d70a69 Mon Sep 17 00:00:00 2001 From: ShreeshaM07 Date: Tue, 11 Jun 2024 23:56:16 +0530 Subject: [PATCH 04/12] docstring --- skpro/distributions/histogram.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/skpro/distributions/histogram.py b/skpro/distributions/histogram.py index 06dfb0e3b..5b5694a39 100644 --- a/skpro/distributions/histogram.py +++ b/skpro/distributions/histogram.py @@ -20,30 +20,35 @@ class Histogram(BaseArrayDistribution): bins : tuple(float, float, int) or numpy.array of float 1D or 2D list of size m x n 1. tuple(first bin's start point, last bin's end point, number of bins) Used when bin widths are equal. - example: bins:(0,4,4) + example: + bins:(0,4,4) 2. array has the bin boundaries with 1st element the first bin's starting point and rest are the bin ending points of all bins - example: bins:[0, 1, 2, 3, 4] + example: + bins:[0, 1, 2, 3, 4] 3. 2D list of size m x n containing m*n float numpy.arrays or tuple like case 1. example: - bins: [ - [[0, 1, 2, 3, 4], [5, 5.5, 5.8, 6.5, 7, 7.5]], - [(2, 12, 5), [0, 1, 2, 3, 4]], - [[1.5, 2.5, 3.1, 4, 5.4], [-4, -2, -1.5, 5, 10]] - ] + bins: + [ + [[0, 1, 2, 3, 4], [5, 5.5, 5.8, 6.5, 7, 7.5]], + [(2, 12, 5), [0, 1, 2, 3, 4]], + [[1.5, 2.5, 3.1, 4, 5.4], [-4, -2, -1.5, 5, 10]] + ] bin_mass : array of float 1D or 2D list of size m x n 1. Array has the mass of the bins or area of the bins. - example: bin_mass:[0.1, 0.2, 0.3, 0.4] + example: + bin_mass:[0.1, 0.2, 0.3, 0.4] Note: `len(bin_mass)` will be `(len(bins)-1)`. Note: Sum of all the `bin_mass` must be `1`. 2. 2D list of size m x n containing m*n float numpy.arrays satisfying case 1. example: - bin_mass: [ - [[0.1, 0.2, 0.3, 0.4], [0.25, 0.1, 0, 0.4, 0.25]], - [[0.1, 0.2, 0.4, 0.2, 0.1], [0.4, 0.3, 0.2, 0.1]], - [[0.06, 0.15, 0.09, 0.7], [0.4, 0.15, 0.325, 0.125]] - ] + bin_mass: + [ + [[0.1, 0.2, 0.3, 0.4], [0.25, 0.1, 0, 0.4, 0.25]], + [[0.1, 0.2, 0.4, 0.2, 0.1], [0.4, 0.3, 0.2, 0.1]], + [[0.06, 0.15, 0.09, 0.7], [0.4, 0.15, 0.325, 0.125]] + ] index : pd.Index, optional, default = RangeIndex columns : pd.Index, optional, default = RangeIndex """ From a8358d483aa6ff86cb1145d9a157563b6bc12a79 Mon Sep 17 00:00:00 2001 From: ShreeshaM07 Date: Wed, 12 Jun 2024 00:05:30 +0530 Subject: [PATCH 05/12] docstring --- skpro/distributions/histogram.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skpro/distributions/histogram.py b/skpro/distributions/histogram.py index 5b5694a39..f79ddd6ac 100644 --- a/skpro/distributions/histogram.py +++ b/skpro/distributions/histogram.py @@ -10,7 +10,7 @@ class Histogram(BaseArrayDistribution): - r"""Histogram Probability Distribution. + """Histogram Probability Distribution. The histogram probability distribution is parameterized by the bins and bin densities. From 108176ed695a8ba3b4cb362b606130a36c6aa46b Mon Sep 17 00:00:00 2001 From: ShreeshaM07 Date: Wed, 12 Jun 2024 01:19:24 +0530 Subject: [PATCH 06/12] singular array cdf --- skpro/distributions/histogram.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/skpro/distributions/histogram.py b/skpro/distributions/histogram.py index f79ddd6ac..711997a01 100644 --- a/skpro/distributions/histogram.py +++ b/skpro/distributions/histogram.py @@ -419,10 +419,7 @@ def _log_pdf(self, x): bin_mass = np.array(bin_mass) bin_width = np.diff(bins) if 0 in bin_mass: - warn( - "Zero values detected in bin_mass. These values", - "will be replaced with a small positive constant.", - ) + warn("0 value detected in bin_mass is replaced by a positive const") small_value = 1e-100 bin_mass = np.where(bin_mass == 0, small_value, bin_mass) lpdf_arr = np.log(bin_mass / bin_width) @@ -475,21 +472,23 @@ def _cdf(self, x): pdf = self._pdf(x) if self._check_single_array_distr(bins, bin_mass): - bins = np.array(bins) - bin_mass = np.array(bin_mass) + X = x + bins_hist = bins + bin_mass_hist = bin_mass + cum_sum_mass = np.cumsum(bin_mass_hist) # cum_bin_index is an array of all indices # of the bins or bin edges that are less than X. - X = x - cum_bin_index = np.where(X >= bins)[0] - if len(cum_bin_index) == len(bins): + cum_bin_index = np.where(X >= bins_hist)[0] + if len(cum_bin_index) == len(bins_hist): cdf = 1 elif len(cum_bin_index) > 1: - cdf = np.cumsum(bin_mass)[-2] + pdf * (X - bins[cum_bin_index[-1]]) - + cdf = cum_sum_mass[cum_bin_index[-2]] + pdf * ( + X - bins_hist[cum_bin_index[-1]] + ) elif len(cum_bin_index) == 0: cdf = 0 elif len(cum_bin_index) == 1: - cdf = pdf * (X - bins[cum_bin_index[-1]]) + cdf_row = pdf * (X - bins_hist[cum_bin_index[-1]]) return cdf From 1fb3b1107da3742d03b6802f3ec776c929faf0ce Mon Sep 17 00:00:00 2001 From: ShreeshaM07 Date: Wed, 12 Jun 2024 12:31:41 +0530 Subject: [PATCH 07/12] cdf single arr --- skpro/distributions/histogram.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skpro/distributions/histogram.py b/skpro/distributions/histogram.py index 711997a01..9c919b63c 100644 --- a/skpro/distributions/histogram.py +++ b/skpro/distributions/histogram.py @@ -488,7 +488,7 @@ def _cdf(self, x): elif len(cum_bin_index) == 0: cdf = 0 elif len(cum_bin_index) == 1: - cdf_row = pdf * (X - bins_hist[cum_bin_index[-1]]) + cdf = pdf * (X - bins_hist[cum_bin_index[-1]]) return cdf From bfac335e75a7e2cc7b4917a77d19d4d6eb6d6ffa Mon Sep 17 00:00:00 2001 From: ShreeshaM07 Date: Thu, 13 Jun 2024 23:50:32 +0530 Subject: [PATCH 08/12] stopped override of public methods --- skpro/distributions/base/_base_array.py | 221 ------------------------ 1 file changed, 221 deletions(-) diff --git a/skpro/distributions/base/_base_array.py b/skpro/distributions/base/_base_array.py index 3d12eb07d..c44dba6d4 100644 --- a/skpro/distributions/base/_base_array.py +++ b/skpro/distributions/base/_base_array.py @@ -244,224 +244,3 @@ def _get_bc_params_dict( # print(kwargs_as_np,shape,is_scalar) return kwargs_as_np, shape, is_scalar return kwargs_as_np - - def pdf(self, x): - r"""Probability density function. - - Let :math:`X` be a random variables with the distribution of ``self``, - taking values in ``(N, n)`` ``DataFrame``-s - Let :math:`x\in \mathbb{R}^{N\times n}`. - By :math:`p_{X_{ij}}`, denote the marginal pdf of :math:`X` at the - :math:`(i,j)`-th entry. - - The output of this method, for input ``x`` representing :math:`x`, - is a ``DataFrame`` with same columns and indices as ``self``, - and entries :math:`p_{X_{ij}}(x_{ij})`. - - If ``self`` has a mixed or discrete distribution, this returns - the weighted continuous part of `self`'s distribution instead of the pdf, - i.e., the marginal pdf integrate to the weight of the continuous part. - - Parameters - ---------- - x : ``pandas.DataFrame`` or 2D ``np.ndarray`` - representing :math:`x`, as above - - Returns - ------- - ``pd.DataFrame`` with same columns and index as ``self`` - containing :math:`p_{X_{ij}}(x_{ij})`, as above - """ - distr_type = self.get_tag("distr:measuretype", "mixed", raise_error=False) - x = np.array(x) - if distr_type == "discrete": - return self._coerce_to_self_index_df(0, flatten=False) - - return self._boilerplate("_pdf", x=x) - - def log_pdf(self, x): - r"""Logarithmic probability density function. - - Numerically more stable than calling pdf and then taking logartihms. - - Let :math:`X` be a random variables with the distribution of ``self``, - taking values in `(N, n)` ``DataFrame``-s - Let :math:`x\in \mathbb{R}^{N\times n}`. - By :math:`p_{X_{ij}}`, denote the marginal pdf of :math:`X` at the - :math:`(i,j)`-th entry. - - The output of this method, for input ``x`` representing :math:`x`, - is a ``DataFrame`` with same columns and indices as ``self``, - and entries :math:`\log p_{X_{ij}}(x_{ij})`. - - If ``self`` has a mixed or discrete distribution, this returns - the weighted continuous part of `self`'s distribution instead of the pdf, - i.e., the marginal pdf integrate to the weight of the continuous part. - - Parameters - ---------- - x : ``pandas.DataFrame`` or 2D ``np.ndarray`` - representing :math:`x`, as above - - Returns - ------- - ``pd.DataFrame`` with same columns and index as ``self`` - containing :math:`\log p_{X_{ij}}(x_{ij})`, as above - """ - distr_type = self.get_tag("distr:measuretype", "mixed", raise_error=False) - x = np.array(x) - if distr_type == "discrete": - return self._coerce_to_self_index_df(-np.inf, flatten=False) - - return self._boilerplate("_log_pdf", x=x) - - def pmf(self, x): - r"""Probability mass function. - - Let :math:`X` be a random variables with the distribution of ``self``, - taking values in ``(N, n)`` ``DataFrame``-s - Let :math:`x\in \mathbb{R}^{N\times n}`. - By :math:`m_{X_{ij}}`, denote the marginal mass of :math:`X` at the - :math:`(i,j)`-th entry, i.e., - :math:`m_{X_{ij}}(x_{ij}) = \mathbb{P}(X_{ij} = x_{ij})`. - - The output of this method, for input ``x`` representing :math:`x`, - is a ``DataFrame`` with same columns and indices as ``self``, - and entries :math:`m_{X_{ij}}(x_{ij})`. - - If ``self`` has a mixed or discrete distribution, this returns - the weighted continuous part of `self`'s distribution instead of the pdf, - i.e., the marginal pdf integrate to the weight of the continuous part. - - Parameters - ---------- - x : ``pandas.DataFrame`` or 2D ``np.ndarray`` - representing :math:`x`, as above - - Returns - ------- - ``pd.DataFrame`` with same columns and index as ``self`` - containing :math:`p_{X_{ij}}(x_{ij})`, as above - """ - distr_type = self.get_tag("distr:measuretype", "mixed", raise_error=False) - if distr_type == "continuous": - return self._coerce_to_self_index_df(0, flatten=False) - - return self._boilerplate("_pmf", x=x) - - def log_pmf(self, x): - r"""Logarithmic probability mass function. - - Numerically more stable than calling pmf and then taking logartihms. - - Let :math:`X` be a random variables with the distribution of ``self``, - taking values in `(N, n)` ``DataFrame``-s - Let :math:`x\in \mathbb{R}^{N\times n}`. - By :math:`m_{X_{ij}}`, denote the marginal pdf of :math:`X` at the - :math:`(i,j)`-th entry, i.e., - :math:`m_{X_{ij}}(x_{ij}) = \mathbb{P}(X_{ij} = x_{ij})`. - - The output of this method, for input ``x`` representing :math:`x`, - is a ``DataFrame`` with same columns and indices as ``self``, - and entries :math:`\log m_{X_{ij}}(x_{ij})`. - - If ``self`` has a mixed or discrete distribution, this returns - the weighted continuous part of `self`'s distribution instead of the pdf, - i.e., the marginal pdf integrate to the weight of the continuous part. - - Parameters - ---------- - x : ``pandas.DataFrame`` or 2D ``np.ndarray`` - representing :math:`x`, as above - - Returns - ------- - ``pd.DataFrame`` with same columns and index as ``self`` - containing :math:`\log m_{X_{ij}}(x_{ij})`, as above - """ - distr_type = self.get_tag("distr:measuretype", "mixed", raise_error=False) - if distr_type == "continuous": - return self._coerce_to_self_index_df(-np.inf, flatten=False) - - return self._boilerplate("_log_pmf", x=x) - - def cdf(self, x): - r"""Cumulative distribution function. - - Let :math:`X` be a random variables with the distribution of ``self``, - taking values in ``(N, n)`` ``DataFrame``-s - Let :math:`x\in \mathbb{R}^{N\times n}`. - By :math:`F_{X_{ij}}`, denote the marginal cdf of :math:`X` at the - :math:`(i,j)`-th entry, - i.e., :math:`F_{X_{ij}}(t) = \mathbb{P}(X_{ij} \leq t)`. - - The output of this method, for input ``x`` representing :math:`x`, - is a ``DataFrame`` with same columns and indices as ``self``, - and entries :math:`F_{X_{ij}}(x_{ij})`. - - Parameters - ---------- - x : ``pandas.DataFrame`` or 2D ``np.ndarray`` - representing :math:`x`, as above - - Returns - ------- - ``pd.DataFrame`` with same columns and index as ``self`` - containing :math:`F_{X_{ij}}(x_{ij})`, as above - """ - x = np.array(x) - return self._boilerplate("_cdf", x=x) - - def ppf(self, p): - r"""Quantile function = percent point function = inverse cdf. - - Let :math:`X` be a random variables with the distribution of ``self``, - taking values in ``(N, n)`` ``DataFrame``-s - Let :math:`x\in \mathbb{R}^{N\times n}`. - By :math:`F_{X_{ij}}`, denote the marginal cdf of :math:`X` at the - :math:`(i,j)`-th entry. - - The output of this method, for input ``p`` representing :math:`p`, - is a ``DataFrame`` with same columns and indices as ``self``, - and entries :math:`F^{-1}_{X_{ij}}(p_{ij})`. - - Parameters - ---------- - p : ``pandas.DataFrame`` or 2D np.ndarray - representing :math:`p`, as above - - Returns - ------- - ``pd.DataFrame`` with same columns and index as ``self`` - containing :math:`F_{X_{ij}}(x_{ij})`, as above - """ - p = np.array(p) - return self._boilerplate("_ppf", p=p) - - def energy(self, x=None): - r"""Energy of self, w.r.t. self or a constant frame x. - - Let :math:`X, Y` be i.i.d. random variables with the distribution of ``self``. - - If ``x`` is ``None``, returns :math:`\mathbb{E}[|X-Y|]` (per row), - "self-energy". - If ``x`` is passed, returns :math:`\mathbb{E}[|X-x|]` (per row), "energy wrt x". - - The CRPS is related to energy: - it holds that - :math:`\mbox{CRPS}(\mbox{self}, y)` = `self.energy(y) - 0.5 * self.energy()`. - - Parameters - ---------- - x : None or pd.DataFrame, optional, default=None - if ``pd.DataFrame``, must have same rows and columns as ``self`` - - Returns - ------- - ``pd.DataFrame`` with same rows as ``self``, single column ``"energy"`` - each row contains one float, self-energy/energy as described above. - """ - if x is None: - return self._boilerplate("_energy_self", columns=["energy"]) - x = np.array(x) - return self._boilerplate("_energy_x", x=x, columns=["energy"]) From aadb992f38f5c64b2ff096a21278888b08d72b35 Mon Sep 17 00:00:00 2001 From: ShreeshaM07 Date: Thu, 13 Jun 2024 23:55:50 +0530 Subject: [PATCH 09/12] removed test_all_distrs.py --- skpro/distributions/tests/test_all_distrs.py | 316 ------------------- 1 file changed, 316 deletions(-) delete mode 100644 skpro/distributions/tests/test_all_distrs.py diff --git a/skpro/distributions/tests/test_all_distrs.py b/skpro/distributions/tests/test_all_distrs.py deleted file mode 100644 index 1fed1c5c9..000000000 --- a/skpro/distributions/tests/test_all_distrs.py +++ /dev/null @@ -1,316 +0,0 @@ -"""Tests for BaseDistribution API points.""" -# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) -# adapted from sktime - -__author__ = ["fkiraly", "Alex-JG3"] - -import numpy as np -import pandas as pd -import pytest -from skbase.testing import QuickTester - -from skpro.datatypes import check_is_mtype -from skpro.tests.test_all_estimators import BaseFixtureGenerator, PackageConfig -from skpro.utils.index import random_ss_ix - - -class DistributionFixtureGenerator(BaseFixtureGenerator): - """Fixture generator for probability distributions. - - Fixtures parameterized - ---------------------- - object_class: object inheriting from BaseObject - ranges over object classes not excluded by EXCLUDE_OBJECTS, EXCLUDED_TESTS - object_instance: instance of object inheriting from BaseObject - ranges over object classes not excluded by EXCLUDE_OBJECTS, EXCLUDED_TESTS - instances are generated by create_test_instance class method - """ - - object_type_filter = "distribution" - - -def _has_capability(distr, method): - """Check whether distr has capability of method. - - Parameters - ---------- - distr : BaseDistribution object - method : str - method name to check - - Returns - ------- - whether distr has capability method, according to tags - capabilities:approx and capabilities:exact - """ - approx_methods = distr.get_tag("capabilities:approx") - exact_methods = distr.get_tag("capabilities:exact") - return method in approx_methods or method in exact_methods - - -METHODS_SCALAR = ["mean", "var", "energy"] -METHODS_SCALAR_POS = ["var", "energy"] # result always non-negative? -METHODS_X = ["energy", "pdf", "log_pdf", "pmf", "log_pmf", "cdf"] -METHODS_X_POS = ["energy", "pdf", "pmf", "cdf", "surv", "haz"] # result non-negative? -METHODS_P = ["ppf"] -METHODS_ROWWISE = ["energy"] # results in one column - - -class TestAllDistributions(PackageConfig, DistributionFixtureGenerator, QuickTester): - """Module level tests for all skpro parameter fitters.""" - - # TEMPORARY skip for CyclicBoosting and QPD classes - # due to silent failures on main, se #190 - exclude_objects = ["QPD_B"] - # remove this when fixing failures to re-enable testing - - def test_shape(self, object_instance): - """Test index, columns, len and shape of distribution.""" - d = object_instance - - assert hasattr(d, "shape") - assert isinstance(d.shape, tuple) - assert len(d.shape) in [0, 2] - - if len(d.shape) == 2: - assert all(isinstance(n, int) for n in d.shape) - - assert isinstance(d.index, pd.Index) - assert isinstance(d.columns, pd.Index) - - assert d.shape[0] == len(d.index) - assert d.shape[1] == len(d.columns) - - assert isinstance(len(d), int) - - if len(d.shape) == 2: - assert len(d) == d.shape[0] - else: - assert len(d) == 1 - - assert hasattr(d, "ndim") - assert d.ndim == len(d.shape) - - @pytest.mark.parametrize("shuffled", [False, True]) - def test_sample(self, object_instance, shuffled): - """Test sample expected return.""" - d = object_instance - - if shuffled: - d = _shuffle_distr(d) - - res = d.sample() - - if d.ndim > 0: - assert d.shape == res.shape - assert (res.index == d.index).all() - assert (res.columns == d.columns).all() - else: # d.ndim = 0 - assert np.isscalar(res) - - res_panel = d.sample(3) - if d.ndim > 0: - dummy_panel = pd.concat([res, res, res], keys=range(3)) - else: - dummy_panel = pd.DataFrame(index=range(3), columns=range(1)) - assert dummy_panel.shape == res_panel.shape - assert (res_panel.index == dummy_panel.index).all() - assert (res_panel.columns == dummy_panel.columns).all() - - @pytest.mark.parametrize("shuffled", [False, True]) - @pytest.mark.parametrize("method", METHODS_SCALAR, ids=METHODS_SCALAR) - def test_methods_scalar(self, object_instance, method, shuffled): - """Test expected return of scalar methods.""" - if not _has_capability(object_instance, method): - return None - - d = object_instance - if shuffled: - d = _shuffle_distr(d) - - res = getattr(d, method)() - - _check_output_format(res, d, method) - - @pytest.mark.parametrize("shuffled", [False, True]) - @pytest.mark.parametrize("method", METHODS_X, ids=METHODS_X) - def test_methods_x(self, object_instance, method, shuffled): - """Test expected return of methods that take sample-like argument.""" - if not _has_capability(object_instance, method): - return None - - d = object_instance - - if shuffled: - d = _shuffle_distr(d) - - x = d.sample() - res = getattr(d, method)(x) - - _check_output_format(res, d, method) - - @pytest.mark.parametrize("shuffled", [False, True]) - @pytest.mark.parametrize("method", METHODS_P, ids=METHODS_P) - def test_methods_p(self, object_instance, method, shuffled): - """Test expected return of methods that take percentage-like argument.""" - if not _has_capability(object_instance, method): - return None - - d = object_instance - - if shuffled: - d = _shuffle_distr(d) - - np_unif = np.random.uniform(size=d.shape) - if d.ndim > 0: - p = pd.DataFrame(np_unif, index=d.index, columns=d.columns) - else: - p = np_unif - - res = getattr(d, method)(p) - - _check_output_format(res, d, method) - - @pytest.mark.parametrize("q", [0.7, [0.1, 0.3, 0.9]]) - def test_quantile(self, object_instance, q): - """Test expected return of quantile method.""" - if not _has_capability(object_instance, "ppf"): - return None - - d = object_instance - - def _check_quantile_output(obj, q): - assert check_is_mtype( - obj, "pred_quantiles", "Proba", msg_return_dict="list" - ) - if d.ndim == 0: - expected_index = pd.RangeIndex(1) - vars = [d.__class__.__name__] - else: - expected_index = d.index - vars = d.columns - - assert (obj.index == expected_index).all() - - if not isinstance(q, list): - q = [q] - expected_columns = pd.MultiIndex.from_product([vars, q]) - assert (obj.columns == expected_columns).all() - - res = d.quantile(q) - _check_quantile_output(res, q) - - @pytest.mark.parametrize("subset_row", [True, False]) - @pytest.mark.parametrize("subset_col", [True, False]) - def test_subsetting(self, object_instance, subset_row, subset_col): - """Test subsetting of distribution.""" - d = object_instance - if d.ndim == 0: # no subsetting to test if example is scalar - return None - - if subset_row: - ix_loc = random_ss_ix(d.index, 3) - ix_iloc = d.index.get_indexer(ix_loc) - else: - ix_loc = d.index - ix_iloc = pd.RangeIndex(len(d.index)) - - if subset_col: - iy_loc = random_ss_ix(d.columns, 1) - iy_iloc = d.columns.get_indexer(iy_loc) - else: - iy_loc = d.columns - iy_iloc = pd.RangeIndex(len(d.columns)) - - res_loc = d.loc[ix_loc, iy_loc] - - assert isinstance(res_loc, type(d)) - assert res_loc.shape == (len(ix_loc), len(iy_loc)) - assert (res_loc.index == ix_loc).all() - assert (res_loc.columns == iy_loc).all() - - res_iloc = d.iloc[ix_iloc, iy_iloc] - - assert isinstance(res_iloc, type(d)) - assert res_iloc.shape == (len(ix_iloc), len(iy_iloc)) - assert (res_iloc.index == ix_loc).all() - assert (res_iloc.columns == iy_loc).all() - - def test_log_pdf_and_pdf(self, object_instance): - """Test that the log of the pdf and log_pdf function are similar.""" - d = object_instance - capabilities_exact = d.get_tags()["capabilities:exact"] - - if "log_pdf" not in capabilities_exact or "pdf" not in capabilities_exact: - return - x = d.sample() - pdf = d.pdf(x) - log_pdf = d.log_pdf(x) - assert np.allclose(np.log(pdf), log_pdf) - - def test_log_pmf_and_pmf(self, object_instance): - """Test that the log of the pmf and log_pmf function are similar.""" - d = object_instance - capabilities_exact = d.get_tags()["capabilities:exact"] - - if "log_pmf" not in capabilities_exact or "pmf" not in capabilities_exact: - return - x = d.sample() - pmf = d.pmf(x) - log_pmf = d.log_pmf(x) - assert np.allclose(np.log(pmf), log_pmf) - - def test_ppf_and_cdf(self, object_instance): - """Test that the ppf is the inverse of the cdf.""" - d = object_instance - capabilities_exact = d.get_tags()["capabilities:exact"] - - if "ppf" not in capabilities_exact or "cdf" not in capabilities_exact: - return - x = d.sample() - x_approx = d.ppf(d.cdf(x)) - if d.ndim > 0: - assert np.allclose(x.values, x_approx.values) - else: - assert np.allclose(x, x_approx) - - -def _check_output_format(res, dist, method): - """Check output format expectations for BaseDistribution tests.""" - if dist.shape == (): # scalar distribution case - # check if numpy float - assert np.isscalar(res) - assert np.isreal(res) - if method in METHODS_SCALAR_POS or method in METHODS_X_POS: - assert res >= 0 - return None - - # array distribution case - if method in METHODS_ROWWISE: - exp_shape = (dist.shape[0], 1) - else: - exp_shape = dist.shape - assert res.shape == exp_shape - assert (res.index == dist.index).all() - if method not in METHODS_ROWWISE: - assert (res.columns == dist.columns).all() - - if method in METHODS_SCALAR_POS or method in METHODS_X_POS: - assert (res >= 0).all().all() - - if isinstance(res, pd.DataFrame): - assert res.apply(pd.api.types.is_numeric_dtype).all() - elif isinstance(res, pd.Series): - assert pd.api.types.is_numeric_dtype(res) - else: - raise TypeError("res must be a pandas DataFrame or Series.") - - -def _shuffle_distr(d): - """Shuffle distribution row index.""" - if d.shape == (): # nothing to shuffle if scalar - return d - # shuffle rows otherwise - shuffled_df = pd.DataFrame(d.index).sample(frac=1) - shuffled_index = pd.Index(shuffled_df.values.flatten()) - return d.loc[shuffled_index] From b575912806f1f79e4893387f5189e259be3793ee Mon Sep 17 00:00:00 2001 From: ShreeshaM07 Date: Thu, 13 Jun 2024 23:57:30 +0530 Subject: [PATCH 10/12] Revert "stopped override of public methods" This reverts commit bfac335e75a7e2cc7b4917a77d19d4d6eb6d6ffa. --- skpro/distributions/base/_base_array.py | 221 ++++++++++++++++++++++++ 1 file changed, 221 insertions(+) diff --git a/skpro/distributions/base/_base_array.py b/skpro/distributions/base/_base_array.py index c44dba6d4..3d12eb07d 100644 --- a/skpro/distributions/base/_base_array.py +++ b/skpro/distributions/base/_base_array.py @@ -244,3 +244,224 @@ def _get_bc_params_dict( # print(kwargs_as_np,shape,is_scalar) return kwargs_as_np, shape, is_scalar return kwargs_as_np + + def pdf(self, x): + r"""Probability density function. + + Let :math:`X` be a random variables with the distribution of ``self``, + taking values in ``(N, n)`` ``DataFrame``-s + Let :math:`x\in \mathbb{R}^{N\times n}`. + By :math:`p_{X_{ij}}`, denote the marginal pdf of :math:`X` at the + :math:`(i,j)`-th entry. + + The output of this method, for input ``x`` representing :math:`x`, + is a ``DataFrame`` with same columns and indices as ``self``, + and entries :math:`p_{X_{ij}}(x_{ij})`. + + If ``self`` has a mixed or discrete distribution, this returns + the weighted continuous part of `self`'s distribution instead of the pdf, + i.e., the marginal pdf integrate to the weight of the continuous part. + + Parameters + ---------- + x : ``pandas.DataFrame`` or 2D ``np.ndarray`` + representing :math:`x`, as above + + Returns + ------- + ``pd.DataFrame`` with same columns and index as ``self`` + containing :math:`p_{X_{ij}}(x_{ij})`, as above + """ + distr_type = self.get_tag("distr:measuretype", "mixed", raise_error=False) + x = np.array(x) + if distr_type == "discrete": + return self._coerce_to_self_index_df(0, flatten=False) + + return self._boilerplate("_pdf", x=x) + + def log_pdf(self, x): + r"""Logarithmic probability density function. + + Numerically more stable than calling pdf and then taking logartihms. + + Let :math:`X` be a random variables with the distribution of ``self``, + taking values in `(N, n)` ``DataFrame``-s + Let :math:`x\in \mathbb{R}^{N\times n}`. + By :math:`p_{X_{ij}}`, denote the marginal pdf of :math:`X` at the + :math:`(i,j)`-th entry. + + The output of this method, for input ``x`` representing :math:`x`, + is a ``DataFrame`` with same columns and indices as ``self``, + and entries :math:`\log p_{X_{ij}}(x_{ij})`. + + If ``self`` has a mixed or discrete distribution, this returns + the weighted continuous part of `self`'s distribution instead of the pdf, + i.e., the marginal pdf integrate to the weight of the continuous part. + + Parameters + ---------- + x : ``pandas.DataFrame`` or 2D ``np.ndarray`` + representing :math:`x`, as above + + Returns + ------- + ``pd.DataFrame`` with same columns and index as ``self`` + containing :math:`\log p_{X_{ij}}(x_{ij})`, as above + """ + distr_type = self.get_tag("distr:measuretype", "mixed", raise_error=False) + x = np.array(x) + if distr_type == "discrete": + return self._coerce_to_self_index_df(-np.inf, flatten=False) + + return self._boilerplate("_log_pdf", x=x) + + def pmf(self, x): + r"""Probability mass function. + + Let :math:`X` be a random variables with the distribution of ``self``, + taking values in ``(N, n)`` ``DataFrame``-s + Let :math:`x\in \mathbb{R}^{N\times n}`. + By :math:`m_{X_{ij}}`, denote the marginal mass of :math:`X` at the + :math:`(i,j)`-th entry, i.e., + :math:`m_{X_{ij}}(x_{ij}) = \mathbb{P}(X_{ij} = x_{ij})`. + + The output of this method, for input ``x`` representing :math:`x`, + is a ``DataFrame`` with same columns and indices as ``self``, + and entries :math:`m_{X_{ij}}(x_{ij})`. + + If ``self`` has a mixed or discrete distribution, this returns + the weighted continuous part of `self`'s distribution instead of the pdf, + i.e., the marginal pdf integrate to the weight of the continuous part. + + Parameters + ---------- + x : ``pandas.DataFrame`` or 2D ``np.ndarray`` + representing :math:`x`, as above + + Returns + ------- + ``pd.DataFrame`` with same columns and index as ``self`` + containing :math:`p_{X_{ij}}(x_{ij})`, as above + """ + distr_type = self.get_tag("distr:measuretype", "mixed", raise_error=False) + if distr_type == "continuous": + return self._coerce_to_self_index_df(0, flatten=False) + + return self._boilerplate("_pmf", x=x) + + def log_pmf(self, x): + r"""Logarithmic probability mass function. + + Numerically more stable than calling pmf and then taking logartihms. + + Let :math:`X` be a random variables with the distribution of ``self``, + taking values in `(N, n)` ``DataFrame``-s + Let :math:`x\in \mathbb{R}^{N\times n}`. + By :math:`m_{X_{ij}}`, denote the marginal pdf of :math:`X` at the + :math:`(i,j)`-th entry, i.e., + :math:`m_{X_{ij}}(x_{ij}) = \mathbb{P}(X_{ij} = x_{ij})`. + + The output of this method, for input ``x`` representing :math:`x`, + is a ``DataFrame`` with same columns and indices as ``self``, + and entries :math:`\log m_{X_{ij}}(x_{ij})`. + + If ``self`` has a mixed or discrete distribution, this returns + the weighted continuous part of `self`'s distribution instead of the pdf, + i.e., the marginal pdf integrate to the weight of the continuous part. + + Parameters + ---------- + x : ``pandas.DataFrame`` or 2D ``np.ndarray`` + representing :math:`x`, as above + + Returns + ------- + ``pd.DataFrame`` with same columns and index as ``self`` + containing :math:`\log m_{X_{ij}}(x_{ij})`, as above + """ + distr_type = self.get_tag("distr:measuretype", "mixed", raise_error=False) + if distr_type == "continuous": + return self._coerce_to_self_index_df(-np.inf, flatten=False) + + return self._boilerplate("_log_pmf", x=x) + + def cdf(self, x): + r"""Cumulative distribution function. + + Let :math:`X` be a random variables with the distribution of ``self``, + taking values in ``(N, n)`` ``DataFrame``-s + Let :math:`x\in \mathbb{R}^{N\times n}`. + By :math:`F_{X_{ij}}`, denote the marginal cdf of :math:`X` at the + :math:`(i,j)`-th entry, + i.e., :math:`F_{X_{ij}}(t) = \mathbb{P}(X_{ij} \leq t)`. + + The output of this method, for input ``x`` representing :math:`x`, + is a ``DataFrame`` with same columns and indices as ``self``, + and entries :math:`F_{X_{ij}}(x_{ij})`. + + Parameters + ---------- + x : ``pandas.DataFrame`` or 2D ``np.ndarray`` + representing :math:`x`, as above + + Returns + ------- + ``pd.DataFrame`` with same columns and index as ``self`` + containing :math:`F_{X_{ij}}(x_{ij})`, as above + """ + x = np.array(x) + return self._boilerplate("_cdf", x=x) + + def ppf(self, p): + r"""Quantile function = percent point function = inverse cdf. + + Let :math:`X` be a random variables with the distribution of ``self``, + taking values in ``(N, n)`` ``DataFrame``-s + Let :math:`x\in \mathbb{R}^{N\times n}`. + By :math:`F_{X_{ij}}`, denote the marginal cdf of :math:`X` at the + :math:`(i,j)`-th entry. + + The output of this method, for input ``p`` representing :math:`p`, + is a ``DataFrame`` with same columns and indices as ``self``, + and entries :math:`F^{-1}_{X_{ij}}(p_{ij})`. + + Parameters + ---------- + p : ``pandas.DataFrame`` or 2D np.ndarray + representing :math:`p`, as above + + Returns + ------- + ``pd.DataFrame`` with same columns and index as ``self`` + containing :math:`F_{X_{ij}}(x_{ij})`, as above + """ + p = np.array(p) + return self._boilerplate("_ppf", p=p) + + def energy(self, x=None): + r"""Energy of self, w.r.t. self or a constant frame x. + + Let :math:`X, Y` be i.i.d. random variables with the distribution of ``self``. + + If ``x`` is ``None``, returns :math:`\mathbb{E}[|X-Y|]` (per row), + "self-energy". + If ``x`` is passed, returns :math:`\mathbb{E}[|X-x|]` (per row), "energy wrt x". + + The CRPS is related to energy: + it holds that + :math:`\mbox{CRPS}(\mbox{self}, y)` = `self.energy(y) - 0.5 * self.energy()`. + + Parameters + ---------- + x : None or pd.DataFrame, optional, default=None + if ``pd.DataFrame``, must have same rows and columns as ``self`` + + Returns + ------- + ``pd.DataFrame`` with same rows as ``self``, single column ``"energy"`` + each row contains one float, self-energy/energy as described above. + """ + if x is None: + return self._boilerplate("_energy_self", columns=["energy"]) + x = np.array(x) + return self._boilerplate("_energy_x", x=x, columns=["energy"]) From 489da2a60240aee0b71aaea871d5fd2a079cc955 Mon Sep 17 00:00:00 2001 From: ShreeshaM07 Date: Fri, 14 Jun 2024 00:00:08 +0530 Subject: [PATCH 11/12] Revert "stopped override of public methods" This reverts commit f9290590c65eb99dcf60ea36a2e6fbbbfac7f9c3. --- skpro/distributions/base/_base_array.py | 221 ------------- skpro/distributions/tests/test_all_distrs.py | 316 +++++++++++++++++++ 2 files changed, 316 insertions(+), 221 deletions(-) create mode 100644 skpro/distributions/tests/test_all_distrs.py diff --git a/skpro/distributions/base/_base_array.py b/skpro/distributions/base/_base_array.py index 3d12eb07d..c44dba6d4 100644 --- a/skpro/distributions/base/_base_array.py +++ b/skpro/distributions/base/_base_array.py @@ -244,224 +244,3 @@ def _get_bc_params_dict( # print(kwargs_as_np,shape,is_scalar) return kwargs_as_np, shape, is_scalar return kwargs_as_np - - def pdf(self, x): - r"""Probability density function. - - Let :math:`X` be a random variables with the distribution of ``self``, - taking values in ``(N, n)`` ``DataFrame``-s - Let :math:`x\in \mathbb{R}^{N\times n}`. - By :math:`p_{X_{ij}}`, denote the marginal pdf of :math:`X` at the - :math:`(i,j)`-th entry. - - The output of this method, for input ``x`` representing :math:`x`, - is a ``DataFrame`` with same columns and indices as ``self``, - and entries :math:`p_{X_{ij}}(x_{ij})`. - - If ``self`` has a mixed or discrete distribution, this returns - the weighted continuous part of `self`'s distribution instead of the pdf, - i.e., the marginal pdf integrate to the weight of the continuous part. - - Parameters - ---------- - x : ``pandas.DataFrame`` or 2D ``np.ndarray`` - representing :math:`x`, as above - - Returns - ------- - ``pd.DataFrame`` with same columns and index as ``self`` - containing :math:`p_{X_{ij}}(x_{ij})`, as above - """ - distr_type = self.get_tag("distr:measuretype", "mixed", raise_error=False) - x = np.array(x) - if distr_type == "discrete": - return self._coerce_to_self_index_df(0, flatten=False) - - return self._boilerplate("_pdf", x=x) - - def log_pdf(self, x): - r"""Logarithmic probability density function. - - Numerically more stable than calling pdf and then taking logartihms. - - Let :math:`X` be a random variables with the distribution of ``self``, - taking values in `(N, n)` ``DataFrame``-s - Let :math:`x\in \mathbb{R}^{N\times n}`. - By :math:`p_{X_{ij}}`, denote the marginal pdf of :math:`X` at the - :math:`(i,j)`-th entry. - - The output of this method, for input ``x`` representing :math:`x`, - is a ``DataFrame`` with same columns and indices as ``self``, - and entries :math:`\log p_{X_{ij}}(x_{ij})`. - - If ``self`` has a mixed or discrete distribution, this returns - the weighted continuous part of `self`'s distribution instead of the pdf, - i.e., the marginal pdf integrate to the weight of the continuous part. - - Parameters - ---------- - x : ``pandas.DataFrame`` or 2D ``np.ndarray`` - representing :math:`x`, as above - - Returns - ------- - ``pd.DataFrame`` with same columns and index as ``self`` - containing :math:`\log p_{X_{ij}}(x_{ij})`, as above - """ - distr_type = self.get_tag("distr:measuretype", "mixed", raise_error=False) - x = np.array(x) - if distr_type == "discrete": - return self._coerce_to_self_index_df(-np.inf, flatten=False) - - return self._boilerplate("_log_pdf", x=x) - - def pmf(self, x): - r"""Probability mass function. - - Let :math:`X` be a random variables with the distribution of ``self``, - taking values in ``(N, n)`` ``DataFrame``-s - Let :math:`x\in \mathbb{R}^{N\times n}`. - By :math:`m_{X_{ij}}`, denote the marginal mass of :math:`X` at the - :math:`(i,j)`-th entry, i.e., - :math:`m_{X_{ij}}(x_{ij}) = \mathbb{P}(X_{ij} = x_{ij})`. - - The output of this method, for input ``x`` representing :math:`x`, - is a ``DataFrame`` with same columns and indices as ``self``, - and entries :math:`m_{X_{ij}}(x_{ij})`. - - If ``self`` has a mixed or discrete distribution, this returns - the weighted continuous part of `self`'s distribution instead of the pdf, - i.e., the marginal pdf integrate to the weight of the continuous part. - - Parameters - ---------- - x : ``pandas.DataFrame`` or 2D ``np.ndarray`` - representing :math:`x`, as above - - Returns - ------- - ``pd.DataFrame`` with same columns and index as ``self`` - containing :math:`p_{X_{ij}}(x_{ij})`, as above - """ - distr_type = self.get_tag("distr:measuretype", "mixed", raise_error=False) - if distr_type == "continuous": - return self._coerce_to_self_index_df(0, flatten=False) - - return self._boilerplate("_pmf", x=x) - - def log_pmf(self, x): - r"""Logarithmic probability mass function. - - Numerically more stable than calling pmf and then taking logartihms. - - Let :math:`X` be a random variables with the distribution of ``self``, - taking values in `(N, n)` ``DataFrame``-s - Let :math:`x\in \mathbb{R}^{N\times n}`. - By :math:`m_{X_{ij}}`, denote the marginal pdf of :math:`X` at the - :math:`(i,j)`-th entry, i.e., - :math:`m_{X_{ij}}(x_{ij}) = \mathbb{P}(X_{ij} = x_{ij})`. - - The output of this method, for input ``x`` representing :math:`x`, - is a ``DataFrame`` with same columns and indices as ``self``, - and entries :math:`\log m_{X_{ij}}(x_{ij})`. - - If ``self`` has a mixed or discrete distribution, this returns - the weighted continuous part of `self`'s distribution instead of the pdf, - i.e., the marginal pdf integrate to the weight of the continuous part. - - Parameters - ---------- - x : ``pandas.DataFrame`` or 2D ``np.ndarray`` - representing :math:`x`, as above - - Returns - ------- - ``pd.DataFrame`` with same columns and index as ``self`` - containing :math:`\log m_{X_{ij}}(x_{ij})`, as above - """ - distr_type = self.get_tag("distr:measuretype", "mixed", raise_error=False) - if distr_type == "continuous": - return self._coerce_to_self_index_df(-np.inf, flatten=False) - - return self._boilerplate("_log_pmf", x=x) - - def cdf(self, x): - r"""Cumulative distribution function. - - Let :math:`X` be a random variables with the distribution of ``self``, - taking values in ``(N, n)`` ``DataFrame``-s - Let :math:`x\in \mathbb{R}^{N\times n}`. - By :math:`F_{X_{ij}}`, denote the marginal cdf of :math:`X` at the - :math:`(i,j)`-th entry, - i.e., :math:`F_{X_{ij}}(t) = \mathbb{P}(X_{ij} \leq t)`. - - The output of this method, for input ``x`` representing :math:`x`, - is a ``DataFrame`` with same columns and indices as ``self``, - and entries :math:`F_{X_{ij}}(x_{ij})`. - - Parameters - ---------- - x : ``pandas.DataFrame`` or 2D ``np.ndarray`` - representing :math:`x`, as above - - Returns - ------- - ``pd.DataFrame`` with same columns and index as ``self`` - containing :math:`F_{X_{ij}}(x_{ij})`, as above - """ - x = np.array(x) - return self._boilerplate("_cdf", x=x) - - def ppf(self, p): - r"""Quantile function = percent point function = inverse cdf. - - Let :math:`X` be a random variables with the distribution of ``self``, - taking values in ``(N, n)`` ``DataFrame``-s - Let :math:`x\in \mathbb{R}^{N\times n}`. - By :math:`F_{X_{ij}}`, denote the marginal cdf of :math:`X` at the - :math:`(i,j)`-th entry. - - The output of this method, for input ``p`` representing :math:`p`, - is a ``DataFrame`` with same columns and indices as ``self``, - and entries :math:`F^{-1}_{X_{ij}}(p_{ij})`. - - Parameters - ---------- - p : ``pandas.DataFrame`` or 2D np.ndarray - representing :math:`p`, as above - - Returns - ------- - ``pd.DataFrame`` with same columns and index as ``self`` - containing :math:`F_{X_{ij}}(x_{ij})`, as above - """ - p = np.array(p) - return self._boilerplate("_ppf", p=p) - - def energy(self, x=None): - r"""Energy of self, w.r.t. self or a constant frame x. - - Let :math:`X, Y` be i.i.d. random variables with the distribution of ``self``. - - If ``x`` is ``None``, returns :math:`\mathbb{E}[|X-Y|]` (per row), - "self-energy". - If ``x`` is passed, returns :math:`\mathbb{E}[|X-x|]` (per row), "energy wrt x". - - The CRPS is related to energy: - it holds that - :math:`\mbox{CRPS}(\mbox{self}, y)` = `self.energy(y) - 0.5 * self.energy()`. - - Parameters - ---------- - x : None or pd.DataFrame, optional, default=None - if ``pd.DataFrame``, must have same rows and columns as ``self`` - - Returns - ------- - ``pd.DataFrame`` with same rows as ``self``, single column ``"energy"`` - each row contains one float, self-energy/energy as described above. - """ - if x is None: - return self._boilerplate("_energy_self", columns=["energy"]) - x = np.array(x) - return self._boilerplate("_energy_x", x=x, columns=["energy"]) diff --git a/skpro/distributions/tests/test_all_distrs.py b/skpro/distributions/tests/test_all_distrs.py new file mode 100644 index 000000000..1fed1c5c9 --- /dev/null +++ b/skpro/distributions/tests/test_all_distrs.py @@ -0,0 +1,316 @@ +"""Tests for BaseDistribution API points.""" +# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) +# adapted from sktime + +__author__ = ["fkiraly", "Alex-JG3"] + +import numpy as np +import pandas as pd +import pytest +from skbase.testing import QuickTester + +from skpro.datatypes import check_is_mtype +from skpro.tests.test_all_estimators import BaseFixtureGenerator, PackageConfig +from skpro.utils.index import random_ss_ix + + +class DistributionFixtureGenerator(BaseFixtureGenerator): + """Fixture generator for probability distributions. + + Fixtures parameterized + ---------------------- + object_class: object inheriting from BaseObject + ranges over object classes not excluded by EXCLUDE_OBJECTS, EXCLUDED_TESTS + object_instance: instance of object inheriting from BaseObject + ranges over object classes not excluded by EXCLUDE_OBJECTS, EXCLUDED_TESTS + instances are generated by create_test_instance class method + """ + + object_type_filter = "distribution" + + +def _has_capability(distr, method): + """Check whether distr has capability of method. + + Parameters + ---------- + distr : BaseDistribution object + method : str + method name to check + + Returns + ------- + whether distr has capability method, according to tags + capabilities:approx and capabilities:exact + """ + approx_methods = distr.get_tag("capabilities:approx") + exact_methods = distr.get_tag("capabilities:exact") + return method in approx_methods or method in exact_methods + + +METHODS_SCALAR = ["mean", "var", "energy"] +METHODS_SCALAR_POS = ["var", "energy"] # result always non-negative? +METHODS_X = ["energy", "pdf", "log_pdf", "pmf", "log_pmf", "cdf"] +METHODS_X_POS = ["energy", "pdf", "pmf", "cdf", "surv", "haz"] # result non-negative? +METHODS_P = ["ppf"] +METHODS_ROWWISE = ["energy"] # results in one column + + +class TestAllDistributions(PackageConfig, DistributionFixtureGenerator, QuickTester): + """Module level tests for all skpro parameter fitters.""" + + # TEMPORARY skip for CyclicBoosting and QPD classes + # due to silent failures on main, se #190 + exclude_objects = ["QPD_B"] + # remove this when fixing failures to re-enable testing + + def test_shape(self, object_instance): + """Test index, columns, len and shape of distribution.""" + d = object_instance + + assert hasattr(d, "shape") + assert isinstance(d.shape, tuple) + assert len(d.shape) in [0, 2] + + if len(d.shape) == 2: + assert all(isinstance(n, int) for n in d.shape) + + assert isinstance(d.index, pd.Index) + assert isinstance(d.columns, pd.Index) + + assert d.shape[0] == len(d.index) + assert d.shape[1] == len(d.columns) + + assert isinstance(len(d), int) + + if len(d.shape) == 2: + assert len(d) == d.shape[0] + else: + assert len(d) == 1 + + assert hasattr(d, "ndim") + assert d.ndim == len(d.shape) + + @pytest.mark.parametrize("shuffled", [False, True]) + def test_sample(self, object_instance, shuffled): + """Test sample expected return.""" + d = object_instance + + if shuffled: + d = _shuffle_distr(d) + + res = d.sample() + + if d.ndim > 0: + assert d.shape == res.shape + assert (res.index == d.index).all() + assert (res.columns == d.columns).all() + else: # d.ndim = 0 + assert np.isscalar(res) + + res_panel = d.sample(3) + if d.ndim > 0: + dummy_panel = pd.concat([res, res, res], keys=range(3)) + else: + dummy_panel = pd.DataFrame(index=range(3), columns=range(1)) + assert dummy_panel.shape == res_panel.shape + assert (res_panel.index == dummy_panel.index).all() + assert (res_panel.columns == dummy_panel.columns).all() + + @pytest.mark.parametrize("shuffled", [False, True]) + @pytest.mark.parametrize("method", METHODS_SCALAR, ids=METHODS_SCALAR) + def test_methods_scalar(self, object_instance, method, shuffled): + """Test expected return of scalar methods.""" + if not _has_capability(object_instance, method): + return None + + d = object_instance + if shuffled: + d = _shuffle_distr(d) + + res = getattr(d, method)() + + _check_output_format(res, d, method) + + @pytest.mark.parametrize("shuffled", [False, True]) + @pytest.mark.parametrize("method", METHODS_X, ids=METHODS_X) + def test_methods_x(self, object_instance, method, shuffled): + """Test expected return of methods that take sample-like argument.""" + if not _has_capability(object_instance, method): + return None + + d = object_instance + + if shuffled: + d = _shuffle_distr(d) + + x = d.sample() + res = getattr(d, method)(x) + + _check_output_format(res, d, method) + + @pytest.mark.parametrize("shuffled", [False, True]) + @pytest.mark.parametrize("method", METHODS_P, ids=METHODS_P) + def test_methods_p(self, object_instance, method, shuffled): + """Test expected return of methods that take percentage-like argument.""" + if not _has_capability(object_instance, method): + return None + + d = object_instance + + if shuffled: + d = _shuffle_distr(d) + + np_unif = np.random.uniform(size=d.shape) + if d.ndim > 0: + p = pd.DataFrame(np_unif, index=d.index, columns=d.columns) + else: + p = np_unif + + res = getattr(d, method)(p) + + _check_output_format(res, d, method) + + @pytest.mark.parametrize("q", [0.7, [0.1, 0.3, 0.9]]) + def test_quantile(self, object_instance, q): + """Test expected return of quantile method.""" + if not _has_capability(object_instance, "ppf"): + return None + + d = object_instance + + def _check_quantile_output(obj, q): + assert check_is_mtype( + obj, "pred_quantiles", "Proba", msg_return_dict="list" + ) + if d.ndim == 0: + expected_index = pd.RangeIndex(1) + vars = [d.__class__.__name__] + else: + expected_index = d.index + vars = d.columns + + assert (obj.index == expected_index).all() + + if not isinstance(q, list): + q = [q] + expected_columns = pd.MultiIndex.from_product([vars, q]) + assert (obj.columns == expected_columns).all() + + res = d.quantile(q) + _check_quantile_output(res, q) + + @pytest.mark.parametrize("subset_row", [True, False]) + @pytest.mark.parametrize("subset_col", [True, False]) + def test_subsetting(self, object_instance, subset_row, subset_col): + """Test subsetting of distribution.""" + d = object_instance + if d.ndim == 0: # no subsetting to test if example is scalar + return None + + if subset_row: + ix_loc = random_ss_ix(d.index, 3) + ix_iloc = d.index.get_indexer(ix_loc) + else: + ix_loc = d.index + ix_iloc = pd.RangeIndex(len(d.index)) + + if subset_col: + iy_loc = random_ss_ix(d.columns, 1) + iy_iloc = d.columns.get_indexer(iy_loc) + else: + iy_loc = d.columns + iy_iloc = pd.RangeIndex(len(d.columns)) + + res_loc = d.loc[ix_loc, iy_loc] + + assert isinstance(res_loc, type(d)) + assert res_loc.shape == (len(ix_loc), len(iy_loc)) + assert (res_loc.index == ix_loc).all() + assert (res_loc.columns == iy_loc).all() + + res_iloc = d.iloc[ix_iloc, iy_iloc] + + assert isinstance(res_iloc, type(d)) + assert res_iloc.shape == (len(ix_iloc), len(iy_iloc)) + assert (res_iloc.index == ix_loc).all() + assert (res_iloc.columns == iy_loc).all() + + def test_log_pdf_and_pdf(self, object_instance): + """Test that the log of the pdf and log_pdf function are similar.""" + d = object_instance + capabilities_exact = d.get_tags()["capabilities:exact"] + + if "log_pdf" not in capabilities_exact or "pdf" not in capabilities_exact: + return + x = d.sample() + pdf = d.pdf(x) + log_pdf = d.log_pdf(x) + assert np.allclose(np.log(pdf), log_pdf) + + def test_log_pmf_and_pmf(self, object_instance): + """Test that the log of the pmf and log_pmf function are similar.""" + d = object_instance + capabilities_exact = d.get_tags()["capabilities:exact"] + + if "log_pmf" not in capabilities_exact or "pmf" not in capabilities_exact: + return + x = d.sample() + pmf = d.pmf(x) + log_pmf = d.log_pmf(x) + assert np.allclose(np.log(pmf), log_pmf) + + def test_ppf_and_cdf(self, object_instance): + """Test that the ppf is the inverse of the cdf.""" + d = object_instance + capabilities_exact = d.get_tags()["capabilities:exact"] + + if "ppf" not in capabilities_exact or "cdf" not in capabilities_exact: + return + x = d.sample() + x_approx = d.ppf(d.cdf(x)) + if d.ndim > 0: + assert np.allclose(x.values, x_approx.values) + else: + assert np.allclose(x, x_approx) + + +def _check_output_format(res, dist, method): + """Check output format expectations for BaseDistribution tests.""" + if dist.shape == (): # scalar distribution case + # check if numpy float + assert np.isscalar(res) + assert np.isreal(res) + if method in METHODS_SCALAR_POS or method in METHODS_X_POS: + assert res >= 0 + return None + + # array distribution case + if method in METHODS_ROWWISE: + exp_shape = (dist.shape[0], 1) + else: + exp_shape = dist.shape + assert res.shape == exp_shape + assert (res.index == dist.index).all() + if method not in METHODS_ROWWISE: + assert (res.columns == dist.columns).all() + + if method in METHODS_SCALAR_POS or method in METHODS_X_POS: + assert (res >= 0).all().all() + + if isinstance(res, pd.DataFrame): + assert res.apply(pd.api.types.is_numeric_dtype).all() + elif isinstance(res, pd.Series): + assert pd.api.types.is_numeric_dtype(res) + else: + raise TypeError("res must be a pandas DataFrame or Series.") + + +def _shuffle_distr(d): + """Shuffle distribution row index.""" + if d.shape == (): # nothing to shuffle if scalar + return d + # shuffle rows otherwise + shuffled_df = pd.DataFrame(d.index).sample(frac=1) + shuffled_index = pd.Index(shuffled_df.values.flatten()) + return d.loc[shuffled_index] From f6f5f2b2f556d20188cef35789cf92b6c1803316 Mon Sep 17 00:00:00 2001 From: ShreeshaM07 Date: Fri, 14 Jun 2024 00:27:39 +0530 Subject: [PATCH 12/12] modified docstring --- skpro/distributions/histogram.py | 49 ++++++++++++++++---------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/skpro/distributions/histogram.py b/skpro/distributions/histogram.py index 9c919b63c..573dfe3f8 100644 --- a/skpro/distributions/histogram.py +++ b/skpro/distributions/histogram.py @@ -13,42 +13,43 @@ class Histogram(BaseArrayDistribution): """Histogram Probability Distribution. The histogram probability distribution is parameterized - by the bins and bin densities. + by the ``bins`` and ``bin_mass``. Parameters ---------- bins : tuple(float, float, int) or numpy.array of float 1D or 2D list of size m x n 1. tuple(first bin's start point, last bin's end point, number of bins) - Used when bin widths are equal. - example: - bins:(0,4,4) + Used when ``bin widths`` are ``equal``. + example: + bins:(0,4,4) 2. array has the bin boundaries with 1st element the first bin's starting point and rest are the bin ending points of all bins - example: - bins:[0, 1, 2, 3, 4] + example: + bins:[0, 1, 2, 3, 4] 3. 2D list of size m x n containing m*n float numpy.arrays or tuple - like case 1. - example: - bins: - [ - [[0, 1, 2, 3, 4], [5, 5.5, 5.8, 6.5, 7, 7.5]], - [(2, 12, 5), [0, 1, 2, 3, 4]], - [[1.5, 2.5, 3.1, 4, 5.4], [-4, -2, -1.5, 5, 10]] - ] + like ``case 1``. + example: + bins: + [ + [[0, 1, 2, 3, 4], [5, 5.5, 5.8, 6.5, 7, 7.5]], + [(2, 12, 5), [0, 1, 2, 3, 4]], + [[1.5, 2.5, 3.1, 4, 5.4], [-4, -2, -1.5, 5, 10]] + ] bin_mass : array of float 1D or 2D list of size m x n 1. Array has the mass of the bins or area of the bins. example: bin_mass:[0.1, 0.2, 0.3, 0.4] - Note: `len(bin_mass)` will be `(len(bins)-1)`. - Note: Sum of all the `bin_mass` must be `1`. - 2. 2D list of size m x n containing m*n float numpy.arrays satisfying case 1. - example: - bin_mass: - [ - [[0.1, 0.2, 0.3, 0.4], [0.25, 0.1, 0, 0.4, 0.25]], - [[0.1, 0.2, 0.4, 0.2, 0.1], [0.4, 0.3, 0.2, 0.1]], - [[0.06, 0.15, 0.09, 0.7], [0.4, 0.15, 0.325, 0.125]] - ] + ``Note``: ``len(bin_mass)`` will be ``(len(bins)-1)``. + ``Note``: ``sum(bin_mass)`` must be ``1``. + 2. 2D list of size m x n containing m*n float numpy.arrays + each satisfying ``case 1``. + example: + bin_mass: + [ + [[0.1, 0.2, 0.3, 0.4], [0.25, 0.1, 0, 0.4, 0.25]], + [[0.1, 0.2, 0.4, 0.2, 0.1], [0.4, 0.3, 0.2, 0.1]], + [[0.06, 0.15, 0.09, 0.7], [0.4, 0.15, 0.325, 0.125]] + ] index : pd.Index, optional, default = RangeIndex columns : pd.Index, optional, default = RangeIndex """