Skip to content

Commit

Permalink
Add warn vs exception option to all detectors, fixes #68
Browse files Browse the repository at this point in the history
  • Loading branch information
kwinkunks committed Sep 23, 2023
1 parent 96ac82d commit 48efa8e
Showing 1 changed file with 84 additions and 23 deletions.
107 changes: 84 additions & 23 deletions src/redflag/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,34 @@ def formatwarning(message, *args, **kwargs):

class BaseRedflagDetector(BaseEstimator, TransformerMixin):

def __init__(self, func, warning, **kwargs):
def __init__(self, func, message, warn=True, **kwargs):
self.func = lambda X: func(X, **kwargs)
self.warning = warning
self.message = message
self.warn = warn

def fit(self, X, y=None):
X = check_array(X)

positive = [i for i, feature in enumerate(X.T) if self.func(feature)]
if n := len(positive):
pos = ', '.join(str(i) for i in positive)
warnings.warn(f"🚩 Feature{'' if n == 1 else 's'} {pos} {'has' if n == 1 else 'have'} samples that {self.warning}.")
message = f"🚩 Feature{'' if n == 1 else 's'} {pos} {'has' if n == 1 else 'have'} samples that {self.message}."
if self.warn:
warnings.warn(message)
else:
raise ValueError(message)

if y is not None:
y_ = np.asarray(y)
if y_.ndim == 1:
y_ = y_.reshape(-1, 1)
for i, target in enumerate(y_.T):
if is_continuous(target) and self.func(target):
warnings.warn(f"🚩 Target {i} has samples that {self.warning}.")
message = f"🚩 Target {i} has samples that {self.message}."
if self.warn:
warnings.warn(message)
else:
raise ValueError(message)

return self

Expand Down Expand Up @@ -204,10 +213,11 @@ class MultivariateOutlierDetector(BaseEstimator, TransformerMixin):
[-0.55581573, -2.01881162],
[-0.90942756, 0.36922933]])
"""
def __init__(self, p=0.99, threshold=None, factor=1):
def __init__(self, p=0.99, threshold=None, factor=1, warn=True):
self.p = p if threshold is None else None
self.threshold = threshold
self.factor = factor
self.warn = warn

def fit(self, X, y=None):
return self
Expand All @@ -225,7 +235,11 @@ def transform(self, X, y=None):
outliers = has_outliers(X, p=self.p, threshold=self.threshold, factor=self.factor)

if outliers:
warnings.warn(f"🚩 Dataset has more multivariate outlier samples than expected.")
message = f"🚩 Dataset has more multivariate outlier samples than expected."
if self.warn:
warnings.warn(message)
else:
raise ValueError(message)

if (y is not None) and is_continuous(y):
if np.asarray(y).ndim == 1:
Expand All @@ -235,7 +249,11 @@ def transform(self, X, y=None):
y_ = y
kind = 'multivariate'
if has_outliers(y_, p=self.p, threshold=self.threshold, factor=self.factor):
warnings.warn(f"🚩 Target has more {kind} outlier samples than expected.")
message = f"🚩 Target has more {kind} outlier samples than expected."
if self.warn:
warnings.warn(message)
else:
raise ValueError(message)

return X

Expand Down Expand Up @@ -325,10 +343,11 @@ def transform(self, X, y=None):
positive = np.where(W > self.threshold)[0]
if n := positive.size:
pos = ', '.join(str(i) for i in positive)
message = f"🚩 Feature{'s' if n > 1 else ''} {pos} {'have distributions that are' if n > 1 else 'has a distribution that is'} different from training."
if self.warn:
warnings.warn(f"🚩 Feature{'s' if n > 1 else ''} {pos} {'have distributions that are' if n > 1 else 'has a distribution that is'} different from training.")
warnings.warn(message)
else:
raise ValueError(f"🚩 Feature{'s' if n > 1 else ''} {pos} {'have distributions that are' if n > 1 else 'has a distribution that is'} different from training.")
raise ValueError(message)

return X

Expand All @@ -354,17 +373,20 @@ def fit_transform(self, X, y=None):

class OutlierDetector(BaseEstimator, TransformerMixin):

def __init__(self, p=0.99, threshold=None, factor=1.0):
def __init__(self, p=0.99, threshold=None, factor=1.0, warn=True):
"""
Constructor for the class.
Args:
p (float): The confidence level.
threshold (float): The threshold for the Wasserstein distance.
factor (float): Multiplier for the expected outliers.
warn (bool): Whether to raise a warning or raise an error.
"""
self.threshold = threshold
self.p = p if threshold is None else None
self.factor = factor
self.warn = warn

def _actual_vs_expected(self, z, n, d):
"""
Expand Down Expand Up @@ -410,7 +432,11 @@ def fit(self, X, y=None):

self.outliers_, expected = self._actual_vs_expected(z, n, d)
if self.outliers_.size > expected:
warnings.warn(f"🚩 There are more outliers than expected in the training data ({self.outliers_.size} vs {expected}).")
message = f"🚩 There are more outliers than expected in the training data ({self.outliers_.size} vs {expected})."
if self.warn:
warnings.warn(message)
else:
raise ValueError(message)

return self

Expand Down Expand Up @@ -440,7 +466,11 @@ def transform(self, X, y=None):

actual, expected = self._actual_vs_expected(z, n, d)
if actual.size > expected:
warnings.warn(f"🚩 There are more outliers than expected in the data ({actual.size} vs {expected}).")
message = f"🚩 There are more outliers than expected in the data ({actual.size} vs {expected})."
if self.warn:
warnings.warn(message)
else:
raise ValueError(message)

return X

Expand All @@ -466,7 +496,7 @@ def fit_transform(self, X, y=None):

