Skip to content

Commit

Permalink
allow passing single column as a string
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromedockes committed Oct 25, 2023
1 parent 4846c17 commit cacbbe7
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 10 deletions.
25 changes: 15 additions & 10 deletions skrub/_select_cols.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ def _check_columns(df, columns):
example ``selector = SelectCols(["A", "B"]).fit(None)``, as the fit data is
not used for anything else than this check.
"""
if isinstance(columns, str):
columns = [columns]
if not hasattr(df, "columns"):
return
return columns
diff = set(columns) - set(df.columns)
if not diff:
return
return columns
raise ValueError(
f"The following columns were not found in the input DataFrame: {diff}"
)
Expand All @@ -33,8 +35,10 @@ class SelectCols(TransformerMixin, BaseEstimator):
Parameters
----------
cols : list of str
The columns to select.
cols : list of str or str
The columns to select. A single column name can be passed as a ``str``:
``"col_name"`` is the same as ``["col_name"]``
Examples
--------
Expand Down Expand Up @@ -91,9 +95,9 @@ def transform(self, X):
The input DataFrame ``X`` after selecting only the columns listed
in ``self.cols`` (in the provided order).
"""
_check_columns(X, self.cols)
cols = _check_columns(X, self.cols)
namespace, _ = get_df_namespace(X)
return namespace.select(X, self.cols)
return namespace.select(X, cols)


class DropCols(TransformerMixin, BaseEstimator):
Expand All @@ -104,8 +108,9 @@ class DropCols(TransformerMixin, BaseEstimator):
Parameters
----------
cols : list of str
The columns to drop.
cols : list of str or str
The columns to drop. A single column name can be passed as a ``str``:
``"col_name"`` is the same as ``["col_name"]``.
Examples
--------
Expand Down Expand Up @@ -162,6 +167,6 @@ def transform(self, X):
The input DataFrame ``X`` after dropping the columns listed in
``self.cols``.
"""
_check_columns(X, self.cols)
cols = _check_columns(X, self.cols)
namespace, _ = get_df_namespace(X)
return namespace.select(X, [c for c in X.columns if c not in self.cols])
return namespace.select(X, [c for c in X.columns if c not in cols])
25 changes: 25 additions & 0 deletions skrub/tests/test_select_cols.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pandas
import pandas.testing
import pytest

from skrub import DropCols, SelectCols
Expand All @@ -25,6 +26,18 @@ def test_select_cols(df):
assert list(out["C"]) == ["x", "y"]


def test_select_single_col(df):
out_1 = SelectCols("A").fit_transform(df)
out_2 = SelectCols(["A"]).fit_transform(df)
pandas.testing.assert_frame_equal(pandas.DataFrame(out_1), pandas.DataFrame(out_2))


def test_fit_select_cols_without_x(df):
selector = SelectCols(["C", "A"]).fit(None)
out = selector.transform(df)
assert list(out.columns) == ["C", "A"]


def test_select_missing_cols(df):
selector = SelectCols(["X", "A"])
with pytest.raises(ValueError, match="not found"):
Expand All @@ -43,6 +56,18 @@ def test_drop_cols(df):
assert list(out["B"]) == [10, 20]


def test_drop_single_col(df):
out_1 = DropCols("A").fit_transform(df)
out_2 = DropCols(["A"]).fit_transform(df)
pandas.testing.assert_frame_equal(pandas.DataFrame(out_1), pandas.DataFrame(out_2))


def test_fit_drop_cols_without_x(df):
selector = DropCols(["C", "A"]).fit(None)
out = selector.transform(df)
assert list(out.columns) == ["B"]


def test_drop_missing_cols(df):
selector = DropCols(["X", "A"])
with pytest.raises(ValueError, match="not found"):
Expand Down

0 comments on commit cacbbe7

Please sign in to comment.