Skip to content

Commit

Permalink
fix: Incorrectly preserved sorted flag when concatenating sorted seri…
Browse files Browse the repository at this point in the history
…es containing nulls (#15082)
  • Loading branch information
nameexhaustion authored Mar 15, 2024
1 parent 4933040 commit b9a5603
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 46 deletions.
27 changes: 19 additions & 8 deletions crates/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ impl<T: PolarsDataType> ChunkedArray<T> {
self.bit_settings.contains(Settings::SORTED_DSC)
}

/// Whether `self` is sorted in any direction.
pub(crate) fn is_sorted_any(&self) -> bool {
self.is_sorted_ascending_flag() || self.is_sorted_descending_flag()
}

pub fn unset_fast_explode_list(&mut self) {
self.bit_settings.remove(Settings::FAST_EXPLODE_LIST)
}
Expand Down Expand Up @@ -224,10 +229,7 @@ impl<T: PolarsDataType> ChunkedArray<T> {
None
}
// We now know there is at least 1 non-null item in the array, and self.len() > 0
else if matches!(
self.is_sorted_flag(),
IsSorted::Ascending | IsSorted::Descending
) {
else if self.is_sorted_any() {
let out = if unsafe { self.downcast_get_unchecked(0).is_null_unchecked(0) } {
// nulls are all at the start
self.null_count()
Expand All @@ -236,6 +238,12 @@ impl<T: PolarsDataType> ChunkedArray<T> {
0
};

debug_assert!(
// If we are lucky this catches something.
unsafe { self.get_unchecked(out) }.is_some(),
"incorrect sorted flag"
);

Some(out)
} else {
first_non_null(self.iter_validities())
Expand All @@ -248,10 +256,7 @@ impl<T: PolarsDataType> ChunkedArray<T> {
None
}
// We now know there is at least 1 non-null item in the array, and self.len() > 0
else if matches!(
self.is_sorted_flag(),
IsSorted::Ascending | IsSorted::Descending
) {
else if self.is_sorted_any() {
let out = if unsafe { self.downcast_get_unchecked(0).is_null_unchecked(0) } {
// nulls are all at the start
self.len() - 1
Expand All @@ -260,6 +265,12 @@ impl<T: PolarsDataType> ChunkedArray<T> {
self.len() - self.null_count() - 1
};

debug_assert!(
// If we are lucky this catches something.
unsafe { self.get_unchecked(out) }.is_some(),
"incorrect sorted flag"
);

Some(out)
} else {
last_non_null(self.iter_validities(), self.len())
Expand Down
107 changes: 69 additions & 38 deletions crates/polars-core/src/chunked_array/ops/append.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,51 +19,82 @@ where
T: PolarsDataType,
for<'a> T::Physical<'a>: TotalOrd,
{
// TODO: attempt to maintain sortedness better in case of nulls.
// Note: Do not call (first|last)_non_null on an array here before checking
// it is sorted, otherwise it will lead to quadratic behavior.
let sorted_flag = match (
ca.null_count() != ca.len(),
other.null_count() != other.len(),
) {
(false, false) => IsSorted::Ascending,
(false, true) => {
if
// lhs is empty, just take sorted flag from rhs
ca.is_empty()
|| (
// lhs is non-empty and all-null, so rhs must have nulls ordered first
other.is_sorted_any() && 1 + other.last_non_null().unwrap() == other.len()
)
{
other.is_sorted_flag()
} else {
IsSorted::Not
}
},
(true, false) => {
if
// rhs is empty, just take sorted flag from lhs
other.is_empty()
|| (
// rhs is non-empty and all-null, so lhs must have nulls ordered last
ca.is_sorted_any() && ca.first_non_null().unwrap() == 0
)
{
ca.is_sorted_flag()
} else {
IsSorted::Not
}
},
(true, true) => {
// both arrays have non-null values
if !ca.is_sorted_any()
|| !other.is_sorted_any()
|| ca.is_sorted_flag() != other.is_sorted_flag()
{
IsSorted::Not
} else {
let l_idx = ca.last_non_null().unwrap();
let r_idx = other.first_non_null().unwrap();

// If either is empty, copy the sorted flag from the other.
if ca.is_empty() {
ca.set_sorted_flag(other.is_sorted_flag());
return;
}
if other.is_empty() {
return;
}
let l_val = unsafe { ca.value_unchecked(l_idx) };
let r_val = unsafe { other.value_unchecked(r_idx) };

// Both need to be sorted, in the same order, if the order is maintained.
// TODO: rework sorted flags, ascending and descending are not mutually
// exclusive for all-equal/all-null arrays.
let ls = ca.is_sorted_flag();
let rs = other.is_sorted_flag();
if ls != rs || ls == IsSorted::Not || rs == IsSorted::Not {
ca.set_sorted_flag(IsSorted::Not);
return;
}
let keep_sorted =
// check null positions
// lhs does not end in nulls
(1 + l_idx == ca.len())
// rhs does not start with nulls
&& (r_idx == 0)
// if there are nulls, they are all on one end
&& !(ca.first_non_null().unwrap() != 0 && 1 + other.last_non_null().unwrap() != other.len());

let keep_sorted = keep_sorted
// compare values
&& if ca.is_sorted_ascending_flag() {
l_val.tot_le(&r_val)
} else {
l_val.tot_ge(&r_val)
};

// Check the order is maintained.
let still_sorted = {
// To prevent potential quadratic append behavior we do not find
// the last non-null element in ca.
if let Some(left) = ca.last() {
if let Some(right_idx) = other.first_non_null() {
let right = other.get(right_idx).unwrap();
if ca.is_sorted_ascending_flag() {
left.tot_le(&right)
if keep_sorted {
ca.is_sorted_flag()
} else {
left.tot_ge(&right)
IsSorted::Not
}
} else {
// Right is only nulls, trivially sorted.
true
}
} else {
// Last element in left is null, pessimistically assume not sorted.
false
}
},
};
if !still_sorted {
ca.set_sorted_flag(IsSorted::Not);
}

ca.set_sorted_flag(sorted_flag);
}

impl<T> ChunkedArray<T>
Expand Down
107 changes: 107 additions & 0 deletions py-polars/tests/unit/operations/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,3 +796,110 @@ def test_sorted_flag_14552() -> None:

a = pl.concat([a, a], rechunk=False)
assert not a.join(a, on="a", how="left")["a"].flags["SORTED_ASC"]


def test_sorted_flag_concat_15072() -> None:
def is_sorted_any(s: pl.Series) -> bool:
return s.flags["SORTED_ASC"] or s.flags["SORTED_DESC"]

def is_not_sorted(s: pl.Series) -> bool:
return not is_sorted_any(s)

# Both all-null
a = pl.Series("x", [None, None], dtype=pl.Int8)
b = pl.Series("x", [None, None], dtype=pl.Int8)
assert pl.concat((a, b)).flags["SORTED_ASC"]

# left all-null, right 0 < null_count < len
a = pl.Series("x", [None, None], dtype=pl.Int8)
b = pl.Series("x", [1, 2, 1, None], dtype=pl.Int8)

out = pl.concat((a, b.sort()))
assert out.to_list() == [None, None, None, 1, 1, 2]
assert out.flags["SORTED_ASC"]

out = pl.concat((a, b.sort(descending=True)))
assert out.to_list() == [None, None, None, 2, 1, 1]
assert out.flags["SORTED_DESC"]

out = pl.concat((a, b.sort(nulls_last=True)))
assert out.to_list() == [None, None, 1, 1, 2, None]
assert is_not_sorted(out)

out = pl.concat((a, b.sort(nulls_last=True, descending=True)))
assert out.to_list() == [None, None, 2, 1, 1, None]
assert is_not_sorted(out)

# left 0 < null_count < len, right all-null
a = pl.Series("x", [1, 2, 1, None], dtype=pl.Int8)
b = pl.Series("x", [None, None], dtype=pl.Int8)

out = pl.concat((a.sort(), b))
assert out.to_list() == [None, 1, 1, 2, None, None]
assert is_not_sorted(out)

out = pl.concat((a.sort(descending=True), b))
assert out.to_list() == [None, 2, 1, 1, None, None]
assert is_not_sorted(out)

out = pl.concat((a.sort(nulls_last=True), b))
assert out.to_list() == [1, 1, 2, None, None, None]
assert out.flags["SORTED_ASC"]

out = pl.concat((a.sort(nulls_last=True, descending=True), b))
assert out.to_list() == [2, 1, 1, None, None, None]
assert out.flags["SORTED_DESC"]

# both 0 < null_count < len
assert pl.concat(
(
pl.Series([None, 1]).set_sorted(),
pl.Series([2]).set_sorted(),
)
).flags["SORTED_ASC"]

assert is_not_sorted(
pl.concat(
(
pl.Series([None, 1]).set_sorted(),
pl.Series([2, None]).set_sorted(),
)
)
)

assert pl.concat(
(
pl.Series([None, 2]).set_sorted(descending=True),
pl.Series([1]).set_sorted(descending=True),
)
).flags["SORTED_DESC"]

assert is_not_sorted(
pl.concat(
(
pl.Series([None, 2]).set_sorted(descending=True),
pl.Series([1, None]).set_sorted(descending=True),
)
)
)

# Concat with empty series
s = pl.Series([None, 1]).set_sorted()

out = pl.concat((s.clear(), s))
assert_series_equal(out, s)
assert out.flags["SORTED_ASC"]

out = pl.concat((s, s.clear()))
assert_series_equal(out, s)
assert out.flags["SORTED_ASC"]

s = pl.Series([1, None]).set_sorted()

out = pl.concat((s.clear(), s))
assert_series_equal(out, s)
assert out.flags["SORTED_ASC"]

out = pl.concat((s, s.clear()))
assert_series_equal(out, s)
assert out.flags["SORTED_ASC"]

0 comments on commit b9a5603

Please sign in to comment.