diff --git a/.all-contributorsrc b/.all-contributorsrc index 7d9c22ed..af77c211 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -205,6 +205,17 @@ "code", "doc" ] + }, + { + "login": "meh2135", + "name": "Mike Hankin", + "avatar_url": "https://avatars.githubusercontent.com/u/313774?v=4", + "profile": "https://github.com/meh2135", + "contributions": [ + "bug", + "code", + "test" + ] } ] } diff --git a/skpro/regression/base/_base.py b/skpro/regression/base/_base.py index c89232f6..73561708 100644 --- a/skpro/regression/base/_base.py +++ b/skpro/regression/base/_base.py @@ -72,7 +72,7 @@ def __rmul__(self, other): else: return NotImplemented - def fit(self, X, y, C=None): + def fit(self, X, y, C=None, sample_weight=None): """Fit regressor to training data. Writes to self: @@ -89,6 +89,8 @@ def fit(self, X, y, C=None): C : ignored, optional (default=None) censoring information for survival analysis All probabilistic regressors assume data to be uncensored + sample_weight : pandas DataFrame, same shape as y, default=None + sample weights to fit regressor to Returns ------- @@ -112,13 +114,20 @@ def fit(self, X, y, C=None): # set fitted flag to True self._is_fitted = True - - if not capa_surv: - return self._fit(X_inner, y_inner) + if sample_weight is None: + if not capa_surv: + return self._fit(X_inner, y_inner) + else: + return self._fit(X_inner, y_inner, C=C_inner) else: - return self._fit(X_inner, y_inner, C=C_inner) + if not capa_surv: + return self._fit(X_inner, y_inner, sample_weight=sample_weight) + else: + return self._fit( + X_inner, y_inner, C=C_inner, sample_weight=sample_weight + ) - def _fit(self, X, y, C=None): + def _fit(self, X, y, C=None, sample_weight=None): """Fit regressor to training data. Writes to self: @@ -130,6 +139,11 @@ def _fit(self, X, y, C=None): feature instances to fit regressor to y : pandas DataFrame, must be same length as X labels to fit regressor to + C : ignored, optional (default=None) + censoring information for survival analysis + All probabilistic regressors assume data to be uncensored + sample_weight : pandas DataFrame, same shape as y, default=None + sample weights to fit regressor to Returns ------- @@ -137,7 +151,7 @@ def _fit(self, X, y, C=None): """ raise NotImplementedError - def update(self, X, y, C=None): + def update(self, X, y, C=None, sample_weight=None): """Update regressor with a new batch of training data. Only estimators with the ``capability:update`` tag (value ``True``) @@ -159,6 +173,8 @@ def update(self, X, y, C=None): C : ignored, optional (default=None) censoring information for survival analysis All probabilistic regressors assume data to be uncensored + sample_weight : pandas DataFrame, same shape as y, default=None + sample weights to fit regressor to Returns ------- @@ -178,12 +194,20 @@ def update(self, X, y, C=None): if capa_surv: C_inner = check_ret["C_inner"] - if not capa_surv: - return self._update(X_inner, y_inner) + if sample_weight is None: + if not capa_surv: + return self._update(X_inner, y_inner) + else: + return self._update(X_inner, y_inner, C=C_inner) else: - return self._update(X_inner, y_inner, C=C_inner) + if not capa_surv: + return self._update(X_inner, y_inner, sample_weight=sample_weight) + else: + return self._update( + X_inner, y_inner, C=C_inner, sample_weight=sample_weight + ) - def _update(self, X, y, C=None): + def _update(self, X, y, C=None, sample_weight=None): """Update regressor with a new batch of training data. State required: @@ -198,6 +222,11 @@ def _update(self, X, y, C=None): feature instances to fit regressor to y : pandas DataFrame, must be same length as X labels to fit regressor to + C : ignored, optional (default=None) + censoring information for survival analysis + All probabilistic regressors assume data to be uncensored + sample_weight : pandas DataFrame, same shape as y, default=None + sample weights to fit regressor to Returns ------- diff --git a/skpro/regression/residual.py b/skpro/regression/residual.py index 143c76eb..1af1167b 100644 --- a/skpro/regression/residual.py +++ b/skpro/regression/residual.py @@ -3,8 +3,11 @@ __author__ = ["fkiraly"] +import warnings + import numpy as np import pandas as pd +from scipy.special import gamma from sklearn import clone from skpro.regression.base import BaseProbaRegressor @@ -12,6 +15,22 @@ from skpro.utils.sklearn import prep_skl_df +def half_t_correction(dof: float) -> float: + """Get expected value of absolute value of t-distributed variable with mu=0 sigma=1. + + For X ~ t(dof, 0, sigma), the expected value of the absolute value is + ``2 * sigma * sqrt(dof) * gamma((dof + 1) / 2) / + (sqrt(pi) * (dof - 1) * gamma(dof / 2))``. + So E[|X|] / half_t_correction(dof) is an estimate of sigma. + """ + return ( + 2 + * np.sqrt(dof) + * gamma((dof + 1) / 2) + / (np.sqrt(np.pi) * (dof - 1) * gamma(dof / 2)) + ) + + class ResidualDouble(BaseProbaRegressor): """Residual double regressor. @@ -154,7 +173,7 @@ def __init__( else: self.estimator_resid_ = clone(estimator_resid) - def _predict_residuals_cv(self, X, y, cv, est): + def _predict_residuals_cv(self, X, y, cv, est=None, sample_weight=None): """Predict out-of-sample residuals for y from X using cv. Parameters @@ -171,7 +190,8 @@ def _predict_residuals_cv(self, X, y, cv, est): y_pred : pandas DataFrame, same length as `X`, same columns as `y` in `fit` labels predicted for `X` """ - est = self.estimator_resid_ + if est is None: + est = self.estimator_resid_ method = "predict" y_pred = y.copy() @@ -179,12 +199,18 @@ def _predict_residuals_cv(self, X, y, cv, est): X_train = X.iloc[tr_idx] X_test = X.iloc[tt_idx] y_train = y[tr_idx] - fitted_est = clone(est).fit(X_train, y_train) + if sample_weight is None: + fitted_est = clone(est).fit(X_train, y_train) + else: + sample_weight_train = sample_weight[tr_idx] + fitted_est = clone(est).fit( + X_train, y_train, sample_weight=sample_weight_train + ) y_pred[tt_idx] = getattr(fitted_est, method)(X_test) return y_pred - def _fit(self, X, y): + def _fit(self, X, y, sample_weight=None): """Fit regressor to training data. Writes to self: @@ -196,6 +222,8 @@ def _fit(self, X, y): feature instances to fit regressor to y : pandas DataFrame, must be same length as X labels to fit regressor to + sample_weight : pandas DataFrame, same length as X, default=None + sample weights to fit regressor to Returns ------- @@ -215,8 +243,10 @@ def _fit(self, X, y): # flatten column vector to 1D array to avoid sklearn complaints y = y.values y = flatten_to_1D_if_colvector(y) - - est.fit(X, y) + if sample_weight is None: + est.fit(X, y) + else: + est.fit(X, y, sample_weight=sample_weight) if cv is None: y_pred = est.predict(X) @@ -229,6 +259,13 @@ def _fit(self, X, y): resids = (y - y_pred) ** 2 else: resids = residual_trafo.fit_transform(y - y_pred) + warnings.warn( + ( + "Arbitrary transforms will result in abberrant behavior in " + "the predict_proba method." + ), + stacklevel=2, + ) resids = flatten_to_1D_if_colvector(resids) @@ -241,7 +278,10 @@ def _fit(self, X, y): # coerce X to pandas DataFrame with string column names X_r = prep_skl_df(X_r, copy_df=True) - est_r.fit(X_r, resids) + if sample_weight is None: + est_r.fit(X_r, resids) + else: + est_r.fit(X_r, resids, sample_weight=sample_weight) return self @@ -295,6 +335,7 @@ def _predict_proba(self, X): est = self.estimator_ est_r = self.estimator_resid_ use_y_pred = self.use_y_pred + residual_trafo = self.residual_trafo distr_type = self.distr_type distr_loc_scale_name = self.distr_loc_scale_name distr_params = self.distr_params @@ -307,6 +348,8 @@ def _predict_proba(self, X): if distr_params is None: distr_params = {} + else: + distr_params = distr_params.copy() # predict location - this is the same as in _predict y_pred_loc = est.predict(X) @@ -325,6 +368,19 @@ def _predict_proba(self, X): X_r = prep_skl_df(X_r, copy_df=True) y_pred_scale = est_r.predict(X_r) + if residual_trafo == "absolute": + pass + elif residual_trafo == "squared": + y_pred_scale = np.sqrt(y_pred_scale) + else: + y_pred_scale = residual_trafo.inverse_transform(y_pred_scale) + warnings.warn( + ( + "Arbitrary residual transforms will result in unpredictable" + " behavior." + ), + stacklevel=2, + ) y_pred_scale = y_pred_scale.clip(min=min_scale) y_pred_scale = y_pred_scale.reshape(-1, n_cols) @@ -335,17 +391,69 @@ def _predict_proba(self, X): distr_type = Normal distr_loc_scale_name = ("mu", "sigma") + if residual_trafo == "absolute": + y_pred_scale = y_pred_scale / np.sqrt(2 / np.pi) elif distr_type == "Laplace": from skpro.distributions.laplace import Laplace distr_type = Laplace distr_loc_scale_name = ("mu", "scale") - elif distr_type in ["Cauchy", "t"]: + if residual_trafo == "squared": + y_pred_scale = y_pred_scale / np.sqrt(2.0) + elif distr_type == "t": from skpro.distributions.t import TDistribution distr_type = TDistribution distr_loc_scale_name = ("mu", "sigma") + # Extract degrees of freedom + df = distr_params["df"] + if residual_trafo == "absolute": + if df <= 1: + warnings.warn( + ( + "Both the t-distribution and the half t-distribution have " + "no first moment for df<=1, so predict_proba will result " + "in erratic behavior." + ), + stacklevel=2, + ) + y_pred_scale = y_pred_scale / half_t_correction(df) + elif residual_trafo == "squared": + if df <= 2: + warnings.warn( + ( + "t-distribution has no second moment for df <= 2, and no " + "first moment for df <= 1, so predict_proba will result " + "in erratic behavior." + ), + stacklevel=2, + ) + elif df <= 3: + warnings.warn( + ( + "Degrees of freedom less than 3 tends to yield poor" + " results for squared residuals." + ), + stacklevel=2, + ) + y_pred_scale = y_pred_scale / np.sqrt(df / (df - 2)) + elif distr_type == "Cauchy": + from skpro.distributions.t import TDistribution as CauchyDistribution + + warnings.warn( + ( + "Cauchy distribution has no first or second moments, so " + "predict_proba will result in erratic behavior." + ), + stacklevel=2, + ) + + distr_type = CauchyDistribution + distr_loc_scale_name = ("mu", "sigma") + distr_params = {"df": 1} + else: + raise NotImplementedError(f"distr_type {distr_type} not implemented") # collate all parameters for the distribution constructor # distribution params, if passed params = distr_params diff --git a/skpro/regression/tests/test_residual.py b/skpro/regression/tests/test_residual.py new file mode 100644 index 00000000..f11a3510 --- /dev/null +++ b/skpro/regression/tests/test_residual.py @@ -0,0 +1,126 @@ +"""Tests ResidualDouble regressor for uniform quantiles when model is correct.""" + +from typing import Dict, Literal, Optional + +import numpy as np +import pandas as pd +import pytest +from scipy import stats +from sklearn.linear_model import LinearRegression + +from skpro.regression.residual import ResidualDouble +from skpro.tests.test_switch import run_test_for_class + + +def held_out_cdf( + nn: int = 25_000, + distr_type: Literal["Laplace", "Normal", "t"] = "Laplace", + model: Literal["linear", "constant"] = "linear", + trafo: Literal["absolute", "squared"] = "absolute", + distr_params: Optional[Dict[str, float]] = None, +) -> pd.Series: + np.random.seed(42) + if distr_params is None: + distr_params = {} + else: + distr_params = distr_params.copy() + x_df = pd.DataFrame( + {"a": np.random.randn(nn), "b": np.random.randn(nn), "c": np.random.randn(nn)} + ).clip(-2, 2) + # DGP + if model == "linear": + loc_param_vec = pd.Series({"a": -1, "b": 1, "c": 0}) + log_scale_param_vec = pd.Series({"a": 0, "b": 0.01, "c": 0.5}) + loc_vec = x_df.dot(loc_param_vec) + log_scale_vec = x_df.dot(log_scale_param_vec).round(1) + else: + loc_vec = pd.Series(3.0, index=x_df.index) + log_scale_vec = pd.Series(0.0, index=x_df.index) + + if distr_type == "Laplace": + dist_cls = stats.laplace + elif distr_type == "Normal": + dist_cls = stats.norm + elif distr_type == "t": + dist_cls = stats.t + else: + raise ValueError(f"Distribution {distr_type} not supported") + dist = dist_cls(loc=loc_vec, scale=np.exp(log_scale_vec), **distr_params) + y = pd.DataFrame(dist.rvs((2, nn)).T, index=x_df.index, columns=["y0", "y1"]) + reg = ResidualDouble( + estimator=LinearRegression(), + estimator_resid=LinearRegression(), + distr_params=distr_params, + distr_type=distr_type, + residual_trafo=trafo, + # cv=KFold(n_splits=3), + ) + + reg.fit(x_df, y["y0"]) + pred = reg.predict_proba(x_df) + + cdf = pred.cdf(y[["y1"]])["y0"] + return cdf + + +@pytest.mark.skipif( + not run_test_for_class(ResidualDouble), + reason="run test only if softdeps are present and incrementally (if requested)", +) +@pytest.mark.parametrize( + "distr_type,distr_params", + [ + ("t", {"df": 5.1}), + ("t", {"df": 2.5}), + ("Laplace", None), + ("Normal", None), + ], +) +@pytest.mark.parametrize("trafo", ["absolute", "squared"]) +def test_residual_double_constant(distr_type, distr_params, trafo): + """Test validity of ResidualDouble regressor on a constant model.""" + Q_BINS = 4 + TOL_ALPHA = 0.001 + np.random.seed(42) + # Should be uniform(0,1) + held_out_quantiles = held_out_cdf( + model="constant", distr_type=distr_type, distr_params=distr_params, trafo=trafo + ) + # Counts of quantiles in bins + vc = pd.cut(held_out_quantiles, bins=np.linspace(0, 1, Q_BINS + 1)).value_counts() + # Expected counts under uniformity + e_vec = vc * vc.sum() / (Q_BINS * vc) + # Observed counts + o_vec = vc + # Chi-squared test + chsq = stats.chisquare(o_vec, e_vec, ddof=2) + # dist=1, ddf<3, trafo="squared" does very badly, hence the high tolerance + assert chsq.pvalue > TOL_ALPHA + + +@pytest.mark.skipif( + not run_test_for_class(ResidualDouble), + reason="run test only if softdeps are present and incrementally (if requested)", +) +def test_residual_double_sample_weight(): + """Test validity of ResidualDouble regressor on a constant model.""" + trafo = "absolute" + distr_type = "Laplace" + distr_params = None + Q_BINS = 4 + TOL_ALPHA = 0.001 + np.random.seed(42) + # Should be uniform(0,1) + held_out_quantiles = held_out_cdf( + model="constant", distr_type=distr_type, distr_params=distr_params, trafo=trafo + ) + # Counts of quantiles in bins + vc = pd.cut(held_out_quantiles, bins=np.linspace(0, 1, Q_BINS + 1)).value_counts() + # Expected counts under uniformity + e_vec = vc * vc.sum() / (Q_BINS * vc) + # Observed counts + o_vec = vc + # Chi-squared test + chsq = stats.chisquare(o_vec, e_vec, ddof=2) + # dist=1, ddf<3, trafo="squared" does very badly, hence the high tolerance + assert chsq.pvalue > TOL_ALPHA