Skip to content

Commit

Permalink
add "suffix" parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromedockes committed Sep 21, 2023
1 parent 9f4aa84 commit fa09bf1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
16 changes: 16 additions & 0 deletions skrub/_interpolation_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class InterpolationJoin(base.BaseEstimator):
Column names to use both `left_on` and `right_on`, when they are the
same. Provide either `on` (only) or both `left_on` and `right_on`.
suffix : str
Suffix to append to the ``right_table``'s column names. You can use it
to avoid duplicate column names in the join.
regressor : scikit-learn regressor
Model used to predict the numerical columns of ``right_table``.
Expand Down Expand Up @@ -129,6 +133,7 @@ def __init__(
left_on=None,
right_on=None,
on=None,
suffix="",
regressor=ensemble.HistGradientBoostingRegressor(),
classifier=ensemble.HistGradientBoostingClassifier(),
vectorizer=TableVectorizer(),
Expand All @@ -138,6 +143,7 @@ def __init__(
self.left_on = left_on
self.right_on = right_on
self.on = on
self.suffix = suffix
self.regressor = regressor
self.classifier = classifier
self.vectorizer = vectorizer
Expand Down Expand Up @@ -221,6 +227,7 @@ def transform(self, left_table):
)
for assignment in self.estimators_
)
interpolated_parts = _add_column_name_suffix(interpolated_parts, self.suffix)
return pd.concat([left_table] + interpolated_parts, axis=1)

def fit_transform(self, left_table):
Expand Down Expand Up @@ -309,3 +316,12 @@ def _fit(X_values, right_table, target_columns, estimator):
def _predict(X_values, columns, estimator):
Y_values = estimator.predict(X_values)
return pd.DataFrame(data=Y_values, columns=columns)


def _add_column_name_suffix(dataframes, suffix):
if suffix == "":
return dataframes
renamed = []
for df in dataframes:
renamed.append(df.rename(columns={c: f"{c}{suffix}" for c in df.columns}))
return renamed
8 changes: 8 additions & 0 deletions skrub/tests/test_interpolation_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ def test_condition_choice():
).fit()


def test_suffix():
df = pd.DataFrame({"A": [0, 1], "B": [0, 1]})
join = InterpolationJoin(
df, on="A", suffix="_right", regressor=KNeighborsRegressor(1)
).fit_transform(df)
assert (join.columns == ["A", "B", "B_right"]).all()


# expected to fail until we have a way to get the timestamp (only) from a date
# with the tablevectorizer
@pytest.mark.xfail
Expand Down

0 comments on commit fa09bf1

Please sign in to comment.