Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Design of EstimatorReport #997

Merged
merged 146 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 133 commits
Commits
Show all changes
146 commits
Select commit Hold shift + click to select a range
fd94974
feat: Use friendly verbose and colorish
glemaitre Dec 15, 2024
1a3d4a6
limit size
glemaitre Dec 15, 2024
fe075a4
tweak bold effect
glemaitre Dec 15, 2024
505f2b4
iter
glemaitre Dec 19, 2024
c563d45
test: complete tests for new arg in cli
MarieS-WiMLDS Dec 19, 2024
801795c
iter
glemaitre Dec 19, 2024
376099c
use context manager as a more explicit way to configurate the logger
glemaitre Dec 20, 2024
d427f2d
TST add a couple of quick test for the logger context manager
glemaitre Dec 20, 2024
ce273b6
Merge remote-tracking branch 'glemaitre/is/959' into model_report
glemaitre Dec 20, 2024
9cdeacd
feat: EstimatorReport
glemaitre Dec 20, 2024
56b1821
iter
glemaitre Dec 21, 2024
acb51e0
fix: Use estimator whenever possible to detect the ML task
glemaitre Dec 21, 2024
92b4e6c
iter
glemaitre Dec 21, 2024
2982798
tests
glemaitre Dec 21, 2024
66a3fd3
tests
glemaitre Dec 21, 2024
4b7124b
DOC add some docstring
glemaitre Dec 21, 2024
7b7c9c8
EXA add an example to present the feature
glemaitre Dec 21, 2024
8de79ae
iter
glemaitre Dec 21, 2024
cf0c865
add metrics
glemaitre Dec 21, 2024
b1cd767
allow to pass a new set of data to the metrics
glemaitre Dec 22, 2024
97577b4
iter
glemaitre Dec 22, 2024
a9d67b4
TST add test for individual metrics
glemaitre Dec 22, 2024
597212b
TST add test for the default scoring in
glemaitre Dec 22, 2024
19fd6d8
TST add test for passing scoring kwargs in report_metrics
glemaitre Dec 22, 2024
eeba764
TST add check that we properly hit the cache with arbitrary keywords
glemaitre Dec 22, 2024
f845db6
FEA add support for an arbitrary metric
glemaitre Dec 22, 2024
9f219f2
improve example
glemaitre Dec 22, 2024
eaeb072
check name add add test with joblib hash
glemaitre Dec 22, 2024
0533abf
allow to add a custom metric in the reporting
glemaitre Dec 22, 2024
6d22b33
mainly refactor the help
glemaitre Dec 22, 2024
7c42cf9
iter
glemaitre Dec 22, 2024
9197928
fix bug according to test
glemaitre Dec 22, 2024
7594551
fix docstring check
glemaitre Dec 22, 2024
c94635f
add the EstimatorReport to the API doc
glemaitre Dec 22, 2024
be4b265
use literal option by default
glemaitre Dec 22, 2024
b491d4e
iter
glemaitre Dec 22, 2024
13297b5
add stubs for solving the problem of auto-completion
glemaitre Dec 22, 2024
4ac4237
only check the cache
glemaitre Dec 22, 2024
bb69f8f
TST add test for help and repr of accessor
glemaitre Dec 23, 2024
a3d97a0
use pos_label instead of possitive_class and test for plotting
glemaitre Dec 23, 2024
152fc21
add rich repr and help for display
glemaitre Dec 23, 2024
f56f19f
iter
glemaitre Dec 23, 2024
003039f
TST for the plot repr and help function
glemaitre Dec 23, 2024
36f7a2c
iter
glemaitre Dec 23, 2024
41a67a0
add precision recall curve
glemaitre Dec 23, 2024
c87ea6e
TST add more test for the estimator report displays
glemaitre Dec 23, 2024
70fb720
iter
glemaitre Dec 23, 2024
7258f75
rename X_val and y_val to X_test and y_test to simplify
glemaitre Dec 23, 2024
2cbba2f
use a single constructor
glemaitre Dec 23, 2024
914fb1a
Merge remote-tracking branch 'glemaitre/_find_ml_task_estimator_base'…
glemaitre Dec 23, 2024
259794b
accept external data
glemaitre Dec 23, 2024
0f93945
fix docstring
glemaitre Dec 23, 2024
f463285
iter
glemaitre Dec 23, 2024
d811bc6
add data_source with test
glemaitre Jan 3, 2025
1705252
Merge remote-tracking branch 'origin/main' into model_report
glemaitre Jan 3, 2025
aab9bf4
bring the cache to the external data by computing a hash
glemaitre Jan 3, 2025
ad28319
use agg backend
glemaitre Jan 3, 2025
b99d78c
Merge remote-tracking branch 'origin/main' into model_report
glemaitre Jan 3, 2025
4b23afb
expose .plot under the metrics accessor
glemaitre Jan 3, 2025
2df34d1
rename plot accessors
glemaitre Jan 3, 2025
ce4811c
iter
glemaitre Jan 3, 2025
8609881
iter
glemaitre Jan 3, 2025
6cb3b08
small refactoring for plotting
glemaitre Jan 4, 2025
4f9b6fc
commit refactoring
glemaitre Jan 4, 2025
ea18335
add multiclass ovr roc curve
glemaitre Jan 4, 2025
8b61fa0
update classification support for plots
glemaitre Jan 4, 2025
9ea0c6e
check as well for regression
glemaitre Jan 4, 2025
ef3937d
add a module to test display
glemaitre Jan 5, 2025
54d9fab
more test roc curve
glemaitre Jan 5, 2025
3119809
check error message
glemaitre Jan 5, 2025
6824bef
add test for the kwargs
glemaitre Jan 5, 2025
2a55aec
check chance kwargs
glemaitre Jan 5, 2025
c3e86cb
more doc and remove sample_weight for the moment
glemaitre Jan 5, 2025
6745a3d
modify pr curve and align roc curve
glemaitre Jan 5, 2025
efaa8b8
iter
glemaitre Jan 6, 2025
3646154
Merge remote-tracking branch 'origin/main' into model_report
glemaitre Jan 6, 2025
aaffeb2
add test precision recall curve binary
glemaitre Jan 6, 2025
54b8692
docstring for tests
glemaitre Jan 6, 2025
84bfba1
add test for multiclass precision recall curve
glemaitre Jan 6, 2025
7dad8b1
add test for args
glemaitre Jan 6, 2025
a48fb16
fix file
glemaitre Jan 6, 2025
0613274
iter
glemaitre Jan 6, 2025
0df833d
fix bug with default metric
glemaitre Jan 6, 2025
4eee6ab
more coverage
glemaitre Jan 6, 2025
4c5b10c
improve coverage in EstimatorReport
glemaitre Jan 6, 2025
b52cffa
do not cover the cross-validation for the moment but raise an error
glemaitre Jan 6, 2025
1a73cf6
update outdated setter
glemaitre Jan 6, 2025
a23fbcb
check other data_source displya
glemaitre Jan 6, 2025
94e4edd
check plot kwargs
glemaitre Jan 6, 2025
2235230
more coverage for precision recall display
glemaitre Jan 6, 2025
906058f
test providing axis in displays
glemaitre Jan 6, 2025
5f7e303
add test for plotting utils
glemaitre Jan 6, 2025
38182b0
modify example
glemaitre Jan 6, 2025
58d2a33
improve menu accessor
glemaitre Jan 6, 2025
84b5ae3
Merge branch 'main' into model_report
glemaitre Jan 6, 2025
97caf68
iter
glemaitre Jan 6, 2025
4fe9660
fix
glemaitre Jan 6, 2025
6c5eb2c
docstring fix
glemaitre Jan 6, 2025
c549a7f
new function to precompute the cache
glemaitre Jan 7, 2025
59e3ab8
test the cache_predictions function
glemaitre Jan 7, 2025
c64f259
add plotting error display
glemaitre Jan 7, 2025
0f9d368
Update skore/src/skore/sklearn/_estimator.py
glemaitre Jan 7, 2025
5746811
Apply suggestions from code review
glemaitre Jan 7, 2025
4bc0358
brier do not support multiclass
glemaitre Jan 7, 2025
b728222
iter
glemaitre Jan 7, 2025
b677035
docstring model
glemaitre Jan 7, 2025
3b1bea4
use unicode visual clue
glemaitre Jan 7, 2025
1cbbb17
iter
glemaitre Jan 7, 2025
b91a672
iter
glemaitre Jan 7, 2025
3b5c2eb
integration test
glemaitre Jan 7, 2025
05ff4f3
more test
glemaitre Jan 7, 2025
bc92165
simplify API for naming with non-default value needed
glemaitre Jan 7, 2025
096fd93
improve color help
glemaitre Jan 7, 2025
e2a3a74
iter
glemaitre Jan 7, 2025
aee7270
improve repr of the displays
glemaitre Jan 7, 2025
f627593
add estimator name in repr
glemaitre Jan 7, 2025
22f10c3
refactor in plots
glemaitre Jan 8, 2025
14da688
more refactoring
glemaitre Jan 8, 2025
1859092
refactor detecting but
glemaitre Jan 8, 2025
63e3d1c
make pos_label and average consistent
glemaitre Jan 8, 2025
1a3bea6
improve consistency documentation
glemaitre Jan 8, 2025
11d338e
small refactor test
glemaitre Jan 8, 2025
e46ad58
more refactor tests
glemaitre Jan 8, 2025
bb63a82
refactor
glemaitre Jan 8, 2025
1e7f909
fix
glemaitre Jan 8, 2025
e6d2a69
split stubs file
glemaitre Jan 8, 2025
e8721d6
update doc
glemaitre Jan 8, 2025
334d121
use accessor from pandas to ease doc building
glemaitre Jan 8, 2025
7fba2ff
add noqa F401 to avoid removing import
glemaitre Jan 8, 2025
9c1abd0
do not inject do and overwrite instead
glemaitre Jan 8, 2025
50b7ce7
use the register_accessor for sub-accessor
glemaitre Jan 8, 2025
249e062
first draft for accessor documentation
glemaitre Jan 8, 2025
5a44fc7
add sphinx_autosummary_accessors as a dependence
glemaitre Jan 8, 2025
ad4f408
Update examples/model_evaluation/plot_estimator_report.py
glemaitre Jan 9, 2025
dad7211
Update skore/src/skore/sklearn/_estimator/base.py
glemaitre Jan 9, 2025
8cc9f4e
Update skore/src/skore/sklearn/_estimator/report.py
glemaitre Jan 9, 2025
4b8be36
rewrap
glemaitre Jan 9, 2025
ad20a1c
add legend in the help
glemaitre Jan 9, 2025
682801c
make matplotlib and pandas a dependency
glemaitre Jan 9, 2025
d519bd7
remove unecessary __init__.py
glemaitre Jan 9, 2025
65a0f5d
vendor the accessor
glemaitre Jan 9, 2025
ebb7ab9
iter
glemaitre Jan 9, 2025
cb5f210
add attributes
glemaitre Jan 9, 2025
b8d4610
check that we support X_y without passing original dataset
glemaitre Jan 9, 2025
866be82
compute brier score for both labels
glemaitre Jan 9, 2025
82f6332
simplify the brier score
glemaitre Jan 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
384 changes: 384 additions & 0 deletions examples/model_evaluation/plot_estimator_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,384 @@
"""
============================================
Get insights from any scikit-learn estimator
============================================

This example shows how the :class:`skore.EstimatorReport` class can be used to
quickly get insights from any scikit-learn estimator.
"""

