From ca4d3f794a376616557110ab12f6a115f17735f9 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Wed, 18 Dec 2024 16:35:18 -0600 Subject: [PATCH] feat(connect): df.show (#3560) depends on https://github.com/Eventual-Inc/Daft/pull/3554 [see here for proper diff](https://github.com/universalmind303/Daft/compare/refactor-lp-3...universalmind303:Daft:connect_show?expand=1) --- src/daft-connect/src/session.rs | 1 - .../src/translation/logical_plan.rs | 105 +++++++++++++++++- src/daft-dsl/src/lit.rs | 6 + tests/connect/test_show.py | 9 ++ 4 files changed, 116 insertions(+), 5 deletions(-) create mode 100644 tests/connect/test_show.py diff --git a/src/daft-connect/src/session.rs b/src/daft-connect/src/session.rs index 30f827ba9e..7de8d5851b 100644 --- a/src/daft-connect/src/session.rs +++ b/src/daft-connect/src/session.rs @@ -28,7 +28,6 @@ impl Session { pub fn new(id: String) -> Self { let server_side_session_id = Uuid::new_v4(); let server_side_session_id = server_side_session_id.to_string(); - Self { config_values: Default::default(), id, diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 15eb495502..439f5bd551 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -1,7 +1,21 @@ +use std::sync::Arc; + +use common_daft_config::DaftExecutionConfig; +use daft_core::prelude::Schema; +use daft_dsl::LiteralValue; +use daft_local_execution::NativeExecutor; use daft_logical_plan::LogicalPlanBuilder; -use daft_micropartition::partitioning::InMemoryPartitionSetCache; +use daft_micropartition::{ + partitioning::{ + InMemoryPartitionSetCache, MicroPartitionSet, PartitionCacheEntry, PartitionMetadata, + PartitionSet, PartitionSetCache, + }, + MicroPartition, +}; +use daft_table::Table; use eyre::{bail, Context}; -use spark_connect::{relation::RelType, Limit, Relation}; +use futures::TryStreamExt; +use spark_connect::{relation::RelType, Limit, Relation, ShowString}; use tracing::warn; mod aggregate; @@ -22,6 +36,35 @@ impl SparkAnalyzer<'_> { pub fn new(pset: &InMemoryPartitionSetCache) -> SparkAnalyzer { SparkAnalyzer { psets: pset } } + pub fn create_in_memory_scan( + &self, + plan_id: usize, + schema: Arc, + tables: Vec, + ) -> eyre::Result { + let partition_key = uuid::Uuid::new_v4().to_string(); + + let pset = Arc::new(MicroPartitionSet::from_tables(plan_id, tables)?); + + let PartitionMetadata { + num_rows, + size_bytes, + } = pset.metadata(); + let num_partitions = pset.num_partitions(); + + self.psets.put_partition_set(&partition_key, &pset); + + let cache_entry = PartitionCacheEntry::new_rust(partition_key.clone(), pset); + + Ok(LogicalPlanBuilder::in_memory_scan( + &partition_key, + cache_entry, + schema, + num_partitions, + size_bytes, + num_rows, + )?) + } pub async fn to_logical_plan(&self, relation: Relation) -> eyre::Result { let Some(common) = relation.common else { @@ -78,12 +121,18 @@ impl SparkAnalyzer<'_> { .filter(*f) .await .wrap_err("Failed to apply filter to logical plan"), + RelType::ShowString(ss) => { + let Some(plan_id) = common.plan_id else { + bail!("Plan ID is required for LocalRelation"); + }; + self.show_string(plan_id, *ss) + .await + .wrap_err("Failed to show string") + } plan => bail!("Unsupported relation type: {plan:?}"), } } -} -impl SparkAnalyzer<'_> { async fn limit(&self, limit: Limit) -> eyre::Result { let Limit { input, limit } = limit; @@ -96,4 +145,52 @@ impl SparkAnalyzer<'_> { plan.limit(i64::from(limit), false) .wrap_err("Failed to apply limit to logical plan") } + + /// right now this just naively applies a limit to the logical plan + /// In the future, we want this to more closely match our daft implementation + async fn show_string( + &self, + plan_id: i64, + show_string: ShowString, + ) -> eyre::Result { + let ShowString { + input, + num_rows, + truncate: _, + vertical, + } = show_string; + + if vertical { + bail!("Vertical show string is not supported"); + } + + let Some(input) = input else { + bail!("input must be set"); + }; + + let plan = Box::pin(self.to_logical_plan(*input)).await?; + let plan = plan.limit(num_rows as i64, true)?; + + let optimized_plan = tokio::task::spawn_blocking(move || plan.optimize()) + .await + .unwrap()?; + + let cfg = Arc::new(DaftExecutionConfig::default()); + let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; + let result_stream = native_executor.run(self.psets, cfg, None)?.into_stream(); + let batch = result_stream.try_collect::>().await?; + let single_batch = MicroPartition::concat(batch)?; + let tbls = single_batch.get_tables()?; + let tbl = Table::concat(&tbls)?; + let output = tbl.to_comfy_table(None).to_string(); + + let s = LiteralValue::Utf8(output) + .into_single_value_series()? + .rename("show_string"); + + let tbl = Table::from_nonempty_columns(vec![s])?; + let schema = tbl.schema.clone(); + + self.create_in_memory_scan(plan_id as _, schema, vec![tbl]) + } } diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 1d86442aef..c1c7ce81c3 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -444,6 +444,12 @@ pub fn null_lit() -> ExprRef { Arc::new(Expr::Literal(LiteralValue::Null)) } +impl LiteralValue { + pub fn into_single_value_series(self) -> DaftResult { + literals_to_series(&[self]) + } +} + /// Convert a slice of literals to a series. /// This function will return an error if the literals are not all the same type pub fn literals_to_series(values: &[LiteralValue]) -> DaftResult { diff --git a/tests/connect/test_show.py b/tests/connect/test_show.py new file mode 100644 index 0000000000..a463d5de72 --- /dev/null +++ b/tests/connect/test_show.py @@ -0,0 +1,9 @@ +from __future__ import annotations + + +def test_show(spark_session): + df = spark_session.range(10) + try: + df.show() + except Exception as e: + assert False, e