diff --git a/examples/model_evaluation/plot_estimator_report.py b/examples/model_evaluation/plot_estimator_report.py new file mode 100644 index 000000000..fb8fd30f2 --- /dev/null +++ b/examples/model_evaluation/plot_estimator_report.py @@ -0,0 +1,385 @@ +""" +============================================ +Get insights from any scikit-learn estimator +============================================ + +This example shows how the :class:`skore.EstimatorReport` class can be used to +quickly get insights from any scikit-learn estimator. +""" + +# %% +# +# TODO: we need to describe the aim of this classification problem. +from skrub.datasets import fetch_open_payments + +dataset = fetch_open_payments() +df = dataset.X +y = dataset.y + +# %% +from skrub import TableReport + +TableReport(df) + +# %% +TableReport(y.to_frame()) + +# %% +# Looking at the distributions of the target, we observe that this classification +# task is quite imbalanced. It means that we have to be careful when selecting a set +# of statistical metrics to evaluate the classification performance of our predictive +# model. In addition, we see that the class labels are not specified by an integer +# 0 or 1 but instead by a string "allowed" or "disallowed". +# +# For our application, the label of interest is "allowed". +pos_label, neg_label = "allowed", "disallowed" + +# %% +# Before training a predictive model, we need to split our dataset into a training +# and a validation set. +from skore import train_test_split + +X_train, X_test, y_train, y_test = train_test_split(df, y, random_state=42) + +# %% +# TODO: we have a perfect case to show useful feature of the `train_test_split` +# function from `skore`. +# +# Now, we need to define a predictive model. Hopefully, `skrub` provides a convenient +# function (:func:`skrub.tabular_learner`) when it comes to getting strong baseline +# predictive models with a single line of code. As its feature engineering is generic, +# it does not provide some handcrafted and tailored feature engineering but still +# provides a good starting point. +# +# So let's create a classifier for our task and fit it on the training set. +from skrub import tabular_learner + +estimator = tabular_learner("classifier").fit(X_train, y_train) +estimator + +# %% +# +# Introducing the :class:`skore.EstimatorReport` class +# ---------------------------------------------------- +# +# Now, we would be interested in getting some insights from our predictive model. +# One way is to use the :class:`skore.EstimatorReport` class. This constructor will +# detect that our estimator is already fitted and will not fit it again. +from skore import EstimatorReport + +reporter = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test +) +reporter + +# %% +# +# Once the reporter is created, we get some information regarding the available tools +# allowing us to get some insights from our specific model on the specific task. +# +# You can get a similar information if you call the :meth:`~skore.EstimatorReport.help` +# method. +reporter.help() + +# %% +# +# Be aware that you can access the help for each individual sub-accessor. For instance: +reporter.metrics.help() + +# %% +reporter.metrics.plot.help() + +# %% +# +# Metrics computation with aggressive caching +# ------------------------------------------- +# +# At this point, we might be interested to have a first look at the statistical +# performance of our model on the validation set that we provided. We can access it +# by calling any of the metrics displayed above. Since we are greedy, we want to get +# several metrics at once and we will use the +# :meth:`~skore.EstimatorReport.metrics.report_metrics` method. +import time + +start = time.time() +metric_report = reporter.metrics.report_metrics(pos_label=pos_label) +end = time.time() +metric_report + +# %% +print(f"Time taken to compute the metrics: {end - start:.2f} seconds") + +# %% +# +# An interesting feature provided by the :class:`skore.EstimatorReport` is the +# the caching mechanism. Indeed, when we have a large enough dataset, computing the +# predictions for a model is not cheap anymore. For instance, on our smallish dataset, +# it took a couple of seconds to compute the metrics. The reporter will cache the +# predictions and if you are interested in computing a metric again or an alternative +# metric that requires the same predictions, it will be faster. Let's check by +# requesting the same metrics report again. + +start = time.time() +metric_report = reporter.metrics.report_metrics(pos_label=pos_label) +end = time.time() +metric_report + +# %% +print(f"Time taken to compute the metrics: {end - start:.2f} seconds") + +# %% +# +# Since we obtain a pandas dataframe, we can also use the plotting interface of +# pandas. +import matplotlib.pyplot as plt + +ax = metric_report.T.plot.barh() +ax.set_title("Metrics report") +plt.tight_layout() + +# %% +# +# Whenever computing a metric, we check if the predictions are available in the cache +# and reload them if available. So for instance, let's compute the log loss. + +start = time.time() +log_loss = reporter.metrics.log_loss() +end = time.time() +log_loss + +# %% +print(f"Time taken to compute the log loss: {end - start:.2f} seconds") + +# %% +# +# We can show that without initial cache, it would have taken more time to compute +# the log loss. +reporter.clean_cache() + +start = time.time() +log_loss = reporter.metrics.log_loss() +end = time.time() +log_loss + +# %% +print(f"Time taken to compute the log loss: {end - start:.2f} seconds") + +# %% +# +# By default, the metrics are computed on the test set. However, if a training set +# is provided, we can also compute the metrics by specifying the `data_source` +# parameter. +reporter.metrics.log_loss(data_source="train") + +# %% +# +# In the case where we are interested in computing the metrics on a completely new set +# of data, we can use the `data_source="X_y"` parameter. In addition, we need to provide +# a `X` and `y` parameters. + +start = time.time() +metric_report = reporter.metrics.report_metrics( + data_source="X_y", X=X_test, y=y_test, pos_label=pos_label +) +end = time.time() +metric_report + +# %% +print(f"Time taken to compute the metrics: {end - start:.2f} seconds") + +# %% +# +# As in the other case, we rely on the cache to avoid recomputing the predictions. +# Internally, we compute a hash of the input data to be sure that we can hit the cache +# in a consistent way. + +# %% +start = time.time() +metric_report = reporter.metrics.report_metrics( + data_source="X_y", X=X_test, y=y_test, pos_label=pos_label +) +end = time.time() +metric_report + +# %% +print(f"Time taken to compute the metrics: {end - start:.2f} seconds") + +# %% +# +# .. warning:: +# In this last example, we rely on computing the hash of the input data. Therefore, +# there is a trade-off: the computation of the hash is not free and it might be +# faster to compute the predictions instead. +# +# Be aware that you can also benefit from the caching mechanism with your own custom +# metrics. We only expect that you define your own metric function to take `y_true` +# and `y_pred` as the first two positional arguments. It can take any other arguments. +# Let's see an example. + + +def operational_decision_cost(y_true, y_pred, amount): + mask_true_positive = (y_true == pos_label) & (y_pred == pos_label) + mask_true_negative = (y_true == neg_label) & (y_pred == neg_label) + mask_false_positive = (y_true == neg_label) & (y_pred == pos_label) + mask_false_negative = (y_true == pos_label) & (y_pred == neg_label) + # FIXME: we need to make sense of the cost sensitive part with the right naming + fraudulent_refuse = mask_true_positive.sum() * 50 + fraudulent_accept = -amount[mask_false_negative].sum() + legitimate_refuse = mask_false_positive.sum() * -5 + legitimate_accept = (amount[mask_true_negative] * 0.02).sum() + return fraudulent_refuse + fraudulent_accept + legitimate_refuse + legitimate_accept + + +# %% +# +# In our use case, we have a operational decision to make that translate the +# classification outcome into a cost. It translate the confusion matrix into a cost +# matrix based on some amount linked to each sample in the dataset that are provided to +# us. Here, we randomly generate some amount as an illustration. +import numpy as np + +rng = np.random.default_rng(42) +amount = rng.integers(low=100, high=1000, size=len(y_test)) + +# %% +# +# Let's make sure that a function called the `predict` method and cached the result. +# We compute the accuracy metric to make sure that the `predict` method is called. +reporter.metrics.accuracy() + +# %% +# +# We can now compute the cost of our operational decision. +start = time.time() +cost = reporter.metrics.custom_metric( + metric_function=operational_decision_cost, + metric_name="Operational Decision Cost", + response_method="predict", + amount=amount, +) +end = time.time() +cost + +# %% +print(f"Time taken to compute the cost: {end - start:.2f} seconds") + +# %% +# +# Let's now clean the cache and see if it is faster. +reporter.clean_cache() + +# %% +start = time.time() +cost = reporter.metrics.custom_metric( + metric_function=operational_decision_cost, + metric_name="Operational Decision Cost", + response_method="predict", + amount=amount, +) +end = time.time() +cost + +# %% +print(f"Time taken to compute the cost: {end - start:.2f} seconds") + +# %% +# +# We observe that caching is working as expected. It is really handy because it means +# that you can compute some additional metrics without having to recompute the +# the predictions. +reporter.metrics.report_metrics( + scoring=["precision", "recall", operational_decision_cost], + pos_label=pos_label, + scoring_kwargs={ + "amount": amount, + "response_method": "predict", + "metric_name": "Operational Decision Cost", + }, +) + +# %% +# +# It could happen that you are interested in providing several custom metrics which +# does not necessarily share the same parameters. In this more complex case, we will +# require you to provide a scorer using the :func:`sklearn.metrics.make_scorer` +# function. +from sklearn.metrics import make_scorer, f1_score + +f1_scorer = make_scorer( + f1_score, + response_method="predict", + metric_name="F1 Score", + pos_label=pos_label, +) +operational_decision_cost_scorer = make_scorer( + operational_decision_cost, + response_method="predict", + metric_name="Operational Decision Cost", + amount=amount, +) +reporter.metrics.report_metrics(scoring=[f1_scorer, operational_decision_cost_scorer]) + +# %% +# +# Effortless one-liner plotting +# ----------------------------- +# +# The :class:`skore.EstimatorReport` class also provides a plotting interface that +# allows to plot *defacto* the most common plots. As for the the metrics, we only +# provide the meaningful set of plots for the provided estimator. +reporter.metrics.plot.help() + +# %% +# +# Let's start by plotting the ROC curve for our binary classification task. +display = reporter.metrics.plot.roc(pos_label=pos_label) +plt.tight_layout() + +# %% +# +# The plot functionality is built upon the scikit-learn display objects. We return +# those display (slightly modified to improve the UI) in case you want to tweak some +# of the plot properties. You can have quick look at the available attributes and +# methods by calling the `help` method or simply by printing the display. +display + +# %% +display.help() + +# %% +display.plot() +display.ax_.set_title("Example of a ROC curve") +display.figure_ +plt.tight_layout() + +# %% +# +# Similarly to the metrics, we aggressively use the caching to avoid recomputing the +# predictions of the model. We also cache the plot display object by detection if the +# input parameters are the same as the previous call. Let's demonstrate the kind of +# performance gain we can get. +start = time.time() +# we already trigger the computation of the predictions in a previous call +reporter.metrics.plot.roc(pos_label=pos_label) +plt.tight_layout() +end = time.time() + +# %% +print(f"Time taken to compute the ROC curve: {end - start:.2f} seconds") + +# %% +# +# Now, let's clean the cache and check if we get a slowdown. +reporter.clean_cache() + +# %% +start = time.time() +reporter.metrics.plot.roc(pos_label=pos_label) +plt.tight_layout() +end = time.time() + +# %% +print(f"Time taken to compute the ROC curve: {end - start:.2f} seconds") + +# %% +# As expected, since we need to recompute the predictions, it takes more time. diff --git a/skore/pyproject.toml b/skore/pyproject.toml index 81eeac218..f12baa04c 100644 --- a/skore/pyproject.toml +++ b/skore/pyproject.toml @@ -8,6 +8,8 @@ dependencies = [ "diskcache", "fastapi", "numpy", + "pandas", + "matplotlib", "plotly>=5,<6", "pyarrow", "rich", @@ -66,8 +68,6 @@ artifacts = ["src/skore/ui/static/"] test = [ "altair>=5,<6", "httpx", - "matplotlib", - "pandas", "pillow", "plotly", "polars", @@ -85,13 +85,12 @@ test = [ sphinx = [ "IPython", "altair", - "matplotlib", "numpydoc", - "pandas", "polars", "kaleido", "pydata-sphinx-theme", "sphinx", + "sphinx_autosummary_accessors", "sphinx-design", "sphinx-gallery", "sphinx-copybutton", diff --git a/skore/src/skore/__init__.py b/skore/src/skore/__init__.py index 1f33543cf..135bd0947 100644 --- a/skore/src/skore/__init__.py +++ b/skore/src/skore/__init__.py @@ -6,11 +6,12 @@ from rich.theme import Theme from skore.project import Project, open -from skore.sklearn import CrossValidationReporter, train_test_split +from skore.sklearn import CrossValidationReporter, EstimatorReport, train_test_split from skore.utils._show_versions import show_versions __all__ = [ "CrossValidationReporter", + "EstimatorReport", "open", "Project", "show_versions", diff --git a/skore/src/skore/externals/_pandas_accessors.py b/skore/src/skore/externals/_pandas_accessors.py new file mode 100644 index 000000000..7dabdc6bc --- /dev/null +++ b/skore/src/skore/externals/_pandas_accessors.py @@ -0,0 +1,53 @@ +"""Pandas-like accessors. + +This code is copied from: +https://github.com/pandas-dev/pandas/blob/main/pandas/core/accessor.py + +It is used to register accessors for the skore classes. +""" + +from typing import final + + +class DirNamesMixin: + _accessors: set[str] = set() + _hidden_attrs: frozenset[str] = frozenset() + + @final + def _dir_deletions(self) -> set[str]: + return self._accessors | self._hidden_attrs + + def _dir_additions(self) -> set[str]: + return {accessor for accessor in self._accessors if hasattr(self, accessor)} + + def __dir__(self) -> list[str]: + rv = set(super().__dir__()) + rv = (rv - self._dir_deletions()) | self._dir_additions() + return sorted(rv) + + +class Accessor: + def __init__(self, name: str, accessor) -> None: + self._name = name + self._accessor = accessor + + def __get__(self, obj, cls): + if obj is None: + # we're accessing the attribute of the class, i.e., Dataset.geo + return self._accessor + return self._accessor(obj) + + +def _register_accessor(name, cls): + def decorator(accessor): + if hasattr(cls, name): + raise ValueError( + f"registration of accessor {accessor!r} under name " + f"{name!r} for type {cls!r} is overriding a preexisting " + f"attribute with the same name." + ) + setattr(cls, name, Accessor(name, accessor)) + cls._accessors.add(name) + return accessor + + return decorator diff --git a/skore/src/skore/sklearn/__init__.py b/skore/src/skore/sklearn/__init__.py index 9331d60a8..eb3d2188f 100644 --- a/skore/src/skore/sklearn/__init__.py +++ b/skore/src/skore/sklearn/__init__.py @@ -1,9 +1,11 @@ """Enhance `sklearn` functions.""" +from skore.sklearn._estimator import EstimatorReport from skore.sklearn.cross_validation import CrossValidationReporter from skore.sklearn.train_test_split.train_test_split import train_test_split __all__ = [ "train_test_split", "CrossValidationReporter", + "EstimatorReport", ] diff --git a/skore/src/skore/sklearn/_estimator/__init__.py b/skore/src/skore/sklearn/_estimator/__init__.py new file mode 100644 index 000000000..ba7a3058a --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/__init__.py @@ -0,0 +1,25 @@ +from skore.externals._pandas_accessors import _register_accessor +from skore.sklearn._estimator.metrics_accessor import ( + _MetricsAccessor, + _PlotMetricsAccessor, +) +from skore.sklearn._estimator.report import EstimatorReport + + +def register_estimator_report_accessor(name: str): + """Register an accessor for the EstimatorReport class.""" + return _register_accessor(name, EstimatorReport) + + +def register_metrics_accessor(name: str): + """Register an accessor for the EstimatorReport class.""" + return _register_accessor(name, _MetricsAccessor) + + +# add the plot accessor to the metrics accessor +register_metrics_accessor("plot")(_PlotMetricsAccessor) + +# add the metrics accessor to the estimator report +register_estimator_report_accessor("metrics")(_MetricsAccessor) + +__all__ = ["EstimatorReport"] diff --git a/skore/src/skore/sklearn/_estimator/__init__.pyi b/skore/src/skore/sklearn/_estimator/__init__.pyi new file mode 100644 index 000000000..f496ff1f3 --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/__init__.pyi @@ -0,0 +1,3 @@ +from skore.sklearn._estimator.report import EstimatorReport + +__all__ = ["EstimatorReport"] diff --git a/skore/src/skore/sklearn/_estimator/base.py b/skore/src/skore/sklearn/_estimator/base.py new file mode 100644 index 000000000..fffb67c11 --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/base.py @@ -0,0 +1,174 @@ +import inspect +from io import StringIO + +import joblib +from rich.console import Console, Group +from rich.panel import Panel +from rich.tree import Tree + +from skore.externals._sklearn_compat import is_clusterer + + +class _HelpMixin: + """Mixin class providing help for the `help` method and the `__repr__` method.""" + + def _get_methods_for_help(self): + """Get the methods to display in help.""" + methods = inspect.getmembers(self, predicate=inspect.ismethod) + filtered_methods = [] + for name, method in methods: + is_private_method = name.startswith("_") + # we cannot use `isinstance(method, classmethod)` because it is already + # transformed by the decorator `@classmethod`. + is_class_method = inspect.ismethod(method) and method.__self__ is type(self) + is_help_method = name == "help" + if not (is_private_method or is_class_method or is_help_method): + filtered_methods.append((name, method)) + return filtered_methods + + def _sort_methods_for_help(self, methods): + """Sort methods for help display.""" + return sorted(methods) + + def _format_method_name(self, name): + """Format method name for display.""" + return f"{name}(...)" + + def _get_method_description(self, method): + """Get the description for a method.""" + return ( + method.__doc__.split("\n")[0] + if method.__doc__ + else "No description available" + ) + + def _get_help_legend(self): + """Get the help legend.""" + return None + + def _create_help_panel(self): + """Create the help panel.""" + if self._get_help_legend(): + content = Group( + self._create_help_tree(), + f"\n\nLegend:\n{self._get_help_legend()}", + ) + else: + content = self._create_help_tree() + + return Panel( + content, + title=self._get_help_panel_title(), + expand=False, + border_style="orange1", + ) + + def help(self): + """Display available methods using rich.""" + from skore import console # avoid circular import + + console.print(self._create_help_panel()) + + def __repr__(self): + """Return a string representation using rich.""" + console = Console(file=StringIO(), force_terminal=False) + console.print(self._create_help_panel()) + return console.file.getvalue() + + +class _BaseAccessor(_HelpMixin): + """Base class for all accessors.""" + + def __init__(self, parent, icon): + self._parent = parent + self._icon = icon + + def _get_help_panel_title(self): + name = self.__class__.__name__.replace("_", "").replace("Accessor", "").lower() + return f"{self._icon} Available {name} methods" + + def _create_help_tree(self): + """Create a rich Tree with the available methods.""" + tree = Tree(self._get_help_tree_title()) + + methods = self._get_methods_for_help() + methods = self._sort_methods_for_help(methods) + + for name, method in methods: + displayed_name = self._format_method_name(name) + description = self._get_method_description(method) + tree.add(f".{displayed_name}".ljust(26) + f" - {description}") + + return tree + + def _get_X_y_and_data_source_hash(self, *, data_source, X=None, y=None): + """Get the requested dataset and mention if we should hash before caching. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features) or None, default=None + The input data. + + y : array-like of shape (n_samples,) or None, default=None + The target data. + + Returns + ------- + X : array-like of shape (n_samples, n_features) + The requested dataset. + + y : array-like of shape (n_samples,) + The requested dataset. + + data_source_hash : int or None + The hash of the data source. None when we are able to track the data, and + thus relying on X_train, y_train, X_test, y_test. + """ + is_cluster = is_clusterer(self._parent.estimator) + if data_source == "test": + if not (X is None or y is None): + raise ValueError("X and y must be None when data_source is test.") + if self._parent._X_test is None or ( + not is_cluster and self._parent._y_test is None + ): + missing_data = "X_test" if is_cluster else "X_test and y_test" + raise ValueError( + f"No {data_source} data (i.e. {missing_data}) were provided " + f"when creating the reporter. Please provide the {data_source} " + "data either when creating the reporter or by setting data_source " + "to 'X_y' and providing X and y." + ) + return self._parent._X_test, self._parent._y_test, None + elif data_source == "train": + if not (X is None or y is None): + raise ValueError("X and y must be None when data_source is train.") + if self._parent._X_train is None or ( + not is_cluster and self._parent._y_train is None + ): + missing_data = "X_train" if is_cluster else "X_train and y_train" + raise ValueError( + f"No {data_source} data (i.e. {missing_data}) were provided " + f"when creating the reporter. Please provide the {data_source} " + "data either when creating the reporter or by setting data_source " + "to 'X_y' and providing X and y." + ) + return self._parent._X_train, self._parent._y_train, None + elif data_source == "X_y": + if X is None or (not is_cluster and y is None): + missing_data = "X" if is_cluster else "X and y" + raise ValueError( + f"{missing_data} must be provided when data_source is X_y." + ) + return X, y, joblib.hash((X, y)) + else: + raise ValueError( + f"Invalid data source: {data_source}. Possible values are: " + "test, train, X_y." + ) diff --git a/skore/src/skore/sklearn/_estimator/base.pyi b/skore/src/skore/sklearn/_estimator/base.pyi new file mode 100644 index 000000000..eca62e85e --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/base.pyi @@ -0,0 +1,33 @@ +from typing import Any, Literal, Optional + +import numpy as np +from rich.panel import Panel +from rich.tree import Tree + +class _HelpMixin: + def _get_methods_for_help(self) -> list[tuple[str, Any]]: ... + def _sort_methods_for_help( + self, methods: list[tuple[str, Any]] + ) -> list[tuple[str, Any]]: ... + def _format_method_name(self, name: str) -> str: ... + def _get_method_description(self, method: Any) -> str: ... + def _create_help_panel(self) -> Panel: ... + def _get_help_panel_title(self) -> str: ... + def _create_help_tree(self) -> Tree: ... + def help(self) -> None: ... + def __repr__(self) -> str: ... + +class _BaseAccessor(_HelpMixin): + _parent: Any + _icon: str + + def __init__(self, parent: Any, icon: str) -> None: ... + def _get_help_panel_title(self) -> str: ... + def _create_help_tree(self) -> Tree: ... + def _get_X_y_and_data_source_hash( + self, + *, + data_source: Literal["test", "train", "X_y"], + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + ) -> tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[str]]: ... diff --git a/skore/src/skore/sklearn/_estimator/metrics_accessor.py b/skore/src/skore/sklearn/_estimator/metrics_accessor.py new file mode 100644 index 000000000..0f41595ac --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/metrics_accessor.py @@ -0,0 +1,1096 @@ +import inspect +from functools import partial + +import joblib +import numpy as np +import pandas as pd +from sklearn import metrics +from sklearn.metrics._scorer import _BaseScorer +from sklearn.utils.metaestimators import available_if + +from skore.externals._pandas_accessors import DirNamesMixin +from skore.sklearn._estimator.base import _BaseAccessor +from skore.sklearn._plot import ( + PrecisionRecallCurveDisplay, + PredictionErrorDisplay, + RocCurveDisplay, +) +from skore.utils._accessor import _check_supported_ml_task + +############################################################################### +# Metrics accessor +############################################################################### + + +class _MetricsAccessor(_BaseAccessor, DirNamesMixin): + """Accessor for metrics-related operations. + + You can access this accessor using the `metrics` attribute. + """ + + _SCORE_OR_LOSS_ICONS = { + "accuracy": "(↗︎)", + "precision": "(↗︎)", + "recall": "(↗︎)", + "brier_score": "(↘︎)", + "roc_auc": "(↗︎)", + "log_loss": "(↘︎)", + "r2": "(↗︎)", + "rmse": "(↘︎)", + "report_metrics": "", + "custom_metric": "", + } + + def __init__(self, parent): + super().__init__(parent, icon="📏") + + # TODO: should build on the `add_scorers` function + def report_metrics( + self, + *, + data_source="test", + X=None, + y=None, + scoring=None, + pos_label=None, + scoring_kwargs=None, + ): + """Report a set of metrics for our estimator. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + scoring : list of str, callable, or scorer, default=None + The metrics to report. You can get the possible list of string by calling + `reporter.metrics.help()`. When passing a callable, it should take as + arguments `y_true`, `y_pred` as the two first arguments. Additional + arguments can be passed as keyword arguments and will be forwarded with + `scoring_kwargs`. If the callable API is too restrictive (e.g. need to pass + same parameter name with different values), you can use scikit-learn scorers + as provided by :func:`sklearn.metrics.make_scorer`. + + pos_label : int, default=None + The positive class. + + scoring_kwargs : dict, default=None + The keyword arguments to pass to the scoring functions. + + Returns + ------- + pd.DataFrame + The statistics for the metrics. + """ + if scoring is None: + # Equivalent to _get_scorers_to_add + if self._parent._ml_task == "binary-classification": + scoring = ["precision", "recall", "roc_auc"] + if hasattr(self._parent._estimator, "predict_proba"): + scoring.append("brier_score") + elif self._parent._ml_task == "multiclass-classification": + scoring = ["precision", "recall"] + if hasattr(self._parent._estimator, "predict_proba"): + scoring += ["roc_auc", "log_loss"] + else: + scoring = ["r2", "rmse"] + + scores = [] + + for metric in scoring: + # NOTE: we have to check specifically for `_BaseScorer` first because this + # is also a callable but it has a special private API that we can leverage + if isinstance(metric, _BaseScorer): + # scorers have the advantage to have scoped defined kwargs + metric_fn = partial( + self.custom_metric, + metric_function=metric._score_func, + response_method=metric._response_method, + ) + # forward the additional parameters specific to the scorer + metrics_kwargs = {**metric._kwargs} + elif isinstance(metric, str) or callable(metric): + if isinstance(metric, str): + metric_fn = getattr(self, metric) + metrics_kwargs = {} + else: + metric_fn = partial(self.custom_metric, metric_function=metric) + if scoring_kwargs is None: + metrics_kwargs = {} + else: + # check if we should pass any parameters specific to the metric + # callable + metric_callable_params = inspect.signature(metric).parameters + metrics_kwargs = { + param: scoring_kwargs[param] + for param in metric_callable_params + if param in scoring_kwargs + } + metrics_params = inspect.signature(metric_fn).parameters + if scoring_kwargs is not None: + for param in metrics_params: + if param in scoring_kwargs: + metrics_kwargs[param] = scoring_kwargs[param] + if "pos_label" in metrics_params: + metrics_kwargs["pos_label"] = pos_label + else: + raise ValueError( + f"Invalid type of metric: {type(metric)} for {metric!r}" + ) + + scores.append( + metric_fn(data_source=data_source, X=X, y=y, **metrics_kwargs) + ) + + has_multilevel = any( + isinstance(score, pd.DataFrame) and isinstance(score.columns, pd.MultiIndex) + for score in scores + ) + + if has_multilevel: + # Convert single-level dataframes to multi-level + for i, score in enumerate(scores): + if hasattr(score, "columns") and not isinstance( + score.columns, pd.MultiIndex + ): + name_index = ( + ["Metric", "Output"] + if self._parent._ml_task == "regression" + else ["Metric", "Class label"] + ) + scores[i].columns = pd.MultiIndex.from_tuples( + [(col, "") for col in score.columns], + names=name_index, + ) + + return pd.concat(scores, axis=1) + + def _compute_metric_scores( + self, + metric_fn, + X, + y_true, + *, + data_source="test", + response_method, + pos_label=None, + metric_name=None, + **metric_kwargs, + ): + X, y_true, data_source_hash = self._get_X_y_and_data_source_hash( + data_source=data_source, X=X, y=y_true + ) + + y_pred = self._parent._get_cached_response_values( + estimator_hash=self._parent._hash, + estimator=self._parent.estimator, + X=X, + response_method=response_method, + pos_label=pos_label, + data_source=data_source, + data_source_hash=data_source_hash, + ) + cache_key = (self._parent._hash, metric_fn.__name__, data_source) + if data_source_hash: + cache_key += (data_source_hash,) + + metric_params = inspect.signature(metric_fn).parameters + if "pos_label" in metric_params: + cache_key += (pos_label,) + if metric_kwargs != {}: + # we need to enforce the order of the parameter for a specific metric + # to make sure that we hit the cache in a consistent way + ordered_metric_kwargs = sorted(metric_kwargs.keys()) + cache_key += tuple( + ( + joblib.hash(metric_kwargs[key]) + if isinstance(metric_kwargs[key], np.ndarray) + else metric_kwargs[key] + ) + for key in ordered_metric_kwargs + ) + + if cache_key in self._parent._cache: + score = self._parent._cache[cache_key] + else: + metric_params = inspect.signature(metric_fn).parameters + kwargs = {**metric_kwargs} + if "pos_label" in metric_params: + kwargs.update(pos_label=pos_label) + + score = metric_fn(y_true, y_pred, **kwargs) + self._parent._cache[cache_key] = score + + score = np.array([score]) if not isinstance(score, np.ndarray) else score + metric_name = metric_name or metric_fn.__name__ + + if self._parent._ml_task in [ + "binary-classification", + "multiclass-classification", + ]: + if len(score) == 1: + columns = pd.Index([metric_name], name="Metric") + else: + classes = self._parent._estimator.classes_ + columns = pd.MultiIndex.from_arrays( + [[metric_name] * len(classes), classes], + names=["Metric", "Class label"], + ) + score = score.reshape(1, -1) + elif self._parent._ml_task == "regression": + if len(score) == 1: + columns = pd.Index([metric_name], name="Metric") + else: + columns = pd.MultiIndex.from_arrays( + [ + [metric_name] * len(score), + [f"#{i}" for i in range(len(score))], + ], + names=["Metric", "Output"], + ) + score = score.reshape(1, -1) + else: + # FIXME: clusterer would fall here. + columns = None + return pd.DataFrame(score, columns=columns, index=[self._parent.estimator_name]) + + @available_if( + _check_supported_ml_task( + supported_ml_tasks=["binary-classification", "multiclass-classification"] + ) + ) + def accuracy(self, *, data_source="test", X=None, y=None): + """Compute the accuracy score. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + Returns + ------- + pd.DataFrame + The accuracy score. + """ + return self._compute_metric_scores( + metrics.accuracy_score, + X=X, + y_true=y, + data_source=data_source, + response_method="predict", + metric_name=f"Accuracy {self._SCORE_OR_LOSS_ICONS['accuracy']}", + ) + + @available_if( + _check_supported_ml_task( + supported_ml_tasks=["binary-classification", "multiclass-classification"] + ) + ) + def precision( + self, *, data_source="test", X=None, y=None, average=None, pos_label=None + ): + """Compute the precision score. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + average : {"binary","macro", "micro", "weighted", "samples"} or None, \ + default=None + Used with multiclass problems. + If `None`, the metrics for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + + - "binary": Only report results for the class specified by `pos_label`. + This is applicable only if targets (`y_{true,pred}`) are binary. + - "micro": Calculate metrics globally by counting the total true positives, + false negatives and false positives. + - "macro": Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + - "weighted": Calculate metrics for each label, and find their average + weighted by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance; it can result in an F-score + that is not between precision and recall. + - "samples": Calculate metrics for each instance, and find their average + (only meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + + .. note:: + If `pos_label` is specified and `average` is None, then we report + only the statistics of the positive class (i.e. equivalent to + `average="binary"`). + + pos_label : int, default=None + The positive class. + + Returns + ------- + pd.DataFrame + The precision score. + """ + if self._parent._ml_task == "binary-classification" and pos_label is not None: + # if `pos_label` is specified by our user, then we can safely report only + # the statistics of the positive class + average = "binary" + + return self._compute_metric_scores( + metrics.precision_score, + X=X, + y_true=y, + data_source=data_source, + response_method="predict", + pos_label=pos_label, + metric_name=f"Precision {self._SCORE_OR_LOSS_ICONS['precision']}", + average=average, + ) + + @available_if( + _check_supported_ml_task( + supported_ml_tasks=["binary-classification", "multiclass-classification"] + ) + ) + def recall( + self, *, data_source="test", X=None, y=None, average=None, pos_label=None + ): + """Compute the recall score. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + average : {"binary","macro", "micro", "weighted", "samples"} or None, \ + default=None + Used with multiclass problems. + If `None`, the metrics for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + + - "binary": Only report results for the class specified by `pos_label`. + This is applicable only if targets (`y_{true,pred}`) are binary. + - "micro": Calculate metrics globally by counting the total true positives, + false negatives and false positives. + - "macro": Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + - "weighted": Calculate metrics for each label, and find their average + weighted by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance; it can result in an F-score + that is not between precision and recall. Weighted recall is equal to + accuracy. + - "samples": Calculate metrics for each instance, and find their average + (only meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + + .. note:: + If `pos_label` is specified and `average` is None, then we report + only the statistics of the positive class (i.e. equivalent to + `average="binary"`). + + pos_label : int, default=None + The positive class. + + Returns + ------- + pd.DataFrame + The recall score. + """ + if self._parent._ml_task == "binary-classification" and pos_label is not None: + # if `pos_label` is specified by our user, then we can safely report only + # the statistics of the positive class + average = "binary" + + return self._compute_metric_scores( + metrics.recall_score, + X=X, + y_true=y, + data_source=data_source, + response_method="predict", + pos_label=pos_label, + metric_name=f"Recall {self._SCORE_OR_LOSS_ICONS['recall']}", + average=average, + ) + + @available_if( + _check_supported_ml_task(supported_ml_tasks=["binary-classification"]) + ) + def brier_score(self, *, data_source="test", X=None, y=None): + """Compute the Brier score. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + Returns + ------- + pd.DataFrame + The Brier score. + """ + # The Brier score in scikit-learn request `pos_label` to ensure that the + # integral encoding of `y_true` corresponds to the probabilities of the + # `pos_label`. Since we get the predictions with `get_response_method`, we + # can pass any `pos_label`, they will lead to the same result. + return self._compute_metric_scores( + metrics.brier_score_loss, + X=X, + y_true=y, + data_source=data_source, + response_method="predict_proba", + metric_name=f"Brier score {self._SCORE_OR_LOSS_ICONS['brier_score']}", + pos_label=self._parent._estimator.classes_[-1], + ) + + @available_if( + _check_supported_ml_task( + supported_ml_tasks=["binary-classification", "multiclass-classification"] + ) + ) + def roc_auc( + self, *, data_source="test", X=None, y=None, average=None, multi_class="ovr" + ): + """Compute the ROC AUC score. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + average : {"auto", "macro", "micro", "weighted", "samples"}, \ + default=None + Average to compute the ROC AUC score in a multiclass setting. By default, + no average is computed. Otherwise, this determines the type of averaging + performed on the data. + + - "micro": Calculate metrics globally by considering each element of + the label indicator matrix as a label. + - "macro": Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + - "weighted": Calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - "samples": Calculate metrics for each instance, and find their + average. + + .. note:: + Multiclass ROC AUC currently only handles the "macro" and + "weighted" averages. For multiclass targets, `average=None` is only + implemented for `multi_class="ovr"` and `average="micro"` is only + implemented for `multi_class="ovr"`. + + multi_class : {"raise", "ovr", "ovo"}, default="ovr" + The multi-class strategy to use. + + - "raise": Raise an error if the data is multiclass. + - "ovr": Stands for One-vs-rest. Computes the AUC of each class against the + rest. This treats the multiclass case in the same way as the multilabel + case. Sensitive to class imbalance even when `average == "macro"`, + because class imbalance affects the composition of each of the "rest" + groupings. + - "ovo": Stands for One-vs-one. Computes the average AUC of all possible + pairwise combinations of classes. Insensitive to class imbalance when + `average == "macro"`. + + Returns + ------- + pd.DataFrame + The ROC AUC score. + """ + return self._compute_metric_scores( + metrics.roc_auc_score, + X=X, + y_true=y, + data_source=data_source, + response_method=["predict_proba", "decision_function"], + metric_name=f"ROC AUC {self._SCORE_OR_LOSS_ICONS['roc_auc']}", + average=average, + multi_class=multi_class, + ) + + @available_if( + _check_supported_ml_task( + supported_ml_tasks=["binary-classification", "multiclass-classification"] + ) + ) + def log_loss(self, *, data_source="test", X=None, y=None): + """Compute the log loss. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + Returns + ------- + pd.DataFrame + The log-loss. + """ + return self._compute_metric_scores( + metrics.log_loss, + X=X, + y_true=y, + data_source=data_source, + response_method="predict_proba", + metric_name=f"Log loss {self._SCORE_OR_LOSS_ICONS['log_loss']}", + ) + + @available_if(_check_supported_ml_task(supported_ml_tasks=["regression"])) + def r2(self, *, data_source="test", X=None, y=None, multioutput="raw_values"): + """Compute the R² score. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + multioutput : {"raw_values", "uniform_average"} or array-like of shape \ + (n_outputs,), default="raw_values" + Defines aggregating of multiple output values. Array-like value defines + weights used to average errors. The other possible values are: + + - "raw_values": Returns a full set of errors in case of multioutput input. + - "uniform_average": Errors of all outputs are averaged with uniform weight. + + By default, no averaging is done. + + Returns + ------- + pd.DataFrame + The R² score. + """ + return self._compute_metric_scores( + metrics.r2_score, + X=X, + y_true=y, + data_source=data_source, + response_method="predict", + metric_name=f"R² {self._SCORE_OR_LOSS_ICONS['r2']}", + multioutput=multioutput, + ) + + @available_if(_check_supported_ml_task(supported_ml_tasks=["regression"])) + def rmse(self, *, data_source="test", X=None, y=None, multioutput="raw_values"): + """Compute the root mean squared error. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + multioutput : {"raw_values", "uniform_average"} or array-like of shape \ + (n_outputs,), default="raw_values" + Defines aggregating of multiple output values. Array-like value defines + weights used to average errors. The other possible values are: + + - "raw_values": Returns a full set of errors in case of multioutput input. + - "uniform_average": Errors of all outputs are averaged with uniform weight. + + By default, no averaging is done. + + Returns + ------- + pd.DataFrame + The root mean squared error. + """ + return self._compute_metric_scores( + metrics.root_mean_squared_error, + X=X, + y_true=y, + data_source=data_source, + response_method="predict", + metric_name=f"RMSE {self._SCORE_OR_LOSS_ICONS['rmse']}", + multioutput=multioutput, + ) + + def custom_metric( + self, + metric_function, + response_method, + *, + metric_name=None, + data_source="test", + X=None, + y=None, + **kwargs, + ): + """Compute a custom metric. + + It brings some flexibility to compute any desired metric. However, we need to + follow some rules: + + - `metric_function` should take `y_true` and `y_pred` as the first two + positional arguments. + - `response_method` corresponds to the estimator's method to be invoked to get + the predictions. It can be a string or a list of strings to defined in which + order the methods should be invoked. + + Parameters + ---------- + metric_function : callable + The metric function to be computed. The expected signature is + `metric_function(y_true, y_pred, **kwargs)`. + + response_method : str or list of str + The estimator's method to be invoked to get the predictions. The possible + values are: `predict`, `predict_proba`, `predict_log_proba`, and + `decision_function`. + + metric_name : str, default=None + The name of the metric. If not provided, it will be inferred from the + metric function. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + **kwargs : dict + Any additional keyword arguments to be passed to the metric function. + + Returns + ------- + pd.DataFrame + The custom metric. + """ + return self._compute_metric_scores( + metric_function, + X=X, + y_true=y, + data_source=data_source, + response_method=response_method, + metric_name=metric_name, + **kwargs, + ) + + #################################################################################### + # Methods related to the help tree + #################################################################################### + + def _sort_methods_for_help(self, methods): + """Override sort method for metrics-specific ordering. + + In short, we display the `report_metrics` first and then the `custom_metric`. + """ + + def _sort_key(method): + name = method[0] + if name == "custom_metric": + priority = 1 + elif name == "report_metrics": + priority = 2 + else: + priority = 0 + return priority, name + + return sorted(methods, key=_sort_key) + + def _format_method_name(self, name): + """Override format method for metrics-specific naming.""" + method_name = f"{name}(...)" + method_name = method_name.ljust(22) + if self._SCORE_OR_LOSS_ICONS[name] in ("(↗︎)", "(↘︎)"): + if self._SCORE_OR_LOSS_ICONS[name] == "(↗︎)": + method_name += f"[cyan]{self._SCORE_OR_LOSS_ICONS[name]}[/cyan]" + return method_name.ljust(43) + else: # (↘︎) + method_name += f"[orange1]{self._SCORE_OR_LOSS_ICONS[name]}[/orange1]" + return method_name.ljust(49) + else: + return method_name.ljust(29) + + def _get_methods_for_help(self): + """Override to exclude the plot accessor from methods list.""" + methods = super()._get_methods_for_help() + return [(name, method) for name, method in methods if name != "plot"] + + def _create_help_tree(self): + """Override to include plot methods in a separate branch.""" + tree = super()._create_help_tree() + + # Add plot methods in a separate branch + plot_branch = tree.add("[bold cyan].plot 🎨[/bold cyan]") + plot_methods = self.plot._get_methods_for_help() + plot_methods = self.plot._sort_methods_for_help(plot_methods) + + for name, method in plot_methods: + displayed_name = self.plot._format_method_name(name) + description = self.plot._get_method_description(method) + plot_branch.add(f".{displayed_name}".ljust(27) + f"- {description}") + + return tree + + def _get_help_panel_title(self): + return f"[bold cyan]{self._icon} Available metrics methods[/bold cyan]" + + def _get_help_legend(self): + return ( + "[cyan](↗︎)[/cyan] higher is better [orange1](↘︎)[/orange1] lower is better" + ) + + def _get_help_tree_title(self): + return "[bold cyan]reporter.metrics[/bold cyan]" + + +######################################################################################## +# Sub-accessors +# Plotting +######################################################################################## + + +class _PlotMetricsAccessor(_BaseAccessor): + """Plotting methods for the metrics accessor.""" + + def __init__(self, parent): + super().__init__(parent._parent, icon="🎨") + self._metrics_parent = parent + + def _get_display( + self, + *, + X, + y, + data_source, + response_method, + display_class, + display_kwargs, + display_plot_kwargs, + ): + """Get the display from the cache or compute it. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The data. + + y : array-like of shape (n_samples,) + The target. + + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + response_method : str + The response method. + + display_class : class + The display class. + + display_kwargs : dict + The display kwargs used by `display_class._from_predictions`. + + display_plot_kwargs : dict + The display kwargs used by `display.plot`. + + Returns + ------- + display : display_class + The display. + """ + X, y, data_source_hash = self._get_X_y_and_data_source_hash( + data_source=data_source, X=X, y=y + ) + + cache_key = (self._parent._hash, display_class.__name__) + cache_key += tuple(display_kwargs.values()) + cache_key += (data_source_hash,) if data_source_hash else (data_source,) + + if cache_key in self._parent._cache: + display = self._parent._cache[cache_key] + display.plot(**display_plot_kwargs) + else: + y_pred = self._parent._get_cached_response_values( + estimator_hash=self._parent._hash, + estimator=self._parent.estimator, + X=X, + response_method=response_method, + data_source=data_source, + data_source_hash=data_source_hash, + pos_label=display_kwargs.get("pos_label", None), + ) + + display = display_class._from_predictions( + y, + y_pred, + estimator=self._parent.estimator, + estimator_name=self._parent.estimator_name, + ml_task=self._parent._ml_task, + data_source=data_source, + **display_kwargs, + **display_plot_kwargs, + ) + self._parent._cache[cache_key] = display + + return display + + @available_if( + _check_supported_ml_task( + supported_ml_tasks=["binary-classification", "multiclass-classification"] + ) + ) + def roc(self, *, data_source="test", X=None, y=None, pos_label=None, ax=None): + """Plot the ROC curve. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + pos_label : str, default=None + The positive class. + + ax : matplotlib.axes.Axes, default=None + The axes to plot on. + + Returns + ------- + RocCurveDisplay + The ROC curve display. + """ + response_method = ("predict_proba", "decision_function") + display_kwargs = {"pos_label": pos_label} + display_plot_kwargs = {"ax": ax, "plot_chance_level": True, "despine": True} + return self._get_display( + X=X, + y=y, + data_source=data_source, + response_method=response_method, + display_class=RocCurveDisplay, + display_kwargs=display_kwargs, + display_plot_kwargs=display_plot_kwargs, + ) + + @available_if( + _check_supported_ml_task( + supported_ml_tasks=["binary-classification", "multiclass-classification"] + ) + ) + def precision_recall( + self, + *, + data_source="test", + X=None, + y=None, + pos_label=None, + ax=None, + ): + """Plot the precision-recall curve. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + pos_label : str, default=None + The positive class. + + ax : matplotlib.axes.Axes, default=None + The axes to plot on. + + Returns + ------- + PrecisionRecallCurveDisplay + The precision-recall curve display. + """ + response_method = ("predict_proba", "decision_function") + display_kwargs = {"pos_label": pos_label} + display_plot_kwargs = {"ax": ax, "plot_chance_level": False, "despine": True} + return self._get_display( + X=X, + y=y, + data_source=data_source, + response_method=response_method, + display_class=PrecisionRecallCurveDisplay, + display_kwargs=display_kwargs, + display_plot_kwargs=display_plot_kwargs, + ) + + @available_if(_check_supported_ml_task(supported_ml_tasks=["regression"])) + def prediction_error( + self, + *, + data_source="test", + X=None, + y=None, + ax=None, + kind="residual_vs_predicted", + subsample=1_000, + ): + """Plot the prediction error of a regression model. + + Extra keyword arguments will be passed to matplotlib's `plot`. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + kind : {"actual_vs_predicted", "residual_vs_predicted"}, \ + default="residual_vs_predicted" + The type of plot to draw: + + - "actual_vs_predicted" draws the observed values (y-axis) vs. + the predicted values (x-axis). + - "residual_vs_predicted" draws the residuals, i.e. difference + between observed and predicted values, (y-axis) vs. the predicted + values (x-axis). + + subsample : float, int or None, default=1_000 + Sampling the samples to be shown on the scatter plot. If `float`, + it should be between 0 and 1 and represents the proportion of the + original dataset. If `int`, it represents the number of samples + display on the scatter plot. If `None`, no subsampling will be + applied. by default, 1,000 samples or less will be displayed. + + Returns + ------- + PredictionErrorDisplay + The prediction error display. + """ + display_kwargs = {"kind": kind, "subsample": subsample} + display_plot_kwargs = {"ax": ax} + return self._get_display( + X=X, + y=y, + data_source=data_source, + response_method="predict", + display_class=PredictionErrorDisplay, + display_kwargs=display_kwargs, + display_plot_kwargs=display_plot_kwargs, + ) + + def _get_help_panel_title(self): + return f"[bold cyan]{self._icon} Available plot methods[/bold cyan]" + + def _get_help_tree_title(self): + return "[bold cyan]reporter.metrics.plot[/bold cyan]" diff --git a/skore/src/skore/sklearn/_estimator/metrics_accessor.pyi b/skore/src/skore/sklearn/_estimator/metrics_accessor.pyi new file mode 100644 index 000000000..c73d136f1 --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/metrics_accessor.pyi @@ -0,0 +1,168 @@ +from typing import Any, Callable, Literal, Optional, Union + +import matplotlib.axes +import numpy as np +import pandas as pd +from sklearn.metrics import PrecisionRecallDisplay, RocCurveDisplay + +from skore.sklearn._estimator.base import _BaseAccessor +from skore.sklearn._plot import PredictionErrorDisplay + +class _PlotMetricsAccessor(_BaseAccessor): + _metrics_parent: _MetricsAccessor + + def __init__(self, parent: _MetricsAccessor) -> None: ... + def _get_display( + self, + *, + X: Optional[np.ndarray], + y: Optional[np.ndarray], + data_source: Literal["test", "train", "X_y"], + response_method: Union[str, list[str]], + display_class: Any, + display_kwargs: dict[str, Any], + display_plot_kwargs: dict[str, Any], + ) -> Union[RocCurveDisplay, PrecisionRecallDisplay, PredictionErrorDisplay]: ... + def roc( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + pos_label: Optional[Union[str, int]] = None, + ax: Optional[matplotlib.axes.Axes] = None, + ) -> RocCurveDisplay: ... + def precision_recall( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + pos_label: Optional[Union[str, int]] = None, + ax: Optional[matplotlib.axes.Axes] = None, + ) -> PrecisionRecallDisplay: ... + def prediction_error( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + ax: Optional[matplotlib.axes.Axes] = None, + kind: Literal[ + "actual_vs_predicted", "residual_vs_predicted" + ] = "residual_vs_predicted", + subsample: Optional[Union[int, float]] = 1_000, + ) -> PredictionErrorDisplay: ... + +class _MetricsAccessor(_BaseAccessor): + _SCORE_OR_LOSS_ICONS: dict[str, str] + plot: _PlotMetricsAccessor + + def _compute_metric_scores( + self, + metric_fn: Callable, + X: Optional[np.ndarray], + y_true: Optional[np.ndarray], + *, + data_source: Literal["test", "train", "X_y"] = "test", + response_method: Union[str, list[str]], + pos_label: Optional[Union[str, int]] = None, + metric_name: Optional[str] = None, + **metric_kwargs: Any, + ) -> pd.DataFrame: ... + def report_metrics( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + scoring: Optional[Union[list[str], Callable]] = None, + pos_label: Optional[Union[str, int]] = None, + scoring_kwargs: Optional[dict[str, Any]] = None, + ) -> pd.DataFrame: ... + def accuracy( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + ) -> pd.DataFrame: ... + def precision( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + average: Optional[ + Literal["binary", "micro", "macro", "weighted", "samples"] + ] = None, + pos_label: Optional[Union[str, int]] = None, + ) -> pd.DataFrame: ... + def recall( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + average: Optional[ + Literal["binary", "micro", "macro", "weighted", "samples"] + ] = None, + pos_label: Optional[Union[str, int]] = None, + ) -> pd.DataFrame: ... + def brier_score( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + pos_label: Optional[Union[str, int]] = None, + ) -> pd.DataFrame: ... + def roc_auc( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + average: Optional[ + Literal["auto", "micro", "macro", "weighted", "samples"] + ] = None, + multi_class: Literal["raise", "ovr", "ovo"] = "ovr", + ) -> pd.DataFrame: ... + def log_loss( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + ) -> pd.DataFrame: ... + def r2( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + multioutput: Union[ + Literal["raw_values", "uniform_average"], np.ndarray + ] = "raw_values", + ) -> pd.DataFrame: ... + def rmse( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + multioutput: Union[ + Literal["raw_values", "uniform_average"], np.ndarray + ] = "raw_values", + ) -> pd.DataFrame: ... + def custom_metric( + self, + metric_function: Callable, + response_method: Union[str, list[str]], + *, + metric_name: Optional[str] = None, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + **kwargs: Any, + ) -> pd.DataFrame: ... diff --git a/skore/src/skore/sklearn/_estimator/report.py b/skore/src/skore/sklearn/_estimator/report.py new file mode 100644 index 000000000..35b4e6425 --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/report.py @@ -0,0 +1,413 @@ +import inspect +import time +from itertools import product + +import joblib +import numpy as np +from rich.progress import track +from rich.tree import Tree +from sklearn.base import clone +from sklearn.exceptions import NotFittedError +from sklearn.pipeline import Pipeline +from sklearn.utils._response import _check_response_method, _get_response_values +from sklearn.utils.validation import check_is_fitted + +from skore.externals._pandas_accessors import DirNamesMixin +from skore.externals._sklearn_compat import is_clusterer +from skore.sklearn._estimator.base import _BaseAccessor, _HelpMixin +from skore.sklearn.find_ml_task import _find_ml_task + + +class EstimatorReport(_HelpMixin, DirNamesMixin): + """Report for a fitted estimator. + + This class provides a set of tools to quickly validate and inspect a scikit-learn + compatible estimator. + + Parameters + ---------- + estimator : estimator object + Estimator to make report from. + + fit : {"auto", True, False}, default="auto" + Whether to fit the estimator on the training data. If "auto", the estimator + is fitted only if the training data is provided. + + X_train : {array-like, sparse matrix} of shape (n_samples, n_features) or \ + None + Training data. + + y_train : array-like of shape (n_samples,) or (n_samples, n_outputs) or None + Training target. + + X_test : {array-like, sparse matrix} of shape (n_samples, n_features) or None + Testing data. It should have the same structure as the training data. + + y_test : array-like of shape (n_samples,) or (n_samples, n_outputs) or None + Testing target. + + Attributes + ---------- + metrics : _MetricsAccessor + Accessor for metrics-related operations. + + Examples + -------- + >>> from sklearn.datasets import make_classification + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.linear_model import LogisticRegression + >>> X, y = make_classification(random_state=42) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + >>> estimator = LogisticRegression().fit(X_train, y_train) + >>> from skore import EstimatorReport + >>> report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + """ + + _ACCESSOR_CONFIG = { + "metrics": {"icon": "📏", "name": "metrics"}, + # Add other accessors as they're implemented + # "inspection": {"icon": "🔍", "name": "inspection"}, + # "linting": {"icon": "✔️", "name": "linting"}, + } + + @staticmethod + def _fit_estimator(estimator, X_train, y_train): + if X_train is None or (y_train is None and not is_clusterer(estimator)): + raise ValueError( + "The training data is required to fit the estimator. " + "Please provide both X_train and y_train." + ) + return clone(estimator).fit(X_train, y_train) + + def __init__( + self, + estimator, + *, + fit="auto", + X_train=None, + y_train=None, + X_test=None, + y_test=None, + ): + if fit == "auto": + try: + check_is_fitted(estimator) + self._estimator = estimator + except NotFittedError: + self._estimator = self._fit_estimator(estimator, X_train, y_train) + elif fit is True: + self._estimator = self._fit_estimator(estimator, X_train, y_train) + else: # fit is False + self._estimator = estimator + + # private storage to be able to invalidate the cache when the user alters + # those attributes + self._X_train = X_train + self._y_train = y_train + self._X_test = X_test + self._y_test = y_test + + self._initialize_state() + + def _initialize_state(self): + """Initialize/reset the random number generator, hash, and cache.""" + self._rng = np.random.default_rng(time.time_ns()) + self._hash = self._rng.integers( + low=np.iinfo(np.int64).min, high=np.iinfo(np.int64).max + ) + self._cache = {} + self._ml_task = _find_ml_task(self._y_test, estimator=self._estimator) + + # NOTE: + # For the moment, we do not allow to alter the estimator and the training data. + # For the validation set, we allow it and we invalidate the cache. + + def clean_cache(self): + """Clean the cache.""" + self._cache = {} + + def cache_predictions(self, response_methods="auto", n_jobs=None): + """Force caching of estimator's predictions. + + Parameters + ---------- + response_methods : "auto" or list of str, default="auto" + The response methods to precompute. If "auto", the response methods are + inferred from the ml task: for classification we compute the response of + the `predict_proba`, `decision_function` and `predict` methods; for + regression we compute the response of the `predict` method. + + n_jobs : int or None, default=None + The number of jobs to run in parallel. None means 1 unless in a + joblib.parallel_backend context. -1 means using all processors. + """ + if self._ml_task in ("binary-classification", "multiclass-classification"): + if response_methods == "auto": + response_methods = ("predict",) + if hasattr(self._estimator, "predict_proba"): + response_methods = ("predict_proba",) + if hasattr(self._estimator, "decision_function"): + response_methods = ("decision_function",) + pos_labels = self._estimator.classes_ + else: + if response_methods == "auto": + response_methods = ("predict",) + pos_labels = [None] + + data_sources = ("test",) + Xs = (self._X_test,) + if self._X_train is not None: + data_sources = ("train",) + Xs = (self._X_train,) + + parallel = joblib.Parallel(n_jobs=n_jobs, return_as="generator_unordered") + generator = parallel( + joblib.delayed(self._get_cached_response_values)( + estimator_hash=self._hash, + estimator=self._estimator, + X=X, + response_method=response_method, + pos_label=pos_label, + data_source=data_source, + ) + for response_method, pos_label, data_source, X in product( + response_methods, pos_labels, data_sources, Xs + ) + ) + # trigger the computation + list( + track( + generator, + total=len(response_methods) * len(pos_labels) * len(data_sources), + description="Caching predictions", + ) + ) + + @property + def estimator(self): + return self._estimator + + @estimator.setter + def estimator(self, value): + raise AttributeError( + "The estimator attribute is immutable. " + "Call the constructor of {self.__class__.__name__} to create a new report." + ) + + @property + def X_train(self): + return self._X_train + + @X_train.setter + def X_train(self, value): + raise AttributeError( + "The X_train attribute is immutable. " + "Please use the `from_unfitted_estimator` method to create a new report." + ) + + @property + def y_train(self): + return self._y_train + + @y_train.setter + def y_train(self, value): + raise AttributeError( + "The y_train attribute is immutable. " + "Please use the `from_unfitted_estimator` method to create a new report." + ) + + @property + def X_test(self): + return self._X_test + + @X_test.setter + def X_test(self, value): + self._X_test = value + self._initialize_state() + + @property + def y_test(self): + return self._y_test + + @y_test.setter + def y_test(self, value): + self._y_test = value + self._initialize_state() + + @property + def estimator_name(self): + if isinstance(self._estimator, Pipeline): + name = self._estimator[-1].__class__.__name__ + else: + name = self._estimator.__class__.__name__ + return name + + def _get_cached_response_values( + self, + *, + estimator_hash, + estimator, + X, + response_method, + pos_label=None, + data_source="test", + data_source_hash=None, + ): + """Compute or load from local cache the response values. + + Parameters + ---------- + estimator_hash : int + A hash associated with the estimator such that we can retrieve the data from + the cache. + + estimator : estimator object + The estimator. + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The data. + + response_method : str + The response method. + + pos_label : str, default=None + The positive label. + + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + data_source_hash : int or None + The hash of the data source when `data_source` is "X_y". + + Returns + ------- + array-like of shape (n_samples,) or (n_samples, n_outputs) + The response values. + """ + prediction_method = _check_response_method(estimator, response_method).__name__ + if prediction_method in ("predict_proba", "decision_function"): + # pos_label is only important in classification and with probabilities + # and decision functions + cache_key = (estimator_hash, pos_label, prediction_method, data_source) + else: + cache_key = (estimator_hash, prediction_method, data_source) + + if data_source == "X_y": + data_source_hash = joblib.hash(X) + cache_key += (data_source_hash,) + + if cache_key in self._cache: + return self._cache[cache_key] + + predictions, _ = _get_response_values( + estimator, + X=X, + response_method=prediction_method, + pos_label=pos_label, + return_response_method_used=False, + ) + self._cache[cache_key] = predictions + + return predictions + + #################################################################################### + # Methods related to the help tree + #################################################################################### + + def _get_help_panel_title(self): + return ( + f"[bold cyan]📓 Tools to diagnose estimator " + f"{self.estimator_name}[/bold cyan]" + ) + + def _get_help_legend(self): + return ( + "[cyan](↗︎)[/cyan] higher is better [orange1](↘︎)[/orange1] lower is better" + ) + + def _get_attributes_for_help(self): + """Get the public attributes to display in help.""" + attributes = [] + xy_attributes = [] + + for name in dir(self): + # Skip private attributes, callables, and accessors + if ( + name.startswith("_") + or callable(getattr(self, name)) + or isinstance(getattr(self, name), _BaseAccessor) + ): + continue + + # Group X and y attributes separately + value = getattr(self, name) + if name.startswith(("X_", "y_")): + if value is not None: # Only include non-None X/y attributes + xy_attributes.append(name) + else: + attributes.append(name) + + # Sort X/y attributes to keep them grouped + xy_attributes.sort() + attributes.sort() + + # Return X/y attributes first, followed by other attributes + return xy_attributes + attributes + + def _create_help_tree(self): + """Create a rich Tree with the available tools and accessor methods.""" + tree = Tree("reporter") + + # Add accessor methods first + for accessor_attr, config in self._ACCESSOR_CONFIG.items(): + accessor = getattr(self, accessor_attr) + branch = tree.add( + f"[bold cyan].{config['name']} {config['icon']}[/bold cyan]" + ) + + # Add main accessor methods first + methods = accessor._get_methods_for_help() + methods = accessor._sort_methods_for_help(methods) + + # Add methods + for name, method in methods: + displayed_name = accessor._format_method_name(name) + description = accessor._get_method_description(method) + branch.add(f".{displayed_name} - {description}") + + # Add sub-accessors after main methods + for sub_attr, sub_obj in inspect.getmembers(accessor): + if isinstance(sub_obj, _BaseAccessor) and not sub_attr.startswith("_"): + sub_branch = branch.add( + f"[bold cyan].{sub_attr} {sub_obj._icon}[/bold cyan]" + ) + + # Add sub-accessor methods + sub_methods = sub_obj._get_methods_for_help() + sub_methods = sub_obj._sort_methods_for_help(sub_methods) + + for name, method in sub_methods: + displayed_name = sub_obj._format_method_name(name) + description = sub_obj._get_method_description(method) + sub_branch.add(f".{displayed_name.ljust(25)} - {description}") + + # Add base methods + base_methods = self._get_methods_for_help() + base_methods = self._sort_methods_for_help(base_methods) + + for name, method in base_methods: + description = self._get_method_description(method) + tree.add(f".{name}(...)".ljust(34) + f" - {description}") + + # Add attributes section + attributes = self._get_attributes_for_help() + if attributes: + attr_branch = tree.add("[bold cyan]Attributes[/bold cyan]") + for attr in attributes: + attr_branch.add(f".{attr}") + + return tree diff --git a/skore/src/skore/sklearn/_estimator/report.pyi b/skore/src/skore/sklearn/_estimator/report.pyi new file mode 100644 index 000000000..74c4c215d --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/report.pyi @@ -0,0 +1,71 @@ +from typing import Any, Literal, Optional, Union + +import numpy as np +from sklearn.base import BaseEstimator + +from skore.sklearn._estimator.base import _HelpMixin +from skore.sklearn._estimator.metrics_accessor import _MetricsAccessor + +class EstimatorReport(_HelpMixin): + _ACCESSOR_CONFIG: dict[str, dict[str, str]] + _estimator: BaseEstimator + _X_train: Optional[np.ndarray] + _y_train: Optional[np.ndarray] + _X_test: Optional[np.ndarray] + _y_test: Optional[np.ndarray] + _rng: np.random.Generator + _hash: int + _cache: dict[Any, Any] + _ml_task: str + metrics: _MetricsAccessor + + @staticmethod + def _fit_estimator( + estimator: BaseEstimator, X_train: np.ndarray, y_train: Optional[np.ndarray] + ) -> BaseEstimator: ... + def __init__( + self, + estimator: BaseEstimator, + *, + fit: Literal["auto", True, False] = "auto", + X_train: Optional[np.ndarray] = None, + y_train: Optional[np.ndarray] = None, + X_test: Optional[np.ndarray] = None, + y_test: Optional[np.ndarray] = None, + ) -> None: ... + def _initialize_state(self) -> None: ... + def clean_cache(self) -> None: ... + def cache_predictions( + self, + response_methods: Union[Literal["auto"], list[str]] = "auto", + n_jobs: Optional[int] = None, + ) -> None: ... + @property + def estimator(self) -> BaseEstimator: ... + @property + def X_train(self) -> Optional[np.ndarray]: ... + @property + def y_train(self) -> Optional[np.ndarray]: ... + @property + def X_test(self) -> Optional[np.ndarray]: ... + @X_test.setter + def X_test(self, value: Optional[np.ndarray]) -> None: ... + @property + def y_test(self) -> Optional[np.ndarray]: ... + @y_test.setter + def y_test(self, value: Optional[np.ndarray]) -> None: ... + @property + def estimator_name(self) -> str: ... + def _get_cached_response_values( + self, + *, + estimator_hash: int, + estimator: BaseEstimator, + X: np.ndarray, + response_method: Union[str, list[str]], + pos_label: Optional[Union[str, int]] = None, + data_source: Literal["test", "train", "X_y"] = "test", + data_source_hash: Optional[str] = None, + ) -> np.ndarray: ... + def _get_help_panel_title(self) -> str: ... + def _create_help_tree(self) -> Any: ... # Returns rich.tree.Tree diff --git a/skore/src/skore/sklearn/_estimator/utils.py b/skore/src/skore/sklearn/_estimator/utils.py new file mode 100644 index 000000000..578fab610 --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/utils.py @@ -0,0 +1,19 @@ +from sklearn.pipeline import Pipeline + + +def _check_supported_estimator(supported_estimators): + def check(accessor): + estimator = accessor._parent.estimator + if isinstance(estimator, Pipeline): + estimator = estimator.steps[-1][1] + supported_estimator = isinstance(estimator, supported_estimators) + + if not supported_estimator: + raise AttributeError( + f"The {estimator.__class__.__name__} estimator is not supported " + "by the function called." + ) + + return True + + return check diff --git a/skore/src/skore/sklearn/_plot/__init__.py b/skore/src/skore/sklearn/_plot/__init__.py new file mode 100644 index 000000000..7f39733e4 --- /dev/null +++ b/skore/src/skore/sklearn/_plot/__init__.py @@ -0,0 +1,9 @@ +from skore.sklearn._plot.precision_recall_curve import PrecisionRecallCurveDisplay +from skore.sklearn._plot.prediction_error import PredictionErrorDisplay +from skore.sklearn._plot.roc_curve import RocCurveDisplay + +__all__ = [ + "RocCurveDisplay", + "PrecisionRecallCurveDisplay", + "PredictionErrorDisplay", +] diff --git a/skore/src/skore/sklearn/_plot/precision_recall_curve.py b/skore/src/skore/sklearn/_plot/precision_recall_curve.py new file mode 100644 index 000000000..96881362b --- /dev/null +++ b/skore/src/skore/sklearn/_plot/precision_recall_curve.py @@ -0,0 +1,511 @@ +from collections import Counter + +from sklearn.metrics import average_precision_score, precision_recall_curve +from sklearn.preprocessing import LabelBinarizer + +from skore.sklearn._plot.utils import ( + HelpDisplayMixin, + _ClassifierCurveDisplayMixin, + _despine_matplotlib_axis, + _validate_style_kwargs, +) + + +class PrecisionRecallCurveDisplay(HelpDisplayMixin, _ClassifierCurveDisplayMixin): + """Precision Recall visualization. + + An instance of this class is should created by + `EstimatorReport.metrics.plot.precision_recall()`. You should not create an + instance of this class directly. + + + Parameters + ---------- + precision : dict of list of ndarray + Precision values. The structure is: + + - for binary classification: + - the key is the positive label. + - the value is a list of `ndarray`, each `ndarray` being the precision. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `ndarray`, each `ndarray` being the precision. + + recall : dict of list of ndarray + Recall values. The structure is: + + - for binary classification: + - the key is the positive label. + - the value is a list of `ndarray`, each `ndarray` being the recall. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `ndarray`, each `ndarray` being the recall. + + average_precision : dict of list of float + Average precision. The structure is: + + - for binary classification: + - the key is the positive label. + - the value is a list of `float`, each `float` being the average + precision. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `float`, each `float` being the average + precision. + + prevalence : dict of list of float + The prevalence of the positive label. The structure is: + + - for binary classification: + - the key is the positive label. + - the value is a list of `float`, each `float` being the prevalence. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `float`, each `float` being the prevalence. + + estimator_name : str + Name of the estimator. + + pos_label : int, float, bool or str, default=None + The class considered as the positive class. If None, the class will not + be shown in the legend. + + data_source : {"train", "test", "X_y"}, default=None + The data source used to compute the precision recall curve. + + Attributes + ---------- + ax_ : matplotlib Axes + Axes with precision recall curve. + + figure_ : matplotlib Figure + Figure containing the curve. + + lines_ : list of matplotlib Artist + Precision recall curve. + + chance_levels_ : matplotlib Artist or None + The chance level line. It is `None` if the chance level is not plotted. + """ + + def __init__( + self, + precision, + recall, + *, + average_precision, + prevalence, + estimator_name, + pos_label=None, + data_source=None, + ): + self.precision = precision + self.recall = recall + self.average_precision = average_precision + self.prevalence = prevalence + self.estimator_name = estimator_name + self.pos_label = pos_label + self.data_source = data_source + + def plot( + self, + ax=None, + *, + estimator_name=None, + pr_curve_kwargs=None, + plot_chance_level=False, + chance_level_kwargs=None, + despine=True, + ): + """Plot visualization. + + Extra keyword arguments will be passed to matplotlib's `plot`. + + Parameters + ---------- + ax : Matplotlib Axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + estimator_name : str, default=None + Name of the estimator used to plot the precision-recall curve. If + `None`, we use the inferred name from the estimator. + + plot_chance_level : bool, default=True + Whether to plot the chance level. The chance level is the prevalence + of the positive label computed from the data passed during + :meth:`from_estimator` or :meth:`from_predictions` call. + + pr_curve_kwargs : dict or list of dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the precision-recall curve(s). + + chance_level_kwargs : dict or list of dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + + Returns + ------- + display : PrecisionRecallCurveDisplay + Object that stores computed values. + + Notes + ----- + The average precision (cf. :func:`~sklearn.metrics.average_precision_score`) + in scikit-learn is computed without any interpolation. To be consistent + with this metric, the precision-recall curve is plotted without any + interpolation as well (step-wise style). + + You can change this style by passing the keyword argument + `drawstyle="default"`. However, the curve will not be strictly + consistent with the reported average precision. + """ + self.ax_, self.figure_, estimator_name = self._validate_plot_params( + ax=ax, estimator_name=estimator_name + ) + + self.lines_ = [] + self.chance_levels_ = [] + if len(self.precision) == 1: # binary-classification + if len(self.precision[self.pos_label]) == 1: # single-split + if pr_curve_kwargs is None: + pr_curve_kwargs = {} + elif isinstance(pr_curve_kwargs, list): + if len(pr_curve_kwargs) > 1: + raise ValueError( + "You intend to plot a single precision-recall curve and " + "provide multiple precision-recall curve keyword " + "arguments. Provide a single dictionary or a list with " + "a single dictionary." + ) + pr_curve_kwargs = pr_curve_kwargs[0] + + precision = self.precision[self.pos_label][0] + recall = self.recall[self.pos_label][0] + average_precision = self.average_precision[self.pos_label][0] + prevalence = self.prevalence[self.pos_label][0] + + default_line_kwargs = {"drawstyle": "steps-post"} + if average_precision is not None and self.data_source in ( + "train", + "test", + ): + default_line_kwargs["label"] = ( + f"{self.data_source.title()} set " + f"(AP = {average_precision:0.2f})" + ) + elif average_precision is not None: # data_source in (None, "X_y") + default_line_kwargs["label"] = f"AP = {average_precision:0.2f}" + + line_kwargs = _validate_style_kwargs( + default_line_kwargs, pr_curve_kwargs + ) + + (line_,) = self.ax_.plot(recall, precision, **line_kwargs) + self.lines_.append(line_) + + if plot_chance_level: + default_chance_level_line_kwargs = { + "label": f"Chance level (AP = {prevalence:0.2f})", + "color": "k", + "linestyle": "--", + } + + if chance_level_kwargs is None: + chance_level_kwargs = {} + elif isinstance(chance_level_kwargs, list): + if len(chance_level_kwargs) > 1: + raise ValueError( + "You intend to plot a single chance level line and " + "provide multiple chance level line keyword " + "arguments. Provide a single dictionary or a list " + "with a single dictionary." + ) + chance_level_kwargs = chance_level_kwargs[0] + + chance_level_line_kwargs = _validate_style_kwargs( + default_chance_level_line_kwargs, chance_level_kwargs + ) + + (chance_level_,) = self.ax_.plot( + (0, 1), (prevalence, prevalence), **chance_level_line_kwargs + ) + self.chance_levels_.append(chance_level_) + else: + self.chance_levels_ = None + else: # cross-validation + raise NotImplementedError( + "We don't support yet cross-validation" + ) # pragma: no cover + + info_pos_label = ( + f"\n(Positive label: {self.pos_label})" + if self.pos_label is not None + else "" + ) + else: # multiclass-classification + info_pos_label = None # irrelevant for multiclass + if pr_curve_kwargs is None: + pr_curve_kwargs = [{}] * len(self.precision) + elif isinstance(pr_curve_kwargs, list): + if len(pr_curve_kwargs) != len(self.precision): + raise ValueError( + "You intend to plot multiple precision-recall curves. We " + "expect `pr_curve_kwargs` to be a list of dictionaries with " + "the same length as the number of precision-recall curves. " + "Got " + f"{len(pr_curve_kwargs)} instead of " + f"{len(self.precision)}." + ) + else: + raise ValueError( + "You intend to plot multiple precision-recall curves. We expect " + "`pr_curve_kwargs` to be a list of dictionaries of " + f"{len(self.precision)} elements. Got {pr_curve_kwargs!r} instead." + ) + + if plot_chance_level: + if chance_level_kwargs is None: + chance_level_kwargs = [{}] * len(self.precision) + elif isinstance(chance_level_kwargs, list): + if len(chance_level_kwargs) != len(self.precision): + raise ValueError( + "You intend to plot multiple precision-recall curves. We " + "expect `chance_level_kwargs` to be a list of dictionaries " + "with the same length as the number of precision-recall " + "curves. Got " + f"{len(chance_level_kwargs)} instead of " + f"{len(self.precision)}." + ) + else: + raise ValueError( + "You intend to plot multiple precision-recall curves. We " + "expect `chance_level_kwargs` to be a list of dictionaries of " + f"{len(self.precision)} elements. Got {chance_level_kwargs!r} " + "instead." + ) + + for class_idx, class_ in enumerate(self.precision): + precision_class = self.precision[class_] + recall_class = self.recall[class_] + average_precision_class = self.average_precision[class_] + prevalence_class = self.prevalence[class_] + pr_curve_kwargs_class = pr_curve_kwargs[class_idx] + + if len(precision_class) == 1: # single-split + precision = precision_class[0] + recall = recall_class[0] + average_precision = average_precision_class[0] + prevalence = prevalence_class[0] + + default_line_kwargs = {"drawstyle": "steps-post"} + if average_precision is not None and self.data_source in ( + "train", + "test", + ): + default_line_kwargs["label"] = ( + f"{str(class_).title()} - {self.data_source} set " + f"(AP = {average_precision:0.2f})" + ) + elif average_precision is not None: # data_source in (None, "X_y") + default_line_kwargs["label"] = ( + f"{str(class_).title()} AP = {average_precision:0.2f}" + ) + + line_kwargs = _validate_style_kwargs( + default_line_kwargs, pr_curve_kwargs_class + ) + + (line_,) = self.ax_.plot(recall, precision, **line_kwargs) + self.lines_.append(line_) + + if plot_chance_level: + chance_level_kwargs_class = chance_level_kwargs[class_idx] + + default_chance_level_line_kwargs = { + "label": ( + f"Chance level - {str(class_).title()} " + f"(AP = {prevalence:0.2f})" + ), + "color": "k", + "linestyle": "--", + } + + chance_level_line_kwargs = _validate_style_kwargs( + default_chance_level_line_kwargs, chance_level_kwargs_class + ) + + (chance_level_,) = self.ax_.plot( + (0, 1), (prevalence, prevalence), **chance_level_line_kwargs + ) + self.chance_levels_.append(chance_level_) + else: + self.chance_levels_ = None + else: # cross-validation + raise NotImplementedError( + "We don't support yet cross-validation" + ) # pragma: no cover + + xlabel = "Recall" + ylabel = "Precision" + if info_pos_label: + xlabel += info_pos_label + ylabel += info_pos_label + + self.ax_.set( + xlabel=xlabel, + xlim=(-0.01, 1.01), + ylabel=ylabel, + ylim=(-0.01, 1.01), + aspect="equal", + ) + + if despine: + _despine_matplotlib_axis(self.ax_) + + self.ax_.legend(loc="lower left", title=estimator_name) + + @classmethod + def _from_predictions( + cls, + y_true, + y_pred, + *, + estimator, + estimator_name, + ml_task, + data_source=None, + pos_label=None, + drop_intermediate=False, + ax=None, + pr_curve_kwargs=None, + plot_chance_level=False, + chance_level_kwargs=None, + despine=True, + ): + """Plot precision-recall curve given binary class predictions. + + Parameters + ---------- + y_true : array-like of shape (n_samples,) + True binary labels. + + y_pred : array-like of shape (n_samples,) + Target scores, can either be probability estimates of the positive class, + confidence values, or non-thresholded measure of decisions (as returned by + “decision_function” on some classifiers). + + estimator : estimator instance + The estimator from which `y_pred` is obtained. + + estimator_name : str + Name of the estimator used to plot the precision-recall curve. + + ml_task : {"binary-classification", "multiclass-classification"} + The machine learning task. + + data_source : {"train", "test", "X_y"}, default=None + The data source used to compute the ROC curve. + + pos_label : int, float, bool or str, default=None + The class considered as the positive class when computing the + precision and recall metrics. + + drop_intermediate : bool, default=False + Whether to drop some suboptimal thresholds which would not appear + on a plotted precision-recall curve. This is useful in order to + create lighter precision-recall curves. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is created. + + pr_curve_kwargs : dict or list of dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the precision-recall curve(s). + + plot_chance_level : bool, default=False + Whether to plot the chance level. The chance level is the prevalence + of the positive label computed from the data passed during + :meth:`from_estimator` or :meth:`from_predictions` call. + + chance_level_kwargs : dict or list of dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + + **kwargs : dict + Keyword arguments to be passed to matplotlib's `plot`. + + Returns + ------- + display : :class:`~sklearn.metrics.PrecisionRecallDisplay` + """ + pos_label_validated = cls._validate_from_predictions_params( + y_true, y_pred, ml_task=ml_task, pos_label=pos_label + ) + + if ml_task == "binary-classification": + precision, recall, _ = precision_recall_curve( + y_true, + y_pred, + pos_label=pos_label_validated, + drop_intermediate=drop_intermediate, + ) + average_precision = average_precision_score( + y_true, y_pred, pos_label=pos_label_validated + ) + + class_count = Counter(y_true) + prevalence = class_count[pos_label_validated] / sum(class_count.values()) + + precision = {pos_label_validated: [precision]} + recall = {pos_label_validated: [recall]} + average_precision = {pos_label_validated: [average_precision]} + prevalence = {pos_label_validated: [prevalence]} + else: # multiclass-classification + precision, recall, average_precision, prevalence = {}, {}, {}, {} + label_binarizer = LabelBinarizer().fit(estimator.classes_) + y_true_onehot = label_binarizer.transform(y_true) + for class_idx, class_ in enumerate(estimator.classes_): + precision_class, recall_class, _ = precision_recall_curve( + y_true_onehot[:, class_idx], + y_pred[:, class_idx], + pos_label=None, + drop_intermediate=drop_intermediate, + ) + average_precision_class = average_precision_score( + y_true_onehot[:, class_idx], y_pred[:, class_idx] + ) + class_count = Counter(y_true) + prevalence_class = class_count[class_] / sum(class_count.values()) + + precision[class_] = [precision_class] + recall[class_] = [recall_class] + average_precision[class_] = [average_precision_class] + prevalence[class_] = [prevalence_class] + + viz = cls( + precision=precision, + recall=recall, + average_precision=average_precision, + prevalence=prevalence, + estimator_name=estimator_name, + pos_label=pos_label_validated, + data_source=data_source, + ) + + viz.plot( + ax=ax, + estimator_name=estimator_name, + pr_curve_kwargs=pr_curve_kwargs, + plot_chance_level=plot_chance_level, + chance_level_kwargs=chance_level_kwargs, + despine=despine, + ) + + return viz diff --git a/skore/src/skore/sklearn/_plot/prediction_error.py b/skore/src/skore/sklearn/_plot/prediction_error.py new file mode 100644 index 000000000..392f6058b --- /dev/null +++ b/skore/src/skore/sklearn/_plot/prediction_error.py @@ -0,0 +1,318 @@ +import numbers + +import matplotlib.pyplot as plt +import numpy as np +from sklearn.utils.validation import check_random_state + +from skore.externals._sklearn_compat import _safe_indexing +from skore.sklearn._plot.utils import ( + HelpDisplayMixin, + _despine_matplotlib_axis, + _validate_style_kwargs, +) + + +class PredictionErrorDisplay(HelpDisplayMixin): + """Visualization of the prediction error of a regression model. + + This tool can display "residuals vs predicted" or "actual vs predicted" + using scatter plots to qualitatively assess the behavior of a regressor, + preferably on held-out data points. + + An instance of this class is should created by + `EstimatorReport.metrics.plot.prediction_error()`. + You should not create an instance of this class directly. + + Parameters + ---------- + ----------z + y_true : ndarray of shape (n_samples,) + True values. + + y_pred : ndarray of shape (n_samples,) + Prediction values. + + estimator_name : str + Name of the estimator. + + data_source : {"train", "test", "X_y"}, default=None + The data source used to compute the ROC curve. + + Attributes + ---------- + line_ : matplotlib Artist + Optimal line representing `y_true == y_pred`. Therefore, it is a + diagonal line for `kind="predictions"` and a horizontal line for + `kind="residuals"`. + + errors_lines_ : matplotlib Artist or None + Residual lines. If `with_errors=False`, then it is set to `None`. + + scatter_ : matplotlib Artist + Scatter data points. + + ax_ : matplotlib Axes + Axes with the different matplotlib axis. + + figure_ : matplotlib Figure + Figure containing the scatter and lines. + """ + + def __init__(self, *, y_true, y_pred, estimator_name, data_source=None): + self.y_true = y_true + self.y_pred = y_pred + self.estimator_name = estimator_name + self.data_source = data_source + + def plot( + self, + ax=None, + *, + estimator_name=None, + kind="residual_vs_predicted", + scatter_kwargs=None, + line_kwargs=None, + despine=True, + ): + """Plot visualization. + + Extra keyword arguments will be passed to matplotlib's ``plot``. + + Parameters + ---------- + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + estimator_name : str, default=None + Name of the estimator used to plot the prediction error. If `None`, + we used the inferred name from the estimator. + + kind : {"actual_vs_predicted", "residual_vs_predicted"}, \ + default="residual_vs_predicted" + The type of plot to draw: + + - "actual_vs_predicted" draws the observed values (y-axis) vs. + the predicted values (x-axis). + - "residual_vs_predicted" draws the residuals, i.e. difference + between observed and predicted values, (y-axis) vs. the predicted + values (x-axis). + + scatter_kwargs : dict, default=None + Dictionary with keywords passed to the `matplotlib.pyplot.scatter` + call. + + line_kwargs : dict, default=None + Dictionary with keyword passed to the `matplotlib.pyplot.plot` + call to draw the optimal line. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + + Returns + ------- + display : PredictionErrorDisplay + Object that stores computed values. + """ + expected_kind = ("actual_vs_predicted", "residual_vs_predicted") + if kind not in expected_kind: + raise ValueError( + f"`kind` must be one of {', '.join(expected_kind)}. " + f"Got {kind!r} instead." + ) + + if scatter_kwargs is None: + scatter_kwargs = {} + if line_kwargs is None: + line_kwargs = {} + + default_scatter_kwargs = {"color": "tab:blue", "alpha": 0.8} + default_line_kwargs = {"color": "black", "alpha": 0.7, "linestyle": "--"} + + scatter_kwargs = _validate_style_kwargs(default_scatter_kwargs, scatter_kwargs) + line_kwargs = _validate_style_kwargs(default_line_kwargs, line_kwargs) + + scatter_kwargs = {**default_scatter_kwargs, **scatter_kwargs} + line_kwargs = {**default_line_kwargs, **line_kwargs} + + if self.data_source in ("train", "test"): + scatter_label = f"{self.data_source.title()} set" + else: + scatter_label = "Data set" + + if estimator_name is None: + estimator_name = self.estimator_name + + if ax is None: + _, ax = plt.subplots() + + if kind == "actual_vs_predicted": + max_value = max(np.max(self.y_true), np.max(self.y_pred)) + min_value = min(np.min(self.y_true), np.min(self.y_pred)) + + x_range = (min_value, max_value) + y_range = (min_value, max_value) + + self.line_ = ax.plot( + [min_value, max_value], + [min_value, max_value], + label="Perfect predictions", + **line_kwargs, + )[0] + + x_data, y_data = self.y_pred, self.y_true + xlabel, ylabel = "Predicted values", "Actual values" + + self.scatter_ = ax.scatter( + x_data, y_data, label=scatter_label, **scatter_kwargs + ) + + # force to have a squared axis + ax.set_aspect("equal", adjustable="datalim") + ax.set_xticks(np.linspace(min_value, max_value, num=5)) + ax.set_yticks(np.linspace(min_value, max_value, num=5)) + else: # kind == "residual_vs_predicted" + x_range = (np.min(self.y_pred), np.max(self.y_pred)) + residuals = self.y_true - self.y_pred + y_range = (np.min(residuals), np.max(residuals)) + + self.line_ = ax.plot( + [np.min(self.y_pred), np.max(self.y_pred)], + [0, 0], + label="Perfect predictions", + **line_kwargs, + )[0] + + self.scatter_ = ax.scatter( + self.y_pred, residuals, label=scatter_label, **scatter_kwargs + ) + xlabel, ylabel = "Predicted values", "Residuals (actual - predicted)" + + ax.set(xlabel=xlabel, ylabel=ylabel) + ax.legend(title=estimator_name) + + self.ax_ = ax + self.figure_ = ax.figure + + if despine: + _despine_matplotlib_axis(self.ax_, x_range=x_range, y_range=y_range) + + @classmethod + def _from_predictions( + cls, + y_true, + y_pred, + *, + estimator, # currently only for consistency with other plots + estimator_name, + ml_task, # FIXME: to be used when having single-output vs. multi-output + data_source=None, + kind="residual_vs_predicted", + subsample=1_000, + random_state=None, + ax=None, + scatter_kwargs=None, + line_kwargs=None, + despine=True, + ): + """Plot the prediction error given the true and predicted targets. + + Parameters + ---------- + y_true : array-like of shape (n_samples,) + True target values. + + y_pred : array-like of shape (n_samples,) + Predicted target values. + + estimator : estimator instance + The estimator from which `y_pred` is obtained. + + estimator_name : str, + The name of the estimator. + + ml_task : {"binary-classification", "multiclass-classification"} + The machine learning task. + + data_source : {"train", "test", "X_y"}, default=None + The data source used to compute the ROC curve. + + kind : {"actual_vs_predicted", "residual_vs_predicted"}, \ + default="residual_vs_predicted" + The type of plot to draw: + + - "actual_vs_predicted" draws the observed values (y-axis) vs. + the predicted values (x-axis). + - "residual_vs_predicted" draws the residuals, i.e. difference + between observed and predicted values, (y-axis) vs. the predicted + values (x-axis). + + subsample : float, int or None, default=1_000 + Sampling the samples to be shown on the scatter plot. If `float`, + it should be between 0 and 1 and represents the proportion of the + original dataset. If `int`, it represents the number of samples + display on the scatter plot. If `None`, no subsampling will be + applied. by default, 1000 samples or less will be displayed. + + random_state : int or RandomState, default=None + Controls the randomness when `subsample` is not `None`. + See :term:`Glossary ` for details. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + scatter_kwargs : dict, default=None + Dictionary with keywords passed to the `matplotlib.pyplot.scatter` + call. + + line_kwargs : dict, default=None + Dictionary with keyword passed to the `matplotlib.pyplot.plot` + call to draw the optimal line. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + + Returns + ------- + display : PredictionErrorDisplay + Object that stores the computed values. + """ + random_state = check_random_state(random_state) + + n_samples = len(y_true) + if isinstance(subsample, numbers.Integral): + if subsample <= 0: + raise ValueError( + f"When an integer, subsample={subsample} should be positive." + ) + elif isinstance(subsample, numbers.Real): + if subsample <= 0 or subsample >= 1: + raise ValueError( + f"When a floating-point, subsample={subsample} should" + " be in the (0, 1) range." + ) + subsample = int(n_samples * subsample) + + if subsample is not None and subsample < n_samples: + indices = random_state.choice(np.arange(n_samples), size=subsample) + y_true = _safe_indexing(y_true, indices, axis=0) + y_pred = _safe_indexing(y_pred, indices, axis=0) + + viz = cls( + y_true=y_true, + y_pred=y_pred, + estimator_name=estimator_name, + data_source=data_source, + ) + + viz.plot( + ax=ax, + estimator_name=estimator_name, + kind=kind, + scatter_kwargs=scatter_kwargs, + line_kwargs=line_kwargs, + despine=despine, + ) + + return viz diff --git a/skore/src/skore/sklearn/_plot/roc_curve.py b/skore/src/skore/sklearn/_plot/roc_curve.py new file mode 100644 index 000000000..013d491e3 --- /dev/null +++ b/skore/src/skore/sklearn/_plot/roc_curve.py @@ -0,0 +1,399 @@ +from sklearn.metrics import auc, roc_curve +from sklearn.preprocessing import LabelBinarizer + +from skore.sklearn._plot.utils import ( + HelpDisplayMixin, + _ClassifierCurveDisplayMixin, + _despine_matplotlib_axis, + _validate_style_kwargs, +) + + +class RocCurveDisplay(HelpDisplayMixin, _ClassifierCurveDisplayMixin): + """ROC Curve visualization. + + An instance of this class is should created by `EstimatorReport.metrics.plot.roc()`. + You should not create an instance of this class directly. + + Parameters + ---------- + fpr : dict of list of ndarray + False positive rate. The structure is: + + - for binary classification: + - the key is the positive label. + - the value is a list of `ndarray`, each `ndarray` being the false + positive rate. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `ndarray`, each `ndarray` being the false + positive rate. + + tpr : dict of list of ndarray + True positive rate. The structure is: + + - for binary classification: + - the key is the positive label + - the value is a list of `ndarray`, each `ndarray` being the true + positive rate. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `ndarray`, each `ndarray` being the true + positive rate. + + roc_auc : dict of list of float + Area under the ROC curve. The structure is: + + - for binary classification: + - the key is the positive label + - the value is a list of `float`, each `float` being the area under + the ROC curve. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `float`, each `float` being the area under + the ROC curve. + + estimator_name : str + Name of the estimator. + + pos_label : str, default=None + The class considered as positive. Only meaningful for binary classification. + + data_source : {"train", "test", "X_y"}, default=None + The data source used to compute the ROC curve. + + Attributes + ---------- + ax_ : matplotlib axes + The axes on which the ROC curve is plotted. + + figure_ : matplotlib figure + The figure on which the ROC curve is plotted. + + lines_ : list of matplotlib lines + The lines of the ROC curve. + + chance_level_ : matplotlib line + The chance level line. + """ + + def __init__( + self, + *, + fpr, + tpr, + roc_auc, + estimator_name, + pos_label=None, + data_source=None, + ): + self.estimator_name = estimator_name + self.fpr = fpr + self.tpr = tpr + self.roc_auc = roc_auc + self.pos_label = pos_label + self.data_source = data_source + + def plot( + self, + ax=None, + *, + estimator_name=None, + roc_curve_kwargs=None, + plot_chance_level=True, + chance_level_kwargs=None, + despine=True, + ): + """Plot visualization. + + Extra keyword arguments will be passed to matplotlib's ``plot``. + + Parameters + ---------- + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + estimator_name : str, default=None + Name of the estimator used to plot the ROC curve. If `None`, we use + the inferred name from the estimator. + + roc_curve_kwargs : dict or list of dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the ROC curve(s). + + plot_chance_level : bool, default=True + Whether to plot the chance level. + + chance_level_kwargs : dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + + Returns + ------- + display : :class:`~sklearn.metrics.RocCurveDisplay` + Object that stores computed values. + """ + self.ax_, self.figure_, estimator_name = self._validate_plot_params( + ax=ax, estimator_name=estimator_name + ) + + self.lines_ = [] + if len(self.fpr) == 1: # binary-classification + if len(self.fpr[self.pos_label]) == 1: # single-split + if roc_curve_kwargs is None: + roc_curve_kwargs = {} + elif isinstance(roc_curve_kwargs, list): + if len(roc_curve_kwargs) > 1: + raise ValueError( + "You intend to plot a single ROC curve and provide " + "multiple ROC curve keyword arguments. Provide a single " + "dictionary or a list with a single dictionary." + ) + roc_curve_kwargs = roc_curve_kwargs[0] + + fpr = self.fpr[self.pos_label][0] + tpr = self.tpr[self.pos_label][0] + roc_auc = self.roc_auc[self.pos_label][0] + + default_line_kwargs = {} + if roc_auc is not None and self.data_source in ("train", "test"): + default_line_kwargs["label"] = ( + f"{self.data_source.title()} set (AUC = {roc_auc:0.2f})" + ) + elif roc_auc is not None: # data_source in (None, "X_y") + default_line_kwargs["label"] = f"AUC = {roc_auc:0.2f}" + + line_kwargs = _validate_style_kwargs( + default_line_kwargs, roc_curve_kwargs + ) + + (line_,) = self.ax_.plot(fpr, tpr, **line_kwargs) + self.lines_.append(line_) + else: # cross-validation + raise NotImplementedError( + "We don't support yet cross-validation" + ) # pragma: no cover + + info_pos_label = ( + f"\n(Positive label: {self.pos_label})" + if self.pos_label is not None + else "" + ) + else: # multiclass-classification + info_pos_label = None # irrelevant for multiclass + if roc_curve_kwargs is None: + roc_curve_kwargs = [{}] * len(self.fpr) + elif isinstance(roc_curve_kwargs, list): + if len(roc_curve_kwargs) != len(self.fpr): + raise ValueError( + "You intend to plot multiple ROC curves. We expect " + "`roc_curve_kwargs` to be a list of dictionaries with the " + "same length as the number of ROC curves. Got " + f"{len(roc_curve_kwargs)} instead of " + f"{len(self.fpr)}." + ) + else: + raise ValueError( + "You intend to plot multiple ROC curves. We expect " + "`roc_curve_kwargs` to be a list of dictionaries of " + f"{len(self.fpr)} elements. Got {roc_curve_kwargs!r} instead." + ) + + for class_idx, class_ in enumerate(self.fpr): + fpr_class = self.fpr[class_] + tpr_class = self.tpr[class_] + roc_auc_class = self.roc_auc[class_] + roc_curve_kwargs_class = roc_curve_kwargs[class_idx] + + if len(fpr_class) == 1: # single-split + fpr = fpr_class[0] + tpr = tpr_class[0] + roc_auc = roc_auc_class[0] + + default_line_kwargs = {} + if roc_auc is not None and self.data_source in ("train", "test"): + default_line_kwargs["label"] = ( + f"{str(class_).title()} - {self.data_source} " + f"set (AUC = {roc_auc:0.2f})" + ) + elif roc_auc is not None: # data_source in (None, "X_y") + default_line_kwargs["label"] = ( + f"{str(class_).title()} AUC = {roc_auc:0.2f}" + ) + + line_kwargs = _validate_style_kwargs( + default_line_kwargs, roc_curve_kwargs_class + ) + + (line_,) = self.ax_.plot(fpr, tpr, **line_kwargs) + self.lines_.append(line_) + else: # cross-validation + raise NotImplementedError( + "We don't support yet cross-validation" + ) # pragma: no cover + + default_chance_level_line_kw = { + "label": "Chance level (AUC = 0.5)", + "color": "k", + "linestyle": "--", + } + + if chance_level_kwargs is None: + chance_level_kwargs = {} + + chance_level_kwargs = _validate_style_kwargs( + default_chance_level_line_kw, chance_level_kwargs + ) + + xlabel = "False Positive Rate" + ylabel = "True Positive Rate" + if info_pos_label: + xlabel += info_pos_label + ylabel += info_pos_label + + self.ax_.set( + xlabel=xlabel, + xlim=(-0.01, 1.01), + ylabel=ylabel, + ylim=(-0.01, 1.01), + aspect="equal", + ) + + if plot_chance_level: + (self.chance_level_,) = self.ax_.plot((0, 1), (0, 1), **chance_level_kwargs) + else: + self.chance_level_ = None + + if despine: + _despine_matplotlib_axis(self.ax_) + + self.ax_.legend(loc="lower right", title=estimator_name) + + @classmethod + def _from_predictions( + cls, + y_true, + y_pred, + *, + estimator, + estimator_name, + ml_task, + data_source=None, + pos_label=None, + drop_intermediate=True, + ax=None, + roc_curve_kwargs=None, + plot_chance_level=True, + chance_level_kwargs=None, + despine=True, + ): + """Private method to create a RocCurveDisplay from predictions. + + Parameters + ---------- + y_true : array-like of shape (n_samples,) + True binary labels in binary classification. + + y_pred : array-like of shape (n_samples,) + Target scores, can either be probability estimates of the positive class, + confidence values, or non-thresholded measure of decisions (as returned by + “decision_function” on some classifiers). + + estimator : estimator instance + The estimator from which `y_pred` is obtained. + + estimator_name : str + Name of the estimator used to plot the ROC curve. + + ml_task : {"binary-classification", "multiclass-classification"} + The machine learning task. + + data_source : {"train", "test", "X_y"}, default=None + The data source used to compute the ROC curve. + + pos_label : int, float, bool or str, default=None + The class considered as the positive class when computing the + precision and recall metrics. + + drop_intermediate : bool, default=True + Whether to drop intermediate points with identical value. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + roc_curve_kwargs : dict or list of dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the ROC curve(s). + + plot_chance_level : bool, default=True + Whether to plot the chance level. + + chance_level_kwargs : dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + + Returns + ------- + display : RocCurveDisplay + Object that stores computed values. + """ + pos_label_validated = cls._validate_from_predictions_params( + y_true, y_pred, ml_task=ml_task, pos_label=pos_label + ) + + if ml_task == "binary-classification": + fpr, tpr, _ = roc_curve( + y_true, + y_pred, + pos_label=pos_label, + drop_intermediate=drop_intermediate, + ) + roc_auc = auc(fpr, tpr) + fpr = {pos_label_validated: [fpr]} + tpr = {pos_label_validated: [tpr]} + roc_auc = {pos_label_validated: [roc_auc]} + else: # multiclass-classification + # OvR fashion to collect fpr, tpr, and roc_auc + fpr, tpr, roc_auc = {}, {}, {} + label_binarizer = LabelBinarizer().fit(estimator.classes_) + y_true_onehot = label_binarizer.transform(y_true) + for class_idx, class_ in enumerate(estimator.classes_): + fpr_class, tpr_class, _ = roc_curve( + y_true_onehot[:, class_idx], + y_pred[:, class_idx], + pos_label=None, + drop_intermediate=drop_intermediate, + ) + roc_auc_class = auc(fpr_class, tpr_class) + + fpr[class_] = [fpr_class] + tpr[class_] = [tpr_class] + roc_auc[class_] = [roc_auc_class] + + viz = cls( + fpr=fpr, + tpr=tpr, + roc_auc=roc_auc, + estimator_name=estimator_name, + pos_label=pos_label_validated, + data_source=data_source, + ) + + viz.plot( + ax=ax, + estimator_name=estimator_name, + roc_curve_kwargs=roc_curve_kwargs, + plot_chance_level=plot_chance_level, + chance_level_kwargs=chance_level_kwargs, + despine=despine, + ) + + return viz diff --git a/skore/src/skore/sklearn/_plot/utils.py b/skore/src/skore/sklearn/_plot/utils.py new file mode 100644 index 000000000..14e24bedb --- /dev/null +++ b/skore/src/skore/sklearn/_plot/utils.py @@ -0,0 +1,200 @@ +import inspect +from io import StringIO + +import matplotlib.pyplot as plt +from rich.console import Console +from rich.panel import Panel +from rich.tree import Tree +from sklearn.utils.validation import ( + _check_pos_label_consistency, + check_consistent_length, +) + + +class HelpDisplayMixin: + """Mixin class to add help functionality to a class.""" + + def _get_attributes_for_help(self): + """Get the attributes ending with '_' to display in help.""" + attributes = [] + for name in dir(self): + if name.endswith("_") and not name.startswith("_"): + attributes.append(f".{name}") + return sorted(attributes) + + def _get_methods_for_help(self): + """Get the public methods to display in help.""" + methods = inspect.getmembers(self, predicate=inspect.ismethod) + filtered_methods = [] + for name, method in methods: + is_private = name.startswith("_") + is_class_method = inspect.ismethod(method) and method.__self__ is type(self) + is_help_method = name == "help" + if not (is_private or is_class_method or is_help_method): + filtered_methods.append((f".{name}(...)", method)) + return sorted(filtered_methods) + + def _create_help_tree(self): + """Create a rich Tree with attributes and methods.""" + tree = Tree("display") + + attributes = self._get_attributes_for_help() + attr_branch = tree.add("[bold cyan] Attributes[/bold cyan]") + # Ensure figure_ and ax_ are first + sorted_attrs = sorted(attributes) + sorted_attrs.remove(".ax_") + sorted_attrs.remove(".figure_") + sorted_attrs = [".figure_", ".ax_"] + [ + attr for attr in sorted_attrs if attr not in [".figure_", ".ax_"] + ] + for attr in sorted_attrs: + attr_branch.add(attr) + + methods = self._get_methods_for_help() + method_branch = tree.add("[bold cyan]Methods[/bold cyan]") + for name, method in methods: + description = ( + method.__doc__.split("\n")[0] + if method.__doc__ + else "No description available" + ) + method_branch.add(f"{name} - {description}") + + return tree + + def _create_help_panel(self): + return Panel( + self._create_help_tree(), + title=( + f"[bold cyan]📊 {self.__class__.__name__} for {self.estimator_name}" + "[/bold cyan]" + ), + border_style="orange1", + expand=False, + ) + + def help(self): + """Display available attributes and methods using rich.""" + from skore import console # avoid circular import + + console.print(self._create_help_panel()) + + def __repr__(self): + """Return a string representation using rich.""" + console = Console(file=StringIO(), force_terminal=False) + console.print(self._create_help_panel()) + return console.file.getvalue() + + +class _ClassifierCurveDisplayMixin: + """Mixin class to be used in Displays requiring a binary classifier. + + The aim of this class is to centralize some validations regarding the estimator and + the target and gather the response of the estimator. + """ + + def _validate_plot_params(self, *, ax, estimator_name): + if ax is None: + _, ax = plt.subplots() + + estimator_name = ( + self.estimator_name if estimator_name is None else estimator_name + ) + return ax, ax.figure, estimator_name + + @classmethod + def _validate_from_predictions_params( + cls, + y_true, + y_pred, + *, + ml_task, + sample_weight=None, + pos_label=None, + ): + check_consistent_length(y_true, y_pred, sample_weight) + + if ml_task == "binary-classification": + pos_label = _check_pos_label_consistency(pos_label, y_true) + + return pos_label + + +def _despine_matplotlib_axis(ax, *, x_range=(0, 1), y_range=(0, 1)): + """Despine the matplotlib axis. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The matplotlib axis to despine. + x_range : tuple of float, default=(0, 1) + The range of the x-axis. + y_range : tuple of float, default=(0, 1) + The range of the y-axis. + """ + for s in ["top", "right"]: + ax.spines[s].set_visible(False) + ax.spines["bottom"].set_bounds(x_range[0], x_range[1]) + ax.spines["left"].set_bounds(y_range[0], y_range[1]) + + +def _validate_style_kwargs(default_style_kwargs, user_style_kwargs): + """Create valid style kwargs by avoiding Matplotlib alias errors. + + Matplotlib raises an error when, for example, 'color' and 'c', or 'linestyle' and + 'ls', are specified together. To avoid this, we automatically keep only the one + specified by the user and raise an error if the user specifies both. + + Parameters + ---------- + default_style_kwargs : dict + The Matplotlib style kwargs used by default in the scikit-learn display. + user_style_kwargs : dict + The user-defined Matplotlib style kwargs. + + Returns + ------- + valid_style_kwargs : dict + The validated style kwargs taking into account both default and user-defined + Matplotlib style kwargs. + """ + invalid_to_valid_kw = { + "ls": "linestyle", + "c": "color", + "ec": "edgecolor", + "fc": "facecolor", + "lw": "linewidth", + "mec": "markeredgecolor", + "mfcalt": "markerfacecoloralt", + "ms": "markersize", + "mew": "markeredgewidth", + "mfc": "markerfacecolor", + "aa": "antialiased", + "ds": "drawstyle", + "font": "fontproperties", + "family": "fontfamily", + "name": "fontname", + "size": "fontsize", + "stretch": "fontstretch", + "style": "fontstyle", + "variant": "fontvariant", + "weight": "fontweight", + "ha": "horizontalalignment", + "va": "verticalalignment", + "ma": "multialignment", + } + for invalid_key, valid_key in invalid_to_valid_kw.items(): + if invalid_key in user_style_kwargs and valid_key in user_style_kwargs: + raise TypeError( + f"Got both {invalid_key} and {valid_key}, which are aliases of one " + "another" + ) + valid_style_kwargs = default_style_kwargs.copy() + + for key in user_style_kwargs: + if key in invalid_to_valid_kw: + valid_style_kwargs[invalid_to_valid_kw[key]] = user_style_kwargs[key] + else: + valid_style_kwargs[key] = user_style_kwargs[key] + + return valid_style_kwargs diff --git a/skore/src/skore/utils/_accessor.py b/skore/src/skore/utils/_accessor.py new file mode 100644 index 000000000..aafcc8d92 --- /dev/null +++ b/skore/src/skore/utils/_accessor.py @@ -0,0 +1,15 @@ +def _check_supported_ml_task(supported_ml_tasks): + def check(accessor): + supported_task = any( + task in accessor._parent._ml_task for task in supported_ml_tasks + ) + + if not supported_task: + raise AttributeError( + f"The {accessor._parent._ml_task} task is not a supported task by " + f"function called. The supported tasks are {supported_ml_tasks}." + ) + + return True + + return check diff --git a/skore/tests/conftest.py b/skore/tests/conftest.py index 6b2e3e0ad..82f8a2b48 100644 --- a/skore/tests/conftest.py +++ b/skore/tests/conftest.py @@ -7,6 +7,13 @@ from skore.view.view_repository import ViewRepository +def pytest_configure(config): + # Use matplotlib agg backend during the tests including doctests + import matplotlib + + matplotlib.use("agg") + + @pytest.fixture def mock_now(): return datetime.now(tz=timezone.utc) @@ -37,3 +44,21 @@ def in_memory_project(): item_repository=item_repository, view_repository=view_repository, ) + + +@pytest.fixture(scope="function") +def pyplot(): + """Setup and teardown fixture for matplotlib. + + This fixture closes the figures before and after running the functions. + + Returns + ------- + pyplot : module + The ``matplotlib.pyplot`` module. + """ + from matplotlib import pyplot + + pyplot.close("all") + yield pyplot + pyplot.close("all") diff --git a/skore/tests/unit/sklearn/plot/test_common.py b/skore/tests/unit/sklearn/plot/test_common.py new file mode 100644 index 000000000..8384f1538 --- /dev/null +++ b/skore/tests/unit/sklearn/plot/test_common.py @@ -0,0 +1,80 @@ +import pytest +from sklearn.datasets import make_classification, make_regression +from sklearn.linear_model import LinearRegression, LogisticRegression +from sklearn.model_selection import train_test_split +from skore import EstimatorReport + + +@pytest.mark.parametrize( + "plot_func, estimator, dataset", + [ + ("roc", LogisticRegression(), make_classification(random_state=42)), + ( + "precision_recall", + LogisticRegression(), + make_classification(random_state=42), + ), + ("prediction_error", LinearRegression(), make_regression(random_state=42)), + ], +) +def test_display_help(pyplot, capsys, plot_func, estimator, dataset): + """Check that the help method writes to the console.""" + + X_train, X_test, y_train, y_test = train_test_split(*dataset, random_state=42) + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = getattr(report.metrics.plot, plot_func)() + + display.help() + captured = capsys.readouterr() + assert f"📊 {display.__class__.__name__}" in captured.out + + +@pytest.mark.parametrize( + "plot_func, estimator, dataset", + [ + ("roc", LogisticRegression(), make_classification(random_state=42)), + ( + "precision_recall", + LogisticRegression(), + make_classification(random_state=42), + ), + ("prediction_error", LinearRegression(), make_regression(random_state=42)), + ], +) +def test_display_repr(pyplot, plot_func, estimator, dataset): + """Check that __repr__ returns a string starting with the expected prefix.""" + X_train, X_test, y_train, y_test = train_test_split(*dataset, random_state=42) + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = getattr(report.metrics.plot, plot_func)() + + repr_str = repr(display) + assert f"📊 {display.__class__.__name__}" in repr_str + + +@pytest.mark.parametrize( + "plot_func, estimator, dataset", + [ + ("roc", LogisticRegression(), make_classification(random_state=42)), + ( + "precision_recall", + LogisticRegression(), + make_classification(random_state=42), + ), + ("prediction_error", LinearRegression(), make_regression(random_state=42)), + ], +) +def test_display_provide_ax(pyplot, plot_func, estimator, dataset): + """Check that we can provide an ax to the plot method.""" + X_train, X_test, y_train, y_test = train_test_split(*dataset, random_state=42) + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = getattr(report.metrics.plot, plot_func)() + + _, ax = pyplot.subplots() + display.plot(ax=ax) + assert display.ax_ is ax diff --git a/skore/tests/unit/sklearn/plot/test_precision_recall_curve.py b/skore/tests/unit/sklearn/plot/test_precision_recall_curve.py new file mode 100644 index 000000000..ec42899f6 --- /dev/null +++ b/skore/tests/unit/sklearn/plot/test_precision_recall_curve.py @@ -0,0 +1,238 @@ +import matplotlib as mpl +import pytest +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split +from skore import EstimatorReport +from skore.sklearn._plot import PrecisionRecallCurveDisplay + + +@pytest.fixture +def binary_classification_data(): + X, y = make_classification(random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + return LogisticRegression().fit(X_train, y_train), X_train, X_test, y_train, y_test + + +@pytest.fixture +def multiclass_classification_data(): + X, y = make_classification(n_classes=3, n_clusters_per_class=1, random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + return LogisticRegression().fit(X_train, y_train), X_train, X_test, y_train, y_test + + +def test_precision_recall_curve_display_binary_classification( + pyplot, binary_classification_data +): + """Check the attributes and default plotting behaviour of the + precision-recall curve plot with binary data. + """ + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.precision_recall() + assert isinstance(display, PrecisionRecallCurveDisplay) + + # check the structure of the attributes + for attr_name in ("precision", "recall", "average_precision", "prevalence"): + assert isinstance(getattr(display, attr_name), dict) + assert len(getattr(display, attr_name)) == 1 + + attr = getattr(display, attr_name) + assert list(attr.keys()) == [estimator.classes_[1]] + assert list(attr.keys()) == [display.pos_label] + assert isinstance(attr[estimator.classes_[1]], list) + assert len(attr[estimator.classes_[1]]) == 1 + + assert isinstance(display.lines_, list) + assert len(display.lines_) == 1 + precision_recall_curve_mpl = display.lines_[0] + assert isinstance(precision_recall_curve_mpl, mpl.lines.Line2D) + assert ( + precision_recall_curve_mpl.get_label() + == f"Test set (AP = {display.average_precision[estimator.classes_[1]][0]:0.2f})" + ) + assert precision_recall_curve_mpl.get_color() == "#1f77b4" # tab:blue in hex + + assert display.chance_levels_ is None + + assert isinstance(display.ax_, mpl.axes.Axes) + legend = display.ax_.get_legend() + assert legend.get_title().get_text() == estimator.__class__.__name__ + assert len(legend.get_texts()) == 1 + + assert display.ax_.get_xlabel() == "Recall\n(Positive label: 1)" + assert display.ax_.get_ylabel() == "Precision\n(Positive label: 1)" + assert display.ax_.get_adjustable() == "box" + assert display.ax_.get_aspect() in ("equal", 1.0) + assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) + + +def test_precision_recall_curve_display_data_source(pyplot, binary_classification_data): + """Check that we can pass the `data_source` argument to the precision-recall + curve plot. + """ + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.precision_recall(data_source="train") + assert display.lines_[0].get_label() == "Train set (AP = 1.00)" + + display = report.metrics.plot.precision_recall( + data_source="X_y", X=X_train, y=y_train + ) + assert display.lines_[0].get_label() == "AP = 1.00" + + +def test_precision_recall_curve_display_multiclass_classification( + pyplot, multiclass_classification_data +): + """Check the attributes and default plotting behaviour of the precision-recall + curve plot with multiclass data. + """ + estimator, X_train, X_test, y_train, y_test = multiclass_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.precision_recall() + assert isinstance(display, PrecisionRecallCurveDisplay) + + # check the structure of the attributes + for attr_name in ("precision", "recall", "average_precision", "prevalence"): + assert isinstance(getattr(display, attr_name), dict) + assert len(getattr(display, attr_name)) == len(estimator.classes_) + + attr = getattr(display, attr_name) + for class_label in estimator.classes_: + assert isinstance(attr[class_label], list) + assert len(attr[class_label]) == 1 + + assert isinstance(display.lines_, list) + assert len(display.lines_) == len(estimator.classes_) + default_colors = ["#1f77b4", "#ff7f0e", "#2ca02c"] + for class_label, expected_color in zip(estimator.classes_, default_colors): + precision_recall_curve_mpl = display.lines_[class_label] + assert isinstance(precision_recall_curve_mpl, mpl.lines.Line2D) + assert precision_recall_curve_mpl.get_label() == ( + f"{str(class_label).title()} - test set " + f"(AP = {display.average_precision[class_label][0]:0.2f})" + ) + assert precision_recall_curve_mpl.get_color() == expected_color + + assert display.chance_levels_ is None + + assert isinstance(display.ax_, mpl.axes.Axes) + legend = display.ax_.get_legend() + assert legend.get_title().get_text() == estimator.__class__.__name__ + assert len(legend.get_texts()) == 3 + + assert display.ax_.get_xlabel() == "Recall" + assert display.ax_.get_ylabel() == "Precision" + assert display.ax_.get_adjustable() == "box" + assert display.ax_.get_aspect() in ("equal", 1.0) + assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) + + +def test_precision_recall_curve_display_pr_curve_kwargs( + pyplot, binary_classification_data, multiclass_classification_data +): + """Check that we can pass keyword arguments to the precision-recall curve plot.""" + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.precision_recall() + for pr_curve_kwargs in ({"color": "red"}, [{"color": "red"}]): + display.plot( + pr_curve_kwargs=pr_curve_kwargs, + plot_chance_level=True, + chance_level_kwargs={"color": "blue"}, + ) + + assert display.lines_[0].get_color() == "red" + assert display.chance_levels_[0].get_color() == "blue" + + display.plot(plot_chance_level=True) + assert display.chance_levels_[0].get_color() == "k" + + display.plot(plot_chance_level=True, chance_level_kwargs=[{"color": "red"}]) + assert display.chance_levels_[0].get_color() == "red" + + estimator, X_train, X_test, y_train, y_test = multiclass_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.precision_recall() + display.plot( + pr_curve_kwargs=[dict(color="red"), dict(color="blue"), dict(color="green")], + plot_chance_level=True, + chance_level_kwargs=[ + dict(color="red"), + dict(color="blue"), + dict(color="green"), + ], + ) + assert display.lines_[0].get_color() == "red" + assert display.lines_[1].get_color() == "blue" + assert display.lines_[2].get_color() == "green" + assert display.chance_levels_[0].get_color() == "red" + assert display.chance_levels_[1].get_color() == "blue" + assert display.chance_levels_[2].get_color() == "green" + + display.plot(plot_chance_level=True) + for chance_level in display.chance_levels_: + assert chance_level.get_color() == "k" + + display.plot(despine=False) + assert display.ax_.spines["top"].get_visible() + assert display.ax_.spines["right"].get_visible() + + +def test_precision_recall_curve_display_plot_error_wrong_pr_curve_kwargs( + pyplot, binary_classification_data, multiclass_classification_data +): + """Check that we raise a proper error message when passing an inappropriate + value for the `roc_curve_kwargs` argument. + """ + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.precision_recall() + err_msg = ( + "You intend to plot a single precision-recall curve and provide multiple " + "precision-recall curve keyword arguments" + ) + with pytest.raises(ValueError, match=err_msg): + display.plot(pr_curve_kwargs=[{}, {}]) + + err_msg = ( + "You intend to plot a single chance level line and provide multiple chance " + "level line keyword arguments" + ) + with pytest.raises(ValueError, match=err_msg): + display.plot(plot_chance_level=True, chance_level_kwargs=[{}, {}]) + + estimator, X_train, X_test, y_train, y_test = multiclass_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.precision_recall() + err_msg = "You intend to plot multiple precision-recall curves." + with pytest.raises(ValueError, match=err_msg): + display.plot(pr_curve_kwargs=[{}, {}]) + + with pytest.raises(ValueError, match=err_msg): + display.plot(pr_curve_kwargs={}) + + err_msg = ( + "You intend to plot multiple precision-recall curves. We expect " + "`chance_level_kwargs` to be a list" + ) + with pytest.raises(ValueError, match=err_msg): + display.plot(plot_chance_level=True, chance_level_kwargs=[{}, {}]) + + with pytest.raises(ValueError, match=err_msg): + display.plot(plot_chance_level=True, chance_level_kwargs={}) diff --git a/skore/tests/unit/sklearn/plot/test_prediction_error.py b/skore/tests/unit/sklearn/plot/test_prediction_error.py new file mode 100644 index 000000000..30db39c00 --- /dev/null +++ b/skore/tests/unit/sklearn/plot/test_prediction_error.py @@ -0,0 +1,136 @@ +import matplotlib as mpl +import numpy as np +import pytest +from sklearn.datasets import make_regression +from sklearn.linear_model import LinearRegression +from sklearn.model_selection import train_test_split +from skore import EstimatorReport +from skore.sklearn._plot import PredictionErrorDisplay + + +@pytest.fixture +def regression_data(): + X, y = make_regression(random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + return LinearRegression().fit(X_train, y_train), X_train, X_test, y_train, y_test + + +@pytest.mark.parametrize( + "params, err_msg", + [ + ({"subsample": -1}, "When an integer, subsample=-1 should be"), + ({"subsample": 20.0}, "When a floating-point, subsample=20.0 should be"), + ({"subsample": -20.0}, "When a floating-point, subsample=-20.0 should be"), + ({"kind": "xxx"}, "`kind` must be one of"), + ], +) +def test_prediction_error_display_raise_error(pyplot, params, err_msg, regression_data): + """Check that we raise the proper error when making the parameters + validation.""" + estimator, X_train, X_test, y_train, y_test = regression_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + with pytest.raises(ValueError, match=err_msg): + report.metrics.plot.prediction_error(**params) + + +def test_prediction_error_display_regression(pyplot, regression_data): + """Check the attributes and default plotting behaviour of the prediction error plot + with regression data.""" + estimator, X_train, X_test, y_train, y_test = regression_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.prediction_error() + assert isinstance(display, PredictionErrorDisplay) + + # check the structure of the attributes + assert isinstance(display.y_true, np.ndarray) + assert isinstance(display.y_pred, np.ndarray) + np.testing.assert_allclose(display.y_true, y_test) + np.testing.assert_allclose(display.y_pred, estimator.predict(X_test)) + assert display.data_source == "test" + + assert isinstance(display.line_, mpl.lines.Line2D) + assert display.line_.get_label() == "Perfect predictions" + assert display.line_.get_color() == "black" + + assert isinstance(display.scatter_, mpl.collections.PathCollection) + + assert isinstance(display.ax_, mpl.axes.Axes) + legend = display.ax_.get_legend() + assert legend.get_title().get_text() == estimator.__class__.__name__ + assert len(legend.get_texts()) == 2 + + assert display.ax_.get_xlabel() == "Predicted values" + assert display.ax_.get_ylabel() == "Residuals (actual - predicted)" + + +def test_prediction_error_display_regression_kind(pyplot, regression_data): + """Check the attributes when switching to the "actual_vs_predicted" kind.""" + estimator, X_train, X_test, y_train, y_test = regression_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.prediction_error(kind="actual_vs_predicted") + assert isinstance(display, PredictionErrorDisplay) + + assert isinstance(display.line_, mpl.lines.Line2D) + assert display.line_.get_label() == "Perfect predictions" + assert display.line_.get_color() == "black" + + assert isinstance(display.scatter_, mpl.collections.PathCollection) + + assert isinstance(display.ax_, mpl.axes.Axes) + legend = display.ax_.get_legend() + assert legend.get_title().get_text() == estimator.__class__.__name__ + assert len(legend.get_texts()) == 2 + + assert display.ax_.get_xlabel() == "Predicted values" + assert display.ax_.get_ylabel() == "Actual values" + + assert display.ax_.get_xlim() == display.ax_.get_ylim() + assert display.ax_.get_aspect() in ("equal", 1.0) + + +def test_prediction_error_display_data_source(pyplot, regression_data): + """Check that we can pass the `data_source` argument to the prediction error + plot.""" + estimator, X_train, X_test, y_train, y_test = regression_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.prediction_error(data_source="train") + assert display.line_.get_label() == "Perfect predictions" + assert display.scatter_.get_label() == "Train set" + + display = report.metrics.plot.prediction_error( + data_source="X_y", X=X_train, y=y_train + ) + assert display.line_.get_label() == "Perfect predictions" + assert display.scatter_.get_label() == "Data set" + + +def test_prediction_error_display_kwargs(pyplot, regression_data): + """Check that we can pass keyword arguments to the prediction error plot.""" + estimator, X_train, X_test, y_train, y_test = regression_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.prediction_error() + display.plot(scatter_kwargs={"color": "red"}, line_kwargs={"color": "blue"}) + np.testing.assert_allclose(display.scatter_.get_facecolor(), [[1, 0, 0, 0.8]]) + assert display.line_.get_color() == "blue" + + display.plot(despine=False) + assert display.ax_.spines["top"].get_visible() + assert display.ax_.spines["right"].get_visible() + + expected_subsample = 10 + display = report.metrics.plot.prediction_error(subsample=expected_subsample) + assert len(display.scatter_.get_offsets()) == expected_subsample + + expected_subsample = int(X_test.shape[0] * 0.5) + display = report.metrics.plot.prediction_error(subsample=0.5) + assert len(display.scatter_.get_offsets()) == expected_subsample diff --git a/skore/tests/unit/sklearn/plot/test_roc_curve.py b/skore/tests/unit/sklearn/plot/test_roc_curve.py new file mode 100644 index 000000000..9f2f09678 --- /dev/null +++ b/skore/tests/unit/sklearn/plot/test_roc_curve.py @@ -0,0 +1,199 @@ +import matplotlib as mpl +import pytest +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split +from skore import EstimatorReport +from skore.sklearn._plot import RocCurveDisplay + + +@pytest.fixture +def binary_classification_data(): + X, y = make_classification(random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + return LogisticRegression().fit(X_train, y_train), X_train, X_test, y_train, y_test + + +@pytest.fixture +def multiclass_classification_data(): + X, y = make_classification(n_classes=3, n_clusters_per_class=1, random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + return LogisticRegression().fit(X_train, y_train), X_train, X_test, y_train, y_test + + +def test_roc_curve_display_binary_classification(pyplot, binary_classification_data): + """Check the attributes and default plotting behaviour of the ROC curve plot with + binary data.""" + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.roc() + assert isinstance(display, RocCurveDisplay) + + # check the structure of the attributes + for attr_name in ("fpr", "tpr", "roc_auc"): + assert isinstance(getattr(display, attr_name), dict) + assert len(getattr(display, attr_name)) == 1 + + attr = getattr(display, attr_name) + assert list(attr.keys()) == [estimator.classes_[1]] + assert list(attr.keys()) == [display.pos_label] + assert isinstance(attr[estimator.classes_[1]], list) + assert len(attr[estimator.classes_[1]]) == 1 + + assert isinstance(display.lines_, list) + assert len(display.lines_) == 1 + roc_curve_mpl = display.lines_[0] + assert isinstance(roc_curve_mpl, mpl.lines.Line2D) + assert ( + roc_curve_mpl.get_label() + == f"Test set (AUC = {display.roc_auc[estimator.classes_[1]][0]:0.2f})" + ) + assert roc_curve_mpl.get_color() == "#1f77b4" # tab:blue in hex + + assert isinstance(display.chance_level_, mpl.lines.Line2D) + assert display.chance_level_.get_label() == "Chance level (AUC = 0.5)" + assert display.chance_level_.get_color() == "k" + + assert isinstance(display.ax_, mpl.axes.Axes) + legend = display.ax_.get_legend() + assert legend.get_title().get_text() == estimator.__class__.__name__ + assert len(legend.get_texts()) == 2 + + assert display.ax_.get_xlabel() == "False Positive Rate\n(Positive label: 1)" + assert display.ax_.get_ylabel() == "True Positive Rate\n(Positive label: 1)" + assert display.ax_.get_adjustable() == "box" + assert display.ax_.get_aspect() in ("equal", 1.0) + assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) + + +def test_roc_curve_display_multiclass_classification( + pyplot, multiclass_classification_data +): + """Check the attributes and default plotting behaviour of the ROC curve plot with + multiclass data.""" + estimator, X_train, X_test, y_train, y_test = multiclass_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.roc() + assert isinstance(display, RocCurveDisplay) + + # check the structure of the attributes + for attr_name in ("fpr", "tpr", "roc_auc"): + assert isinstance(getattr(display, attr_name), dict) + assert len(getattr(display, attr_name)) == len(estimator.classes_) + + attr = getattr(display, attr_name) + for class_label in estimator.classes_: + assert isinstance(attr[class_label], list) + assert len(attr[class_label]) == 1 + + assert isinstance(display.lines_, list) + assert len(display.lines_) == len(estimator.classes_) + default_colors = ["#1f77b4", "#ff7f0e", "#2ca02c"] + for class_label, expected_color in zip(estimator.classes_, default_colors): + roc_curve_mpl = display.lines_[class_label] + assert isinstance(roc_curve_mpl, mpl.lines.Line2D) + assert roc_curve_mpl.get_label() == ( + f"{str(class_label).title()} - test set " + f"(AUC = {display.roc_auc[class_label][0]:0.2f})" + ) + assert roc_curve_mpl.get_color() == expected_color + + assert isinstance(display.chance_level_, mpl.lines.Line2D) + assert display.chance_level_.get_label() == "Chance level (AUC = 0.5)" + assert display.chance_level_.get_color() == "k" + + assert isinstance(display.ax_, mpl.axes.Axes) + legend = display.ax_.get_legend() + assert legend.get_title().get_text() == estimator.__class__.__name__ + assert len(legend.get_texts()) == 4 + + assert display.ax_.get_xlabel() == "False Positive Rate" + assert display.ax_.get_ylabel() == "True Positive Rate" + assert display.ax_.get_adjustable() == "box" + assert display.ax_.get_aspect() in ("equal", 1.0) + assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) + + +def test_roc_curve_display_data_source(pyplot, binary_classification_data): + """Check that we can pass the `data_source` argument to the ROC curve plot.""" + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.roc(data_source="train") + assert display.lines_[0].get_label() == "Train set (AUC = 1.00)" + + display = report.metrics.plot.roc(data_source="X_y", X=X_train, y=y_train) + assert display.lines_[0].get_label() == "AUC = 1.00" + + +def test_roc_curve_display_plot_error_wrong_roc_curve_kwargs( + pyplot, binary_classification_data, multiclass_classification_data +): + """Check that we raise a proper error message when passing an inappropriate + value for the `roc_curve_kwargs` argument.""" + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.roc() + err_msg = ( + "You intend to plot a single ROC curve and provide multiple ROC curve " + "keyword arguments" + ) + with pytest.raises(ValueError, match=err_msg): + display.plot(roc_curve_kwargs=[{}, {}]) + + estimator, X_train, X_test, y_train, y_test = multiclass_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.roc() + err_msg = "You intend to plot multiple ROC curves." + with pytest.raises(ValueError, match=err_msg): + display.plot(roc_curve_kwargs=[{}, {}]) + + with pytest.raises(ValueError, match=err_msg): + display.plot(roc_curve_kwargs={}) + + +def test_roc_curve_display_roc_curve_kwargs( + pyplot, binary_classification_data, multiclass_classification_data +): + """Check that we can pass keyword arguments to the ROC curve plot.""" + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.roc() + display.plot( + roc_curve_kwargs={"color": "red"}, chance_level_kwargs={"color": "blue"} + ) + + assert display.lines_[0].get_color() == "red" + assert display.chance_level_.get_color() == "blue" + + estimator, X_train, X_test, y_train, y_test = multiclass_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.roc() + display.plot( + roc_curve_kwargs=[dict(color="red"), dict(color="blue"), dict(color="green")], + chance_level_kwargs={"color": "blue"}, + ) + assert display.lines_[0].get_color() == "red" + assert display.lines_[1].get_color() == "blue" + assert display.lines_[2].get_color() == "green" + assert display.chance_level_.get_color() == "blue" + + display.plot(plot_chance_level=False) + assert display.chance_level_ is None + + display.plot(despine=False) + assert display.ax_.spines["top"].get_visible() + assert display.ax_.spines["right"].get_visible() diff --git a/skore/tests/unit/sklearn/plot/test_utils.py b/skore/tests/unit/sklearn/plot/test_utils.py new file mode 100644 index 000000000..fb54322bb --- /dev/null +++ b/skore/tests/unit/sklearn/plot/test_utils.py @@ -0,0 +1,65 @@ +import pytest +from skore.sklearn._plot.utils import _validate_style_kwargs + + +@pytest.mark.parametrize( + "default_kwargs, user_kwargs, expected", + [ + ( + {"color": "blue", "linewidth": 2}, + {"linestyle": "dashed"}, + {"color": "blue", "linewidth": 2, "linestyle": "dashed"}, + ), + ( + {"color": "blue", "linestyle": "solid"}, + {"c": "red", "ls": "dashed"}, + {"color": "red", "linestyle": "dashed"}, + ), + ( + {"label": "xxx", "color": "k", "linestyle": "--"}, + {"ls": "-."}, + {"label": "xxx", "color": "k", "linestyle": "-."}, + ), + ({}, {}, {}), + ( + {}, + { + "ls": "dashed", + "c": "red", + "ec": "black", + "fc": "yellow", + "lw": 2, + "mec": "green", + "mfcalt": "blue", + "ms": 5, + }, + { + "linestyle": "dashed", + "color": "red", + "edgecolor": "black", + "facecolor": "yellow", + "linewidth": 2, + "markeredgecolor": "green", + "markerfacecoloralt": "blue", + "markersize": 5, + }, + ), + ], +) +def test_validate_style_kwargs(default_kwargs, user_kwargs, expected): + """Check the behaviour of `validate_style_kwargs` with various type of entries.""" + result = _validate_style_kwargs(default_kwargs, user_kwargs) + assert result == expected, ( + "The validation of style keywords does not provide the expected results: " + f"Got {result} instead of {expected}." + ) + + +@pytest.mark.parametrize( + "default_kwargs, user_kwargs", + [({}, {"ls": 2, "linestyle": 3}), ({}, {"c": "r", "color": "blue"})], +) +def test_validate_style_kwargs_error(default_kwargs, user_kwargs): + """Check that `validate_style_kwargs` raises TypeError""" + with pytest.raises(TypeError): + _validate_style_kwargs(default_kwargs, user_kwargs) diff --git a/skore/tests/unit/sklearn/test_estimator.py b/skore/tests/unit/sklearn/test_estimator.py new file mode 100644 index 000000000..8fd2c54aa --- /dev/null +++ b/skore/tests/unit/sklearn/test_estimator.py @@ -0,0 +1,927 @@ +import re +from copy import deepcopy + +import joblib +import numpy as np +import pandas as pd +import pytest +from sklearn.base import clone +from sklearn.cluster import KMeans +from sklearn.datasets import make_classification, make_regression +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import LinearRegression, LogisticRegression +from sklearn.metrics import make_scorer, median_absolute_error, r2_score, rand_score +from sklearn.model_selection import train_test_split +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.svm import SVC +from sklearn.utils.validation import check_is_fitted +from skore import EstimatorReport +from skore.sklearn._estimator.utils import _check_supported_estimator +from skore.sklearn._plot import RocCurveDisplay + + +@pytest.fixture +def binary_classification_data(): + """Create a binary classification dataset and return fitted estimator and data.""" + X, y = make_classification(random_state=42) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + return RandomForestClassifier().fit(X_train, y_train), X_test, y_test + + +@pytest.fixture +def binary_classification_data_svc(): + """Create a binary classification dataset and return fitted estimator and data. + The estimator is a SVC that does not support `predict_proba`. + """ + X, y = make_classification(random_state=42) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + return SVC().fit(X_train, y_train), X_test, y_test + + +@pytest.fixture +def multiclass_classification_data(): + """Create a multiclass classification dataset and return fitted estimator and + data.""" + X, y = make_classification( + n_classes=3, n_clusters_per_class=1, random_state=42, n_informative=10 + ) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + return RandomForestClassifier().fit(X_train, y_train), X_test, y_test + + +@pytest.fixture +def multiclass_classification_data_svc(): + """Create a multiclass classification dataset and return fitted estimator and + data. The estimator is a SVC that does not support `predict_proba`. + """ + X, y = make_classification( + n_classes=3, n_clusters_per_class=1, random_state=42, n_informative=10 + ) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + return SVC().fit(X_train, y_train), X_test, y_test + + +@pytest.fixture +def binary_classification_data_pipeline(): + """Create a binary classification dataset and return fitted pipeline and data.""" + X, y = make_classification(random_state=42) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + estimator = Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression())]) + return estimator.fit(X_train, y_train), X_test, y_test + + +@pytest.fixture +def regression_data(): + """Create a regression dataset and return fitted estimator and data.""" + X, y = make_regression(random_state=42) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + return LinearRegression().fit(X_train, y_train), X_test, y_test + + +@pytest.fixture +def regression_multioutput_data(): + """Create a regression dataset and return fitted estimator and data.""" + X, y = make_regression(n_targets=2, random_state=42) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + return LinearRegression().fit(X_train, y_train), X_test, y_test + + +def test_check_supported_estimator(): + """Test the behaviour of `_check_supported_estimator`.""" + + class MockParent: + def __init__(self, estimator): + self.estimator = estimator + + class MockAccessor: + def __init__(self, parent): + self._parent = parent + + parent = MockParent(LogisticRegression()) + accessor = MockAccessor(parent) + check = _check_supported_estimator((LogisticRegression,)) + assert check(accessor) + + pipeline = Pipeline([("clf", LogisticRegression())]) + parent = MockParent(pipeline) + accessor = MockAccessor(parent) + assert check(accessor) + + parent = MockParent(RandomForestClassifier()) + accessor = MockAccessor(parent) + err_msg = ( + "The RandomForestClassifier estimator is not supported by the function called." + ) + with pytest.raises(AttributeError, match=err_msg): + check(accessor) + + +######################################################################################## +# Check the general behaviour of the report +######################################################################################## + + +@pytest.mark.parametrize("fit", [True, "auto"]) +def test_estimator_not_fitted(fit): + """Test that an error is raised when trying to create a report from an unfitted + estimator and no data are provided to fit the estimator. + """ + estimator = LinearRegression() + err_msg = "The training data is required to fit the estimator. " + with pytest.raises(ValueError, match=err_msg): + EstimatorReport(estimator, fit=fit) + + +@pytest.mark.parametrize("fit", [True, "auto"]) +def test_estimator_report_from_unfitted_estimator(fit): + """Check the general behaviour of passing an unfitted estimator and training + data.""" + X, y = make_regression(random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + estimator = LinearRegression() + report = EstimatorReport( + estimator, + fit=fit, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + + check_is_fitted(report.estimator) + assert report.estimator is not estimator # the estimator should be cloned + + assert report.X_train is X_train + assert report.y_train is y_train + assert report.X_test is X_test + assert report.y_test is y_test + + err_msg = "attribute is immutable" + with pytest.raises(AttributeError, match=err_msg): + report.estimator = LinearRegression() + with pytest.raises(AttributeError, match=err_msg): + report.X_train = X_train + with pytest.raises(AttributeError, match=err_msg): + report.y_train = y_train + + +@pytest.mark.parametrize("fit", [False, "auto"]) +def test_estimator_report_from_fitted_estimator(binary_classification_data, fit): + """Check the general behaviour of passing an already fitted estimator without + refitting it.""" + estimator, X, y = binary_classification_data + report = EstimatorReport(estimator, fit=fit, X_test=X, y_test=y) + + assert report.estimator is estimator # we should not clone the estimator + assert report.X_train is None + assert report.y_train is None + assert report.X_test is X + assert report.y_test is y + + err_msg = "attribute is immutable" + with pytest.raises(AttributeError, match=err_msg): + report.estimator = LinearRegression() + with pytest.raises(AttributeError, match=err_msg): + report.X_train = X + with pytest.raises(AttributeError, match=err_msg): + report.y_train = y + + +def test_estimator_report_from_fitted_pipeline(binary_classification_data_pipeline): + """Check the general behaviour of passing an already fitted pipeline without + refitting it. + """ + estimator, X, y = binary_classification_data_pipeline + report = EstimatorReport(estimator, X_test=X, y_test=y) + + assert report.estimator is estimator # we should not clone the estimator + assert report.estimator_name == estimator[-1].__class__.__name__ + assert report.X_train is None + assert report.y_train is None + assert report.X_test is X + assert report.y_test is y + + +def test_estimator_report_invalidate_cache_data(binary_classification_data): + """Check that we invalidate the cache when the data is changed.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + for attribute in ("X_test", "y_test"): + report._cache["mocking"] = "mocking" # mock writing to cache + setattr(report, attribute, None) + assert report._cache == {} + + +@pytest.mark.parametrize( + "Estimator, X_test, y_test, supported_plot_methods, not_supported_plot_methods", + [ + ( + RandomForestClassifier(), + *make_classification(random_state=42), + ["roc", "precision_recall"], + ["prediction_error"], + ), + ( + RandomForestClassifier(), + *make_classification(n_classes=3, n_clusters_per_class=1, random_state=42), + ["roc", "precision_recall"], + ["prediction_error"], + ), + ( + LinearRegression(), + *make_regression(random_state=42), + ["prediction_error"], + ["roc", "precision_recall"], + ), + ], +) +def test_estimator_report_check_support_plot( + Estimator, X_test, y_test, supported_plot_methods, not_supported_plot_methods +): + """Check that the available plot methods are correctly registered.""" + estimator = Estimator.fit(X_test, y_test) + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + for supported_plot_method in supported_plot_methods: + assert hasattr(report.metrics.plot, supported_plot_method) + + for not_supported_plot_method in not_supported_plot_methods: + assert not hasattr(report.metrics.plot, not_supported_plot_method) + + +def test_estimator_report_help(capsys, binary_classification_data): + """Check that the help method writes to the console.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + report.help() + captured = capsys.readouterr() + assert ( + f"📓 Tools to diagnose estimator {estimator.__class__.__name__}" in captured.out + ) + + +def test_estimator_report_repr(binary_classification_data): + """Check that __repr__ returns a string starting with the expected prefix.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + repr_str = repr(report) + assert f"📓 Tools to diagnose estimator {estimator.__class__.__name__}" in repr_str + + +@pytest.mark.parametrize( + "fixture_name", ["binary_classification_data", "regression_data"] +) +def test_estimator_report_cache_predictions(request, fixture_name): + """Check that calling cache_predictions fills the cache.""" + estimator, X_test, y_test = request.getfixturevalue(fixture_name) + report = EstimatorReport( + estimator, X_train=X_test, y_train=y_test, X_test=X_test, y_test=y_test + ) + + assert report._cache == {} + report.cache_predictions() + assert report._cache != {} + stored_cache = deepcopy(report._cache) + report.cache_predictions() + # check that the keys are exactly the same + assert report._cache.keys() == stored_cache.keys() + + +######################################################################################## +# Check the plot methods +######################################################################################## + + +def test_estimator_report_plot_help(capsys, binary_classification_data): + """Check that the help method writes to the console.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + report.metrics.plot.help() + captured = capsys.readouterr() + assert "🎨 Available plot methods" in captured.out + + +def test_estimator_report_plot_repr(binary_classification_data): + """Check that __repr__ returns a string starting with the expected prefix.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + repr_str = repr(report.metrics.plot) + assert "🎨 Available plot methods" in repr_str + + +def test_estimator_report_plot_roc(binary_classification_data): + """Check that the ROC plot method works.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + assert isinstance(report.metrics.plot.roc(), RocCurveDisplay) + + +@pytest.mark.parametrize("display", ["roc", "precision_recall"]) +def test_estimator_report_display_binary_classification( + pyplot, binary_classification_data, display +): + """General behaviour of the function creating display on binary classification.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + assert hasattr(report.metrics.plot, display) + display_first_call = getattr(report.metrics.plot, display)() + assert report._cache != {} + display_second_call = getattr(report.metrics.plot, display)() + assert display_first_call is display_second_call + + +@pytest.mark.parametrize("display", ["prediction_error"]) +def test_estimator_report_display_regression(pyplot, regression_data, display): + """General behaviour of the function creating display on regression.""" + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + assert hasattr(report.metrics.plot, display) + display_first_call = getattr(report.metrics.plot, display)() + assert report._cache != {} + display_second_call = getattr(report.metrics.plot, display)() + assert display_first_call is display_second_call + + +@pytest.mark.parametrize("display", ["roc", "precision_recall"]) +def test_estimator_report_display_binary_classification_external_data( + pyplot, binary_classification_data, display +): + """General behaviour of the function creating display on binary classification + when passing external data. + """ + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator) + assert hasattr(report.metrics.plot, display) + display_first_call = getattr(report.metrics.plot, display)( + data_source="X_y", X=X_test, y=y_test + ) + assert report._cache != {} + display_second_call = getattr(report.metrics.plot, display)( + data_source="X_y", X=X_test, y=y_test + ) + assert display_first_call is display_second_call + + +@pytest.mark.parametrize("display", ["prediction_error"]) +def test_estimator_report_display_regression_external_data( + pyplot, regression_data, display +): + """General behaviour of the function creating display on regression when passing + external data. + """ + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator) + assert hasattr(report.metrics.plot, display) + display_first_call = getattr(report.metrics.plot, display)( + data_source="X_y", X=X_test, y=y_test + ) + assert report._cache != {} + display_second_call = getattr(report.metrics.plot, display)( + data_source="X_y", X=X_test, y=y_test + ) + assert display_first_call is display_second_call + + +@pytest.mark.parametrize("display", ["roc", "precision_recall"]) +def test_estimator_report_display_binary_classification_switching_data_source( + pyplot, binary_classification_data, display +): + """Check that we don't hit the cache when switching the data source.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_test, y_train=y_test, X_test=X_test, y_test=y_test + ) + assert hasattr(report.metrics.plot, display) + display_first_call = getattr(report.metrics.plot, display)(data_source="test") + assert report._cache != {} + display_second_call = getattr(report.metrics.plot, display)(data_source="train") + assert display_first_call is not display_second_call + display_third_call = getattr(report.metrics.plot, display)( + data_source="X_y", X=X_test, y=y_test + ) + assert display_first_call is not display_third_call + assert display_second_call is not display_third_call + + +@pytest.mark.parametrize("display", ["prediction_error"]) +def test_estimator_report_display_regression_switching_data_source( + pyplot, regression_data, display +): + """Check that we don't hit the cache when switching the data source.""" + estimator, X_test, y_test = regression_data + report = EstimatorReport( + estimator, X_train=X_test, y_train=y_test, X_test=X_test, y_test=y_test + ) + assert hasattr(report.metrics.plot, display) + display_first_call = getattr(report.metrics.plot, display)(data_source="test") + assert report._cache != {} + display_second_call = getattr(report.metrics.plot, display)(data_source="train") + assert display_first_call is not display_second_call + display_third_call = getattr(report.metrics.plot, display)( + data_source="X_y", X=X_test, y=y_test + ) + assert display_first_call is not display_third_call + assert display_second_call is not display_third_call + + +######################################################################################## +# Check the metrics methods +######################################################################################## + + +def test_estimator_report_metrics_help(capsys, binary_classification_data): + """Check that the help method writes to the console.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + report.metrics.help() + captured = capsys.readouterr() + assert "📏 Available metrics methods" in captured.out + + +def test_estimator_report_metrics_repr(binary_classification_data): + """Check that __repr__ returns a string starting with the expected prefix.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + repr_str = repr(report.metrics) + assert "📏 Available metrics methods" in repr_str + + +@pytest.mark.parametrize( + "metric", ["accuracy", "precision", "recall", "brier_score", "roc_auc", "log_loss"] +) +def test_estimator_report_metrics_binary_classification( + binary_classification_data, metric +): + """Check the behaviour of the metrics methods available for binary + classification. + """ + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + assert hasattr(report.metrics, metric) + result = getattr(report.metrics, metric)() + assert isinstance(result, pd.DataFrame) + # check that we hit the cache + result_with_cache = getattr(report.metrics, metric)() + pd.testing.assert_frame_equal(result, result_with_cache) + + # check that something was written to the cache + assert report._cache != {} + report.clean_cache() + + # check that passing using data outside from the report works and that we they + # don't come from the cache + result_external_data = getattr(report.metrics, metric)( + data_source="X_y", X=X_test, y=y_test + ) + assert isinstance(result_external_data, pd.DataFrame) + pd.testing.assert_frame_equal(result, result_external_data) + assert report._cache != {} + + +@pytest.mark.parametrize("metric", ["r2", "rmse"]) +def test_estimator_report_metrics_regression(regression_data, metric): + """Check the behaviour of the metrics methods available for regression.""" + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + assert hasattr(report.metrics, metric) + result = getattr(report.metrics, metric)() + assert isinstance(result, pd.DataFrame) + # check that we hit the cache + result_with_cache = getattr(report.metrics, metric)() + pd.testing.assert_frame_equal(result, result_with_cache) + + # check that something was written to the cache + assert report._cache != {} + report.clean_cache() + + # check that passing using data outside from the report works and that we they + # don't come from the cache + result_external_data = getattr(report.metrics, metric)( + data_source="X_y", X=X_test, y=y_test + ) + assert isinstance(result_external_data, pd.DataFrame) + pd.testing.assert_frame_equal(result, result_external_data) + assert report._cache != {} + + +def _normalize_metric_name(column): + """Helper to normalize the metric name present in a pandas column that could be + a multi-index or single-index.""" + # if we have a multi-index, then the metric name is on level 0 + s = column[0] if isinstance(column, tuple) else column + # Remove spaces and underscores + return re.sub(r"[^a-zA-Z]", "", s.lower()) + + +def _check_results_report_metrics(result, expected_metrics, expected_nb_stats): + assert isinstance(result, pd.DataFrame) + assert len(result.columns) == expected_nb_stats + + normalized_expected = { + _normalize_metric_name(metric) for metric in expected_metrics + } + for column in result.columns: + normalized_column = _normalize_metric_name(column) + matches = [ + metric for metric in normalized_expected if metric == normalized_column + ] + assert len(matches) == 1, ( + f"No match found for column '{column}' in expected metrics: " + f" {expected_metrics}" + ) + + +@pytest.mark.parametrize("pos_label, nb_stats", [(None, 2), (1, 1)]) +def test_estimator_report_report_metrics_binary( + binary_classification_data, binary_classification_data_svc, pos_label, nb_stats +): + """Check the behaviour of the `report_metrics` method with binary + classification. We test both with an SVC that does not support `predict_proba` and a + RandomForestClassifier that does. + """ + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + result = report.metrics.report_metrics(pos_label=pos_label) + expected_metrics = ("precision", "recall", "roc_auc", "brier_score") + # depending on `pos_label`, we report a stats for each class or not for + # precision and recall + expected_nb_stats = 2 * nb_stats + 2 + _check_results_report_metrics(result, expected_metrics, expected_nb_stats) + + # Repeat the same experiment where we the target labels are not [0, 1] but + # ["neg", "pos"]. We check that we don't get any error. + target_names = np.array(["neg", "pos"], dtype=object) + pos_label_name = target_names[pos_label] if pos_label is not None else pos_label + y_test = target_names[y_test] + estimator = clone(estimator).fit(X_test, y_test) + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + result = report.metrics.report_metrics(pos_label=pos_label_name) + expected_metrics = ("precision", "recall", "roc_auc", "brier_score") + # depending on `pos_label`, we report a stats for each class or not for + # precision and recall + expected_nb_stats = 2 * nb_stats + 2 + _check_results_report_metrics(result, expected_metrics, expected_nb_stats) + + estimator, X_test, y_test = binary_classification_data_svc + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + result = report.metrics.report_metrics(pos_label=pos_label) + expected_metrics = ("precision", "recall", "roc_auc") + # depending on `pos_label`, we report a stats for each class or not for + # precision and recall + expected_nb_stats = 2 * nb_stats + 1 + _check_results_report_metrics(result, expected_metrics, expected_nb_stats) + + +def test_estimator_report_report_metrics_multiclass( + multiclass_classification_data, multiclass_classification_data_svc +): + """Check the behaviour of the `report_metrics` method with multiclass + classification. + """ + estimator, X_test, y_test = multiclass_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + result = report.metrics.report_metrics() + expected_metrics = ("precision", "recall", "roc_auc", "log_loss") + # since we are not averaging by default, we report 3 statistics for + # precision, recall and roc_auc + expected_nb_stats = 3 * 3 + 1 + _check_results_report_metrics(result, expected_metrics, expected_nb_stats) + + estimator, X_test, y_test = multiclass_classification_data_svc + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + result = report.metrics.report_metrics() + expected_metrics = ("precision", "recall") + # since we are not averaging by default, we report 3 statistics for + # precision and recall + expected_nb_stats = 3 * 2 + _check_results_report_metrics(result, expected_metrics, expected_nb_stats) + + +def test_estimator_report_report_metrics_regression(regression_data): + """Check the behaviour of the `report_metrics` method with regression.""" + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + result = report.metrics.report_metrics() + expected_metrics = ("r2", "rmse") + _check_results_report_metrics(result, expected_metrics, len(expected_metrics)) + + +def test_estimator_report_report_metrics_scoring_kwargs( + regression_multioutput_data, multiclass_classification_data +): + """Check the behaviour of the `report_metrics` method with scoring kwargs.""" + estimator, X_test, y_test = regression_multioutput_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + assert hasattr(report.metrics, "report_metrics") + result = report.metrics.report_metrics(scoring_kwargs={"multioutput": "raw_values"}) + assert result.shape == (1, 4) + assert isinstance(result.columns, pd.MultiIndex) + assert result.columns.names == ["Metric", "Output"] + + estimator, X_test, y_test = multiclass_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + assert hasattr(report.metrics, "report_metrics") + result = report.metrics.report_metrics(scoring_kwargs={"average": None}) + assert result.shape == (1, 10) + assert isinstance(result.columns, pd.MultiIndex) + assert result.columns.names == ["Metric", "Class label"] + + +def test_estimator_report_interaction_cache_metrics(regression_multioutput_data): + """Check that the cache take into account the 'kwargs' of a metric.""" + estimator, X_test, y_test = regression_multioutput_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + # The underlying metrics will call `_compute_metric_scores` that take some arbitrary + # kwargs apart from `pos_label`. Let's pass an arbitrary kwarg and make sure it is + # part of the cache. + multioutput = "raw_values" + result_r2_raw_values = report.metrics.r2(multioutput=multioutput) + should_raise = True + for cached_key in report._cache: + if any(item == multioutput for item in cached_key): + should_raise = False + break + assert ( + not should_raise + ), f"The value {multioutput} should be stored in one of the cache keys" + assert result_r2_raw_values.shape == (1, 2) + + multioutput = "uniform_average" + result_r2_uniform_average = report.metrics.r2(multioutput=multioutput) + should_raise = True + for cached_key in report._cache: + if any(item == multioutput for item in cached_key): + should_raise = False + break + assert ( + not should_raise + ), f"The value {multioutput} should be stored in one of the cache keys" + assert result_r2_uniform_average.shape == (1, 1) + + +def test_estimator_report_custom_metric(regression_data): + """Check the behaviour of the `custom_metric` computation in the report.""" + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + def custom_metric(y_true, y_pred, threshold=0.5): + residuals = y_true - y_pred + return np.mean(np.where(residuals < threshold, residuals, 1)) + + threshold = 1 + result = report.metrics.custom_metric( + metric_function=custom_metric, + metric_name="Custom Metric", + response_method="predict", + threshold=threshold, + ) + should_raise = True + for cached_key in report._cache: + if any(item == threshold for item in cached_key): + should_raise = False + break + assert ( + not should_raise + ), f"The value {threshold} should be stored in one of the cache keys" + + assert result.columns.tolist() == ["Custom Metric"] + assert result.to_numpy()[0, 0] == pytest.approx( + custom_metric(y_test, estimator.predict(X_test), threshold) + ) + + threshold = 100 + result = report.metrics.custom_metric( + metric_function=custom_metric, + metric_name="Custom Metric", + response_method="predict", + threshold=threshold, + ) + should_raise = True + for cached_key in report._cache: + if any(item == threshold for item in cached_key): + should_raise = False + break + assert ( + not should_raise + ), f"The value {threshold} should be stored in one of the cache keys" + + assert result.columns.tolist() == ["Custom Metric"] + assert result.to_numpy()[0, 0] == pytest.approx( + custom_metric(y_test, estimator.predict(X_test), threshold) + ) + + +def test_estimator_report_custom_function_kwargs_numpy_array(regression_data): + """Check that we are able to store a hash of a numpy array in the cache when they + are passed as kwargs. + """ + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + weights = np.ones_like(y_test) * 2 + hash_weights = joblib.hash(weights) + + def custom_metric(y_true, y_pred, some_weights): + return np.mean((y_true - y_pred) * some_weights) + + result = report.metrics.custom_metric( + metric_function=custom_metric, + metric_name="Custom Metric", + response_method="predict", + some_weights=weights, + ) + should_raise = True + for cached_key in report._cache: + if any(item == hash_weights for item in cached_key): + should_raise = False + break + assert ( + not should_raise + ), "The hash of the weights should be stored in one of the cache keys" + + assert result.columns.tolist() == ["Custom Metric"] + assert result.to_numpy()[0, 0] == pytest.approx( + custom_metric(y_test, estimator.predict(X_test), weights) + ) + + +def test_estimator_report_report_metrics_with_custom_metric(regression_data): + """Check that we can pass a custom metric with specific kwargs into + `report_metrics`.""" + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + weights = np.ones_like(y_test) * 2 + + def custom_metric(y_true, y_pred, some_weights): + return np.mean((y_true - y_pred) * some_weights) + + result = report.metrics.report_metrics( + scoring=["r2", custom_metric], + scoring_kwargs={"some_weights": weights, "response_method": "predict"}, + ) + assert result.shape == (1, 2) + np.testing.assert_allclose( + result.to_numpy(), + [ + [ + r2_score(y_test, estimator.predict(X_test)), + custom_metric(y_test, estimator.predict(X_test), weights), + ] + ], + ) + + +def test_estimator_report_report_metrics_with_scorer(regression_data): + """Check that we can pass scikit-learn scorer with different parameters to + the `report_metrics` method.""" + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + weights = np.ones_like(y_test) * 2 + + def custom_metric(y_true, y_pred, some_weights): + return np.mean((y_true - y_pred) * some_weights) + + median_absolute_error_scorer = make_scorer( + median_absolute_error, response_method="predict" + ) + custom_metric_scorer = make_scorer( + custom_metric, response_method="predict", some_weights=weights + ) + result = report.metrics.report_metrics( + scoring=[r2_score, median_absolute_error_scorer, custom_metric_scorer], + scoring_kwargs={"response_method": "predict"}, # only dispatched to r2_score + ) + assert result.shape == (1, 3) + np.testing.assert_allclose( + result.to_numpy(), + [ + [ + r2_score(y_test, estimator.predict(X_test)), + median_absolute_error(y_test, estimator.predict(X_test)), + custom_metric(y_test, estimator.predict(X_test), weights), + ] + ], + ) + + +def test_estimator_report_report_metrics_invalid_metric_type(regression_data): + """Check that we raise the expected error message if an invalid metric is passed.""" + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + err_msg = re.escape("Invalid type of metric: for 1") + with pytest.raises(ValueError, match=err_msg): + report.metrics.report_metrics(scoring=[1]) + + +def test_estimator_report_get_X_y_and_data_source_hash_error(): + """Check that we raise the proper error in `get_X_y_and_use_cache`.""" + X, y = make_classification(n_classes=2, random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + + estimator = LogisticRegression().fit(X_train, y_train) + report = EstimatorReport(estimator) + + err_msg = re.escape( + "Invalid data source: unknown. Possible values are: " "test, train, X_y." + ) + with pytest.raises(ValueError, match=err_msg): + report.metrics.log_loss(data_source="unknown") + + for data_source in ("train", "test"): + err_msg = re.escape( + f"No {data_source} data (i.e. X_{data_source} and y_{data_source}) were " + f"provided when creating the reporter. Please provide the {data_source} " + "data either when creating the reporter or by setting data_source to " + "'X_y' and providing X and y." + ) + with pytest.raises(ValueError, match=err_msg): + report.metrics.log_loss(data_source=data_source) + + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + + for data_source in ("train", "test"): + err_msg = f"X and y must be None when data_source is {data_source}." + with pytest.raises(ValueError, match=err_msg): + report.metrics.log_loss(data_source=data_source, X=X_test, y=y_test) + + err_msg = "X and y must be provided." + with pytest.raises(ValueError, match=err_msg): + report.metrics.log_loss(data_source="X_y") + + # FIXME: once we choose some basic metrics for clustering, then we don't need to + # use `custom_metric` for them. + estimator = KMeans().fit(X_train) + report = EstimatorReport(estimator, X_test=X_test) + err_msg = "X must be provided." + with pytest.raises(ValueError, match=err_msg): + report.metrics.custom_metric( + rand_score, response_method="predict", data_source="X_y" + ) + + report = EstimatorReport(estimator) + for data_source in ("train", "test"): + err_msg = re.escape( + f"No {data_source} data (i.e. X_{data_source}) were provided when " + f"creating the reporter. Please provide the {data_source} data either " + f"when creating the reporter or by setting data_source to 'X_y' and " + f"providing X and y." + ) + with pytest.raises(ValueError, match=err_msg): + report.metrics.custom_metric( + rand_score, response_method="predict", data_source=data_source + ) + + +@pytest.mark.parametrize("data_source", ("train", "test", "X_y")) +def test_estimator_report_get_X_y_and_data_source_hash(data_source): + """Check the general behaviour of `get_X_y_and_use_cache`.""" + X, y = make_classification(n_classes=2, random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + + estimator = LogisticRegression() + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + + kwargs = {"X": X_test, "y": y_test} if data_source == "X_y" else {} + X, y, data_source_hash = report.metrics._get_X_y_and_data_source_hash( + data_source=data_source, **kwargs + ) + + if data_source == "train": + assert X is X_train + assert y is y_train + assert data_source_hash is None + elif data_source == "test": + assert X is X_test + assert y is y_test + assert data_source_hash is None + elif data_source == "X_y": + assert X is X_test + assert y is y_test + assert data_source_hash == joblib.hash((X_test, y_test)) diff --git a/skore/tests/unit/utils/test_accessors.py b/skore/tests/unit/utils/test_accessors.py new file mode 100644 index 000000000..7e7a8953b --- /dev/null +++ b/skore/tests/unit/utils/test_accessors.py @@ -0,0 +1,60 @@ +import pytest +from skore.externals._pandas_accessors import DirNamesMixin, _register_accessor +from skore.utils._accessor import _check_supported_ml_task + + +def test_register_accessor(): + """Test that an accessor is properly registered and accessible on a class + instance. + """ + + class ParentClass(DirNamesMixin): + pass + + def register_parent_class_accessor(name: str): + """Register an accessor for the ParentClass class.""" + return _register_accessor(name, ParentClass) + + @register_parent_class_accessor("accessor") + class _Accessor: + def __init__(self, parent): + self._parent = parent + + def func(self): + return True + + obj = ParentClass() + assert hasattr(obj, "accessor") + assert isinstance(obj.accessor, _Accessor) + assert obj.accessor.func() + + +def test_check_supported_ml_task(): + """Test that ML task validation accepts supported tasks and rejects unsupported + ones. + """ + + class MockParent: + def __init__(self, ml_task): + self._ml_task = ml_task + + class MockAccessor: + def __init__(self, parent): + self._parent = parent + + parent = MockParent("binary-classification") + accessor = MockAccessor(parent) + check = _check_supported_ml_task( + ["binary-classification", "multiclass-classification"] + ) + assert check(accessor) + + parent = MockParent("multiclass-classification") + accessor = MockAccessor(parent) + assert check(accessor) + + parent = MockParent("regression") + accessor = MockAccessor(parent) + err_msg = "The regression task is not a supported task by function called." + with pytest.raises(AttributeError, match=err_msg): + check(accessor) diff --git a/sphinx/_templates/autosummary/accessor.rst b/sphinx/_templates/autosummary/accessor.rst new file mode 100644 index 000000000..145ca83dd --- /dev/null +++ b/sphinx/_templates/autosummary/accessor.rst @@ -0,0 +1,5 @@ +{{ objname | escape | underline(line="=") }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessor:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/sphinx/_templates/autosummary/accessor_attribute.rst b/sphinx/_templates/autosummary/accessor_attribute.rst new file mode 100644 index 000000000..c2769d66d --- /dev/null +++ b/sphinx/_templates/autosummary/accessor_attribute.rst @@ -0,0 +1,5 @@ +{{ objname | escape | underline(line="=") }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessorattribute:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/sphinx/_templates/autosummary/accessor_callable.rst b/sphinx/_templates/autosummary/accessor_callable.rst new file mode 100644 index 000000000..261adfdf1 --- /dev/null +++ b/sphinx/_templates/autosummary/accessor_callable.rst @@ -0,0 +1,5 @@ +{{ objname | escape | underline(line="=") }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessorcallable:: {{ (module.split('.')[1:] + [objname]) | join('.') }}.__call__ diff --git a/sphinx/_templates/autosummary/accessor_method.rst b/sphinx/_templates/autosummary/accessor_method.rst new file mode 100644 index 000000000..5c116571d --- /dev/null +++ b/sphinx/_templates/autosummary/accessor_method.rst @@ -0,0 +1,5 @@ +{{ objname | escape | underline(line="=") }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessormethod:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/sphinx/api.rst b/sphinx/api.rst index 4c3b2535c..e1a519531 100644 --- a/sphinx/api.rst +++ b/sphinx/api.rst @@ -36,3 +36,59 @@ These functions and classes enhance scikit-learn's ones. train_test_split CrossValidationReporter item.cross_validation_item.CrossValidationItem + +Report for a single estimator +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The class :class:`EstimatorReport` provides a reporter allowing to inspect and +evaluate a scikit-learn estimator in an interactive way. The functionalities of the +reporter are accessible through accessors. + +.. autosummary:: + :toctree: generated/ + :template: base.rst + :caption: Reporting for a single estimator + + EstimatorReport + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: autosummary/accessor_method.rst + + EstimatorReport.help + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: autosummary/accessor.rst + + EstimatorReport.metrics + +Metrics +""""""" + +The `metrics` accessor helps you to evaluate the statistical performance of your +estimator. In addition, we provide a sub-accessor `plot`, to get the common +performance metric representations. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: autosummary/accessor_method.rst + + EstimatorReport.metrics.help + EstimatorReport.metrics.report_metrics + EstimatorReport.metrics.custom_metric + EstimatorReport.metrics.accuracy + EstimatorReport.metrics.brier_score + EstimatorReport.metrics.log_loss + EstimatorReport.metrics.precision + EstimatorReport.metrics.r2 + EstimatorReport.metrics.recall + EstimatorReport.metrics.rmse + EstimatorReport.metrics.roc_auc + EstimatorReport.metrics.plot.help + EstimatorReport.metrics.plot.precision_recall + EstimatorReport.metrics.plot.prediction_error + EstimatorReport.metrics.plot.roc diff --git a/sphinx/conf.py b/sphinx/conf.py index b22954ba3..7d554848e 100644 --- a/sphinx/conf.py +++ b/sphinx/conf.py @@ -7,6 +7,7 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information import os +import sphinx_autosummary_accessors from sphinx_gallery.sorting import ExplicitOrder project = "skore" @@ -27,11 +28,16 @@ "sphinx_gallery.gen_gallery", "sphinx_copybutton", "sphinx_tabs.tabs", + "sphinx_autosummary_accessors", ] -templates_path = ["_templates"] exclude_patterns = ["build", "Thumbs.db", ".DS_Store"] +# The reST default role (used for this markup: `text`) to use for all +# documents. +default_role = "literal" + # Add any paths that contain templates here, relative to this directory. +autosummary_generate = True # generate stubs for all classes templates_path = ["_templates"] # -- Options for HTML output ------------------------------------------------- @@ -107,7 +113,7 @@ # Use :html_theme.sidebar_secondary.remove: for file-wide removal "secondary_sidebar_items": { "**": ["page-toc", "sourcelink", "sg_download_links", "sg_launcher_links"], - "index": [], # hide secondary sidebar items for the landing page + "index": [], # hide secondary sidebar items for the landing page "install": [], }, "external_links": [