diff --git a/py-polars/polars/selectors.py b/py-polars/polars/selectors.py index a994347dd7e8..ad1c88debe91 100644 --- a/py-polars/polars/selectors.py +++ b/py-polars/polars/selectors.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re from datetime import timezone from functools import reduce from operator import or_ @@ -9,7 +8,7 @@ from polars import functions as F from polars._utils.deprecation import deprecate_nonkeyword_arguments from polars._utils.parse_expr_input import _parse_inputs_as_iterable -from polars._utils.various import is_column +from polars._utils.various import is_column, re_escape from polars.datatypes import ( FLOAT_DTYPES, INTEGER_DTYPES, @@ -287,7 +286,7 @@ def __repr__(self) -> str: op = set_ops[selector_name] return "({})".format(f" {op} ".join(repr(p) for p in params.values())) else: - str_params = ",".join( + str_params = ", ".join( (repr(v)[1:-1] if k.startswith("*") else f"{k}={v!r}") for k, v in (params or {}).items() ).rstrip(",") @@ -310,8 +309,10 @@ def __rsub__(self, other: Any) -> SelectorType | Expr: # type: ignore[override] def __and__(self, other: Any) -> SelectorType | Expr: # type: ignore[override] if is_column(other): colname = other.meta.output_name() - if self._attrs["name"] == "by_name": - return by_name(*self._attrs["params"]["*names"], colname) + if self._attrs["name"] == "by_name" and ( + params := self._attrs["params"] + ).get("require_all", True): + return by_name(*params["*names"], colname) other = by_name(colname) if is_selector(other): return _selector_proxy_( @@ -337,8 +338,10 @@ def __or__(self, other: Any) -> SelectorType | Expr: # type: ignore[override] def __rand__(self, other: Any) -> SelectorType | Expr: # type: ignore[override] if is_column(other): colname = other.meta.output_name() - if self._attrs["name"] == "by_name": - return by_name(colname, *self._attrs["params"]["*names"]) + if self._attrs["name"] == "by_name" and ( + params := self._attrs["params"] + ).get("require_all", True): + return by_name(colname, *params["*names"]) other = by_name(colname) return self.as_expr().__rand__(other) @@ -400,7 +403,7 @@ def as_expr(self) -> Expr: def _re_string(string: str | Collection[str], *, escape: bool = True) -> str: """Return escaped regex, potentially representing multiple string fragments.""" if isinstance(string, str): - rx = f"{re.escape(string)}" if escape else string + rx = f"{re_escape(string)}" if escape else string else: strings: list[str] = [] for st in string: @@ -408,7 +411,7 @@ def _re_string(string: str | Collection[str], *, escape: bool = True) -> str: strings.extend(st) else: strings.append(st) - rx = "|".join((re.escape(x) if escape else x) for x in strings) + rx = "|".join((re_escape(x) if escape else x) for x in strings) return f"({rx})" @@ -732,7 +735,7 @@ def by_index(*indices: int | range | Sequence[int | range]) -> SelectorType: ) -def by_name(*names: str | Collection[str]) -> SelectorType: +def by_name(*names: str | Collection[str], require_all: bool = True) -> SelectorType: """ Select all columns matching the given names. @@ -740,6 +743,13 @@ def by_name(*names: str | Collection[str]) -> SelectorType: ---------- *names One or more names of columns to select. + require_all + Whether to match *all* names (the default) or *any* of the names. + + Notes + ----- + Matching columns are returned in the order in which they are declared in + the selector, not the underlying schema order. See Also -------- @@ -771,6 +781,19 @@ def by_name(*names: str | Collection[str]) -> SelectorType: │ y ┆ 456 │ └─────┴─────┘ + Match *any* of the given columns by name: + + >>> df.select(cs.by_name("baz", "moose", "foo", "bear", require_all=False)) + shape: (2, 2) + ┌─────┬─────┐ + │ foo ┆ baz │ + │ --- ┆ --- │ + │ str ┆ f64 │ + ╞═════╪═════╡ + │ x ┆ 2.0 │ + │ y ┆ 5.5 │ + └─────┴─────┘ + Match all columns *except* for those given: >>> df.select(~cs.by_name("foo", "bar")) @@ -798,8 +821,16 @@ def by_name(*names: str | Collection[str]) -> SelectorType: msg = f"invalid name: {nm!r}" raise TypeError(msg) + selector_params: dict[str, Any] = {"*names": all_names} + match_cols: list[str] | str = all_names + if not require_all: + match_cols = f"^({'|'.join(re_escape(nm) for nm in all_names)})$" + selector_params["require_all"] = require_all + return _selector_proxy_( - F.col(all_names), name="by_name", parameters={"*names": all_names} + F.col(match_cols), + name="by_name", + parameters=selector_params, ) diff --git a/py-polars/tests/unit/test_selectors.py b/py-polars/tests/unit/test_selectors.py index 16766d5cfb6e..ca5112c6ccbf 100644 --- a/py-polars/tests/unit/test_selectors.py +++ b/py-polars/tests/unit/test_selectors.py @@ -130,14 +130,22 @@ def test_selector_by_name(df: pl.DataFrame) -> None: assert df.select(cs.by_name()).columns == [] assert df.select(cs.by_name([])).columns == [] + selected_cols = df.select( + cs.by_name("???", "fgg", "!!!", require_all=False) + ).columns + assert selected_cols == ["fgg"] + # check "by_name & col" - for selector_expr in ( - cs.by_name("abc", "cde") & pl.col("ghi"), - pl.col("abc") & cs.by_name("cde", "ghi"), + for selector_expr, expected in ( + (cs.by_name("abc", "cde") & pl.col("ghi"), ["abc", "cde", "ghi"]), + (pl.col("ghi") & cs.by_name("cde", "abc"), ["ghi", "cde", "abc"]), ): - assert df.select(selector_expr).columns == ["abc", "cde", "ghi"] + assert df.select(selector_expr).columns == expected # expected errors + with pytest.raises(ColumnNotFoundError, match="xxx"): + df.select(cs.by_name("xxx", "fgg", "!!!")) + with pytest.raises(ColumnNotFoundError): df.select(cs.by_name("stroopwafel")) @@ -494,7 +502,16 @@ def test_selector_repr() -> None: assert_repr_equals(~cs.starts_with("a", "b"), "~cs.starts_with('a', 'b')") assert_repr_equals(cs.float() | cs.by_name("x"), "(cs.float() | cs.by_name('x'))") assert_repr_equals( - cs.integer() & cs.matches("z"), "(cs.integer() & cs.matches(pattern='z'))" + cs.integer() & cs.matches("z"), + "(cs.integer() & cs.matches(pattern='z'))", + ) + assert_repr_equals( + cs.by_name("baz", "moose", "foo", "bear"), + "cs.by_name('baz', 'moose', 'foo', 'bear')", + ) + assert_repr_equals( + cs.by_name("baz", "moose", "foo", "bear", require_all=False), + "cs.by_name('baz', 'moose', 'foo', 'bear', require_all=False)", ) assert_repr_equals( cs.temporal() | cs.by_dtype(pl.String) & cs.string(include_categorical=False),