Skip to content

Commit

Permalink
[FEAT]: connect: df.distinct()
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 19, 2024
1 parent 1e96247 commit f7382ac
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ use futures::TryStreamExt;
use spark_connect::{relation::RelType, Limit, Relation, ShowString};
use tracing::warn;

use crate::translation::logical_plan::{deduplicate::deduplicate, set_op::set_op};

mod aggregate;
mod deduplicate;
mod drop;
mod filter;
mod local_relation;
Expand Down Expand Up @@ -133,6 +136,9 @@ impl SparkAnalyzer<'_> {
RelType::SetOp(s) => set_op(*s)
.await
.wrap_err("Failed to apply set_op to logical plan"),
RelType::Deduplicate(d) => deduplicate(*d)
.await
.wrap_err("Failed to apply deduplicate to logical plan"),
plan => bail!("Unsupported relation type: {plan:?}"),
}
}
Expand Down
41 changes: 41 additions & 0 deletions src/daft-connect/src/translation/logical_plan/deduplicate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use eyre::{bail, ensure, WrapErr};
use tracing::warn;

use crate::translation::{to_logical_plan, Plan};

pub async fn deduplicate(deduplicate: spark_connect::Deduplicate) -> eyre::Result<Plan> {
let spark_connect::Deduplicate {
input,
column_names,
all_columns_as_keys,
within_watermark,
} = deduplicate;

let Some(input) = input else {
bail!("Input is required");
};

if !column_names.is_empty() {
warn!("Ignoring column_names: {column_names:?}; not yet implemented");
}

let all_columns_as_keys = all_columns_as_keys.unwrap_or(false);

ensure!(
all_columns_as_keys,
"only implemented for all_columns_as_keys=true"
);

if let Some(within_watermark) = within_watermark {
warn!("Ignoring within_watermark: {within_watermark:?}; not yet implemented");
}

let mut plan = Box::pin(to_logical_plan(*input)).await?;

plan.builder = plan
.builder
.distinct()
.wrap_err("Failed to apply distinct to logical plan")?;

Ok(plan)
}
21 changes: 21 additions & 0 deletions tests/connect/test_distinct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations


def test_distinct(spark_session):
# Create ranges using Spark - with overlap
range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6
range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9

# Union the two ranges and get distinct values
unioned = range1.union(range2).distinct()

# Collect results
results = unioned.collect()

# Verify the DataFrame has expected values
# Distinct removes duplicates, so length should be 10 (0-9)
assert len(results) == 10, "DataFrame should have 10 unique rows"

# Check that all expected values are present, with no duplicates
values = [row.id for row in results]
assert sorted(values) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], "Values should match expected sequence without duplicates"

0 comments on commit f7382ac

Please sign in to comment.