Skip to content

Commit

Permalink
[FEAT] connect: add df.{intersection,union}
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 19, 2024
1 parent ae74c10 commit 1e96247
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ mod local_relation;
mod project;
mod range;
mod read;
mod set_op;
mod to_df;
mod with_columns;

Expand Down Expand Up @@ -129,6 +130,9 @@ impl SparkAnalyzer<'_> {
.await
.wrap_err("Failed to show string")
}
RelType::SetOp(s) => set_op(*s)
.await
.wrap_err("Failed to apply set_op to logical plan"),
plan => bail!("Unsupported relation type: {plan:?}"),
}
}
Expand Down
65 changes: 65 additions & 0 deletions src/daft-connect/src/translation/logical_plan/set_op.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use eyre::{bail, Context};
use spark_connect::set_operation::SetOpType;
use tracing::warn;

use crate::translation::{to_logical_plan, Plan};

pub async fn set_op(set_op: spark_connect::SetOperation) -> eyre::Result<Plan> {
let spark_connect::SetOperation {
left_input,
right_input,
set_op_type,
is_all,
by_name,
allow_missing_columns,
} = set_op;

let Some(left_input) = left_input else {
bail!("Left input is required");
};

let Some(right_input) = right_input else {
bail!("Right input is required");
};

let set_op = SetOpType::try_from(set_op_type)
.wrap_err_with(|| format!("Invalid set operation type: {set_op_type}"))?;

if let Some(by_name) = by_name {
warn!("Ignoring by_name: {by_name}");
}

if let Some(allow_missing_columns) = allow_missing_columns {
warn!("Ignoring allow_missing_columns: {allow_missing_columns}");
}

let mut left = Box::pin(to_logical_plan(*left_input)).await?;
let right = Box::pin(to_logical_plan(*right_input)).await?;

left.psets.partitions.extend(right.psets.partitions);

let is_all = is_all.unwrap_or(false);

let builder = match set_op {
SetOpType::Unspecified => {
bail!("Unspecified set operation is not supported");
}
SetOpType::Intersect => left
.builder
.intersect(&right.builder, is_all)
.wrap_err("Failed to apply intersect to logical plan"),
SetOpType::Union => left
.builder
.union(&right.builder, is_all)
.wrap_err("Failed to apply union to logical plan"),
SetOpType::Except => {
bail!("Except set operation is not supported");
}
}?;

// we merged left and right psets
Ok(Plan {
builder,
psets: left.psets,
})
}
21 changes: 21 additions & 0 deletions tests/connect/test_intersection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations


def test_intersection(spark_session):
# Create ranges using Spark - with overlap
range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6
range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9

# Intersect the two ranges
intersected = range1.intersect(range2)

# Collect results
results = intersected.collect()

# Verify the DataFrame has expected values
# Intersection should only include overlapping values once
assert len(results) == 4, "DataFrame should have 4 rows (overlapping values 3,4,5,6)"

# Check that all expected values are present
values = [row.id for row in results]
assert sorted(values) == [3, 4, 5, 6], "Values should match expected overlapping sequence"
36 changes: 36 additions & 0 deletions tests/connect/test_union.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations


def test_union(spark_session):
# Create ranges using Spark - with overlap
range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6
range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9

# Union the two ranges
unioned = range1.union(range2)

# Collect results
results = unioned.collect()

# Verify the DataFrame has expected values
# Union includes duplicates, so length should be sum of both ranges
assert len(results) == 14, "DataFrame should have 14 rows (7 + 7)"

# Check that all expected values are present, including duplicates
values = [row.id for row in results]
assert sorted(values) == [
0,
1,
2,
3,
3,
4,
4,
5,
5,
6,
6,
7,
8,
9,
], "Values should match expected sequence with duplicates"

0 comments on commit 1e96247

Please sign in to comment.