Skip to content

Commit b87e0a3

Browse files
feat(connect): add more unresolved functions (#3618)
1 parent f6002f9 commit b87e0a3

File tree

2 files changed

+80
-15
lines changed

2 files changed

+80
-15
lines changed

src/daft-connect/src/translation/expr/unresolved_function.rs

+32-15
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,20 @@ pub fn unresolved_to_daft_expr(f: &UnresolvedFunction) -> eyre::Result<daft_dsl:
2323
}
2424

2525
match function_name.as_str() {
26-
"count" => handle_count(arguments).wrap_err("Failed to handle count function"),
27-
"<" => handle_binary_op(arguments, daft_dsl::Operator::Lt)
28-
.wrap_err("Failed to handle < function"),
29-
">" => handle_binary_op(arguments, daft_dsl::Operator::Gt)
30-
.wrap_err("Failed to handle > function"),
31-
"<=" => handle_binary_op(arguments, daft_dsl::Operator::LtEq)
32-
.wrap_err("Failed to handle <= function"),
33-
">=" => handle_binary_op(arguments, daft_dsl::Operator::GtEq)
34-
.wrap_err("Failed to handle >= function"),
35-
"%" => handle_binary_op(arguments, daft_dsl::Operator::Modulus)
36-
.wrap_err("Failed to handle % function"),
37-
"sum" => handle_sum(arguments).wrap_err("Failed to handle sum function"),
38-
"isnotnull" => handle_isnotnull(arguments).wrap_err("Failed to handle isnotnull function"),
39-
"isnull" => handle_isnull(arguments).wrap_err("Failed to handle isnull function"),
40-
n => bail!("Unresolved function {n} not yet supported"),
26+
"%" => handle_binary_op(arguments, daft_dsl::Operator::Modulus),
27+
"<" => handle_binary_op(arguments, daft_dsl::Operator::Lt),
28+
"<=" => handle_binary_op(arguments, daft_dsl::Operator::LtEq),
29+
"==" => handle_binary_op(arguments, daft_dsl::Operator::Eq),
30+
">" => handle_binary_op(arguments, daft_dsl::Operator::Gt),
31+
">=" => handle_binary_op(arguments, daft_dsl::Operator::GtEq),
32+
"count" => handle_count(arguments),
33+
"isnotnull" => handle_isnotnull(arguments),
34+
"isnull" => handle_isnull(arguments),
35+
"not" => not(arguments),
36+
"sum" => handle_sum(arguments),
37+
n => bail!("Unresolved function {n:?} not yet supported"),
4138
}
39+
.wrap_err_with(|| format!("Failed to handle function {function_name:?}"))
4240
}
4341

4442
pub fn handle_sum(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
@@ -53,6 +51,25 @@ pub fn handle_sum(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::E
5351
Ok(arg.sum())
5452
}
5553

54+
/// If the arguments are exactly one, return it. Otherwise, return an error.
55+
pub fn to_single(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
56+
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() {
57+
Ok(arguments) => arguments,
58+
Err(arguments) => {
59+
bail!("requires exactly one argument; got {arguments:?}");
60+
}
61+
};
62+
63+
let [arg] = arguments;
64+
65+
Ok(arg)
66+
}
67+
68+
pub fn not(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
69+
let arg = to_single(arguments)?;
70+
Ok(arg.not())
71+
}
72+
5673
pub fn handle_binary_op(
5774
arguments: Vec<daft_dsl::ExprRef>,
5875
op: daft_dsl::Operator,

tests/connect/test_unresolved.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import pytest
2+
from pyspark.sql import functions as F
3+
4+
5+
def test_numeric_equals(spark_session):
6+
"""Test numeric equality comparison with NULL handling."""
7+
data = [(1, 10), (2, None)]
8+
df = spark_session.createDataFrame(data, ["id", "value"])
9+
10+
result = df.withColumn("equals_20", F.col("value") == F.lit(20)).collect()
11+
12+
assert result[0].equals_20 is False # 10 == 20
13+
assert result[1].equals_20 is None # NULL == 20
14+
15+
16+
def test_string_equals(spark_session):
17+
"""Test string equality comparison with NULL handling."""
18+
data = [(1, "apple"), (2, None)]
19+
df = spark_session.createDataFrame(data, ["id", "text"])
20+
21+
result = df.withColumn("equals_banana", F.col("text") == F.lit("banana")).collect()
22+
23+
assert result[0].equals_banana is False # apple == banana
24+
assert result[1].equals_banana is None # NULL == banana
25+
26+
27+
@pytest.mark.skip(reason="We believe null-safe equals are not yet implemented")
28+
def test_null_safe_equals(spark_session):
29+
"""Test null-safe equality comparison."""
30+
data = [(1, 10), (2, None)]
31+
df = spark_session.createDataFrame(data, ["id", "value"])
32+
33+
result = df.withColumn("null_safe_equals", F.col("value").eqNullSafe(F.lit(10))).collect()
34+
35+
assert result[0].null_safe_equals is True # 10 <=> 10
36+
assert result[1].null_safe_equals is False # NULL <=> 10
37+
38+
39+
def test_not(spark_session):
40+
"""Test logical NOT operation with NULL handling."""
41+
data = [(True,), (False,), (None,)]
42+
df = spark_session.createDataFrame(data, ["value"])
43+
44+
result = df.withColumn("not_value", ~F.col("value")).collect()
45+
46+
assert result[0].not_value is False # NOT True
47+
assert result[1].not_value is True # NOT False
48+
assert result[2].not_value is None # NOT NULL

0 commit comments

Comments
 (0)