# %%
#
# TODO: we need to describe the aim of this classification problem.
from skrub.datasets import fetch_open_payments

dataset = fetch_open_payments()
df = dataset.X
y = dataset.y

# %%
from skrub import TableReport

TableReport(df)

# %%
TableReport(y.to_frame())

# %%
# Looking at the distributions of the target, we observe that this classification
# task is quite imbalanced. It means that we have to be careful when selecting a set
# of statistical metrics to evaluate the classification performance of our predictive
# model. In addition, we see that the class labels are not specified by an integer
# 0 or 1 but instead by a string "allowed" or "disallowed".
#
# For our application, the label of interest is "allowed".
pos_label, neg_label = "allowed", "disallowed"

# %%
# Before training a predictive model, we need to split our dataset into a training
# and a validation set.
from skore import train_test_split

X_train, X_test, y_train, y_test = train_test_split(df, y, random_state=42)

# %%
# TODO: we have a perfect case to show useful feature of the `train_test_split`
# function from `skore`.
#
# Now, we need to define a predictive model. Hopefully, `skrub` provides a convenient
# function (:func:`skrub.tabular_learner`) when it comes to get strong baseline
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
# predictive models with a single line of code. As its feature engineering is generic, it does not
# provide some handcrafted and tailored feature engineering but still provides a good starting point.
#
# So let's create a classifier for our task and fit it on the training set.
from skrub import tabular_learner

