diff --git a/crates/deltalake-core/src/operations/optimize.rs b/crates/deltalake-core/src/operations/optimize.rs index ef8905e0c9..11a883eef6 100644 --- a/crates/deltalake-core/src/operations/optimize.rs +++ b/crates/deltalake-core/src/operations/optimize.rs @@ -40,6 +40,7 @@ use serde::{Deserialize, Serialize}; use super::transaction::{commit, PROTOCOL}; use super::writer::{PartitionWriter, PartitionWriterConfig}; +use super::write::cast_record_batch; use crate::errors::{DeltaResult, DeltaTableError}; use crate::kernel::{Action, Remove}; use crate::logstore::LogStoreRef; @@ -439,7 +440,9 @@ impl MergePlan { let mut read_stream = read_stream.await?; while let Some(maybe_batch) = read_stream.next().await { - let batch = maybe_batch?; + let mut batch = maybe_batch?; + + batch = cast_record_batch(&batch, task_parameters.file_schema.clone(), false)?; partial_metrics.num_batches += 1; writer.write(&batch).await.map_err(DeltaTableError::from)?; } @@ -528,6 +531,7 @@ impl MergePlan { async fn read_zorder( files: MergeBin, context: Arc, + file_schema: ArrowSchemaRef, ) -> Result>, DeltaTableError> { use datafusion::prelude::{col, ParquetReadOptions}; use datafusion_expr::expr::ScalarUDF; @@ -539,7 +543,7 @@ impl MergePlan { .collect_vec(); let df = context .ctx - .read_parquet(locations, ParquetReadOptions::default()) + .read_parquet(locations, ParquetReadOptions::default().schema(&file_schema)) .await?; let original_columns = df @@ -644,7 +648,7 @@ impl MergePlan { let log_store = log_store.clone(); futures::stream::iter(bins) .map(move |(partition, files)| { - let batch_stream = Self::read_zorder(files.clone(), exec_context.clone()); + let batch_stream = Self::read_zorder(files.clone(), exec_context.clone(), task_parameters.file_schema.clone()); let rewrite_result = tokio::task::spawn(Self::rewrite_files( task_parameters.clone(), partition, diff --git a/crates/deltalake-core/src/operations/write.rs b/crates/deltalake-core/src/operations/write.rs index 7f61e46c5c..09c1e12cbc 100644 --- a/crates/deltalake-core/src/operations/write.rs +++ b/crates/deltalake-core/src/operations/write.rs @@ -625,7 +625,8 @@ fn cast_record_batch_columns( .collect::, _>>() } -fn cast_record_batch( +/// Cast recordbatch to a new target_schema, by casting each column array +pub fn cast_record_batch( batch: &RecordBatch, target_schema: ArrowSchemaRef, safe: bool, diff --git a/python/tests/test_optimize.py b/python/tests/test_optimize.py index 4b746b1434..07de6920d2 100644 --- a/python/tests/test_optimize.py +++ b/python/tests/test_optimize.py @@ -7,9 +7,16 @@ from deltalake import DeltaTable, write_deltalake +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) @pytest.mark.parametrize("use_relative", [True, False]) +@pytest.mark.parametrize("large_dtypes", [True, False]) def test_optimize_run_table( - tmp_path: pathlib.Path, sample_data: pa.Table, monkeypatch, use_relative: bool + tmp_path: pathlib.Path, + sample_data: pa.Table, + monkeypatch, + use_relative: bool, + large_dtypes: bool, + engine, ): if use_relative: monkeypatch.chdir(tmp_path) # Make tmp_path the working directory @@ -18,9 +25,15 @@ def test_optimize_run_table( else: table_path = str(tmp_path) - write_deltalake(table_path, sample_data, mode="append") - write_deltalake(table_path, sample_data, mode="append") - write_deltalake(table_path, sample_data, mode="append") + write_deltalake( + table_path, sample_data, mode="append", engine=engine, large_dtypes=large_dtypes + ) + write_deltalake( + table_path, sample_data, mode="append", engine=engine, large_dtypes=large_dtypes + ) + write_deltalake( + table_path, sample_data, mode="append", engine=engine, large_dtypes=large_dtypes + ) dt = DeltaTable(table_path) old_version = dt.version() @@ -31,22 +44,31 @@ def test_optimize_run_table( assert dt.version() == old_version + 1 +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +@pytest.mark.parametrize("large_dtypes", [True, False]) def test_z_order_optimize( - tmp_path: pathlib.Path, - sample_data: pa.Table, + tmp_path: pathlib.Path, sample_data: pa.Table, large_dtypes: bool, engine ): - write_deltalake(tmp_path, sample_data, mode="append") - write_deltalake(tmp_path, sample_data, mode="append") - write_deltalake(tmp_path, sample_data, mode="append") + write_deltalake( + tmp_path, sample_data, mode="append", large_dtypes=large_dtypes, engine=engine + ) + write_deltalake( + tmp_path, sample_data, mode="append", large_dtypes=large_dtypes, engine=engine + ) + write_deltalake( + tmp_path, sample_data, mode="append", large_dtypes=large_dtypes, engine=engine + ) dt = DeltaTable(tmp_path) old_version = dt.version() + total_files = len(dt.file_uris()) dt.optimize.z_order(["date32", "timestamp"]) last_action = dt.history(1)[0] assert last_action["operation"] == "OPTIMIZE" assert dt.version() == old_version + 1 + assert len(dt.file_uris()) == 1 def test_optimize_min_commit_interval(