Skip to content

Commit

Permalink
add another test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 4, 2024
1 parent 837c3ed commit 0d0dbe5
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 126 deletions.
2 changes: 1 addition & 1 deletion daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@

ManyColumnsInputType = Union[ColumnInputType, Iterable[ColumnInputType]]


def to_logical_plan_builder(*parts: MicroPartition) -> LogicalPlanBuilder:
"""Creates a Daft DataFrame from a single Table.
Expand Down Expand Up @@ -503,7 +504,6 @@ def _from_pandas(cls, data: Union["pandas.DataFrame", List["pandas.DataFrame"]])
data_micropartitions = [MicroPartition.from_pandas(df) for df in data]
return cls._from_tables(*data_micropartitions)


@classmethod
def _from_tables(cls, *parts: MicroPartition) -> "DataFrame":
"""Creates a Daft DataFrame from a single Table.
Expand Down
2 changes: 1 addition & 1 deletion src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ daft-schema = {workspace = true}
daft-table = {workspace = true}
daft-writers = {workspace = true}
dashmap = "6.1.0"
derive_more = {workspace = true}
eyre = "0.6.12"
futures = "0.3.31"
itertools = {workspace = true}
pyo3 = {workspace = true, optional = true}
serde_json = {workspace = true}
derive_more = {workspace = true}
spark-connect = {workspace = true}
tokio = {version = "1.40.0", features = ["full"]}
tonic = "0.12.3"
Expand Down
5 changes: 2 additions & 3 deletions src/daft-connect/src/op/execute/root.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use std::{collections::HashMap, future::ready, sync::Arc};
use std::future::ready;
use std::{future::ready, sync::Arc};

use common_daft_config::DaftExecutionConfig;
use daft_local_execution::NativeExecutor;
use futures::stream;
use spark_connect::{ExecutePlanResponse, Plan, Relation};
use spark_connect::{ExecutePlanResponse, Relation};
use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status};

