From 3127c81dd6eb466913fcef296f4a9d714eaf3ceb Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sat, 7 Sep 2024 13:30:46 +0200 Subject: [PATCH] fix: escaped columns in dataskippingstatscolumns --- crates/core/src/writer/stats.rs | 22 ++++++++++++++++++++++ python/tests/test_writer.py | 8 +++----- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/crates/core/src/writer/stats.rs b/crates/core/src/writer/stats.rs index 28a089ae1c..e4b93a54f5 100644 --- a/crates/core/src/writer/stats.rs +++ b/crates/core/src/writer/stats.rs @@ -5,6 +5,7 @@ use std::{collections::HashMap, ops::AddAssign}; use delta_kernel::expressions::Scalar; use indexmap::IndexMap; +use itertools::Itertools; use parquet::file::metadata::ParquetMetaData; use parquet::format::FileMetaData; use parquet::schema::types::{ColumnDescriptor, SchemaDescriptor}; @@ -130,8 +131,29 @@ fn stats_from_metadata( let mut min_values: HashMap = HashMap::new(); let mut max_values: HashMap = HashMap::new(); let mut null_count: HashMap = HashMap::new(); + let dialect = sqlparser::dialect::GenericDialect {}; let idx_to_iterate = if let Some(stats_cols) = stats_columns { + let stats_cols = stats_cols + .into_iter() + .map(|v| { + match sqlparser::parser::Parser::new(&dialect) + .try_with_sql(v) + .map_err(|e| DeltaTableError::generic(e.to_string()))? + .parse_multipart_identifier() + { + Ok(parts) => Ok(parts.into_iter().map(|v| v.value).join(".")), + Err(e) => { + return Err(DeltaWriterError::DeltaTable( + DeltaTableError::GenericError { + source: Box::new(e), + }, + )) + } + } + }) + .collect::, DeltaWriterError>>()?; + schema_descriptor .columns() .iter() diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 1534d42789..0186500032 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -1699,8 +1699,7 @@ def _check_stats(dt: DeltaTable): _check_stats(dt) -@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) -def test_write_stats_columns_stats_provided(tmp_path: pathlib.Path, engine): +def test_write_stats_columns_stats_provided(tmp_path: pathlib.Path): def _check_stats(dt: DeltaTable): add_actions_table = dt.get_add_actions(flatten=True) stats = add_actions_table.to_pylist()[0] @@ -1726,15 +1725,14 @@ def _check_stats(dt: DeltaTable): tmp_path, data, mode="append", - engine=engine, - configuration={"delta.dataSkippingStatsColumns": "foo,baz"}, + configuration={"delta.dataSkippingStatsColumns": "foo,`baz`"}, ) dt = DeltaTable(tmp_path) _check_stats(dt) # Check if it properly takes skippingNumIndexCols from the config in the table - write_deltalake(tmp_path, data, mode="overwrite", engine=engine) + write_deltalake(tmp_path, data, mode="overwrite") dt = DeltaTable(tmp_path) assert dt.version() == 1