From 40e4b4827a011ae075505a4f2fc11311ba6325de Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 02:26:15 -0800 Subject: [PATCH] [FEAT] connect: add parquet support --- Cargo.lock | 1 + Cargo.toml | 1 + src/daft-connect/Cargo.toml | 1 + src/daft-connect/src/lib.rs | 5 +- src/daft-connect/src/op/execute.rs | 1 + src/daft-connect/src/op/execute/write.rs | 206 ++++++++++++++++++ .../src/translation/logical_plan.rs | 4 +- .../src/translation/logical_plan/read.rs | 29 +++ .../logical_plan/read/data_source.rs | 42 ++++ tests/connect/test_write.py | 36 +++ 10 files changed, 323 insertions(+), 3 deletions(-) create mode 100644 src/daft-connect/src/op/execute/write.rs create mode 100644 src/daft-connect/src/translation/logical_plan/read.rs create mode 100644 src/daft-connect/src/translation/logical_plan/read/data_source.rs create mode 100644 tests/connect/test_write.py diff --git a/Cargo.lock b/Cargo.lock index a040694975..83169134b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1933,6 +1933,7 @@ version = "0.3.0-dev0" dependencies = [ "arrow2", "common-daft-config", + "common-file-formats", "daft-core", "daft-dsl", "daft-local-execution", diff --git a/Cargo.toml b/Cargo.toml index be1146166a..18532c92b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -194,6 +194,7 @@ chrono-tz = "0.8.4" comfy-table = "7.1.1" common-daft-config = {path = "src/common/daft-config"} common-error = {path = "src/common/error", default-features = false} +common-file-formats = {path = "src/common/file-formats"} daft-core = {path = "src/daft-core"} daft-dsl = {path = "src/daft-dsl"} daft-hash = {path = "src/daft-hash"} diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index 0d657d1c5f..326fac04ad 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -1,6 +1,7 @@ [dependencies] arrow2 = {workspace = true} common-daft-config = {workspace = true} +common-file-formats = {workspace = true} daft-core = {workspace = true} daft-dsl = {workspace = true} daft-local-execution = {workspace = true} diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 15c7442c77..1ef94fa159 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -167,8 +167,9 @@ impl SparkConnectService for DaftSparkConnectService { CommandType::RegisterFunction(_) => { unimplemented_err!("RegisterFunction not implemented") } - CommandType::WriteOperation(_) => { - unimplemented_err!("WriteOperation not implemented") + CommandType::WriteOperation(op) => { + let result = session.handle_write_command(op, operation).await?; + return Ok(Response::new(result)) } CommandType::CreateDataframeView(_) => { unimplemented_err!("CreateDataframeView not implemented") diff --git a/src/daft-connect/src/op/execute.rs b/src/daft-connect/src/op/execute.rs index fba3cc850d..41baf88b09 100644 --- a/src/daft-connect/src/op/execute.rs +++ b/src/daft-connect/src/op/execute.rs @@ -11,6 +11,7 @@ use uuid::Uuid; use crate::{DaftSparkConnectService, Session}; mod root; +mod write; pub type ExecuteStream = ::ExecutePlanStream; diff --git a/src/daft-connect/src/op/execute/write.rs b/src/daft-connect/src/op/execute/write.rs new file mode 100644 index 0000000000..7566b33273 --- /dev/null +++ b/src/daft-connect/src/op/execute/write.rs @@ -0,0 +1,206 @@ +use std::{collections::HashMap, future::ready}; + +use common_daft_config::DaftExecutionConfig; +use common_file_formats::FileFormat; +use eyre::{bail, WrapErr}; +use futures::stream; +use spark_connect::{ + write_operation::{SaveMode, SaveType}, + ExecutePlanResponse, Relation, WriteOperation, +}; +use tokio_util::sync::CancellationToken; +use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status}; +use tracing::warn; + +use crate::{ + invalid_argument_err, + op::execute::{ExecuteStream, PlanIds}, + session::Session, + translation, +}; + +impl Session { + pub async fn handle_write_command( + &self, + operation: WriteOperation, + operation_id: String, + ) -> Result { + use futures::{StreamExt, TryStreamExt}; + + let context = PlanIds { + session: self.client_side_session_id().to_string(), + server_side_session: self.server_side_session_id().to_string(), + operation: operation_id, + }; + + let finished = context.finished(); + + // operation: WriteOperation { + // input: Some( + // Relation { + // common: Some( + // RelationCommon { + // source_info: "", + // plan_id: Some( + // 0, + // ), + // origin: None, + // }, + // ), + // rel_type: Some( + // Range( + // Range { + // start: Some( + // 0, + // ), + // end: 10, + // step: 1, + // num_partitions: None, + // }, + // ), + // ), + // }, + // ), + // source: Some( + // "parquet", + // ), + // mode: Unspecified, + // sort_column_names: [], + // partitioning_columns: [], + // bucket_by: None, + // options: {}, + // clustering_columns: [], + // save_type: Some( + // Path( + // "/var/folders/zy/g1zccty96bg_frmz9x0198zh0000gn/T/tmpxki7yyr0/test.parquet", + // ), + // ), + // } + + let (tx, rx) = tokio::sync::mpsc::channel::>(16); + std::thread::spawn(move || { + let result = (|| -> eyre::Result<()> { + let WriteOperation { + input, + source, + mode, + sort_column_names, + partitioning_columns, + bucket_by, + options, + clustering_columns, + save_type, + } = operation; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let Some(source) = source else { + bail!("Source is required"); + }; + + if source != "parquet" { + bail!("Unsupported source: {source}; only parquet is supported"); + } + + let Ok(mode) = SaveMode::try_from(mode) else { + bail!("Invalid save mode: {mode}"); + }; + + if !sort_column_names.is_empty() { + // todo(completeness): implement sort + warn!( + "Ignoring sort_column_names: {sort_column_names:?} (not yet implemented)" + ); + } + + if !partitioning_columns.is_empty() { + // todo(completeness): implement partitioning + warn!("Ignoring partitioning_columns: {partitioning_columns:?} (not yet implemented)"); + } + + if let Some(bucket_by) = bucket_by { + // todo(completeness): implement bucketing + warn!("Ignoring bucket_by: {bucket_by:?} (not yet implemented)"); + } + + if !options.is_empty() { + // todo(completeness): implement options + warn!("Ignoring options: {options:?} (not yet implemented)"); + } + + if !clustering_columns.is_empty() { + // todo(completeness): implement clustering + warn!( + "Ignoring clustering_columns: {clustering_columns:?} (not yet implemented)" + ); + } + + match mode { + SaveMode::Unspecified => {} + SaveMode::Append => {} + SaveMode::Overwrite => {} + SaveMode::ErrorIfExists => {} + SaveMode::Ignore => {} + } + + let Some(save_type) = save_type else { + return bail!("Save type is required"); + }; + + let path = match save_type { + SaveType::Path(path) => path, + SaveType::Table(table) => { + let name = table.table_name; + bail!("Tried to write to table {name} but it is not yet implemented. Try to write to a path instead."); + } + }; + + let plan = translation::to_logical_plan(input)?; + + let plan = plan + .table_write(&path, FileFormat::Parquet, None, None, None) + .wrap_err("Failed to create table write plan")?; + + let logical_plan = plan.build(); + let physical_plan = daft_local_plan::translate(&logical_plan)?; + + let cfg = DaftExecutionConfig::default(); + + // "hot" flow not a "cold" flow + let iterator = daft_local_execution::run_local( + &physical_plan, + HashMap::new(), + cfg.into(), + None, + CancellationToken::new(), // todo: maybe implement cancelling + )?; + + for _ignored in iterator { + + } + + // this is so we make sure the operation is actually done + // before we return + // + // an example where this is important is if we write to a parquet file + // and then read immediately after, we need to wait for the write to finish + + Ok(()) + })(); + + if let Err(e) = result { + tx.blocking_send(Err(e)).unwrap(); + } + }); + + let stream = ReceiverStream::new(rx); + + let stream = stream + .map_err(|e| Status::internal(format!("Error in Daft server: {e:?}"))) + .chain(stream::once(ready(Ok(finished)))); + + Ok(Box::pin(stream)) + } +} diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 947e0cd0d3..6ab9ca587f 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -3,10 +3,11 @@ use eyre::{bail, Context}; use spark_connect::{relation::RelType, Relation}; use tracing::warn; -use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range}; +use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range, read::read}; mod aggregate; mod project; +mod read; mod range; pub fn to_logical_plan(relation: Relation) -> eyre::Result { @@ -24,6 +25,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::Aggregate(a) => { aggregate(*a).wrap_err("Failed to apply aggregate to logical plan") } + RelType::Read(r) => read(r).wrap_err("Failed to apply table read to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } } diff --git a/src/daft-connect/src/translation/logical_plan/read.rs b/src/daft-connect/src/translation/logical_plan/read.rs new file mode 100644 index 0000000000..199d77da4b --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/read.rs @@ -0,0 +1,29 @@ +use daft_logical_plan::LogicalPlanBuilder; +use eyre::{bail, WrapErr}; +use spark_connect::read::ReadType; +use tracing::warn; + +mod data_source; + +pub fn read(read: spark_connect::Read) -> eyre::Result { + let spark_connect::Read { + is_streaming, + read_type, + } = read; + + warn!("Ignoring is_streaming: {is_streaming}"); + + let Some(read_type) = read_type else { + bail!("Read type is required"); + }; + + match read_type { + ReadType::NamedTable(table) => { + let name = table.unparsed_identifier; + bail!("Tried to read from table {name} but it is not yet implemented. Try to read from a path instead."); + } + ReadType::DataSource(source) => { + data_source::data_source(source).wrap_err("Failed to create data source") + } + } +} diff --git a/src/daft-connect/src/translation/logical_plan/read/data_source.rs b/src/daft-connect/src/translation/logical_plan/read/data_source.rs new file mode 100644 index 0000000000..0a9a14c494 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/read/data_source.rs @@ -0,0 +1,42 @@ +use daft_logical_plan::LogicalPlanBuilder; +use daft_scan::builder::ParquetScanBuilder; +use eyre::{bail, ensure, WrapErr}; +use tracing::warn; + +pub fn data_source(data_source: spark_connect::read::DataSource) -> eyre::Result { + let spark_connect::read::DataSource { + format, + schema, + options, + paths, + predicates, + } = data_source; + + let Some(format) = format else { + bail!("Format is required"); + }; + + if format != "parquet" { + bail!("Unsupported format: {format}; only parquet is supported"); + } + + ensure!(!paths.is_empty(), "Paths are required"); + + if let Some(schema) = schema { + warn!("Ignoring schema: {schema:?}; not yet implemented"); + } + + if !options.is_empty() { + warn!("Ignoring options: {options:?}; not yet implemented"); + } + + if !predicates.is_empty() { + warn!("Ignoring predicates: {predicates:?}; not yet implemented"); + } + + let builder = ParquetScanBuilder::new(paths) + .finish() + .wrap_err("Failed to create parquet scan builder")?; + + Ok(builder) +} diff --git a/tests/connect/test_write.py b/tests/connect/test_write.py new file mode 100644 index 0000000000..75eae16571 --- /dev/null +++ b/tests/connect/test_write.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import tempfile +import shutil +import os + + +def test_write_parquet(spark_session): + # Create a temporary directory + temp_dir = tempfile.mkdtemp() + try: + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Write DataFrame to parquet directory + parquet_dir = os.path.join(temp_dir, "test.parquet") + df.write.parquet(parquet_dir) + + # List all files in the parquet directory + parquet_files = [f for f in os.listdir(parquet_dir) if f.endswith('.parquet')] + print(f"Parquet files in directory: {parquet_files}") + + # Assert there is at least one parquet file + assert len(parquet_files) > 0, "Expected at least one parquet file to be written" + + # Read back from the parquet directory (not specific file) + df_read = spark_session.read.parquet(parquet_dir) + + # Verify the data is unchanged + df_pandas = df.toPandas() + df_read_pandas = df_read.toPandas() + assert df_pandas["id"].equals(df_read_pandas["id"]), "Data should be unchanged after write/read" + + finally: + # Clean up temp directory + shutil.rmtree(temp_dir)