diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 70171ad0d4..60965c01d1 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -328,7 +328,10 @@ impl SparkConnectService for DaftSparkConnectService { Ok(Response::new(response)) } - _ => unimplemented_err!("Analyze plan operation is not yet implemented"), + op => { + println!("{op:#?}"); + unimplemented_err!("Analyze plan operation is not yet implemented") + }, } } diff --git a/src/daft-connect/src/op.rs b/src/daft-connect/src/op.rs index 2e8bdddf98..4e012a6c30 100644 --- a/src/daft-connect/src/op.rs +++ b/src/daft-connect/src/op.rs @@ -1 +1,2 @@ pub mod execute; +pub mod analyze; diff --git a/src/daft-connect/src/op/analyze.rs b/src/daft-connect/src/op/analyze.rs new file mode 100644 index 0000000000..f7f738134e --- /dev/null +++ b/src/daft-connect/src/op/analyze.rs @@ -0,0 +1,101 @@ +use std::pin::Pin; + +use spark_connect::{analyze_plan_response, AnalyzePlanResponse, Relation}; +use tonic::Status; + +use crate::session::Session; + +pub type AnalyzeStream = + Pin> + Send + Sync>>; + +impl Session { + pub async fn handle_explain_command( + &self, + command: Relation, + operation_id: String, + ) -> Result { + } +} + +use std::{collections::HashMap, future::ready}; + +use common_daft_config::DaftExecutionConfig; +use futures::stream; +use spark_connect::{ExecutePlanResponse, Relation}; +use tokio_util::sync::CancellationToken; +use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status}; + +use crate::{op::execute::ExecuteStream, session::Session, translation}; + +pub struct PlanIds { + session: String, + server_side_session: String, + operation: String, +} + +impl PlanIds { + pub fn response(&self, result: analyze_plan_response::Result) -> AnalyzePlanResponse { + AnalyzePlanResponse { + session_id: self.session.to_string(), + server_side_session_id: self.server_side_session.to_string(), + result: Some(result), + } + } +} + +impl Session { + pub async fn handle_explain_command( + &self, + command: Relation, + operation_id: String, + ) -> Result { + use futures::{StreamExt, TryStreamExt}; + + let context = PlanIds { + session: self.client_side_session_id().to_string(), + server_side_session: self.server_side_session_id().to_string(), + operation: operation_id, + }; + + let (tx, rx) = tokio::sync::mpsc::channel::>(16); + std::thread::spawn(move || { + let result = (|| -> eyre::Result<()> { + let plan = translation::to_logical_plan(command)?; + let logical_plan = plan.build(); + let physical_plan = daft_local_plan::translate(&logical_plan)?; + + let cfg = DaftExecutionConfig::default(); + let results = daft_local_execution::run_local( + &physical_plan, + HashMap::new(), + cfg.into(), + None, + CancellationToken::new(), // todo: maybe implement cancelling + )?; + + for result in results { + let result = result?; + let tables = result.get_tables()?; + + for table in tables.as_slice() { + let response = context.gen_response(table)?; + tx.blocking_send(Ok(response)).unwrap(); + } + } + Ok(()) + })(); + + if let Err(e) = result { + tx.blocking_send(Err(e)).unwrap(); + } + }); + + let stream = ReceiverStream::new(rx); + + let stream = stream + .map_err(|e| Status::internal(format!("Error in Daft server: {e:?}"))) + .chain(stream::once(ready(Ok(finished)))); + + Ok(Box::pin(stream)) + } +} diff --git a/src/daft-connect/src/op/execute.rs b/src/daft-connect/src/op/execute.rs index fba3cc850d..7ec6a44324 100644 --- a/src/daft-connect/src/op/execute.rs +++ b/src/daft-connect/src/op/execute.rs @@ -14,7 +14,7 @@ mod root; pub type ExecuteStream = ::ExecutePlanStream; -pub struct PlanIds { +struct PlanIds { session: String, server_side_session: String, operation: String, diff --git a/tests/connect/test_explain.py b/tests/connect/test_explain.py new file mode 100644 index 0000000000..fd8c9e02c4 --- /dev/null +++ b/tests/connect/test_explain.py @@ -0,0 +1,21 @@ +from __future__ import annotations + + +def test_explain(spark_session): + # Create ranges using Spark - with overlap + range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6 + range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9 + + # Union the two ranges + unioned = range1.union(range2) + + # Get the explain plan + explain_str = unioned.explain(extended=True) + + # Verify explain output contains expected elements + assert "Union" in explain_str, "Explain plan should contain Union operation" + assert "Range" in explain_str, "Explain plan should contain Range operations" + + # Check that both range operations are present + assert "(0, 7" in explain_str, "First range parameters should be in explain plan" + assert "(3, 10" in explain_str, "Second range parameters should be in explain plan"