-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] FEA Add interpolation join #742
Changes from all commits
4ccccd7
82caaa3
f1e31c0
9f4aa84
fa09bf1
0c3b741
86ba0a3
83c2942
a978c92
0e32a46
9b51875
fb1c809
ab0f0f6
91c82bc
70a5081
180b888
33c6733
9ad0b4e
3048db5
9137367
5fb3451
31cf759
b03e603
2f98a9a
1a5cae3
efac580
e00d7c2
14679e1
d9ee6dd
49f5ccf
9a5061e
c66ae32
7969c83
9c7d714
a5e7331
dff2c78
0cb2a1a
d855d8a
acd6ab1
71657f6
c859ace
38dafb1
81b79b1
9bda6d1
d446421
7f2cf94
6f8e0dd
94a091f
16dde72
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
""" | ||
Interpolation join: infer missing rows when joining two tables | ||
============================================================== | ||
|
||
We illustrate the :class:`~skrub.InterpolationJoiner`, which is a type of join where values from the second table are inferred with machine-learning, rather than looked up in the table. | ||
It is useful when exact matches are not available but we have rows that are close enough to make an educated guess -- in this sense it is a generalization of a :func:`~skrub.fuzzy_join`. | ||
|
||
The :class:`~skrub.InterpolationJoiner` is therefore a transformer that adds the outputs of one or more machine-learning models as new columns to the table it operates on. | ||
|
||
In this example we want our transformer to add weather data (temperature, rain, etc.) to the table it operates on. | ||
We have a table containing information about commercial flights, and we want to add information about the weather at the time and place where each flight took off. | ||
This could be useful to predict delays -- flights are often delayed by bad weather. | ||
|
||
We have a table of weather data containing, at many weather stations, measurements such as temperature, rain and snow at many time points. | ||
Unfortunately, our weather stations are not inside the airports, and the measurements are not timed according to the flight schedule. | ||
Therefore, a simple equi-join would not yield any matching pair of rows from our two tables. | ||
Instead, we use the :class:`~skrub.InterpolationJoiner` to *infer* the temperature at the airport at take-off time. | ||
We train supervised machine-learning models using the weather table, then query them with the times and locations in the flights table. | ||
|
||
""" | ||
|
||
###################################################################### | ||
# Load weather data | ||
# ----------------- | ||
# We join the table containing the measurements to the table that contains the weather stations’ latitude and longitude. | ||
# We subsample these large tables for the example to run faster. | ||
|
||
from skrub.datasets import fetch_figshare | ||
|
||
weather = fetch_figshare("41771457").X | ||
weather = weather.sample(100_000, random_state=0, ignore_index=True) | ||
stations = fetch_figshare("41710524").X | ||
weather = stations.merge(weather, on="ID")[ | ||
["LATITUDE", "LONGITUDE", "YEAR/MONTH/DAY", "TMAX", "PRCP", "SNOW"] | ||
] | ||
|
||
###################################################################### | ||
# The ``'TMAX'`` is in tenths of degree Celsius -- a ``'TMAX'`` of 297 means the maximum temperature that day was 29.7℃. | ||
# We convert it to degrees for readability | ||
|
||
weather["TMAX"] /= 10 | ||
|
||
###################################################################### | ||
# InterpolationJoiner with a ground truth: joining the weather table on itself | ||
# ---------------------------------------------------------------------------- | ||
# As a first simple example, we apply the :class:`~skrub.InterpolationJoiner` in a situation where the ground truth is known. | ||
# We split the weather table in half and join the second half on the first half. | ||
# Thus, the values from the right side table of the join are inferred, whereas the corresponding columns from the left side contain the ground truth and we can compare them. | ||
|
||
n_main = weather.shape[0] // 2 | ||
main_table = weather.iloc[:n_main] | ||
main_table.head() | ||
|
||
###################################################################### | ||
aux_table = weather.iloc[n_main:] | ||
aux_table.head() | ||
|
||
|
||
###################################################################### | ||
# Joining the tables | ||
# ------------------ | ||
# Now we join our two tables and check how well the :class:`~skrub.InterpolationJoiner` can reconstruct the matching rows that are missing from the right side table. | ||
# To avoid clashes in the column names, we use the ``suffix`` parameter to append ``"predicted"`` to the right side table column names. | ||
|
||
from skrub import InterpolationJoiner | ||
|
||
joiner = InterpolationJoiner( | ||
aux_table, | ||
key=["LATITUDE", "LONGITUDE", "YEAR/MONTH/DAY"], | ||
suffix="_predicted", | ||
).fit(main_table) | ||
join = joiner.transform(main_table) | ||
join.head() | ||
|
||
###################################################################### | ||
# Comparing the estimated values to the ground truth | ||
# -------------------------------------------------- | ||
|
||
from matplotlib import pyplot as plt | ||
|
||
join = join.sample(2000, random_state=0, ignore_index=True) | ||
fig, axes = plt.subplots( | ||
3, | ||
1, | ||
figsize=(5, 9), | ||
gridspec_kw={"height_ratios": [1.0, 0.5, 0.5]}, | ||
layout="compressed", | ||
) | ||
for ax, col in zip(axes.ravel(), ["TMAX", "PRCP", "SNOW"]): | ||
ax.scatter( | ||
join[col].values, | ||
join[f"{col}_predicted"].values, | ||
alpha=0.1, | ||
) | ||
ax.set_aspect(1) | ||
ax.set_xlabel(f"true {col}") | ||
ax.set_ylabel(f"predicted {col}") | ||
plt.show() | ||
|
||
###################################################################### | ||
# We see that in this case the interpolation join works well for the temperature, but not precipitation nor snow. | ||
# So we will only add the temperature to our flights table. | ||
|
||
aux_table = aux_table.drop(["PRCP", "SNOW"], axis=1) | ||
|
||
###################################################################### | ||
# Loading the flights table | ||
# ------------------------- | ||
# We load the flights table and join it to the airports table using the flights’ ``'Origin'`` which refers to the departure airport’s IATA code. | ||
# We use only a subset to speed up the example. | ||
|
||
flights = fetch_figshare("41771418").X[["Year_Month_DayofMonth", "Origin", "ArrDelay"]] | ||
flights = flights.sample(20_000, random_state=0, ignore_index=True) | ||
airports = fetch_figshare("41710257").X[["iata", "airport", "state", "lat", "long"]] | ||
flights = flights.merge(airports, left_on="Origin", right_on="iata") | ||
# printing the first row is more readable than the head() when we have many columns | ||
flights.iloc[0] | ||
|
||
###################################################################### | ||
# Joining the flights and weather data | ||
# ------------------------------------ | ||
# As before, we initialize our join transformer with the weather table. | ||
# Then, we use it to transform the flights table -- it adds a ``'TMAX'`` column containing the predicted maximum daily temperature. | ||
# | ||
|
||
joiner = InterpolationJoiner( | ||
aux_table, | ||
main_key=["lat", "long", "Year_Month_DayofMonth"], | ||
aux_key=["LATITUDE", "LONGITUDE", "YEAR/MONTH/DAY"], | ||
) | ||
join = joiner.fit_transform(flights) | ||
join.head() | ||
|
||
###################################################################### | ||
# Sanity checks | ||
# ------------- | ||
# This time we do not have a ground truth for the temperatures. | ||
# We can perform a few basic sanity checks. | ||
|
||
state_temperatures = join.groupby("state")["TMAX"].mean().sort_values() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These sanity checks are simple but also very helpful, cf my suggestion of plot utils above. |
||
|
||
###################################################################### | ||
# States with the lowest average predicted temperatures: Alaska, Montana, North Dakota, Washington, Minnesota. | ||
state_temperatures.head() | ||
|
||
###################################################################### | ||
# States with the highest predicted temperatures: Puerto Rico, Virgin Islands, Hawaii, Florida, Louisiana. | ||
state_temperatures.tail() | ||
|
||
###################################################################### | ||
# Higher latitudes (farther up north) are colder -- the airports in this dataset are in the United States. | ||
fig, ax = plt.subplots() | ||
ax.scatter(join["lat"], join["TMAX"]) | ||
ax.set_xlabel("Latitude (higher is farther north)") | ||
ax.set_ylabel("TMAX") | ||
plt.show() | ||
|
||
###################################################################### | ||
# Winter months are colder than spring -- in the north hemisphere January is colder than April | ||
# | ||
|
||
import seaborn as sns | ||
|
||
join["month"] = join["Year_Month_DayofMonth"].dt.strftime("%m %B") | ||
plt.figure(layout="constrained") | ||
sns.barplot(data=join.sort_values(by="month"), y="month", x="TMAX") | ||
plt.show() | ||
|
||
###################################################################### | ||
# Of course these checks do not guarantee that the inferred values in our ``join`` table’s ``'TMAX'`` column are accurate. | ||
# But at least the :class:`~skrub.InterpolationJoiner` seems to have learned a few reasonable trends from its training table. | ||
|
||
|
||
###################################################################### | ||
# Conclusion | ||
# ---------- | ||
# We have seen how to fit an :class:`~skrub.InterpolationJoiner` transformer: we give it a table (the weather data) and a set of matching columns (here date, latitude, longitude) and it learns to predict the other columns’ values (such as the max daily temperature). | ||
# Then, it transforms tables by *predicting* values that a matching row would contain, rather than by searching for an actual match. | ||
# It is a generalization of the :func:`~skrub.fuzzy_join`, as :func:`~skrub.fuzzy_join` is the same thing as an :class:`~skrub.InterpolationJoiner` where the estimators are 1-nearest-neighbor estimators. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think we should make another example later where we display several There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes I think that would be great! the first step will be finding a slightly easier dataset, where we see a bigger benefit of joining an extra table on a downstream task -- otherwise we won't see any difference between the different classifiers and regressors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This entire section is so helpful for discovering what features to join that I wonder if we should make a plot utils based on this in a subsequent PR. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, yes I think that's a good idea. skrub is likely to be used in some rather complex pipelines so it would be very useful for users if we could help with inspection, debugging etc