Skip to content

Commit

Permalink
feat: add more unresolved functions
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 19, 2024
1 parent ae74c10 commit 8766efc
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 1 deletion.
23 changes: 22 additions & 1 deletion src/daft-connect/src/translation/expr/unresolved_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@ pub fn unresolved_to_daft_expr(f: &UnresolvedFunction) -> eyre::Result<daft_dsl:
.wrap_err("Failed to handle >= function"),
"%" => handle_binary_op(arguments, daft_dsl::Operator::Modulus)
.wrap_err("Failed to handle % function"),
"==" => handle_binary_op(arguments, daft_dsl::Operator::Eq)
.wrap_err("Failed to handle == function"),
"not" => not(arguments).wrap_err("Failed to handle not function"),
"sum" => handle_sum(arguments).wrap_err("Failed to handle sum function"),
"isnotnull" => handle_isnotnull(arguments).wrap_err("Failed to handle isnotnull function"),
"isnull" => handle_isnull(arguments).wrap_err("Failed to handle isnull function"),
n => bail!("Unresolved function {n} not yet supported"),
n => bail!("Unresolved function {n:?} not yet supported"),
}
}

Expand All @@ -53,6 +56,24 @@ pub fn handle_sum(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::E
Ok(arg.sum())
}

pub fn one(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() {
Ok(arguments) => arguments,
Err(arguments) => {
bail!("requires exactly one argument; got {arguments:?}");
}
};

let [arg] = arguments;

Ok(arg)
}

pub fn not(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
let arg = one(arguments)?;
Ok(arg.not())
}

pub fn handle_binary_op(
arguments: Vec<daft_dsl::ExprRef>,
op: daft_dsl::Operator,
Expand Down
190 changes: 190 additions & 0 deletions tests/connect/test_unresolved.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import pytest
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, IntegerType, StringType


@pytest.fixture
def sample_df(spark_session):
"""Create a sample DataFrame with various test cases."""
schema = StructType([
StructField("id", IntegerType(), False),
StructField("value_a", IntegerType(), True),
StructField("value_b", IntegerType(), True),
StructField("text_a", StringType(), True),
StructField("text_b", StringType(), True),
])

data = [
(1, 10, 20, "apple", "banana"),
(2, 20, 20, "apple", "apple"),
(3, 30, 20, "cherry", "banana"),
(4, None, 20, None, "banana"),
(5, 50, None, "date", None),
]

return spark_session.createDataFrame(data, schema)


@pytest.mark.parametrize("operator,func_name,expected_results", [
("==", "equals", [
(1, False), # 10 == 20
(2, True), # 20 == 20
(3, False), # 30 == 20
(4, None), # NULL == 20
(5, None), # 50 == NULL
]),
("!=", "not_equals", [
(1, True), # 10 != 20
(2, False), # 20 != 20
(3, True), # 30 != 20
(4, None), # NULL != 20
(5, None), # 50 != NULL
]),
(">", "greater_than", [
(1, False), # 10 > 20
(2, False), # 20 > 20
(3, True), # 30 > 20
(4, None), # NULL > 20
(5, None), # 50 > NULL
]),
(">=", "greater_than_equals", [
(1, False), # 10 >= 20
(2, True), # 20 >= 20
(3, True), # 30 >= 20
(4, None), # NULL >= 20
(5, None), # 50 >= NULL
]),
("<", "less_than", [
(1, True), # 10 < 20
(2, False), # 20 < 20
(3, False), # 30 < 20
(4, None), # NULL < 20
(5, None), # 50 < NULL
]),
("<=", "less_than_equals", [
(1, True), # 10 <= 20
(2, True), # 20 <= 20
(3, False), # 30 <= 20
(4, None), # NULL <= 20
(5, None), # 50 <= NULL
]),
])
def test_numeric_comparisons(sample_df, operator, func_name, expected_results):
"""
Test various numeric comparison operations with NULL handling.
Tests both direct column comparisons and literal value comparisons.
"""
# Test column to column comparison
result_df = sample_df.withColumn(
f"result_{func_name}_col",
eval(f"F.col('value_a') {operator} F.col('value_b')")
).orderBy("id")

# Test column to literal comparison
result_df = result_df.withColumn(
f"result_{func_name}_lit",
eval(f"F.col('value_a') {operator} F.lit(20)")
)

# Collect results and compare row by row
actual_results_col = [(row['id'], row[f'result_{func_name}_col'])
for row in result_df.collect()]
actual_results_lit = [(row['id'], row[f'result_{func_name}_lit'])
for row in result_df.collect()]

# Compare column results
for expected, actual in zip(expected_results, actual_results_col):
assert expected == actual, (
f"Column comparison {operator} failed for id {expected[0]}. "
f"Expected {expected[1]}, got {actual[1]}"
)

# Compare literal results (same expected results since we're comparing with 20)
for expected, actual in zip(expected_results, actual_results_lit):
assert expected == actual, (
f"Literal comparison {operator} failed for id {expected[0]}. "
f"Expected {expected[1]}, got {actual[1]}"
)


@pytest.mark.parametrize("operator,func_name,expected_results", [
("==", "equals", [
(1, False), # apple == banana
(2, True), # apple == apple
(3, False), # cherry == banana
(4, None), # NULL == banana
(5, None), # date == NULL
]),
("!=", "not_equals", [
(1, True), # apple != banana
(2, False), # apple != apple
(3, True), # cherry != banana
(4, None), # NULL != banana
(5, None), # date != NULL
]),
])
def test_string_comparisons(sample_df, operator, func_name, expected_results):
"""
Test string comparison operations with NULL handling.
Tests both direct column comparisons and literal value comparisons.
"""
# Test column to column comparison
result_df = sample_df.withColumn(
f"result_{func_name}_col",
eval(f"F.col('text_a') {operator} F.col('text_b')")
).orderBy("id")

# Test column to literal comparison
result_df = result_df.withColumn(
f"result_{func_name}_lit",
eval(f"F.col('text_a') {operator} F.lit('banana')")
)

# Collect results and compare row by row
actual_results_col = [(row['id'], row[f'result_{func_name}_col'])
for row in result_df.collect()]
actual_results_lit = [(row['id'], row[f'result_{func_name}_lit'])
for row in result_df.collect()]

# Compare column results
for expected, actual in zip(expected_results, actual_results_col):
assert expected == actual, (
f"Column comparison {operator} failed for id {expected[0]}. "
f"Expected {expected[1]}, got {actual[1]}"
)

# Compare literal results
for expected, actual in zip(expected_results, actual_results_lit):
assert expected == actual, (
f"Literal comparison {operator} failed for id {expected[0]}. "
f"Expected {expected[1]}, got {actual[1]}"
)


@pytest.mark.skip(reason="We believe null-safe equals are not yet implemented")
def test_null_safe_equals(sample_df):
"""
Test null-safe equality comparison using the <=> operator.
This operator treats NULL = NULL as TRUE.
"""
result_df = sample_df.withColumn(
"result_null_safe_equals",
F.col("value_a").eqNullSafe(F.col("value_b"))
).orderBy("id")

actual_results = [(row['id'], row['result_null_safe_equals'])
for row in result_df.collect()]

expected_results = [
(1, False), # 10 <=> 20
(2, True), # 20 <=> 20
(3, False), # 30 <=> 20
(4, False), # NULL <=> 20
(5, False), # 50 <=> NULL
]

for expected, actual in zip(expected_results, actual_results):
assert expected == actual, (
f"Null-safe equals failed for id {expected[0]}. "
f"Expected {expected[1]}, got {actual[1]}"
)

0 comments on commit 8766efc

Please sign in to comment.