diff --git a/skrub/_select_cols.py b/skrub/_select_cols.py index 9fead261e..b189ce717 100644 --- a/skrub/_select_cols.py +++ b/skrub/_select_cols.py @@ -15,11 +15,13 @@ def _check_columns(df, columns): example ``selector = SelectCols(["A", "B"]).fit(None)``, as the fit data is not used for anything else than this check. """ + if isinstance(columns, str): + columns = [columns] if not hasattr(df, "columns"): - return + return columns diff = set(columns) - set(df.columns) if not diff: - return + return columns raise ValueError( f"The following columns were not found in the input DataFrame: {diff}" ) @@ -33,8 +35,10 @@ class SelectCols(TransformerMixin, BaseEstimator): Parameters ---------- - cols : list of str - The columns to select. + cols : list of str or str + The columns to select. A single column name can be passed as a ``str``: + ``"col_name"`` is the same as ``["col_name"]`` + Examples -------- @@ -91,9 +95,9 @@ def transform(self, X): The input DataFrame ``X`` after selecting only the columns listed in ``self.cols`` (in the provided order). """ - _check_columns(X, self.cols) + cols = _check_columns(X, self.cols) namespace, _ = get_df_namespace(X) - return namespace.select(X, self.cols) + return namespace.select(X, cols) class DropCols(TransformerMixin, BaseEstimator): @@ -104,8 +108,9 @@ class DropCols(TransformerMixin, BaseEstimator): Parameters ---------- - cols : list of str - The columns to drop. + cols : list of str or str + The columns to drop. A single column name can be passed as a ``str``: + ``"col_name"`` is the same as ``["col_name"]``. Examples -------- @@ -162,6 +167,6 @@ def transform(self, X): The input DataFrame ``X`` after dropping the columns listed in ``self.cols``. """ - _check_columns(X, self.cols) + cols = _check_columns(X, self.cols) namespace, _ = get_df_namespace(X) - return namespace.select(X, [c for c in X.columns if c not in self.cols]) + return namespace.select(X, [c for c in X.columns if c not in cols]) diff --git a/skrub/tests/test_select_cols.py b/skrub/tests/test_select_cols.py index ee77af692..0e36866be 100644 --- a/skrub/tests/test_select_cols.py +++ b/skrub/tests/test_select_cols.py @@ -1,4 +1,5 @@ import pandas +import pandas.testing import pytest from skrub import DropCols, SelectCols @@ -25,6 +26,18 @@ def test_select_cols(df): assert list(out["C"]) == ["x", "y"] +def test_select_single_col(df): + out_1 = SelectCols("A").fit_transform(df) + out_2 = SelectCols(["A"]).fit_transform(df) + pandas.testing.assert_frame_equal(pandas.DataFrame(out_1), pandas.DataFrame(out_2)) + + +def test_fit_select_cols_without_x(df): + selector = SelectCols(["C", "A"]).fit(None) + out = selector.transform(df) + assert list(out.columns) == ["C", "A"] + + def test_select_missing_cols(df): selector = SelectCols(["X", "A"]) with pytest.raises(ValueError, match="not found"): @@ -43,6 +56,18 @@ def test_drop_cols(df): assert list(out["B"]) == [10, 20] +def test_drop_single_col(df): + out_1 = DropCols("A").fit_transform(df) + out_2 = DropCols(["A"]).fit_transform(df) + pandas.testing.assert_frame_equal(pandas.DataFrame(out_1), pandas.DataFrame(out_2)) + + +def test_fit_drop_cols_without_x(df): + selector = DropCols(["C", "A"]).fit(None) + out = selector.transform(df) + assert list(out.columns) == ["B"] + + def test_drop_missing_cols(df): selector = DropCols(["X", "A"]) with pytest.raises(ValueError, match="not found"):