estimator = tabular_learner("classifier").fit(X_train, y_train)
estimator

# %%
#
# Introducing the :class:`skore.EstimatorReport` class
# ----------------------------------------------------
#
# Now, we would be interested in getting some insights from our predictive model.
# One way is to use the :class:`skore.EstimatorReport` class. This constructor will
# detect that our estimator is already fitted and will not fit it again.
from skore import EstimatorReport

reporter = EstimatorReport(
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
)
reporter

# %%
#
# Once the reporter is created, we get some information regarding the available tools
# allowing us to get some insights from our specific model on the specific task.
#
# You can get a similar information if you call the :meth:`~skore.EstimatorReport.help`
# method.
reporter.help()

# %%
#
# Be aware that you can access the help for each individual sub-accessor. For instance:
reporter.metrics.help()

# %%
reporter.metrics.plot.help()

# %%
#
# Metrics computation with aggressive caching
# -------------------------------------------
#
# At this point, we might be interested to have a first look at the statistical
# performance of our model on the validation set that we provided. We can access it
# by calling any of the metrics displayed above. Since we are greedy, we want to get
# several metrics at once and we will use the
# :meth:`~skore.EstimatorReport.metrics.report_metrics` method.
import time

start = time.time()
metric_report = reporter.metrics.report_metrics(pos_label=pos_label)
end = time.time()
metric_report

