Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] FEA Add interpolation join #742

Merged
merged 49 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
4ccccd7
add interpolation join
jeromedockes Sep 20, 2023
82caaa3
changelog
jeromedockes Sep 20, 2023
f1e31c0
details
jeromedockes Sep 20, 2023
9f4aa84
add the "on" convenience parameter
jeromedockes Sep 21, 2023
fa09bf1
add "suffix" parameter
jeromedockes Sep 21, 2023
0c3b741
blank line at end of docstring
jeromedockes Sep 21, 2023
86ba0a3
blank line in docstring
jeromedockes Sep 21, 2023
83c2942
missing reset_index()
jeromedockes Sep 22, 2023
a978c92
Merge remote-tracking branch 'upstream/main' into interpolation_join
jeromedockes Sep 22, 2023
0e32a46
improve example
jeromedockes Sep 22, 2023
9b51875
Apply suggestions from code review
jeromedockes Sep 25, 2023
fb1c809
review comments on example
jeromedockes Sep 25, 2023
ab0f0f6
sklearn imports + TransformerMixin
jeromedockes Sep 25, 2023
91c82bc
rename fit params X, y
jeromedockes Sep 25, 2023
70a5081
remove target_columns parameter from _fit
jeromedockes Sep 25, 2023
180b888
prefer __getitem__ to .loc
jeromedockes Sep 25, 2023
33c6733
add some docstrings & comments
jeromedockes Sep 25, 2023
9ad0b4e
use _safe_tags rather than _get_tags
jeromedockes Sep 25, 2023
3048db5
use default verbose
jeromedockes Sep 25, 2023
9137367
apply renaming decided in skrub meeting
jeromedockes Sep 26, 2023
5fb3451
Merge remote-tracking branch 'upstream/main' into interpolation_join
jeromedockes Sep 26, 2023
31cf759
simplify index handling in concatenation
jeromedockes Sep 26, 2023
b03e603
Update examples/08_interpolation_join.py
jeromedockes Sep 26, 2023
2f98a9a
address review comments
jeromedockes Sep 26, 2023
1a5cae3
remove vectorizer param, always vectorize keys
jeromedockes Sep 26, 2023
efac580
blank line at the end of docstring
jeromedockes Sep 26, 2023
e00d7c2
rename InterpolationJoin → InterpolationJoiner
jeromedockes Sep 26, 2023
14679e1
rename interpolation_join module
jeromedockes Sep 26, 2023
d9ee6dd
address review
jeromedockes Sep 26, 2023
49f5ccf
restore the vectorizer parameter
jeromedockes Sep 28, 2023
9a5061e
allow controlling how estimator exceptions should be handled
jeromedockes Sep 29, 2023
c66ae32
improve n_jobs description and default value
jeromedockes Sep 29, 2023
7969c83
Merge remote-tracking branch 'upstream/main' into interpolation_join
jeromedockes Oct 9, 2023
9c7d714
use MinHashEncoder in InterpolationJoiner
jeromedockes Oct 9, 2023
a5e7331
Merge remote-tracking branch 'upstream/main' into interpolation_join
jeromedockes Oct 13, 2023
dff2c78
call plt.show() in example
jeromedockes Oct 13, 2023
0cb2a1a
rename example (08 already taken now)
jeromedockes Oct 13, 2023
d855d8a
Apply suggestions from code review
jeromedockes Oct 16, 2023
acd6ab1
add doctest setup
jeromedockes Oct 16, 2023
71657f6
Merge remote-tracking branch 'upstream/main' into interpolation_join
jeromedockes Nov 2, 2023
c859ace
Merge remote-tracking branch 'upstream/main' into interpolation_join
jeromedockes Nov 2, 2023
38dafb1
fix transform after change in tablevectorizer
jeromedockes Nov 2, 2023
81b79b1
use checks from join_utils
jeromedockes Nov 2, 2023
9bda6d1
improve example and docstring
jeromedockes Nov 2, 2023
d446421
Merge remote-tracking branch 'upstream/main' into interpolation_join
jeromedockes Nov 10, 2023
7f2cf94
apply same handling of default estimators as in TableVectorizer
jeromedockes Nov 10, 2023
6f8e0dd
use default datetimeencoder params
jeromedockes Nov 10, 2023
94a091f
add test
jeromedockes Nov 10, 2023
16dde72
add note on minhash vs gap encoding
jeromedockes Nov 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ development and backward compatibility is not ensured.
Major changes
-------------

* :class:`InterpolationJoiner` was added to join two tables by using
machine-learning to infer the matching rows from the second table.
:pr:`742` by :user:`Jérôme Dockès <jeromedockes>`.

