From 26131ff7b1ab4a29060034eca3b0c426224c5b65 Mon Sep 17 00:00:00 2001 From: Guillaume VIGNAL Date: Tue, 17 Sep 2024 16:38:43 +0200 Subject: [PATCH] Add interaction plot in the report --- .github/workflows/main.yml | 2 +- README.md | 2 +- eurybia/core/smartdrift.py | 9 +++++---- eurybia/core/smartplotter.py | 2 +- eurybia/report/generation.py | 16 ++++++++++++++-- eurybia/report/project_report.py | 12 ++++++------ eurybia/report/properties.py | 11 ++++++++--- pyproject.toml | 1 + 8 files changed, 37 insertions(+), 18 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0fd2009..1c94806 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -8,7 +8,7 @@ jobs: strategy: max-parallel: 1 matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/README.md b/README.md index 89dbf5b..85e594f 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ sd.generate_report( ## 🛠 Installation -Eurybia is intended to work with Python versions 3.9 to 3.11. Installation can be done with pip: +Eurybia is intended to work with Python versions 3.9 to 3.12. Installation can be done with pip: ``` pip install eurybia diff --git a/eurybia/core/smartdrift.py b/eurybia/core/smartdrift.py index 3387c8e..b6d85dc 100644 --- a/eurybia/core/smartdrift.py +++ b/eurybia/core/smartdrift.py @@ -10,7 +10,6 @@ import shutil import tempfile from pathlib import Path -from typing import Dict import catboost import pandas as pd @@ -199,12 +198,12 @@ def __init__( def compile( self, full_validation=False, - ignore_cols: list = list(), + ignore_cols: list = None, sampling=True, sample_size=100000, datadrift_file=None, date_compile_auc=None, - hyperparameter: Dict = catboost_hyperparameter_init.copy(), + hyperparameter: dict = catboost_hyperparameter_init.copy(), attr_importance="feature_importances_", ): r""" @@ -237,6 +236,8 @@ def compile( >>> SD.compile() """ + if ignore_cols is None: + ignore_cols = [] if datadrift_file is not None: self.datadrift_file = datadrift_file if hyperparameter is not None: @@ -468,7 +469,7 @@ def _analyze_consistency(self, full_validation=False, ignore_cols: list = list() and will not be analyzed: \n {err_dtypes}""" ) # Feature values - err_mods: Dict[str, Dict] = {} + err_mods: dict[str, dict] = {} if full_validation is True: invalid_cols = ignore_cols + new_cols + removed_cols + err_dtypes for column in self.df_baseline.columns: diff --git a/eurybia/core/smartplotter.py b/eurybia/core/smartplotter.py index 1fcd83e..930ef71 100644 --- a/eurybia/core/smartplotter.py +++ b/eurybia/core/smartplotter.py @@ -597,7 +597,7 @@ def generate_modeldrift_data( if data_modeldrift is None: data_modeldrift = self.smartdrift.data_modeldrift if data_modeldrift is None: - raise Exception( + raise ValueError( """You should run the add_data_modeldrift method before displaying model drift performances. For more information see the documentation""" ) diff --git a/eurybia/report/generation.py b/eurybia/report/generation.py index 82476ca..ea62d08 100644 --- a/eurybia/report/generation.py +++ b/eurybia/report/generation.py @@ -232,7 +232,6 @@ def get_data_drift_panel(dr: DriftReport) -> pn.Column: pn.pane.Markdown("### Univariate analysis"), pn.pane.Markdown(report_text["Data drift"]["07"]), ] - contribution_figures, contribution_labels = dr.display_model_contribution() distribution_figures, labels, distribution_tables = dr.display_dataset_analysis(global_analysis=False)["univariate"] distribution_plots_blocks = get_select_plots( @@ -262,6 +261,9 @@ def get_data_drift_panel(dr: DriftReport) -> pn.Column: max_gauge=0.2, ) blocks += [pn.pane.Plotly(js_fig)] + + contribution_figures, contribution_labels = dr.display_model_contribution() + blocks += [ pn.pane.Markdown("## Feature contribution on data drift's detection"), pn.pane.Markdown(report_text["Data drift"]["09"]), @@ -273,14 +275,24 @@ def get_data_drift_panel(dr: DriftReport) -> pn.Column: figures=contribution_figures, ) blocks += contribution_plots_blocks + + fig_02 = dr.explainer.plot.top_interactions_plot(nb_top_interactions=10) + fig_02.update_layout(width=1240) + blocks += [ + pn.pane.Markdown("## Feature interaction on data drift's detection"), + pn.pane.Markdown(report_text["Data drift"]["10"]), + pn.pane.Plotly(fig_02), + ] + if dr.smartdrift.historical_auc is not None: fig = dr.smartdrift.plot.generate_historical_datadrift_metric() fig.update_layout(width=1240) blocks += [ pn.pane.Markdown("## Historical Data drift"), - pn.pane.Markdown(report_text["Data drift"]["10"]), + pn.pane.Markdown(report_text["Data drift"]["11"]), pn.pane.Plotly(fig), ] + return pn.Column(*blocks, name="Data drift", styles=dict(display="none"), css_classes=["data-drift"]) diff --git a/eurybia/report/project_report.py b/eurybia/report/project_report.py index e7b838c..624df0c 100644 --- a/eurybia/report/project_report.py +++ b/eurybia/report/project_report.py @@ -5,7 +5,7 @@ import copy import logging import os -from typing import Dict, Optional, Union +from typing import Optional, Union import jinja2 import pandas as pd @@ -36,11 +36,11 @@ class DriftReport: Attributes ---------- smartdrift: object - SmartDrift object + SmartDrift object explainer : shapash.explainer.smart_explainer.SmartExplainer - A shapash SmartExplainer object that has already be compiled + A shapash SmartExplainer object that has already be compiled title_story : str - Report title + Report title metadata : dict Information about the project (author, description, ...) df_predict : pd.DataFrame @@ -56,7 +56,7 @@ def __init__( smartdrift: SmartDrift, explainer: SmartExplainer, project_info_file: Optional[str] = None, - config_report: Optional[Dict] = None, + config_report: Optional[dict] = None, ): """ Parameters @@ -253,7 +253,7 @@ def display_model_contribution(self): c_list = self.explainer._classes if multiclass else [1] # list just used for multiclass plot_list = [] labels = [] - for index_label, label in enumerate(c_list): # Iterating over all labels in multiclass case + for label in c_list: # Iterating over all labels in multiclass case for feature in self.features_imp_list: fig = self.explainer.plot.contribution_plot(feature, label=label, max_points=200) plot_list.append(fig) diff --git a/eurybia/report/properties.py b/eurybia/report/properties.py index 06e8c66..a8ae7d3 100644 --- a/eurybia/report/properties.py +++ b/eurybia/report/properties.py @@ -1,6 +1,6 @@ -from typing import Any, Dict +from typing import Any -report_text: Dict[str, Any] = { +report_text: dict[str, Any] = { "Index": { "01": "- Project information: report context and information", "02": "- Consistency Analysis: highlighting differences between the two datasets", @@ -77,7 +77,12 @@ "This representation constitutes a support to understand the drift " "when the analysis of the dataset is unclear." ), - "10": ("Line chart showing the metrics evolution of the datadrift classifier over the given period of time."), + "10": ( + "This graph represents the interactions between couple of variable to the data drift detection. " + "This representation constitutes a support to understand the drift " + "when the analysis of the dataset is unclear." + ), + "11": ("Line chart showing the metrics evolution of the datadrift classifier over the given period of time."), }, "Model drift": { "01": ( diff --git a/pyproject.toml b/pyproject.toml index ca1734e..005aa80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", ]