diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 369bfe8e47..d541d4d8dd 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -24,7 +24,7 @@ use spark_connect::{ use tonic::{transport::Server, Request, Response, Status}; use tracing::info; use uuid::Uuid; - +use spark_connect::analyze_plan_request::explain::ExplainMode; use crate::session::Session; mod config; @@ -142,10 +142,10 @@ impl DaftSparkConnectService { #[tonic::async_trait] impl SparkConnectService for DaftSparkConnectService { type ExecutePlanStream = std::pin::Pin< - Box> + Send + 'static>, + Box> + Send + 'static>, >; type ReattachExecuteStream = std::pin::Pin< - Box> + Send + 'static>, + Box> + Send + 'static>, >; #[tracing::instrument(skip_all)] @@ -282,6 +282,8 @@ impl SparkConnectService for DaftSparkConnectService { use spark_connect::analyze_plan_request::*; let request = request.into_inner(); + let mut session = self.get_session(&request.session_id)?; + let AnalyzePlanRequest { session_id, analyze, @@ -323,6 +325,35 @@ impl SparkConnectService for DaftSparkConnectService { Ok(Response::new(response)) } + Analyze::Explain(explain) => { + let Explain { plan, explain_mode } = explain; + + let explain_mode = ExplainMode::try_from(explain_mode) + .map_err(|_| invalid_argument_err!("Invalid Explain Mode"))?; + + let Some(plan) = plan else { + return invalid_argument_err!("Plan is required"); + }; + + let Some(plan) = plan.op_type else { + return invalid_argument_err!("Op Type is required"); + }; + + let OpType::Root(relation) = plan else { + return invalid_argument_err!("Plan operation is required"); + }; + + let result = match session.handle_explain_command(relation, explain_mode).await { + Ok(result) => result, + Err(e) => return Err(Status::internal(format!("Error in Daft server: {e:?}"))), + }; + + Ok(Response::new(result)) + } + op => { + println!("{op:#?}"); + unimplemented_err!("Analyze plan operation is not yet implemented") + } Analyze::DdlParse(DdlParse { ddl_string }) => { let daft_schema = match daft_sql::sql_schema(&ddl_string) { Ok(daft_schema) => daft_schema, diff --git a/src/daft-connect/src/op.rs b/src/daft-connect/src/op.rs index 2e8bdddf98..4e012a6c30 100644 --- a/src/daft-connect/src/op.rs +++ b/src/daft-connect/src/op.rs @@ -1 +1,2 @@ pub mod execute; +pub mod analyze; diff --git a/src/daft-connect/src/op/analyze.rs b/src/daft-connect/src/op/analyze.rs new file mode 100644 index 0000000000..5bb85bd194 --- /dev/null +++ b/src/daft-connect/src/op/analyze.rs @@ -0,0 +1,52 @@ +use std::pin::Pin; + +use spark_connect::{analyze_plan_response, AnalyzePlanResponse}; + +pub type AnalyzeStream = + Pin> + Send + Sync>>; + +use spark_connect::{analyze_plan_request::explain::ExplainMode, Relation}; +use tonic::Status; + +use crate::{session::Session, translation}; + +pub struct PlanIds { + session: String, + server_side_session: String, +} + +impl PlanIds { + pub fn response(&self, result: analyze_plan_response::Result) -> AnalyzePlanResponse { + AnalyzePlanResponse { + session_id: self.session.to_string(), + server_side_session_id: self.server_side_session.to_string(), + result: Some(result), + } + } +} + +impl Session { + pub async fn handle_explain_command( + &self, + command: Relation, + _mode: ExplainMode, + ) -> eyre::Result { + let context = PlanIds { + session: self.client_side_session_id().to_string(), + server_side_session: self.server_side_session_id().to_string(), + }; + + let plan = translation::to_logical_plan(command)?; + let optimized_plan = plan.optimize()?; + + let optimized_plan = optimized_plan.build(); + + // todo: what do we want this to display + let explain_string = format!("{optimized_plan}"); + + let schema = analyze_plan_response::Explain { explain_string }; + + let response = context.response(analyze_plan_response::Result::Explain(schema)); + Ok(response) + } +} diff --git a/src/daft-connect/src/op/execute.rs b/src/daft-connect/src/op/execute.rs index 41baf88b09..3b517b043a 100644 --- a/src/daft-connect/src/op/execute.rs +++ b/src/daft-connect/src/op/execute.rs @@ -15,7 +15,7 @@ mod write; pub type ExecuteStream = ::ExecutePlanStream; -pub struct PlanIds { +struct PlanIds { session: String, server_side_session: String, operation: String, diff --git a/tests/connect/test_explain.py b/tests/connect/test_explain.py new file mode 100644 index 0000000000..3b5574dec3 --- /dev/null +++ b/tests/connect/test_explain.py @@ -0,0 +1,16 @@ +from __future__ import annotations + + +def test_explain(spark_session): + # Create ranges using Spark - with overlap + range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6 + range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9 + + # Union the two ranges + unioned = range1.union(range2) + + # Get the explain plan + explain_str = unioned.explain(extended=True) + + # Verify explain output contains expected elements + print(explain_str)