* Pipelines including :class:`TableVectorizer` can now be grid-searched, since
we can now call `set_params` on the default transformers of :class:`TableVectorizer`.
:pr:`814` by :user:`Vincent Maladiere <Vincent-Maladiere>`
Expand Down
7 changes: 7 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ This page lists all available functions and classes of `skrub`.
AggTarget


.. autosummary::
:toctree: generated/
:template: class.rst
:nosignatures:

InterpolationJoiner

.. raw:: html

<h2>Column selection in a pipeline</h2>
Expand Down
179 changes: 179 additions & 0 deletions examples/09_interpolation_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""
Interpolation join: infer missing rows when joining two tables
==============================================================

We illustrate the :class:`~skrub.InterpolationJoiner`, which is a type of join where values from the second table are inferred with machine-learning, rather than looked up in the table.
It is useful when exact matches are not available but we have rows that are close enough to make an educated guess -- in this sense it is a generalization of a :func:`~skrub.fuzzy_join`.

The :class:`~skrub.InterpolationJoiner` is therefore a transformer that adds the outputs of one or more machine-learning models as new columns to the table it operates on.

In this example we want our transformer to add weather data (temperature, rain, etc.) to the table it operates on.
We have a table containing information about commercial flights, and we want to add information about the weather at the time and place where each flight took off.
This could be useful to predict delays -- flights are often delayed by bad weather.

We have a table of weather data containing, at many weather stations, measurements such as temperature, rain and snow at many time points.
Unfortunately, our weather stations are not inside the airports, and the measurements are not timed according to the flight schedule.
Therefore, a simple equi-join would not yield any matching pair of rows from our two tables.
Instead, we use the :class:`~skrub.InterpolationJoiner` to *infer* the temperature at the airport at take-off time.
We train supervised machine-learning models using the weather table, then query them with the times and locations in the flights table.

