diff --git a/Cargo.lock b/Cargo.lock index 34b37bc81f..fd01681f0b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1980,6 +1980,7 @@ dependencies = [ "async-stream", "color-eyre", "common-daft-config", + "common-file-formats", "daft-core", "daft-dsl", "daft-local-execution", diff --git a/Cargo.toml b/Cargo.toml index e8e9992e2c..b6f5284a60 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -194,6 +194,7 @@ chrono-tz = "0.10.0" comfy-table = "7.1.1" common-daft-config = {path = "src/common/daft-config"} common-error = {path = "src/common/error", default-features = false} +common-file-formats = {path = "src/common/file-formats"} common-runtime = {path = "src/common/runtime", default-features = false} daft-core = {path = "src/daft-core"} daft-dsl = {path = "src/daft-dsl"} diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index ab59be1796..b1d1f63052 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -3,6 +3,7 @@ arrow2 = {workspace = true, features = ["io_json_integration"]} async-stream = "0.3.6" color-eyre = "0.6.3" common-daft-config = {workspace = true} +common-file-formats = {workspace = true} daft-core = {workspace = true} daft-dsl = {workspace = true} daft-local-execution = {workspace = true} diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index efc861b986..439a74dc57 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -142,14 +142,10 @@ impl DaftSparkConnectService { #[tonic::async_trait] impl SparkConnectService for DaftSparkConnectService { type ExecutePlanStream = std::pin::Pin< - Box< - dyn futures::Stream> + Send + Sync + 'static, - >, + Box> + Send + 'static>, >; type ReattachExecuteStream = std::pin::Pin< - Box< - dyn futures::Stream> + Send + Sync + 'static, - >, + Box> + Send + 'static>, >; #[tracing::instrument(skip_all)] @@ -190,8 +186,9 @@ impl SparkConnectService for DaftSparkConnectService { CommandType::RegisterFunction(_) => { unimplemented_err!("RegisterFunction not implemented") } - CommandType::WriteOperation(_) => { - unimplemented_err!("WriteOperation not implemented") + CommandType::WriteOperation(op) => { + let result = session.handle_write_command(op, operation).await?; + return Ok(Response::new(result)); } CommandType::CreateDataframeView(_) => { unimplemented_err!("CreateDataframeView not implemented") @@ -305,7 +302,7 @@ impl SparkConnectService for DaftSparkConnectService { return Err(Status::invalid_argument("op_type is required to be root")); }; - let result = match translation::relation_to_schema(relation) { + let result = match translation::relation_to_schema(relation).await { Ok(schema) => schema, Err(e) => { return invalid_argument_err!( diff --git a/src/daft-connect/src/op/execute.rs b/src/daft-connect/src/op/execute.rs index fba3cc850d..41baf88b09 100644 --- a/src/daft-connect/src/op/execute.rs +++ b/src/daft-connect/src/op/execute.rs @@ -11,6 +11,7 @@ use uuid::Uuid; use crate::{DaftSparkConnectService, Session}; mod root; +mod write; pub type ExecuteStream = ::ExecutePlanStream; diff --git a/src/daft-connect/src/op/execute/root.rs b/src/daft-connect/src/op/execute/root.rs index fcd1f41bb9..d00071522e 100644 --- a/src/daft-connect/src/op/execute/root.rs +++ b/src/daft-connect/src/op/execute/root.rs @@ -32,7 +32,7 @@ impl Session { let (tx, rx) = tokio::sync::mpsc::channel::>(1); tokio::spawn(async move { let execution_fut = async { - let Plan { builder, psets } = translation::to_logical_plan(command)?; + let Plan { builder, psets } = translation::to_logical_plan(command).await?; let optimized_plan = builder.optimize()?; let cfg = Arc::new(DaftExecutionConfig::default()); let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; diff --git a/src/daft-connect/src/op/execute/write.rs b/src/daft-connect/src/op/execute/write.rs new file mode 100644 index 0000000000..44696f8164 --- /dev/null +++ b/src/daft-connect/src/op/execute/write.rs @@ -0,0 +1,145 @@ +use std::future::ready; + +use common_daft_config::DaftExecutionConfig; +use common_file_formats::FileFormat; +use daft_local_execution::NativeExecutor; +use eyre::{bail, WrapErr}; +use spark_connect::{ + write_operation::{SaveMode, SaveType}, + WriteOperation, +}; +use tonic::Status; +use tracing::warn; + +use crate::{ + op::execute::{ExecuteStream, PlanIds}, + session::Session, + translation, +}; + +impl Session { + pub async fn handle_write_command( + &self, + operation: WriteOperation, + operation_id: String, + ) -> Result { + use futures::StreamExt; + + let context = PlanIds { + session: self.client_side_session_id().to_string(), + server_side_session: self.server_side_session_id().to_string(), + operation: operation_id, + }; + + let finished = context.finished(); + + let result = async move { + let WriteOperation { + input, + source, + mode, + sort_column_names, + partitioning_columns, + bucket_by, + options, + clustering_columns, + save_type, + } = operation; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let Some(source) = source else { + bail!("Source is required"); + }; + + if source != "parquet" { + bail!("Unsupported source: {source}; only parquet is supported"); + } + + let Ok(mode) = SaveMode::try_from(mode) else { + bail!("Invalid save mode: {mode}"); + }; + + if !sort_column_names.is_empty() { + // todo(completeness): implement sort + warn!("Ignoring sort_column_names: {sort_column_names:?} (not yet implemented)"); + } + + if !partitioning_columns.is_empty() { + // todo(completeness): implement partitioning + warn!( + "Ignoring partitioning_columns: {partitioning_columns:?} (not yet implemented)" + ); + } + + if let Some(bucket_by) = bucket_by { + // todo(completeness): implement bucketing + warn!("Ignoring bucket_by: {bucket_by:?} (not yet implemented)"); + } + + if !options.is_empty() { + // todo(completeness): implement options + warn!("Ignoring options: {options:?} (not yet implemented)"); + } + + if !clustering_columns.is_empty() { + // todo(completeness): implement clustering + warn!("Ignoring clustering_columns: {clustering_columns:?} (not yet implemented)"); + } + + match mode { + SaveMode::Unspecified => {} + SaveMode::Append => {} + SaveMode::Overwrite => {} + SaveMode::ErrorIfExists => {} + SaveMode::Ignore => {} + } + + let Some(save_type) = save_type else { + bail!("Save type is required"); + }; + + let path = match save_type { + SaveType::Path(path) => path, + SaveType::Table(table) => { + let name = table.table_name; + bail!("Tried to write to table {name} but it is not yet implemented. Try to write to a path instead."); + } + }; + + let mut plan = translation::to_logical_plan(input).await?; + + plan.builder = plan + .builder + .table_write(&path, FileFormat::Parquet, None, None, None) + .wrap_err("Failed to create table write plan")?; + + let optimized_plan = plan.builder.optimize()?; + let cfg = DaftExecutionConfig::default(); + let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; + let mut result_stream = native_executor + .run(plan.psets, cfg.into(), None)? + .into_stream(); + + // this is so we make sure the operation is actually done + // before we return + // + // an example where this is important is if we write to a parquet file + // and then read immediately after, we need to wait for the write to finish + while let Some(_result) = result_stream.next().await {} + + Ok(()) + }; + + use futures::TryFutureExt; + + let result = result.map_err(|e| Status::internal(format!("Error in Daft server: {e:?}"))); + + let future = result.and_then(|()| ready(Ok(finished))); + let stream = futures::stream::once(future); + + Ok(Box::pin(stream)) + } +} diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 0af833c54d..9cc96172ea 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -6,13 +6,14 @@ use tracing::warn; use crate::translation::logical_plan::{ aggregate::aggregate, local_relation::local_relation, project::project, range::range, - to_df::to_df, with_columns::with_columns, + read::read, to_df::to_df, with_columns::with_columns, }; mod aggregate; mod local_relation; mod project; mod range; +mod read; mod to_df; mod with_columns; @@ -39,7 +40,7 @@ impl From for Plan { } } -pub fn to_logical_plan(relation: Relation) -> eyre::Result { +pub async fn to_logical_plan(relation: Relation) -> eyre::Result { if let Some(common) = relation.common { if common.origin.is_some() { warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); @@ -51,31 +52,40 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { }; match rel_type { - RelType::Limit(l) => limit(*l).wrap_err("Failed to apply limit to logical plan"), + RelType::Limit(l) => limit(*l) + .await + .wrap_err("Failed to apply limit to logical plan"), RelType::Range(r) => range(r).wrap_err("Failed to apply range to logical plan"), - RelType::Project(p) => project(*p).wrap_err("Failed to apply project to logical plan"), - RelType::Aggregate(a) => { - aggregate(*a).wrap_err("Failed to apply aggregate to logical plan") - } + RelType::Project(p) => project(*p) + .await + .wrap_err("Failed to apply project to logical plan"), + RelType::Aggregate(a) => aggregate(*a) + .await + .wrap_err("Failed to apply aggregate to logical plan"), RelType::WithColumns(w) => { with_columns(*w).wrap_err("Failed to apply with_columns to logical plan") } - RelType::ToDf(t) => to_df(*t).wrap_err("Failed to apply to_df to logical plan"), + RelType::ToDf(t) => to_df(*t) + .await + .wrap_err("Failed to apply to_df to logical plan"), RelType::LocalRelation(l) => { local_relation(l).wrap_err("Failed to apply local_relation to logical plan") } + RelType::Read(r) => read(r) + .await + .wrap_err("Failed to apply read to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } } -fn limit(limit: Limit) -> eyre::Result { +async fn limit(limit: Limit) -> eyre::Result { let Limit { input, limit } = limit; let Some(input) = input else { bail!("input must be set"); }; - let mut plan = to_logical_plan(*input)?; + let mut plan = Box::pin(to_logical_plan(*input)).await?; plan.builder = plan.builder.limit(i64::from(limit), false)?; // todo: eager or no Ok(plan) diff --git a/src/daft-connect/src/translation/logical_plan/aggregate.rs b/src/daft-connect/src/translation/logical_plan/aggregate.rs index 5fbf9d6a9b..3687f191f8 100644 --- a/src/daft-connect/src/translation/logical_plan/aggregate.rs +++ b/src/daft-connect/src/translation/logical_plan/aggregate.rs @@ -3,7 +3,7 @@ use spark_connect::aggregate::GroupType; use crate::translation::{logical_plan::Plan, to_daft_expr, to_logical_plan}; -pub fn aggregate(aggregate: spark_connect::Aggregate) -> eyre::Result { +pub async fn aggregate(aggregate: spark_connect::Aggregate) -> eyre::Result { let spark_connect::Aggregate { input, group_type, @@ -17,7 +17,7 @@ pub fn aggregate(aggregate: spark_connect::Aggregate) -> eyre::Result { bail!("input is required"); }; - let mut plan = to_logical_plan(*input)?; + let mut plan = Box::pin(to_logical_plan(*input)).await?; let group_type = GroupType::try_from(group_type) .wrap_err_with(|| format!("Invalid group type: {group_type:?}"))?; diff --git a/src/daft-connect/src/translation/logical_plan/project.rs b/src/daft-connect/src/translation/logical_plan/project.rs index b5c1a136ec..af03c8dc2e 100644 --- a/src/daft-connect/src/translation/logical_plan/project.rs +++ b/src/daft-connect/src/translation/logical_plan/project.rs @@ -8,14 +8,14 @@ use spark_connect::Project; use crate::translation::{logical_plan::Plan, to_daft_expr, to_logical_plan}; -pub fn project(project: Project) -> eyre::Result { +pub async fn project(project: Project) -> eyre::Result { let Project { input, expressions } = project; let Some(input) = input else { bail!("Project input is required"); }; - let mut plan = to_logical_plan(*input)?; + let mut plan = Box::pin(to_logical_plan(*input)).await?; let daft_exprs: Vec<_> = expressions.iter().map(to_daft_expr).try_collect()?; plan.builder = plan.builder.select(daft_exprs)?; diff --git a/src/daft-connect/src/translation/logical_plan/read.rs b/src/daft-connect/src/translation/logical_plan/read.rs new file mode 100644 index 0000000000..fc8a834fbb --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/read.rs @@ -0,0 +1,32 @@ +use eyre::{bail, WrapErr}; +use spark_connect::read::ReadType; +use tracing::warn; + +use crate::translation::Plan; + +mod data_source; + +pub async fn read(read: spark_connect::Read) -> eyre::Result { + let spark_connect::Read { + is_streaming, + read_type, + } = read; + + warn!("Ignoring is_streaming: {is_streaming}"); + + let Some(read_type) = read_type else { + bail!("Read type is required"); + }; + + let builder = match read_type { + ReadType::NamedTable(table) => { + let name = table.unparsed_identifier; + bail!("Tried to read from table {name} but it is not yet implemented. Try to read from a path instead."); + } + ReadType::DataSource(source) => data_source::data_source(source) + .await + .wrap_err("Failed to create data source"), + }?; + + Ok(Plan::from(builder)) +} diff --git a/src/daft-connect/src/translation/logical_plan/read/data_source.rs b/src/daft-connect/src/translation/logical_plan/read/data_source.rs new file mode 100644 index 0000000000..90164dd0bd --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/read/data_source.rs @@ -0,0 +1,45 @@ +use daft_logical_plan::LogicalPlanBuilder; +use daft_scan::builder::ParquetScanBuilder; +use eyre::{bail, ensure, WrapErr}; +use tracing::warn; + +pub async fn data_source( + data_source: spark_connect::read::DataSource, +) -> eyre::Result { + let spark_connect::read::DataSource { + format, + schema, + options, + paths, + predicates, + } = data_source; + + let Some(format) = format else { + bail!("Format is required"); + }; + + if format != "parquet" { + bail!("Unsupported format: {format}; only parquet is supported"); + } + + ensure!(!paths.is_empty(), "Paths are required"); + + if let Some(schema) = schema { + warn!("Ignoring schema: {schema:?}; not yet implemented"); + } + + if !options.is_empty() { + warn!("Ignoring options: {options:?}; not yet implemented"); + } + + if !predicates.is_empty() { + warn!("Ignoring predicates: {predicates:?}; not yet implemented"); + } + + let builder = ParquetScanBuilder::new(paths) + .finish() + .await + .wrap_err("Failed to create parquet scan builder")?; + + Ok(builder) +} diff --git a/src/daft-connect/src/translation/logical_plan/to_df.rs b/src/daft-connect/src/translation/logical_plan/to_df.rs index 63ad58f10f..c2a355a1e5 100644 --- a/src/daft-connect/src/translation/logical_plan/to_df.rs +++ b/src/daft-connect/src/translation/logical_plan/to_df.rs @@ -2,7 +2,7 @@ use eyre::{bail, WrapErr}; use crate::translation::{logical_plan::Plan, to_logical_plan}; -pub fn to_df(to_df: spark_connect::ToDf) -> eyre::Result { +pub async fn to_df(to_df: spark_connect::ToDf) -> eyre::Result { let spark_connect::ToDf { input, column_names, @@ -12,8 +12,9 @@ pub fn to_df(to_df: spark_connect::ToDf) -> eyre::Result { bail!("Input is required"); }; - let mut plan = - to_logical_plan(*input).wrap_err("Failed to translate relation to logical plan")?; + let mut plan = Box::pin(to_logical_plan(*input)) + .await + .wrap_err("Failed to translate relation to logical plan")?; let column_names: Vec<_> = column_names .iter() diff --git a/src/daft-connect/src/translation/schema.rs b/src/daft-connect/src/translation/schema.rs index a43d165b60..1868eaeb2d 100644 --- a/src/daft-connect/src/translation/schema.rs +++ b/src/daft-connect/src/translation/schema.rs @@ -7,14 +7,14 @@ use tracing::warn; use crate::translation::{to_logical_plan, to_spark_datatype}; #[tracing::instrument(skip_all)] -pub fn relation_to_schema(input: Relation) -> eyre::Result { +pub async fn relation_to_schema(input: Relation) -> eyre::Result { if let Some(common) = &input.common { if common.origin.is_some() { warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); } } - let plan = to_logical_plan(input)?; + let plan = Box::pin(to_logical_plan(input)).await?; let result = plan.builder.schema(); diff --git a/tests/connect/test_parquet.py b/tests/connect/test_parquet.py new file mode 100644 index 0000000000..b356254fdf --- /dev/null +++ b/tests/connect/test_parquet.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import os +import shutil +import tempfile + + +def test_write_parquet(spark_session): + # Create a temporary directory + temp_dir = tempfile.mkdtemp() + try: + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Write DataFrame to parquet directory + parquet_dir = os.path.join(temp_dir, "test.parquet") + df.write.parquet(parquet_dir) + + # List all files in the parquet directory + parquet_files = [f for f in os.listdir(parquet_dir) if f.endswith(".parquet")] + print(f"Parquet files in directory: {parquet_files}") + + # Assert there is at least one parquet file + assert len(parquet_files) > 0, "Expected at least one parquet file to be written" + + # Read back from the parquet directory (not specific file) + df_read = spark_session.read.parquet(parquet_dir) + + # Verify the data is unchanged + df_pandas = df.toPandas() + df_read_pandas = df_read.toPandas() + assert df_pandas["id"].equals(df_read_pandas["id"]), "Data should be unchanged after write/read" + + finally: + # Clean up temp directory + shutil.rmtree(temp_dir)