From b576b8c5af7922efb67bba5d272ac8b7aa2b2e30 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 20:56:54 -0800 Subject: [PATCH] [FEAT] connect: support `sample` --- .../src/translation/logical_plan.rs | 4 +- .../src/translation/logical_plan/sample.rs | 41 +++++++++++++++++++ tests/connect/test_sample.py | 18 ++++++++ 3 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 src/daft-connect/src/translation/logical_plan/sample.rs create mode 100644 tests/connect/test_sample.py diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index e96d006c6e..add4a72bd1 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -9,13 +9,14 @@ use tracing::warn; use crate::translation::logical_plan::{ aggregate::aggregate, local_relation::local_relation, project::project, range::range, - set_op::set_op, to_df::to_df, with_columns::with_columns, + sample::sample, set_op::set_op, to_df::to_df, with_columns::with_columns, }; mod aggregate; mod local_relation; mod project; mod range; +mod sample; mod set_op; mod to_df; mod with_columns; @@ -58,6 +59,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::LocalRelation(l) => { local_relation(l).wrap_err("Failed to apply local_relation to logical plan") } + RelType::Sample(s) => sample(*s).wrap_err("Failed to apply sample to logical plan"), RelType::SetOp(s) => set_op(*s).wrap_err("Failed to apply set_op to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } diff --git a/src/daft-connect/src/translation/logical_plan/sample.rs b/src/daft-connect/src/translation/logical_plan/sample.rs new file mode 100644 index 0000000000..af9728531d --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/sample.rs @@ -0,0 +1,41 @@ +use eyre::{bail, WrapErr}; +use tracing::warn; + +use crate::translation::{to_logical_plan, Plan}; + +pub fn sample(sample: spark_connect::Sample) -> eyre::Result { + let spark_connect::Sample { + input, + lower_bound, + upper_bound, + with_replacement, + seed, + deterministic_order, + } = sample; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let mut plan = to_logical_plan(*input)?; + + // Calculate fraction from bounds + // todo: is this correct? + let fraction = upper_bound - lower_bound; + + let with_replacement = with_replacement.unwrap_or(false); + + // we do not care about sign change + let seed = seed.map(|seed| seed as u64); + + if deterministic_order { + warn!("Deterministic order is not yet supported"); + } + + plan.builder = plan + .builder + .sample(fraction, with_replacement, seed) + .wrap_err("Failed to apply sample to logical plan")?; + + Ok(plan) +} diff --git a/tests/connect/test_sample.py b/tests/connect/test_sample.py new file mode 100644 index 0000000000..c7bd4df86e --- /dev/null +++ b/tests/connect/test_sample.py @@ -0,0 +1,18 @@ +from __future__ import annotations + + +def test_sample(spark_session): + # Create a range DataFrame + df = spark_session.range(100) + + # Test sample with fraction + sampled_df = df.sample(fraction=0.1, seed=42) + sampled_rows = sampled_df.collect() + + # Verify sample size is roughly 10% of original + sample_size = len(sampled_rows) + assert 5 <= sample_size <= 15, f"Sample size {sample_size} should be roughly 10 rows" + + # Verify sampled values are within original range + for row in sampled_rows: + assert 0 <= row["id"] < 100, f"Sampled value {row['id']} outside valid range"