diff --git a/bluecast/monitoring/data_monitoring.py b/bluecast/monitoring/data_monitoring.py index 3cfc4003..8ed6bbf2 100644 --- a/bluecast/monitoring/data_monitoring.py +++ b/bluecast/monitoring/data_monitoring.py @@ -26,11 +26,12 @@ class DataDrift: This is suitable for batch models and not recommended for online models. """ - def __init__(self): + def __init__(self, random_state=25): self.kolmogorov_smirnov_flags: Dict[str, bool] = {} self.population_stability_index_values: Dict[str, float] = {} self.population_stability_index_flags: Dict[str, Any] = {} self.adversarial_auc_score: float = 0.0 + self.random_state = random_state def kolmogorov_smirnov_test( self, @@ -222,7 +223,11 @@ def qqplot_two_samples( plt.close() def adversarial_validation( - self, df: pd.DataFrame, df_new: pd.DataFrame, cat_columns: Optional[List] + self, + df: pd.DataFrame, + df_new: pd.DataFrame, + cat_columns: Optional[List], + sample_to_same_size: bool = True, ) -> float: """ Perform adversarial validation to check if the new data is similar to the training data. @@ -232,8 +237,15 @@ def adversarial_validation( :param df: Baseline DataFrame that is the point of comparison. :param df_new: New DataFrame to compare against the baseline. :param cat_columns: (Optional) List with names of categorical columns. + :param sample_to_same_size: If any of the two DataFrames is larger, subsample it to not impact AUC score due + to imbalance. :return: Auc score that indicates similarity. """ + if sample_to_same_size: + min_size = min(len(df.index), len(df_new.index)) + df = df.sample(n=min_size, random_state=self.random_state) + df_new = df_new.sample(n=min_size, random_state=self.random_state) + # add the train/test labels df["AV_label"] = 0 df_new["AV_label"] = 1 diff --git a/dist/bluecast-1.6.2-py3-none-any.whl b/dist/bluecast-1.6.2-py3-none-any.whl index 5ad36e7b..b03ad9f9 100644 Binary files a/dist/bluecast-1.6.2-py3-none-any.whl and b/dist/bluecast-1.6.2-py3-none-any.whl differ diff --git a/dist/bluecast-1.6.2.tar.gz b/dist/bluecast-1.6.2.tar.gz index 5afffef4..6243331a 100644 Binary files a/dist/bluecast-1.6.2.tar.gz and b/dist/bluecast-1.6.2.tar.gz differ