# %%
print(f"Time taken to compute the metrics: {end - start:.2f} seconds")

# %%
#
# An interesting feature provided by the :class:`skore.EstimatorReport` is the
# the caching mechanism. Indeed, when we have a large enough dataset, computing the
# predictions for a model is not cheap anymore. For instance, on our smallish dataset,
# it took a couple of seconds to compute the metrics. The reporter will cache the
# predictions and if you are interested in computing a metric again or an alternative
# metric that requires the same predictions, it will be faster. Let's check by
# requesting the same metrics report again.

start = time.time()
metric_report = reporter.metrics.report_metrics(pos_label=pos_label)
end = time.time()
metric_report

# %%
print(f"Time taken to compute the metrics: {end - start:.2f} seconds")

# %%
#
# Since we obtain a pandas dataframe, we can also use the plotting interface of
# pandas.
import matplotlib.pyplot as plt

ax = metric_report.T.plot.barh()
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
ax.set_title("Metrics report")
plt.tight_layout()

# %%
#
# Whenever computing a metric, we check if the predictions are available in the cache
# and reload them if available. So for instance, let's compute the log loss.

start = time.time()
log_loss = reporter.metrics.log_loss()
end = time.time()
log_loss

# %%
print(f"Time taken to compute the log loss: {end - start:.2f} seconds")

# %%
#
# We can show that without initial cache, it would have taken more time to compute
# the log loss.
reporter.clean_cache()

start = time.time()
log_loss = reporter.metrics.log_loss()
end = time.time()
log_loss

# %%
print(f"Time taken to compute the log loss: {end - start:.2f} seconds")

# %%
#
# By default, the metrics are computed on the test set. However, if a training set
# is provided, we can also compute the metrics by specifying the `data_source`
# parameter.
reporter.metrics.log_loss(data_source="train")

# %%
#
# In the case where we are interested in computing the metrics on a completely new set
# of data, we can use the `data_source="X_y"` parameter. In addition, we need to provide
# a `X` and `y` parameters.

start = time.time()
metric_report = reporter.metrics.report_metrics(
data_source="X_y", X=X_test, y=y_test, pos_label=pos_label
)
end = time.time()
metric_report

# %%
print(f"Time taken to compute the metrics: {end - start:.2f} seconds")

# %%
#
# As in the other case, we rely on the cache to avoid recomputing the predictions.
# Internally, we compute a hash of the input data to be sure that we can hit the cache
# in a consistent way.

# %%
start = time.time()
metric_report = reporter.metrics.report_metrics(
data_source="X_y", X=X_test, y=y_test, pos_label=pos_label
)
end = time.time()
metric_report

# %%
print(f"Time taken to compute the metrics: {end - start:.2f} seconds")

# %%
#
# .. warning::
# In this last example, we rely on computing the hash of the input data. Therefore,
# there is a trade-off: the computation of the hash is not free and it might be
# faster to compute the predictions instead.
#
# Be aware that you can also benefit from the caching mechanism with your own custom
# metrics. We only expect that you define your own metric function to take `y_true`
# and `y_pred` as the first two positional arguments. It can take any other arguments.
# Let's see an example.


def operational_decision_cost(y_true, y_pred, amount):
mask_true_positive = (y_true == pos_label) & (y_pred == pos_label)
mask_true_negative = (y_true == neg_label) & (y_pred == neg_label)
mask_false_positive = (y_true == neg_label) & (y_pred == pos_label)
mask_false_negative = (y_true == pos_label) & (y_pred == neg_label)
# FIXME: we need to make sense of the cost sensitive part with the right naming
fraudulent_refuse = mask_true_positive.sum() * 50
fraudulent_accept = -amount[mask_false_negative].sum()
legitimate_refuse = mask_false_positive.sum() * -5
legitimate_accept = (amount[mask_true_negative] * 0.02).sum()
return fraudulent_refuse + fraudulent_accept + legitimate_refuse + legitimate_accept


