Skip to content

Commit

Permalink
[FEAT] connect: support basic column operations
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 20, 2024
1 parent cdcd749 commit 89e89e8
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/daft-connect/src/translation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod literal;
mod logical_plan;
mod schema;

pub use datatype::to_spark_datatype;
pub use datatype::{to_daft_datatype, to_spark_datatype};
pub use expr::to_daft_expr;
pub use literal::to_daft_literal;
pub use logical_plan::to_logical_plan;
Expand Down
154 changes: 153 additions & 1 deletion src/daft-connect/src/translation/datatype.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use daft_schema::dtype::DataType;
use daft_schema::{dtype::DataType, field::Field, time_unit::TimeUnit};
use eyre::{bail, ensure, WrapErr};
use spark_connect::data_type::Kind;
use tracing::warn;

Expand Down Expand Up @@ -112,3 +113,154 @@ pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType {
_ => unimplemented!("Unsupported datatype: {datatype:?}"),
}
}

// todo(test): add tests for this esp in Python
pub fn to_daft_datatype(datatype: &spark_connect::DataType) -> eyre::Result<DataType> {
let Some(kind) = &datatype.kind else {
bail!("Datatype is required");
};

let type_variation_err = "Custom type variation reference not supported";

match kind {
Kind::Null(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Null)
}
Kind::Binary(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Binary)
}
Kind::Boolean(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Boolean)
}
Kind::Byte(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Int8)
}
Kind::Short(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Int16)
}
Kind::Integer(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Int32)
}
Kind::Long(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Int64)
}
Kind::Float(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Float32)
}
Kind::Double(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Float64)
}
Kind::Decimal(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);

let Some(precision) = value.precision else {
bail!("Decimal precision is required");
};

let Some(scale) = value.scale else {
bail!("Decimal scale is required");
};

let precision = usize::try_from(precision)
.wrap_err("Decimal precision must be a non-negative integer")?;

let scale =
usize::try_from(scale).wrap_err("Decimal scale must be a non-negative integer")?;

Ok(DataType::Decimal128(precision, scale))
}
Kind::String(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Utf8)
}
Kind::Char(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Utf8)
}
Kind::VarChar(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Utf8)
}
Kind::Date(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Date)
}
Kind::Timestamp(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);

// todo(?): is this correct?

Ok(DataType::Timestamp(TimeUnit::Microseconds, None))
}
Kind::TimestampNtz(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);

// todo(?): is this correct?

Ok(DataType::Timestamp(TimeUnit::Microseconds, None))
}
Kind::CalendarInterval(_) => bail!("Calendar interval type not supported"),
Kind::YearMonthInterval(_) => bail!("Year-month interval type not supported"),
Kind::DayTimeInterval(_) => bail!("Day-time interval type not supported"),
Kind::Array(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
let element_type = to_daft_datatype(
value
.element_type
.as_ref()
.ok_or_else(|| eyre::eyre!("Array element type is required"))?,
)?;
Ok(DataType::List(Box::new(element_type)))
}
Kind::Struct(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
let fields = value
.fields
.iter()
.map(|f| {
let field_type = to_daft_datatype(
f.data_type
.as_ref()
.ok_or_else(|| eyre::eyre!("Struct field type is required"))?,
)?;
Ok(Field::new(&f.name, field_type))
})
.collect::<eyre::Result<Vec<_>>>()?;
Ok(DataType::Struct(fields))
}
Kind::Map(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
let key_type = to_daft_datatype(
value
.key_type
.as_ref()
.ok_or_else(|| eyre::eyre!("Map key type is required"))?,
)?;
let value_type = to_daft_datatype(
value
.value_type
.as_ref()
.ok_or_else(|| eyre::eyre!("Map value type is required"))?,
)?;

let map = DataType::Map {
key: Box::new(key_type),
value: Box::new(value_type),
};

Ok(map)
}
Kind::Variant(_) => bail!("Variant type not supported"),
Kind::Udt(_) => bail!("User-defined type not supported"),
Kind::Unparsed(_) => bail!("Unparsed type not supported"),
}
}
68 changes: 64 additions & 4 deletions src/daft-connect/src/translation/expr.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
use std::sync::Arc;

use eyre::{bail, Context};
use spark_connect::{expression as spark_expr, Expression};
use spark_connect::{
expression as spark_expr,
expression::{
cast::{CastToType, EvalMode},
sort_order::{NullOrdering, SortDirection},
},
Expression,
};
use tracing::warn;
use unresolved_function::unresolved_to_daft_expr;

use crate::translation::to_daft_literal;
use crate::translation::{to_daft_datatype, to_daft_literal};

mod unresolved_function;

Expand Down Expand Up @@ -69,11 +76,64 @@ pub fn to_daft_expr(expression: &Expression) -> eyre::Result<daft_dsl::ExprRef>

