-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b3e96bc
commit 2cd7a98
Showing
1 changed file
with
21 additions
and
175 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,190 +1,36 @@ | ||
import pytest | ||
from pyspark.sql import functions as F | ||
from pyspark.sql.types import StructType, StructField, IntegerType, StringType | ||
|
||
|
||
@pytest.fixture | ||
def sample_df(spark_session): | ||
"""Create a sample DataFrame with various test cases.""" | ||
schema = StructType([ | ||
StructField("id", IntegerType(), False), | ||
StructField("value_a", IntegerType(), True), | ||
StructField("value_b", IntegerType(), True), | ||
StructField("text_a", StringType(), True), | ||
StructField("text_b", StringType(), True), | ||
]) | ||
def test_numeric_equals(spark_session): | ||
"""Test numeric equality comparison with NULL handling.""" | ||
data = [(1, 10), (2, None)] | ||
df = spark_session.createDataFrame(data, ["id", "value"]) | ||
|
||
data = [ | ||
(1, 10, 20, "apple", "banana"), | ||
(2, 20, 20, "apple", "apple"), | ||
(3, 30, 20, "cherry", "banana"), | ||
(4, None, 20, None, "banana"), | ||
(5, 50, None, "date", None), | ||
] | ||
result = df.withColumn("equals_20", F.col("value") == F.lit(20)).collect() | ||
|
||
return spark_session.createDataFrame(data, schema) | ||
assert result[0].equals_20 is False # 10 == 20 | ||
assert result[1].equals_20 is None # NULL == 20 | ||
|
||
|
||
@pytest.mark.parametrize("operator,func_name,expected_results", [ | ||
("==", "equals", [ | ||
(1, False), # 10 == 20 | ||
(2, True), # 20 == 20 | ||
(3, False), # 30 == 20 | ||
(4, None), # NULL == 20 | ||
(5, None), # 50 == NULL | ||
]), | ||
("!=", "not_equals", [ | ||
(1, True), # 10 != 20 | ||
(2, False), # 20 != 20 | ||
(3, True), # 30 != 20 | ||
(4, None), # NULL != 20 | ||
(5, None), # 50 != NULL | ||
]), | ||
(">", "greater_than", [ | ||
(1, False), # 10 > 20 | ||
(2, False), # 20 > 20 | ||
(3, True), # 30 > 20 | ||
(4, None), # NULL > 20 | ||
(5, None), # 50 > NULL | ||
]), | ||
(">=", "greater_than_equals", [ | ||
(1, False), # 10 >= 20 | ||
(2, True), # 20 >= 20 | ||
(3, True), # 30 >= 20 | ||
(4, None), # NULL >= 20 | ||
(5, None), # 50 >= NULL | ||
]), | ||
("<", "less_than", [ | ||
(1, True), # 10 < 20 | ||
(2, False), # 20 < 20 | ||
(3, False), # 30 < 20 | ||
(4, None), # NULL < 20 | ||
(5, None), # 50 < NULL | ||
]), | ||
("<=", "less_than_equals", [ | ||
(1, True), # 10 <= 20 | ||
(2, True), # 20 <= 20 | ||
(3, False), # 30 <= 20 | ||
(4, None), # NULL <= 20 | ||
(5, None), # 50 <= NULL | ||
]), | ||
]) | ||
def test_numeric_comparisons(sample_df, operator, func_name, expected_results): | ||
""" | ||
Test various numeric comparison operations with NULL handling. | ||
Tests both direct column comparisons and literal value comparisons. | ||
""" | ||
# Test column to column comparison | ||
result_df = sample_df.withColumn( | ||
f"result_{func_name}_col", | ||
eval(f"F.col('value_a') {operator} F.col('value_b')") | ||
).orderBy("id") | ||
def test_string_equals(spark_session): | ||
"""Test string equality comparison with NULL handling.""" | ||
data = [(1, "apple"), (2, None)] | ||
df = spark_session.createDataFrame(data, ["id", "text"]) | ||
|
||
# Test column to literal comparison | ||
result_df = result_df.withColumn( | ||
f"result_{func_name}_lit", | ||
eval(f"F.col('value_a') {operator} F.lit(20)") | ||
) | ||
result = df.withColumn("equals_banana", F.col("text") == F.lit("banana")).collect() | ||
|
||
# Collect results and compare row by row | ||
actual_results_col = [(row['id'], row[f'result_{func_name}_col']) | ||
for row in result_df.collect()] | ||
actual_results_lit = [(row['id'], row[f'result_{func_name}_lit']) | ||
for row in result_df.collect()] | ||
|
||
# Compare column results | ||
for expected, actual in zip(expected_results, actual_results_col): | ||
assert expected == actual, ( | ||
f"Column comparison {operator} failed for id {expected[0]}. " | ||
f"Expected {expected[1]}, got {actual[1]}" | ||
) | ||
|
||
# Compare literal results (same expected results since we're comparing with 20) | ||
for expected, actual in zip(expected_results, actual_results_lit): | ||
assert expected == actual, ( | ||
f"Literal comparison {operator} failed for id {expected[0]}. " | ||
f"Expected {expected[1]}, got {actual[1]}" | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("operator,func_name,expected_results", [ | ||
("==", "equals", [ | ||
(1, False), # apple == banana | ||
(2, True), # apple == apple | ||
(3, False), # cherry == banana | ||
(4, None), # NULL == banana | ||
(5, None), # date == NULL | ||
]), | ||
("!=", "not_equals", [ | ||
(1, True), # apple != banana | ||
(2, False), # apple != apple | ||
(3, True), # cherry != banana | ||
(4, None), # NULL != banana | ||
(5, None), # date != NULL | ||
]), | ||
]) | ||
def test_string_comparisons(sample_df, operator, func_name, expected_results): | ||
""" | ||
Test string comparison operations with NULL handling. | ||
Tests both direct column comparisons and literal value comparisons. | ||
""" | ||
# Test column to column comparison | ||
result_df = sample_df.withColumn( | ||
f"result_{func_name}_col", | ||
eval(f"F.col('text_a') {operator} F.col('text_b')") | ||
).orderBy("id") | ||
|
||
# Test column to literal comparison | ||
result_df = result_df.withColumn( | ||
f"result_{func_name}_lit", | ||
eval(f"F.col('text_a') {operator} F.lit('banana')") | ||
) | ||
|
||
# Collect results and compare row by row | ||
actual_results_col = [(row['id'], row[f'result_{func_name}_col']) | ||
for row in result_df.collect()] | ||
actual_results_lit = [(row['id'], row[f'result_{func_name}_lit']) | ||
for row in result_df.collect()] | ||
|
||
# Compare column results | ||
for expected, actual in zip(expected_results, actual_results_col): | ||
assert expected == actual, ( | ||
f"Column comparison {operator} failed for id {expected[0]}. " | ||
f"Expected {expected[1]}, got {actual[1]}" | ||
) | ||
|
||
# Compare literal results | ||
for expected, actual in zip(expected_results, actual_results_lit): | ||
assert expected == actual, ( | ||
f"Literal comparison {operator} failed for id {expected[0]}. " | ||
f"Expected {expected[1]}, got {actual[1]}" | ||
) | ||
assert result[0].equals_banana is False # apple == banana | ||
assert result[1].equals_banana is None # NULL == banana | ||
|
||
|
||
@pytest.mark.skip(reason="We believe null-safe equals are not yet implemented") | ||
def test_null_safe_equals(sample_df): | ||
""" | ||
Test null-safe equality comparison using the <=> operator. | ||
This operator treats NULL = NULL as TRUE. | ||
""" | ||
result_df = sample_df.withColumn( | ||
"result_null_safe_equals", | ||
F.col("value_a").eqNullSafe(F.col("value_b")) | ||
).orderBy("id") | ||
|
||
actual_results = [(row['id'], row['result_null_safe_equals']) | ||
for row in result_df.collect()] | ||
def test_null_safe_equals(spark_session): | ||
"""Test null-safe equality comparison.""" | ||
data = [(1, 10), (2, None)] | ||
df = spark_session.createDataFrame(data, ["id", "value"]) | ||
|
||
expected_results = [ | ||
(1, False), # 10 <=> 20 | ||
(2, True), # 20 <=> 20 | ||
(3, False), # 30 <=> 20 | ||
(4, False), # NULL <=> 20 | ||
(5, False), # 50 <=> NULL | ||
] | ||
result = df.withColumn("null_safe_equals", F.col("value").eqNullSafe(F.lit(10))).collect() | ||
|
||
for expected, actual in zip(expected_results, actual_results): | ||
assert expected == actual, ( | ||
f"Null-safe equals failed for id {expected[0]}. " | ||
f"Expected {expected[1]}, got {actual[1]}" | ||
) | ||
assert result[0].null_safe_equals is True # 10 <=> 10 | ||
assert result[1].null_safe_equals is False # NULL <=> 10 |