class ImbalanceDetector(BaseEstimator, TransformerMixin):

def __init__(self, method='id', threshold=0.4, classes=None):
def __init__(self, method='id', threshold=0.4, classes=None, warn=True):
"""
Constructor for the class.
Expand Down Expand Up @@ -495,6 +525,7 @@ def __init__(self, method='id', threshold=0.4, classes=None):
self.method = method
self.threshold = threshold
self.classes = classes
self.warn = warn

def fit(self, X, y=None):
"""
Expand Down Expand Up @@ -531,9 +562,17 @@ def fit(self, X, y=None):
imbalanced = (len(self.minority_classes_) > 0) and (imbalance > self.threshold)

if imbalanced and self.method == 'id':
warnings.warn(f"🚩 The labels are imbalanced by more than the threshold ({imbalance:0.3f} > {self.threshold:0.3f}). See self.minority_classes_ for the minority classes.")
if imbalanced and self.method == 'ir':
warnings.warn(f"🚩 The labels are imbalanced by more than the threshold ({imbalance:0.1f} > {self.threshold:0.1f}). See self.minority_classes_ for the minority classes.")
message = f"🚩 The labels are imbalanced by more than the threshold ({imbalance:0.3f} > {self.threshold:0.3f}). See self.minority_classes_ for the minority classes."
elif imbalanced and self.method == 'ir':
message = f"🚩 The labels are imbalanced by more than the threshold ({imbalance:0.1f} > {self.threshold:0.1f}). See self.minority_classes_ for the minority classes."
else:
message = None

if message is not None:
if self.warn:
warnings.warn(message)
else:
raise ValueError(message)

return self

Expand All @@ -554,7 +593,7 @@ def transform(self, X, y=None):

class ImbalanceComparator(BaseEstimator, TransformerMixin):

def __init__(self, method='id', threshold=0.4, min_class_diff=1, classes=None):
def __init__(self, method='id', threshold=0.4, min_class_diff=1, classes=None, warn=True):
"""
Args:
method (str): The method to use for imbalance detection. In general,
Expand Down Expand Up @@ -584,6 +623,7 @@ def __init__(self, method='id', threshold=0.4, min_class_diff=1, classes=None):
self.threshold = threshold
self.min_class_diff = min_class_diff
self.classes = classes
self.warn = warn

def fit(self, X, y=None):
"""
Expand Down Expand Up @@ -649,15 +689,27 @@ def transform(self, X, y=None):

# Check if there's a different *number* of minority classes.
if diff >= self.min_class_diff:
warnings.warn(f"🚩 There is a different number of minority classes ({len(min_classes)}) compared to the training data ({len(self.minority_classes_)}).")
message = f"🚩 There is a different number of minority classes ({len(min_classes)}) compared to the training data ({len(self.minority_classes_)})."
if self.warn:
warnings.warn(message)
else:
raise ValueError(message)

# Check if there's the same number but the minority classes have changed.
if set(min_classes) != set(self.minority_classes_):
warnings.warn(f"🚩 The minority classes ({', '.join(str(c) for c in set(min_classes))}) are different from those in the training data ({', '.join(str(c) for c in set(self.minority_classes_))}).")
message = f"🚩 The minority classes ({', '.join(str(c) for c in set(min_classes))}) are different from those in the training data ({', '.join(str(c) for c in set(self.minority_classes_))})."
if self.warn:
warnings.warn(message)
else:
raise ValueError(message)

# Check if the imbalance metric has changed.
if abs(imbalance - self.imbalance_) >= self.threshold:
warnings.warn(f"🚩 The imbalance metric ({imbalance}) is different from that of the training data ({self.imbalance_}).")
message = f"🚩 The imbalance metric ({imbalance}) is different from that of the training data ({self.imbalance_})."
if self.warn:
warnings.warn(message)
else:
raise ValueError(message)

return check_array(X)

Expand All @@ -683,7 +735,7 @@ def fit_transform(self, X, y=None):

class ImportanceDetector(BaseEstimator, TransformerMixin):

def __init__(self, threshold=None, random_state=None):
def __init__(self, threshold=None, random_state=None, warn=True):
"""
Constructor for the class.
Expand All @@ -697,6 +749,7 @@ def __init__(self, threshold=None, random_state=None):

self.threshold = threshold
self.random_state = random_state
self.warn = warn

def fit(self, X, y=None):
"""
Expand All @@ -721,15 +774,23 @@ def fit(self, X, y=None):

if (m := len(most_important)) <= 2 and (m < M):
most_str = ', '.join(str(i) for i in sorted(most_important))
warnings.warn(f"🚩 Feature{'' if m == 1 else 's'} {most_str} {'has' if m == 1 else 'have'} very high importance; check for leakage.")
message = f"🚩 Feature{'' if m == 1 else 's'} {most_str} {'has' if m == 1 else 'have'} very high importance; check for leakage."
if self.warn:
warnings.warn(message)
else:
raise ValueError(message)
return self

# Don't do this check if there were high-importance features (infer that the others are low.)
least_important = least_important_features(importances, threshold=self.threshold)

if (m := len(least_important)) > 0:
least_str = ', '.join(str(i) for i in sorted(least_important))
warnings.warn(f"🚩 Feature{'' if m == 1 else 's'} {least_str} {'has' if m == 1 else 'have'} low importance; check for relevance.")
message = f"🚩 Feature{'' if m == 1 else 's'} {least_str} {'has' if m == 1 else 'have'} low importance; check for relevance."
if self.warn:
warnings.warn(message)
else:
raise ValueError(message)

return self

Expand Down

0 comments on commit 48efa8e

Please sign in to comment.