# %%
#
# In our use case, we have a operational decision to make that translate the
# classification outcome into a cost. It translate the confusion matrix into a cost
# matrix based on some amount linked to each sample in the dataset that are provided to
# us. Here, we randomly generate some amount as an illustration.
import numpy as np

rng = np.random.default_rng(42)
amount = rng.integers(low=100, high=1000, size=len(y_test))

# %%
#
# Let's make sure that a function called the `predict` method and cached the result.
# We compute the accuracy metric to make sure that the `predict` method is called.
reporter.metrics.accuracy()

# %%
#
# We can now compute the cost of our operational decision.
start = time.time()
cost = reporter.metrics.custom_metric(
metric_function=operational_decision_cost,
metric_name="Operational Decision Cost",
response_method="predict",
amount=amount,
)
end = time.time()
cost

# %%
print(f"Time taken to compute the cost: {end - start:.2f} seconds")

# %%
#
# Let's now clean the cache and see if it is faster.
reporter.clean_cache()

# %%
start = time.time()
cost = reporter.metrics.custom_metric(
metric_function=operational_decision_cost,
metric_name="Operational Decision Cost",
response_method="predict",
amount=amount,
)
end = time.time()
cost

# %%
print(f"Time taken to compute the cost: {end - start:.2f} seconds")

# %%
#
# We observe that caching is working as expected. It is really handy because it means
# that you can compute some additional metrics without having to recompute the
# the predictions.
reporter.metrics.report_metrics(
scoring=["precision", "recall", operational_decision_cost],
pos_label=pos_label,
scoring_kwargs={
"amount": amount,
"response_method": "predict",
"metric_name": "Operational Decision Cost",
},
)

# %%
#
# It could happen that you are interested in providing several custom metrics which
# does not necessarily share the same parameters. In this more complex case, we will
# require you to provide a scorer using the :func:`sklearn.metrics.make_scorer`
# function.
from sklearn.metrics import make_scorer, f1_score

f1_scorer = make_scorer(
f1_score,
response_method="predict",
metric_name="F1 Score",
pos_label=pos_label,
)
operational_decision_cost_scorer = make_scorer(
operational_decision_cost,
response_method="predict",
metric_name="Operational Decision Cost",
amount=amount,
)
reporter.metrics.report_metrics(scoring=[f1_scorer, operational_decision_cost_scorer])

# %%
#
# Effortless one-liner plotting
# -----------------------------
#
# The :class:`skore.EstimatorReport` class also provides a plotting interface that
# allows to plot *defacto* the most common plots. As for the the metrics, we only
# provide the meaningful set of plots for the provided estimator.
reporter.metrics.plot.help()

# %%
#
# Let's start by plotting the ROC curve for our binary classification task.
display = reporter.metrics.plot.roc(pos_label=pos_label)
plt.tight_layout()

# %%
#
# The plot functionality is built upon the scikit-learn display objects. We return
# those display (slightly modified to improve the UI) in case you want to tweak some
# of the plot properties. You can have quick look at the available attributes and
# methods by calling the `help` method or simply by printing the display.
display

# %%
display.help()

# %%
display.plot()
display.ax_.set_title("Example of a ROC curve")
display.figure_
plt.tight_layout()

# %%
#
# Similarly to the metrics, we aggressively use the caching to avoid recomputing the
# predictions of the model. We also cache the plot display object by detection if the
# input parameters are the same as the previous call. Let's demonstrate the kind of
# performance gain we can get.
start = time.time()
# we already trigger the computation of the predictions in a previous call
reporter.metrics.plot.roc(pos_label=pos_label)
plt.tight_layout()
end = time.time()

# %%
print(f"Time taken to compute the ROC curve: {end - start:.2f} seconds")

# %%
#
# Now, let's clean the cache and check if we get a slowdown.
reporter.clean_cache()

# %%
start = time.time()
reporter.metrics.plot.roc(pos_label=pos_label)
plt.tight_layout()
end = time.time()

# %%
print(f"Time taken to compute the ROC curve: {end - start:.2f} seconds")

# %%
# As expected, since we need to recompute the predictions, it takes more time.
Loading
Loading