diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 4073334ff4..65355f8043 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -5,7 +5,7 @@ use spark_connect::{relation::RelType, Limit, Relation}; use tracing::warn; use crate::translation::logical_plan::{ - aggregate::aggregate, drop::drop, filter::filter, local_relation::local_relation, + aggregate::aggregate, drop::drop, filter::filter, join::join, local_relation::local_relation, project::project, range::range, read::read, to_df::to_df, with_columns::with_columns, with_columns_renamed::with_columns_renamed, }; @@ -13,6 +13,7 @@ use crate::translation::logical_plan::{ mod aggregate; mod drop; mod filter; +mod join; mod local_relation; mod project; mod range; @@ -84,6 +85,9 @@ pub async fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::Read(r) => read(r) .await .wrap_err("Failed to apply read to logical plan"), + RelType::Join(j) => join(*j) + .await + .wrap_err("Failed to apply join to logical plan"), RelType::Drop(d) => drop(*d) .await .wrap_err("Failed to apply drop to logical plan"), diff --git a/src/daft-connect/src/translation/logical_plan/join.rs b/src/daft-connect/src/translation/logical_plan/join.rs new file mode 100644 index 0000000000..89c4345ee8 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/join.rs @@ -0,0 +1,102 @@ +use eyre::{bail, WrapErr}; +use spark_connect::join::JoinType; +use tracing::warn; + +use crate::translation::{to_logical_plan, Plan}; + +pub async fn join(join: spark_connect::Join) -> eyre::Result { + let spark_connect::Join { + left, + right, + join_condition, + join_type, + using_columns, + join_data_type, + } = join; + + let Some(left) = left else { + bail!("Left side of join is required"); + }; + + let Some(right) = right else { + bail!("Right side of join is required"); + }; + + if let Some(join_condition) = join_condition { + bail!("Join conditions are not yet supported; use using_columns (join keys) instead; got {join_condition:?}"); + } + + let join_type = JoinType::try_from(join_type) + .wrap_err_with(|| format!("Invalid join type: {join_type:?}"))?; + + let join_type = to_daft_join_type(join_type)?; + + let using_columns_exprs: Vec<_> = using_columns + .iter() + .map(|s| daft_dsl::col(s.as_str())) + .collect(); + + if let Some(join_data_type) = join_data_type { + warn!("Ignoring join data type {join_data_type:?} for join; not yet implemented"); + } + + let mut left = Box::pin(to_logical_plan(*left)).await?; + let right = Box::pin(to_logical_plan(*right)).await?; + + left.psets.partitions.extend(right.psets.partitions); + + let builder = match join_type { + JoinTypeInfo::Cross => { + left.builder.cross_join(&right.builder, None, None)? // todo(correctness): is this correct? + } + JoinTypeInfo::Regular(join_type) => { + left.builder.join( + &right.builder, + // join_conditions.clone(), // todo(correctness): is this correct? + // join_conditions, // todo(correctness): is this correct? + using_columns_exprs.clone(), + using_columns_exprs, + join_type, + None, + None, + None, + false, // todo(correctness): we want join keys or not + )? + } + }; + + let result = Plan { + builder, + psets: left.psets, + }; + + Ok(result) +} + +enum JoinTypeInfo { + Regular(daft_core::join::JoinType), + Cross, +} + +impl From for JoinTypeInfo { + fn from(join_type: daft_logical_plan::JoinType) -> Self { + Self::Regular(join_type) + } +} + +fn to_daft_join_type(join_type: JoinType) -> eyre::Result { + match join_type { + JoinType::Unspecified => { + bail!("Join type must be specified; got Unspecified") + } + JoinType::Inner => Ok(daft_core::join::JoinType::Inner.into()), + JoinType::FullOuter => { + bail!("Full outer joins not yet supported") // todo(completeness): add support for full outer joins if it is not already implemented + } + JoinType::LeftOuter => Ok(daft_core::join::JoinType::Left.into()), // todo(correctness): is this correct? + JoinType::RightOuter => Ok(daft_core::join::JoinType::Right.into()), + JoinType::LeftAnti => Ok(daft_core::join::JoinType::Anti.into()), // todo(correctness): is this correct? + JoinType::LeftSemi => bail!("Left semi joins not yet supported"), // todo(completeness): add support for left semi joins if it is not already implemented + JoinType::Cross => Ok(JoinTypeInfo::Cross), + } +} diff --git a/tests/connect/test_join.py b/tests/connect/test_join.py new file mode 100644 index 0000000000..86afb975a5 --- /dev/null +++ b/tests/connect/test_join.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from pyspark.sql.functions import col + + +def test_join(spark_session): + # Create two DataFrames with overlapping IDs + df1 = spark_session.range(5) + df2 = spark_session.range(3, 7) + + # Perform inner join on 'id' column + joined_df = df1.join(df2, "id", "inner") + + # Verify join results using collect() + joined_ids = {row.id for row in joined_df.select("id").collect()} + assert joined_ids == {3, 4}, "Inner join should only contain IDs 3 and 4" + + # Test left outer join + left_joined_df = df1.join(df2, "id", "left") + left_joined_ids = {row.id for row in left_joined_df.select("id").collect()} + assert left_joined_ids == {0, 1, 2, 3, 4}, "Left join should keep all rows from left DataFrame" + + # Test right outer join + right_joined_df = df1.join(df2, "id", "right") + right_joined_ids = {row.id for row in right_joined_df.select("id").collect()} + assert right_joined_ids == {3, 4, 5, 6}, "Right join should keep all rows from right DataFrame" + + + +def test_cross_join(spark_session): + # Create two small DataFrames to demonstrate cross join + # df_left: [0, 1] + # df_right: [10, 11] + # Expected result will be all combinations: + # id1 id2 + # 0 10 + # 0 11 + # 1 10 + # 1 11 + df_left = spark_session.range(2) + df_right = spark_session.range(10, 12).withColumnRenamed("id", "id2") + + # Perform cross join - this creates cartesian product of both DataFrames + cross_joined_df = df_left.crossJoin(df_right) + + # Convert to pandas for easier verification + result_df = cross_joined_df.toPandas() + + # Verify we get all 4 combinations (2 x 2 = 4 rows) + assert len(result_df) == 4, "Cross join should produce 4 rows (2x2 cartesian product)" + + # Verify all expected combinations exist + expected_combinations = {(0, 10), (0, 11), (1, 10), (1, 11)} + actual_combinations = {(row["id"], row["id2"]) for _, row in result_df.iterrows()} + assert actual_combinations == expected_combinations, "Cross join should contain all possible combinations" + +