Skip to content

Commit

Permalink
ENH apply TableVectorizer processing on joining columsn of Joiner (#972)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromedockes authored Jun 25, 2024
1 parent aa7836c commit 48d6ef1
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ Minor changes
value for the parameter `pca_components`.
:pr:`956` by :user:`Guillaume Lemaitre <glemaitre>`.

* :class:`Joiner` now performs some preprocessing (the same as done by the
:class:`TableVectorizer`, eg trying to parse dates, converting pandas object
columns with mixed types to a single type) on the joining columns before
vectorizing them. :pr:`972` by :user:`Jérôme Dockès <jeromedockes>`.

skrub release 0.1.0
===================

Expand Down
26 changes: 16 additions & 10 deletions skrub/_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from . import _selectors as s
from ._check_input import CheckInputDataFrame
from ._datetime_encoder import DatetimeEncoder
from ._table_vectorizer import TableVectorizer
from ._to_str import ToStr
from ._wrap_transformer import wrap_transformer

Expand Down Expand Up @@ -59,28 +60,33 @@ def _make_vectorizer(table, string_encoder, rescale):
In addition if `rescale` is `True`, a StandardScaler is applied to
numeric and datetime columns.
"""
# TODO: add Skrubber before ColumnTransformer
# TODO: remove use of ColumnTransformer
skrubber = TableVectorizer(
datetime="passthrough",
low_cardinality="passthrough",
high_cardinality="passthrough",
numeric="passthrough",
)
table = skrubber.fit_transform(table)
cols = skrubber.kind_to_columns_
transformers = [
(clone(string_encoder), c) for c in (s.string() | s.categorical()).expand(table)
(clone(string_encoder), c)
for c in cols["high_cardinality"] + cols["low_cardinality"]
]
num_columns = s.numeric().expand(table)
if num_columns:
if cols["numeric"]:
transformers.append(
(StandardScaler() if rescale else "passthrough", num_columns)
(StandardScaler() if rescale else "passthrough", cols["numeric"])
)
dt_columns = s.any_date().expand(table)
if dt_columns:
if cols["datetime"]:
transformers.append(
(
make_pipeline(
wrap_transformer(_DATETIME_ENCODER, s.all()),
StandardScaler() if rescale else "passthrough",
),
dt_columns,
cols["datetime"],
)
)
return make_column_transformer(*transformers)
return make_pipeline(skrubber, make_column_transformer(*transformers))


class Joiner(TransformerMixin, BaseEstimator):
Expand Down
9 changes: 9 additions & 0 deletions skrub/tests/test_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,12 @@ def test_fit_transform_datetimes(df_module):
df = df_module.make_dataframe({"A": values})
joiner = Joiner(df, key="A", suffix="_")
joiner.fit_transform(df)


def test_preprocessing(df_module):
df1 = df_module.make_dataframe(dict(date=["2021-10-01"], v=["A"]))
df2 = df_module.make_dataframe(
dict(date=["2021-09-17", "2012-10-01"], v=["A", "B"])
)
out = Joiner(df2, key="date", suffix="_", add_match_info=False).fit_transform(df1)
assert ns.to_list(ns.col(out, "v_")) == ["A"]

0 comments on commit 48d6ef1

Please sign in to comment.