diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 59bdc41d4c..e574ee8673 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -5,11 +5,13 @@ use spark_connect::{relation::RelType, Limit, Relation}; use tracing::warn; use crate::translation::logical_plan::{ - aggregate::aggregate, local_relation::local_relation, project::project, range::range, - read::read, set_op::set_op, to_df::to_df, with_columns::with_columns, + aggregate::aggregate, deduplicate::deduplicate, local_relation::local_relation, + project::project, range::range, read::read, set_op::set_op, to_df::to_df, + with_columns::with_columns, }; mod aggregate; +mod deduplicate; mod local_relation; mod project; mod range; @@ -78,6 +80,9 @@ pub async fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::SetOp(s) => set_op(*s) .await .wrap_err("Failed to apply set_op to logical plan"), + RelType::Deduplicate(d) => deduplicate(*d) + .await + .wrap_err("Failed to apply deduplicate to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } } diff --git a/src/daft-connect/src/translation/logical_plan/deduplicate.rs b/src/daft-connect/src/translation/logical_plan/deduplicate.rs new file mode 100644 index 0000000000..512b81620c --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/deduplicate.rs @@ -0,0 +1,41 @@ +use eyre::{bail, ensure, WrapErr}; +use tracing::warn; + +use crate::translation::{to_logical_plan, Plan}; + +pub async fn deduplicate(deduplicate: spark_connect::Deduplicate) -> eyre::Result { + let spark_connect::Deduplicate { + input, + column_names, + all_columns_as_keys, + within_watermark, + } = deduplicate; + + let Some(input) = input else { + bail!("Input is required"); + }; + + if !column_names.is_empty() { + warn!("Ignoring column_names: {column_names:?}; not yet implemented"); + } + + let all_columns_as_keys = all_columns_as_keys.unwrap_or(false); + + ensure!( + all_columns_as_keys, + "only implemented for all_columns_as_keys=true" + ); + + if let Some(within_watermark) = within_watermark { + warn!("Ignoring within_watermark: {within_watermark:?}; not yet implemented"); + } + + let mut plan = Box::pin(to_logical_plan(*input)).await?; + + plan.builder = plan + .builder + .distinct() + .wrap_err("Failed to apply distinct to logical plan")?; + + Ok(plan) +} diff --git a/tests/connect/test_distinct.py b/tests/connect/test_distinct.py new file mode 100644 index 0000000000..9e4b861d36 --- /dev/null +++ b/tests/connect/test_distinct.py @@ -0,0 +1,21 @@ +from __future__ import annotations + + +def test_distinct(spark_session): + # Create ranges using Spark - with overlap + range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6 + range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9 + + # Union the two ranges and get distinct values + unioned = range1.union(range2).distinct() + + # Collect results + results = unioned.collect() + + # Verify the DataFrame has expected values + # Distinct removes duplicates, so length should be 10 (0-9) + assert len(results) == 10, "DataFrame should have 10 unique rows" + + # Check that all expected values are present, with no duplicates + values = [row.id for row in results] + assert sorted(values) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], "Values should match expected sequence without duplicates"