From d103573177ba5c7266ff4fca1bf71c0870a3cc4e Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Tue, 10 Dec 2024 17:31:26 -0800 Subject: [PATCH] fix: boolean and/or expressions with null (#3544) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolves #3512 A fun truth table ($V_{x}$ is the validity $x$, $L$ is left, and $R$ is right): | $L$ | $V_{L}$ | $R$ | $V_{R}$ | $∧$ | $V_{∧}$ | $∨$ | $V_{∨}$ | | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | | ⬜️ | ⬜️ | ⬜️ | ⬜️ | | ⬜️ | | ⬜️ | | ⬜️ | ⬜️ | ⬜️ | ✅ | ⬜️ | ✅ | | ⬜️ | | ⬜️ | ⬜️ | ✅ | ⬜️ | | ⬜️ | | ⬜️ | | ⬜️ | ⬜️ | ✅ | ✅ | | ⬜️ | ✅ | ✅ | | ⬜️ | ✅ | ⬜️ | ⬜️ | ⬜️ | ✅ | | ⬜️ | | ⬜️ | ✅ | ⬜️ | ✅ | ⬜️ | ✅ | ⬜️ | ✅ | | ⬜️ | ✅ | ✅ | ⬜️ | ⬜️ | ✅ | | ⬜️ | | ⬜️ | ✅ | ✅ | ✅ | ⬜️ | ✅ | ✅ | ✅ | | ✅ | ⬜️ | ⬜️ | ⬜️ | | ⬜️ | | ⬜️ | | ✅ | ⬜️ | ⬜️ | ✅ | ⬜️ | ✅ | | ⬜️ | | ✅ | ⬜️ | ✅ | ⬜️ | | ⬜️ | | ⬜️ | | ✅ | ⬜️ | ✅ | ✅ | | ⬜️ | ✅ | ✅ | | ✅ | ✅ | ⬜️ | ⬜️ | | ⬜️ | ✅ | ✅ | | ✅ | ✅ | ⬜️ | ✅ | ⬜️ | ✅ | ✅ | ✅ | | ✅ | ✅ | ✅ | ⬜️ | | ⬜️ | ✅ | ✅ | | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | --- src/daft-core/src/array/ops/comparison.rs | 140 ++++++++++++++++------ tests/series/test_comparisons.py | 116 +++++++++++++----- tests/table/test_between.py | 4 +- 3 files changed, 194 insertions(+), 66 deletions(-) diff --git a/src/daft-core/src/array/ops/comparison.rs b/src/daft-core/src/array/ops/comparison.rs index 0a48557cdf..2b9f855286 100644 --- a/src/daft-core/src/array/ops/comparison.rs +++ b/src/daft-core/src/array/ops/comparison.rs @@ -601,13 +601,47 @@ impl Not for &BooleanArray { impl DaftLogical<&Self> for BooleanArray { type Output = DaftResult; fn and(&self, rhs: &Self) -> Self::Output { + // When performing a logical AND with a NULL value: + // - If the non-null value is false, the result is false (not null) + // - If the non-null value is true, the result is null + fn and_with_null(name: &str, arr: &BooleanArray) -> BooleanArray { + let values = arr.as_arrow().values(); + + let new_validity = match arr.as_arrow().validity() { + None => values.not(), + Some(validity) => arrow2::bitmap::and(&values.not(), validity), + }; + + BooleanArray::from(( + name, + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + values.clone(), + Some(new_validity), + ), + )) + } + match (self.len(), rhs.len()) { (x, y) if x == y => { - let validity = - arrow_bitmap_and_helper(self.as_arrow().validity(), rhs.as_arrow().validity()); - - let result_bitmap = - arrow2::bitmap::and(self.as_arrow().values(), rhs.as_arrow().values()); + let l_values = self.as_arrow().values(); + let r_values = rhs.as_arrow().values(); + + // (false & NULL) should be false, compute validity to ensure that + let validity = match (self.as_arrow().validity(), rhs.as_arrow().validity()) { + (None, None) => None, + (None, Some(r_valid)) => Some(arrow2::bitmap::or(&l_values.not(), r_valid)), + (Some(l_valid), None) => Some(arrow2::bitmap::or(l_valid, &r_values.not())), + (Some(l_valid), Some(r_valid)) => Some(arrow2::bitmap::or( + &arrow2::bitmap::or( + &arrow2::bitmap::and(&l_values.not(), l_valid), + &arrow2::bitmap::and(&r_values.not(), r_valid), + ), + &arrow2::bitmap::and(l_valid, r_valid), + )), + }; + + let result_bitmap = arrow2::bitmap::and(l_values, r_values); Ok(Self::from(( self.name(), arrow2::array::BooleanArray::new( @@ -617,18 +651,18 @@ impl DaftLogical<&Self> for BooleanArray { ), ))) } - (l_size, 1) => { + (_, 1) => { if let Some(value) = rhs.get(0) { self.and(value) } else { - Ok(Self::full_null(self.name(), &DataType::Boolean, l_size)) + Ok(and_with_null(self.name(), self)) } } - (1, r_size) => { + (1, _) => { if let Some(value) = self.get(0) { rhs.and(value) } else { - Ok(Self::full_null(self.name(), &DataType::Boolean, r_size)) + Ok(and_with_null(self.name(), rhs)) } } (l, r) => Err(DaftError::ValueError(format!( @@ -640,13 +674,47 @@ impl DaftLogical<&Self> for BooleanArray { } fn or(&self, rhs: &Self) -> Self::Output { + // When performing a logical OR with a NULL value: + // - If the non-null value is false, the result is null + // - If the non-null value is true, the result is true (not null) + fn or_with_null(name: &str, arr: &BooleanArray) -> BooleanArray { + let values = arr.as_arrow().values(); + + let new_validity = match arr.as_arrow().validity() { + None => values.clone(), + Some(validity) => arrow2::bitmap::and(values, validity), + }; + + BooleanArray::from(( + name, + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + values.clone(), + Some(new_validity), + ), + )) + } + match (self.len(), rhs.len()) { (x, y) if x == y => { - let validity = - arrow_bitmap_and_helper(self.as_arrow().validity(), rhs.as_arrow().validity()); - - let result_bitmap = - arrow2::bitmap::or(self.as_arrow().values(), rhs.as_arrow().values()); + let l_values = self.as_arrow().values(); + let r_values = rhs.as_arrow().values(); + + // (true | NULL) should be true, compute validity to ensure that + let validity = match (self.as_arrow().validity(), rhs.as_arrow().validity()) { + (None, None) => None, + (None, Some(r_valid)) => Some(arrow2::bitmap::or(l_values, r_valid)), + (Some(l_valid), None) => Some(arrow2::bitmap::or(l_valid, r_values)), + (Some(l_valid), Some(r_valid)) => Some(arrow2::bitmap::or( + &arrow2::bitmap::or( + &arrow2::bitmap::and(l_values, l_valid), + &arrow2::bitmap::and(r_values, r_valid), + ), + &arrow2::bitmap::and(l_valid, r_valid), + )), + }; + + let result_bitmap = arrow2::bitmap::or(l_values, r_values); Ok(Self::from(( self.name(), arrow2::array::BooleanArray::new( @@ -656,18 +724,18 @@ impl DaftLogical<&Self> for BooleanArray { ), ))) } - (l_size, 1) => { + (_, 1) => { if let Some(value) = rhs.get(0) { self.or(value) } else { - Ok(Self::full_null(self.name(), &DataType::Boolean, l_size)) + Ok(or_with_null(self.name(), self)) } } - (1, r_size) => { + (1, _) => { if let Some(value) = self.get(0) { rhs.or(value) } else { - Ok(Self::full_null(self.name(), &DataType::Boolean, r_size)) + Ok(or_with_null(self.name(), rhs)) } } (l, r) => Err(DaftError::ValueError(format!( @@ -756,33 +824,33 @@ impl DaftCompare<&Self> for NullArray { impl DaftLogical for BooleanArray { type Output = DaftResult; fn and(&self, rhs: bool) -> Self::Output { - let validity = self.as_arrow().validity(); if rhs { Ok(self.clone()) } else { - use arrow2::{array, bitmap::Bitmap, datatypes::DataType}; - let arrow_array = array::BooleanArray::new( - DataType::Boolean, - Bitmap::new_zeroed(self.len()), - validity.cloned(), - ); - Ok(Self::from((self.name(), arrow_array))) + Ok(Self::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + arrow2::bitmap::Bitmap::new_zeroed(self.len()), + None, // false & x is always valid false for any x + ), + ))) } } fn or(&self, rhs: bool) -> Self::Output { - let validity = self.as_arrow().validity(); if rhs { - use arrow2::{array, bitmap::Bitmap, datatypes::DataType}; - let arrow_array = array::BooleanArray::new( - DataType::Boolean, - Bitmap::new_zeroed(self.len()).not(), - validity.cloned(), - ); - return Ok(Self::from((self.name(), arrow_array))); + Ok(Self::from(( + self.name(), + arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + arrow2::bitmap::Bitmap::new_trued(self.len()), + None, // true | x is always valid true for any x + ), + ))) + } else { + Ok(self.clone()) } - - Ok(self.clone()) } fn xor(&self, rhs: bool) -> Self::Output { diff --git a/tests/series/test_comparisons.py b/tests/series/test_comparisons.py index 37a52330c6..6fd016a3e0 100644 --- a/tests/series/test_comparisons.py +++ b/tests/series/test_comparisons.py @@ -10,7 +10,16 @@ from daft import DataType, Series -arrow_int_types = [pa.int8(), pa.uint8(), pa.int16(), pa.uint16(), pa.int32(), pa.uint32(), pa.int64(), pa.uint64()] +arrow_int_types = [ + pa.int8(), + pa.uint8(), + pa.int16(), + pa.uint16(), + pa.int32(), + pa.uint32(), + pa.int64(), + pa.uint64(), +] arrow_decimal_types = [pa.decimal128(20, 5), pa.decimal128(15, 9)] arrow_string_types = [pa.string(), pa.large_string()] arrow_float_types = [pa.float32(), pa.float64()] @@ -139,7 +148,8 @@ def test_comparisons_int_and_str_right_null_scalar(l_dtype, r_dtype) -> None: @pytest.mark.parametrize( - "l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_float_types + arrow_decimal_types, repeat=2) + "l_dtype, r_dtype", + itertools.product(arrow_int_types + arrow_float_types + arrow_decimal_types, repeat=2), ) def test_comparisons_int_and_float(l_dtype, r_dtype) -> None: l_arrow = make_array([1, 2, 3, None, 5, None], type=l_dtype) @@ -168,7 +178,8 @@ def test_comparisons_int_and_float(l_dtype, r_dtype) -> None: @pytest.mark.parametrize( - "l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_float_types + arrow_decimal_types, repeat=2) + "l_dtype, r_dtype", + itertools.product(arrow_int_types + arrow_float_types + arrow_decimal_types, repeat=2), ) def test_comparisons_int_and_float_right_scalar(l_dtype, r_dtype) -> None: l_arrow = make_array([1, 2, 3, None, 5, None], type=l_dtype) @@ -197,7 +208,8 @@ def test_comparisons_int_and_float_right_scalar(l_dtype, r_dtype) -> None: @pytest.mark.parametrize( - "l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_float_types + arrow_decimal_types, repeat=2) + "l_dtype, r_dtype", + itertools.product(arrow_int_types + arrow_float_types + arrow_decimal_types, repeat=2), ) def test_comparisons_int_and_float_right_null_scalar(l_dtype, r_dtype) -> None: l_arrow = make_array([1, 2, 3, None, 5, None], type=l_dtype) @@ -226,39 +238,39 @@ def test_comparisons_int_and_float_right_null_scalar(l_dtype, r_dtype) -> None: def test_comparisons_boolean_array() -> None: - l_arrow = make_array([False, False, None, True, None]) - r_arrow = make_array([True, False, True, None, None]) + l_arrow = make_array([True, False, True, False, True, False, None, None, None]) + r_arrow = make_array([True, True, False, False, None, None, True, False, None]) # lt, eq, lt, None left = Series.from_arrow(l_arrow) right = Series.from_arrow(r_arrow) lt = (left < right).to_pylist() - assert lt == [True, False, None, None, None] + assert lt == [False, True, False, False, None, None, None, None, None] le = (left <= right).to_pylist() - assert le == [True, True, None, None, None] + assert le == [True, True, False, True, None, None, None, None, None] eq = (left == right).to_pylist() - assert eq == [False, True, None, None, None] + assert eq == [True, False, False, True, None, None, None, None, None] neq = (left != right).to_pylist() - assert neq == [True, False, None, None, None] + assert neq == [False, True, True, False, None, None, None, None, None] ge = (left >= right).to_pylist() - assert ge == [False, True, None, None, None] + assert ge == [True, False, True, True, None, None, None, None, None] gt = (left > right).to_pylist() - assert gt == [False, False, None, None, None] + assert gt == [False, False, True, False, None, None, None, None, None] _and = (left & right).to_pylist() - assert _and == [False, False, None, None, None] + assert _and == [True, False, False, False, None, False, None, False, None] _or = (left | right).to_pylist() - assert _or == [True, False, None, None, None] + assert _or == [True, True, True, False, True, None, True, None, None] _xor = (left ^ right).to_pylist() - assert _xor == [True, False, None, None, None] + assert _xor == [False, True, True, False, None, None, None, None, None] def test_comparisons_boolean_array_right_scalar() -> None: @@ -290,7 +302,7 @@ def test_comparisons_boolean_array_right_scalar() -> None: assert _and == [False, True, None] _or = (left | right).to_pylist() - assert _or == [True, True, None] + assert _or == [True, True, True] _xor = (left ^ right).to_pylist() assert _xor == [True, False, None] @@ -317,7 +329,7 @@ def test_comparisons_boolean_array_right_scalar() -> None: assert gt == [False, True, None] _and = (left & right).to_pylist() - assert _and == [False, False, None] + assert _and == [False, False, False] _or = (left | right).to_pylist() assert _or == [False, True, None] @@ -347,10 +359,10 @@ def test_comparisons_boolean_array_right_scalar() -> None: assert gt == [None, None, None] _and = (left & right).to_pylist() - assert _and == [None, None, None] + assert _and == [False, None, None] _or = (left | right).to_pylist() - assert _or == [None, None, None] + assert _or == [None, True, None] _xor = (left ^ right).to_pylist() assert _xor == [None, None, None] @@ -386,7 +398,7 @@ def test_comparisons_boolean_array_left_scalar() -> None: assert _and == [False, True, None] _oright = (left | right).to_pylist() - assert _oright == [True, True, None] + assert _oright == [True, True, True] _xoright = (left ^ right).to_pylist() assert _xoright == [True, False, None] @@ -490,8 +502,26 @@ def test_logical_ops_with_non_boolean() -> None: def test_comparisons_dates() -> None: from datetime import date - left = Series.from_pylist([date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 3), None, date(2023, 1, 5), None]) - right = Series.from_pylist([date(2023, 1, 1), date(2023, 1, 3), date(2023, 1, 1), date(2023, 1, 5), None, None]) + left = Series.from_pylist( + [ + date(2023, 1, 1), + date(2023, 1, 2), + date(2023, 1, 3), + None, + date(2023, 1, 5), + None, + ] + ) + right = Series.from_pylist( + [ + date(2023, 1, 1), + date(2023, 1, 3), + date(2023, 1, 1), + date(2023, 1, 5), + None, + None, + ] + ) # eq, lt, gt, None, None, None @@ -702,12 +732,42 @@ def __ge__(self, other): @pytest.mark.parametrize( ["op", "reflected_op", "expected", "expected_self"], [ - (operator.eq, operator.eq, [False, True, False, None, None], [True, True, True, True, None]), - (operator.ne, operator.ne, [True, False, True, None, None], [False, False, False, False, None]), - (operator.lt, operator.gt, [False, False, True, None, None], [False, False, False, False, None]), - (operator.gt, operator.lt, [True, False, False, None, None], [False, False, False, False, None]), - (operator.le, operator.ge, [False, True, True, None, None], [True, True, True, True, None]), - (operator.ge, operator.le, [True, True, False, None, None], [True, True, True, True, None]), + ( + operator.eq, + operator.eq, + [False, True, False, None, None], + [True, True, True, True, None], + ), + ( + operator.ne, + operator.ne, + [True, False, True, None, None], + [False, False, False, False, None], + ), + ( + operator.lt, + operator.gt, + [False, False, True, None, None], + [False, False, False, False, None], + ), + ( + operator.gt, + operator.lt, + [True, False, False, None, None], + [False, False, False, False, None], + ), + ( + operator.le, + operator.ge, + [False, True, True, None, None], + [True, True, True, True, None], + ), + ( + operator.ge, + operator.le, + [True, True, False, None, None], + [True, True, True, True, None], + ), ], ) def test_comparisons_pyobjects(op, reflected_op, expected, expected_self) -> None: diff --git a/tests/table/test_between.py b/tests/table/test_between.py index 75fc074672..989a29b895 100644 --- a/tests/table/test_between.py +++ b/tests/table/test_between.py @@ -37,7 +37,7 @@ pytest.param([1.0, 2.0, 3.0, 4.0], 1.0, 2.0, [True, True, False, False], id="FloatFloatFloat"), pytest.param([1.0, 2.0, 3.0, 4.0], 1, 2, [True, True, False, False], id="FloatIntInt"), pytest.param([1.0, 2.0, 3.0, 4.0], 1, 2.0, [True, True, False, False], id="FloatIntFloat"), - pytest.param([1.0, 2.0, 3.0, 4.0], None, 1, [None, None, None, None], id="FloatNullInt"), + pytest.param([1.0, 2.0, 3.0, 4.0], None, 1, [None, False, False, False], id="FloatNullInt"), pytest.param([1.0, 2.0, 3.0, 4.0], 1, None, [None, None, None, None], id="FloatIntNull"), pytest.param([1.0, 2.0, 3.0, 4.0], None, None, [None, None, None, None], id="FloatNullNull"), pytest.param([None, None, None, None], None, None, [None, None, None, None], id="NullNullNull"), @@ -164,7 +164,7 @@ def test_between_bad_input() -> None: pytest.param([1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0], 2.0, [True, True, False, False], id="FloatFloatFloat"), pytest.param([1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0], 2, [True, True, False, False], id="FloatIntInt"), pytest.param([1.0, 2.0, 3.0, 4.0], [1, 2, 3, 4], 2.0, [True, True, False, False], id="FloatIntFloat"), - pytest.param([1.0, 2.0, 3.0, 4.0], [None, None, None, None], 1, [None, None, None, None], id="FloatNullInt"), + pytest.param([1.0, 2.0, 3.0, 4.0], [None, None, None, None], 1, [None, False, False, False], id="FloatNullInt"), pytest.param([1.0, 2.0, 3.0, 4.0], [1, 2, 3, 4], None, [None, None, None, None], id="FloatIntNull"), pytest.param( [1.0, 2.0, 3.0, 4.0], [None, None, None, None], None, [None, None, None, None], id="FloatNullNull"