diff --git a/Cargo.lock b/Cargo.lock index fd22dcfa10..09b6dfd699 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1931,6 +1931,7 @@ dependencies = [ "arrow2", "async-stream", "common-daft-config", + "common-file-formats", "daft-core", "daft-dsl", "daft-local-execution", diff --git a/Cargo.toml b/Cargo.toml index 67334d8b0d..4af4918398 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -194,6 +194,7 @@ chrono-tz = "0.10.0" comfy-table = "7.1.1" common-daft-config = {path = "src/common/daft-config"} common-error = {path = "src/common/error", default-features = false} +common-file-formats = {path = "src/common/file-formats"} common-runtime = {path = "src/common/runtime", default-features = false} daft-core = {path = "src/daft-core"} daft-dsl = {path = "src/daft-dsl"} diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index 9651106968..7e085df7f5 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -2,6 +2,7 @@ arrow2 = {workspace = true} async-stream = "0.3.6" common-daft-config = {workspace = true} +common-file-formats = {workspace = true} daft-core = {workspace = true} daft-dsl = {workspace = true} daft-local-execution = {workspace = true} diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 70171ad0d4..1fef44b619 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -142,14 +142,10 @@ impl DaftSparkConnectService { #[tonic::async_trait] impl SparkConnectService for DaftSparkConnectService { type ExecutePlanStream = std::pin::Pin< - Box< - dyn futures::Stream> + Send + Sync + 'static, - >, + Box> + Send + 'static>, >; type ReattachExecuteStream = std::pin::Pin< - Box< - dyn futures::Stream> + Send + Sync + 'static, - >, + Box> + Send + 'static>, >; #[tracing::instrument(skip_all)] @@ -190,8 +186,9 @@ impl SparkConnectService for DaftSparkConnectService { CommandType::RegisterFunction(_) => { unimplemented_err!("RegisterFunction not implemented") } - CommandType::WriteOperation(_) => { - unimplemented_err!("WriteOperation not implemented") + CommandType::WriteOperation(op) => { + let result = session.handle_write_command(op, operation).await?; + return Ok(Response::new(result)); } CommandType::CreateDataframeView(_) => { unimplemented_err!("CreateDataframeView not implemented") @@ -305,7 +302,7 @@ impl SparkConnectService for DaftSparkConnectService { return Err(Status::invalid_argument("op_type is required to be root")); }; - let result = match translation::relation_to_schema(relation) { + let result = match translation::relation_to_schema(relation).await { Ok(schema) => schema, Err(e) => { return invalid_argument_err!( diff --git a/src/daft-connect/src/op/execute.rs b/src/daft-connect/src/op/execute.rs index fba3cc850d..41baf88b09 100644 --- a/src/daft-connect/src/op/execute.rs +++ b/src/daft-connect/src/op/execute.rs @@ -11,6 +11,7 @@ use uuid::Uuid; use crate::{DaftSparkConnectService, Session}; mod root; +mod write; pub type ExecuteStream = ::ExecutePlanStream; diff --git a/src/daft-connect/src/op/execute/root.rs b/src/daft-connect/src/op/execute/root.rs index 1e1fac147b..c3e1db22a4 100644 --- a/src/daft-connect/src/op/execute/root.rs +++ b/src/daft-connect/src/op/execute/root.rs @@ -31,7 +31,7 @@ impl Session { let (tx, rx) = tokio::sync::mpsc::channel::>(1); tokio::spawn(async move { let execution_fut = async { - let plan = translation::to_logical_plan(command)?; + let plan = translation::to_logical_plan(command).await?; let optimized_plan = plan.optimize()?; let cfg = DaftExecutionConfig::default(); let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; diff --git a/src/daft-connect/src/op/execute/write.rs b/src/daft-connect/src/op/execute/write.rs new file mode 100644 index 0000000000..9781fac770 --- /dev/null +++ b/src/daft-connect/src/op/execute/write.rs @@ -0,0 +1,144 @@ +use std::{collections::HashMap, future::ready}; + +use common_daft_config::DaftExecutionConfig; +use common_file_formats::FileFormat; +use daft_local_execution::NativeExecutor; +use eyre::{bail, WrapErr}; +use spark_connect::{ + write_operation::{SaveMode, SaveType}, + WriteOperation, +}; +use tonic::Status; +use tracing::warn; + +use crate::{ + op::execute::{ExecuteStream, PlanIds}, + session::Session, + translation, +}; + +impl Session { + pub async fn handle_write_command( + &self, + operation: WriteOperation, + operation_id: String, + ) -> Result { + use futures::StreamExt; + + let context = PlanIds { + session: self.client_side_session_id().to_string(), + server_side_session: self.server_side_session_id().to_string(), + operation: operation_id, + }; + + let finished = context.finished(); + + let result = async move { + let WriteOperation { + input, + source, + mode, + sort_column_names, + partitioning_columns, + bucket_by, + options, + clustering_columns, + save_type, + } = operation; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let Some(source) = source else { + bail!("Source is required"); + }; + + if source != "parquet" { + bail!("Unsupported source: {source}; only parquet is supported"); + } + + let Ok(mode) = SaveMode::try_from(mode) else { + bail!("Invalid save mode: {mode}"); + }; + + if !sort_column_names.is_empty() { + // todo(completeness): implement sort + warn!("Ignoring sort_column_names: {sort_column_names:?} (not yet implemented)"); + } + + if !partitioning_columns.is_empty() { + // todo(completeness): implement partitioning + warn!( + "Ignoring partitioning_columns: {partitioning_columns:?} (not yet implemented)" + ); + } + + if let Some(bucket_by) = bucket_by { + // todo(completeness): implement bucketing + warn!("Ignoring bucket_by: {bucket_by:?} (not yet implemented)"); + } + + if !options.is_empty() { + // todo(completeness): implement options + warn!("Ignoring options: {options:?} (not yet implemented)"); + } + + if !clustering_columns.is_empty() { + // todo(completeness): implement clustering + warn!("Ignoring clustering_columns: {clustering_columns:?} (not yet implemented)"); + } + + match mode { + SaveMode::Unspecified => {} + SaveMode::Append => {} + SaveMode::Overwrite => {} + SaveMode::ErrorIfExists => {} + SaveMode::Ignore => {} + } + + let Some(save_type) = save_type else { + bail!("Save type is required"); + }; + + let path = match save_type { + SaveType::Path(path) => path, + SaveType::Table(table) => { + let name = table.table_name; + bail!("Tried to write to table {name} but it is not yet implemented. Try to write to a path instead."); + } + }; + + let plan = translation::to_logical_plan(input).await?; + + let plan = plan + .table_write(&path, FileFormat::Parquet, None, None, None) + .wrap_err("Failed to create table write plan")?; + + let optimized_plan = plan.optimize()?; + let cfg = DaftExecutionConfig::default(); + let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; + let mut result_stream = native_executor + .run(HashMap::new(), cfg.into(), None)? + .into_stream(); + + // this is so we make sure the operation is actually done + // before we return + // + // an example where this is important is if we write to a parquet file + // and then read immediately after, we need to wait for the write to finish + while let Some(_result) = result_stream.next().await {} + + Ok(()) + }; + + use futures::TryFutureExt; + + let result = result.map_err(|e| Status::internal(format!("Error in Daft server: {e:?}"))); + + let future = result.and_then(|_| ready(Ok(finished))); + let stream = futures::stream::once(future); + + Ok(Box::pin(stream)) + } +} diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 93c9e9bd4a..e168676d1f 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -3,13 +3,16 @@ use eyre::{bail, Context}; use spark_connect::{relation::RelType, Limit, Relation}; use tracing::warn; -use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range}; +use crate::translation::logical_plan::{ + aggregate::aggregate, project::project, range::range, read::read, +}; mod aggregate; mod project; mod range; +mod read; -pub fn to_logical_plan(relation: Relation) -> eyre::Result { +pub async fn to_logical_plan(relation: Relation) -> eyre::Result { if let Some(common) = relation.common { warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); }; @@ -19,24 +22,33 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { }; match rel_type { - RelType::Limit(l) => limit(*l).wrap_err("Failed to apply limit to logical plan"), + RelType::Limit(l) => limit(*l) + .await + .wrap_err("Failed to apply limit to logical plan"), RelType::Range(r) => range(r).wrap_err("Failed to apply range to logical plan"), - RelType::Project(p) => project(*p).wrap_err("Failed to apply project to logical plan"), - RelType::Aggregate(a) => { - aggregate(*a).wrap_err("Failed to apply aggregate to logical plan") - } + RelType::Project(p) => project(*p) + .await + .wrap_err("Failed to apply project to logical plan"), + RelType::Aggregate(a) => aggregate(*a) + .await + .wrap_err("Failed to apply aggregate to logical plan"), + RelType::Read(r) => read(r) + .await + .wrap_err("Failed to apply table read to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } } -fn limit(limit: Limit) -> eyre::Result { +async fn limit(limit: Limit) -> eyre::Result { let Limit { input, limit } = limit; let Some(input) = input else { bail!("input must be set"); }; - let plan = to_logical_plan(*input)?.limit(i64::from(limit), false)?; // todo: eager or no + let plan = Box::pin(to_logical_plan(*input)) + .await? + .limit(i64::from(limit), false)?; // todo: eager or no Ok(plan) } diff --git a/src/daft-connect/src/translation/logical_plan/aggregate.rs b/src/daft-connect/src/translation/logical_plan/aggregate.rs index 193ca4d088..a9500cc308 100644 --- a/src/daft-connect/src/translation/logical_plan/aggregate.rs +++ b/src/daft-connect/src/translation/logical_plan/aggregate.rs @@ -4,7 +4,7 @@ use spark_connect::aggregate::GroupType; use crate::translation::{to_daft_expr, to_logical_plan}; -pub fn aggregate(aggregate: spark_connect::Aggregate) -> eyre::Result { +pub async fn aggregate(aggregate: spark_connect::Aggregate) -> eyre::Result { let spark_connect::Aggregate { input, group_type, @@ -18,7 +18,7 @@ pub fn aggregate(aggregate: spark_connect::Aggregate) -> eyre::Result eyre::Result { +pub async fn project(project: Project) -> eyre::Result { let Project { input, expressions } = project; let Some(input) = input else { bail!("Project input is required"); }; - let plan = to_logical_plan(*input)?; + let plan = Box::pin(to_logical_plan(*input)).await?; let daft_exprs: Vec<_> = expressions.iter().map(to_daft_expr).try_collect()?; diff --git a/src/daft-connect/src/translation/logical_plan/read.rs b/src/daft-connect/src/translation/logical_plan/read.rs new file mode 100644 index 0000000000..af7aed29dc --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/read.rs @@ -0,0 +1,29 @@ +use daft_logical_plan::LogicalPlanBuilder; +use eyre::{bail, WrapErr}; +use spark_connect::read::ReadType; +use tracing::warn; + +mod data_source; + +pub async fn read(read: spark_connect::Read) -> eyre::Result { + let spark_connect::Read { + is_streaming, + read_type, + } = read; + + warn!("Ignoring is_streaming: {is_streaming}"); + + let Some(read_type) = read_type else { + bail!("Read type is required"); + }; + + match read_type { + ReadType::NamedTable(table) => { + let name = table.unparsed_identifier; + bail!("Tried to read from table {name} but it is not yet implemented. Try to read from a path instead."); + } + ReadType::DataSource(source) => data_source::data_source(source) + .await + .wrap_err("Failed to create data source"), + } +} diff --git a/src/daft-connect/src/translation/logical_plan/read/data_source.rs b/src/daft-connect/src/translation/logical_plan/read/data_source.rs new file mode 100644 index 0000000000..90164dd0bd --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/read/data_source.rs @@ -0,0 +1,45 @@ +use daft_logical_plan::LogicalPlanBuilder; +use daft_scan::builder::ParquetScanBuilder; +use eyre::{bail, ensure, WrapErr}; +use tracing::warn; + +pub async fn data_source( + data_source: spark_connect::read::DataSource, +) -> eyre::Result { + let spark_connect::read::DataSource { + format, + schema, + options, + paths, + predicates, + } = data_source; + + let Some(format) = format else { + bail!("Format is required"); + }; + + if format != "parquet" { + bail!("Unsupported format: {format}; only parquet is supported"); + } + + ensure!(!paths.is_empty(), "Paths are required"); + + if let Some(schema) = schema { + warn!("Ignoring schema: {schema:?}; not yet implemented"); + } + + if !options.is_empty() { + warn!("Ignoring options: {options:?}; not yet implemented"); + } + + if !predicates.is_empty() { + warn!("Ignoring predicates: {predicates:?}; not yet implemented"); + } + + let builder = ParquetScanBuilder::new(paths) + .finish() + .await + .wrap_err("Failed to create parquet scan builder")?; + + Ok(builder) +} diff --git a/src/daft-connect/src/translation/schema.rs b/src/daft-connect/src/translation/schema.rs index 1b242428d2..20b3e4de74 100644 --- a/src/daft-connect/src/translation/schema.rs +++ b/src/daft-connect/src/translation/schema.rs @@ -7,12 +7,12 @@ use tracing::warn; use crate::translation::{to_logical_plan, to_spark_datatype}; #[tracing::instrument(skip_all)] -pub fn relation_to_schema(input: Relation) -> eyre::Result { +pub async fn relation_to_schema(input: Relation) -> eyre::Result { if input.common.is_some() { warn!("We do not currently look at common fields"); } - let plan = to_logical_plan(input)?; + let plan = Box::pin(to_logical_plan(input)).await?; let result = plan.schema(); diff --git a/tests/connect/test_parquet.py b/tests/connect/test_parquet.py new file mode 100644 index 0000000000..b356254fdf --- /dev/null +++ b/tests/connect/test_parquet.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import os +import shutil +import tempfile + + +def test_write_parquet(spark_session): + # Create a temporary directory + temp_dir = tempfile.mkdtemp() + try: + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Write DataFrame to parquet directory + parquet_dir = os.path.join(temp_dir, "test.parquet") + df.write.parquet(parquet_dir) + + # List all files in the parquet directory + parquet_files = [f for f in os.listdir(parquet_dir) if f.endswith(".parquet")] + print(f"Parquet files in directory: {parquet_files}") + + # Assert there is at least one parquet file + assert len(parquet_files) > 0, "Expected at least one parquet file to be written" + + # Read back from the parquet directory (not specific file) + df_read = spark_session.read.parquet(parquet_dir) + + # Verify the data is unchanged + df_pandas = df.toPandas() + df_read_pandas = df_read.toPandas() + assert df_pandas["id"].equals(df_read_pandas["id"]), "Data should be unchanged after write/read" + + finally: + # Clean up temp directory + shutil.rmtree(temp_dir)