From 69b2c290cbb0489b86f312b2102e4389e0893d26 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 03:27:48 -0800 Subject: [PATCH] [FEAT] (WIP) connect: createDataFrame --- Cargo.lock | 64 ++++- Cargo.toml | 1 + daft/dataframe/dataframe.py | 28 ++ src/arrow2/src/io/ipc/read/common.rs | 2 +- src/arrow2/src/io/ipc/read/mod.rs | 2 +- src/daft-connect/Cargo.toml | 9 +- src/daft-connect/src/lib.rs | 22 +- src/daft-connect/src/op/execute/root.rs | 11 +- src/daft-connect/src/translation.rs | 4 +- src/daft-connect/src/translation/datatype.rs | 3 + .../src/translation/datatype/codec.rs | 244 ++++++++++++++++++ .../src/translation/logical_plan.rs | 37 ++- .../src/translation/logical_plan/aggregate.rs | 13 +- .../logical_plan/local_relation.rs | 201 +++++++++++++++ .../src/translation/logical_plan/project.rs | 10 +- .../src/translation/logical_plan/range.rs | 6 +- .../src/translation/logical_plan/to_df.rs | 29 +++ src/daft-connect/src/translation/schema.rs | 2 +- src/daft-local-execution/src/pipeline.rs | 4 +- src/daft-logical-plan/src/builder.rs | 2 +- tests/connect/conftest.py | 1 + tests/connect/test_create_df.py | 35 +++ 22 files changed, 679 insertions(+), 51 deletions(-) create mode 100644 src/daft-connect/src/translation/datatype/codec.rs create mode 100644 src/daft-connect/src/translation/logical_plan/local_relation.rs create mode 100644 src/daft-connect/src/translation/logical_plan/to_df.rs create mode 100644 tests/connect/test_create_df.py diff --git a/Cargo.lock b/Cargo.lock index 690ac121e8..b7f8831220 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10,9 +10,9 @@ checksum = "8b5ace29ee3216de37c0546865ad08edef58b0f9e76838ed8959a84a990e58c5" [[package]] name = "addr2line" -version = "0.22.0" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" dependencies = [ "gimli", ] @@ -983,9 +983,9 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.73" +version = "0.3.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a" +checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" dependencies = [ "addr2line", "cc", @@ -1391,6 +1391,33 @@ dependencies = [ "cc", ] +[[package]] +name = "color-eyre" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55146f5e46f237f7423d74111267d4597b59b0dad0ffaf7303bce9945d843ad5" +dependencies = [ + "backtrace", + "color-spantrace", + "eyre", + "indenter", + "once_cell", + "owo-colors", + "tracing-error", +] + +[[package]] +name = "color-spantrace" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd6be1b2a7e382e2b98b43b2adcca6bb0e465af0bdd38123873ae61eb17a72c2" +dependencies = [ + "once_cell", + "owo-colors", + "tracing-core", + "tracing-error", +] + [[package]] name = "color_quant" version = "1.1.0" @@ -1951,18 +1978,23 @@ version = "0.3.0-dev0" dependencies = [ "arrow2", "async-stream", + "color-eyre", "common-daft-config", "daft-core", "daft-dsl", "daft-local-execution", "daft-logical-plan", + "daft-micropartition", "daft-scan", "daft-schema", "daft-table", "dashmap", + "derive_more", "eyre", "futures", + "itertools 0.11.0", "pyo3", + "serde_json", "spark-connect", "tokio", "tonic", @@ -3128,9 +3160,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.29.0" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "glob" @@ -4428,9 +4460,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.3" +version = "0.32.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27b64972346851a39438c60b341ebc01bba47464ae329e55cf343eb93964efd9" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" dependencies = [ "memchr", ] @@ -4536,6 +4568,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "owo-colors" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" + [[package]] name = "parking" version = "2.2.0" @@ -6589,6 +6627,16 @@ dependencies = [ "valuable", ] +[[package]] +name = "tracing-error" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b1581020d7a273442f5b45074a6a57d5757ad0a47dac0e9f0bd57b81936f3db" +dependencies = [ + "tracing", + "tracing-subscriber", +] + [[package]] name = "tracing-log" version = "0.2.0" diff --git a/Cargo.toml b/Cargo.toml index 67334d8b0d..e8e9992e2c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -200,6 +200,7 @@ daft-dsl = {path = "src/daft-dsl"} daft-hash = {path = "src/daft-hash"} daft-local-execution = {path = "src/daft-local-execution"} daft-logical-plan = {path = "src/daft-logical-plan"} +daft-micropartition = {path = "src/daft-micropartition"} daft-scan = {path = "src/daft-scan"} daft-schema = {path = "src/daft-schema"} daft-table = {path = "src/daft-table"} diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 0961fe0b0a..bdb4cb81ff 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -63,6 +63,34 @@ ManyColumnsInputType = Union[ColumnInputType, Iterable[ColumnInputType]] +def to_logical_plan_builder(*parts: MicroPartition) -> LogicalPlanBuilder: + """Creates a Daft DataFrame from a single Table. + + Args: + parts: The Tables that we wish to convert into a Daft DataFrame. + + Returns: + DataFrame: Daft DataFrame created from the provided Table. + """ + if not parts: + raise ValueError("Can't create a DataFrame from an empty list of tables.") + + result_pset = LocalPartitionSet() + + for i, part in enumerate(parts): + result_pset.set_partition_from_table(i, part) + + context = get_context() + cache_entry = context.get_or_create_runner().put_partition_set_into_cache(result_pset) + size_bytes = result_pset.size_bytes() + num_rows = len(result_pset) + + assert size_bytes is not None, "In-memory data should always have non-None size in bytes" + return LogicalPlanBuilder.from_in_memory_scan( + cache_entry, parts[0].schema(), result_pset.num_partitions(), size_bytes, num_rows=num_rows + ) + + class DataFrame: """A Daft DataFrame is a table of data. It has columns, where each column has a type and the same number of items (rows) as all other columns. diff --git a/src/arrow2/src/io/ipc/read/common.rs b/src/arrow2/src/io/ipc/read/common.rs index dc91c40257..bcfabf6760 100644 --- a/src/arrow2/src/io/ipc/read/common.rs +++ b/src/arrow2/src/io/ipc/read/common.rs @@ -87,7 +87,7 @@ pub fn read_record_batch( file_size: u64, scratch: &mut Vec, ) -> Result>> { - assert_eq!(fields.len(), ipc_schema.fields.len()); + assert_eq!(fields.len(), ipc_schema.fields.len(), "IPC schema fields and Arrow schema fields must be the same length"); let buffers = batch .buffers() .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferBuffers(err)))? diff --git a/src/arrow2/src/io/ipc/read/mod.rs b/src/arrow2/src/io/ipc/read/mod.rs index 887cf7b362..25c0eafae0 100644 --- a/src/arrow2/src/io/ipc/read/mod.rs +++ b/src/arrow2/src/io/ipc/read/mod.rs @@ -42,4 +42,4 @@ pub type Dictionaries = AHashMap>; pub(crate) type Node<'a> = arrow_format::ipc::FieldNodeRef<'a>; pub(crate) type IpcBuffer<'a> = arrow_format::ipc::BufferRef<'a>; pub(crate) type Compression<'a> = arrow_format::ipc::BodyCompressionRef<'a>; -pub(crate) type Version = arrow_format::ipc::MetadataVersion; +pub type Version = arrow_format::ipc::MetadataVersion; diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index 9651106968..f06a9dc9af 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -1,18 +1,23 @@ [dependencies] -arrow2 = {workspace = true} +arrow2 = {workspace = true, features = ["io_json_integration"]} async-stream = "0.3.6" +color-eyre = "0.6.3" common-daft-config = {workspace = true} daft-core = {workspace = true} daft-dsl = {workspace = true} daft-local-execution = {workspace = true} daft-logical-plan = {workspace = true} +daft-micropartition = {workspace = true} daft-scan = {workspace = true} daft-schema = {workspace = true} daft-table = {workspace = true} dashmap = "6.1.0" +derive_more = {workspace = true} eyre = "0.6.12" futures = "0.3.31" +itertools = {workspace = true} pyo3 = {workspace = true, optional = true} +serde_json = {workspace = true} spark-connect = {workspace = true} tokio = {version = "1.40.0", features = ["full"]} tonic = "0.12.3" @@ -20,7 +25,7 @@ tracing = {workspace = true} uuid = {version = "1.10.0", features = ["v4"]} [features] -python = ["dep:pyo3", "common-daft-config/python", "daft-local-execution/python", "daft-logical-plan/python", "daft-scan/python", "daft-table/python", "daft-dsl/python", "daft-schema/python", "daft-core/python"] +python = ["dep:pyo3", "common-daft-config/python", "daft-local-execution/python", "daft-logical-plan/python", "daft-scan/python", "daft-table/python", "daft-dsl/python", "daft-schema/python", "daft-core/python", "daft-micropartition/python"] [lints] workspace = true diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index a5c842e1cb..efc861b986 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -5,7 +5,6 @@ #![feature(iter_from_coroutine)] #![feature(stmt_expr_attributes)] #![feature(try_trait_v2_residual)] -#![deny(clippy::print_stdout)] use dashmap::DashMap; use eyre::Context; @@ -23,7 +22,7 @@ use spark_connect::{ ReleaseExecuteResponse, ReleaseSessionRequest, ReleaseSessionResponse, }; use tonic::{transport::Server, Request, Response, Status}; -use tracing::{debug, info}; +use tracing::info; use uuid::Uuid; use crate::session::Session; @@ -325,8 +324,6 @@ impl SparkConnectService for DaftSparkConnectService { result: Some(analyze_plan_response::Result::Schema(schema)), }; - debug!("response: {response:#?}"); - Ok(Response::new(response)) } _ => unimplemented_err!("Analyze plan operation is not yet implemented"), @@ -346,7 +343,6 @@ impl SparkConnectService for DaftSparkConnectService { &self, _request: Request, ) -> Result, Status> { - println!("got interrupt"); unimplemented_err!("interrupt operation is not yet implemented") } @@ -361,9 +357,19 @@ impl SparkConnectService for DaftSparkConnectService { #[tracing::instrument(skip_all)] async fn release_execute( &self, - _request: Request, + request: Request, ) -> Result, Status> { - unimplemented_err!("release_execute operation is not yet implemented") + let request = request.into_inner(); + + let session = self.get_session(&request.session_id)?; + + let response = ReleaseExecuteResponse { + session_id: session.client_side_session_id().to_string(), + server_side_session_id: session.server_side_session_id().to_string(), + operation_id: None, // todo: set but not strictly required + }; + + Ok(Response::new(response)) } #[tracing::instrument(skip_all)] @@ -371,7 +377,6 @@ impl SparkConnectService for DaftSparkConnectService { &self, _request: Request, ) -> Result, Status> { - println!("got release session"); unimplemented_err!("release_session operation is not yet implemented") } @@ -380,7 +385,6 @@ impl SparkConnectService for DaftSparkConnectService { &self, _request: Request, ) -> Result, Status> { - println!("got fetch error details"); unimplemented_err!("fetch_error_details operation is not yet implemented") } } diff --git a/src/daft-connect/src/op/execute/root.rs b/src/daft-connect/src/op/execute/root.rs index 4f765243c8..fcd1f41bb9 100644 --- a/src/daft-connect/src/op/execute/root.rs +++ b/src/daft-connect/src/op/execute/root.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, future::ready, sync::Arc}; +use std::{future::ready, sync::Arc}; use common_daft_config::DaftExecutionConfig; use daft_local_execution::NativeExecutor; @@ -10,6 +10,7 @@ use crate::{ op::execute::{ExecuteStream, PlanIds}, session::Session, translation, + translation::Plan, }; impl Session { @@ -31,13 +32,11 @@ impl Session { let (tx, rx) = tokio::sync::mpsc::channel::>(1); tokio::spawn(async move { let execution_fut = async { - let plan = translation::to_logical_plan(command)?; - let optimized_plan = plan.optimize()?; + let Plan { builder, psets } = translation::to_logical_plan(command)?; + let optimized_plan = builder.optimize()?; let cfg = Arc::new(DaftExecutionConfig::default()); let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; - let mut result_stream = native_executor - .run(HashMap::new(), cfg, None)? - .into_stream(); + let mut result_stream = native_executor.run(psets, cfg, None)?.into_stream(); while let Some(result) = result_stream.next().await { let result = result?; diff --git a/src/daft-connect/src/translation.rs b/src/daft-connect/src/translation.rs index a03fe113a7..8b61b93f98 100644 --- a/src/daft-connect/src/translation.rs +++ b/src/daft-connect/src/translation.rs @@ -6,8 +6,8 @@ mod literal; mod logical_plan; mod schema; -pub use datatype::{to_daft_datatype, to_spark_datatype}; +pub use datatype::{deser_spark_datatype, to_daft_datatype, to_spark_datatype}; pub use expr::to_daft_expr; pub use literal::to_daft_literal; -pub use logical_plan::to_logical_plan; +pub use logical_plan::{to_logical_plan, Plan}; pub use schema::relation_to_schema; diff --git a/src/daft-connect/src/translation/datatype.rs b/src/daft-connect/src/translation/datatype.rs index d6e51250c7..722e66c3b3 100644 --- a/src/daft-connect/src/translation/datatype.rs +++ b/src/daft-connect/src/translation/datatype.rs @@ -3,6 +3,9 @@ use eyre::{bail, ensure, WrapErr}; use spark_connect::data_type::Kind; use tracing::warn; +mod codec; +pub use codec::deser as deser_spark_datatype; + pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType { match datatype { DataType::Null => spark_connect::DataType { diff --git a/src/daft-connect/src/translation/datatype/codec.rs b/src/daft-connect/src/translation/datatype/codec.rs new file mode 100644 index 0000000000..50f2d94a02 --- /dev/null +++ b/src/daft-connect/src/translation/datatype/codec.rs @@ -0,0 +1,244 @@ +use color_eyre::Help; +use eyre::{bail, ensure, eyre}; +use serde_json::Value; +use spark_connect::data_type::Kind; +use tracing::warn; + +#[derive(Debug)] +enum TypeTag { + Null, + Binary, + Boolean, + Byte, + Short, + Integer, + Long, + Float, + Double, + Decimal, + String, + Char, + VarChar, + Date, + Timestamp, + TimestampNtz, + CalendarInterval, + YearMonthInterval, + DayTimeInterval, + Array, + Struct, + Map, + Variant, + Udt, + Unparsed, +} + +fn remove_type(input: &mut serde_json::Map) -> eyre::Result { + let Some(r#type) = input.remove("type") else { + bail!("missing type"); + }; + + let Value::String(r#type) = r#type else { + bail!("expected type to be string; instead got {:?}", r#type); + }; + + let result = match r#type.as_str() { + "null" => TypeTag::Null, + "binary" => TypeTag::Binary, + "boolean" => TypeTag::Boolean, + "byte" => TypeTag::Byte, + "short" => TypeTag::Short, + "integer" => TypeTag::Integer, + "long" => TypeTag::Long, + "float" => TypeTag::Float, + "double" => TypeTag::Double, + "decimal" => TypeTag::Decimal, + "string" => TypeTag::String, + "char" => TypeTag::Char, + "varchar" => TypeTag::VarChar, + "date" => TypeTag::Date, + "timestamp" => TypeTag::Timestamp, + "timestamp_ntz" => TypeTag::TimestampNtz, + "calendar_interval" => TypeTag::CalendarInterval, + "yearmonthinterval" => TypeTag::YearMonthInterval, + "daytimeinterval" => TypeTag::DayTimeInterval, + "array" => TypeTag::Array, + "struct" => TypeTag::Struct, + "map" => TypeTag::Map, + "variant" => TypeTag::Variant, + "udt" => TypeTag::Udt, + "unparsed" => TypeTag::Unparsed, + other => bail!("unsupported type: {other}"), + }; + + Ok(result) +} + +pub fn deser(value: Value) -> eyre::Result { + let Value::Object(input) = value else { + bail!("expected object; instead got {:?}", value); + }; + + deser_helper(input) +} + +fn deser_helper( + mut input: serde_json::Map, +) -> eyre::Result { + // {"fields":[{"metadata":{},"name":"id","nullable":true,"type":"long"}],"type":"struct"} + + let kind = remove_type(&mut input)?; + + let result = match kind { + TypeTag::Null => Ok(Kind::Null(spark_connect::data_type::Null { + type_variation_reference: 0, + })), + TypeTag::Binary => Ok(Kind::Binary(spark_connect::data_type::Binary { + type_variation_reference: 0, + })), + TypeTag::Boolean => Ok(Kind::Boolean(spark_connect::data_type::Boolean { + type_variation_reference: 0, + })), + TypeTag::Byte => Ok(Kind::Byte(spark_connect::data_type::Byte { + type_variation_reference: 0, + })), + TypeTag::Short => Ok(Kind::Short(spark_connect::data_type::Short { + type_variation_reference: 0, + })), + TypeTag::Integer => Ok(Kind::Integer(spark_connect::data_type::Integer { + type_variation_reference: 0, + })), + TypeTag::Long => Ok(Kind::Long(spark_connect::data_type::Long { + type_variation_reference: 0, + })), + TypeTag::Float => Ok(Kind::Float(spark_connect::data_type::Float { + type_variation_reference: 0, + })), + TypeTag::Double => Ok(Kind::Double(spark_connect::data_type::Double { + type_variation_reference: 0, + })), + TypeTag::Decimal => Ok(Kind::Decimal(spark_connect::data_type::Decimal { + scale: None, + precision: None, + type_variation_reference: 0, + })), + TypeTag::String => Ok(Kind::String(spark_connect::data_type::String { + type_variation_reference: 0, + collation: String::new(), + })), + TypeTag::Char => Ok(Kind::Char(spark_connect::data_type::Char { + type_variation_reference: 0, + length: 1, + })), + TypeTag::VarChar => Ok(Kind::VarChar(spark_connect::data_type::VarChar { + type_variation_reference: 0, + length: 0, + })), + TypeTag::Date => Ok(Kind::Date(spark_connect::data_type::Date { + type_variation_reference: 0, + })), + TypeTag::Timestamp => Ok(Kind::Timestamp(spark_connect::data_type::Timestamp { + type_variation_reference: 0, + })), + TypeTag::TimestampNtz => Ok(Kind::TimestampNtz(spark_connect::data_type::TimestampNtz { + type_variation_reference: 0, + })), + TypeTag::CalendarInterval => Ok(Kind::CalendarInterval( + spark_connect::data_type::CalendarInterval { + type_variation_reference: 0, + }, + )), + TypeTag::YearMonthInterval => Ok(Kind::YearMonthInterval( + spark_connect::data_type::YearMonthInterval { + type_variation_reference: 0, + start_field: None, + end_field: None, + }, + )), + TypeTag::DayTimeInterval => Ok(Kind::DayTimeInterval( + spark_connect::data_type::DayTimeInterval { + type_variation_reference: 0, + start_field: None, + end_field: None, + }, + )), + TypeTag::Array => Err(eyre!("Array type not supported")) + .suggestion("Wait until we support arrays in Spark Connect"), + TypeTag::Struct => deser_struct(input), + TypeTag::Map => Err(eyre!("Map type not supported")) + .suggestion("Wait until we support maps in Spark Connect"), + TypeTag::Variant => Ok(Kind::Variant(spark_connect::data_type::Variant { + type_variation_reference: 0, + })), + TypeTag::Udt => bail!("UDT type not supported"), + TypeTag::Unparsed => bail!("Unparsed type not supported"), + }?; + + let result = spark_connect::DataType { kind: Some(result) }; + + Ok(result) +} + +fn deser_struct( + mut object: serde_json::Map, +) -> eyre::Result { + // {"fields":[{"metadata":{},"name":"id","nullable":true,"type":"long"}]} + + let Some(fields) = object.remove("fields") else { + bail!("missing fields"); + }; + + ensure!(object.is_empty(), "unexpected fields: {object:?}"); + + let Value::Array(fields) = fields else { + bail!("expected fields to be array"); + }; + + let fields: Vec<_> = fields.into_iter().map(deser_struct_field).try_collect()?; + + Ok(Kind::Struct(spark_connect::data_type::Struct { + fields, + type_variation_reference: 0, + })) +} + +fn deser_struct_field( + field: serde_json::Value, +) -> eyre::Result { + // {"metadata":{},"name":"id","nullable":true,"type":"long"} + + let Value::Object(mut object) = field else { + bail!("expected object"); + }; + + let Some(metadata) = object.remove("metadata") else { + bail!("missing metadata"); + }; + + warn!("ignoring metadata: {metadata:?}"); + + let Some(name) = object.remove("name") else { + bail!("missing name"); + }; + + let Value::String(name) = name else { + bail!("expected name to be string; instead got {:?}", name); + }; + + let Some(nullable) = object.remove("nullable") else { + bail!("missing nullable"); + }; + + let Value::Bool(nullable) = nullable else { + bail!("expected nullable to be bool; instead got {:?}", nullable); + }; + + let inner = deser_helper(object)?; + + Ok(spark_connect::data_type::StructField { + name, + data_type: Some(inner), + nullable, + metadata: None, + }) +} diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 93c9e9bd4a..59f8e0890e 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -1,15 +1,39 @@ +use std::{collections::HashMap, sync::Arc}; + use daft_logical_plan::LogicalPlanBuilder; +use daft_micropartition::MicroPartition; +use derive_more::Constructor; use eyre::{bail, Context}; use spark_connect::{relation::RelType, Limit, Relation}; use tracing::warn; -use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range}; +use crate::translation::logical_plan::{ + aggregate::aggregate, local_relation::local_relation, project::project, range::range, + to_df::to_df, +}; mod aggregate; +mod local_relation; mod project; mod range; +mod to_df; + +#[derive(Constructor)] +pub struct Plan { + pub builder: LogicalPlanBuilder, + pub psets: HashMap>>, +} + +impl From for Plan { + fn from(builder: LogicalPlanBuilder) -> Self { + Self { + builder, + psets: HashMap::new(), + } + } +} -pub fn to_logical_plan(relation: Relation) -> eyre::Result { +pub fn to_logical_plan(relation: Relation) -> eyre::Result { if let Some(common) = relation.common { warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); }; @@ -25,18 +49,23 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::Aggregate(a) => { aggregate(*a).wrap_err("Failed to apply aggregate to logical plan") } + RelType::ToDf(t) => to_df(*t).wrap_err("Failed to apply to_df to logical plan"), + RelType::LocalRelation(l) => { + local_relation(l).wrap_err("Failed to apply local_relation to logical plan") + } plan => bail!("Unsupported relation type: {plan:?}"), } } -fn limit(limit: Limit) -> eyre::Result { +fn limit(limit: Limit) -> eyre::Result { let Limit { input, limit } = limit; let Some(input) = input else { bail!("input must be set"); }; - let plan = to_logical_plan(*input)?.limit(i64::from(limit), false)?; // todo: eager or no + let mut plan = to_logical_plan(*input)?; + plan.builder = plan.builder.limit(i64::from(limit), false)?; // todo: eager or no Ok(plan) } diff --git a/src/daft-connect/src/translation/logical_plan/aggregate.rs b/src/daft-connect/src/translation/logical_plan/aggregate.rs index 193ca4d088..5fbf9d6a9b 100644 --- a/src/daft-connect/src/translation/logical_plan/aggregate.rs +++ b/src/daft-connect/src/translation/logical_plan/aggregate.rs @@ -1,10 +1,9 @@ -use daft_logical_plan::LogicalPlanBuilder; use eyre::{bail, WrapErr}; use spark_connect::aggregate::GroupType; -use crate::translation::{to_daft_expr, to_logical_plan}; +use crate::translation::{logical_plan::Plan, to_daft_expr, to_logical_plan}; -pub fn aggregate(aggregate: spark_connect::Aggregate) -> eyre::Result { +pub fn aggregate(aggregate: spark_connect::Aggregate) -> eyre::Result { let spark_connect::Aggregate { input, group_type, @@ -18,7 +17,7 @@ pub fn aggregate(aggregate: spark_connect::Aggregate) -> eyre::Result eyre::Result eyre::Result { + #[cfg(not(feature = "python"))] + { + bail!("LocalRelation plan is only supported in Python mode"); + } + + #[cfg(feature = "python")] + { + use daft_micropartition::{python::PyMicroPartition, MicroPartition}; + use pyo3::{types::PyAnyMethods, Python}; + let spark_connect::LocalRelation { data, schema } = plan; + + let Some(data) = data else { + bail!("Data is required but was not provided in the LocalRelation plan.") + }; + + let Some(schema) = schema else { + bail!("Schema is required but was not provided in the LocalRelation plan.") + }; + + let schema: serde_json::Value = serde_json::from_str(&schema).wrap_err_with(|| { + format!("Failed to parse schema string into JSON format: {schema}") + })?; + + // spark schema + let schema = deser_spark_datatype(schema)?; + + // daft schema + let schema = to_daft_datatype(&schema)?; + + // should be of type struct + let daft_schema::dtype::DataType::Struct(daft_fields) = &schema else { + bail!("schema must be struct") + }; + + let daft_schema = daft_schema::schema::Schema::new(daft_fields.clone()) + .wrap_err("Could not create schema")?; + + let daft_schema = Arc::new(daft_schema); + + let arrow_fields: Vec<_> = daft_fields + .iter() + .map(|daft_field| daft_field.to_arrow()) + .try_collect()?; + + let mut dict_idx = 0; + + let ipc_fields: Vec<_> = daft_fields + .iter() + .map(|field| { + let required_dictionary = field.dtype == DaftDataType::Utf8; + + let dictionary_id = match required_dictionary { + true => { + let res = dict_idx; + dict_idx += 1; + Some(res) + } + false => None, + }; + + // For integer columns, we don't need dictionary encoding + IpcField { + fields: vec![], // No nested fields for primitive types + dictionary_id, + } + }) + .collect(); + + let schema = arrow2::datatypes::Schema::from(arrow_fields); + + let little_endian = true; + let version = Version::V5; + + let tables = { + let metadata = StreamMetadata { + schema, + version, + ipc_schema: IpcSchema { + fields: ipc_fields, + is_little_endian: little_endian, + }, + }; + + let reader = Cursor::new(&data); + let reader = StreamReader::new(reader, metadata, None); + + let chunks = reader.map(|value| match value { + Ok(StreamState::Some(chunk)) => Ok(chunk.arrays().to_vec()), + Ok(StreamState::Waiting) => { + bail!("StreamReader is waiting for data, but a chunk was expected.") + } + Err(e) => bail!("Error occurred while reading chunk from StreamReader: {e}"), + }); + + // todo: eek + let chunks = chunks.skip(1); + + let mut tables = Vec::new(); + + for (idx, chunk) in chunks.enumerate() { + let chunk = chunk.wrap_err_with(|| format!("chunk {idx} is invalid"))?; + + let mut columns = Vec::with_capacity(daft_schema.fields.len()); + let mut num_rows = Vec::with_capacity(daft_schema.fields.len()); + + for (array, (_, daft_field)) in itertools::zip_eq(chunk, &daft_schema.fields) { + // Note: Cloning field and array; consider optimizing to avoid unnecessary clones. + let field = daft_field.clone(); + let field_ref = Arc::new(field); + let series = Series::from_arrow(field_ref, array) + .wrap_err("Failed to create Series from Arrow array.")?; + + num_rows.push(series.len()); + columns.push(series); + } + + ensure!( + num_rows.iter().all_equal(), + "Mismatch in row counts across columns; all columns must have the same number of rows." + ); + + let Some(&num_rows) = num_rows.first() else { + bail!("No columns were found; at least one column is required.") + }; + + let table = Table::new_with_size(daft_schema.clone(), columns, num_rows) + .wrap_err("Failed to create Table from columns and schema.")?; + + tables.push(table); + } + tables + }; + + // Note: Verify if the Daft schema used here matches the schema of the table. + let micro_partition = MicroPartition::new_loaded(daft_schema, Arc::new(tables), None); + let micro_partition = Arc::new(micro_partition); + + let plan = Python::with_gil(|py| { + // Convert MicroPartition to a logical plan using Python interop. + let py_micropartition = py + .import_bound(pyo3::intern!(py, "daft.table"))? + .getattr(pyo3::intern!(py, "MicroPartition"))? + .getattr(pyo3::intern!(py, "_from_pymicropartition"))? + .call1((PyMicroPartition::from(micro_partition.clone()),))?; + + // ERROR: 2: AttributeError: 'daft.daft.PySchema' object has no attribute '_schema' + let py_plan_builder = py + .import_bound(pyo3::intern!(py, "daft.dataframe.dataframe"))? + .getattr(pyo3::intern!(py, "to_logical_plan_builder"))? + .call1((py_micropartition,))?; + + let py_plan_builder = py_plan_builder.getattr(pyo3::intern!(py, "_builder"))?; + + let plan: PyLogicalPlanBuilder = py_plan_builder.extract()?; + + Ok::<_, eyre::Error>(plan.builder) + })?; + + let cache_key = grab_singular_cache_key(&plan)?; + + let mut psets = HashMap::new(); + psets.insert(cache_key, vec![micro_partition]); + + let plan = Plan::new(plan, psets); + + Ok(plan) + } +} + +fn grab_singular_cache_key(plan: &LogicalPlanBuilder) -> eyre::Result { + let plan = &*plan.plan; + + let LogicalPlan::Source(Source { source_info, .. }) = plan else { + bail!("Expected a source plan"); + }; + + let SourceInfo::InMemory(InMemoryInfo { cache_key, .. }) = &**source_info else { + bail!("Expected an in-memory source"); + }; + + Ok(cache_key.clone()) +} diff --git a/src/daft-connect/src/translation/logical_plan/project.rs b/src/daft-connect/src/translation/logical_plan/project.rs index 3096b7f313..b5c1a136ec 100644 --- a/src/daft-connect/src/translation/logical_plan/project.rs +++ b/src/daft-connect/src/translation/logical_plan/project.rs @@ -3,24 +3,22 @@ //! TL;DR: Project is Spark's equivalent of SQL SELECT - it selects columns, renames them via aliases, //! and creates new columns from expressions. Example: `df.select(col("id").alias("my_number"))` -use daft_logical_plan::LogicalPlanBuilder; use eyre::bail; use spark_connect::Project; -use crate::translation::{to_daft_expr, to_logical_plan}; +use crate::translation::{logical_plan::Plan, to_daft_expr, to_logical_plan}; -pub fn project(project: Project) -> eyre::Result { +pub fn project(project: Project) -> eyre::Result { let Project { input, expressions } = project; let Some(input) = input else { bail!("Project input is required"); }; - let plan = to_logical_plan(*input)?; + let mut plan = to_logical_plan(*input)?; let daft_exprs: Vec<_> = expressions.iter().map(to_daft_expr).try_collect()?; - - let plan = plan.select(daft_exprs)?; + plan.builder = plan.builder.select(daft_exprs)?; Ok(plan) } diff --git a/src/daft-connect/src/translation/logical_plan/range.rs b/src/daft-connect/src/translation/logical_plan/range.rs index e11fef26cb..ff15e0cacb 100644 --- a/src/daft-connect/src/translation/logical_plan/range.rs +++ b/src/daft-connect/src/translation/logical_plan/range.rs @@ -2,7 +2,9 @@ use daft_logical_plan::LogicalPlanBuilder; use eyre::{ensure, Context}; use spark_connect::Range; -pub fn range(range: Range) -> eyre::Result { +use crate::translation::logical_plan::Plan; + +pub fn range(range: Range) -> eyre::Result { #[cfg(not(feature = "python"))] { use eyre::bail; @@ -50,6 +52,6 @@ pub fn range(range: Range) -> eyre::Result { }) .wrap_err("Failed to create range scan")?; - Ok(plan) + Ok(plan.into()) } } diff --git a/src/daft-connect/src/translation/logical_plan/to_df.rs b/src/daft-connect/src/translation/logical_plan/to_df.rs new file mode 100644 index 0000000000..63ad58f10f --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/to_df.rs @@ -0,0 +1,29 @@ +use eyre::{bail, WrapErr}; + +use crate::translation::{logical_plan::Plan, to_logical_plan}; + +pub fn to_df(to_df: spark_connect::ToDf) -> eyre::Result { + let spark_connect::ToDf { + input, + column_names, + } = to_df; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let mut plan = + to_logical_plan(*input).wrap_err("Failed to translate relation to logical plan")?; + + let column_names: Vec<_> = column_names + .iter() + .map(|s| daft_dsl::col(s.as_str())) + .collect(); + + plan.builder = plan + .builder + .select(column_names) + .wrap_err("Failed to add columns to logical plan")?; + + Ok(plan) +} diff --git a/src/daft-connect/src/translation/schema.rs b/src/daft-connect/src/translation/schema.rs index 1b242428d2..6eaad72a0c 100644 --- a/src/daft-connect/src/translation/schema.rs +++ b/src/daft-connect/src/translation/schema.rs @@ -14,7 +14,7 @@ pub fn relation_to_schema(input: Relation) -> eyre::Result { let plan = to_logical_plan(input)?; - let result = plan.schema(); + let result = plan.builder.schema(); let fields: eyre::Result> = result .fields diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index c931614ff3..ded430d3bc 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -102,7 +102,9 @@ pub fn physical_plan_to_pipeline( scan_task_source.arced().into() } LocalPhysicalPlan::InMemoryScan(InMemoryScan { info, .. }) => { - let partitions = psets.get(&info.cache_key).expect("Cache key not found"); + let partitions = psets + .get(&info.cache_key) + .unwrap_or_else(|| panic!("Cache key not found: {:?}", info.cache_key)); InMemorySource::new(partitions.clone(), info.source_schema.clone()) .arced() .into() diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder.rs index f40a55ed4f..0daaf79460 100644 --- a/src/daft-logical-plan/src/builder.rs +++ b/src/daft-logical-plan/src/builder.rs @@ -653,7 +653,7 @@ impl LogicalPlanBuilder { /// as possible, converting pyo3 wrapper type arguments into their underlying Rust-native types /// (e.g. PySchema -> Schema). #[cfg_attr(feature = "python", pyclass(name = "LogicalPlanBuilder"))] -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct PyLogicalPlanBuilder { // Internal logical plan builder. pub builder: LogicalPlanBuilder, diff --git a/tests/connect/conftest.py b/tests/connect/conftest.py index 60c5ae9986..7f6b05a27a 100644 --- a/tests/connect/conftest.py +++ b/tests/connect/conftest.py @@ -12,6 +12,7 @@ def spark_session(): This fixture is available to all test files and creates a single Spark session for the entire test suite run. """ + from daft.daft import connect_start # Start Daft Connect server diff --git a/tests/connect/test_create_df.py b/tests/connect/test_create_df.py new file mode 100644 index 0000000000..187f4fbc5a --- /dev/null +++ b/tests/connect/test_create_df.py @@ -0,0 +1,35 @@ +from __future__ import annotations + + +def test_create_df(spark_session): + # Create simple DataFrame with single column + data = [(1,), (2,), (3,)] + df = spark_session.createDataFrame(data, ["id"]) + + # Convert to pandas and verify + df_pandas = df.toPandas() + assert len(df_pandas) == 3, "DataFrame should have 3 rows" + assert list(df_pandas["id"]) == [1, 2, 3], "DataFrame should contain expected values" + + # Create DataFrame with float column + float_data = [(1.1,), (2.2,), (3.3,)] + df_float = spark_session.createDataFrame(float_data, ["value"]) + df_float_pandas = df_float.toPandas() + assert len(df_float_pandas) == 3, "Float DataFrame should have 3 rows" + assert list(df_float_pandas["value"]) == [1.1, 2.2, 3.3], "Float DataFrame should contain expected values" + + # Create DataFrame with two numeric columns + two_col_data = [(1, 10), (2, 20), (3, 30)] + df_two = spark_session.createDataFrame(two_col_data, ["num1", "num2"]) + df_two_pandas = df_two.toPandas() + assert len(df_two_pandas) == 3, "Two-column DataFrame should have 3 rows" + assert list(df_two_pandas["num1"]) == [1, 2, 3], "First number column should contain expected values" + assert list(df_two_pandas["num2"]) == [10, 20, 30], "Second number column should contain expected values" + + # now do boolean + print("now testing boolean") + boolean_data = [(True,), (False,), (True,)] + df_boolean = spark_session.createDataFrame(boolean_data, ["value"]) + df_boolean_pandas = df_boolean.toPandas() + assert len(df_boolean_pandas) == 3, "Boolean DataFrame should have 3 rows" + assert list(df_boolean_pandas["value"]) == [True, False, True], "Boolean DataFrame should contain expected values"