diff --git a/crates/core/src/operations/merge/filter.rs b/crates/core/src/operations/merge/filter.rs index 5346c1b6a9..0745c55830 100644 --- a/crates/core/src/operations/merge/filter.rs +++ b/crates/core/src/operations/merge/filter.rs @@ -675,7 +675,7 @@ mod tests { Arc::new(arrow::array::StringArray::from(vec![ "2023-07-04", "2023-07-05", - "2023-07-05" + "2023-07-05", ])), ], ) @@ -702,12 +702,10 @@ mod tests { relation: Some(target_name.clone()), name: "id".to_owned(), })) - .and( - col("modified".to_owned()) - .in_list(vec![ - lit("2023-07-05"), lit("2023-07-06"), lit("2023-07-07") - ], false), - ); + .and(col("modified".to_owned()).in_list( + vec![lit("2023-07-05"), lit("2023-07-06"), lit("2023-07-07")], + false, + )); let pred = try_construct_early_filter( join_predicate, @@ -734,11 +732,15 @@ mod tests { col(Column { relation: None, name: "modified".to_owned(), - }).in_list(vec![ - Expr::Literal(ScalarValue::Utf8(Some("2023-07-05".to_string()))), - Expr::Literal(ScalarValue::Utf8(Some("2023-07-06".to_string()))), - Expr::Literal(ScalarValue::Utf8(Some("2023-07-07".to_string()))) - ], false), + }) + .in_list( + vec![ + Expr::Literal(ScalarValue::Utf8(Some("2023-07-05".to_string()))), + Expr::Literal(ScalarValue::Utf8(Some("2023-07-06".to_string()))), + Expr::Literal(ScalarValue::Utf8(Some("2023-07-07".to_string()))), + ], + false, + ), ); assert_eq!(pred.unwrap(), filter); } @@ -760,7 +762,7 @@ mod tests { Arc::new(arrow::array::StringArray::from(vec![ "2023-07-04", "2023-07-05", - "2023-07-05" + "2023-07-05", ])), ], ) @@ -787,9 +789,8 @@ mod tests { relation: Some(target_name.clone()), name: "id".to_owned(), })) - .and( - col("modified".to_owned()) - .in_list(vec![ + .and(col("modified".to_owned()).in_list( + vec![ col(Column { relation: Some(target_name.clone()), name: "id".to_owned(), @@ -797,9 +798,10 @@ mod tests { col(Column { relation: Some(target_name.clone()), name: "modified".to_owned(), - }) - ], false), - ); + }), + ], + false, + )); let pred = try_construct_early_filter( join_predicate, @@ -826,16 +828,20 @@ mod tests { col(Column { relation: None, name: "modified".to_owned(), - }).in_list(vec![ - col(Column { - relation: Some(target_name.clone()), - name: "id".to_owned(), - }), - col(Column { - relation: Some(target_name.clone()), - name: "modified".to_owned(), - }) - ], false), + }) + .in_list( + vec![ + col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + }), + col(Column { + relation: Some(target_name.clone()), + name: "modified".to_owned(), + }), + ], + false, + ), ); assert_eq!(pred.unwrap(), filter); } @@ -857,7 +863,7 @@ mod tests { Arc::new(arrow::array::StringArray::from(vec![ "2023-07-04", "2023-07-05", - "2023-07-05" + "2023-07-05", ])), ], ) @@ -884,9 +890,8 @@ mod tests { relation: Some(target_name.clone()), name: "id".to_owned(), })) - .and( - ident("source.id") - .in_list(vec![ + .and(ident("source.id").in_list( + vec![ col(Column { relation: Some(target_name.clone()), name: "id".to_owned(), @@ -894,9 +899,10 @@ mod tests { col(Column { relation: Some(target_name.clone()), name: "modified".to_owned(), - }) - ], false), - ); + }), + ], + false, + )); let pred = try_construct_early_filter( join_predicate, @@ -919,8 +925,8 @@ mod tests { Expr::Literal(ScalarValue::Utf8(Some("A".to_string()))), Expr::Literal(ScalarValue::Utf8(Some("C".to_string()))), ) - .and( - ident("source.id").in_list(vec![ + .and(ident("source.id").in_list( + vec![ col(Column { relation: Some(target_name.clone()), name: "id".to_owned(), @@ -928,9 +934,10 @@ mod tests { col(Column { relation: Some(target_name.clone()), name: "modified".to_owned(), - }) - ], false), - ); + }), + ], + false, + )); assert_eq!(pred.unwrap(), filter); } } diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index f799ecfe86..fbe255cdbc 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -1968,6 +1968,115 @@ mod tests { assert_batches_sorted_eq!(&expected, &actual); } + #[tokio::test] + async fn test_merge_partitions_with_in() { + /* Validate the join predicate works with table partitions */ + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["modified"])).await; + + let table = write_data(table, &schema).await; + assert_eq!(table.version(), 1); + assert_eq!(table.get_files_count(), 2); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-02", + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let (table, metrics) = DeltaOps(table) + .merge( + source, + col("target.id") + .eq(col("source.id")) + .and(col("target.id").in_list( + vec![ + col("source.id"), + col("source.modified"), + col("source.value"), + ], + false, + )) + .and(col("target.modified").in_list(vec![lit("2021-02-02")], false)), + ) + .with_source_alias("source") + .with_target_alias("target") + .when_matched_update(|update| { + update + .update("value", col("source.value")) + .update("modified", col("source.modified")) + }) + .unwrap() + .when_not_matched_by_source_update(|update| { + update + .predicate(col("target.value").eq(lit(1))) + .update("value", col("target.value") + lit(1)) + }) + .unwrap() + .when_not_matched_by_source_update(|update| { + update + .predicate(col("target.modified").eq(lit("2021-02-01"))) + .update("value", col("target.value") - lit(1)) + }) + .unwrap() + .when_not_matched_insert(|insert| { + insert + .set("id", col("source.id")) + .set("value", col("source.value")) + .set("modified", col("source.modified")) + }) + .unwrap() + .await + .unwrap(); + + assert_eq!(table.version(), 2); + assert!(table.get_files_count() >= 3); + assert!(metrics.num_target_files_added >= 3); + assert_eq!(metrics.num_target_files_removed, 2); + assert_eq!(metrics.num_target_rows_copied, 1); + assert_eq!(metrics.num_target_rows_updated, 3); + assert_eq!(metrics.num_target_rows_inserted, 2); + assert_eq!(metrics.num_target_rows_deleted, 0); + assert_eq!(metrics.num_output_rows, 6); + assert_eq!(metrics.num_source_rows, 3); + + let commit_info = table.history(None).await.unwrap(); + let last_commit = &commit_info[0]; + let parameters = last_commit.operation_parameters.clone().unwrap(); + assert!(!parameters.contains_key("predicate")); + assert_eq!( + parameters["mergePredicate"], + "target.id = source.id AND \ + target.id IN (source.id, source.modified, source.value) AND \ + target.modified IN ('2021-02-02')" + ); + + let expected = vec![ + "+----+-------+------------+", + "| id | value | modified |", + "+----+-------+------------+", + "| A | 2 | 2021-02-01 |", + "| B | 9 | 2021-02-01 |", + "| B | 10 | 2021-02-02 |", + "| C | 20 | 2023-07-04 |", + "| D | 100 | 2021-02-02 |", + "| X | 30 | 2023-07-04 |", + "+----+-------+------------+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } + #[tokio::test] async fn test_merge_delete_matched() { // Validate behaviours of match delete