diff --git a/CHANGES.rst b/CHANGES.rst index c1cd568d4..dd6726821 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -15,6 +15,10 @@ development and backward compatibility is not ensured. Major changes ------------- +* :class:`InterpolationJoiner` was added to join two tables by using + machine-learning to infer the matching rows from the second table. + :pr:`742` by :user:`Jérôme Dockès `. + * Pipelines including :class:`TableVectorizer` can now be grid-searched, since we can now call `set_params` on the default transformers of :class:`TableVectorizer`. :pr:`814` by :user:`Vincent Maladiere ` diff --git a/doc/api.rst b/doc/api.rst index 730fe37cb..7cdd0689a 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -36,6 +36,13 @@ This page lists all available functions and classes of `skrub`. AggTarget +.. autosummary:: + :toctree: generated/ + :template: class.rst + :nosignatures: + + InterpolationJoiner + .. raw:: html

Column selection in a pipeline

diff --git a/examples/09_interpolation_join.py b/examples/09_interpolation_join.py new file mode 100644 index 000000000..d452b113d --- /dev/null +++ b/examples/09_interpolation_join.py @@ -0,0 +1,179 @@ +""" +Interpolation join: infer missing rows when joining two tables +============================================================== + +We illustrate the :class:`~skrub.InterpolationJoiner`, which is a type of join where values from the second table are inferred with machine-learning, rather than looked up in the table. +It is useful when exact matches are not available but we have rows that are close enough to make an educated guess -- in this sense it is a generalization of a :func:`~skrub.fuzzy_join`. + +The :class:`~skrub.InterpolationJoiner` is therefore a transformer that adds the outputs of one or more machine-learning models as new columns to the table it operates on. + +In this example we want our transformer to add weather data (temperature, rain, etc.) to the table it operates on. +We have a table containing information about commercial flights, and we want to add information about the weather at the time and place where each flight took off. +This could be useful to predict delays -- flights are often delayed by bad weather. + +We have a table of weather data containing, at many weather stations, measurements such as temperature, rain and snow at many time points. +Unfortunately, our weather stations are not inside the airports, and the measurements are not timed according to the flight schedule. +Therefore, a simple equi-join would not yield any matching pair of rows from our two tables. +Instead, we use the :class:`~skrub.InterpolationJoiner` to *infer* the temperature at the airport at take-off time. +We train supervised machine-learning models using the weather table, then query them with the times and locations in the flights table. + +""" + +###################################################################### +# Load weather data +# ----------------- +# We join the table containing the measurements to the table that contains the weather stations’ latitude and longitude. +# We subsample these large tables for the example to run faster. + +from skrub.datasets import fetch_figshare + +weather = fetch_figshare("41771457").X +weather = weather.sample(100_000, random_state=0, ignore_index=True) +stations = fetch_figshare("41710524").X +weather = stations.merge(weather, on="ID")[ + ["LATITUDE", "LONGITUDE", "YEAR/MONTH/DAY", "TMAX", "PRCP", "SNOW"] +] + +###################################################################### +# The ``'TMAX'`` is in tenths of degree Celsius -- a ``'TMAX'`` of 297 means the maximum temperature that day was 29.7℃. +# We convert it to degrees for readability + +weather["TMAX"] /= 10 + +###################################################################### +# InterpolationJoiner with a ground truth: joining the weather table on itself +# ---------------------------------------------------------------------------- +# As a first simple example, we apply the :class:`~skrub.InterpolationJoiner` in a situation where the ground truth is known. +# We split the weather table in half and join the second half on the first half. +# Thus, the values from the right side table of the join are inferred, whereas the corresponding columns from the left side contain the ground truth and we can compare them. + +n_main = weather.shape[0] // 2 +main_table = weather.iloc[:n_main] +main_table.head() + +###################################################################### +aux_table = weather.iloc[n_main:] +aux_table.head() + + +###################################################################### +# Joining the tables +# ------------------ +# Now we join our two tables and check how well the :class:`~skrub.InterpolationJoiner` can reconstruct the matching rows that are missing from the right side table. +# To avoid clashes in the column names, we use the ``suffix`` parameter to append ``"predicted"`` to the right side table column names. + +from skrub import InterpolationJoiner + +joiner = InterpolationJoiner( + aux_table, + key=["LATITUDE", "LONGITUDE", "YEAR/MONTH/DAY"], + suffix="_predicted", +).fit(main_table) +join = joiner.transform(main_table) +join.head() + +###################################################################### +# Comparing the estimated values to the ground truth +# -------------------------------------------------- + +from matplotlib import pyplot as plt + +join = join.sample(2000, random_state=0, ignore_index=True) +fig, axes = plt.subplots( + 3, + 1, + figsize=(5, 9), + gridspec_kw={"height_ratios": [1.0, 0.5, 0.5]}, + layout="compressed", +) +for ax, col in zip(axes.ravel(), ["TMAX", "PRCP", "SNOW"]): + ax.scatter( + join[col].values, + join[f"{col}_predicted"].values, + alpha=0.1, + ) + ax.set_aspect(1) + ax.set_xlabel(f"true {col}") + ax.set_ylabel(f"predicted {col}") +plt.show() + +###################################################################### +# We see that in this case the interpolation join works well for the temperature, but not precipitation nor snow. +# So we will only add the temperature to our flights table. + +aux_table = aux_table.drop(["PRCP", "SNOW"], axis=1) + +###################################################################### +# Loading the flights table +# ------------------------- +# We load the flights table and join it to the airports table using the flights’ ``'Origin'`` which refers to the departure airport’s IATA code. +# We use only a subset to speed up the example. + +flights = fetch_figshare("41771418").X[["Year_Month_DayofMonth", "Origin", "ArrDelay"]] +flights = flights.sample(20_000, random_state=0, ignore_index=True) +airports = fetch_figshare("41710257").X[["iata", "airport", "state", "lat", "long"]] +flights = flights.merge(airports, left_on="Origin", right_on="iata") +# printing the first row is more readable than the head() when we have many columns +flights.iloc[0] + +###################################################################### +# Joining the flights and weather data +# ------------------------------------ +# As before, we initialize our join transformer with the weather table. +# Then, we use it to transform the flights table -- it adds a ``'TMAX'`` column containing the predicted maximum daily temperature. +# + +joiner = InterpolationJoiner( + aux_table, + main_key=["lat", "long", "Year_Month_DayofMonth"], + aux_key=["LATITUDE", "LONGITUDE", "YEAR/MONTH/DAY"], +) +join = joiner.fit_transform(flights) +join.head() + +###################################################################### +# Sanity checks +# ------------- +# This time we do not have a ground truth for the temperatures. +# We can perform a few basic sanity checks. + +state_temperatures = join.groupby("state")["TMAX"].mean().sort_values() + +###################################################################### +# States with the lowest average predicted temperatures: Alaska, Montana, North Dakota, Washington, Minnesota. +state_temperatures.head() + +###################################################################### +# States with the highest predicted temperatures: Puerto Rico, Virgin Islands, Hawaii, Florida, Louisiana. +state_temperatures.tail() + +###################################################################### +# Higher latitudes (farther up north) are colder -- the airports in this dataset are in the United States. +fig, ax = plt.subplots() +ax.scatter(join["lat"], join["TMAX"]) +ax.set_xlabel("Latitude (higher is farther north)") +ax.set_ylabel("TMAX") +plt.show() + +###################################################################### +# Winter months are colder than spring -- in the north hemisphere January is colder than April +# + +import seaborn as sns + +join["month"] = join["Year_Month_DayofMonth"].dt.strftime("%m %B") +plt.figure(layout="constrained") +sns.barplot(data=join.sort_values(by="month"), y="month", x="TMAX") +plt.show() + +###################################################################### +# Of course these checks do not guarantee that the inferred values in our ``join`` table’s ``'TMAX'`` column are accurate. +# But at least the :class:`~skrub.InterpolationJoiner` seems to have learned a few reasonable trends from its training table. + + +###################################################################### +# Conclusion +# ---------- +# We have seen how to fit an :class:`~skrub.InterpolationJoiner` transformer: we give it a table (the weather data) and a set of matching columns (here date, latitude, longitude) and it learns to predict the other columns’ values (such as the max daily temperature). +# Then, it transforms tables by *predicting* values that a matching row would contain, rather than by searching for an actual match. +# It is a generalization of the :func:`~skrub.fuzzy_join`, as :func:`~skrub.fuzzy_join` is the same thing as an :class:`~skrub.InterpolationJoiner` where the estimators are 1-nearest-neighbor estimators. diff --git a/skrub/__init__.py b/skrub/__init__.py index e2fccecff..1cd022176 100644 --- a/skrub/__init__.py +++ b/skrub/__init__.py @@ -9,6 +9,7 @@ from ._deduplicate import compute_ngram_distance, deduplicate from ._fuzzy_join import fuzzy_join from ._gap_encoder import GapEncoder +from ._interpolation_joiner import InterpolationJoiner from ._joiner import Joiner from ._minhash_encoder import MinHashEncoder from ._select_cols import DropCols, SelectCols @@ -27,6 +28,7 @@ "Joiner", "fuzzy_join", "GapEncoder", + "InterpolationJoiner", "MinHashEncoder", "SimilarityEncoder", "TableVectorizer", diff --git a/skrub/_interpolation_joiner.py b/skrub/_interpolation_joiner.py new file mode 100644 index 000000000..92e9bf30b --- /dev/null +++ b/skrub/_interpolation_joiner.py @@ -0,0 +1,441 @@ +import warnings + +import joblib +import numpy as np +import pandas as pd +from sklearn.base import BaseEstimator, TransformerMixin, clone +from sklearn.ensemble import ( + HistGradientBoostingClassifier, + HistGradientBoostingRegressor, +) +from sklearn.utils._tags import _safe_tags + +from skrub import _join_utils, _utils +from skrub._minhash_encoder import MinHashEncoder +from skrub._table_vectorizer import TableVectorizer + +DEFAULT_VECTORIZER = TableVectorizer(high_cardinality_transformer=MinHashEncoder()) +DEFAULT_REGRESSOR = HistGradientBoostingRegressor() +DEFAULT_CLASSIFIER = HistGradientBoostingClassifier() + + +class InterpolationJoiner(TransformerMixin, BaseEstimator): + """Join with a table augmented by machine-learning predictions. + + This is similar to a usual equi-join, but instead of looking for actual + rows in the right table that satisfy the join condition, we estimate what + those rows would contain if they existed in the table. + + Suppose we want to join a table ``buildings(latitude, longitude, n_stories)`` + with a table ``annual_avg_temp(latitude, longitude, avg_temp)``. Our annual + average temperature table may not contain data for the exact latitude and + longitude of our buildings. However, we can interpolate what we need from + the data points it does contain. Using ``annual_avg_temp``, we train a + model to predict the temperature, given the latitude and longitude. Then, + we use this model to estimate the values we want to add to our + ``buildings`` table. In a way we are joining ``buildings`` to a virtual + table, in which rows for any (latitude, longitude) location are inferred, + rather than retrieved, when requested. This is done with:: + + InterpolationJoiner( + annual_avg_temp, on=["latitude", "longitude"] + ).fit_transform(buildings) + + Parameters + ---------- + aux_table : DataFrame + The (auxiliary) table to be joined to the `main_table` (which is the + argument of ``transform``). ``aux_table`` is used to train a model that + takes as inputs the contents of the columns listed in ``aux_key``, and + predicts the contents of the other columns. In the example above, we + want our transformer to add temperature data to the table it is + operating on. Therefore, ``aux_table`` is the ``annual_avg_temp`` + table. + + main_key : list of str, or str + The columns in the main table used for joining. The main table is the + argument of ``transform``, to which we add information inferred using + ``aux_table``. The column names listed in ``main_key`` will provide the + inputs (features) of the interpolators at prediction (joining) time. In + the example above, ``main_key`` is ``["latitude", "longitude"]``, which + refer to columns in the ``buildings`` table. When joining on a single + column, we can pass its name rather than a list: ``"latitude"`` is + equivalent to ``["latitude"]``. + + aux_key : list of str, or str + The columns in ``aux_table`` used for joining. Their number and types + must match those of the ``main_key`` columns in the main table. These + columns provide the features for the estimators to be fitted. As for + ``main_key``, it is possible to pass a string when using a single + column. + + key : list of str, or str + Column names to use for both `main_key` and `aux_key`, when they are + the same. Provide either `key` (only) or both `main_key` and `aux_key`. + + suffix : str + Suffix to append to the ``aux_table``'s column names. You can use it + to avoid duplicate column names in the join. + + regressor : scikit-learn regressor + Model used to predict the numerical columns of ``aux_table``. + + classifier : scikit-learn classifier + Model used to predict the categorical (string) columns of ``aux_table``. + + vectorizer : scikit-learn transformer that can operate on a DataFrame + Used to transform the feature columns before passing them to the + scikit-learn estimators. This is useful if we are joining on columns + that need some transformation, such as dates or strings representing + high-cardinality categories. By default we use a ``MinHashEncoder`` to + vectorize text columns. This is because the ``MinHashEncoder`` is very + fast and usually gives good results with downstream learners based on + trees like the gradient-boosted trees used by default for ``regressor`` + and ``classifier``. If you replace the default regressor and classifier + with models such as nearest-neighbors or linear models, consider + passing ``vectorizer=TableVectorizer()`` which will encode text with a + ``GapEncoder`` rather than a ``MinHashEncoder``. + + n_jobs : int or None + Number of jobs to run in parallel. ``None`` means 1 unless in a + ``joblib.parallel_backend`` context. -1 means using all processors. + Depending on the estimators used and the contents of ``aux_table``, + several estimators may need to be fitted -- for example one for + continuous outputs (regressor) and one for categorical outputs + (classifier), or one for each column when the provided estimators do + not support multi-output tasks. Fitting and querying these estimators + can be done in parallel. + + on_estimator_failure : "warn", "raise" or "pass" + How to handle exceptions raised when fitting one of the estimators + (regressors and classifiers) or querying them for a prediction. If + "raise", exceptions are propagated. If "pass" (i) if an exception is + raised during ``fit`` the corresponding columns are ignored -- they + will not appear in the join and (ii) if an exception is raised during + ``transform``, the corresponding column will be filled with nulls. + Columns are filled with nulls during ``transform`` rather than dropped + so that the output always has the same shape. If "warn" (the default), + behave like "pass" but issue a warning. + + Attributes + ---------- + vectorizer_ : scikit-learn transformer + The transformer used to vectorize the feature columns. + + estimators_ : list of dicts + The estimators used to infer values to be joined. Each entry in this + list is a dictionary with keys ``"estimator"`` (the fitted estimator) + and ``"columns"`` (the list of columns in ``aux_table`` that it is + trained to predict). + + See Also + -------- + Joiner : + Works in a similar way but instead of inferring values, picks the + closest row from the auxiliary table. + + Examples + -------- + >>> buildings = pd.DataFrame( + ... {"latitude": [1.0, 2.0], "longitude": [1.0, 2.0], "n_stories": [3, 7]} + ... ) + >>> annual_avg_temp = pd.DataFrame( + ... { + ... "latitude": [1.2, 0.9, 1.9, 1.7, 5.0], + ... "longitude": [0.8, 1.1, 1.8, 1.8, 5.0], + ... "avg_temp": [10.0, 11.0, 15.0, 16.0, 20.0], + ... } + ... ) + + >>> buildings + latitude longitude n_stories + 0 1.0 1.0 3 + 1 2.0 2.0 7 + + >>> annual_avg_temp + latitude longitude avg_temp + 0 1.2 0.8 10.0 + 1 0.9 1.1 11.0 + 2 1.9 1.8 15.0 + 3 1.7 1.8 16.0 + 4 5.0 5.0 20.0 + + >>> from sklearn.neighbors import KNeighborsRegressor + + >>> InterpolationJoiner( + ... annual_avg_temp, + ... key=["latitude", "longitude"], + ... regressor=KNeighborsRegressor(2), + ... ).fit_transform(buildings) + latitude longitude n_stories avg_temp + 0 1.0 1.0 3 10.5 + 1 2.0 2.0 7 15.5 + """ + + def __init__( + self, + aux_table, + *, + main_key=None, + aux_key=None, + key=None, + suffix="", + regressor=DEFAULT_REGRESSOR, + classifier=DEFAULT_CLASSIFIER, + vectorizer=DEFAULT_VECTORIZER, + n_jobs=None, + on_estimator_failure="warn", + ): + self.aux_table = aux_table + self.main_key = main_key + self.aux_key = aux_key + self.key = key + self.suffix = suffix + self.regressor = _utils.clone_if_default(regressor, DEFAULT_REGRESSOR) + self.classifier = _utils.clone_if_default(classifier, DEFAULT_CLASSIFIER) + self.vectorizer = _utils.clone_if_default(vectorizer, DEFAULT_VECTORIZER) + self.n_jobs = n_jobs + self.on_estimator_failure = on_estimator_failure + + def fit(self, X, y=None): + """Fit estimators to the `aux_table` provided during initialization. + + `X` and `y` are mostly for scikit-learn compatibility. + + Parameters + ---------- + X : array-like or None + The main table to which ``self.aux_table`` could be joined. If `X` + is not ``None``, an error is raised if any of the matching columns + listed in ``self.main_key`` (or ``self.key``) is missing from `X`. + + y : array-like + Ignored; only exists for compatibility with scikit-learn. + + Returns + ------- + self : InterpolationJoiner + Returns self. + """ + del y + self._check_inputs() + if X is not None: + _join_utils.check_missing_columns(X, self._main_key, "'X' (the main table)") + key_values = self.vectorizer_.fit_transform(self.aux_table[self._aux_key]) + estimators = self._get_estimator_assignments() + fit_results = joblib.Parallel(self.n_jobs)( + joblib.delayed(_fit)( + key_values, + self.aux_table[assignment["columns"]], + assignment["estimator"], + propagate_exceptions=(self.on_estimator_failure == "raise"), + ) + for assignment in estimators + ) + fit_results = self._check_fit_results(fit_results) + for res in fit_results: + del res["failed"] + self.estimators_ = fit_results + return self + + def _check_inputs(self): + self.vectorizer_ = clone(self.vectorizer) + self.classifier_ = clone(self.classifier) + self.regressor_ = clone(self.regressor) + self._main_key, self._aux_key = _join_utils.check_key( + self.main_key, self.aux_key, self.key + ) + _join_utils.check_missing_columns(self.aux_table, self._aux_key, "'aux_table'") + + def _check_fit_results(self, results): + successful_results = [res for res in results if not res["failed"]] + if self.on_estimator_failure == "pass": + return successful_results + failed_columns = [] + for res in results: + if res["failed"]: + failed_columns.extend(res["columns"]) + if not failed_columns: + return successful_results + warnings.warn( + "Estimators failed to be fitted for the following" + f" columns:\n{failed_columns}" + ) + return successful_results + + def transform(self, X): + """Transform a table by joining inferred values to it. + + The values of the `main_key` columns in `X` (the main table) are used + to predict likely values for the contents of a matching row in + `self.aux_table` (the auxiliary table). + + Parameters + ---------- + X : DataFrame + The (main) table to transform. + + Returns + ------- + join : DataFrame + The result of the join between `X` and inferred rows from + ``self.aux_table``. + """ + main_table = X + _join_utils.check_missing_columns( + main_table, self._main_key, "'X' (the main table)" + ) + key_values = self.vectorizer_.transform( + main_table[self._main_key].set_axis(self._aux_key, axis="columns") + ) + prediction_results = joblib.Parallel(self.n_jobs)( + joblib.delayed(_predict)( + key_values, + assignment["columns"], + assignment["estimator"], + propagate_exceptions=(self.on_estimator_failure == "raise"), + ) + for assignment in self.estimators_ + ) + prediction_results = self._check_prediction_results(prediction_results) + predictions = [res["predictions"] for res in prediction_results] + predictions = _add_column_name_suffix(predictions, self.suffix) + for part in predictions: + part.index = main_table.index + return pd.concat([main_table] + predictions, axis=1) + + def _check_prediction_results(self, results): + checked_results = [] + failed_columns = [] + for res in results: + new_res = dict(**res) + if res["failed"]: + if set(res["columns"]).issubset( + self.aux_table.select_dtypes("number").columns.values + ): + dtype = float + else: + dtype = object + pred = pd.DataFrame( + columns=res["columns"], + index=np.arange(res["shape"][0]), + dtype=dtype, + ) + new_res["predictions"] = pred + failed_columns.extend(res["columns"]) + checked_results.append(new_res) + if not failed_columns: + return checked_results + if self.on_estimator_failure == "pass": + return checked_results + warnings.warn( + "Prediction failed for the following columns; output will be filled with" + f" nulls:\n{failed_columns}" + ) + return checked_results + + def _get_estimator_assignments(self): + """Identify column groups to be predicted together and assign them an estimator. + + In many cases, a single estimator cannot handle all the target columns. + This function groups columns that can be handled together and returns a + list of dictionaries, each with keys "columns" and "estimator". + + Regression and classification targets are always handled separately. + + Any column with missing values is handled separately from the rest. + This is due to the fact that missing values in the columns we are + trying to predict have to be dropped, and the corresponding rows may + have valid values in the other columns. + + When the estimator does not handle multi-output, an estimator is fitted + separately to each column. + """ + aux_table = self.aux_table.drop(self._aux_key, axis=1) + assignments = [] + regression_table = aux_table.select_dtypes("number") + assignments.extend( + _get_assignments_for_estimator(regression_table, self.regressor_) + ) + classification_table = aux_table.select_dtypes(["object", "string", "category"]) + assignments.extend( + _get_assignments_for_estimator(classification_table, self.classifier_) + ) + return assignments + + +def _get_assignments_for_estimator(table, estimator): + """Get the groups of columns assigned to a single estimator. + + (which is either the regressor or the classifier).""" + + # If the complete set of columns that have to be predicted with this + # estimator is empty (eg the estimator is the regressor and there are no + # numerical columns), return an empty list -- no columns are assigned to + # that estimator. + if table.empty: + return [] + if not _handles_multioutput(estimator): + return [{"columns": [col], "estimator": estimator} for col in table.columns] + columns_with_nulls = table.columns[table.isnull().any()] + assignments = [ + {"columns": [col], "estimator": estimator} for col in columns_with_nulls + ] + columns_without_nulls = list(set(table.columns).difference(columns_with_nulls)) + if columns_without_nulls: + assignments.append({"columns": columns_without_nulls, "estimator": estimator}) + return assignments + + +def _handles_multioutput(estimator): + return _safe_tags(estimator).get("multioutput", False) + + +def _fit(key_values, target_table, estimator, propagate_exceptions): + estimator = clone(estimator) + kept_rows = target_table.notnull().all(axis=1).to_numpy() + key_values = key_values[kept_rows] + Y = target_table.to_numpy()[kept_rows] + + # Estimators that expect a single output issue a DataConversionWarning if + # passing a column vector rather than a 1-D array + if len(target_table.columns) == 1: + Y = Y.ravel() + failed = False + try: + estimator.fit(key_values, Y) + except Exception: + if propagate_exceptions: + raise + failed = True + estimator = None + return {"columns": target_table.columns, "estimator": estimator, "failed": failed} + + +def _predict(key_values, columns, estimator, propagate_exceptions): + failed = False + try: + Y_values = estimator.predict(key_values) + except Exception: + if propagate_exceptions: + raise + failed = True + if failed: + predictions = None + else: + predictions = pd.DataFrame(data=Y_values, columns=columns) + return { + "predictions": predictions, + "failed": failed, + "columns": columns, + "shape": (key_values.shape[0], len(columns)), + } + + +def _add_column_name_suffix(dataframes, suffix): + if suffix == "": + return dataframes + renamed = [] + for df in dataframes: + renamed.append(df.rename(columns={c: f"{c}{suffix}" for c in df.columns})) + return renamed diff --git a/skrub/_table_vectorizer.py b/skrub/_table_vectorizer.py index 7832ca0e1..af66605e8 100644 --- a/skrub/_table_vectorizer.py +++ b/skrub/_table_vectorizer.py @@ -21,7 +21,7 @@ from sklearn.utils.validation import check_is_fitted from skrub import DatetimeEncoder, GapEncoder -from skrub._utils import parse_astype_error_message +from skrub._utils import clone_if_default, parse_astype_error_message HIGH_CARDINALITY_TRANSFORMER = GapEncoder(n_components=30) LOW_CARDINALITY_TRANSFORMER = OneHotEncoder( @@ -150,10 +150,6 @@ def _replace_missing_in_cat_col(ser: pd.Series, value: str = "missing") -> pd.Se return ser -def _clone_if_default(transformer, default_transformer): - return clone(transformer) if transformer is default_transformer else transformer - - def _clone_during_fit(transformer, remainder, n_jobs): if isinstance(transformer, sklearn.base.TransformerMixin): return _propagate_n_jobs(clone(transformer), n_jobs) @@ -449,13 +445,13 @@ def __init__( verbose_feature_names_out=False, ): self.cardinality_threshold = cardinality_threshold - self.low_cardinality_transformer = _clone_if_default( + self.low_cardinality_transformer = clone_if_default( low_cardinality_transformer, LOW_CARDINALITY_TRANSFORMER ) - self.high_cardinality_transformer = _clone_if_default( + self.high_cardinality_transformer = clone_if_default( high_cardinality_transformer, HIGH_CARDINALITY_TRANSFORMER ) - self.datetime_transformer = _clone_if_default( + self.datetime_transformer = clone_if_default( datetime_transformer, DATETIME_TRANSFORMER ) self.numerical_transformer = numerical_transformer diff --git a/skrub/_utils.py b/skrub/_utils.py index 1d7e5e162..72b89e295 100644 --- a/skrub/_utils.py +++ b/skrub/_utils.py @@ -6,6 +6,7 @@ import numpy as np from numpy.typing import NDArray +from sklearn.base import clone from sklearn.utils import check_array @@ -157,3 +158,7 @@ def atleast_2d_or_none(x): # 1d array else: return [x] + + +def clone_if_default(estimator, default_estimator): + return clone(estimator) if estimator is default_estimator else estimator diff --git a/skrub/tests/test_interpolation_join.py b/skrub/tests/test_interpolation_join.py new file mode 100644 index 000000000..feee706c8 --- /dev/null +++ b/skrub/tests/test_interpolation_join.py @@ -0,0 +1,259 @@ +import pandas as pd +import pytest +from numpy.testing import assert_array_equal +from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.dummy import DummyClassifier, DummyRegressor +from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor + +from skrub import InterpolationJoiner + + +@pytest.fixture +def buildings(): + return pd.DataFrame( + {"latitude": [1.0, 2.0], "longitude": [1.0, 2.0], "n_stories": [3, 7]} + ) + + +@pytest.fixture +def weather(): + return pd.DataFrame( + { + "latitude": [1.2, 0.9, 1.9, 1.7, 5.0, 5.0], + "longitude": [0.8, 1.1, 1.8, 1.8, 5.0, 5.0], + "avg_temp": [10.0, 11.0, 15.0, 16.0, 20.0, None], + "climate": ["A", "A", "B", "B", "C", "C"], + } + ) + + +@pytest.mark.parametrize("key", [["latitude", "longitude"], "latitude"]) +@pytest.mark.parametrize("with_nulls", [False, True]) +def test_interpolation_join(buildings, weather, key, with_nulls): + if not with_nulls: + weather = weather.fillna(0.0) + transformed = InterpolationJoiner( + weather, + key=key, + regressor=KNeighborsRegressor(2), + classifier=KNeighborsClassifier(2), + ).fit_transform(buildings) + assert_array_equal(transformed["avg_temp"].values, [10.5, 15.5]) + assert_array_equal(transformed["climate"].values, ["A", "B"]) + + +def test_vectorizer(): + main = pd.DataFrame({"A": [0, 1]}) + aux = pd.DataFrame({"A": [11, 110], "B": [1, 0]}) + + class Vectorizer(TransformerMixin, BaseEstimator): + def fit(self, X): + return self + + def transform(self, X): + return X % 10 + + join = InterpolationJoiner( + aux, + key="A", + regressor=KNeighborsRegressor(1), + vectorizer=Vectorizer(), + ).fit_transform(main) + assert_array_equal(join["B"], [0, 1]) + + +def test_no_multioutput(buildings, weather): + transformed = InterpolationJoiner( + weather, + main_key=("latitude", "longitude"), + aux_key=("latitude", "longitude"), + ).fit_transform(buildings) + assert transformed.shape == (2, 5) + + +def test_condition_choice(): + main = pd.DataFrame({"A": [0, 1, 2]}) + aux = pd.DataFrame({"A": [0, 1, 2], "rB": [2, 0, 1], "C": [10, 11, 12]}) + join = InterpolationJoiner( + aux, key="A", regressor=KNeighborsRegressor(1) + ).fit_transform(main) + assert_array_equal(join["C"].values, [10, 11, 12]) + + join = InterpolationJoiner( + aux, main_key="A", aux_key="rB", regressor=KNeighborsRegressor(1) + ).fit_transform(main) + assert_array_equal(join["C"].values, [11, 12, 10]) + + with pytest.raises(ValueError, match="Must pass EITHER"): + join = InterpolationJoiner( + aux, main_key="A", regressor=KNeighborsRegressor(1) + ).fit(None) + + with pytest.raises(ValueError, match="Can only pass"): + join = InterpolationJoiner( + aux, key="A", main_key="A", regressor=KNeighborsRegressor(1) + ).fit(None) + + with pytest.raises(ValueError, match="Can only pass"): + join = InterpolationJoiner( + aux, key="A", main_key="A", aux_key="A", regressor=KNeighborsRegressor(1) + ).fit(None) + + +def test_suffix(): + df = pd.DataFrame({"A": [0, 1], "B": [0, 1]}) + join = InterpolationJoiner( + df, key="A", suffix="_aux", regressor=KNeighborsRegressor(1) + ).fit_transform(df) + assert_array_equal(join.columns, ["A", "B", "B_aux"]) + + +def test_mismatched_indexes(): + main = pd.DataFrame({"A": [0, 1]}, index=[1, 0]) + aux = pd.DataFrame({"A": [0, 1], "B": [10, 11]}) + join = InterpolationJoiner( + aux, key="A", regressor=KNeighborsRegressor(1) + ).fit_transform(main) + assert_array_equal(join["B"].values, [10, 11]) + assert_array_equal(join.index.values, [1, 0]) + + +def test_fit_on_none(): + # X is hardly used in fit so it should be ok to fit without a main table + aux = pd.DataFrame({"A": [0, 1], "B": [10, 11]}) + joiner = InterpolationJoiner(aux, key="A", regressor=KNeighborsRegressor(1)).fit( + None + ) + main = pd.DataFrame({"A": [0, 1]}, index=[1, 0]) + join = joiner.transform(main) + assert_array_equal(join["B"].values, [10, 11]) + assert_array_equal(join.index.values, [1, 0]) + + +def test_join_on_date(): + sales = pd.DataFrame({"date": ["2023-09-20", "2023-09-29"], "n": [10, 15]}) + temp = pd.DataFrame( + {"date": ["2023-09-09", "2023-10-01", "2024-09-21"], "temp": [-10, 10, 30]} + ) + transformed = ( + InterpolationJoiner( + temp, + main_key="date", + aux_key="date", + regressor=KNeighborsRegressor(1), + ) + .set_params(vectorizer__datetime_transformer__resolution=None) + .fit_transform(sales) + ) + assert_array_equal(transformed["temp"].values, [-10, 10]) + + +class FailFit(DummyClassifier): + def fit(self, X, y): + raise ValueError("FailFit failed") + + +def test_fit_failures(buildings, weather): + weather["climate"] = "A" + joiner = InterpolationJoiner( + weather, + key=["latitude", "longitude"], + regressor=KNeighborsRegressor(2), + classifier=FailFit(), + on_estimator_failure="pass", + ) + join = joiner.fit_transform(buildings) + assert_array_equal(join["avg_temp"].values, [10.5, 15.5]) + assert join.shape == (2, 4) + + joiner = InterpolationJoiner( + weather, + key=["latitude", "longitude"], + regressor=KNeighborsRegressor(2), + classifier=FailFit(), + on_estimator_failure="warn", + ) + with pytest.warns(UserWarning, match="(?s)Estimators failed.*climate"): + join = joiner.fit_transform(buildings) + assert_array_equal(join["avg_temp"].values, [10.5, 15.5]) + assert join.shape == (2, 4) + + joiner = InterpolationJoiner( + weather, + key=["latitude", "longitude"], + regressor=KNeighborsRegressor(2), + classifier=FailFit(), + on_estimator_failure="raise", + ) + with pytest.raises(ValueError, match="FailFit failed"): + join = joiner.fit_transform(buildings) + + +class FailPredict(DummyClassifier): + def predict(self, X): + raise ValueError("FailPredict failed") + + +def test_transform_failures(buildings, weather): + joiner = InterpolationJoiner( + weather, + key=["latitude", "longitude"], + regressor=KNeighborsRegressor(2), + classifier=FailPredict(), + on_estimator_failure="pass", + ) + join = joiner.fit_transform(buildings) + assert_array_equal(join["avg_temp"].values, [10.5, 15.5]) + assert join["climate"].isnull().all() + assert join["climate"].dtype == object + assert join.shape == (2, 5) + + joiner = InterpolationJoiner( + weather, + key=["latitude", "longitude"], + regressor=KNeighborsRegressor(2), + classifier=FailPredict(), + on_estimator_failure="warn", + ) + with pytest.warns(UserWarning, match="(?s)Prediction failed.*climate"): + join = joiner.fit_transform(buildings) + assert_array_equal(join["avg_temp"].values, [10.5, 15.5]) + assert join["climate"].isnull().all() + assert join["climate"].dtype == object + assert join.shape == (2, 5) + + joiner = InterpolationJoiner( + weather, + key=["latitude", "longitude"], + regressor=KNeighborsRegressor(2), + classifier=FailPredict(), + on_estimator_failure="raise", + ) + with pytest.raises(Exception, match="FailPredict failed"): + join = joiner.fit_transform(buildings) + + +def test_transform_failures_dtype(buildings, weather): + joiner = InterpolationJoiner( + weather, + key=["latitude", "longitude"], + regressor=FailPredict(), + classifier=DummyClassifier(), + on_estimator_failure="pass", + ) + join = joiner.fit_transform(buildings) + assert join["avg_temp"].isnull().all() + assert join["avg_temp"].dtype == "float64" + assert join.shape == (2, 5) + + joiner = InterpolationJoiner( + weather, + key=["latitude", "longitude"], + regressor=DummyRegressor(), + classifier=FailPredict(), + on_estimator_failure="pass", + ) + join = joiner.fit_transform(buildings) + assert join["climate"].isnull().all() + assert join["climate"].dtype == object + assert join.shape == (2, 5)