From 6508d07a9e4888e8d0277c8ea3b6cf10d45a16c5 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 16:11:14 -0800 Subject: [PATCH] [FEAT] connect: add `df.{intersection,union}` --- .../src/translation/logical_plan.rs | 7 +- .../src/translation/logical_plan/set_op.rs | 65 +++++++++++++++++++ tests/connect/test_intersection.py | 21 ++++++ tests/connect/test_union.py | 36 ++++++++++ 4 files changed, 128 insertions(+), 1 deletion(-) create mode 100644 src/daft-connect/src/translation/logical_plan/set_op.rs create mode 100644 tests/connect/test_intersection.py create mode 100644 tests/connect/test_union.py diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index b6097d17ad..152a6f9510 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -6,7 +6,8 @@ use tracing::warn; use crate::translation::logical_plan::{ aggregate::aggregate, drop::drop, filter::filter, local_relation::local_relation, - project::project, range::range, read::read, to_df::to_df, with_columns::with_columns, + project::project, range::range, read::read, set_op::set_op, to_df::to_df, + with_columns::with_columns, }; mod aggregate; @@ -16,6 +17,7 @@ mod local_relation; mod project; mod range; mod read; +mod set_op; mod to_df; mod with_columns; @@ -82,6 +84,9 @@ pub async fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::Drop(d) => drop(*d) .await .wrap_err("Failed to apply drop to logical plan"), + RelType::SetOp(s) => set_op(*s) + .await + .wrap_err("Failed to apply set_op to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } } diff --git a/src/daft-connect/src/translation/logical_plan/set_op.rs b/src/daft-connect/src/translation/logical_plan/set_op.rs new file mode 100644 index 0000000000..066e4e44c5 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/set_op.rs @@ -0,0 +1,65 @@ +use eyre::{bail, Context}; +use spark_connect::set_operation::SetOpType; +use tracing::warn; + +use crate::translation::{to_logical_plan, Plan}; + +pub async fn set_op(set_op: spark_connect::SetOperation) -> eyre::Result { + let spark_connect::SetOperation { + left_input, + right_input, + set_op_type, + is_all, + by_name, + allow_missing_columns, + } = set_op; + + let Some(left_input) = left_input else { + bail!("Left input is required"); + }; + + let Some(right_input) = right_input else { + bail!("Right input is required"); + }; + + let set_op = SetOpType::try_from(set_op_type) + .wrap_err_with(|| format!("Invalid set operation type: {set_op_type}"))?; + + if let Some(by_name) = by_name { + warn!("Ignoring by_name: {by_name}"); + } + + if let Some(allow_missing_columns) = allow_missing_columns { + warn!("Ignoring allow_missing_columns: {allow_missing_columns}"); + } + + let mut left = Box::pin(to_logical_plan(*left_input)).await?; + let right = Box::pin(to_logical_plan(*right_input)).await?; + + left.psets.partitions.extend(right.psets.partitions); + + let is_all = is_all.unwrap_or(false); + + let builder = match set_op { + SetOpType::Unspecified => { + bail!("Unspecified set operation is not supported"); + } + SetOpType::Intersect => left + .builder + .intersect(&right.builder, is_all) + .wrap_err("Failed to apply intersect to logical plan"), + SetOpType::Union => left + .builder + .union(&right.builder, is_all) + .wrap_err("Failed to apply union to logical plan"), + SetOpType::Except => { + bail!("Except set operation is not supported"); + } + }?; + + // we merged left and right psets + Ok(Plan { + builder, + psets: left.psets, + }) +} diff --git a/tests/connect/test_intersection.py b/tests/connect/test_intersection.py new file mode 100644 index 0000000000..200f391f39 --- /dev/null +++ b/tests/connect/test_intersection.py @@ -0,0 +1,21 @@ +from __future__ import annotations + + +def test_intersection(spark_session): + # Create ranges using Spark - with overlap + range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6 + range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9 + + # Intersect the two ranges + intersected = range1.intersect(range2) + + # Collect results + results = intersected.collect() + + # Verify the DataFrame has expected values + # Intersection should only include overlapping values once + assert len(results) == 4, "DataFrame should have 4 rows (overlapping values 3,4,5,6)" + + # Check that all expected values are present + values = [row.id for row in results] + assert sorted(values) == [3, 4, 5, 6], "Values should match expected overlapping sequence" diff --git a/tests/connect/test_union.py b/tests/connect/test_union.py new file mode 100644 index 0000000000..34157fd2c1 --- /dev/null +++ b/tests/connect/test_union.py @@ -0,0 +1,36 @@ +from __future__ import annotations + + +def test_union(spark_session): + # Create ranges using Spark - with overlap + range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6 + range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9 + + # Union the two ranges + unioned = range1.union(range2) + + # Collect results + results = unioned.collect() + + # Verify the DataFrame has expected values + # Union includes duplicates, so length should be sum of both ranges + assert len(results) == 14, "DataFrame should have 14 rows (7 + 7)" + + # Check that all expected values are present, including duplicates + values = [row.id for row in results] + assert sorted(values) == [ + 0, + 1, + 2, + 3, + 3, + 4, + 4, + 5, + 5, + 6, + 6, + 7, + 8, + 9, + ], "Values should match expected sequence with duplicates"