Ok(child.alias(name))
}
spark_expr::ExprType::Cast(_) => bail!("Cast expressions not yet supported"),
spark_expr::ExprType::Cast(c) => {
// Cast { expr: Some(Expression { common: None, expr_type: Some(UnresolvedAttribute(UnresolvedAttribute { unparsed_identifier: "id", plan_id: None, is_metadata_column: None })) }), eval_mode: Unspecified, cast_to_type: Some(Type(DataType { kind: Some(String(String { type_variation_reference: 0, collation: "" })) })) }
// thread 'tokio-runtime-worker' panicked at src/daft-connect/src/trans
println!("got cast {c:?}");
let spark_expr::Cast {
expr,
eval_mode,
cast_to_type,
} = &**c;

let Some(expr) = expr else {
bail!("Cast expression is required");
};

let expr = to_daft_expr(expr)?;

let Some(cast_to_type) = cast_to_type else {
bail!("Cast to type is required");
};

let data_type = match cast_to_type {
CastToType::Type(kind) => to_daft_datatype(kind).wrap_err_with(|| {
format!("Failed to convert spark datatype to daft datatype: {kind:?}")
})?,
CastToType::TypeStr(s) => {
bail!("Cast to type string not yet supported; tried to cast to {s}");
}
};

let eval_mode = EvalMode::try_from(*eval_mode)
.wrap_err_with(|| format!("Invalid cast eval mode: {eval_mode}"))?;

warn!("Ignoring cast eval mode: {eval_mode:?}");

Ok(expr.cast(&data_type))
}
spark_expr::ExprType::UnresolvedRegex(_) => {
bail!("Unresolved regex expressions not yet supported")
}
spark_expr::ExprType::SortOrder(_) => bail!("Sort order expressions not yet supported"),
spark_expr::ExprType::SortOrder(s) => {
let spark_expr::SortOrder {
child,
direction,
null_ordering,
} = &**s;

let Some(child) = child else {
bail!("Sort order child is required");
};

let sort_direction = SortDirection::try_from(*direction)
.wrap_err_with(|| format!("Invalid sort direction: {direction}"))?;

let sort_nulls = NullOrdering::try_from(*null_ordering)
.wrap_err_with(|| format!("Invalid sort nulls: {null_ordering}"))?;

bail!("Sort order expressions not yet supported");
}
spark_expr::ExprType::LambdaFunction(_) => {
bail!("Lambda function expressions not yet supported")
}
Expand Down
28 changes: 28 additions & 0 deletions src/daft-connect/src/translation/expr/unresolved_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub fn unresolved_to_daft_expr(f: &UnresolvedFunction) -> eyre::Result<daft_dsl:

match function_name.as_str() {
"count" => handle_count(arguments).wrap_err("Failed to handle count function"),
"isnotnull" => handle_isnotnull(arguments).wrap_err("Failed to handle isnotnull function"),
"isnull" => handle_isnull(arguments).wrap_err("Failed to handle isnull function"),
n => bail!("Unresolved function {n} not yet supported"),
}
}
Expand All @@ -42,3 +44,29 @@ pub fn handle_count(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl:

Ok(count)
}

pub fn handle_isnull(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() {
Ok(arguments) => arguments,
Err(arguments) => {
bail!("requires exactly one argument; got {arguments:?}");
}
};

let [arg] = arguments;

Ok(arg.is_null())
}

pub fn handle_isnotnull(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() {
Ok(arguments) => arguments,
Err(arguments) => {
bail!("requires exactly one argument; got {arguments:?}");
}
};

let [arg] = arguments;

Ok(arg.not_null())
}
36 changes: 36 additions & 0 deletions tests/connect/test_basic_column.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

from pyspark.sql.functions import col
from pyspark.sql.types import StringType


def test_column_operations(spark_session):
# Create DataFrame from range(10)
df = spark_session.range(10)

# Test __getattr__
df_attr = df.select(col("id").desc()) # Fix: call desc() as method
assert df_attr.toPandas()["id"].iloc[0] == 9, "desc should sort in descending order"

# Test __getitem__
# df_item = df.select(col("id")[0])
# assert df_item.toPandas()["id"].iloc[0] == 0, "getitem should return first element"

# Test alias
df_alias = df.select(col("id").alias("my_number"))
assert "my_number" in df_alias.columns, "alias should rename column"
assert df_alias.toPandas()["my_number"].equals(df.toPandas()["id"]), "data should be unchanged"

# Test cast
df_cast = df.select(col("id").cast(StringType()))
assert df_cast.schema.fields[0].dataType == StringType(), "cast should change data type"

# Test isNotNull/isNull
df_null = df.select(col("id").isNotNull().alias("not_null"), col("id").isNull().alias("is_null"))
assert df_null.toPandas()["not_null"].iloc[0] == True, "isNotNull should be True for non-null values"
assert df_null.toPandas()["is_null"].iloc[0] == False, "isNull should be False for non-null values"

# Test name
df_name = df.select(col("id").name("renamed_id"))
assert "renamed_id" in df_name.columns, "name should rename column"
assert df_name.toPandas()["renamed_id"].equals(df.toPandas()["id"]), "data should be unchanged"

0 comments on commit 89e89e8

Please sign in to comment.