Skip to content

Commit

Permalink
add explicit type casts while merge
Browse files Browse the repository at this point in the history
Signed-off-by: Igor Vitenko <[email protected]>
  • Loading branch information
t1g0rz committed Dec 10, 2024
1 parent 2501f2f commit a31f7ee
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
8 changes: 5 additions & 3 deletions crates/core/src/operations/merge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner};
use datafusion::{
execution::context::SessionState,
physical_plan::ExecutionPlan,
prelude::{DataFrame, SessionContext},
prelude::{cast, DataFrame, SessionContext},
};
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{Column, DFSchema, ScalarValue, TableReference};
Expand Down Expand Up @@ -990,8 +990,10 @@ async fn execute(
.end()?;

let name = "__delta_rs_c_".to_owned() + delta_field.name();
write_projection
.push(Expr::Column(Column::from_name(name.clone())).alias(delta_field.name()));
write_projection.push(cast(
Expr::Column(Column::from_name(name.clone())).alias(delta_field.name()),
delta_field.data_type().try_into()?,
));
new_columns.push((name, case));
}

Expand Down
29 changes: 29 additions & 0 deletions python/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pathlib

import pyarrow as pa
import pyarrow.parquet as pq
import pytest

from deltalake import DeltaTable, write_deltalake
Expand Down Expand Up @@ -1120,3 +1121,31 @@ def test_merge_non_nullable(tmp_path):
target_alias="t",
predicate="s.id = t.id",
).when_matched_update_all().when_not_matched_insert_all().execute()


def test_merge_when_wrong_but_castable_type_passed_while_merge(
tmp_path: pathlib.Path, sample_table: pa.Table
):
write_deltalake(tmp_path, sample_table, mode="append")

dt = DeltaTable(tmp_path)

source_table = pa.table(
{
"id": pa.array(["7", "8"]),
"price": pa.array(["1", "2"], pa.string()),
"sold": pa.array([1, 2], pa.int32()),
"deleted": pa.array([False, False]),
}
)
dt.merge(
source=source_table,
predicate="t.id = s.id",
source_alias="s",
target_alias="t",
).when_not_matched_insert_all().execute()

table_schema = pq.read_table(
tmp_path / dt.get_add_actions().to_pandas()["path"].iloc[0]
).schema
assert table_schema.field("price").type == sample_table["price"].type

0 comments on commit a31f7ee

Please sign in to comment.