use crate::{
Expand Down
262 changes: 143 additions & 119 deletions src/daft-connect/src/translation/logical_plan/local_relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,164 +10,188 @@ use daft_logical_plan::{
logical_plan::Source, InMemoryInfo, LogicalPlan, LogicalPlanBuilder, PyLogicalPlanBuilder,
SourceInfo,
};
use daft_micropartition::{python::PyMicroPartition, MicroPartition};
use daft_schema::dtype::DaftDataType;
use daft_table::Table;
use eyre::{bail, ensure, WrapErr};
use itertools::Itertools;
use pyo3::{types::PyAnyMethods, Python};
use tracing::debug;

use crate::translation::{deser_spark_datatype, logical_plan::Plan, to_daft_datatype};

pub fn local_relation(plan: spark_connect::LocalRelation) -> eyre::Result<Plan> {
let spark_connect::LocalRelation { data, schema } = plan;
#[cfg(not(feature = "python"))]
{
bail!("LocalRelation plan is only supported in Python mode");
}

#[cfg(feature = "python")]
{
use daft_micropartition::{python::PyMicroPartition, MicroPartition};
use pyo3::{types::PyAnyMethods, Python};
let spark_connect::LocalRelation { data, schema } = plan;

let Some(data) = data else {
bail!("Data is required but was not provided in the LocalRelation plan.")
};

let Some(data) = data else {
bail!("Data is required but was not provided in the LocalRelation plan.")
};
let Some(schema) = schema else {
bail!("Schema is required but was not provided in the LocalRelation plan.")
};

let Some(schema) = schema else {
bail!("Schema is required but was not provided in the LocalRelation plan.")
};
let schema: serde_json::Value = serde_json::from_str(&schema).wrap_err_with(|| {
format!("Failed to parse schema string into JSON format: {schema}")
})?;

let schema: serde_json::Value = serde_json::from_str(&schema)
.wrap_err_with(|| format!("Failed to parse schema string into JSON format: {schema}"))?;
debug!("schema JSON {schema}");

debug!("schema JSON {schema}");
// spark schema
let schema = deser_spark_datatype(schema)?;

// spark schema
let schema = deser_spark_datatype(schema)?;
// daft schema
let schema = to_daft_datatype(&schema)?;

// daft schema
let schema = to_daft_datatype(&schema)?;
// should be of type struct
let daft_schema::dtype::DataType::Struct(daft_fields) = &schema else {
bail!("schema must be struct")
};

// should be of type struct
let daft_schema::dtype::DataType::Struct(daft_fields) = &schema else {
bail!("schema must be struct")
};
let daft_schema = daft_schema::schema::Schema::new(daft_fields.clone())
.wrap_err("Could not create schema")?;

let daft_schema = Arc::new(daft_schema);

let arrow_fields: Vec<_> = daft_fields
.iter()
.map(|daft_field| daft_field.to_arrow())
.try_collect()?;

let mut dict_idx = 0;

let ipc_fields: Vec<_> = daft_fields
.iter()
.map(|field| {
let required_dictionary = field.dtype == DaftDataType::Utf8;

let dictionary_id = match required_dictionary {
true => {
let res = dict_idx;
dict_idx += 1;
debug!("using dictionary id {res}");
Some(res)
}
false => None,
};

// For integer columns, we don't need dictionary encoding
IpcField {
fields: vec![], // No nested fields for primitive types
dictionary_id,
}
})
.collect();

let schema = arrow2::datatypes::Schema::from(arrow_fields);
debug!("schema {schema:?}");

let little_endian = true;
let version = MetadataVersion::V5;

let tables = {
let metadata = StreamMetadata {
schema,
version,
ipc_schema: IpcSchema {
fields: ipc_fields,
is_little_endian: little_endian,
},
};

let daft_schema = daft_schema::schema::Schema::new(daft_fields.clone())
.wrap_err("Could not create schema")?;
let reader = Cursor::new(&data);
let reader = StreamReader::new(reader, metadata, None);

let daft_schema = Arc::new(daft_schema);
let chunks = reader.map(|value| match value {
Ok(StreamState::Some(chunk)) => Ok(chunk.arrays().to_vec()),
Ok(StreamState::Waiting) => {
bail!("StreamReader is waiting for data, but a chunk was expected.")
}
Err(e) => bail!("Error occurred while reading chunk from StreamReader: {e}"),
});

let arrow_fields: Vec<_> = daft_fields
.iter()
.map(|daft_field| daft_field.to_arrow())
.try_collect()?;
// todo: eek
let chunks = chunks.skip(1);

let ipc_fields: Vec<_> = daft_fields
.iter()
.map(|_| {
// For integer columns, we don't need dictionary encoding
IpcField {
fields: vec![], // No nested fields for primitive types
dictionary_id: None, // No dictionary encoding
}
})
.collect();

let schema = arrow2::datatypes::Schema::from(arrow_fields);
debug!("schema {schema:?}");

let little_endian = true;
let version = MetadataVersion::V5;

let tables = {
let metadata = StreamMetadata {
schema,
version,
ipc_schema: IpcSchema {
fields: ipc_fields,
is_little_endian: little_endian,
},
};
let mut tables = Vec::new();

let reader = Cursor::new(&data);
let reader = StreamReader::new(reader, metadata, None);
for (idx, chunk) in chunks.enumerate() {
let chunk = chunk.wrap_err_with(|| format!("chunk {idx} is invalid"))?;

let chunks = reader.map(|value| match value {
Ok(StreamState::Some(chunk)) => Ok(chunk.arrays().to_vec()),
Ok(StreamState::Waiting) => {
bail!("StreamReader is waiting for data, but a chunk was expected.")
}
Err(e) => bail!("Error occurred while reading chunk from StreamReader: {e}"),
});
let mut columns = Vec::with_capacity(daft_schema.fields.len());
let mut num_rows = Vec::with_capacity(daft_schema.fields.len());

// todo: eek
let chunks = chunks.skip(1);
for (array, (_, daft_field)) in itertools::zip_eq(chunk, &daft_schema.fields) {
// Note: Cloning field and array; consider optimizing to avoid unnecessary clones.
let field = daft_field.clone();
let array = array.clone();

let mut tables = Vec::new();
let field_ref = Arc::new(field);
let series = Series::from_arrow(field_ref, array)
.wrap_err("Failed to create Series from Arrow array.")?;

for (idx, chunk) in chunks.enumerate() {
let chunk = chunk.wrap_err_with(|| format!("chunk {idx} is invalid"))?;
num_rows.push(series.len());
columns.push(series);
}

let mut columns = Vec::with_capacity(daft_schema.fields.len());
let mut num_rows = Vec::with_capacity(daft_schema.fields.len());
ensure!(
num_rows.iter().all_equal(),
"Mismatch in row counts across columns; all columns must have the same number of rows."
);

for (array, (_, daft_field)) in itertools::zip_eq(chunk, &daft_schema.fields) {
// Note: Cloning field and array; consider optimizing to avoid unnecessary clones.
let field = daft_field.clone();
let array = array.clone();
let Some(&num_rows) = num_rows.first() else {
bail!("No columns were found; at least one column is required.")
};

let field_ref = Arc::new(field);
let series = Series::from_arrow(field_ref, array)
.wrap_err("Failed to create Series from Arrow array.")?;
let table = Table::new_with_size(daft_schema.clone(), columns, num_rows)
.wrap_err("Failed to create Table from columns and schema.")?;

num_rows.push(series.len());
columns.push(series);
tables.push(table);
}
tables
};

ensure!(
num_rows.iter().all_equal(),
"Mismatch in row counts across columns; all columns must have the same number of rows."
);

let Some(&num_rows) = num_rows.first() else {
bail!("No columns were found; at least one column is required.")
};

let table = Table::new_with_size(daft_schema.clone(), columns, num_rows)
.wrap_err("Failed to create Table from columns and schema.")?;

tables.push(table);
}
tables
};

// Note: Verify if the Daft schema used here matches the schema of the table.
let micro_partition = MicroPartition::new_loaded(daft_schema, Arc::new(tables), None);
let micro_partition = Arc::new(micro_partition);
// Note: Verify if the Daft schema used here matches the schema of the table.
let micro_partition = MicroPartition::new_loaded(daft_schema, Arc::new(tables), None);
let micro_partition = Arc::new(micro_partition);

let plan = Python::with_gil(|py| {
// Convert MicroPartition to a logical plan using Python interop.
let py_micropartition = py
.import_bound(pyo3::intern!(py, "daft.table"))?
.getattr(pyo3::intern!(py, "MicroPartition"))?
.getattr(pyo3::intern!(py, "_from_pymicropartition"))?
.call1((PyMicroPartition::from(micro_partition.clone()),))?;
let plan = Python::with_gil(|py| {
// Convert MicroPartition to a logical plan using Python interop.
let py_micropartition = py
.import_bound(pyo3::intern!(py, "daft.table"))?
.getattr(pyo3::intern!(py, "MicroPartition"))?
.getattr(pyo3::intern!(py, "_from_pymicropartition"))?
.call1((PyMicroPartition::from(micro_partition.clone()),))?;

// ERROR: 2: AttributeError: 'daft.daft.PySchema' object has no attribute '_schema'
let py_plan_builder = py
.import_bound(pyo3::intern!(py, "daft.dataframe.dataframe"))?
.getattr(pyo3::intern!(py, "to_logical_plan_builder"))?
.call1((py_micropartition,))?;
// ERROR: 2: AttributeError: 'daft.daft.PySchema' object has no attribute '_schema'
let py_plan_builder = py
.import_bound(pyo3::intern!(py, "daft.dataframe.dataframe"))?
.getattr(pyo3::intern!(py, "to_logical_plan_builder"))?
.call1((py_micropartition,))?;

let py_plan_builder = py_plan_builder.getattr(pyo3::intern!(py, "_builder"))?;
let py_plan_builder = py_plan_builder.getattr(pyo3::intern!(py, "_builder"))?;

let plan: PyLogicalPlanBuilder = py_plan_builder.extract()?;
let plan: PyLogicalPlanBuilder = py_plan_builder.extract()?;

Ok::<_, eyre::Error>(plan.builder)
})?;
Ok::<_, eyre::Error>(plan.builder)
})?;

let cache_key = grab_singular_cache_key(&plan)?;
let cache_key = grab_singular_cache_key(&plan)?;

let mut psets = HashMap::new();
psets.insert(cache_key, vec![micro_partition]);
let mut psets = HashMap::new();
psets.insert(cache_key, vec![micro_partition]);

let plan = Plan::new(plan, psets);
let plan = Plan::new(plan, psets);

Ok(plan)
Ok(plan)
}
}

fn grab_singular_cache_key(plan: &LogicalPlanBuilder) -> eyre::Result<String> {
Expand Down
4 changes: 4 additions & 0 deletions tests/connect/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ def spark_session():
This fixture is available to all test files and creates a single
Spark session for the entire test suite run.
"""

from daft.daft import connect_start
from daft.logging import setup_debug_logger

setup_debug_logger()

# Start Daft Connect server
server = connect_start()
Expand Down
8 changes: 8 additions & 0 deletions tests/connect/test_create_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,11 @@ def test_create_df(spark_session):
assert len(df_two_pandas) == 3, "Two-column DataFrame should have 3 rows"
assert list(df_two_pandas["num1"]) == [1, 2, 3], "First number column should contain expected values"
assert list(df_two_pandas["num2"]) == [10, 20, 30], "Second number column should contain expected values"

# now do boolean
print("now testing boolean")
boolean_data = [(True,), (False,), (True,)]
df_boolean = spark_session.createDataFrame(boolean_data, ["value"])
df_boolean_pandas = df_boolean.toPandas()
assert len(df_boolean_pandas) == 3, "Boolean DataFrame should have 3 rows"
assert list(df_boolean_pandas["value"]) == [True, False, True], "Boolean DataFrame should contain expected values"
2 changes: 0 additions & 2 deletions tests/connect/test_distinct.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from pyspark.sql.functions import col


def test_distinct(spark_session):
# Create DataFrame with duplicates
Expand Down

0 comments on commit 0d0dbe5

Please sign in to comment.