"""

######################################################################
# Load weather data
# -----------------
# We join the table containing the measurements to the table that contains the weather stations’ latitude and longitude.
# We subsample these large tables for the example to run faster.

from skrub.datasets import fetch_figshare

weather = fetch_figshare("41771457").X
weather = weather.sample(100_000, random_state=0, ignore_index=True)
stations = fetch_figshare("41710524").X
weather = stations.merge(weather, on="ID")[
["LATITUDE", "LONGITUDE", "YEAR/MONTH/DAY", "TMAX", "PRCP", "SNOW"]
]

######################################################################
# The ``'TMAX'`` is in tenths of degree Celsius -- a ``'TMAX'`` of 297 means the maximum temperature that day was 29.7℃.
# We convert it to degrees for readability

weather["TMAX"] /= 10

######################################################################
# InterpolationJoiner with a ground truth: joining the weather table on itself
# ----------------------------------------------------------------------------
# As a first simple example, we apply the :class:`~skrub.InterpolationJoiner` in a situation where the ground truth is known.
# We split the weather table in half and join the second half on the first half.
# Thus, the values from the right side table of the join are inferred, whereas the corresponding columns from the left side contain the ground truth and we can compare them.

n_main = weather.shape[0] // 2
main_table = weather.iloc[:n_main]
main_table.head()

######################################################################
aux_table = weather.iloc[n_main:]
aux_table.head()


######################################################################
# Joining the tables
# ------------------
# Now we join our two tables and check how well the :class:`~skrub.InterpolationJoiner` can reconstruct the matching rows that are missing from the right side table.
# To avoid clashes in the column names, we use the ``suffix`` parameter to append ``"predicted"`` to the right side table column names.

from skrub import InterpolationJoiner

joiner = InterpolationJoiner(
aux_table,
key=["LATITUDE", "LONGITUDE", "YEAR/MONTH/DAY"],
suffix="_predicted",
).fit(main_table)
join = joiner.transform(main_table)
join.head()

######################################################################
# Comparing the estimated values to the ground truth
# --------------------------------------------------

from matplotlib import pyplot as plt

join = join.sample(2000, random_state=0, ignore_index=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire section is so helpful for discovering what features to join that I wonder if we should make a plot utils based on this in a subsequent PR. WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, yes I think that's a good idea. skrub is likely to be used in some rather complex pipelines so it would be very useful for users if we could help with inspection, debugging etc

fig, axes = plt.subplots(
3,
1,
figsize=(5, 9),
gridspec_kw={"height_ratios": [1.0, 0.5, 0.5]},
layout="compressed",
)
for ax, col in zip(axes.ravel(), ["TMAX", "PRCP", "SNOW"]):
ax.scatter(
join[col].values,
join[f"{col}_predicted"].values,
alpha=0.1,
)
ax.set_aspect(1)
ax.set_xlabel(f"true {col}")
ax.set_ylabel(f"predicted {col}")
plt.show()

######################################################################
# We see that in this case the interpolation join works well for the temperature, but not precipitation nor snow.
# So we will only add the temperature to our flights table.

aux_table = aux_table.drop(["PRCP", "SNOW"], axis=1)

######################################################################
# Loading the flights table
# -------------------------
# We load the flights table and join it to the airports table using the flights’ ``'Origin'`` which refers to the departure airport’s IATA code.
# We use only a subset to speed up the example.

flights = fetch_figshare("41771418").X[["Year_Month_DayofMonth", "Origin", "ArrDelay"]]
flights = flights.sample(20_000, random_state=0, ignore_index=True)
airports = fetch_figshare("41710257").X[["iata", "airport", "state", "lat", "long"]]
flights = flights.merge(airports, left_on="Origin", right_on="iata")
# printing the first row is more readable than the head() when we have many columns
flights.iloc[0]

######################################################################
# Joining the flights and weather data
# ------------------------------------
# As before, we initialize our join transformer with the weather table.
# Then, we use it to transform the flights table -- it adds a ``'TMAX'`` column containing the predicted maximum daily temperature.
#

joiner = InterpolationJoiner(
aux_table,
main_key=["lat", "long", "Year_Month_DayofMonth"],
aux_key=["LATITUDE", "LONGITUDE", "YEAR/MONTH/DAY"],
)
join = joiner.fit_transform(flights)
join.head()

######################################################################
# Sanity checks
# -------------
# This time we do not have a ground truth for the temperatures.
# We can perform a few basic sanity checks.

state_temperatures = join.groupby("state")["TMAX"].mean().sort_values()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These sanity checks are simple but also very helpful, cf my suggestion of plot utils above.


######################################################################
# States with the lowest average predicted temperatures: Alaska, Montana, North Dakota, Washington, Minnesota.
state_temperatures.head()

######################################################################
# States with the highest predicted temperatures: Puerto Rico, Virgin Islands, Hawaii, Florida, Louisiana.
state_temperatures.tail()

######################################################################
# Higher latitudes (farther up north) are colder -- the airports in this dataset are in the United States.
fig, ax = plt.subplots()
ax.scatter(join["lat"], join["TMAX"])
ax.set_xlabel("Latitude (higher is farther north)")
ax.set_ylabel("TMAX")
plt.show()

######################################################################
# Winter months are colder than spring -- in the north hemisphere January is colder than April
#

import seaborn as sns

join["month"] = join["Year_Month_DayofMonth"].dt.strftime("%m %B")
plt.figure(layout="constrained")
sns.barplot(data=join.sort_values(by="month"), y="month", x="TMAX")
plt.show()

######################################################################
# Of course these checks do not guarantee that the inferred values in our ``join`` table’s ``'TMAX'`` column are accurate.
# But at least the :class:`~skrub.InterpolationJoiner` seems to have learned a few reasonable trends from its training table.


######################################################################
# Conclusion
# ----------
# We have seen how to fit an :class:`~skrub.InterpolationJoiner` transformer: we give it a table (the weather data) and a set of matching columns (here date, latitude, longitude) and it learns to predict the other columns’ values (such as the max daily temperature).
# Then, it transforms tables by *predicting* values that a matching row would contain, rather than by searching for an actual match.
# It is a generalization of the :func:`~skrub.fuzzy_join`, as :func:`~skrub.fuzzy_join` is the same thing as an :class:`~skrub.InterpolationJoiner` where the estimators are 1-nearest-neighbor estimators.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should make another example later where we display several InterpolateJoiner with different classifier and regressor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I think that would be great! the first step will be finding a slightly easier dataset, where we see a bigger benefit of joining an extra table on a downstream task -- otherwise we won't see any difference between the different classifiers and regressors

2 changes: 2 additions & 0 deletions skrub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ._deduplicate import compute_ngram_distance, deduplicate
from ._fuzzy_join import fuzzy_join
from ._gap_encoder import GapEncoder
from ._interpolation_joiner import InterpolationJoiner
from ._joiner import Joiner
from ._minhash_encoder import MinHashEncoder
from ._select_cols import DropCols, SelectCols
Expand All @@ -27,6 +28,7 @@
"Joiner",
"fuzzy_join",
"GapEncoder",
"InterpolationJoiner",
"MinHashEncoder",
"SimilarityEncoder",
"TableVectorizer",
Expand Down
Loading
Loading