diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index bf95827649..e96d006c6e 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -9,13 +9,14 @@ use tracing::warn; use crate::translation::logical_plan::{ aggregate::aggregate, local_relation::local_relation, project::project, range::range, - to_df::to_df, with_columns::with_columns, + set_op::set_op, to_df::to_df, with_columns::with_columns, }; mod aggregate; mod local_relation; mod project; mod range; +mod set_op; mod to_df; mod with_columns; @@ -57,6 +58,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::LocalRelation(l) => { local_relation(l).wrap_err("Failed to apply local_relation to logical plan") } + RelType::SetOp(s) => set_op(*s).wrap_err("Failed to apply set_op to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } } diff --git a/src/daft-connect/src/translation/logical_plan/set_op.rs b/src/daft-connect/src/translation/logical_plan/set_op.rs new file mode 100644 index 0000000000..077b5b6388 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/set_op.rs @@ -0,0 +1,62 @@ +use eyre::{bail, Context}; +use spark_connect::set_operation::SetOpType; +use tracing::warn; + +use crate::translation::{to_logical_plan, Plan}; + +pub fn set_op(set_op: spark_connect::SetOperation) -> eyre::Result { + let spark_connect::SetOperation { + left_input, + right_input, + set_op_type, + is_all, + by_name, + allow_missing_columns, + } = set_op; + + let Some(left_input) = left_input else { + bail!("Left input is required"); + }; + + let Some(right_input) = right_input else { + bail!("Right input is required"); + }; + + let set_op = SetOpType::try_from(set_op_type) + .wrap_err_with(|| format!("Invalid set operation type: {set_op_type}"))?; + + if let Some(by_name) = by_name { + warn!("Ignoring by_name: {by_name}"); + } + + if let Some(allow_missing_columns) = allow_missing_columns { + warn!("Ignoring allow_missing_columns: {allow_missing_columns}"); + } + + let mut left = to_logical_plan(*left_input)?; + let right = to_logical_plan(*right_input)?; + + left.psets.extend(right.psets); + + let is_all = is_all.unwrap_or(false); + + let builder = match set_op { + SetOpType::Unspecified => { + bail!("Unspecified set operation is not supported"); + } + SetOpType::Intersect => left + .builder + .intersect(&right.builder, is_all) + .wrap_err("Failed to apply intersect to logical plan"), + SetOpType::Union => left + .builder + .union(&right.builder, is_all) + .wrap_err("Failed to apply union to logical plan"), + SetOpType::Except => { + bail!("Except set operation is not supported"); + } + }?; + + // we merged left and right psets + Ok(Plan::new(builder, left.psets)) +} diff --git a/tests/connect/test_intersection.py b/tests/connect/test_intersection.py new file mode 100644 index 0000000000..7944de5cae --- /dev/null +++ b/tests/connect/test_intersection.py @@ -0,0 +1,21 @@ +from __future__ import annotations + + +def test_intersection(spark_session): + # Create ranges using Spark - with overlap + range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6 + range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9 + + # Intersect the two ranges + intersected = range1.intersect(range2) + + # Collect results + results = intersected.collect() + + # Verify the DataFrame has expected values + # Intersection should only include overlapping values once + assert len(results) == 4, "DataFrame should have 4 rows (overlapping values 3,4,5,6)" + + # Check that all expected values are present + values = [row.id for row in results] + assert sorted(values) == [3, 4, 5, 6], "Values should match expected overlapping sequence" diff --git a/tests/connect/test_union.py b/tests/connect/test_union.py new file mode 100644 index 0000000000..9ac235d9e5 --- /dev/null +++ b/tests/connect/test_union.py @@ -0,0 +1,21 @@ +from __future__ import annotations + + +def test_union(spark_session): + # Create ranges using Spark - with overlap + range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6 + range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9 + + # Union the two ranges + unioned = range1.union(range2) + + # Collect results + results = unioned.collect() + + # Verify the DataFrame has expected values + # Union includes duplicates, so length should be sum of both ranges + assert len(results) == 14, "DataFrame should have 14 rows (7 + 7)" + + # Check that all expected values are present, including duplicates + values = [row.id for row in results] + assert sorted(values) == [0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 8, 9], "Values should match expected sequence with duplicates"