Skip to content

Commit

Permalink
[FEAT] connect: add more (internally P2) column operations
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 4, 2024
1 parent 3ee5757 commit 9fd596b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/daft-connect/src/translation/expr/unresolved_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ pub fn unresolved_to_daft_expr(f: &UnresolvedFunction) -> eyre::Result<daft_dsl:
.wrap_err("Failed to handle <= function"),
">=" => handle_binary_op(arguments, daft_dsl::Operator::GtEq)
.wrap_err("Failed to handle >= function"),
"and" => handle_binary_op(arguments, daft_dsl::Operator::And)
.wrap_err("Failed to handle and function"),
"isnotnull" => handle_isnotnull(arguments).wrap_err("Failed to handle isnotnull function"),
"isnull" => handle_isnull(arguments).wrap_err("Failed to handle isnull function"),
n => bail!("Unresolved function {n} not yet supported"),
Expand Down
57 changes: 57 additions & 0 deletions tests/connect/test_basic_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,60 @@ def test_column_name(spark_session):
# df = spark_session.range(10)
# df_item = df.select(col("id")[0])
# assert df_item.toPandas()["id"].iloc[0] == 0, "getitem should return first element"


def test_column_astype(spark_session):
df = spark_session.range(10)
df_astype = df.select(col("id").astype(StringType()))
assert df_astype.schema.fields[0].dataType == StringType(), "astype should change data type"


def test_column_between(spark_session):
df = spark_session.range(10)
df_between = df.select(col("id").between(3, 6).alias("in_range"))
assert df_between.toPandas()["in_range"].tolist() == [False, False, False, True, True, True, True, False, False, False]


# TODO: Uncomment when string operations are implemented
# def test_column_string_ops(spark_session):
# df_str = spark_session.createDataFrame([("hello",), ("world",)], ["text"])
# df_contains = df_str.select(col("text").contains("o").alias("has_o"))
# assert df_contains.toPandas()["has_o"].tolist() == [True, True]
# df_startswith = df_str.select(col("text").startswith("h").alias("starts_h"))
# assert df_startswith.toPandas()["starts_h"].tolist() == [True, False]
# df_endswith = df_str.select(col("text").endswith("d").alias("ends_d"))
# assert df_endswith.toPandas()["ends_d"].tolist() == [False, True]
# df_substr = df_str.select(col("text").substr(1, 2).alias("first_two"))
# assert df_substr.toPandas()["first_two"].tolist() == ["he", "wo"]


# TODO: Uncomment when struct operations are implemented
# def test_column_struct_ops(spark_session):
# df_struct = spark_session.createDataFrame([
# ({"a": 1, "b": 2},),
# ({"a": 3, "b": 4},)
# ], ["data"])
# df_getfield = df_struct.select(col("data").getField("a").alias("a_val"))
# assert df_getfield.toPandas()["a_val"].tolist() == [1, 3]
# df_dropfields = df_struct.select(col("data").dropFields("a").alias("no_a"))
# assert "a" not in df_dropfields.toPandas()["no_a"][0]
# df_withfield = df_struct.select(col("data").withField("c", col("data.a") + 10).alias("with_c"))
# assert df_withfield.toPandas()["with_c"][0]["c"] == 11


# TODO: Uncomment when array operations are implemented
# def test_column_array_ops(spark_session):
# df_array = spark_session.createDataFrame([([1, 2, 3],), ([4, 5, 6],)], ["numbers"])
# df_getitem = df_array.select(col("numbers").getItem(0).alias("first"))
# assert df_getitem.toPandas()["first"].tolist() == [1, 4]


# TODO: Uncomment when when/otherwise operations are implemented
# def test_column_case_when(spark_session):
# df = spark_session.range(10)
# df_case = df.select(
# col("id").when(col("id") < 5, "low")
# .otherwise("high")
# .alias("category")
# )
# assert df_case.toPandas()["category"].tolist() == ["low"] * 5 + ["high"] * 5

0 comments on commit 9fd596b

Please sign in to comment.