Skip to content

Commit

Permalink
fix(polars): fix polars std/var to properly handle sample/`popu…
Browse files Browse the repository at this point in the history
…lation`
  • Loading branch information
jcrist committed Jul 23, 2024
1 parent 8717629 commit f83d84f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 20 deletions.
60 changes: 40 additions & 20 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,39 +721,59 @@ def struct_column(op, **kw):
ops.All: "all",
ops.Any: "any",
ops.ApproxMedian: "median",
ops.Arbitrary: "first",
ops.Count: "count",
ops.CountDistinct: "n_unique",
ops.First: "first",
ops.Last: "last",
ops.Max: "max",
ops.Mean: "mean",
ops.Median: "median",
ops.Min: "min",
ops.StandardDev: "std",
ops.Sum: "sum",
ops.Variance: "var",
}

for reduction in _reductions.keys():

@translate.register(reduction)
def reduction(op, **kw):
args = [
translate(arg, **kw)
for name, arg in zip(op.argnames, op.args)
if name not in ("where", "how")
]
def execute_reduction(op, **kw):
arg = translate(op.arg, **kw)

if op.where is not None:
arg = arg.filter(translate(op.where, **kw))

method = _reductions[type(op)]

return getattr(arg, method)()


for cls in _reductions:
translate.register(cls, execute_reduction)


@translate.register(ops.First)
@translate.register(ops.Last)
@translate.register(ops.Arbitrary)
def execute_first_last(op, **kw):
arg = translate(op.arg, **kw)

# polars doesn't ignore nulls by default for these methods
predicate = arg.is_not_null()
if op.where is not None:
predicate &= translate(op.where, **kw)

arg = arg.filter(predicate)

return arg.last() if isinstance(op, ops.Last) else arg.first()

agg = _reductions[type(op)]

predicates = [arg.is_not_null() for arg in args]
if (where := op.where) is not None:
predicates.append(translate(where, **kw))
@translate.register(ops.StandardDev)
@translate.register(ops.Variance)
def execute_std_var(op, **kw):
arg = translate(op.arg, **kw)

if op.where is not None:
arg = arg.filter(translate(op.where, **kw))

method = "std" if isinstance(op, ops.StandardDev) else "var"
ddof = 0 if op.how == "pop" else 1

first, *rest = args
method = operator.methodcaller(agg, *rest)
return method(first.filter(reduce(operator.and_, predicates)))
return getattr(arg, method)(ddof=ddof)


@translate.register(ops.Mode)
Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,12 @@ def test_reduction_ops(
ibis_cond,
pandas_cond,
):
# Operate on a subset of the data, since aggregations like var/std with
# sample/population can be too numerically similar for a larger number of
# rows.
alltypes = alltypes.filter(alltypes.id < 1550)
df = df[df.id < 1550]

expr = alltypes.agg(tmp=result_fn(alltypes, ibis_cond(alltypes))).tmp
result = expr.execute().squeeze()
expected = expected_fn(df, pandas_cond(df))
Expand Down

0 comments on commit f83d84f

Please sign in to comment.