diff --git a/tests/connect/test_get_attr.py b/tests/connect/test_get_attr.py new file mode 100644 index 0000000000..c48f114091 --- /dev/null +++ b/tests/connect/test_get_attr.py @@ -0,0 +1,15 @@ +from __future__ import annotations + + +def test_get_attr(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Get column using df[...] + # df.get_attr("id") is equivalent to df["id"] + df_col = df["id"] + + # Check that column values match expected range + values = df.select(df_col).collect() # Changed to select column first + assert len(values) == 10 + assert [row[0] for row in values] == list(range(10)) # Need to extract values from Row objects