Skip to content

Commit

Permalink
[FEAT] connect: support basic column operations
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 27, 2024
1 parent b6eee0b commit 2c7437d
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/daft-connect/src/translation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod literal;
mod logical_plan;
mod schema;

pub use datatype::to_spark_datatype;
pub use datatype::{to_daft_datatype, to_spark_datatype};
pub use expr::to_daft_expr;
pub use literal::to_daft_literal;
pub use logical_plan::to_logical_plan;
Expand Down
154 changes: 153 additions & 1 deletion src/daft-connect/src/translation/datatype.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use daft_schema::dtype::DataType;
use daft_schema::{dtype::DataType, field::Field, time_unit::TimeUnit};
use eyre::{bail, ensure, WrapErr};
use spark_connect::data_type::Kind;
use tracing::warn;

Expand Down Expand Up @@ -112,3 +113,154 @@ pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType {
_ => unimplemented!("Unsupported datatype: {datatype:?}"),
}
}

// todo(test): add tests for this esp in Python
pub fn to_daft_datatype(datatype: &spark_connect::DataType) -> eyre::Result<DataType> {
let Some(kind) = &datatype.kind else {
bail!("Datatype is required");

Check warning on line 120 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L120

Added line #L120 was not covered by tests
};

let type_variation_err = "Custom type variation reference not supported";

match kind {
Kind::Null(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Null)

Check warning on line 128 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L126-L128

Added lines #L126 - L128 were not covered by tests
}
Kind::Binary(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Binary)

Check warning on line 132 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L130-L132

Added lines #L130 - L132 were not covered by tests
}
Kind::Boolean(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Boolean)

Check warning on line 136 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L134-L136

Added lines #L134 - L136 were not covered by tests
}
Kind::Byte(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Int8)

Check warning on line 140 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L138-L140

Added lines #L138 - L140 were not covered by tests
}
Kind::Short(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Int16)

Check warning on line 144 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L142-L144

Added lines #L142 - L144 were not covered by tests
}
Kind::Integer(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Int32)

Check warning on line 148 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L146-L148

Added lines #L146 - L148 were not covered by tests
}
Kind::Long(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Int64)

Check warning on line 152 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L150-L152

Added lines #L150 - L152 were not covered by tests
}
Kind::Float(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Float32)

Check warning on line 156 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L154-L156

Added lines #L154 - L156 were not covered by tests
}
Kind::Double(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Float64)

Check warning on line 160 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L158-L160

Added lines #L158 - L160 were not covered by tests
}
Kind::Decimal(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);

Check warning on line 163 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L162-L163

Added lines #L162 - L163 were not covered by tests

let Some(precision) = value.precision else {
bail!("Decimal precision is required");

Check warning on line 166 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L165-L166

Added lines #L165 - L166 were not covered by tests
};

let Some(scale) = value.scale else {
bail!("Decimal scale is required");

Check warning on line 170 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L169-L170

Added lines #L169 - L170 were not covered by tests
};

let precision = usize::try_from(precision)
.wrap_err("Decimal precision must be a non-negative integer")?;

Check warning on line 174 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L173-L174

Added lines #L173 - L174 were not covered by tests

let scale =
usize::try_from(scale).wrap_err("Decimal scale must be a non-negative integer")?;

Check warning on line 177 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L176-L177

Added lines #L176 - L177 were not covered by tests

Ok(DataType::Decimal128(precision, scale))

Check warning on line 179 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L179

Added line #L179 was not covered by tests
}
Kind::String(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Utf8)
}
Kind::Char(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Utf8)

Check warning on line 187 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L185-L187

Added lines #L185 - L187 were not covered by tests
}
Kind::VarChar(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Utf8)

Check warning on line 191 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L189-L191

Added lines #L189 - L191 were not covered by tests
}
Kind::Date(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Date)

Check warning on line 195 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L193-L195

Added lines #L193 - L195 were not covered by tests
}
Kind::Timestamp(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);

Check warning on line 198 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L197-L198

Added lines #L197 - L198 were not covered by tests

// todo(?): is this correct?

Ok(DataType::Timestamp(TimeUnit::Microseconds, None))

Check warning on line 202 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L202

Added line #L202 was not covered by tests
}
Kind::TimestampNtz(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);

Check warning on line 205 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L204-L205

Added lines #L204 - L205 were not covered by tests

// todo(?): is this correct?

Ok(DataType::Timestamp(TimeUnit::Microseconds, None))

Check warning on line 209 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L209

