Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] log-normal AFT model from lifelines #258

Merged
merged 4 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_reference/survival.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Accelerated failure time models
:toctree: auto_generated/
:template: class.rst

AFTLogNormal
AFTWeibull

Generalized additive survival models
Expand Down
3 changes: 2 additions & 1 deletion skpro/survival/aft/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module containing accelerated failure time models."""
# copyright: skpro developers, BSD-3-Clause License (see LICENSE file)

__all__ = ["AFTWeibull"]
__all__ = ["AFTLogNormal", "AFTWeibull"]

from skpro.survival.aft._aft_lifelines_lognormal import AFTLogNormal
from skpro.survival.aft._aft_lifelines_weibull import AFTWeibull
209 changes: 209 additions & 0 deletions skpro/survival/aft/_aft_lifelines_lognormal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""Interface adapter to lifelines Log-Normal AFT model."""
# copyright: skpro developers, BSD-3-Clause License (see LICENSE file)

__author__ = ["fkiraly"]

import numpy as np

from skpro.distributions.lognormal import LogNormal
from skpro.survival.adapters.lifelines import _LifelinesAdapter
from skpro.survival.base import BaseSurvReg


class AFTLogNormal(_LifelinesAdapter, BaseSurvReg):
r"""Log-Normal AFT model, from lifelines.

Direct interface to ``lifelines.fitters.LogNormalAFTFitter``,
by ``CamDavidsonPilon``.

This class implements a Log-Normal AFT model. The model has parametric form, with
:math:`\mu(x) = \exp\left(\beta_0 + \beta_1x_1 + ... + \beta_n x_n \right)`,
and optionally,
:math:`\sigma(y) = \exp\left(\alpha_0 + \alpha_1 y_1 + ... + \alpha_m y_m \right)`,

with predictive distribution being Log-Normal, with
mean :math:`\mu(x)` and standard deviation :math:`\sigma(y)`.

Parameters
----------
mu_cols: pd.Index or coercible, optional, default=None
Columns of the input data frame to be used as covariates for
the mean parameter :math:`\mu`.
If None, all columns are used.

sd_cols: string "all", pd.Index or coercible, optional, default=None
Columns of the input data frame to be used as covariates for
the standard deviation parameter :math:`\sigma`.
If None, no covariates are used, the standard deviation parameter
is estimated as a constant. If "all", all columns are used.

fit_intercept: boolean, optional (default=True)
Whether to fit an intercept term in the model.

alpha: float, optional (default=0.05)
the level in the confidence intervals around the estimated survival function,
for computation of ``confidence_intervals_`` fitted parameter.

penalizer: float or array, optional (default=0.0)
the penalizer coefficient to the size of the coefficients.
See ``l1_ratio``. Must be equal to or greater than 0.
Alternatively, penalizer is an array equal in size to the number of parameters,
with penalty coefficients for specific variables. For
example, ``penalizer=0.01 * np.ones(p)`` is the same as ``penalizer=0.01``

l1_ratio: float, optional (default=0.0)
how much of the penalizer should be attributed to an l1 penalty
(otherwise an l2 penalty). The penalty function looks like
``penalizer * l1_ratio * ||w||_1 + 0.5 * penalizer * (1 - l1_ratio) * ||w||^2_2`` # noqa E501

Attributes
----------
params_ : DataFrame
The estimated coefficients
confidence_intervals_ : DataFrame
The lower and upper confidence intervals for the coefficients
durations: Series
The event_observed variable provided
event_observed: Series
The event_observed variable provided
weights: Series
The event_observed variable provided
variance_matrix_ : DataFrame
The variance matrix of the coefficients
standard_errors_: Series
the standard errors of the estimates
score_: float
the concordance index of the model.
"""

_tags = {"authors": ["CamDavidsonPilon", "fkiraly"]}
# CamDavidsonPilon, credit for interfaced estimator

def __init__(
self,
mu_cols=None,
sd_cols=None,
fit_intercept: bool = True,
alpha: float = 0.05,
penalizer: float = 0.0,
l1_ratio: float = 0.0,
):
self.mu_cols = mu_cols
self.sd_cols = sd_cols
self.alpha = alpha
self.penalizer = penalizer
self.l1_ratio = l1_ratio
self.fit_intercept = fit_intercept

super().__init__()

if mu_cols is not None:
self.X_col_subset = mu_cols

def _get_lifelines_class(self):
"""Getter of the lifelines class to be used for the adapter."""
from lifelines.fitters.weibull_aft_fitter import WeibullAFTFitter

return WeibullAFTFitter

def _get_lifelines_object(self):
"""Abstract method to initialize lifelines object.

The default initializes result of _get_lifelines_class
with self.get_params.
"""
cls = self._get_lifelines_class()
params = self.get_params()
params.pop("mu_cols", None)
params.pop("sd_cols", None)
return cls(**params)

def _add_extra_fit_args(self, X, y, C=None):
"""Get extra arguments for the fit method.

Parameters
----------
X : pd.DataFrame
Training features
y: pd.DataFrame
Training labels
C: pd.DataFrame, optional (default=None)
Censoring information for survival analysis.
fit_args: dict, optional (default=None)
Existing arguments for the fit method, from the adapter.

Returns
-------
dict
Extra arguments for the fit method.
"""
if self.mu_cols is not None:
if self.mu_cols == "all":
return {"ancillary": True}
else:
return {"ancillary": X[self.mu_cols]}
else:
return {}

def _predict_proba(self, X):
"""Predict_proba method adapter.

Parameters
----------
X : pd.DataFrame
Features to predict on.

Returns
-------
skpro Empirical distribution
"""
if self.sd_cols == "all":
ancillary = X
elif self.sd_cols is not None:
ancillary = X[self.sd_cols]
else:
ancillary = None

if self.mu_cols is not None:
df = X[self.mu_cols]
else:
df = X

lifelines_est = getattr(self, self._estimator_attr)
ll_pred_proba = lifelines_est._prep_inputs_for_prediction_and_return_scores

mu, sigma = ll_pred_proba(df, ancillary)
mu = np.expand_dims(mu, axis=1)
sigma = np.expand_dims(sigma, axis=1)

dist = LogNormal(mu=mu, sigma=sigma, index=X.index, columns=self._y_cols)
return dist

@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.

Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.

Returns
-------
params : dict or list of dict, default = {}
Parameters to create testing instances of the class
Each dict are parameters to construct an "interesting" test instance, i.e.,
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
params1 = {}

params2 = {
"sd_cols": "all",
"fit_intercept": False,
"alpha": 0.1,
"penalizer": 0.1,
"l1_ratio": 0.1,
}
return [params1, params2]
Loading