Skip to content

Commit

Permalink
Merge pull request #42 from biolab/dev
Browse files Browse the repository at this point in the history
[ENH] use new interface for testing and scoring learners
  • Loading branch information
JakaKokosar authored Mar 23, 2022
2 parents 2416ce1 + 7ec99f3 commit 15ce516
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 18 deletions.
10 changes: 6 additions & 4 deletions orangecontrib/survival_analysis/evaluation/scoring.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from lifelines.utils import concordance_index
from Orange.data import DiscreteVariable, ContinuousVariable
from Orange.data import DiscreteVariable, ContinuousVariable, Domain
from Orange.evaluation.scoring import Score
from orangecontrib.survival_analysis.widgets.data import get_survival_endpoints
from orangecontrib.survival_analysis.widgets.data import get_survival_endpoints, contains_survival_endpoints

__all__ = ['ConcordanceIndex']

Expand All @@ -11,8 +11,10 @@ class SurvivalScorer(Score, abstract=True):
ContinuousVariable,
DiscreteVariable,
)
is_built_in = False
problem_type = 'time_to_event'

@staticmethod
def is_compatible(domain: Domain) -> bool:
return contains_survival_endpoints(domain)


class ConcordanceIndex(SurvivalScorer):
Expand Down
10 changes: 5 additions & 5 deletions orangecontrib/survival_analysis/modeling/cox.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from orangecontrib.survival_analysis.widgets.data import (
contains_survival_endpoints,
get_survival_endpoints,
MISSING_SURVIVAL_DATA,
)


def to_data_frame(table: Table) -> pd.DataFrame:
columns = table.domain.attributes + table.domain.class_vars
df = pd.DataFrame({col.name: table.get_column_view(col)[0] for col in columns})
Expand Down Expand Up @@ -47,14 +47,14 @@ def __call__(self, data, ret=Model.Value):
class CoxRegressionLearner(Learner):
__returns__ = CoxRegressionModel
supports_multiclass = True
learner_adequacy_err_msg = 'Survival variables expected. Use As Survival Data widget.'

def __init__(self, preprocessors=None, **kwargs):
self.params = vars()
super().__init__(preprocessors=preprocessors)

def check_learner_adequacy(self, domain):
return len(domain.class_vars) == 2
def incompatibility_reason(self, domain):
if not contains_survival_endpoints(domain):
return MISSING_SURVIVAL_DATA

def fit_storage(self, data):
return self.fit(data)
Expand All @@ -67,7 +67,7 @@ def _fit_model(self, data):

def fit(self, data):
if not contains_survival_endpoints(data.domain):
raise ValueError(self.learner_adequacy_err_msg)
raise ValueError(MISSING_SURVIVAL_DATA)
time_var, event_var = get_survival_endpoints(data.domain)

df = to_data_frame(data)
Expand Down
18 changes: 11 additions & 7 deletions orangecontrib/survival_analysis/widgets/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
from Orange.data import Table, Domain, Variable


TIME_VAR = 'time'
EVENT_VAR = 'event'
TIME_TO_EVENT_VAR = '_time_to_event_var'
TIME_VAR: str = 'time'
EVENT_VAR: str = 'event'
TIME_TO_EVENT_VAR: str = '_time_to_event_var'

# Error/Warning messages related to survival data tables.
MISSING_ROWS: str = 'Rows with missing values detected. They will be omitted.'
MISSING_SURVIVAL_DATA: str = (
'No survival data detected. ' 'Use the "As Survival Data" widget or consult the documentation.'
)


def contains_survival_endpoints(domain: Domain):
Expand Down Expand Up @@ -35,15 +41,13 @@ def get_survival_endpoints(domain: Domain) -> Tuple[Optional[Variable], Optional

def check_survival_data(f):
"""Check for survival data."""
error_msg = 'No survival data detected. Use the "As Survival Data" widget or consult the documentation.'
warning_msg = 'Rows with missing values detected. They will be omitted.'

@wraps(f)
def wrapper(widget, data: Table, *args, **kwargs):
widget.Error.add_message('missing_survival_data', UnboundMsg(error_msg))
widget.Error.add_message('missing_survival_data', UnboundMsg(MISSING_SURVIVAL_DATA))
widget.Error.missing_survival_data.clear()

widget.Warning.add_message('missing_values_detected', UnboundMsg(warning_msg))
widget.Warning.add_message('missing_values_detected', UnboundMsg(MISSING_ROWS))
widget.Warning.missing_values_detected.clear()

if data is not None and isinstance(data, Table):
Expand Down
5 changes: 3 additions & 2 deletions orangecontrib/survival_analysis/widgets/owcoxregression.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ def check_data(self):
self.Error.sparse_not_supported.clear()
if self.data is not None and self.learner is not None:
self.Error.data_error.clear()
if not self.learner.check_learner_adequacy(self.data.domain):
self.Error.data_error(self.learner.learner_adequacy_err_msg)
incompatibility_reason = self.learner.incompatibility_reason(self.data.domain)
if incompatibility_reason is not None:
self.Error.data_error(incompatibility_reason)
elif not len(self.data):
self.Error.data_error("Dataset is empty.")
elif self.data.X.size == 0:
Expand Down

0 comments on commit 15ce516

Please sign in to comment.