diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index 6be8c264ba..59bd28e400 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -44,7 +44,7 @@ use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner}; use datafusion::{ execution::context::SessionState, physical_plan::ExecutionPlan, - prelude::{DataFrame, SessionContext}, + prelude::{cast, DataFrame, SessionContext}, }; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, DFSchema, ScalarValue, TableReference}; @@ -990,8 +990,10 @@ async fn execute( .end()?; let name = "__delta_rs_c_".to_owned() + delta_field.name(); - write_projection - .push(Expr::Column(Column::from_name(name.clone())).alias(delta_field.name())); + write_projection.push(cast( + Expr::Column(Column::from_name(name.clone())).alias(delta_field.name()), + delta_field.data_type().try_into()?, + )); new_columns.push((name, case)); } diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index e8416f6e5f..1f81e81142 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -3,6 +3,7 @@ import pathlib import pyarrow as pa +import pyarrow.parquet as pq import pytest from deltalake import DeltaTable, write_deltalake @@ -1120,3 +1121,31 @@ def test_merge_non_nullable(tmp_path): target_alias="t", predicate="s.id = t.id", ).when_matched_update_all().when_not_matched_insert_all().execute() + + +def test_merge_when_wrong_but_castable_type_passed_while_merge( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["7", "8"]), + "price": pa.array(["1", "2"], pa.string()), + "sold": pa.array([1, 2], pa.int32()), + "deleted": pa.array([False, False]), + } + ) + dt.merge( + source=source_table, + predicate="t.id = s.id", + source_alias="s", + target_alias="t", + ).when_not_matched_insert_all().execute() + + table_schema = pq.read_table( + tmp_path / dt.get_add_actions().column(0)[0].as_py() + ).schema + assert table_schema.field("price").type == sample_table["price"].type