From 5889075fb2486d311a3c3d59a45302df43624a0e Mon Sep 17 00:00:00 2001 From: ShreeshaM07 <120820143+ShreeshaM07@users.noreply.github.com> Date: Thu, 16 May 2024 20:10:35 +0530 Subject: [PATCH 1/2] [MNT] `Exponential` dist in `NGBoostRegressor`, `NGBoostSurvival` (#332) This adds exponential distribution to `NGBoostRegressor` and `NGBoostSurvival`. Also refactors distribution adapter logic to a common location. --- .../adapters/ngboost/_ngboost_proba.py | 65 ++++++++++++++++++- skpro/regression/ensemble/_ngboost.py | 57 ++++------------ skpro/survival/ensemble/_ngboost_surv.py | 55 ++++------------ 3 files changed, 88 insertions(+), 89 deletions(-) diff --git a/skpro/regression/adapters/ngboost/_ngboost_proba.py b/skpro/regression/adapters/ngboost/_ngboost_proba.py index 1de3c3830..2d620b095 100644 --- a/skpro/regression/adapters/ngboost/_ngboost_proba.py +++ b/skpro/regression/adapters/ngboost/_ngboost_proba.py @@ -32,7 +32,7 @@ def _dist_to_ngboost_instance(self, dist, survival=False): ------- NGBoost Distribution object. """ - from ngboost.distns import Laplace, LogNormal, Normal, Poisson, T + from ngboost.distns import Exponential, Laplace, LogNormal, Normal, Poisson, T ngboost_dists = { "Normal": Normal, @@ -40,6 +40,7 @@ def _dist_to_ngboost_instance(self, dist, survival=False): "TDistribution": T, "Poisson": Poisson, "LogNormal": LogNormal, + "Exponential": Exponential, } # default Normal distribution dist_ngboost = Normal @@ -52,6 +53,66 @@ def _dist_to_ngboost_instance(self, dist, survival=False): return dist_ngboost + def _ngb_skpro_dist_params( + self, + pred_dist, + index, + columns, + **kwargs, + ): + import numpy as np + + # The returned values of the Distributions from NGBoost + # are different. So based on that they are split into these + # categories of loc,scale,mu and s. + # Distribution type | Parameters + # ------------------|----------- + # Normal | loc = mean, scale = standard deviation + # TDistribution | loc = mean, scale = standard deviation + # Poisson | mu = mean + # LogNormal | s = standard deviation, scale = exp(mean) + # | (see scipy.stats.lognorm) + # Laplace | loc = mean, scale = scale parameter + # Exponential | scale = 1/rate + # Normal, Laplace, TDistribution and Poisson have not yet + # been implemented for Survival analysis. + + dist_params = { + "Normal": ["loc", "scale"], + "Laplace": ["loc", "scale"], + "TDistribution": ["loc", "scale"], + "Poisson": ["mu"], + "LogNormal": ["scale", "s"], + "Exponential": ["scale"], + } + + skpro_params = { + "Normal": ["mu", "sigma"], + "Laplace": ["mu", "scale"], + "TDistribution": ["mu", "sigma"], + "Poisson": ["mu"], + "LogNormal": ["mu", "sigma"], + "Exponential": ["rate"], + } + + if self.dist in dist_params and self.dist in skpro_params: + ngboost_params = dist_params[self.dist] + skp_params = skpro_params[self.dist] + for ngboost_param, skp_param in zip(ngboost_params, skp_params): + kwargs[skp_param] = pred_dist.params[ngboost_param] + if self.dist == "LogNormal" and ngboost_param == "scale": + kwargs[skp_param] = np.log(pred_dist.params[ngboost_param]) + if self.dist == "Exponential" and ngboost_param == "scale": + kwargs[skp_param] = 1 / pred_dist.params[ngboost_param] + + kwargs[skp_param] = self._check_y(y=kwargs[skp_param]) + # returns a tuple so taking only first index of the tuple + kwargs[skp_param] = kwargs[skp_param][0] + kwargs["index"] = index + kwargs["columns"] = columns + + return kwargs + def _ngb_dist_to_skpro(self, **kwargs): """Convert NGBoost distribution object to skpro BaseDistribution object. @@ -64,6 +125,7 @@ def _ngb_dist_to_skpro(self, **kwargs): skpro_dist (skpro.distributions.BaseDistribution): Converted skpro distribution object. """ + from skpro.distributions.exponential import Exponential from skpro.distributions.laplace import Laplace from skpro.distributions.lognormal import LogNormal from skpro.distributions.normal import Normal @@ -76,6 +138,7 @@ def _ngb_dist_to_skpro(self, **kwargs): "TDistribution": TDistribution, "Poisson": Poisson, "LogNormal": LogNormal, + "Exponential": Exponential, } skpro_dist = None diff --git a/skpro/regression/ensemble/_ngboost.py b/skpro/regression/ensemble/_ngboost.py index 4a09871f9..5abc3af6e 100644 --- a/skpro/regression/ensemble/_ngboost.py +++ b/skpro/regression/ensemble/_ngboost.py @@ -3,8 +3,6 @@ __author__ = ["ShreeshaM07"] -import numpy as np - from skpro.regression.adapters.ngboost._ngboost_proba import NGBoostAdapter from skpro.regression.base import BaseProbaRegressor @@ -28,6 +26,7 @@ class NGBoostRegressor(BaseProbaRegressor, NGBoostAdapter): 3. "LogNormal" 4. "Poisson" 5. "TDistribution" + 6. "Exponential" score : string , default = "LogScore" A score from ngboost.scores for LogScore rule to compare probabilistic @@ -227,49 +226,13 @@ def _predict_proba(self, X): """ X = self._check_X(X) - # The returned values of the Distributions from NGBoost - # are different. So based on that they are split into these - # categories of loc,scale,mu and s. - # Distribution type | Parameters - # ------------------|----------- - # Normal | loc = mean, scale = standard deviation - # TDistribution | loc = mean, scale = standard deviation - # Poisson | mu = mean - # LogNormal | s = standard deviation, scale = exp(mean) - # | (see scipy.stats.lognorm) - # Laplace | loc = mean, scale = scale parameter - - dist_params = { - "Normal": ["loc", "scale"], - "Laplace": ["loc", "scale"], - "TDistribution": ["loc", "scale"], - "Poisson": ["mu"], - "LogNormal": ["scale", "s"], - } - - skpro_params = { - "Normal": ["mu", "sigma"], - "Laplace": ["mu", "scale"], - "TDistribution": ["mu", "sigma"], - "Poisson": ["mu"], - "LogNormal": ["mu", "sigma"], - } - kwargs = {} + pred_dist = self._pred_dist(X) + index = X.index + columns = self._y_cols - if self.dist in dist_params and self.dist in skpro_params: - ngboost_params = dist_params[self.dist] - skp_params = skpro_params[self.dist] - for ngboost_param, skp_param in zip(ngboost_params, skp_params): - kwargs[skp_param] = self._pred_dist(X).params[ngboost_param] - if self.dist == "LogNormal" and ngboost_param == "scale": - kwargs[skp_param] = np.log(self._pred_dist(X).params[ngboost_param]) - - kwargs[skp_param] = self._check_y(y=kwargs[skp_param]) - # returns a tuple so taking only first index of the tuple - kwargs[skp_param] = kwargs[skp_param][0] - kwargs["index"] = X.index - kwargs["columns"] = self._y_cols + # Convert NGBoost Distribution return params into a dict + kwargs = self._ngb_skpro_dist_params(pred_dist, index, columns, **kwargs) # Convert NGBoost Distribution to skpro BaseDistribution pred_dist = self._ngb_dist_to_skpro(**kwargs) @@ -317,4 +280,10 @@ def get_test_params(cls, parameter_set="default"): "verbose": False, } - return [params1, params2, params3, params4, params5, params6] + params7 = { + "dist": "Exponential", + "n_estimators": 800, + "verbose_eval": 50, + } + + return [params1, params2, params3, params4, params5, params6, params7] diff --git a/skpro/survival/ensemble/_ngboost_surv.py b/skpro/survival/ensemble/_ngboost_surv.py index 9a95f741f..6e9ad62b0 100644 --- a/skpro/survival/ensemble/_ngboost_surv.py +++ b/skpro/survival/ensemble/_ngboost_surv.py @@ -25,6 +25,7 @@ class NGBoostSurvival(BaseSurvReg, NGBoostAdapter): A distribution from ngboost.distns, e.g. LogNormal Available distribution types 1. "LogNormal" + 2. "Exponential" score : string , default = "LogScore" rule to compare probabilistic predictions PĚ‚ to the observed data y. A score from ngboost.scores, e.g. LogScore @@ -228,51 +229,13 @@ def _predict_proba(self, X): """ X = self._check_X(X) - # The returned values of the Distributions from NGBoost - # are different. So based on that they are split into these - # categories of loc,scale,mu and s. - # Distribution type | Parameters - # ------------------|----------- - # Normal | loc = mean, scale = standard deviation - # TDistribution | loc = mean, scale = standard deviation - # Poisson | mu = mean - # LogNormal | s = standard deviation, scale = exp(mean) - # | (see scipy.stats.lognorm) - # Laplace | loc = mean, scale = scale parameter - # Normal, Laplace, TDistribution and Poisson have not yet - # been implemented for Survival analysis. - - dist_params = { - "Normal": ["loc", "scale"], - "Laplace": ["loc", "scale"], - "TDistribution": ["loc", "scale"], - "Poisson": ["mu"], - "LogNormal": ["scale", "s"], - } - - skpro_params = { - "Normal": ["mu", "sigma"], - "Laplace": ["mu", "scale"], - "TDistribution": ["mu", "sigma"], - "Poisson": ["mu"], - "LogNormal": ["mu", "sigma"], - } - kwargs = {} + pred_dist = self._pred_dist(X) + index = X.index + columns = self._y_cols - if self.dist in dist_params and self.dist in skpro_params: - ngboost_params = dist_params[self.dist] - skp_params = skpro_params[self.dist] - for ngboost_param, skp_param in zip(ngboost_params, skp_params): - kwargs[skp_param] = self._pred_dist(X).params[ngboost_param] - if self.dist == "LogNormal" and ngboost_param == "scale": - kwargs[skp_param] = np.log(self._pred_dist(X).params[ngboost_param]) - - kwargs[skp_param] = self._check_y(y=kwargs[skp_param]) - # returns a tuple so taking only first index of the tuple - kwargs[skp_param] = kwargs[skp_param][0] - kwargs["index"] = X.index - kwargs["columns"] = self._y_cols + # Convert NGBoost Distribution return params into a dict + kwargs = self._ngb_skpro_dist_params(pred_dist, index, columns, **kwargs) # Convert NGBoost Distribution to skpro BaseDistribution pred_dist = self._ngb_dist_to_skpro(**kwargs) @@ -306,5 +269,9 @@ def get_test_params(cls, parameter_set="default"): "n_estimators": 800, "minibatch_frac": 0.8, } + params4 = { + "dist": "Exponential", + "n_estimators": 600, + } - return [params1, params2, params3] + return [params1, params2, params3, params4] From dad8793497bd83a3f248495319794b190eaad87b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 16 May 2024 15:42:52 +0100 Subject: [PATCH 2/2] [ENH] native implementation of Johnson QPD family, explicit pdf (#327) This PR reworks the family of QPD family of distributions for efficiency and to allow removal of the newly introduced dependency `findiff` in #232. The dependency `findiff` was introduced for approximation of `pdf`, but in fact it is unnecessary as the `pdf` can be analytically derived by applying the chain rule. True, it has to be applied three or four times, but it's still the chain rule... efficiency and accuracy gains are significant, and it helps us avoid computing numerical derivatives for all entries in a large matrix, together with the now unnecessary `findiff` dependency. Makes the following changes: * refactoring of the three QPD distributions tp use `skpro` machinery: * use of the `skpro` native parameter broadcasting system instead of ad-hoc broadcasting * use of the `skpro` native approximation for `mean`, `var`, instead of three copies of similar (and partially duplicative) approximation inside the distributions * refactoring between the three QPD distributions with the end of simplification * refactoring QPD parameter computation into a single, fully vectorized function, `_prep_qpd_vars` * clean room reimplementation of `cdf`, `ppf` of the three distributions based on the `cyclic_boosting` reference * new implementation of `pdf`, as derivative of `cdf` As side effects of the rework: * all parameters now broadcast in numpy-like fashion, including `alpha`, `lower`, `upper`, which previously was not possible * the distributions can be 2D with more than 1 column, which previously was not possible * `version` (the base distribution) can now be an arbitrary continuous `scipy` distribution * `pdf` is numerically exact * the distributions do not have soft dependencies anymore Regarding the relation to `cyclic_boosting`: * this is clean room reimplementation and credit is given, so I hope this is fine license-wise - @felixwick? * this is the result of trying to remove the `findiff` dependency for computing the `pdf` from the `cdf` that was introduced in #232, as well as cleanup before release. I ended up simplifying a lot, ending up here. In this sense, the work of @setoguchi-naoki was crucial in arriving at this point. * I would have no issue at all with you moving the improved code into `cyclic_boosting`. We can even restore the dependency and maintain the distribution logic in `cyclic_boosting` if that were your preference, e.g., for ownership reasons. --- skpro/distributions/qpd.py | 819 +++++++++++--------------- skpro/distributions/tests/test_qpd.py | 17 + 2 files changed, 358 insertions(+), 478 deletions(-) diff --git a/skpro/distributions/qpd.py b/skpro/distributions/qpd.py index dc724c524..4b7c7c4c0 100644 --- a/skpro/distributions/qpd.py +++ b/skpro/distributions/qpd.py @@ -9,13 +9,8 @@ "setoguchi-naoki", ] # interface only. Cyclic boosting authors in cyclic_boosting package -import typing -import warnings from typing import Sequence -if typing.TYPE_CHECKING: - from cyclic_boosting.quantile_matching import J_QPD_S, J_QPD_B - import numpy as np import pandas as pd from scipy.stats import logistic, norm @@ -85,7 +80,6 @@ class QPD_Johnson(_DelegatedDistribution): # -------------- "authors": ["setoguchi-naoki", "felix-wick", "fkiraly"], "maintainers": ["setoguchi-naoki"], - "python_dependencies": ["cyclic_boosting>=1.4.0", "findiff"], # estimator tags # -------------- "capabilities:approx": ["pdfnorm", "energy"], @@ -158,9 +152,9 @@ def get_test_params(cls, parameter_set="default"): params2 = { "alpha": 0.1, "version": "normal", - "qv_low": [0.2, 0.2, 0.2], - "qv_median": [0.5, 0.5, 0.5], - "qv_high": [0.8, 0.8, 0.8], + "qv_low": [[-0.3], [-0.2], [-0.1]], + "qv_median": [[-0.1], [0.0], [0.1]], + "qv_high": [[0.2], [0.3], [0.4]], "index": pd.Index([1, 2, 5]), "columns": pd.Index(["a"]), } @@ -231,14 +225,22 @@ class QPD_S(BaseDistribution): _tags = { # packaging info # -------------- - "authors": ["setoguchi-naoki", "felix-wick"], + "authors": ["setoguchi-naoki", "felix-wick", "fkiraly"], "maintainers": ["setoguchi-naoki"], - "python_dependencies": ["cyclic_boosting>=1.4.0", "findiff"], # estimator tags # -------------- "capabilities:approx": ["pdfnorm", "energy"], "capabilities:exact": ["mean", "var", "cdf", "ppf", "pdf"], "distr:measuretype": "continuous", + "broadcast_init": "on", + "broadcast_params": [ + "alpha", + "qv_low", + "qv_median", + "qv_high", + "lower", + "upper", + ], } def __init__( @@ -263,138 +265,78 @@ def __init__( self.index = index self.columns = columns - from cyclic_boosting.quantile_matching import J_QPD_S + super().__init__(index=index, columns=columns) - qv_low, qv_median, qv_high = _prep_qpd_params(qv_low, qv_median, qv_high) + # precompute parameters for methods + phi = _resolve_phi(version) + self.phi = phi - if index is None: - index = pd.RangeIndex(qv_low.shape[0]) - self.index = index + qpd_params = _prep_qpd_vars(phi=phi, mode="S", **self._bc_params) + self._qpd_params = qpd_params - if columns is None: - columns = pd.RangeIndex(1) - self.columns = columns + def _ppf(self, p: np.ndarray): + """Quantile function = percent point function = inverse cdf.""" + lower = self._bc_params["lower"] + delta = self._qpd_params["delta"] + kappa = self._qpd_params["kappa"] + c = self._qpd_params["c"] + n = self._qpd_params["n"] + theta = self._qpd_params["theta"] - self._shape = (qv_low.shape[0], 1) + phi = self.phi - if version == "normal": - self.phi = norm() - elif version == "logistic": - self.phi = logistic() - else: - raise Exception("Invalid version.") - - if (np.any(qv_low > qv_median)) or np.any(qv_high < qv_median): - warnings.warn( - "The SPT values are not monotonically increasing, " - "each SPT is sorted by value", - stacklevel=2, - ) - idx = np.where((qv_low > qv_median), True, False) + np.where( - (qv_high < qv_median), True, False - ) - un_orderd_idx = np.argwhere(idx > 0).tolist() - warnings.warn(f"sorted index {un_orderd_idx}", stacklevel=2) - for idx in un_orderd_idx: - low, mid, high = sorted([qv_low[idx], qv_median[idx], qv_high[idx]]) - qv_low[idx] = low - qv_median[idx] = mid - qv_high[idx] = high - - self.qpd = J_QPD_S( - alpha=alpha, - qv_low=qv_low, - qv_median=qv_median, - qv_high=qv_high, - l=self.lower, - version=version, - ) - super().__init__(index=index, columns=columns) + in_sinh = np.arcsinh(phi.ppf(p) * delta) + np.arcsinh(n * c * delta) + in_exp = kappa * np.sinh(in_sinh) + ppf_arr = lower + theta * np.exp(in_exp) - def _mean(self): - """Return expected value of the distribution. - - Please set the upper and lower limits of the random variable correctly. - - Returns - ------- - pd.DataFrame with same rows, columns as `self` - expected value of distribution (entry-wise) - """ - params = self.get_params(deep=False) - lower = params["lower"] - upper = params["upper"] - index = params["index"] - x = np.linspace(lower, upper, num=int(1e3)) - cdf = self.qpd.cdf(x) - if cdf.ndim < 2: - cdf = cdf[:, np.newaxis] - loc = exp_func(x, cdf.T, index.shape[0]) - return loc - - def _var(self): - """Return element/entry-wise variance of the distribution. - - Please set the upper and lower limits of the random variable correctly. - - Returns - ------- - pd.DataFrame with same rows, columns as `self` - variance of distribution (entry-wise) - """ - params = self.get_params(deep=False) - lower = params["lower"] - upper = params["upper"] - index = params["index"] - mean = self.mean().values - x = np.linspace(lower, upper, num=int(1e3)) - cdf = self.qpd.cdf(x) - if cdf.ndim < 2: - cdf = cdf[:, np.newaxis] - var = var_func(x, mean, cdf.T, index.shape[0]) - return var + return ppf_arr def _pdf(self, x: np.ndarray): - """Probability density function. + """Probability density function.""" + lower = self._bc_params["lower"] + delta = self._qpd_params["delta"] + kappa = self._qpd_params["kappa"] + c = self._qpd_params["c"] + n = self._qpd_params["n"] + theta = self._qpd_params["theta"] - this fucntion transform cdf to pdf - because j-qpd's pdf calculation is bit complex - """ - return pdf_func(x, self.qpd) + phi = self.phi - def _ppf(self, p: np.ndarray): - """Quantile function = percent point function = inverse cdf.""" - params = self.get_params(deep=False) - index = params["index"] - columns = params["columns"] - qv_low = params["qv_low"] - p_unique = np.unique(p) # de-broadcast - ppf_all = ppf_func(p_unique, self.qpd) - ppf_map = np.tile(p_unique, (qv_low.size, 1)).T - ppf = np.zeros((index.shape[0], len(columns))) - for r in range(p.shape[0]): - for c in range(p.shape[1]): - t = np.where(ppf_map[:, c] == p[r][c]) - ppf_part = ppf_all[t][c] - ppf[r][c] = ppf_part - return ppf + # we work through the chain rule for the entire nested expression in cdf + x_ = (x - lower) / theta + x_der = 1 / theta + + in_arcsinh = np.log(x_) / kappa + in_arcsinh_der = x_der / (kappa * x_) + + in_sinh = np.arcsinh(in_arcsinh) - np.arcsinh(n * c * delta) + in_sinh_der = arcsinh_der(in_arcsinh) * in_arcsinh_der + + in_cdf = np.sinh(in_sinh) / delta + in_cdf_der = np.cosh(in_sinh) * in_sinh_der / delta + + # cdf_arr = phi.cdf(in_cdf) + cdf_arr_der = phi.pdf(in_cdf) * in_cdf_der + + pdf_arr = cdf_arr_der + return pdf_arr def _cdf(self, x: np.ndarray): """Cumulative distribution function.""" - params = self.get_params(deep=False) - index = params["index"] - columns = params["columns"] - qv_low = params["qv_low"] - x_unique = np.unique(x) # de-broadcast - cdf_all = cdf_func(x_unique, self.qpd) - cdf_map = np.tile(x_unique, (qv_low.size, 1)).T - cdf = np.zeros((index.shape[0], len(columns))) - for r in range(x.shape[0]): - for c in range(x.shape[1]): - t = np.where(cdf_map[:, c] == x[r][c]) - cdf_part = cdf_all[t][c] - cdf[r][c] = cdf_part - return cdf + lower = self._bc_params["lower"] + delta = self._qpd_params["delta"] + kappa = self._qpd_params["kappa"] + c = self._qpd_params["c"] + n = self._qpd_params["n"] + theta = self._qpd_params["theta"] + + phi = self.phi + + in_arcsinh = np.log((x - lower) / theta) / kappa + in_sinh = np.arcsinh(in_arcsinh) - np.arcsinh(n * c * delta) + cdf_arr = phi.cdf(np.sinh(in_sinh) / delta) + + return cdf_arr @classmethod def get_test_params(cls, parameter_set="default"): @@ -412,9 +354,9 @@ def get_test_params(cls, parameter_set="default"): params2 = { "alpha": 0.2, "version": "normal", - "qv_low": [-0.3, -0.3, -0.3], - "qv_median": [0.0, 0.0, 0.0], - "qv_high": [0.3, 0.3, 0.3], + "qv_low": [[-0.3], [-0.2], [-0.1]], + "qv_median": [[-0.1], [0.0], [0.1]], + "qv_high": [[0.2], [0.3], [0.4]], "lower": -0.5, "index": pd.RangeIndex(3), "columns": pd.Index(["a"]), @@ -433,7 +375,7 @@ class QPD_B(BaseDistribution): Parameters ---------- - alpha : float + alpha : float or array_like[float] lower quantile of SPT (upper is ``1 - alpha``) qv_low : float or array_like[float] quantile function value of ``alpha`` @@ -441,11 +383,11 @@ class QPD_B(BaseDistribution): quantile function value of quantile 0.5 qv_high : float or array_like[float] quantile function value of quantile ``1 - alpha`` - lower : float + lower : float or array_like[float] lower bound of semi-bounded range. This is used when estimating QPD and calculating expectation and variance - upper : float + upper : float or array_like[float] upper bound of semi-bounded range. This is used when estimating QPD and calculating expectation and variance @@ -471,14 +413,22 @@ class QPD_B(BaseDistribution): _tags = { # packaging info # -------------- - "authors": ["setoguchi-naoki", "felix-wick"], + "authors": ["setoguchi-naoki", "felix-wick", "fkiraly"], "maintainers": ["setoguchi-naoki"], - "python_dependencies": ["cyclic_boosting>=1.4.0", "findiff"], # estimator tags # -------------- "capabilities:approx": ["pdfnorm", "energy"], "capabilities:exact": ["mean", "var", "cdf", "ppf", "pdf"], "distr:measuretype": "continuous", + "broadcast_init": "on", + "broadcast_params": [ + "alpha", + "qv_low", + "qv_median", + "qv_high", + "lower", + "upper", + ], } def __init__( @@ -503,137 +453,80 @@ def __init__( self.index = index self.columns = columns - from cyclic_boosting.quantile_matching import J_QPD_B + super().__init__(index=index, columns=columns) - qv_low, qv_median, qv_high = _prep_qpd_params(qv_low, qv_median, qv_high) + # precompute parameters for methods + phi = _resolve_phi(version) + self.phi = phi - if index is None: - index = pd.RangeIndex(qv_low.shape[0]) - self.index = index + qpd_params = _prep_qpd_vars(phi=phi, mode="B", **self._bc_params) + self._qpd_params = qpd_params - if columns is None: - columns = pd.RangeIndex(1) - self.columns = columns + def _ppf(self, p: np.ndarray): + """Quantile function = percent point function = inverse cdf.""" + lower = self._bc_params["lower"] + rnge = self._qpd_params["rnge"] + delta = self._qpd_params["delta"] + kappa = self._qpd_params["kappa"] + c = self._qpd_params["c"] + n = self._qpd_params["n"] + xi = self._qpd_params["xi"] - if version == "normal": - self.phi = norm() - elif version == "logistic": - self.phi = logistic() - else: - raise Exception("Invalid version.") - - if (np.any(qv_low > qv_median)) or np.any(qv_high < qv_median): - warnings.warn( - "The SPT values are not monotonically increasing, " - "each SPT is sorted by value", - stacklevel=2, - ) - idx = np.where((qv_low > qv_median), True, False) + np.where( - (qv_high < qv_median), True, False - ) - un_orderd_idx = np.argwhere(idx > 0).tolist() - warnings.warn(f"sorted index {un_orderd_idx}", stacklevel=2) - for idx in un_orderd_idx: - low, mid, high = sorted([qv_low[idx], qv_median[idx], qv_high[idx]]) - qv_low[idx] = low - qv_median[idx] = mid - qv_high[idx] = high - - self.qpd = J_QPD_B( - alpha=alpha, - qv_low=qv_low, - qv_median=qv_median, - qv_high=qv_high, - l=self.lower, - u=self.upper, - version=version, - ) - super().__init__(index=index, columns=columns) + phi = self.phi - def _mean(self): - """Return expected value of the distribution. - - Please set the upper and lower limits of the random variable correctly. - - Returns - ------- - pd.DataFrame with same rows, columns as `self` - expected value of distribution (entry-wise) - """ - params = self.get_params(deep=False) - lower = params["lower"] - upper = params["upper"] - index = params["index"] - x = np.linspace(lower, upper, num=int(1e3)) - cdf = self.qpd.cdf(x) - if cdf.ndim < 2: - cdf = cdf[:, np.newaxis] - loc = exp_func(x, cdf.T, index.shape[0]) - return loc - - def _var(self): - """Return element/entry-wise variance of the distribution. - - Please set the upper and lower limits of the random variable correctly. - - Returns - ------- - pd.DataFrame with same rows, columns as `self` - variance of distribution (entry-wise) - """ - params = self.get_params(deep=False) - lower = params["lower"] - upper = params["upper"] - index = params["index"] - mean = self.mean().values - x = np.linspace(lower, upper, num=int(1e3)) - cdf = self.qpd.cdf(x) - if cdf.ndim < 2: - cdf = cdf[:, np.newaxis] - var = var_func(x, mean, cdf.T, index.shape[0]) - return var + in_cdf = xi + kappa * np.sinh(delta * (phi.ppf(p) + n * c)) + ppf_arr = lower + rnge * phi.cdf(in_cdf) + return ppf_arr def _pdf(self, x: np.ndarray): - """Probability density function. + """Probability density function.""" + lower = self._bc_params["lower"] + rnge = self._qpd_params["rnge"] + delta = self._qpd_params["delta"] + kappa = self._qpd_params["kappa"] + c = self._qpd_params["c"] + n = self._qpd_params["n"] + xi = self._qpd_params["xi"] - this fucntion transform cdf to pdf - because j-qpd's pdf calculation is bit complex - """ - return pdf_func(x, self.qpd) + phi = self.phi - def _ppf(self, p: np.ndarray): - """Quantile function = percent point function = inverse cdf.""" - params = self.get_params(deep=False) - index = params["index"] - columns = params["columns"] - qv_low = params["qv_low"] - p_unique = np.unique(p) # de-broadcast - ppf_all = ppf_func(p_unique, self.qpd) - ppf_map = np.tile(p_unique, (qv_low.size, 1)).T - ppf = np.zeros((index.shape[0], len(columns))) - for r in range(p.shape[0]): - for c in range(p.shape[1]): - t = np.where(ppf_map[:, c] == p[r][c]) - ppf_part = ppf_all[t][c] - ppf[r][c] = ppf_part - return ppf + # we work through the chain rule for the entire nested expression in cdf + x_ = (x - lower) / rnge + x_der = 1 / rnge + + phi_ppf = phi.ppf(x_) + # derivative of ppf at z is 1 / pdf(ppf(z)) + phi_ppf_der = x_der / phi.pdf(phi.ppf(x_)) + + in_arcsinh = (phi_ppf - xi) / kappa + in_arcsinh_der = phi_ppf_der / kappa + + in_cdf = np.arcsinh(in_arcsinh) / delta - n * c + in_cdf_der = arcsinh_der(in_arcsinh) * in_arcsinh_der / delta + + # cdf_arr = phi.cdf(in_cdf) + cdf_arr_der = phi.pdf(in_cdf) * in_cdf_der + + pdf_arr = cdf_arr_der + return pdf_arr def _cdf(self, x: np.ndarray): """Cumulative distribution function.""" - params = self.get_params(deep=False) - index = params["index"] - columns = params["columns"] - qv_low = params["qv_low"] - x_unique = np.unique(x) # de-broadcast - cdf_all = cdf_func(x_unique, self.qpd) - cdf_map = np.tile(x_unique, (qv_low.size, 1)).T - cdf = np.zeros((index.shape[0], len(columns))) - for r in range(x.shape[0]): - for c in range(x.shape[1]): - t = np.where(cdf_map[:, c] == x[r][c]) - cdf_part = cdf_all[t][c] - cdf[r][c] = cdf_part - return cdf + lower = self._bc_params["lower"] + rnge = self._qpd_params["rnge"] + delta = self._qpd_params["delta"] + kappa = self._qpd_params["kappa"] + c = self._qpd_params["c"] + n = self._qpd_params["n"] + xi = self._qpd_params["xi"] + + phi = self.phi + + phi_ppf = phi.ppf((x - lower) / rnge) + in_cdf = np.arcsinh((phi_ppf - xi) / kappa) / delta - n * c + cdf_arr = phi.cdf(in_cdf) + + return cdf_arr @classmethod def get_test_params(cls, parameter_set="default"): @@ -652,9 +545,9 @@ def get_test_params(cls, parameter_set="default"): params2 = { "alpha": 0.2, "version": "normal", - "qv_low": [-0.3, -0.3, -0.3], - "qv_median": [0.0, 0.0, 0.0], - "qv_high": [0.3, 0.3, 0.3], + "qv_low": [[-0.3], [-0.2], [-0.1]], + "qv_median": [[-0.1], [0.0], [0.1]], + "qv_high": [[0.2], [0.3], [0.4]], "lower": -0.5, "upper": 0.5, "index": pd.RangeIndex(3), @@ -712,14 +605,22 @@ class QPD_U(BaseDistribution): _tags = { # packaging info # -------------- - "authors": ["setoguchi-naoki", "felix-wick"], + "authors": ["setoguchi-naoki", "felix-wick", "fkiraly"], "maintainers": ["setoguchi-naoki"], - "python_dependencies": ["cyclic_boosting>=1.4.0", "findiff"], # estimator tags # -------------- "capabilities:approx": ["pdfnorm", "energy"], "capabilities:exact": ["mean", "var", "cdf", "ppf", "pdf"], "distr:measuretype": "continuous", + "broadcast_init": "on", + "broadcast_params": [ + "alpha", + "qv_low", + "qv_median", + "qv_high", + "lower", + "upper", + ], } def __init__( @@ -747,146 +648,65 @@ def __init__( self.index = index self.columns = columns - from cyclic_boosting.quantile_matching import J_QPD_extended_U + super().__init__(index=index, columns=columns) - qv_low, qv_median, qv_high = _prep_qpd_params(qv_low, qv_median, qv_high) + # precompute parameters for methods + phi = _resolve_phi(version) + self.phi = phi - if index is None: - index = pd.RangeIndex(qv_low.shape[0]) - self.index = index + qpd_params = _prep_qpd_vars(phi=phi, mode="U", **self._bc_params) + self._qpd_params = qpd_params - if columns is None: - columns = pd.RangeIndex(1) - self.columns = columns + def _ppf(self, p: np.ndarray): + """Quantile function = percent point function = inverse cdf.""" + alpha = self._bc_params["alpha"] + xi = self._qpd_params["xi"] + gamma = self._qpd_params["gamma"] + delta = self._qpd_params["delta"] + kappa = self._qpd_params["kappa"] - if version == "normal": - self.phi = norm() - elif version == "logistic": - self.phi = logistic() - else: - raise Exception("Invalid version.") - - if (np.any(qv_low > qv_median)) or np.any(qv_high < qv_median): - warnings.warn( - "The SPT values are not monotonically increasing, " - "each SPT is sorted by value", - stacklevel=2, - ) - idx = np.where((qv_low > qv_median), True, False) + np.where( - (qv_high < qv_median), True, False - ) - un_orderd_idx = np.argwhere(idx > 0).tolist() - warnings.warn(f"sorted index {un_orderd_idx}", stacklevel=2) - for idx in un_orderd_idx: - low, mid, high = sorted([qv_low[idx], qv_median[idx], qv_high[idx]]) - qv_low[idx] = low - qv_median[idx] = mid - qv_high[idx] = high - - iter = np.nditer(qv_low, flags=["c_index"]) - for _i in iter: - jqpd = J_QPD_extended_U( - alpha=alpha, - qv_low=qv_low[iter.index], - qv_median=qv_median[iter.index], - qv_high=qv_high[iter.index], - version=version, - shape=dist_shape, - ) - self.qpd.append(jqpd) + phi = self.phi - super().__init__(index=index, columns=columns) + width = phi.ppf(1 - alpha) + qs = phi.ppf(p) / width - def _mean(self): - """Return expected value of the distribution. - - Please set the upper and lower limits of the random variable correctly. - - Returns - ------- - pd.DataFrame with same rows, columns as `self` - expected value of distribution (entry-wise) - """ - params = self.get_params(deep=False) - lower = params["lower"] - upper = params["upper"] - index = params["index"] - cdf_arr = [] - x = np.linspace(lower, upper, num=int(1e3)) - for qpd in self.qpd: - cdf_arr.append(qpd.cdf(x)) - cdf = np.asarray(cdf_arr) - if cdf.ndim < 2: - cdf = cdf[:, np.newaxis] - loc = exp_func(x, cdf, index.shape[0]) - return loc - - def _var(self): - """Return element/entry-wise variance of the distribution. - - Please set the upper and lower limits of the random variable correctly. - - Returns - ------- - pd.DataFrame with same rows, columns as `self` - variance of distribution (entry-wise) - """ - params = self.get_params(deep=False) - lower = params["lower"] - upper = params["upper"] - index = params["index"] - mean = self.mean().values - cdf_list = [] - x = np.linspace(lower, upper, num=int(1e3)) - for qpd in self.qpd: - cdf_list.append(qpd.cdf(x)) - cdf = np.asarray(cdf_list) - if cdf.ndim < 2: - cdf = cdf[:, np.newaxis] - var = var_func(x, mean, cdf, index.shape[0]) - return var + ppf_arr = xi + kappa * np.sinh((qs - gamma) / delta) + return ppf_arr def _pdf(self, x: np.ndarray): - """Probability density function. + """Probability density function.""" + alpha = self._bc_params["alpha"] + xi = self._qpd_params["xi"] + gamma = self._qpd_params["gamma"] + delta = self._qpd_params["delta"] + kappa = self._qpd_params["kappa"] - this fucntion transform cdf to pdf - because j-qpd's pdf calculation is bit complex - """ - return pdf_func(x, self.qpd) + phi = self.phi - def _ppf(self, p: np.ndarray): - """Quantile function = percent point function = inverse cdf.""" - params = self.get_params(deep=False) - index = params["index"] - columns = params["columns"] - qv_low = params["qv_low"] - p_unique = np.unique(p) # de-broadcast - ppf_all = ppf_func(p_unique, self.qpd) - ppf_map = np.tile(p_unique, (qv_low.size, 1)).T - ppf = np.zeros((index.shape[0], len(columns))) - for r in range(p.shape[0]): - for c in range(p.shape[1]): - t = np.where(ppf_map[:, c] == p[r][c]) - ppf_part = ppf_all[t][c] - ppf[r][c] = ppf_part - return ppf + width = phi.ppf(1 - alpha) + + qs = gamma + delta * np.arcsinh((x - xi) / kappa) + qs_der = delta * arcsinh_der((x - xi) / kappa) / kappa + + # cdf_arr = phi.cdf(qs * width) + pdf_arr = phi.pdf(qs * width) * qs_der + return pdf_arr def _cdf(self, x: np.ndarray): """Cumulative distribution function.""" - params = self.get_params(deep=False) - index = params["index"] - columns = params["columns"] - qv_low = params["qv_low"] - x_unique = np.unique(x) # de-broadcast - cdf_all = cdf_func(x_unique, self.qpd) - cdf_map = np.tile(x_unique, (qv_low.size, 1)).T - cdf = np.zeros((index.shape[0], len(columns))) - for r in range(x.shape[0]): - for c in range(x.shape[1]): - t = np.where(cdf_map[:, c] == x[r][c]) - cdf_part = cdf_all[t][c] - cdf[r][c] = cdf_part - return cdf + alpha = self._bc_params["alpha"] + xi = self._qpd_params["xi"] + gamma = self._qpd_params["gamma"] + delta = self._qpd_params["delta"] + kappa = self._qpd_params["kappa"] + + phi = self.phi + + width = phi.ppf(1 - alpha) + qs = gamma + delta * np.arcsinh((x - xi) / kappa) + + cdf_arr = phi.cdf(qs * width) + return cdf_arr @classmethod def get_test_params(cls, parameter_set="default"): @@ -903,99 +723,142 @@ def get_test_params(cls, parameter_set="default"): params2 = { "alpha": 0.2, "version": "normal", - "qv_low": [-0.3, -0.3, -0.3], - "qv_median": [0.0, 0.0, 0.0], - "qv_high": [0.3, 0.3, 0.3], + "qv_low": [[-0.3], [-0.2], [-0.1]], + "qv_median": [[-0.1], [0.0], [0.1]], + "qv_high": [[0.2], [0.3], [0.4]], "index": pd.RangeIndex(3), "columns": pd.Index(["a"]), } return [params1, params2] -def calc_pdf(cdf: np.ndarray) -> np.ndarray: - """Return pdf value for all samples.""" - from findiff import FinDiff - - dx = 1e-6 - derivative = FinDiff(1, dx, 1) - pdf = np.asarray(derivative(cdf)) - return pdf - - -def exp_func(x: np.ndarray, cdf: np.ndarray, size: int): - """Return Expectation.""" - pdf = calc_pdf(cdf) - x = np.tile(x, (size, 1)) - loc = np.trapz(x * pdf, x, dx=1e-6, axis=1) - return loc - - -def var_func(x: np.ndarray, mu: np.ndarray, cdf: np.ndarray, size: int): - """Return Variance.""" - pdf = calc_pdf(cdf) - x = np.tile(x, (size, 1)) - var = np.trapz(((x - mu) ** 2) * pdf, x, dx=1e-6, axis=1) - return var - - -def pdf_func(x: np.ndarray, qpd: J_QPD_S | J_QPD_B | list): - """Return pdf value.""" - pdf = np.zeros_like(x) - for r in range(x.shape[0]): - for c in range(x.shape[1]): - element = x[r][c] - x0 = np.linspace(element, element + 1e-3, num=3) - if isinstance(qpd, list): - cdf = np.asarray([func.cdf(x0) for func in qpd]) - cdf = cdf.reshape(cdf.shape[0], -1) - else: - cdf = qpd.cdf(x0) - if cdf.ndim < 2: - for _ in range(2 - cdf.ndim): - cdf = cdf[:, np.newaxis] - cdf = cdf.T - pdf_part = calc_pdf(cdf) - pdf[r][c] = pdf_part[0][0] - return pdf - - -def ppf_func(x: np.ndarray, qpd: J_QPD_S | J_QPD_B | list): - """Return ppf value.""" - if isinstance(qpd, list): - ppf = np.asarray([func.ppf(x) for func in qpd]) - ppf = ppf.reshape(ppf.shape[0], -1) - else: - ppf = qpd.ppf(x) - if ppf.ndim < 2: - for _ in range(2 - ppf.ndim): - ppf = ppf[np.newaxis] - ppf = ppf.T - return ppf - - -def cdf_func(x: np.ndarray, qpd: J_QPD_S | J_QPD_B | list): - """Return cdf value.""" - if isinstance(qpd, list): - cdf = np.asarray([func.cdf(x) for func in qpd]) - cdf = cdf.reshape(cdf.shape[0], -1) +def _resolve_phi(phi): + """Resolve base distribution.""" + if phi == "normal": + return norm() + elif phi == "logistic": + return logistic() else: - cdf = qpd.cdf(x) - if cdf.ndim < 2: - for _ in range(2 - cdf.ndim): - cdf = cdf[np.newaxis] - cdf = cdf.T - return cdf - - -def _prep_qpd_params(qv_low, qv_median, qv_high): - """Prepare parameters for Johnson Quantile-Parameterized Distributions.""" - qv = [qv_low, qv_median, qv_high] - for i, instance in enumerate(qv): - if isinstance(instance, float): - qv[i] = np.array([qv[i]]) - elif isinstance(instance, Sequence): - qv[i] = np.asarray(qv[i]) - qv_low = qv[0].flatten() - qv_median = qv[1].flatten() - qv_high = qv[2].flatten() - return qv_low, qv_median, qv_high + return phi + + +def _prep_qpd_vars( + alpha, + qv_low, + qv_median, + qv_high, + lower, + upper, + phi, + mode="B", + **kwargs, +): + """Prepare parameters for Johnson Quantile-Parameterized Distributions. + + Parameters + ---------- + alpha : 2D np.array + lower quantile of SPT (upper is ``1 - alpha``) + qv_low : 2D np.array + quantile function value of ``alpha`` + qv_median : 2D np.array + quantile function value of quantile 0.5 + qv_high : 2D np.array + quantile function value of quantile ``1 - alpha`` + lower : 2D np.array + lower bound of range. + upper : 2D np.array + upper bound of range. + phi : scipy.stats.rv_continuous + base distribution + mode : str + options are ``B`` (default) or ``S`` + B = bounded mode, S = lower semi-bounded mode + """ + c = phi.ppf(1 - alpha) + + if mode == "U": + lower = 0 + + qll = qv_low - lower + qml = qv_median - lower + qhl = qv_high - lower + + if mode == "B": + rnge = upper - lower + + def tfun(x): + return phi.ppf(x / rnge) + + elif mode == "S": + tfun = np.log + elif mode == "U": + + def tfun(x): + return x + + L = tfun(qll) + H = tfun(qhl) + B = tfun(qml) + HL = H - L + BL = B - L + HB = H - B + LH2B = L + H - 2 * B + + HBL = np.where(BL < HB, BL, HB) + + n = np.where(LH2B > 0, 1, -1) + n = np.where(LH2B == 0, 0, n) + + if mode in ["B", "U"]: + xi = np.where(LH2B > 0, L, H) + xi = np.where(LH2B == 0, B, xi) + if mode == "S": + theta = np.where(LH2B > 0, qll, qhl) + theta = np.where(LH2B == 0, qml, theta) + if mode == "U": + theta = np.where(LH2B > 0, BL / HL, HB / HL) + + if mode in ["B", "S"]: + in_arccosh = HL / (2 * HBL) + delta_unn = np.arccosh(in_arccosh) + if mode == "S": + delta_unn = np.sinh(delta_unn) + delta = delta_unn / c + elif mode == "U": + delta = 1.0 / np.arccosh(1 / (2.0 * theta)) + delta = np.where(LH2B == 0, 1, delta) + + if mode == "B": + kappa = HL / np.sinh(2 * delta * c) + elif mode == "S": + kappa = HBL / (delta * c) + elif mode == "U": + kappa = HL / np.sinh(2.0 / delta) + kappa = np.where(LH2B == 0, HB, kappa) + + params = { + "c": c, + "L": L, + "H": H, + "B": B, + "n": n, + "delta": delta, + "kappa": kappa, + } + + if mode == "S": + params["theta"] = theta + if mode == "B": + params["rnge"] = rnge + if mode in ["B", "U"]: + params["xi"] = xi + if mode == "U": + params["gamma"] = -np.sign(LH2B) + + return params + + +def arcsinh_der(x): + """Return derivative of arcsinh.""" + return 1 / np.sqrt(1 + x**2) diff --git a/skpro/distributions/tests/test_qpd.py b/skpro/distributions/tests/test_qpd.py index 78c418c1a..0bf070516 100644 --- a/skpro/distributions/tests/test_qpd.py +++ b/skpro/distributions/tests/test_qpd.py @@ -1,5 +1,6 @@ """Tests for quantile-parameterized distributions.""" +import numpy as np import pytest from skpro.distributions.qpd import QPD_B, QPD_S, QPD_U @@ -24,6 +25,22 @@ def test_qpd_b_simple_use(): qpd.mean() +def test_qpd_b_pdf(): + """Test pdf of qpd with bounded mode.""" + # these parameters should produce a uniform on -0.5, 0.5 + qpd_linear = QPD_B( + alpha=0.2, + qv_low=-0.3, + qv_median=0, + qv_high=0.3, + lower=-0.5, + upper=0.5, + ) + x = np.linspace(-0.45, 0.45, 100) + pdf_vals = [qpd_linear.pdf(x_) for x_ in x] + np.testing.assert_allclose(pdf_vals, 1.0, rtol=1e-5) + + @pytest.mark.skipif( not run_test_for_class(QPD_S), reason="run test only if softdeps are present and incrementally (if requested)",