-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] adapter to
lifelines
, most distributional survival regressors…
… interfaced (#247) Adds an adapter to `lifelines`, exposing array-like survival functions as `Empirical` distributions in `predict_proba`. Adds all models from `lifelines` which are capable of full distributional predictions: * `AalenAdditiveFitter` * `CoxPHFitter` * `WeibullAFTFitter` Remaining AFT fitters require distributions not yet merged.
- Loading branch information
Showing
11 changed files
with
897 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
"""Common utilities for adapters.""" | ||
|
||
import numpy as np | ||
|
||
|
||
def _clip_surv(surv_arr): | ||
"""Clips improper survival function values to proper range. | ||
Enforces: values are in [0, 1] and are monotonically decreasing. | ||
First clips to [0, 1], then enforces monotonicity, by replacing | ||
any value with minimum of itself and any previous values. | ||
Parameters | ||
---------- | ||
surv_arr : 2D np.ndarray | ||
Survival function values. | ||
index 0 is instance index. | ||
index 1 is time index, increasing. | ||
Returns | ||
------- | ||
surv_arr_clipped : 2D np.ndarray | ||
Clipped survival function values. | ||
surv_arr_diff : 2D np.ndarray, same shape as surv_arr_clipped. | ||
Difference of clipped survival function values. | ||
Same as np.diff(surv_arr, axis=1, prepend=1, append=0), | ||
then summing the last two columns to become one column. | ||
Returned to avoid recomputation, if needed later in context. | ||
clipped : boolean | ||
Whether clipping was needed. | ||
""" | ||
too_large = surv_arr > 1 | ||
too_small = surv_arr < 0 | ||
|
||
surv_arr[too_large] = 1 | ||
surv_arr[too_small] = 0 | ||
|
||
surv_arr_diff = _surv_diff(surv_arr) | ||
|
||
# avoid iterative minimization if no further clipping is needed | ||
if not (surv_arr_diff > 0).any(): | ||
clipped = too_large.any() or too_small.any() | ||
return surv_arr, surv_arr_diff, clipped | ||
|
||
# enforce monotonicity | ||
# iterating from left to right ensures values are replaced | ||
# with minimum of itself and all values to the left | ||
for i in range(1, surv_arr.shape[1]): | ||
surv_arr[:, i] = np.minimum(surv_arr[:, i], surv_arr[:, i - 1]) | ||
|
||
surv_arr_diff = _surv_diff(surv_arr) | ||
|
||
return surv_arr, surv_arr_diff, True | ||
|
||
|
||
def _surv_diff(surv_arr): | ||
"""Compute difference of survival function values. | ||
Parameters | ||
---------- | ||
surv_arr : 2D np.ndarray | ||
Survival function values. | ||
index 0 is instance index. | ||
index 1 is time index, increasing. | ||
Returns | ||
------- | ||
surv_arr_diff : 2D np.ndarray, same shape as surv_arr | ||
Difference of survival function values. | ||
Same as np.diff(surv_arr, axis=1, prepend=1, append=0), | ||
then summing the last two columns to become one column | ||
""" | ||
surv_arr_diff = np.diff(surv_arr, axis=1, prepend=1, append=0) | ||
|
||
surv_arr_diff[:, -2] = surv_arr_diff[:, -2] + surv_arr_diff[:, -1] | ||
surv_arr_diff = surv_arr_diff[:, :-1] | ||
|
||
return surv_arr_diff | ||
|
||
|
||
def _get_fitted_params_default_safe(obj=None): | ||
"""Obtain fitted params of object, per sklearn convention. | ||
Same as _get_fitted_params_default, but with exception handling. | ||
This is since in sksurv, feature_importances_ is a property | ||
and may raise an exception if the estimator does not have it. | ||
Parameters | ||
---------- | ||
obj : any object | ||
Returns | ||
------- | ||
fitted_params : dict with str keys | ||
fitted parameters, keyed by names of fitted parameter | ||
""" | ||
# default retrieves all self attributes ending in "_" | ||
# and returns them with keys that have the "_" removed | ||
# | ||
# get all attributes ending in "_", exclude any that start with "_" (private) | ||
fitted_params = [ | ||
attr for attr in dir(obj) if attr.endswith("_") and not attr.startswith("_") | ||
] | ||
|
||
def hasattr_safe(obj, attr): | ||
try: | ||
if hasattr(obj, attr): | ||
getattr(obj, attr) | ||
return True | ||
except Exception: | ||
return False | ||
|
||
# remove the "_" at the end | ||
fitted_param_dict = { | ||
p[:-1]: getattr(obj, p) for p in fitted_params if hasattr_safe(obj, p) | ||
} | ||
|
||
return fitted_param_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file) | ||
"""Implements adapter for lifelines models.""" | ||
|
||
__all__ = ["_LifelinesAdapter"] | ||
__author__ = ["fkiraly"] | ||
|
||
from warnings import warn | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from skpro.distributions.empirical import Empirical | ||
from skpro.survival.adapters._common import _clip_surv, _get_fitted_params_default_safe | ||
from skpro.utils.sklearn import prep_skl_df | ||
|
||
|
||
class _LifelinesAdapter: | ||
"""Mixin adapter class for lifelines models.""" | ||
|
||
_tags = { | ||
# packaging info | ||
# -------------- | ||
"authors": ["fkiraly"], | ||
"python_dependencies": ["lifelines"], | ||
"license_type": "permissive", | ||
# capability tags | ||
# --------------- | ||
"X_inner_mtype": "pd_DataFrame_Table", | ||
"y_inner_mtype": "pd_DataFrame_Table", | ||
"C_inner_mtype": "pd_DataFrame_Table", | ||
"capability:multioutput": False, | ||
} | ||
|
||
# defines the name of the attribute containing the lifelines estimator | ||
_estimator_attr = "_estimator" | ||
|
||
def _get_lifelines_class(self): | ||
"""Abstract method to get lifelines class. | ||
should import and return lifelines class | ||
""" | ||
# from lifelines import LifelinesClass | ||
# | ||
# return LifelinesClass | ||
raise NotImplementedError("abstract method") | ||
|
||
def _get_lifelines_object(self): | ||
"""Abstract method to initialize lifelines object. | ||
The default initializes result of _get_lifelines_class | ||
with self.get_params. | ||
""" | ||
cls = self._get_lifelines_class() | ||
return cls(**self.get_params()) | ||
|
||
def _init_lifelines_object(self): | ||
"""Abstract method to initialize lifelines object and set to _estimator_attr. | ||
The default writes the return of _get_lifelines_object to | ||
the attribute of self with name _estimator_attr | ||
""" | ||
cls = self._get_lifelines_object() | ||
setattr(self, self._estimator_attr, cls) | ||
return getattr(self, self._estimator_attr) | ||
|
||
def _get_extra_fit_args(self, X, y, C=None): | ||
"""Get extra arguments for the fit method. | ||
Parameters | ||
---------- | ||
X : pd.DataFrame | ||
Training features | ||
y: pd.DataFrame | ||
Training labels | ||
C: pd.DataFrame, optional (default=None) | ||
Censoring information for survival analysis. | ||
Returns | ||
------- | ||
dict | ||
Extra arguments for the fit method. | ||
""" | ||
return {} | ||
|
||
def _fit(self, X, y, C=None): | ||
"""Fit estimator training data. | ||
Parameters | ||
---------- | ||
X : pd.DataFrame | ||
Training features | ||
y: pd.DataFrame | ||
Training labels | ||
C: pd.DataFrame, optional (default=None) | ||
Censoring information for survival analysis. | ||
Returns | ||
------- | ||
self: reference to self | ||
Fitted estimator. | ||
""" | ||
lifelines_est = self._init_lifelines_object() | ||
|
||
# input conversion | ||
X = X.astype("float") # lifelines insists on float dtype | ||
X = prep_skl_df(X) | ||
|
||
if hasattr(self, "X_col_subset"): | ||
X = X[self.X_col_subset] | ||
|
||
to_concat = [X, y] | ||
|
||
if C is not None: | ||
C_col = 1 - C.copy() # lifelines uses 1 for uncensored, 0 for censored | ||
C_col.columns = ["__C"] | ||
to_concat.append(C_col) | ||
|
||
df = pd.concat(to_concat, axis=1) | ||
|
||
self._y_cols = y.columns # remember column names for later | ||
y_name = y.columns[0] | ||
|
||
fit_args = { | ||
"df": df, | ||
"duration_col": y_name, | ||
} | ||
if C is not None: | ||
fit_args["event_col"] = "__C" | ||
|
||
fit_args.update(self._get_extra_fit_args(X, y, C)) | ||
|
||
# fit lifelines estimator | ||
lifelines_est.fit(**fit_args) | ||
|
||
# write fitted params to self | ||
# some fitted parameters are properties and may raise exceptions | ||
# for example, AIC_ and AIC_partial_ of CoxPHFitter | ||
# to avoid this, we use a safe getter | ||
lifelines_fitted_params = _get_fitted_params_default_safe(lifelines_est) | ||
for k, v in lifelines_fitted_params.items(): | ||
setattr(self, f"{k}_", v) | ||
|
||
return self | ||
|
||
def _predict_proba(self, X): | ||
"""Predict_proba method adapter. | ||
Parameters | ||
---------- | ||
X : pd.DataFrame | ||
Features to predict on. | ||
Returns | ||
------- | ||
skpro Empirical distribution | ||
""" | ||
lifelines_est = getattr(self, self._estimator_attr) | ||
|
||
# input conversion | ||
X = X.astype("float") # lifelines insists on float dtype | ||
X = prep_skl_df(X) | ||
|
||
# predict on X | ||
lifelines_survf = lifelines_est.predict_survival_function(X) | ||
|
||
times = lifelines_survf.index | ||
|
||
nt = len(times) | ||
mi = pd.MultiIndex.from_product([X.index, range(nt)]).swaplevel() | ||
|
||
times_val = np.repeat(times, repeats=len(X)) | ||
times_df = pd.DataFrame(times_val, index=mi, columns=self._y_cols) | ||
|
||
lifelines_survf_t = np.transpose(lifelines_survf.values) | ||
_, lifelines_survf_t_diff, clipped = _clip_surv(lifelines_survf_t) | ||
|
||
if clipped: | ||
warn( | ||
f"Warning from {self.__class__.__name__}: " | ||
f"Interfaced lifelines class {lifelines_est.__class__.__name__} " | ||
"produced improper survival function predictions, i.e., " | ||
"not monotonically decreasing or not in [0, 1]. " | ||
"skpro has clipped the predictions to enforce proper range and " | ||
"valid predictive distributions. " | ||
"However, predictions may still be unreliable.", | ||
stacklevel=2, | ||
) | ||
|
||
weights = -lifelines_survf_t_diff.flatten() | ||
weights_df = pd.Series(weights, index=mi) | ||
|
||
dist = Empirical( | ||
spl=times_df, weights=weights_df, index=X.index, columns=self._y_cols | ||
) | ||
|
||
return dist |
Oops, something went wrong.