Skip to content

Commit

Permalink
fix: fix pyarrow len aggregation behaviour for depth-1 exprs (#1589)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlessandroMiola authored Dec 14, 2024
1 parent 104a17d commit 3eec7a7
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 60 deletions.
50 changes: 23 additions & 27 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,28 @@
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.typing import IntoArrowExpr

POLARS_TO_ARROW_AGGREGATIONS = {
"sum": "sum",
"mean": "mean",
"median": "approximate_median",
"max": "max",
"min": "min",
"std": "stddev",
"var": "variance", # currently unused, we don't have `var` yet
"len": "count",
"n_unique": "count_distinct",
"count": "count",
}


def get_function_name_option(
function_name: str,
) -> pc.CountOptions | pc.VarianceOptions | None:
"""Map specific pyarrow compute function to respective option to match polars behaviour."""

def polars_to_arrow_aggregations() -> (
dict[str, tuple[str, pc.VarianceOptions | pc.CountOptions | None]]
):
"""Map polars compute functions to their pyarrow counterparts and options that help match polars behaviour."""
import pyarrow.compute as pc

function_name_to_options = {
"count": pc.CountOptions(mode="only_valid"),
"count_distinct": pc.CountOptions(mode="all"),
"stddev": pc.VarianceOptions(ddof=1),
"variance": pc.VarianceOptions(ddof=1),
return {
"sum": ("sum", None),
"mean": ("mean", None),
"median": ("approximate_median", None),
"max": ("max", None),
"min": ("min", None),
"std": ("stddev", pc.VarianceOptions(ddof=1)),
"var": (
"variance",
pc.VarianceOptions(ddof=1),
), # currently unused, we don't have `var` yet
"len": ("count", pc.CountOptions(mode="all")),
"n_unique": ("count_distinct", pc.CountOptions(mode="all")),
"count": ("count", pc.CountOptions(mode="only_valid")),
}
return function_name_to_options.get(function_name)


class ArrowGroupBy:
Expand Down Expand Up @@ -139,7 +134,7 @@ def agg_arrow(
if not (
is_simple_aggregation(expr)
and remove_prefix(expr._function_name, "col->")
in POLARS_TO_ARROW_AGGREGATIONS
in polars_to_arrow_aggregations()
):
all_simple_aggs = False
break
Expand Down Expand Up @@ -170,9 +165,10 @@ def agg_arrow(
raise AssertionError(msg)

function_name = remove_prefix(expr._function_name, "col->")
function_name = POLARS_TO_ARROW_AGGREGATIONS.get(function_name, function_name)
function_name, option = polars_to_arrow_aggregations().get(
function_name, (function_name, None)
)

option = get_function_name_option(function_name)
for root_name, output_name in zip(expr._root_names, expr._output_names):
simple_aggregations[output_name] = (
(root_name, function_name, option),
Expand Down
53 changes: 20 additions & 33 deletions tests/group_by_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,26 @@ def test_group_by_iter(constructor_eager: ConstructorEager) -> None:
assert sorted(keys) == sorted(expected_keys)


def test_group_by_len(constructor: Constructor) -> None:
result = (
nw.from_native(constructor(data)).group_by("a").agg(nw.col("b").len()).sort("a")
)
expected = {"a": [1, 3], "b": [2, 1]}
@pytest.mark.parametrize(
("expr", "expected"),
[
(nw.col("b").sum(), {"a": [1, 2], "b": [3, 3]}),
(nw.col("b").mean(), {"a": [1, 2], "b": [1.5, 3]}),
(nw.col("b").max(), {"a": [1, 2], "b": [2, 3]}),
(nw.col("b").min(), {"a": [1, 2], "b": [1, 3]}),
(nw.col("b").std(), {"a": [1, 2], "b": [0.707107, None]}),
(nw.col("b").len(), {"a": [1, 2], "b": [3, 1]}),
(nw.col("b").n_unique(), {"a": [1, 2], "b": [3, 1]}),
(nw.col("b").count(), {"a": [1, 2], "b": [2, 1]}),
],
)
def test_group_by_depth_1_agg(
constructor: Constructor,
expr: nw.Expr,
expected: dict[str, list[int | float]],
) -> None:
data = {"a": [1, 1, 1, 2], "b": [1, None, 2, 3]}
result = nw.from_native(constructor(data)).group_by("a").agg(expr).sort("a")
assert_equal_data(result, expected)


Expand All @@ -117,26 +132,6 @@ def test_group_by_median(constructor: Constructor) -> None:
assert_equal_data(result, expected)


def test_group_by_n_unique(constructor: Constructor) -> None:
result = (
nw.from_native(constructor(data))
.group_by("a")
.agg(nw.col("b").n_unique())
.sort("a")
)
expected = {"a": [1, 3], "b": [1, 1]}
assert_equal_data(result, expected)


def test_group_by_std(constructor: Constructor) -> None:
data = {"a": [1, 1, 2, 2], "b": [5, 4, 3, 2]}
result = (
nw.from_native(constructor(data)).group_by("a").agg(nw.col("b").std()).sort("a")
)
expected = {"a": [1, 2], "b": [0.707107] * 2}
assert_equal_data(result, expected)


def test_group_by_n_unique_w_missing(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
Expand Down Expand Up @@ -357,11 +352,3 @@ def test_group_by_shift_raises(
ValueError, match=".*(failed to aggregate|Non-trivial complex aggregation found)"
):
df.group_by("b").agg(nw.col("a").shift(1))


def test_group_by_count(constructor: Constructor) -> None:
data = {"a": [1, 1, 1, 2], "b": [1, None, 2, 3]}
df = nw.from_native(constructor(data))
result = df.group_by("a").agg(nw.col("b").count()).sort("a")
expected = {"a": [1, 2], "b": [2, 1]}
assert_equal_data(result, expected)

0 comments on commit 3eec7a7

Please sign in to comment.