diff --git a/src/redflag/sklearn.py b/src/redflag/sklearn.py index b9da9ec..f755afe 100644 --- a/src/redflag/sklearn.py +++ b/src/redflag/sklearn.py @@ -52,9 +52,10 @@ def formatwarning(message, *args, **kwargs): class BaseRedflagDetector(BaseEstimator, TransformerMixin): - def __init__(self, func, warning, **kwargs): + def __init__(self, func, message, warn=True, **kwargs): self.func = lambda X: func(X, **kwargs) - self.warning = warning + self.message = message + self.warn = warn def fit(self, X, y=None): X = check_array(X) @@ -62,7 +63,11 @@ def fit(self, X, y=None): positive = [i for i, feature in enumerate(X.T) if self.func(feature)] if n := len(positive): pos = ', '.join(str(i) for i in positive) - warnings.warn(f"🚩 Feature{'' if n == 1 else 's'} {pos} {'has' if n == 1 else 'have'} samples that {self.warning}.") + message = f"🚩 Feature{'' if n == 1 else 's'} {pos} {'has' if n == 1 else 'have'} samples that {self.message}." + if self.warn: + warnings.warn(message) + else: + raise ValueError(message) if y is not None: y_ = np.asarray(y) @@ -70,7 +75,11 @@ def fit(self, X, y=None): y_ = y_.reshape(-1, 1) for i, target in enumerate(y_.T): if is_continuous(target) and self.func(target): - warnings.warn(f"🚩 Target {i} has samples that {self.warning}.") + message = f"🚩 Target {i} has samples that {self.message}." + if self.warn: + warnings.warn(message) + else: + raise ValueError(message) return self @@ -204,10 +213,11 @@ class MultivariateOutlierDetector(BaseEstimator, TransformerMixin): [-0.55581573, -2.01881162], [-0.90942756, 0.36922933]]) """ - def __init__(self, p=0.99, threshold=None, factor=1): + def __init__(self, p=0.99, threshold=None, factor=1, warn=True): self.p = p if threshold is None else None self.threshold = threshold self.factor = factor + self.warn = warn def fit(self, X, y=None): return self @@ -225,7 +235,11 @@ def transform(self, X, y=None): outliers = has_outliers(X, p=self.p, threshold=self.threshold, factor=self.factor) if outliers: - warnings.warn(f"🚩 Dataset has more multivariate outlier samples than expected.") + message = f"🚩 Dataset has more multivariate outlier samples than expected." + if self.warn: + warnings.warn(message) + else: + raise ValueError(message) if (y is not None) and is_continuous(y): if np.asarray(y).ndim == 1: @@ -235,7 +249,11 @@ def transform(self, X, y=None): y_ = y kind = 'multivariate' if has_outliers(y_, p=self.p, threshold=self.threshold, factor=self.factor): - warnings.warn(f"🚩 Target has more {kind} outlier samples than expected.") + message = f"🚩 Target has more {kind} outlier samples than expected." + if self.warn: + warnings.warn(message) + else: + raise ValueError(message) return X @@ -325,10 +343,11 @@ def transform(self, X, y=None): positive = np.where(W > self.threshold)[0] if n := positive.size: pos = ', '.join(str(i) for i in positive) + message = f"🚩 Feature{'s' if n > 1 else ''} {pos} {'have distributions that are' if n > 1 else 'has a distribution that is'} different from training." if self.warn: - warnings.warn(f"🚩 Feature{'s' if n > 1 else ''} {pos} {'have distributions that are' if n > 1 else 'has a distribution that is'} different from training.") + warnings.warn(message) else: - raise ValueError(f"🚩 Feature{'s' if n > 1 else ''} {pos} {'have distributions that are' if n > 1 else 'has a distribution that is'} different from training.") + raise ValueError(message) return X @@ -354,17 +373,20 @@ def fit_transform(self, X, y=None): class OutlierDetector(BaseEstimator, TransformerMixin): - def __init__(self, p=0.99, threshold=None, factor=1.0): + def __init__(self, p=0.99, threshold=None, factor=1.0, warn=True): """ Constructor for the class. Args: p (float): The confidence level. threshold (float): The threshold for the Wasserstein distance. + factor (float): Multiplier for the expected outliers. + warn (bool): Whether to raise a warning or raise an error. """ self.threshold = threshold self.p = p if threshold is None else None self.factor = factor + self.warn = warn def _actual_vs_expected(self, z, n, d): """ @@ -410,7 +432,11 @@ def fit(self, X, y=None): self.outliers_, expected = self._actual_vs_expected(z, n, d) if self.outliers_.size > expected: - warnings.warn(f"🚩 There are more outliers than expected in the training data ({self.outliers_.size} vs {expected}).") + message = f"🚩 There are more outliers than expected in the training data ({self.outliers_.size} vs {expected})." + if self.warn: + warnings.warn(message) + else: + raise ValueError(message) return self @@ -440,7 +466,11 @@ def transform(self, X, y=None): actual, expected = self._actual_vs_expected(z, n, d) if actual.size > expected: - warnings.warn(f"🚩 There are more outliers than expected in the data ({actual.size} vs {expected}).") + message = f"🚩 There are more outliers than expected in the data ({actual.size} vs {expected})." + if self.warn: + warnings.warn(message) + else: + raise ValueError(message) return X @@ -466,7 +496,7 @@ def fit_transform(self, X, y=None): class ImbalanceDetector(BaseEstimator, TransformerMixin): - def __init__(self, method='id', threshold=0.4, classes=None): + def __init__(self, method='id', threshold=0.4, classes=None, warn=True): """ Constructor for the class. @@ -495,6 +525,7 @@ def __init__(self, method='id', threshold=0.4, classes=None): self.method = method self.threshold = threshold self.classes = classes + self.warn = warn def fit(self, X, y=None): """ @@ -531,9 +562,17 @@ def fit(self, X, y=None): imbalanced = (len(self.minority_classes_) > 0) and (imbalance > self.threshold) if imbalanced and self.method == 'id': - warnings.warn(f"🚩 The labels are imbalanced by more than the threshold ({imbalance:0.3f} > {self.threshold:0.3f}). See self.minority_classes_ for the minority classes.") - if imbalanced and self.method == 'ir': - warnings.warn(f"🚩 The labels are imbalanced by more than the threshold ({imbalance:0.1f} > {self.threshold:0.1f}). See self.minority_classes_ for the minority classes.") + message = f"🚩 The labels are imbalanced by more than the threshold ({imbalance:0.3f} > {self.threshold:0.3f}). See self.minority_classes_ for the minority classes." + elif imbalanced and self.method == 'ir': + message = f"🚩 The labels are imbalanced by more than the threshold ({imbalance:0.1f} > {self.threshold:0.1f}). See self.minority_classes_ for the minority classes." + else: + message = None + + if message is not None: + if self.warn: + warnings.warn(message) + else: + raise ValueError(message) return self @@ -554,7 +593,7 @@ def transform(self, X, y=None): class ImbalanceComparator(BaseEstimator, TransformerMixin): - def __init__(self, method='id', threshold=0.4, min_class_diff=1, classes=None): + def __init__(self, method='id', threshold=0.4, min_class_diff=1, classes=None, warn=True): """ Args: method (str): The method to use for imbalance detection. In general, @@ -584,6 +623,7 @@ def __init__(self, method='id', threshold=0.4, min_class_diff=1, classes=None): self.threshold = threshold self.min_class_diff = min_class_diff self.classes = classes + self.warn = warn def fit(self, X, y=None): """ @@ -649,15 +689,27 @@ def transform(self, X, y=None): # Check if there's a different *number* of minority classes. if diff >= self.min_class_diff: - warnings.warn(f"🚩 There is a different number of minority classes ({len(min_classes)}) compared to the training data ({len(self.minority_classes_)}).") + message = f"🚩 There is a different number of minority classes ({len(min_classes)}) compared to the training data ({len(self.minority_classes_)})." + if self.warn: + warnings.warn(message) + else: + raise ValueError(message) # Check if there's the same number but the minority classes have changed. if set(min_classes) != set(self.minority_classes_): - warnings.warn(f"🚩 The minority classes ({', '.join(str(c) for c in set(min_classes))}) are different from those in the training data ({', '.join(str(c) for c in set(self.minority_classes_))}).") + message = f"🚩 The minority classes ({', '.join(str(c) for c in set(min_classes))}) are different from those in the training data ({', '.join(str(c) for c in set(self.minority_classes_))})." + if self.warn: + warnings.warn(message) + else: + raise ValueError(message) # Check if the imbalance metric has changed. if abs(imbalance - self.imbalance_) >= self.threshold: - warnings.warn(f"🚩 The imbalance metric ({imbalance}) is different from that of the training data ({self.imbalance_}).") + message = f"🚩 The imbalance metric ({imbalance}) is different from that of the training data ({self.imbalance_})." + if self.warn: + warnings.warn(message) + else: + raise ValueError(message) return check_array(X) @@ -683,7 +735,7 @@ def fit_transform(self, X, y=None): class ImportanceDetector(BaseEstimator, TransformerMixin): - def __init__(self, threshold=None, random_state=None): + def __init__(self, threshold=None, random_state=None, warn=True): """ Constructor for the class. @@ -697,6 +749,7 @@ def __init__(self, threshold=None, random_state=None): self.threshold = threshold self.random_state = random_state + self.warn = warn def fit(self, X, y=None): """ @@ -721,7 +774,11 @@ def fit(self, X, y=None): if (m := len(most_important)) <= 2 and (m < M): most_str = ', '.join(str(i) for i in sorted(most_important)) - warnings.warn(f"🚩 Feature{'' if m == 1 else 's'} {most_str} {'has' if m == 1 else 'have'} very high importance; check for leakage.") + message = f"🚩 Feature{'' if m == 1 else 's'} {most_str} {'has' if m == 1 else 'have'} very high importance; check for leakage." + if self.warn: + warnings.warn(message) + else: + raise ValueError(message) return self # Don't do this check if there were high-importance features (infer that the others are low.) @@ -729,7 +786,11 @@ def fit(self, X, y=None): if (m := len(least_important)) > 0: least_str = ', '.join(str(i) for i in sorted(least_important)) - warnings.warn(f"🚩 Feature{'' if m == 1 else 's'} {least_str} {'has' if m == 1 else 'have'} low importance; check for relevance.") + message = f"🚩 Feature{'' if m == 1 else 's'} {least_str} {'has' if m == 1 else 'have'} low importance; check for relevance." + if self.warn: + warnings.warn(message) + else: + raise ValueError(message) return self