Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Add require_all parameter to the by_name column selector #15028

Merged
merged 2 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions py-polars/polars/selectors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import re
from datetime import timezone
from functools import reduce
from operator import or_
Expand All @@ -9,7 +8,7 @@
from polars import functions as F
from polars._utils.deprecation import deprecate_nonkeyword_arguments
from polars._utils.parse_expr_input import _parse_inputs_as_iterable
from polars._utils.various import is_column
from polars._utils.various import is_column, re_escape
from polars.datatypes import (
FLOAT_DTYPES,
INTEGER_DTYPES,
Expand Down Expand Up @@ -287,7 +286,7 @@ def __repr__(self) -> str:
op = set_ops[selector_name]
return "({})".format(f" {op} ".join(repr(p) for p in params.values()))
else:
str_params = ",".join(
str_params = ", ".join(
(repr(v)[1:-1] if k.startswith("*") else f"{k}={v!r}")
for k, v in (params or {}).items()
).rstrip(",")
Expand All @@ -310,8 +309,10 @@ def __rsub__(self, other: Any) -> SelectorType | Expr: # type: ignore[override]
def __and__(self, other: Any) -> SelectorType | Expr: # type: ignore[override]
if is_column(other):
colname = other.meta.output_name()
if self._attrs["name"] == "by_name":
return by_name(*self._attrs["params"]["*names"], colname)
if self._attrs["name"] == "by_name" and (
params := self._attrs["params"]
).get("require_all", True):
return by_name(*params["*names"], colname)
other = by_name(colname)
if is_selector(other):
return _selector_proxy_(
Expand All @@ -337,8 +338,10 @@ def __or__(self, other: Any) -> SelectorType | Expr: # type: ignore[override]
def __rand__(self, other: Any) -> SelectorType | Expr: # type: ignore[override]
if is_column(other):
colname = other.meta.output_name()
if self._attrs["name"] == "by_name":
return by_name(colname, *self._attrs["params"]["*names"])
if self._attrs["name"] == "by_name" and (
params := self._attrs["params"]
).get("require_all", True):
return by_name(colname, *params["*names"])
other = by_name(colname)
return self.as_expr().__rand__(other)

Expand Down Expand Up @@ -400,15 +403,15 @@ def as_expr(self) -> Expr:
def _re_string(string: str | Collection[str], *, escape: bool = True) -> str:
"""Return escaped regex, potentially representing multiple string fragments."""
if isinstance(string, str):
rx = f"{re.escape(string)}" if escape else string
rx = f"{re_escape(string)}" if escape else string
else:
strings: list[str] = []
for st in string:
if isinstance(st, Collection) and not isinstance(st, str): # type: ignore[redundant-expr]
strings.extend(st)
else:
strings.append(st)
rx = "|".join((re.escape(x) if escape else x) for x in strings)
rx = "|".join((re_escape(x) if escape else x) for x in strings)
return f"({rx})"


Expand Down Expand Up @@ -732,14 +735,21 @@ def by_index(*indices: int | range | Sequence[int | range]) -> SelectorType:
)


def by_name(*names: str | Collection[str]) -> SelectorType:
def by_name(*names: str | Collection[str], require_all: bool = True) -> SelectorType:
"""
Select all columns matching the given names.

Parameters
----------
*names
One or more names of columns to select.
require_all
Whether to match *all* names (the default) or *any* of the names.

Notes
-----
Matching columns are returned in the order in which they are declared in
the selector, not the underlying schema order.

See Also
--------
Expand Down Expand Up @@ -771,6 +781,19 @@ def by_name(*names: str | Collection[str]) -> SelectorType:
β”‚ y ┆ 456 β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

Match *any* of the given columns by name:

>>> df.select(cs.by_name("baz", "moose", "foo", "bear", require_all=False))
shape: (2, 2)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”
β”‚ foo ┆ baz β”‚
β”‚ --- ┆ --- β”‚
β”‚ str ┆ f64 β”‚
β•žβ•β•β•β•β•β•ͺ═════║
β”‚ x ┆ 2.0 β”‚
β”‚ y ┆ 5.5 β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

Match all columns *except* for those given:

>>> df.select(~cs.by_name("foo", "bar"))
Expand Down Expand Up @@ -798,8 +821,16 @@ def by_name(*names: str | Collection[str]) -> SelectorType:
msg = f"invalid name: {nm!r}"
raise TypeError(msg)

selector_params: dict[str, Any] = {"*names": all_names}
match_cols: list[str] | str = all_names
if not require_all:
match_cols = f"^({'|'.join(re_escape(nm) for nm in all_names)})$"
selector_params["require_all"] = require_all

return _selector_proxy_(
F.col(all_names), name="by_name", parameters={"*names": all_names}
F.col(match_cols),
name="by_name",
parameters=selector_params,
)


Expand Down
27 changes: 22 additions & 5 deletions py-polars/tests/unit/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,22 @@ def test_selector_by_name(df: pl.DataFrame) -> None:
assert df.select(cs.by_name()).columns == []
assert df.select(cs.by_name([])).columns == []

selected_cols = df.select(
cs.by_name("???", "fgg", "!!!", require_all=False)
).columns
assert selected_cols == ["fgg"]

# check "by_name & col"
for selector_expr in (
cs.by_name("abc", "cde") & pl.col("ghi"),
pl.col("abc") & cs.by_name("cde", "ghi"),
for selector_expr, expected in (
(cs.by_name("abc", "cde") & pl.col("ghi"), ["abc", "cde", "ghi"]),
(pl.col("ghi") & cs.by_name("cde", "abc"), ["ghi", "cde", "abc"]),
):
assert df.select(selector_expr).columns == ["abc", "cde", "ghi"]
assert df.select(selector_expr).columns == expected

# expected errors
with pytest.raises(ColumnNotFoundError, match="xxx"):
df.select(cs.by_name("xxx", "fgg", "!!!"))

with pytest.raises(ColumnNotFoundError):
df.select(cs.by_name("stroopwafel"))

Expand Down Expand Up @@ -494,7 +502,16 @@ def test_selector_repr() -> None:
assert_repr_equals(~cs.starts_with("a", "b"), "~cs.starts_with('a', 'b')")
assert_repr_equals(cs.float() | cs.by_name("x"), "(cs.float() | cs.by_name('x'))")
assert_repr_equals(
cs.integer() & cs.matches("z"), "(cs.integer() & cs.matches(pattern='z'))"
cs.integer() & cs.matches("z"),
"(cs.integer() & cs.matches(pattern='z'))",
)
assert_repr_equals(
cs.by_name("baz", "moose", "foo", "bear"),
"cs.by_name('baz', 'moose', 'foo', 'bear')",
)
assert_repr_equals(
cs.by_name("baz", "moose", "foo", "bear", require_all=False),
"cs.by_name('baz', 'moose', 'foo', 'bear', require_all=False)",
)
assert_repr_equals(
cs.temporal() | cs.by_dtype(pl.String) & cs.string(include_categorical=False),
Expand Down