From cae1b42f8f1a7af5650a9490f2f0bb29ac5a3706 Mon Sep 17 00:00:00 2001 From: advancedxy <807537+advancedxy@users.noreply.github.com> Date: Wed, 11 Dec 2024 00:07:00 +0800 Subject: [PATCH] [FEAT]: Support intersect all and except distinct/all in DataFrame --- daft/daft/__init__.pyi | 1 + daft/dataframe/dataframe.py | 88 +++++ daft/logical/builder.py | 12 + src/daft-core/src/array/ops/list.rs | 46 ++- src/daft-core/src/series/ops/list.rs | 13 +- src/daft-functions/src/list/list_fill.rs | 63 ++++ src/daft-functions/src/list/mod.rs | 2 + src/daft-logical-plan/src/builder.rs | 15 +- src/daft-logical-plan/src/ops/mod.rs | 2 +- .../src/ops/set_operations.rs | 331 +++++++++++++++--- tests/conftest.py | 14 +- tests/dataframe/test_intersect.py | 47 --- tests/dataframe/test_set_ops.py | 112 ++++++ 13 files changed, 643 insertions(+), 103 deletions(-) create mode 100644 src/daft-functions/src/list/list_fill.rs delete mode 100644 tests/dataframe/test_intersect.py create mode 100644 tests/dataframe/test_set_ops.py diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 6860f72491..4e007c7034 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1634,6 +1634,7 @@ class LogicalPlanBuilder: ) -> LogicalPlanBuilder: ... def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: ... def intersect(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ... + def except_(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ... def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder: ... def table_write( self, diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index ca23c73cbb..64d72d2ba5 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -2542,6 +2542,94 @@ def intersect(self, other: "DataFrame") -> "DataFrame": builder = self._builder.intersect(other._builder) return DataFrame(builder) + @DataframePublicAPI + def intersect_all(self, other: "DataFrame") -> "DataFrame": + """Returns the intersection of two DataFrames, including duplicates. + + Example: + >>> import daft + >>> df1 = daft.from_pydict({"a": [1, 2, 2], "b": [4, 6, 6]}) + >>> df2 = daft.from_pydict({"a": [1, 1, 2, 2], "b": [4, 4, 6, 6]}) + >>> df1.intersect_all(df2).collect() + ╭───────┬───────╮ + │ a ┆ b │ + │ --- ┆ --- │ + │ Int64 ┆ Int64 │ + ╞═══════╪═══════╡ + │ 1 ┆ 4 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 2 ┆ 6 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 2 ┆ 6 │ + ╰───────┴───────╯ + + (Showing first 3 of 3 rows) + + Args: + other (DataFrame): DataFrame to intersect with + + Returns: + DataFrame: DataFrame with the intersection of the two DataFrames, including duplicates + """ + builder = self._builder.intersect_all(other._builder) + return DataFrame(builder) + + @DataframePublicAPI + def except_distinct(self, other: "DataFrame") -> "DataFrame": + """Returns the set difference of two DataFrames. + + Example: + >>> import daft + >>> df1 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) + >>> df2 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 8, 6]}) + >>> df1.except_distinct(df2).collect() + ╭───────┬───────╮ + │ a ┆ b │ + │ --- ┆ --- │ + │ Int64 ┆ Int64 │ + ╞═══════╪═══════╡ + │ 2 ┆ 5 │ + ╰───────┴───────╯ + + (Showing first 1 of 1 rows) + + Args: + other (DataFrame): DataFrame to except with + + Returns: + DataFrame: DataFrame with the set difference of the two DataFrames + """ + builder = self._builder.except_distinct(other._builder) + return DataFrame(builder) + + @DataframePublicAPI + def except_all(self, other: "DataFrame") -> "DataFrame": + """Returns the set difference of two DataFrames, considering duplicates. + + Example: + >>> import daft + >>> df1 = daft.from_pydict({"a": [1, 1, 2, 2], "b": [4, 4, 6, 6]}) + >>> df2 = daft.from_pydict({"a": [1, 2, 2], "b": [4, 6, 6]}) + >>> df1.except_all(df2).collect() + ╭───────┬───────╮ + │ a ┆ b │ + │ --- ┆ --- │ + │ Int64 ┆ Int64 │ + ╞═══════╪═══════╡ + │ 1 ┆ 4 │ + ╰───────┴───────╯ + + (Showing first 1 of 1 rows) + + Args: + other (DataFrame): DataFrame to except with + + Returns: + DataFrame: DataFrame with the set difference of the two DataFrames, considering duplicates + """ + builder = self._builder.except_all(other._builder) + return DataFrame(builder) + def _materialize_results(self) -> None: """Materializes the results of for this DataFrame and hold a pointer to the results.""" context = get_context() diff --git a/daft/logical/builder.py b/daft/logical/builder.py index b7316a0a80..c205ede871 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -279,6 +279,18 @@ def intersect(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: builder = self._builder.intersect(other._builder, False) return LogicalPlanBuilder(builder) + def intersect_all(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: + builder = self._builder.intersect(other._builder, True) + return LogicalPlanBuilder(builder) + + def except_distinct(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: + builder = self._builder.except_(other._builder, False) + return LogicalPlanBuilder(builder) + + def except_all(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: + builder = self._builder.except_(other._builder, True) + return LogicalPlanBuilder(builder) + def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder: builder = self._builder.add_monotonically_increasing_id(column_name) return LogicalPlanBuilder(builder) diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index 2e60efa550..4e3463ac24 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -1,6 +1,6 @@ use std::{iter::repeat, sync::Arc}; -use arrow2::offset::OffsetsBuffer; +use arrow2::offset::{Offsets, OffsetsBuffer}; use common_error::DaftResult; use indexmap::{ map::{raw_entry_v1::RawEntryMut, RawEntryApiV1}, @@ -255,6 +255,31 @@ fn list_sort_helper_fixed_size( .collect() } +fn general_list_fill_helper(element: &Series, num_array: &Int64Array) -> DaftResult> { + let num_iter = create_iter(num_array, element.len()); + let mut result = vec![]; + let element_data = element.as_physical()?; + for (row_index, num) in num_iter.enumerate() { + let list_arr = if element.is_valid(row_index) { + let mut list_growable = make_growable( + element.name(), + element.data_type(), + vec![&element_data], + false, + num as usize, + ); + for _ in 0..num { + list_growable.extend(0, row_index, 1); + } + list_growable.build()? + } else { + Series::full_null(element.name(), element.data_type(), num as usize) + }; + result.push(list_arr); + } + Ok(result) +} + impl ListArray { pub fn value_counts(&self) -> DaftResult { struct IndexRef { @@ -625,6 +650,25 @@ impl ListArray { self.validity().cloned(), )) } + + pub fn list_fill(elem: &Series, num_array: &Int64Array) -> DaftResult { + let generated = general_list_fill_helper(elem, num_array)?; + let generated_refs: Vec<&Series> = generated.iter().collect(); + let lengths = generated.iter().map(|arr| arr.len()); + let offsets = Offsets::try_from_lengths(lengths)?; + let flat_child = if generated_refs.is_empty() { + // when there's no output, we should create an empty series + Series::empty(elem.name(), elem.data_type()) + } else { + Series::concat(&generated_refs)? + }; + Ok(Self::new( + elem.field().to_list_field()?, + flat_child, + offsets.into(), + None, + )) + } } impl FixedSizeListArray { diff --git a/src/daft-core/src/series/ops/list.rs b/src/daft-core/src/series/ops/list.rs index c066fc463e..30b21e15ef 100644 --- a/src/daft-core/src/series/ops/list.rs +++ b/src/daft-core/src/series/ops/list.rs @@ -2,8 +2,9 @@ use common_error::{DaftError, DaftResult}; use daft_schema::field::Field; use crate::{ + array::ListArray, datatypes::{DataType, UInt64Array, Utf8Array}, - prelude::CountMode, + prelude::{CountMode, Int64Array}, series::{IntoSeries, Series}, }; @@ -217,4 +218,14 @@ impl Series { ))), } } + + /// Given a series of data T, repeat each data T with num times to create a list, returns + /// a series of repeated list. + /// # Example + /// ```txt + /// repeat([1, 2, 3], [2, 0, 1]) --> [[1, 1], [], [3]] + /// ``` + pub fn list_fill(&self, num: &Int64Array) -> DaftResult { + ListArray::list_fill(self, num).map(|arr| arr.into_series()) + } } diff --git a/src/daft-functions/src/list/list_fill.rs b/src/daft-functions/src/list/list_fill.rs new file mode 100644 index 0000000000..34e3118c56 --- /dev/null +++ b/src/daft-functions/src/list/list_fill.rs @@ -0,0 +1,63 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + datatypes::{DataType, Field}, + prelude::{Schema, Series}, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct ListFill {} + +#[typetag::serde] +impl ScalarUDF for ListFill { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "fill" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [n, elem] => { + let num_field = n.to_field(schema)?; + let elem_field = elem.to_field(schema)?; + if !num_field.dtype.is_integer() { + return Err(DaftError::TypeError(format!( + "Expected num field to be of numeric type, received: {}", + num_field.dtype + ))); + } + elem_field.to_list_field() + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [num, elem] => { + let num = num.cast(&DataType::Int64)?; + let num_array = num.i64()?; + elem.list_fill(num_array) + } + _ => Err(DaftError::ValueError(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn list_fill(n: ExprRef, elem: ExprRef) -> ExprRef { + ScalarFunction::new(ListFill {}, vec![n, elem]).into() +} diff --git a/src/daft-functions/src/list/mod.rs b/src/daft-functions/src/list/mod.rs index 504176207d..e57c32ce7d 100644 --- a/src/daft-functions/src/list/mod.rs +++ b/src/daft-functions/src/list/mod.rs @@ -3,6 +3,7 @@ mod count; mod explode; mod get; mod join; +mod list_fill; mod max; mod mean; mod min; @@ -17,6 +18,7 @@ pub use count::{list_count as count, ListCount}; pub use explode::{explode, Explode}; pub use get::{list_get as get, ListGet}; pub use join::{list_join as join, ListJoin}; +pub use list_fill::list_fill; pub use max::{list_max as max, ListMax}; pub use mean::{list_mean as mean, ListMean}; pub use min::{list_min as min, ListMin}; diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder.rs index 38921e71fe..5b61281f58 100644 --- a/src/daft-logical-plan/src/builder.rs +++ b/src/daft-logical-plan/src/builder.rs @@ -482,9 +482,17 @@ impl LogicalPlanBuilder { pub fn intersect(&self, other: &Self, is_all: bool) -> DaftResult { let logical_plan: LogicalPlan = ops::Intersect::try_new(self.plan.clone(), other.plan.clone(), is_all)? - .to_optimized_join()?; + .to_logical_plan()?; Ok(self.with_new_plan(logical_plan)) } + + pub fn except(&self, other: &Self, is_all: bool) -> DaftResult { + let logical_plan: LogicalPlan = + ops::Except::try_new(self.plan.clone(), other.plan.clone(), is_all)? + .to_logical_plan()?; + Ok(self.with_new_plan(logical_plan)) + } + pub fn union(&self, other: &Self, is_all: bool) -> DaftResult { let logical_plan: LogicalPlan = ops::Union::try_new(self.plan.clone(), other.plan.clone(), is_all)? @@ -861,6 +869,11 @@ impl PyLogicalPlanBuilder { Ok(self.builder.intersect(&other.builder, is_all)?.into()) } + #[pyo3(name = "except_")] + pub fn except(&self, other: &Self, is_all: bool) -> DaftResult { + Ok(self.builder.except(&other.builder, is_all)?.into()) + } + pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> PyResult { Ok(self .builder diff --git a/src/daft-logical-plan/src/ops/mod.rs b/src/daft-logical-plan/src/ops/mod.rs index e70c5c98d8..c042c04e7a 100644 --- a/src/daft-logical-plan/src/ops/mod.rs +++ b/src/daft-logical-plan/src/ops/mod.rs @@ -30,7 +30,7 @@ pub use pivot::Pivot; pub use project::Project; pub use repartition::Repartition; pub use sample::Sample; -pub use set_operations::{Intersect, Union}; +pub use set_operations::{Except, Intersect, Union}; pub use sink::Sink; pub use sort::Sort; pub use source::Source; diff --git a/src/daft-logical-plan/src/ops/set_operations.rs b/src/daft-logical-plan/src/ops/set_operations.rs index 65ceb807b8..06f5ace631 100644 --- a/src/daft-logical-plan/src/ops/set_operations.rs +++ b/src/daft-logical-plan/src/ops/set_operations.rs @@ -1,14 +1,94 @@ use std::sync::Arc; use common_error::DaftError; -use daft_core::{join::JoinType, utils::supertype::get_supertype}; -use daft_dsl::col; -use daft_schema::field::Field; +use daft_core::{count_mode::CountMode, join::JoinType, utils::supertype::get_supertype}; +use daft_dsl::{col, lit, null_lit, ExprRef}; +use daft_functions::list::{explode, list_fill}; +use daft_schema::{dtype::DataType, field::Field, schema::SchemaRef}; use snafu::ResultExt; -use super::{Concat, Distinct, Project}; +use super::{Aggregate, Concat, Distinct, Filter, Project}; use crate::{logical_plan, logical_plan::CreationSnafu, LogicalPlan}; +fn build_union_all_internal( + lhs: Arc, + rhs: Arc, + left_v_cols: Vec, + right_v_cols: Vec, +) -> logical_plan::Result { + let left_with_v_col = Project::try_new(lhs, left_v_cols)?; + let right_with_v_col = Project::try_new(rhs, right_v_cols)?; + Union::try_new(left_with_v_col.into(), right_with_v_col.into(), true)?.to_logical_plan() +} + +fn intersect_or_except_plan( + lhs: Arc, + rhs: Arc, + join_type: JoinType, +) -> logical_plan::Result { + let left_on = lhs + .schema() + .fields + .keys() + .map(|k| col(k.clone())) + .collect::>(); + let left_on_size = left_on.len(); + let right_on = rhs + .schema() + .fields + .keys() + .map(|k| col(k.clone())) + .collect::>(); + let join = logical_plan::Join::try_new( + lhs, + rhs, + left_on, + right_on, + Some(vec![true; left_on_size]), + join_type, + None, + None, + None, + false, + ); + join.map(|j| Distinct::new(j.into()).into()) +} + +fn check_structurally_equal( + lhs: SchemaRef, + rhs: SchemaRef, + operation: &str, +) -> logical_plan::Result<()> { + if lhs.len() != rhs.len() { + return Err(DaftError::SchemaMismatch(format!( + "Both schemas must have the same num of fields to {}, \ + but got[lhs: {} v.s rhs: {}], lhs schema: {}, rhs schema: {}", + operation, + lhs.len(), + rhs.len(), + lhs, + rhs + ))) + .context(CreationSnafu); + } + // lhs and rhs should have the same type for each field + // TODO: Support nested types recursively + if lhs + .fields + .values() + .zip(rhs.fields.values()) + .any(|(l, r)| l.dtype != r.dtype) + { + return Err(DaftError::SchemaMismatch(format!( + "Both schemas should have the same type for each field to {}, \ + but got lhs schema: {}, rhs schema: {}", + operation, lhs, rhs + ))) + .context(CreationSnafu); + } + Ok(()) +} + #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Intersect { // Upstream nodes. @@ -25,36 +105,12 @@ impl Intersect { ) -> logical_plan::Result { let lhs_schema = lhs.schema(); let rhs_schema = rhs.schema(); - if lhs_schema.len() != rhs_schema.len() { - return Err(DaftError::SchemaMismatch(format!( - "Both plans must have the same num of fields to intersect, \ - but got[lhs: {} v.s rhs: {}], lhs schema: {}, rhs schema: {}", - lhs_schema.len(), - rhs_schema.len(), - lhs_schema, - rhs_schema - ))) - .context(CreationSnafu); - } - // lhs and rhs should have the same type for each field to intersect - if lhs_schema - .fields - .values() - .zip(rhs_schema.fields.values()) - .any(|(l, r)| l.dtype != r.dtype) - { - return Err(DaftError::SchemaMismatch(format!( - "Both plans' schemas should have the same type for each field to intersect, \ - but got lhs schema: {}, rhs schema: {}", - lhs_schema, rhs_schema - ))) - .context(CreationSnafu); - } + check_structurally_equal(lhs_schema, rhs_schema, "intersect")?; Ok(Self { lhs, rhs, is_all }) } - /// intersect distinct could be represented as a semi join + distinct - /// the following intersect operator: + /// intersect operations could be represented by other logical plans + /// for intersect distinct, it could be represented as a semi join + distinct /// ```sql /// select a1, a2 from t1 intersect select b1, b2 from t2 /// ``` @@ -63,40 +119,104 @@ impl Intersect { /// select distinct a1, a2 from t1 left semi join t2 /// on t1.a1 <> t2.b1 and t1.a2 <> t2.b2 /// ``` + /// + /// for intersect all, it could be represented as group by + explode + /// ```sql + /// select a1 from t1 intersect all select a1 from t2 + /// ``` + /// is the same as: + /// ```sql + /// select a1 + /// from ( + /// select explode(list_fill(min_count, a1)) as a1 + /// from ( + /// select a1, if_else(v_l_cnt > v_r_cnt, v_r_cnt, v_l_cnt) as min_count + /// from ( + /// select count(v_col_l) as v_l_cnt, count(v_col_r) as v_r_cnt, a1 + /// from ( + /// select true as v_col_l, null as v_col_r, a1 from t1 + /// union all + /// select null as v_col_l, true as v_col_r, a1 from t2 + /// ) as union_all + /// group by a1 + /// ) + /// where v_l_cnt >= 1 and v_r_cnt >= 1 + /// ) + /// ) + /// ``` /// TODO: Move this logical to logical optimization rules - pub(crate) fn to_optimized_join(&self) -> logical_plan::Result { + pub(crate) fn to_logical_plan(&self) -> logical_plan::Result { if self.is_all { - Err(logical_plan::Error::CreationError { - source: DaftError::InternalError("intersect all is not supported yet".to_string()), - }) - } else { - let left_on = self + let left_cols = self .lhs .schema() .fields .keys() .map(|k| col(k.clone())) - .collect(); - let right_on = self + .collect::>(); + // project the right cols to have the same name as the left cols + let right_cols = self .rhs .schema() .fields .keys() .map(|k| col(k.clone())) - .collect(); - let join = logical_plan::Join::try_new( + .zip(left_cols.iter()) + .map(|(r, l)| r.alias(l.name())) + .collect::>(); + let virtual_col_l = "__v_col_l"; + let virtual_col_r = "__v_col_r"; + let left_v_cols = vec![ + lit(true).alias(virtual_col_l), + null_lit().cast(&DataType::Boolean).alias(virtual_col_r), + ]; + let right_v_cols = vec![ + null_lit().cast(&DataType::Boolean).alias(virtual_col_l), + lit(true).alias(virtual_col_r), + ]; + let left_v_cols = [left_v_cols, left_cols.clone()].concat(); + let right_v_cols = [right_v_cols, right_cols].concat(); + let union_all = build_union_all_internal( self.lhs.clone(), self.rhs.clone(), - left_on, - right_on, - Some(vec![true; self.lhs.schema().fields.len()]), - JoinType::Semi, - None, - None, - None, - false, - ); - join.map(|j| logical_plan::Distinct::new(j.into()).into()) + left_v_cols, + right_v_cols, + )?; + let one_lit = lit(1); + let left_v_cnt = col(virtual_col_l) + .count(CountMode::Valid) + .alias("__v_l_cnt"); + let right_v_cnt = col(virtual_col_r) + .count(CountMode::Valid) + .alias("__v_r_cnt"); + let count_name = "__min_count"; + let min_count = col("__v_l_cnt") + .gt(col("__v_r_cnt")) + .if_else(col("__v_r_cnt"), col("__v_l_cnt")) + .alias(count_name); + let aggregate_plan = Aggregate::try_new( + union_all.into(), + vec![left_v_cnt, right_v_cnt], + left_cols.clone(), + )?; + let filter_plan = Filter::try_new( + aggregate_plan.into(), + col("__v_l_cnt") + .gt_eq(one_lit.clone()) + .and(col("__v_r_cnt").gt_eq(one_lit)), + )?; + let min_count_plan = Project::try_new( + filter_plan.into(), + [vec![min_count], left_cols.clone()].concat(), + )?; + let fill_and_explodes = left_cols + .iter() + .map(|column| explode(list_fill(col(count_name), column.clone()))) + .collect::>(); + let project_plan = Project::try_new(min_count_plan.into(), fill_and_explodes)?; + Ok(project_plan.into()) + } else { + intersect_or_except_plan(self.lhs.clone(), self.rhs.clone(), JoinType::Semi) } } @@ -200,3 +320,112 @@ impl Union { res } } + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct Except { + // Upstream nodes. + pub lhs: Arc, + pub rhs: Arc, + pub is_all: bool, +} +impl Except { + pub(crate) fn try_new( + lhs: Arc, + rhs: Arc, + is_all: bool, + ) -> logical_plan::Result { + let lhs_schema = lhs.schema(); + let rhs_schema = rhs.schema(); + check_structurally_equal(lhs_schema, rhs_schema, "except")?; + Ok(Self { lhs, rhs, is_all }) + } + + /// except could be represented by other logical plans + /// for except distinct, it could be represented as a anti join + /// ```sql + /// select a1, a2 from t1 except select b1, b2 from t2 + /// ``` + /// is the same as: + /// ```sql + /// select distinct a1, a2 from t1 left anti join t2 + /// on t1.a1 <> t2.b1 and t1.a2 <> t2.b2 + /// ``` + /// + /// for except all, it could be represented as group by + explode + /// ```sql + /// select a1 from t1 except all select a1 from t2 + /// ``` + /// is the same as: + /// ```sql + /// select a1 + /// from ( + /// select explode(list_fill(sum, a1)) as a1 + /// from ( + /// select sum(v_col) as sum, a1 + /// from ( + /// select 1 as v_col, a1 from t1 + /// union all + /// select -1 as v_col, a1 from t2 + /// ) union_all + /// group by a1 + /// ) + /// where sum > 0 + /// ) + /// ``` + /// TODO: Move this logical to logical optimization rules + pub(crate) fn to_logical_plan(&self) -> logical_plan::Result { + if self.is_all { + let left_cols = self + .lhs + .schema() + .fields + .keys() + .map(|k| col(k.clone())) + .collect::>(); + // project the right cols to have the same name as the left cols + let right_cols = self + .rhs + .schema() + .fields + .keys() + .map(|k| col(k.clone())) + .zip(left_cols.iter()) + .map(|(r, l)| r.alias(l.name())) + .collect::>(); + let virtual_col = "__v_col"; + let left_v_cols = vec![lit(1).alias(virtual_col)]; + let right_v_cols = vec![lit(-1).alias(virtual_col)]; + let left_v_cols = [left_v_cols, left_cols.clone()].concat(); + let right_v_cols = [right_v_cols, right_cols].concat(); + let union_all = build_union_all_internal( + self.lhs.clone(), + self.rhs.clone(), + left_v_cols, + right_v_cols, + )?; + let sum_name = "__sum"; + let sum = col(virtual_col).sum().alias(sum_name); + let aggregate_plan = + Aggregate::try_new(union_all.into(), vec![sum], left_cols.clone())?; + let filter_plan = Filter::try_new(aggregate_plan.into(), col(sum_name).gt(lit(0)))?; + let fill_and_explodes = left_cols + .iter() + .map(|column| explode(list_fill(col(sum_name), column.clone()))) + .collect::>(); + let project_plan = Project::try_new(filter_plan.into(), fill_and_explodes)?; + Ok(project_plan.into()) + } else { + intersect_or_except_plan(self.lhs.clone(), self.rhs.clone(), JoinType::Anti) + } + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![]; + if self.is_all { + res.push("Except All:".to_string()); + } else { + res.push("Except:".to_string()); + } + res + } +} diff --git a/tests/conftest.py b/tests/conftest.py index 74e8e1e771..ad5d761f28 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ from __future__ import annotations import uuid -from typing import Literal +from typing import Any, Literal import pandas as pd import pyarrow as pa @@ -184,6 +184,18 @@ def assert_df_equals( raise +def check_answer(df: daft.DataFrame, expected_answer: dict[str, Any], is_sorted: bool = False): + daft_df = df.to_pandas() + expected_df = daft.from_pydict(expected_answer).to_pandas() + # when this is an empty result, no need to check data types. + check_dtype = not expected_df.empty + if is_sorted: + assert_df_equals(daft_df, expected_df, assert_ordering=True, check_dtype=check_dtype) + else: + sort_keys = df.column_names + assert_df_equals(daft_df, expected_df, sort_key=sort_keys, assert_ordering=False, check_dtype=check_dtype) + + @pytest.fixture( scope="function", params=[1, None] if get_tests_daft_runner_name() == "native" else [None], diff --git a/tests/dataframe/test_intersect.py b/tests/dataframe/test_intersect.py deleted file mode 100644 index 59e9d9a79e..0000000000 --- a/tests/dataframe/test_intersect.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -import daft -from daft import col - - -def test_simple_intersect(make_df): - df1 = make_df({"foo": [1, 2, 3]}) - df2 = make_df({"bar": [2, 3, 4]}) - result = df1.intersect(df2).sort(by="foo") - assert result.to_pydict() == {"foo": [2, 3]} - - -def test_intersect_with_duplicate(make_df): - df1 = make_df({"foo": [1, 2, 2, 3]}) - df2 = make_df({"bar": [2, 3, 3]}) - result = df1.intersect(df2).sort(by="foo") - assert result.to_pydict() == {"foo": [2, 3]} - - -def test_self_intersect(make_df): - df = make_df({"foo": [1, 2, 3]}) - result = df.intersect(df).sort(by="foo") - assert result.to_pydict() == {"foo": [1, 2, 3]} - - -def test_intersect_empty(make_df): - df1 = make_df({"foo": [1, 2, 3]}) - df2 = make_df({"bar": []}).select(col("bar").cast(daft.DataType.int64())) - result = df1.intersect(df2) - assert result.to_pydict() == {"foo": []} - - -def test_intersect_with_nulls(make_df): - df1 = make_df({"foo": [1, 2, None]}) - df1_without_mull = make_df({"foo": [1, 2]}) - df2 = make_df({"bar": [2, 3, None]}) - df2_without_null = make_df({"bar": [2, 3]}) - - result = df1.intersect(df2).sort(by="foo") - assert result.to_pydict() == {"foo": [2, None]} - - result = df1_without_mull.intersect(df2) - assert result.to_pydict() == {"foo": [2]} - - result = df1.intersect(df2_without_null) - assert result.to_pydict() == {"foo": [2]} diff --git a/tests/dataframe/test_set_ops.py b/tests/dataframe/test_set_ops.py new file mode 100644 index 0000000000..6509b0f67a --- /dev/null +++ b/tests/dataframe/test_set_ops.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import daft +from daft import DataFrame, col +from tests.conftest import check_answer + + +def helper(make_df, op: str, left: dict[str, Any], right: dict[str, Any], expected: dict[str, Any]): + df1 = make_df(left) + df2 = make_df(right) + df_helper(op, df1, df2, expected) + + +def df_helper(op: str, df1: DataFrame, df2: DataFrame, expected: dict[str, Any]): + if op == "intersect": + result = df1.intersect(df2) + elif op == "except_distinct": + result = df1.except_distinct(df2) + elif op == "intersect_all": + result = df1.intersect_all(df2) + else: + result = df1.except_all(df2) + check_answer(result, expected) + + +@pytest.mark.parametrize( + "op, left, right, expected", + [ + ("intersect", {"foo": [1, 2, 3]}, {"bar": [2, 3, 4]}, {"foo": [2, 3]}), + ("intersect_all", {"foo": [1, 2, 2]}, {"bar": [2, 2, 4]}, {"foo": [2, 2]}), + ("except_distinct", {"foo": [1, 2, 3]}, {"bar": [2, 3, 4]}, {"foo": [1]}), + ("except_all", {"foo": [1, 2, 2]}, {"bar": [2, 4]}, {"foo": [1, 2]}), + ], +) +def test_simple_intersect_or_except(make_df, op, left, right, expected): + helper(make_df, op, left, right, expected) + + +@pytest.mark.parametrize( + "op, left, right, expected", + [ + ("intersect", {"foo": [1, 2, 2, 3]}, {"bar": [2, 3, 3]}, {"foo": [2, 3]}), + ("intersect_all", {"foo": [1, 2, 2, 3]}, {"bar": [2, 3, 3]}, {"foo": [2, 3]}), + ("except_distinct", {"foo": [1, 2, 2, 3]}, {"bar": [2, 3, 3]}, {"foo": [1]}), + ("except_all", {"foo": [1, 2, 2, 3]}, {"bar": [2, 3, 3]}, {"foo": [1, 2]}), + ], +) +def test_with_duplicate(make_df, op, left, right, expected): + helper(make_df, op, left, right, expected) + + +@pytest.mark.parametrize( + "op, df, expected", + [ + ("intersect", {"foo": [1, 2, 3]}, {"foo": [1, 2, 3]}), + ("intersect_all", {"foo": [1, 2, 3]}, {"foo": [1, 2, 3]}), + ("except_distinct", {"foo": [1, 2, 3]}, {"foo": []}), + ("except_all", {"foo": [1, 2, 2]}, {"foo": []}), + ], +) +def test_with_self(make_df, op, df, expected): + df = make_df(df) + df_helper(op, df, df, expected) + + +@pytest.mark.parametrize( + "op, left, expected", + [ + ("intersect", {"foo": [1, 2, 3]}, {"foo": []}), + ("intersect_all", {"foo": [1, 2, 3]}, {"foo": []}), + ("except_distinct", {"foo": [1, 2, 3]}, {"foo": [1, 2, 3]}), + ("except_all", {"foo": [1, 2, 2]}, {"foo": [1, 2, 2]}), + ], +) +def test_with_empty(make_df, op, left, expected): + df1 = make_df(left) + df2 = make_df({"bar": []}).select(col("bar").cast(daft.DataType.int64())) + df_helper(op, df1, df2, expected) + + +@pytest.mark.parametrize( + "op, left, right, expected", + [ + ("intersect", {"foo": [1, 2, None]}, {"foo": [2, 3, None]}, {"foo": [2, None]}), + ("intersect_all", {"foo": [1, 2, None]}, {"foo": [2, 3, None]}, {"foo": [2, None]}), + ("intersect", {"foo": [1, 2]}, {"foo": [2, 3, None]}, {"foo": [2]}), + ("intersect_all", {"foo": [1, 2]}, {"foo": [2, 3, None]}, {"foo": [2]}), + ("intersect", {"foo": [1, 2, None]}, {"foo": [2, 3]}, {"foo": [2]}), + ("intersect_all", {"foo": [1, 2, None]}, {"foo": [2, 3]}, {"foo": [2]}), + ], +) +def test_intersect_with_nulls(make_df, op, left, right, expected): + helper(make_df, op, left, right, expected) + + +@pytest.mark.parametrize( + "op, left, right, expected", + [ + ("except_distinct", {"foo": [1, 2, None]}, {"foo": [2, 3, None]}, {"foo": [1]}), + ("except_all", {"foo": [1, 2, None]}, {"foo": [2, 3, None]}, {"foo": [1]}), + ("except_distinct", {"foo": [1, 2]}, {"foo": [2, 3, None]}, {"foo": [1]}), + ("except_all", {"foo": [1, 2]}, {"foo": [2, 3, None]}, {"foo": [1]}), + ("except_distinct", {"foo": [1, 2, None]}, {"foo": [2, 3]}, {"foo": [1, None]}), + ("except_all", {"foo": [1, 2, None]}, {"foo": [2, 3]}, {"foo": [1, None]}), + ], +) +def test_except_with_nulls(make_df, op, left, right, expected): + helper(make_df, op, left, right, expected)