Skip to content

Commit

Permalink
feat: expressify lower_bound and upper_bound in is_between (#1672)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Dec 29, 2024
1 parent b6d1eeb commit 2dd4480
Show file tree
Hide file tree
Showing 11 changed files with 54 additions and 130 deletions.
1 change: 0 additions & 1 deletion docs/how_it_works.md
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ In order to tell whether an aggregation is simple, Narwhals uses the private `_d
```python exec="1" result="python" session="pandas_impl" source="above"
print(pn.col("a").mean())
print((pn.col("a") + 1).mean())
print(pn.mean("a"))
```

For simple aggregations, Narwhals can just look at `_depth` and `function_name` and figure out
Expand Down
25 changes: 0 additions & 25 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,31 +348,6 @@ def concat(
result_table, backend_version=self._backend_version, version=self._version
)

def sum(self: Self, *column_names: str) -> ArrowExpr:
return ArrowExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).sum()

def mean(self: Self, *column_names: str) -> ArrowExpr:
return ArrowExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).mean()

def median(self: Self, *column_names: str) -> ArrowExpr:
return ArrowExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).median()

def max(self: Self, *column_names: str) -> ArrowExpr:
return ArrowExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).max()

def min(self: Self, *column_names: str) -> ArrowExpr:
return ArrowExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).min()

@property
def selectors(self: Self) -> ArrowSelectorNamespace:
return ArrowSelectorNamespace(
Expand Down
6 changes: 6 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,12 @@ def is_between(
import pyarrow.compute as pc

ser = self._native_series
_, lower_bound = broadcast_and_extract_native(
self, lower_bound, self._backend_version
)
_, upper_bound = broadcast_and_extract_native(
self, upper_bound, self._backend_version
)
if closed == "left":
ge = pc.greater_equal(ser, lower_bound)
lt = pc.less(ser, upper_bound)
Expand Down
25 changes: 0 additions & 25 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,31 +91,6 @@ def convert_if_dtype(
kwargs={},
)

def min(self, *column_names: str) -> DaskExpr:
return DaskExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).min()

def max(self, *column_names: str) -> DaskExpr:
return DaskExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).max()

def mean(self, *column_names: str) -> DaskExpr:
return DaskExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).mean()

def median(self, *column_names: str) -> DaskExpr:
return DaskExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).median()

def sum(self, *column_names: str) -> DaskExpr:
return DaskExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).sum()

def len(self) -> DaskExpr:
import dask.dataframe as dd
import pandas as pd
Expand Down
41 changes: 0 additions & 41 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,47 +169,6 @@ def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries:
kwargs={},
)

# --- reduction ---
def sum(self, *column_names: str) -> PandasLikeExpr:
return PandasLikeExpr.from_column_names(
*column_names,
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
).sum()

def mean(self, *column_names: str) -> PandasLikeExpr:
return PandasLikeExpr.from_column_names(
*column_names,
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
).mean()

def median(self, *column_names: str) -> PandasLikeExpr:
return PandasLikeExpr.from_column_names(
*column_names,
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
).median()

def max(self, *column_names: str) -> PandasLikeExpr:
return PandasLikeExpr.from_column_names(
*column_names,
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
).max()

def min(self, *column_names: str) -> PandasLikeExpr:
return PandasLikeExpr.from_column_names(
*column_names,
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
).min()

def len(self) -> PandasLikeExpr:
return PandasLikeExpr(
lambda df: [
Expand Down
11 changes: 10 additions & 1 deletion narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def is_between(
self, lower_bound: Any, upper_bound: Any, closed: str = "both"
) -> PandasLikeSeries:
ser = self._native_series
_, lower_bound = broadcast_align_and_extract_native(self, lower_bound)
_, upper_bound = broadcast_align_and_extract_native(self, upper_bound)
if closed == "left":
res = ser.ge(lower_bound) & ser.lt(upper_bound)
elif closed == "right":
Expand All @@ -273,7 +275,14 @@ def is_between(
res = ser.ge(lower_bound) & ser.le(upper_bound)
else: # pragma: no cover
raise AssertionError
return self._from_native_series(res)
return self._from_native_series(
rename(
res,
ser.name,
implementation=self._implementation,
backend_version=self._backend_version,
)
)

def is_in(self, other: Any) -> PandasLikeSeries:
ser = self._native_series
Expand Down
22 changes: 0 additions & 22 deletions narwhals/_polars/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,6 @@ def lit(self: Self, value: Any, dtype: DType | None = None) -> PolarsExpr:
pl.lit(value), version=self._version, backend_version=self._backend_version
)

def mean(self: Self, *column_names: str) -> PolarsExpr:
import polars as pl

from narwhals._polars.expr import PolarsExpr

return PolarsExpr(
pl.mean([*column_names]), # type: ignore[arg-type]
version=self._version,
backend_version=self._backend_version,
)

def mean_horizontal(self: Self, *exprs: IntoPolarsExpr) -> PolarsExpr:
import polars as pl

Expand All @@ -160,17 +149,6 @@ def mean_horizontal(self: Self, *exprs: IntoPolarsExpr) -> PolarsExpr:
backend_version=self._backend_version,
)

def median(self: Self, *column_names: str) -> PolarsExpr:
import polars as pl

from narwhals._polars.expr import PolarsExpr

return PolarsExpr(
pl.median([*column_names]), # type: ignore[arg-type]
version=self._version,
backend_version=self._backend_version,
)

def concat_str(
self,
exprs: Iterable[IntoPolarsExpr],
Expand Down
21 changes: 13 additions & 8 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,7 +1666,10 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self:

# --- transform ---
def is_between(
self, lower_bound: Any, upper_bound: Any, closed: str = "both"
self,
lower_bound: Any | IntoExpr,
upper_bound: Any | IntoExpr,
closed: str = "both",
) -> Self:
"""Check if this expression is between the given lower and upper bounds.
Expand Down Expand Up @@ -1724,7 +1727,9 @@ def is_between(
"""
return self.__class__(
lambda plx: self._to_compliant_expr(plx).is_between(
lower_bound, upper_bound, closed
extract_compliant(plx, lower_bound),
extract_compliant(plx, upper_bound),
closed,
)
)

Expand Down Expand Up @@ -6049,7 +6054,7 @@ def col(*names: str | Iterable[str]) -> Expr:
"""Creates an expression that references one or more columns by their name(s).
Arguments:
names: Name(s) of the columns to use in the aggregation function.
names: Name(s) of the columns to use.
Returns:
A new expression.
Expand Down Expand Up @@ -6308,7 +6313,7 @@ def sum(*columns: str) -> Expr:
----
a: [[3]]
"""
return Expr(lambda plx: plx.sum(*columns))
return Expr(lambda plx: plx.col(*columns).sum())


def mean(*columns: str) -> Expr:
Expand Down Expand Up @@ -6359,7 +6364,7 @@ def mean(*columns: str) -> Expr:
----
a: [[4]]
"""
return Expr(lambda plx: plx.mean(*columns))
return Expr(lambda plx: plx.col(*columns).mean())


def median(*columns: str) -> Expr:
Expand Down Expand Up @@ -6411,7 +6416,7 @@ def median(*columns: str) -> Expr:
----
a: [[4]]
"""
return Expr(lambda plx: plx.median(*columns))
return Expr(lambda plx: plx.col(*columns).median())


def min(*columns: str) -> Expr:
Expand Down Expand Up @@ -6462,7 +6467,7 @@ def min(*columns: str) -> Expr:
----
b: [[5]]
"""
return Expr(lambda plx: plx.min(*columns))
return Expr(lambda plx: plx.col(*columns).min())


def max(*columns: str) -> Expr:
Expand Down Expand Up @@ -6513,7 +6518,7 @@ def max(*columns: str) -> Expr:
----
a: [[2]]
"""
return Expr(lambda plx: plx.max(*columns))
return Expr(lambda plx: plx.col(*columns).max())


def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
Expand Down
8 changes: 6 additions & 2 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2138,7 +2138,7 @@ def fill_null(
)

def is_between(
self, lower_bound: Any, upper_bound: Any, closed: str = "both"
self, lower_bound: Any | Self, upper_bound: Any | Self, closed: str = "both"
) -> Self:
"""Get a boolean mask of the values that are between the given lower/upper bounds.
Expand Down Expand Up @@ -2189,7 +2189,11 @@ def is_between(
]
"""
return self._from_compliant_series(
self._compliant_series.is_between(lower_bound, upper_bound, closed=closed)
self._compliant_series.is_between(
self._extract_native(lower_bound),
self._extract_native(upper_bound),
closed=closed,
)
)

def n_unique(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2321,7 +2321,7 @@ def col(*names: str | Iterable[str]) -> Expr:
"""Creates an expression that references one or more columns by their name(s).
Arguments:
names: Name(s) of the columns to use in the aggregation function.
names: Name(s) of the columns to use.
Returns:
A new expression.
Expand Down
22 changes: 18 additions & 4 deletions tests/expr_and_series/is_between_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data

data = {
"a": [1, 4, 2, 5],
}


@pytest.mark.parametrize(
("closed", "expected"),
Expand All @@ -22,12 +18,21 @@
],
)
def test_is_between(constructor: Constructor, closed: str, expected: list[bool]) -> None:
data = {"a": [1, 4, 2, 5]}
df = nw.from_native(constructor(data))
result = df.select(nw.col("a").is_between(1, 5, closed=closed))
expected_dict = {"a": expected}
assert_equal_data(result, expected_dict)


def test_is_between_expressified(constructor: Constructor) -> None:
data = {"a": [1, 4, 2, 5], "b": [0, 5, 2, 4], "c": [9, 9, 9, 9]}
df = nw.from_native(constructor(data))
result = df.select(nw.col("a").is_between(nw.col("b") * 0.9, nw.col("c") - 1))
expected_dict = {"a": [True, False, True, True]}
assert_equal_data(result, expected_dict)


@pytest.mark.parametrize(
("closed", "expected"),
[
Expand All @@ -40,7 +45,16 @@ def test_is_between(constructor: Constructor, closed: str, expected: list[bool])
def test_is_between_series(
constructor_eager: ConstructorEager, closed: str, expected: list[bool]
) -> None:
data = {"a": [1, 4, 2, 5]}
df = nw.from_native(constructor_eager(data), eager_only=True)
result = df.with_columns(a=df["a"].is_between(1, 5, closed=closed))
expected_dict = {"a": expected}
assert_equal_data(result, expected_dict)


def test_is_between_expressified_series(constructor_eager: ConstructorEager) -> None:
data = {"a": [1, 4, 2, 5], "b": [0, 5, 2, 4], "c": [9, 9, 9, 9]}
df = nw.from_native(constructor_eager(data), eager_only=True)
result = df["a"].is_between(df["b"], df["c"]).to_frame()
expected_dict = {"a": [True, False, True, True]}
assert_equal_data(result, expected_dict)

0 comments on commit 2dd4480

Please sign in to comment.