From 49d0dacdc50ce5941818b7506cd3440ac28757b1 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 20:56:54 -0800 Subject: [PATCH] [FEAT] connect: support `sample` --- .../src/translation/logical_plan.rs | 4 +- .../src/translation/logical_plan/sample.rs | 42 +++++++++++++++++++ tests/connect/test_sample.py | 18 ++++++++ 3 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 src/daft-connect/src/translation/logical_plan/sample.rs create mode 100644 tests/connect/test_sample.py diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 53a0cfc923..26804832a4 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -4,13 +4,14 @@ use spark_connect::{relation::RelType, Limit, Relation}; use tracing::warn; use crate::translation::logical_plan::{ - aggregate::aggregate, project::project, range::range, set_op::set_op, + aggregate::aggregate, project::project, range::range, sample::sample, set_op::set_op, with_columns::with_columns, }; mod aggregate; mod project; mod range; +mod sample; mod set_op; mod with_columns; @@ -33,6 +34,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::WithColumns(w) => { with_columns(*w).wrap_err("Failed to apply with_columns to logical plan") } + RelType::Sample(s) => sample(*s).wrap_err("Failed to apply sample to logical plan"), RelType::SetOp(s) => set_op(*s).wrap_err("Failed to apply set_op to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } diff --git a/src/daft-connect/src/translation/logical_plan/sample.rs b/src/daft-connect/src/translation/logical_plan/sample.rs new file mode 100644 index 0000000000..f6a714addd --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/sample.rs @@ -0,0 +1,42 @@ +use eyre::{bail, WrapErr}; +use tracing::warn; + +use crate::translation::to_logical_plan; + +pub fn sample( + sample: spark_connect::Sample, +) -> eyre::Result { + let spark_connect::Sample { + input, + lower_bound, + upper_bound, + with_replacement, + seed, + deterministic_order, + } = sample; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let plan = to_logical_plan(*input)?; + + // Calculate fraction from bounds + // todo: is this correct? + let fraction = upper_bound - lower_bound; + + let with_replacement = with_replacement.unwrap_or(false); + + // we do not care about sign change + let seed = seed.map(|seed| seed as u64); + + if deterministic_order { + warn!("Deterministic order is not yet supported"); + } + + let plan = plan + .sample(fraction, with_replacement, seed) + .wrap_err("Failed to apply sample to logical plan")?; + + Ok(plan) +} diff --git a/tests/connect/test_sample.py b/tests/connect/test_sample.py new file mode 100644 index 0000000000..c7bd4df86e --- /dev/null +++ b/tests/connect/test_sample.py @@ -0,0 +1,18 @@ +from __future__ import annotations + + +def test_sample(spark_session): + # Create a range DataFrame + df = spark_session.range(100) + + # Test sample with fraction + sampled_df = df.sample(fraction=0.1, seed=42) + sampled_rows = sampled_df.collect() + + # Verify sample size is roughly 10% of original + sample_size = len(sampled_rows) + assert 5 <= sample_size <= 15, f"Sample size {sample_size} should be roughly 10 rows" + + # Verify sampled values are within original range + for row in sampled_rows: + assert 0 <= row["id"] < 100, f"Sampled value {row['id']} outside valid range"