Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: Add example to showcase EstimatorReport and CrossValidationReport #1156

Merged
merged 133 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 127 commits
Commits
Show all changes
133 commits
Select commit Hold shift + click to select a range
76d72c6
feat: Revamp CrossValidationReport to use EstimatorReport
glemaitre Jan 10, 2025
98d94d7
Merge remote-tracking branch 'origin/main' into cross_validation_repo…
glemaitre Jan 10, 2025
8353b00
add progress bar
glemaitre Jan 10, 2025
b695666
iter
glemaitre Jan 11, 2025
4f8aa1c
make sure to use shared memory
glemaitre Jan 11, 2025
5a42157
fix progress bar glitches
glemaitre Jan 11, 2025
e4accc8
fix attribute sorting
glemaitre Jan 11, 2025
5ace5be
iter
glemaitre Jan 11, 2025
c6333ca
add metric accessor class
glemaitre Jan 11, 2025
82f64cb
iter
glemaitre Jan 11, 2025
b47786e
add custom metric
glemaitre Jan 11, 2025
fe110ff
iter
glemaitre Jan 11, 2025
35e01ea
roc curve
glemaitre Jan 11, 2025
4d0f16f
iter
glemaitre Jan 11, 2025
d46199e
iter
glemaitre Jan 11, 2025
ae11e23
some cache optimization
glemaitre Jan 12, 2025
e3a8a59
hash computation optimization
glemaitre Jan 12, 2025
68755aa
provide a way to send pos_label
glemaitre Jan 12, 2025
3175047
iter
glemaitre Jan 12, 2025
0256166
iter
glemaitre Jan 12, 2025
8e6a532
iter
glemaitre Jan 12, 2025
2635f56
more parallelism
glemaitre Jan 12, 2025
23fd1aa
bug
glemaitre Jan 14, 2025
3ae4da8
Merge remote-tracking branch 'origin/main' into cross_validation_4
glemaitre Jan 14, 2025
7dfe4a4
convert class decorator to function decorator
glemaitre Jan 14, 2025
31a9ba3
simplify progress bar
glemaitre Jan 14, 2025
071daa5
feat: Allow for nested progress bar
glemaitre Jan 14, 2025
8414c20
actually add the file to git
glemaitre Jan 14, 2025
5238fdc
more comment
glemaitre Jan 14, 2025
3a97df2
chore: More readable version of iterating
glemaitre Jan 14, 2025
57d7eb0
document simplify
glemaitre Jan 14, 2025
1199cb3
Merge branch 'progress_bar' into cross_validation_report_3
glemaitre Jan 14, 2025
c2637b2
Merge branch 'chore_cache_predictions_2' into cross_validation_report_3
glemaitre Jan 14, 2025
dc1e119
conflict
glemaitre Jan 14, 2025
3a1cd06
fix: Fix the error message regarding immutable attribute
glemaitre Jan 14, 2025
376f77b
Merge branch 'fix_err_message_mutability' into cross_validation_report_3
glemaitre Jan 14, 2025
05af3fb
fix messages
glemaitre Jan 14, 2025
c8f87f7
feat: Expose private API to to optimize cache optimization by passing…
glemaitre Jan 14, 2025
ae292a8
tests: Check for error message with invalid strings
glemaitre Jan 14, 2025
ef8d1fe
iter
glemaitre Jan 14, 2025
6d770cc
add documentation as suggested per auguste
glemaitre Jan 14, 2025
8f9cbfe
add test to handle pos_label in scorer
glemaitre Jan 14, 2025
79ec786
iter
glemaitre Jan 14, 2025
08146ae
Merge branch 'cache_optimization' into cross_validation_report_6
glemaitre Jan 14, 2025
0915f32
chore: Differentiate __repr__ and help for report and accessors
glemaitre Jan 14, 2025
8680166
another assert for help method name
glemaitre Jan 14, 2025
cf03f94
rename clean_cache to clear_cache
glemaitre Jan 14, 2025
57f12b0
Merge branch 'is/1103' into cross_validation_report_3
glemaitre Jan 14, 2025
a6c5ccb
Merge remote-tracking branch 'origin/main' into cross_validation_repo…
glemaitre Jan 14, 2025
332a692
Merge remote-tracking branch 'origin/main' into cross_validation_repo…
glemaitre Jan 14, 2025
cce35d4
test general behaviour and attribute of the report
glemaitre Jan 15, 2025
aa18678
check the caching mechanism
glemaitre Jan 15, 2025
ad4cec8
check help and repr for metrics and plot accessors
glemaitre Jan 15, 2025
7f44249
add test metrics binary
glemaitre Jan 15, 2025
ef98506
fix typo
glemaitre Jan 15, 2025
d6b4ab3
fix typo
glemaitre Jan 15, 2025
0ea179a
add test for single metrics
glemaitre Jan 15, 2025
c9dfc8a
covert report
glemaitre Jan 15, 2025
aab38f6
Merge remote-tracking branch 'origin/main' into cross_validation_repo…
glemaitre Jan 15, 2025
1d58817
iter
glemaitre Jan 15, 2025
88f76cc
more tests
glemaitre Jan 15, 2025
a06e3ff
more tests
glemaitre Jan 15, 2025
d575dca
check scoring_names
glemaitre Jan 15, 2025
51f77a3
scorer error name
glemaitre Jan 15, 2025
772bb2a
iter
glemaitre Jan 15, 2025
c5afcd0
more tests
glemaitre Jan 15, 2025
bab0e20
iter
glemaitre Jan 15, 2025
a881c56
update pos_label documentation
glemaitre Jan 15, 2025
fe01241
update pos_label documentation
glemaitre Jan 15, 2025
467699c
api: Prepend with name of variable that we modify
glemaitre Jan 15, 2025
44055f8
Merge branch 'is/1118' into cross_validation_report_3
glemaitre Jan 15, 2025
8b4f8c4
move new convention _
glemaitre Jan 15, 2025
a69ac60
documentation reporter
glemaitre Jan 15, 2025
4e5ffb6
documentation
glemaitre Jan 16, 2025
83635d0
revert example
glemaitre Jan 16, 2025
be5e826
fix
glemaitre Jan 16, 2025
c6426ed
Merge remote-tracking branch 'origin/main' into cross_validation_repo…
glemaitre Jan 16, 2025
ee01f4f
Merge remote-tracking branch 'origin/main' into cross_validation_repo…
glemaitre Jan 16, 2025
f6b66a9
iter
glemaitre Jan 16, 2025
13efcd8
cross-validation binary default
glemaitre Jan 16, 2025
9d50c0b
cross_validation multiclass defaults
glemaitre Jan 16, 2025
26169da
iter
glemaitre Jan 16, 2025
21ee6fc
more tests
glemaitre Jan 16, 2025
0df2b42
more tests
glemaitre Jan 16, 2025
28b6fea
iter
glemaitre Jan 16, 2025
2e933cf
roc tests
glemaitre Jan 16, 2025
293a77d
prediction error plot tests
glemaitre Jan 16, 2025
0b9a083
tests
glemaitre Jan 16, 2025
29cad73
add examples
glemaitre Jan 16, 2025
c1c9e92
rename cv to cv_splitter
glemaitre Jan 17, 2025
088e4e3
Update skore/src/skore/sklearn/_cross_validation/report.py
glemaitre Jan 17, 2025
b4523de
Update skore/src/skore/sklearn/_cross_validation/report.py
glemaitre Jan 17, 2025
9fa5f17
Update skore/src/skore/sklearn/_cross_validation/report.py
glemaitre Jan 17, 2025
95478c6
Update skore/src/skore/sklearn/_estimator/report.py
glemaitre Jan 17, 2025
c8eb0a4
Update skore/src/skore/sklearn/_cross_validation/__init__.py
glemaitre Jan 17, 2025
87c0d19
chore: Harmonize `message` to `note` in notes-related methods (#1143)
augustebaum Jan 17, 2025
3f6be36
fix(UI): Item card actions are now aligned (#1145)
rouk1 Jan 17, 2025
715cafc
feat: 404 page when skore-UI as not been built (#1142)
rouk1 Jan 17, 2025
574cb65
fix(UI): Ellipsis long item name (#1147)
rouk1 Jan 17, 2025
c598b86
ci: Fix timeout by upgrading scikit-learn to latest bugfix (#1141)
glemaitre Jan 17, 2025
94d43b4
iter
glemaitre Jan 17, 2025
25bb9ba
order matter
glemaitre Jan 17, 2025
3810c8d
Merge remote-tracking branch 'origin/main' into cross_validation_repo…
glemaitre Jan 18, 2025
2c68c9e
fix: Cache the pos_label=None when calling cache_predictions
glemaitre Jan 18, 2025
fe6bfa0
Merge branch 'bug_pos_label_none_cache' into cross_validation_report_3
glemaitre Jan 18, 2025
e989de0
iter
glemaitre Jan 18, 2025
c13cd12
docs: Add example to show details regarding EstimatorReport and Cross…
glemaitre Jan 18, 2025
ab6baf4
iter
glemaitre Jan 18, 2025
ec8dd8f
end cache example
glemaitre Jan 19, 2025
3303c02
Merge branch 'main' into new_documentation
glemaitre Jan 20, 2025
3b9f814
iter
glemaitre Jan 19, 2025
9d3d867
Merge remote-tracking branch 'origin/main' into new_documentation
glemaitre Jan 20, 2025
42ecc07
call reporter -> report
glemaitre Jan 20, 2025
6e7e0f5
wip
glemaitre Jan 20, 2025
a24c7eb
wip
glemaitre Jan 20, 2025
24f33f0
fix: Make reports pickable out-of-the-box
glemaitre Jan 20, 2025
db9cec6
iter
glemaitre Jan 20, 2025
7a3384e
Merge branch 'make_report_pickable' into new_documentation
glemaitre Jan 20, 2025
218f7b7
iter
glemaitre Jan 20, 2025
5b4811f
Merge branch 'main' into new_documentation
glemaitre Jan 21, 2025
d3f6eb4
rework example
glemaitre Jan 21, 2025
d720d03
iter
glemaitre Jan 22, 2025
9172cc6
tigh_layout
glemaitre Jan 22, 2025
fbdb288
add sentence-transformers
glemaitre Jan 22, 2025
2e4036d
iter
glemaitre Jan 22, 2025
2d67184
Merge remote-tracking branch 'origin/main' into new_documentation
glemaitre Jan 22, 2025
a9ab8bc
remove temp file
glemaitre Jan 22, 2025
1f02bed
Update examples/use_cases/plot_employee_salaries.py
glemaitre Jan 22, 2025
f9c39fe
Update examples/use_cases/plot_employee_salaries.py
glemaitre Jan 22, 2025
ddf28f1
Update examples/use_cases/plot_employee_salaries.py
glemaitre Jan 22, 2025
f2cf83a
Update examples/use_cases/plot_employee_salaries.py
glemaitre Jan 22, 2025
fad060e
timeout
glemaitre Jan 22, 2025
c03d198
iter
glemaitre Jan 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/technical_details/README.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Technical details
-----------------

These examples shows some technical details at the core of `skore` to better understand
some of the mechanic under the hood.
239 changes: 239 additions & 0 deletions examples/technical_details/plot_cache_mechanism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
"""
===============
Cache mechanism
===============

This example shows how :class:`~skore.EstimatorReport` and
:class:`~skore.CrossValidationReport` use caching to speed up computations.
"""

# %%
#
# First, we load a dataset from `skrub`. Our goal is to predict if a company paid a
# physician. The ultimate goal is to detect potential conflict of interest when it comes
# to the actual problem that we want to solve.
from skrub.datasets import fetch_open_payments

dataset = fetch_open_payments()
df = dataset.X
y = dataset.y

# %%
from skrub import TableReport

TableReport(df)

# %%
#
# The dataset has over 70,000 records with only categorical features. Some categories
# are not well-defined. We use `skrub` to create a simple predictive model that handles
# this.
from skrub import tabular_learner

model = tabular_learner("classifier")
model

# %%
#
# This model handles all types of data: numbers, categories, dates, and missing values.
# Let's train it on part of our dataset.
from skore import train_test_split

X_train, X_test, y_train, y_test = train_test_split(df, y, random_state=42)

# %%
#
# Let's explore how :class:`~skore.EstimatorReport` uses caching to speed up
# predictions. We start by training the model:
from skore import EstimatorReport

report = EstimatorReport(
model, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
)
report.help()

# %%
#
# Let's compute the accuracy on our test set and measure how long it takes:
import time

start = time.time()
result = report.metrics.accuracy()
end = time.time()
result

# %%
print(f"Time taken: {end - start:.2f} seconds")

# %%
#
# For comparison, here's how scikit-learn computes the same accuracy score:
from sklearn.metrics import accuracy_score

start = time.time()
result = accuracy_score(report.y_test, report.estimator_.predict(report.X_test))
end = time.time()
result

# %%
print(f"Time taken: {end - start:.2f} seconds")

# %%
#
# Both approaches take similar time. Now watch what happens when we compute accuracy
# again:
start = time.time()
result = report.metrics.accuracy()
end = time.time()
result

# %%
print(f"Time taken: {end - start:.2f} seconds")

# %%
#
# The second calculation is instant! This happens because the report saves previous
# calculations in its cache. Let's look inside the cache:
report._cache

# %%
#
# The cache stores predictions by type and data source. This means metrics that use
# the same type of predictions will be faster. Let's try the precision metric:
start = time.time()
result = report.metrics.precision()
end = time.time()
result

# %%
print(f"Time taken: {end - start:.2f} seconds")

# %%
# We observe that it takes only a few milliseconds to compute the precision because we
# don't need to re-compute the predictions and only have to compute the precision
# metric itself. Since the predictions are the bottleneck in terms of time, we observe
# an interesting speedup.
#
# We can pre-compute all predictions at once using parallel processing:
report.cache_predictions(n_jobs=2)

# %%
#
# Now all possible predictions are stored. Any metric calculation will be much faster,
# even on different data (like the training set):
start = time.time()
result = report.metrics.log_loss(data_source="train")
end = time.time()
result

# %%
print(f"Time taken: {end - start:.2f} seconds")

# %%
#
# The report can also work with external data. We use `data_source="X_y"` to indicate
# that we want to pass those external data.
start = time.time()
result = report.metrics.log_loss(data_source="X_y", X=X_test, y=y_test)
end = time.time()
result

# %%
print(f"Time taken: {end - start:.2f} seconds")

# %%
#
# The first calculation is slower than when using the internal train or test sets
# because it needs to compute a hash of the new data for later retrieval. Let's
# calculate it again:
start = time.time()
result = report.metrics.log_loss(data_source="X_y", X=X_test, y=y_test)
end = time.time()
result

# %%
print(f"Time taken: {end - start:.2f} seconds")

# %%
#
# Much faster! The remaining time is related to the hash computation. Let's compute the
# ROC AUC on the same data:
start = time.time()
result = report.metrics.roc_auc(data_source="X_y", X=X_test, y=y_test)
end = time.time()
result

# %%
print(f"Time taken: {end - start:.2f} seconds")

# %%
# We observe that the computation is already efficient because it boils down to two
# computations: the hash of the data and the ROC-AUC metric. We save a lot of time
# because we don't need to re-compute the predictions.
#
# The cache also speeds up plots. Let's create a ROC curve:
import matplotlib.pyplot as plt

start = time.time()
display = report.metrics.plot.roc(pos_label="allowed")
end = time.time()
plt.tight_layout()

# %%
print(f"Time taken: {end - start:.2f} seconds")

# %%
#
# The second plot is instant because it uses cached data:
start = time.time()
display = report.metrics.plot.roc(pos_label="allowed")
end = time.time()
plt.tight_layout()

# %%
print(f"Time taken: {end - start:.2f} seconds")

# %%
#
# We only use the cache to retrieve the `display` object and not directly the matplotlib
# figure. It means that you can still customize the cached plot before displaying it:
display.plot(roc_curve_kwargs={"color": "tab:orange"})
plt.tight_layout()

# %%
#
# Be aware that you can clear the cache if you want to:
report.clear_cache()
report._cache

# %%
#
# It means that nothing is stored anymore in the cache.
#
# :class:`~skore.CrossValidationReport` uses the same caching system for each fold
# in cross-validation by leveraging the previous :class:`~skore.EstimatorReport`:
from skore import CrossValidationReport

report = CrossValidationReport(model, X=df, y=y, cv_splitter=5, n_jobs=2)
report.help()

# %%
#
# We can pre-compute all predictions at once using parallel processing:
report.cache_predictions(n_jobs=2)

# %%
#
# Now all possible predictions are stored. Any metric calculation will be much faster,
# even on different data as we show for the :class:`~skore.EstimatorReport`.
start = time.time()
result = report.metrics.report_metrics(aggregate=["mean", "std"])
end = time.time()
result

# %%
print(f"Time taken: {end - start:.2f} seconds")

# %%
#
# So we observe the same type of behaviour as we previously exposed.
9 changes: 9 additions & 0 deletions examples/use_cases/README.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
End-to-end data science use cases
---------------------------------

These examples show `skore` in action on real use case. We aimed at showing `skore`
ability to:

- be compatible with `scikit-learn`
- reduce boilerplate code for some standard *de facto* data science analysis
- speed-up exploration by optimizing some internal computation
Loading
Loading