Added line #L209 was not covered by tests
}
Kind::CalendarInterval(_) => bail!("Calendar interval type not supported"),
Kind::YearMonthInterval(_) => bail!("Year-month interval type not supported"),
Kind::DayTimeInterval(_) => bail!("Day-time interval type not supported"),
Kind::Array(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
let element_type = to_daft_datatype(
value
.element_type
.as_ref()
.ok_or_else(|| eyre::eyre!("Array element type is required"))?,
)?;
Ok(DataType::List(Box::new(element_type)))

Check warning on line 222 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L211-L222

Added lines #L211 - L222 were not covered by tests
}
Kind::Struct(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
let fields = value
.fields
.iter()
.map(|f| {
let field_type = to_daft_datatype(
f.data_type
.as_ref()
.ok_or_else(|| eyre::eyre!("Struct field type is required"))?,
)?;
Ok(Field::new(&f.name, field_type))
})
.collect::<eyre::Result<Vec<_>>>()?;
Ok(DataType::Struct(fields))

Check warning on line 238 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L224-L238

Added lines #L224 - L238 were not covered by tests
}
Kind::Map(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
let key_type = to_daft_datatype(
value
.key_type
.as_ref()
.ok_or_else(|| eyre::eyre!("Map key type is required"))?,
)?;
let value_type = to_daft_datatype(
value
.value_type
.as_ref()
.ok_or_else(|| eyre::eyre!("Map value type is required"))?,
)?;

Check warning on line 253 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L240-L253

Added lines #L240 - L253 were not covered by tests

let map = DataType::Map {
key: Box::new(key_type),
value: Box::new(value_type),
};

Ok(map)

Check warning on line 260 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L255-L260

Added lines #L255 - L260 were not covered by tests
}
Kind::Variant(_) => bail!("Variant type not supported"),
Kind::Udt(_) => bail!("User-defined type not supported"),
Kind::Unparsed(_) => bail!("Unparsed type not supported"),

Check warning on line 264 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L262-L264

Added lines #L262 - L264 were not covered by tests
}
}
68 changes: 64 additions & 4 deletions src/daft-connect/src/translation/expr.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
use std::sync::Arc;

use eyre::{bail, Context};
use spark_connect::{expression as spark_expr, Expression};
use spark_connect::{
expression as spark_expr,
expression::{
cast::{CastToType, EvalMode},
sort_order::{NullOrdering, SortDirection},
},
Expression,
};
use tracing::warn;
use unresolved_function::unresolved_to_daft_expr;

use crate::translation::to_daft_literal;
use crate::translation::{to_daft_datatype, to_daft_literal};

mod unresolved_function;

Expand Down Expand Up @@ -69,11 +76,64 @@ pub fn to_daft_expr(expression: &Expression) -> eyre::Result<daft_dsl::ExprRef>

Ok(child.alias(name))
}
spark_expr::ExprType::Cast(_) => bail!("Cast expressions not yet supported"),
spark_expr::ExprType::Cast(c) => {
// Cast { expr: Some(Expression { common: None, expr_type: Some(UnresolvedAttribute(UnresolvedAttribute { unparsed_identifier: "id", plan_id: None, is_metadata_column: None })) }), eval_mode: Unspecified, cast_to_type: Some(Type(DataType { kind: Some(String(String { type_variation_reference: 0, collation: "" })) })) }
// thread 'tokio-runtime-worker' panicked at src/daft-connect/src/trans
println!("got cast {c:?}");
let spark_expr::Cast {
expr,
eval_mode,
cast_to_type,
} = &**c;

let Some(expr) = expr else {
bail!("Cast expression is required");

Check warning on line 90 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L90

Added line #L90 was not covered by tests
};

let expr = to_daft_expr(expr)?;

let Some(cast_to_type) = cast_to_type else {
bail!("Cast to type is required");

Check warning on line 96 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L96

Added line #L96 was not covered by tests
};

let data_type = match cast_to_type {
CastToType::Type(kind) => to_daft_datatype(kind).wrap_err_with(|| {
format!("Failed to convert spark datatype to daft datatype: {kind:?}")

Check warning on line 101 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L101

Added line #L101 was not covered by tests
})?,
CastToType::TypeStr(s) => {
bail!("Cast to type string not yet supported; tried to cast to {s}");

Check warning on line 104 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L103-L104

Added lines #L103 - L104 were not covered by tests
}
};

let eval_mode = EvalMode::try_from(*eval_mode)
.wrap_err_with(|| format!("Invalid cast eval mode: {eval_mode}"))?;

warn!("Ignoring cast eval mode: {eval_mode:?}");

Ok(expr.cast(&data_type))
}
spark_expr::ExprType::UnresolvedRegex(_) => {
bail!("Unresolved regex expressions not yet supported")
}
spark_expr::ExprType::SortOrder(_) => bail!("Sort order expressions not yet supported"),
spark_expr::ExprType::SortOrder(s) => {
let spark_expr::SortOrder {
child,
direction,
null_ordering,
} = &**s;

Check warning on line 123 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L118-L123

Added lines #L118 - L123 were not covered by tests

let Some(_child) = child else {
bail!("Sort order child is required");

Check warning on line 126 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L125-L126

Added lines #L125 - L126 were not covered by tests
};

let _sort_direction = SortDirection::try_from(*direction)
.wrap_err_with(|| format!("Invalid sort direction: {direction}"))?;

Check warning on line 130 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L129-L130

Added lines #L129 - L130 were not covered by tests

let _sort_nulls = NullOrdering::try_from(*null_ordering)
.wrap_err_with(|| format!("Invalid sort nulls: {null_ordering}"))?;

Check warning on line 133 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L132-L133

Added lines #L132 - L133 were not covered by tests

bail!("Sort order expressions not yet supported");

Check warning on line 135 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L135

Added line #L135 was not covered by tests
}
spark_expr::ExprType::LambdaFunction(_) => {
bail!("Lambda function expressions not yet supported")
}
Expand Down
28 changes: 28 additions & 0 deletions src/daft-connect/src/translation/expr/unresolved_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub fn unresolved_to_daft_expr(f: &UnresolvedFunction) -> eyre::Result<daft_dsl:

match function_name.as_str() {
"count" => handle_count(arguments).wrap_err("Failed to handle count function"),
"isnotnull" => handle_isnotnull(arguments).wrap_err("Failed to handle isnotnull function"),
"isnull" => handle_isnull(arguments).wrap_err("Failed to handle isnull function"),
n => bail!("Unresolved function {n} not yet supported"),
}
}
Expand All @@ -42,3 +44,29 @@ pub fn handle_count(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl:

Ok(count)
}

pub fn handle_isnull(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() {
Ok(arguments) => arguments,
Err(arguments) => {
bail!("requires exactly one argument; got {arguments:?}");

Check warning on line 52 in src/daft-connect/src/translation/expr/unresolved_function.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr/unresolved_function.rs#L51-L52

Added lines #L51 - L52 were not covered by tests
}
};

let [arg] = arguments;

Ok(arg.is_null())
}

pub fn handle_isnotnull(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() {
Ok(arguments) => arguments,
Err(arguments) => {
bail!("requires exactly one argument; got {arguments:?}");

Check warning on line 65 in src/daft-connect/src/translation/expr/unresolved_function.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr/unresolved_function.rs#L64-L65

Added lines #L64 - L65 were not covered by tests
}
};

let [arg] = arguments;

Ok(arg.not_null())
}
38 changes: 38 additions & 0 deletions tests/connect/test_basic_column.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from pyspark.sql.functions import col
from pyspark.sql.types import StringType


def test_column_operations(spark_session):
# Create DataFrame from range(10)
df = spark_session.range(10)

# Test __getattr__
# todo: https://github.com/Eventual-Inc/Daft/issues/3433
# df_attr = df.select(col("id").desc()) # Fix: call desc() as method
# assert df_attr.toPandas()["id"].iloc[0] == 9, "desc should sort in descending order"

# Test __getitem__
# todo: add extract value
# df_item = df.select(col("id")[0])
# assert df_item.toPandas()["id"].iloc[0] == 0, "getitem should return first element"

# Test alias
df_alias = df.select(col("id").alias("my_number"))
assert "my_number" in df_alias.columns, "alias should rename column"
assert df_alias.toPandas()["my_number"].equals(df.toPandas()["id"]), "data should be unchanged"

# Test cast
df_cast = df.select(col("id").cast(StringType()))
assert df_cast.schema.fields[0].dataType == StringType(), "cast should change data type"

# Test isNotNull/isNull
df_null = df.select(col("id").isNotNull().alias("not_null"), col("id").isNull().alias("is_null"))
assert df_null.toPandas()["not_null"].iloc[0], "isNotNull should be True for non-null values"
assert not df_null.toPandas()["is_null"].iloc[0], "isNull should be False for non-null values"

# Test name
df_name = df.select(col("id").name("renamed_id"))
assert "renamed_id" in df_name.columns, "name should rename column"
assert df_name.toPandas()["renamed_id"].equals(df.toPandas()["id"]), "data should be unchanged"

0 comments on commit 2c7437d

Please sign in to comment.