Skip to content

Commit

Permalink
fix: SQL COUNT(DISTINCT x) should not include NULL values (pola-rs#…
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Jul 30, 2024
1 parent 82b6388 commit f58aa39
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
4 changes: 3 additions & 1 deletion crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::ops::Sub;

use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions};
use polars_core::export::regex;
use polars_core::prelude::{polars_bail, polars_err, DataType, PolarsResult, Schema, TimeUnit};
Expand Down Expand Up @@ -1573,7 +1575,7 @@ impl SQLFunctionVisitor<'_> {
(true, [FunctionArgExpr::Expr(sql_expr)]) => {
let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
let expr = self.apply_window_spec(expr, &self.func.over)?;
Ok(expr.n_unique())
Ok(expr.clone().n_unique().sub(expr.null_count().gt(lit(0))))
},
_ => self.not_supported_error(),
}
Expand Down
54 changes: 54 additions & 0 deletions py-polars/tests/unit/sql/test_miscellaneous.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,60 @@ def test_any_all() -> None:
}


def test_count() -> None:
df = pl.DataFrame(
{
"a": [1, 2, 3, 4, 5],
"b": [1, 1, 22, 22, 333],
"c": [1, 1, None, None, 2],
}
)
res = df.sql(
"""
SELECT
-- count
COUNT(a) AS count_a,
COUNT(b) AS count_b,
COUNT(c) AS count_c,
COUNT(*) AS count_star,
COUNT(NULL) AS count_null,
-- count distinct
COUNT(DISTINCT a) AS count_unique_a,
COUNT(DISTINCT b) AS count_unique_b,
COUNT(DISTINCT c) AS count_unique_c,
COUNT(DISTINCT NULL) AS count_unique_null,
FROM self
""",
)
assert res.to_dict(as_series=False) == {
"count_a": [5],
"count_b": [5],
"count_c": [3],
"count_star": [5],
"count_null": [0],
"count_unique_a": [5],
"count_unique_b": [3],
"count_unique_c": [2],
"count_unique_null": [0],
}

df = pl.DataFrame({"x": [None, None, None]})
res = df.sql(
"""
SELECT
COUNT(x) AS count_x,
COUNT(*) AS count_star,
COUNT(DISTINCT x) AS count_unique_x
FROM self
"""
)
assert res.to_dict(as_series=False) == {
"count_x": [0],
"count_star": [3],
"count_unique_x": [0],
}


def test_distinct() -> None:
df = pl.DataFrame(
{
Expand Down

0 comments on commit f58aa39

Please sign in to comment.