Skip to content

Commit b04e56f

Browse files
[BUG]: joins with duplicate column names and qualified table expansion (#3074)
- closes #3070 - closes #3072 - closes #3073
1 parent 5795adc commit b04e56f

File tree

13 files changed

+338
-61
lines changed

13 files changed

+338
-61
lines changed

daft/daft/__init__.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -1726,6 +1726,8 @@ class LogicalPlanBuilder:
17261726
right_on: list[PyExpr],
17271727
join_type: JoinType,
17281728
strategy: JoinStrategy | None = None,
1729+
join_prefix: str | None = None,
1730+
join_suffix: str | None = None,
17291731
) -> LogicalPlanBuilder: ...
17301732
def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: ...
17311733
def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder: ...

daft/dataframe/dataframe.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -1688,10 +1688,13 @@ def join(
16881688
right_on: Optional[Union[List[ColumnInputType], ColumnInputType]] = None,
16891689
how: str = "inner",
16901690
strategy: Optional[str] = None,
1691+
prefix: Optional[str] = None,
1692+
suffix: Optional[str] = None,
16911693
) -> "DataFrame":
16921694
"""Column-wise join of the current DataFrame with an ``other`` DataFrame, similar to a SQL ``JOIN``
16931695
1694-
If the two DataFrames have duplicate non-join key column names, "right." will be prepended to the conflicting right columns.
1696+
If the two DataFrames have duplicate non-join key column names, "right." will be prepended to the conflicting right columns. You can change the behavior by passing either (or both) `prefix` or `suffix` to the function.
1697+
If `prefix` is passed, it will be prepended to the conflicting right columns. If `suffix` is passed, it will be appended to the conflicting right columns.
16951698
16961699
.. NOTE::
16971700
Although self joins are supported, we currently duplicate the logical plan for the right side
@@ -1716,6 +1719,42 @@ def join(
17161719
<BLANKLINE>
17171720
(Showing first 2 of 2 rows)
17181721
1722+
>>> import daft
1723+
>>> from daft import col
1724+
>>> df1 = daft.from_pydict({ "a": ["w", "x", "y"], "b": [1, 2, 3] })
1725+
>>> df2 = daft.from_pydict({ "a": ["x", "y", "z"], "b": [20, 30, 40] })
1726+
>>> joined_df = df1.join(df2, left_on=[col("a"), col("b")], right_on=[col("a"), col("b")/10], prefix="right_")
1727+
>>> joined_df.show()
1728+
╭──────┬───────┬─────────╮
1729+
│ a ┆ b ┆ right_b │
1730+
│ --- ┆ --- ┆ --- │
1731+
│ Utf8 ┆ Int64 ┆ Int64 │
1732+
╞══════╪═══════╪═════════╡
1733+
│ x ┆ 2 ┆ 20 │
1734+
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤
1735+
│ y ┆ 3 ┆ 30 │
1736+
╰──────┴───────┴─────────╯
1737+
<BLANKLINE>
1738+
(Showing first 2 of 2 rows)
1739+
1740+
>>> import daft
1741+
>>> from daft import col
1742+
>>> df1 = daft.from_pydict({ "a": ["w", "x", "y"], "b": [1, 2, 3] })
1743+
>>> df2 = daft.from_pydict({ "a": ["x", "y", "z"], "b": [20, 30, 40] })
1744+
>>> joined_df = df1.join(df2, left_on=[col("a"), col("b")], right_on=[col("a"), col("b")/10], suffix="_right")
1745+
>>> joined_df.show()
1746+
╭──────┬───────┬─────────╮
1747+
│ a ┆ b ┆ b_right │
1748+
│ --- ┆ --- ┆ --- │
1749+
│ Utf8 ┆ Int64 ┆ Int64 │
1750+
╞══════╪═══════╪═════════╡
1751+
│ x ┆ 2 ┆ 20 │
1752+
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤
1753+
│ y ┆ 3 ┆ 30 │
1754+
╰──────┴───────┴─────────╯
1755+
<BLANKLINE>
1756+
(Showing first 2 of 2 rows)
1757+
17191758
Args:
17201759
other (DataFrame): the right DataFrame to join on.
17211760
on (Optional[Union[List[ColumnInputType], ColumnInputType]], optional): key or keys to join on [use if the keys on the left and right side match.]. Defaults to None.
@@ -1724,6 +1763,8 @@ def join(
17241763
how (str, optional): what type of join to perform; currently "inner", "left", "right", "outer", "anti", and "semi" are supported. Defaults to "inner".
17251764
strategy (Optional[str]): The join strategy (algorithm) to use; currently "hash", "sort_merge", "broadcast", and None are supported, where None
17261765
chooses the join strategy automatically during query optimization. The default is None.
1766+
suffix (Optional[str], optional): Suffix to add to the column names in case of a name collision. Defaults to "".
1767+
prefix (Optional[str], optional): Prefix to add to the column names in case of a name collision. Defaults to "right.".
17271768
17281769
Raises:
17291770
ValueError: if `on` is passed in and `left_on` or `right_on` is not None.
@@ -1756,6 +1797,8 @@ def join(
17561797
right_on=right_exprs,
17571798
how=join_type,
17581799
strategy=join_strategy,
1800+
join_prefix=prefix,
1801+
join_suffix=suffix,
17591802
)
17601803
return DataFrame(builder)
17611804

daft/logical/builder.py

+4
Original file line numberDiff line numberDiff line change
@@ -255,13 +255,17 @@ def join( # type: ignore[override]
255255
right_on: list[Expression],
256256
how: JoinType = JoinType.Inner,
257257
strategy: JoinStrategy | None = None,
258+
join_suffix: str | None = None,
259+
join_prefix: str | None = None,
258260
) -> LogicalPlanBuilder:
259261
builder = self._builder.join(
260262
right._builder,
261263
[expr._expr for expr in left_on],
262264
[expr._expr for expr in right_on],
263265
how,
264266
strategy,
267+
join_suffix,
268+
join_prefix,
265269
)
266270
return LogicalPlanBuilder(builder)
267271

src/daft-plan/src/builder.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -443,13 +443,16 @@ impl LogicalPlanBuilder {
443443
Ok(self.with_new_plan(pivot_logical_plan))
444444
}
445445

446+
#[allow(clippy::too_many_arguments)]
446447
pub fn join<Right: Into<LogicalPlanRef>>(
447448
&self,
448449
right: Right,
449450
left_on: Vec<ExprRef>,
450451
right_on: Vec<ExprRef>,
451452
join_type: JoinType,
452453
join_strategy: Option<JoinStrategy>,
454+
join_suffix: Option<&str>,
455+
join_prefix: Option<&str>,
453456
) -> DaftResult<Self> {
454457
let logical_plan: LogicalPlan = logical_ops::Join::try_new(
455458
self.plan.clone(),
@@ -458,6 +461,8 @@ impl LogicalPlanBuilder {
458461
right_on,
459462
join_type,
460463
join_strategy,
464+
join_suffix,
465+
join_prefix,
461466
)?
462467
.into();
463468
Ok(self.with_new_plan(logical_plan))
@@ -868,14 +873,16 @@ impl PyLogicalPlanBuilder {
868873
)?
869874
.into())
870875
}
871-
876+
#[allow(clippy::too_many_arguments)]
872877
pub fn join(
873878
&self,
874879
right: &Self,
875880
left_on: Vec<PyExpr>,
876881
right_on: Vec<PyExpr>,
877882
join_type: JoinType,
878883
join_strategy: Option<JoinStrategy>,
884+
join_suffix: Option<&str>,
885+
join_prefix: Option<&str>,
879886
) -> PyResult<Self> {
880887
Ok(self
881888
.builder
@@ -885,6 +892,8 @@ impl PyLogicalPlanBuilder {
885892
pyexprs_to_exprs(right_on),
886893
join_type,
887894
join_strategy,
895+
join_suffix,
896+
join_prefix,
888897
)?
889898
.into())
890899
}

src/daft-plan/src/display.rs

+4
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ mod test {
166166
vec![col("id")],
167167
JoinType::Inner,
168168
None,
169+
None,
170+
None,
169171
)?
170172
.filter(col("first_name").eq(lit("hello")))?
171173
.select(vec![col("first_name")])?
@@ -237,6 +239,8 @@ Project1 --> Limit0
237239
vec![col("id")],
238240
JoinType::Inner,
239241
None,
242+
None,
243+
None,
240244
)?
241245
.filter(col("first_name").eq(lit("hello")))?
242246
.select(vec![col("first_name")])?

src/daft-plan/src/logical_ops/join.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,16 @@ impl std::hash::Hash for Join {
4747
}
4848

4949
impl Join {
50+
#[allow(clippy::too_many_arguments)]
5051
pub(crate) fn try_new(
5152
left: Arc<LogicalPlan>,
5253
right: Arc<LogicalPlan>,
5354
left_on: Vec<ExprRef>,
5455
right_on: Vec<ExprRef>,
5556
join_type: JoinType,
5657
join_strategy: Option<JoinStrategy>,
58+
join_suffix: Option<&str>,
59+
join_prefix: Option<&str>,
5760
) -> logical_plan::Result<Self> {
5861
let (left_on, _) = resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?;
5962
let (right_on, _) =
@@ -124,7 +127,14 @@ impl Join {
124127
} else {
125128
let mut new_name = name.clone();
126129
while names_so_far.contains(&new_name) {
127-
new_name = format!("right.{}", new_name);
130+
if let Some(prefix) = join_prefix {
131+
new_name = format!("{}{}", prefix, new_name);
132+
} else if join_suffix.is_none() {
133+
new_name = format!("right.{}", new_name);
134+
}
135+
if let Some(suffix) = join_suffix {
136+
new_name = format!("{}{}", new_name, suffix);
137+
}
128138
}
129139
names_so_far.insert(new_name.clone());
130140

src/daft-plan/src/logical_optimization/rules/push_down_filter.rs

+37-3
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,8 @@ mod tests {
674674
join_on.clone(),
675675
how,
676676
None,
677+
None,
678+
None,
677679
)?
678680
.filter(pred.clone())?
679681
.build();
@@ -686,7 +688,15 @@ mod tests {
686688
left_scan_plan.filter(pred)?
687689
};
688690
let expected = expected_left_filter_scan
689-
.join(&right_scan_plan, join_on.clone(), join_on, how, None)?
691+
.join(
692+
&right_scan_plan,
693+
join_on.clone(),
694+
join_on,
695+
how,
696+
None,
697+
None,
698+
None,
699+
)?
690700
.build();
691701
assert_optimized_plan_eq(plan, expected)?;
692702
Ok(())
@@ -720,6 +730,8 @@ mod tests {
720730
join_on.clone(),
721731
how,
722732
None,
733+
None,
734+
None,
723735
)?
724736
.filter(pred.clone())?
725737
.build();
@@ -738,6 +750,8 @@ mod tests {
738750
join_on,
739751
how,
740752
None,
753+
None,
754+
None,
741755
)?
742756
.build();
743757
assert_optimized_plan_eq(plan, expected)?;
@@ -785,6 +799,8 @@ mod tests {
785799
join_on.clone(),
786800
how,
787801
None,
802+
None,
803+
None,
788804
)?
789805
.filter(pred.clone())?
790806
.build();
@@ -811,6 +827,8 @@ mod tests {
811827
join_on,
812828
how,
813829
None,
830+
None,
831+
None,
814832
)?
815833
.build();
816834
assert_optimized_plan_eq(plan, expected)?;
@@ -835,7 +853,15 @@ mod tests {
835853
let join_on = vec![col("b")];
836854
let pred = col("a").lt(lit(2));
837855
let plan = left_scan_plan
838-
.join(&right_scan_plan, join_on.clone(), join_on, how, None)?
856+
.join(
857+
&right_scan_plan,
858+
join_on.clone(),
859+
join_on,
860+
how,
861+
None,
862+
None,
863+
None,
864+
)?
839865
.filter(pred)?
840866
.build();
841867
// should not push down filter
@@ -862,7 +888,15 @@ mod tests {
862888
let join_on = vec![col("b")];
863889
let pred = col("c").lt(lit(2.0));
864890
let plan = left_scan_plan
865-
.join(&right_scan_plan, join_on.clone(), join_on, how, None)?
891+
.join(
892+
&right_scan_plan,
893+
join_on.clone(),
894+
join_on,
895+
how,
896+
None,
897+
None,
898+
None,
899+
)?
866900
.filter(pred)?
867901
.build();
868902
// should not push down filter

src/daft-plan/src/logical_plan.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,16 @@ impl LogicalPlan {
264264
[input1, input2] => match self {
265265
Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"),
266266
Self::Concat(_) => Self::Concat(Concat::try_new(input1.clone(), input2.clone()).unwrap()),
267-
Self::Join(Join { left_on, right_on, join_type, join_strategy, .. }) => Self::Join(Join::try_new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), *join_type, *join_strategy).unwrap()),
267+
Self::Join(Join { left_on, right_on, join_type, join_strategy, .. }) => Self::Join(Join::try_new(
268+
input1.clone(),
269+
input2.clone(),
270+
left_on.clone(),
271+
right_on.clone(),
272+
*join_type,
273+
*join_strategy,
274+
None, // The suffix is already eagerly computed in the constructor
275+
None // the prefix is already eagerly computed in the constructor
276+
).unwrap()),
268277
_ => panic!("Logical op {} has one input, but got two", self),
269278
},
270279
_ => panic!("Logical ops should never have more than 2 inputs, but got: {}", children.len())

src/daft-plan/src/physical_planner/translate.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,8 @@ mod tests {
12201220
vec![col("a"), col("b")],
12211221
JoinType::Inner,
12221222
Some(JoinStrategy::Hash),
1223+
None,
1224+
None,
12231225
)?
12241226
.build();
12251227
logical_to_physical(logical_plan, cfg)

src/daft-sql/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,8 @@ mod tests {
261261
vec![col("id")],
262262
JoinType::Inner,
263263
None,
264+
None,
265+
None,
264266
)?
265267
.select(vec![col("*")])?
266268
.build();

0 commit comments

Comments
 (0)