From b96186fdef1ff410663ec8fce41186c018f8e09a Mon Sep 17 00:00:00 2001 From: Oleks V Date: Wed, 10 Jul 2024 08:09:51 -0700 Subject: [PATCH 01/59] Introduce `resources_err!` error macro (#11374) --- datafusion/common/src/error.rs | 3 +++ datafusion/execution/src/disk_manager.rs | 6 +++--- datafusion/execution/src/memory_pool/pool.rs | 5 +++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index b1fdb652af481..9be662ca283e6 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -553,6 +553,9 @@ make_error!(config_err, config_datafusion_err, Configuration); // Exposes a macro to create `DataFusionError::Substrait` with optional backtrace make_error!(substrait_err, substrait_datafusion_err, Substrait); +// Exposes a macro to create `DataFusionError::ResourcesExhausted` with optional backtrace +make_error!(resources_err, resources_datafusion_err, ResourcesExhausted); + // Exposes a macro to create `DataFusionError::SQL` with optional backtrace #[macro_export] macro_rules! sql_datafusion_err { diff --git a/datafusion/execution/src/disk_manager.rs b/datafusion/execution/src/disk_manager.rs index cca25c7c3e885..c98d7e5579f0f 100644 --- a/datafusion/execution/src/disk_manager.rs +++ b/datafusion/execution/src/disk_manager.rs @@ -18,7 +18,7 @@ //! Manages files generated during query execution, files are //! hashed among the directories listed in RuntimeConfig::local_dirs. -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{resources_datafusion_err, DataFusionError, Result}; use log::debug; use parking_lot::Mutex; use rand::{thread_rng, Rng}; @@ -119,9 +119,9 @@ impl DiskManager { ) -> Result { let mut guard = self.local_dirs.lock(); let local_dirs = guard.as_mut().ok_or_else(|| { - DataFusionError::ResourcesExhausted(format!( + resources_datafusion_err!( "Memory Exhausted while {request_description} (DiskManager is disabled)" - )) + ) })?; // Create a temporary directory if needed diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index 4a491630fe205..fd7724f3076c4 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -16,7 +16,7 @@ // under the License. use crate::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{resources_datafusion_err, DataFusionError, Result}; use log::debug; use parking_lot::Mutex; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -231,12 +231,13 @@ impl MemoryPool for FairSpillPool { } } +#[inline(always)] fn insufficient_capacity_err( reservation: &MemoryReservation, additional: usize, available: usize, ) -> DataFusionError { - DataFusionError::ResourcesExhausted(format!("Failed to allocate additional {} bytes for {} with {} bytes already allocated - maximum available is {}", additional, reservation.registration.consumer.name, reservation.size, available)) + resources_datafusion_err!("Failed to allocate additional {} bytes for {} with {} bytes already allocated - maximum available is {}", additional, reservation.registration.consumer.name, reservation.size, available) } #[cfg(test)] From 585504a31fd7d9a44c97f3f19af42bace08b8cc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Thu, 11 Jul 2024 00:58:53 +0800 Subject: [PATCH 02/59] Enable clone_on_ref_ptr clippy lint on common (#11384) --- datafusion/common-runtime/src/lib.rs | 2 + datafusion/common/src/dfschema.rs | 12 +-- datafusion/common/src/hash_utils.rs | 19 ++-- datafusion/common/src/lib.rs | 2 + datafusion/common/src/scalar/mod.rs | 97 ++++++++++--------- .../common/src/scalar/struct_builder.rs | 2 +- datafusion/common/src/utils/mod.rs | 5 +- 7 files changed, 77 insertions(+), 62 deletions(-) diff --git a/datafusion/common-runtime/src/lib.rs b/datafusion/common-runtime/src/lib.rs index e8624163f2240..8145bb110464e 100644 --- a/datafusion/common-runtime/src/lib.rs +++ b/datafusion/common-runtime/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] pub mod common; diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 3c2cc89fc0142..7598cbc4d86a0 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -211,7 +211,7 @@ impl DFSchema { schema: &SchemaRef, ) -> Result { let dfschema = Self { - inner: schema.clone(), + inner: Arc::clone(schema), field_qualifiers: qualifiers, functional_dependencies: FunctionalDependencies::empty(), }; @@ -311,7 +311,7 @@ impl DFSchema { }; if !duplicated_field { // self.inner.fields.push(field.clone()); - schema_builder.push(field.clone()); + schema_builder.push(Arc::clone(field)); qualifiers.push(qualifier.cloned()); } } @@ -1276,7 +1276,7 @@ mod tests { let arrow_schema_ref = Arc::new(arrow_schema.clone()); let df_schema = DFSchema { - inner: arrow_schema_ref.clone(), + inner: Arc::clone(&arrow_schema_ref), field_qualifiers: vec![None; arrow_schema_ref.fields.len()], functional_dependencies: FunctionalDependencies::empty(), }; @@ -1284,7 +1284,7 @@ mod tests { { let arrow_schema = arrow_schema.clone(); - let arrow_schema_ref = arrow_schema_ref.clone(); + let arrow_schema_ref = Arc::clone(&arrow_schema_ref); assert_eq!(df_schema, arrow_schema.to_dfschema().unwrap()); assert_eq!(df_schema, arrow_schema_ref.to_dfschema().unwrap()); @@ -1292,7 +1292,7 @@ mod tests { { let arrow_schema = arrow_schema.clone(); - let arrow_schema_ref = arrow_schema_ref.clone(); + let arrow_schema_ref = Arc::clone(&arrow_schema_ref); assert_eq!(df_schema_ref, arrow_schema.to_dfschema_ref().unwrap()); assert_eq!(df_schema_ref, arrow_schema_ref.to_dfschema_ref().unwrap()); @@ -1322,7 +1322,7 @@ mod tests { let schema = Arc::new(Schema::new(vec![a_field, b_field])); let df_schema = DFSchema { - inner: schema.clone(), + inner: Arc::clone(&schema), field_qualifiers: vec![None; schema.fields.len()], functional_dependencies: FunctionalDependencies::empty(), }; diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index c972536c4d23e..c8adae34f6455 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -244,7 +244,7 @@ fn hash_list_array( where OffsetSize: OffsetSizeTrait, { - let values = array.values().clone(); + let values = Arc::clone(array.values()); let offsets = array.value_offsets(); let nulls = array.nulls(); let mut values_hashes = vec![0u64; values.len()]; @@ -274,7 +274,7 @@ fn hash_fixed_list_array( random_state: &RandomState, hashes_buffer: &mut [u64], ) -> Result<()> { - let values = array.values().clone(); + let values = Arc::clone(array.values()); let value_len = array.value_length(); let offset_size = value_len as usize / array.len(); let nulls = array.nulls(); @@ -622,19 +622,19 @@ mod tests { vec![ ( Arc::new(Field::new("bool", DataType::Boolean, false)), - boolarr.clone() as ArrayRef, + Arc::clone(&boolarr) as ArrayRef, ), ( Arc::new(Field::new("i32", DataType::Int32, false)), - i32arr.clone() as ArrayRef, + Arc::clone(&i32arr) as ArrayRef, ), ( Arc::new(Field::new("i32", DataType::Int32, false)), - i32arr.clone() as ArrayRef, + Arc::clone(&i32arr) as ArrayRef, ), ( Arc::new(Field::new("bool", DataType::Boolean, false)), - boolarr.clone() as ArrayRef, + Arc::clone(&boolarr) as ArrayRef, ), ], Buffer::from(&[0b001011]), @@ -710,7 +710,12 @@ mod tests { let random_state = RandomState::with_seeds(0, 0, 0, 0); let mut one_col_hashes = vec![0; strings1.len()]; - create_hashes(&[dict_array.clone()], &random_state, &mut one_col_hashes).unwrap(); + create_hashes( + &[Arc::clone(&dict_array) as ArrayRef], + &random_state, + &mut one_col_hashes, + ) + .unwrap(); let mut two_col_hashes = vec![0; strings1.len()]; create_hashes( diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index c275152642f0e..8cd64e7d16a26 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] mod column; mod dfschema; diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 26e03a3b9893e..c8f21788cbbdf 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1758,8 +1758,11 @@ impl ScalarValue { if let Some(DataType::FixedSizeList(f, l)) = first_non_null_data_type { for array in arrays.iter_mut() { if array.is_null(0) { - *array = - Arc::new(FixedSizeListArray::new_null(f.clone(), l, 1)); + *array = Arc::new(FixedSizeListArray::new_null( + Arc::clone(&f), + l, + 1, + )); } } } @@ -3298,16 +3301,16 @@ impl TryFrom<&DataType> for ScalarValue { ), // `ScalaValue::List` contains single element `ListArray`. DataType::List(field_ref) => ScalarValue::List(Arc::new( - GenericListArray::new_null(field_ref.clone(), 1), + GenericListArray::new_null(Arc::clone(field_ref), 1), )), // `ScalarValue::LargeList` contains single element `LargeListArray`. DataType::LargeList(field_ref) => ScalarValue::LargeList(Arc::new( - GenericListArray::new_null(field_ref.clone(), 1), + GenericListArray::new_null(Arc::clone(field_ref), 1), )), // `ScalaValue::FixedSizeList` contains single element `FixedSizeList`. DataType::FixedSizeList(field_ref, fixed_length) => { ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::new_null( - field_ref.clone(), + Arc::clone(field_ref), *fixed_length, 1, ))) @@ -3746,11 +3749,11 @@ mod tests { let expected = StructArray::from(vec![ ( Arc::new(Field::new("b", DataType::Boolean, false)), - boolean.clone() as ArrayRef, + Arc::clone(&boolean) as ArrayRef, ), ( Arc::new(Field::new("c", DataType::Int32, false)), - int.clone() as ArrayRef, + Arc::clone(&int) as ArrayRef, ), ]); @@ -3792,11 +3795,11 @@ mod tests { let struct_array = StructArray::from(vec![ ( Arc::new(Field::new("b", DataType::Boolean, false)), - boolean.clone() as ArrayRef, + Arc::clone(&boolean) as ArrayRef, ), ( Arc::new(Field::new("c", DataType::Int32, false)), - int.clone() as ArrayRef, + Arc::clone(&int) as ArrayRef, ), ]); let sv = ScalarValue::Struct(Arc::new(struct_array)); @@ -3810,11 +3813,11 @@ mod tests { let struct_array = StructArray::from(vec![ ( Arc::new(Field::new("b", DataType::Boolean, false)), - boolean.clone() as ArrayRef, + Arc::clone(&boolean) as ArrayRef, ), ( Arc::new(Field::new("c", DataType::Int32, false)), - int.clone() as ArrayRef, + Arc::clone(&int) as ArrayRef, ), ]); @@ -3846,7 +3849,7 @@ mod tests { fn test_to_array_of_size_for_fsl() { let values = Int32Array::from_iter([Some(1), None, Some(2)]); let field = Arc::new(Field::new("item", DataType::Int32, true)); - let arr = FixedSizeListArray::new(field.clone(), 3, Arc::new(values), None); + let arr = FixedSizeListArray::new(Arc::clone(&field), 3, Arc::new(values), None); let sv = ScalarValue::FixedSizeList(Arc::new(arr)); let actual_arr = sv .to_array_of_size(2) @@ -3932,13 +3935,13 @@ mod tests { fn test_iter_to_array_fixed_size_list() { let field = Arc::new(Field::new("item", DataType::Int32, true)); let f1 = Arc::new(FixedSizeListArray::new( - field.clone(), + Arc::clone(&field), 3, Arc::new(Int32Array::from(vec![1, 2, 3])), None, )); let f2 = Arc::new(FixedSizeListArray::new( - field.clone(), + Arc::clone(&field), 3, Arc::new(Int32Array::from(vec![4, 5, 6])), None, @@ -3946,7 +3949,7 @@ mod tests { let f_nulls = Arc::new(FixedSizeListArray::new_null(field, 1, 1)); let scalars = vec![ - ScalarValue::FixedSizeList(f_nulls.clone()), + ScalarValue::FixedSizeList(Arc::clone(&f_nulls)), ScalarValue::FixedSizeList(f1), ScalarValue::FixedSizeList(f2), ScalarValue::FixedSizeList(f_nulls), @@ -4780,7 +4783,7 @@ mod tests { let inner_field = Arc::new(Field::new("item", DataType::Int32, true)); // Test for List - let data_type = &DataType::List(inner_field.clone()); + let data_type = &DataType::List(Arc::clone(&inner_field)); let scalar: ScalarValue = data_type.try_into().unwrap(); let expected = ScalarValue::List( new_null_array(data_type, 1) @@ -4792,7 +4795,7 @@ mod tests { assert!(expected.is_null()); // Test for LargeList - let data_type = &DataType::LargeList(inner_field.clone()); + let data_type = &DataType::LargeList(Arc::clone(&inner_field)); let scalar: ScalarValue = data_type.try_into().unwrap(); let expected = ScalarValue::LargeList( new_null_array(data_type, 1) @@ -4804,7 +4807,7 @@ mod tests { assert!(expected.is_null()); // Test for FixedSizeList(5) - let data_type = &DataType::FixedSizeList(inner_field.clone(), 5); + let data_type = &DataType::FixedSizeList(Arc::clone(&inner_field), 5); let scalar: ScalarValue = data_type.try_into().unwrap(); let expected = ScalarValue::FixedSizeList( new_null_array(data_type, 1) @@ -5212,35 +5215,35 @@ mod tests { let field_f = Arc::new(Field::new("f", DataType::Int64, false)); let field_d = Arc::new(Field::new( "D", - DataType::Struct(vec![field_e.clone(), field_f.clone()].into()), + DataType::Struct(vec![Arc::clone(&field_e), Arc::clone(&field_f)].into()), false, )); let struct_array = StructArray::from(vec![ ( - field_e.clone(), + Arc::clone(&field_e), Arc::new(Int16Array::from(vec![2])) as ArrayRef, ), ( - field_f.clone(), + Arc::clone(&field_f), Arc::new(Int64Array::from(vec![3])) as ArrayRef, ), ]); let struct_array = StructArray::from(vec![ ( - field_a.clone(), + Arc::clone(&field_a), Arc::new(Int32Array::from(vec![23])) as ArrayRef, ), ( - field_b.clone(), + Arc::clone(&field_b), Arc::new(BooleanArray::from(vec![false])) as ArrayRef, ), ( - field_c.clone(), + Arc::clone(&field_c), Arc::new(StringArray::from(vec!["Hello"])) as ArrayRef, ), - (field_d.clone(), Arc::new(struct_array) as ArrayRef), + (Arc::clone(&field_d), Arc::new(struct_array) as ArrayRef), ]); let scalar = ScalarValue::Struct(Arc::new(struct_array)); @@ -5250,26 +5253,26 @@ mod tests { let expected = Arc::new(StructArray::from(vec![ ( - field_a.clone(), + Arc::clone(&field_a), Arc::new(Int32Array::from(vec![23, 23])) as ArrayRef, ), ( - field_b.clone(), + Arc::clone(&field_b), Arc::new(BooleanArray::from(vec![false, false])) as ArrayRef, ), ( - field_c.clone(), + Arc::clone(&field_c), Arc::new(StringArray::from(vec!["Hello", "Hello"])) as ArrayRef, ), ( - field_d.clone(), + Arc::clone(&field_d), Arc::new(StructArray::from(vec![ ( - field_e.clone(), + Arc::clone(&field_e), Arc::new(Int16Array::from(vec![2, 2])) as ArrayRef, ), ( - field_f.clone(), + Arc::clone(&field_f), Arc::new(Int64Array::from(vec![3, 3])) as ArrayRef, ), ])) as ArrayRef, @@ -5348,26 +5351,26 @@ mod tests { let expected = Arc::new(StructArray::from(vec![ ( - field_a.clone(), + Arc::clone(&field_a), Arc::new(Int32Array::from(vec![23, 7, -1000])) as ArrayRef, ), ( - field_b.clone(), + Arc::clone(&field_b), Arc::new(BooleanArray::from(vec![false, true, true])) as ArrayRef, ), ( - field_c.clone(), + Arc::clone(&field_c), Arc::new(StringArray::from(vec!["Hello", "World", "!!!!!"])) as ArrayRef, ), ( - field_d.clone(), + Arc::clone(&field_d), Arc::new(StructArray::from(vec![ ( - field_e.clone(), + Arc::clone(&field_e), Arc::new(Int16Array::from(vec![2, 4, 6])) as ArrayRef, ), ( - field_f.clone(), + Arc::clone(&field_f), Arc::new(Int64Array::from(vec![3, 5, 7])) as ArrayRef, ), ])) as ArrayRef, @@ -5431,11 +5434,11 @@ mod tests { let array = as_struct_array(&array).unwrap(); let expected = StructArray::from(vec![ ( - field_a.clone(), + Arc::clone(&field_a), Arc::new(StringArray::from(vec!["First", "Second", "Third"])) as ArrayRef, ), ( - field_primitive_list.clone(), + Arc::clone(&field_primitive_list), Arc::new(ListArray::from_iter_primitive::(vec![ Some(vec![Some(1), Some(2), Some(3)]), Some(vec![Some(4), Some(5)]), @@ -6195,18 +6198,18 @@ mod tests { let struct_value = vec![ ( - fields[0].clone(), + Arc::clone(&fields[0]), Arc::new(UInt64Array::from(vec![Some(1)])) as ArrayRef, ), ( - fields[1].clone(), + Arc::clone(&fields[1]), Arc::new(StructArray::from(vec![ ( - fields_b[0].clone(), + Arc::clone(&fields_b[0]), Arc::new(UInt64Array::from(vec![Some(2)])) as ArrayRef, ), ( - fields_b[1].clone(), + Arc::clone(&fields_b[1]), Arc::new(UInt64Array::from(vec![Some(3)])) as ArrayRef, ), ])) as ArrayRef, @@ -6215,19 +6218,19 @@ mod tests { let struct_value_with_nulls = vec![ ( - fields[0].clone(), + Arc::clone(&fields[0]), Arc::new(UInt64Array::from(vec![Some(1)])) as ArrayRef, ), ( - fields[1].clone(), + Arc::clone(&fields[1]), Arc::new(StructArray::from(( vec![ ( - fields_b[0].clone(), + Arc::clone(&fields_b[0]), Arc::new(UInt64Array::from(vec![Some(2)])) as ArrayRef, ), ( - fields_b[1].clone(), + Arc::clone(&fields_b[1]), Arc::new(UInt64Array::from(vec![Some(3)])) as ArrayRef, ), ], diff --git a/datafusion/common/src/scalar/struct_builder.rs b/datafusion/common/src/scalar/struct_builder.rs index b1a34e4a61d01..4a6a8f0289a7d 100644 --- a/datafusion/common/src/scalar/struct_builder.rs +++ b/datafusion/common/src/scalar/struct_builder.rs @@ -144,7 +144,7 @@ impl IntoFieldRef for FieldRef { impl IntoFieldRef for &FieldRef { fn into_field_ref(self) -> FieldRef { - self.clone() + Arc::clone(self) } } diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index dd7b80333cf81..8264b48725929 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -245,7 +245,10 @@ pub fn evaluate_partition_ranges( end: num_rows, }] } else { - let cols: Vec<_> = partition_columns.iter().map(|x| x.values.clone()).collect(); + let cols: Vec<_> = partition_columns + .iter() + .map(|x| Arc::clone(&x.values)) + .collect(); partition(&cols)?.ranges() }) } From 6038f4cfac536dbb54ea2761828f7344a23b94f0 Mon Sep 17 00:00:00 2001 From: wiedld Date: Wed, 10 Jul 2024 11:21:01 -0700 Subject: [PATCH 03/59] Track parquet writer encoding memory usage on MemoryPool (#11345) * feat(11344): track memory used for non-parallel writes * feat(11344): track memory usage during parallel writes * test(11344): create bounded stream for testing * test(11344): test ParquetSink memory reservation * feat(11344): track bytes in file writer * refactor(11344): tweak the ordering to add col bytes to rg_reservation, before selecting shrinking for data bytes flushed * refactor: move each col_reservation and rg_reservation to match the parallelized call stack for col vs rg * test(11344): add memory_limit enforcement test for parquet sink * chore: cleanup to remove unnecessary reservation management steps * fix: fix CI test failure due to file extension rename --- .../src/datasource/file_format/parquet.rs | 165 ++++++++++++++++-- datafusion/core/src/test_util/mod.rs | 36 ++++ datafusion/core/tests/memory_limit/mod.rs | 25 +++ 3 files changed, 216 insertions(+), 10 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 27d783cd89b5f..694c949285374 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -48,6 +48,7 @@ use datafusion_common::{ DEFAULT_PARQUET_EXTENSION, }; use datafusion_common_runtime::SpawnedTask; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_physical_expr::expressions::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; @@ -749,9 +750,13 @@ impl DataSink for ParquetSink { parquet_props.writer_options().clone(), ) .await?; + let mut reservation = + MemoryConsumer::new(format!("ParquetSink[{}]", path)) + .register(context.memory_pool()); file_write_tasks.spawn(async move { while let Some(batch) = rx.recv().await { writer.write(&batch).await?; + reservation.try_resize(writer.memory_size())?; } let file_metadata = writer .close() @@ -771,6 +776,7 @@ impl DataSink for ParquetSink { let schema = self.get_writer_schema(); let props = parquet_props.clone(); let parallel_options_clone = parallel_options.clone(); + let pool = Arc::clone(context.memory_pool()); file_write_tasks.spawn(async move { let file_metadata = output_single_parquet_file_parallelized( writer, @@ -778,6 +784,7 @@ impl DataSink for ParquetSink { schema, props.writer_options(), parallel_options_clone, + pool, ) .await?; Ok((path, file_metadata)) @@ -818,14 +825,16 @@ impl DataSink for ParquetSink { async fn column_serializer_task( mut rx: Receiver, mut writer: ArrowColumnWriter, -) -> Result { + mut reservation: MemoryReservation, +) -> Result<(ArrowColumnWriter, MemoryReservation)> { while let Some(col) = rx.recv().await { writer.write(&col)?; + reservation.try_resize(writer.memory_size())?; } - Ok(writer) + Ok((writer, reservation)) } -type ColumnWriterTask = SpawnedTask>; +type ColumnWriterTask = SpawnedTask>; type ColSender = Sender; /// Spawns a parallel serialization task for each column @@ -835,6 +844,7 @@ fn spawn_column_parallel_row_group_writer( schema: Arc, parquet_props: Arc, max_buffer_size: usize, + pool: &Arc, ) -> Result<(Vec, Vec)> { let schema_desc = arrow_to_parquet_schema(&schema)?; let col_writers = get_column_writers(&schema_desc, &parquet_props, &schema)?; @@ -848,7 +858,13 @@ fn spawn_column_parallel_row_group_writer( mpsc::channel::(max_buffer_size); col_array_channels.push(send_array); - let task = SpawnedTask::spawn(column_serializer_task(recieve_array, writer)); + let reservation = + MemoryConsumer::new("ParquetSink(ArrowColumnWriter)").register(pool); + let task = SpawnedTask::spawn(column_serializer_task( + recieve_array, + writer, + reservation, + )); col_writer_tasks.push(task); } @@ -864,7 +880,7 @@ struct ParallelParquetWriterOptions { /// This is the return type of calling [ArrowColumnWriter].close() on each column /// i.e. the Vec of encoded columns which can be appended to a row group -type RBStreamSerializeResult = Result<(Vec, usize)>; +type RBStreamSerializeResult = Result<(Vec, MemoryReservation, usize)>; /// Sends the ArrowArrays in passed [RecordBatch] through the channels to their respective /// parallel column serializers. @@ -895,16 +911,22 @@ async fn send_arrays_to_col_writers( fn spawn_rg_join_and_finalize_task( column_writer_tasks: Vec, rg_rows: usize, + pool: &Arc, ) -> SpawnedTask { + let mut rg_reservation = + MemoryConsumer::new("ParquetSink(SerializedRowGroupWriter)").register(pool); + SpawnedTask::spawn(async move { let num_cols = column_writer_tasks.len(); let mut finalized_rg = Vec::with_capacity(num_cols); for task in column_writer_tasks.into_iter() { - let writer = task.join_unwind().await?; + let (writer, _col_reservation) = task.join_unwind().await?; + let encoded_size = writer.get_estimated_total_bytes(); + rg_reservation.grow(encoded_size); finalized_rg.push(writer.close()?); } - Ok((finalized_rg, rg_rows)) + Ok((finalized_rg, rg_reservation, rg_rows)) }) } @@ -922,6 +944,7 @@ fn spawn_parquet_parallel_serialization_task( schema: Arc, writer_props: Arc, parallel_options: ParallelParquetWriterOptions, + pool: Arc, ) -> SpawnedTask> { SpawnedTask::spawn(async move { let max_buffer_rb = parallel_options.max_buffered_record_batches_per_stream; @@ -931,6 +954,7 @@ fn spawn_parquet_parallel_serialization_task( schema.clone(), writer_props.clone(), max_buffer_rb, + &pool, )?; let mut current_rg_rows = 0; @@ -957,6 +981,7 @@ fn spawn_parquet_parallel_serialization_task( let finalize_rg_task = spawn_rg_join_and_finalize_task( column_writer_handles, max_row_group_rows, + &pool, ); serialize_tx.send(finalize_rg_task).await.map_err(|_| { @@ -973,6 +998,7 @@ fn spawn_parquet_parallel_serialization_task( schema.clone(), writer_props.clone(), max_buffer_rb, + &pool, )?; } } @@ -981,8 +1007,11 @@ fn spawn_parquet_parallel_serialization_task( drop(col_array_channels); // Handle leftover rows as final rowgroup, which may be smaller than max_row_group_rows if current_rg_rows > 0 { - let finalize_rg_task = - spawn_rg_join_and_finalize_task(column_writer_handles, current_rg_rows); + let finalize_rg_task = spawn_rg_join_and_finalize_task( + column_writer_handles, + current_rg_rows, + &pool, + ); serialize_tx.send(finalize_rg_task).await.map_err(|_| { DataFusionError::Internal( @@ -1002,9 +1031,13 @@ async fn concatenate_parallel_row_groups( schema: Arc, writer_props: Arc, mut object_store_writer: Box, + pool: Arc, ) -> Result { let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES); + let mut file_reservation = + MemoryConsumer::new("ParquetSink(SerializedFileWriter)").register(&pool); + let schema_desc = arrow_to_parquet_schema(schema.as_ref())?; let mut parquet_writer = SerializedFileWriter::new( merged_buff.clone(), @@ -1015,15 +1048,20 @@ async fn concatenate_parallel_row_groups( while let Some(task) = serialize_rx.recv().await { let result = task.join_unwind().await; let mut rg_out = parquet_writer.next_row_group()?; - let (serialized_columns, _cnt) = result?; + let (serialized_columns, mut rg_reservation, _cnt) = result?; for chunk in serialized_columns { chunk.append_to_row_group(&mut rg_out)?; + rg_reservation.free(); + let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap(); + file_reservation.try_resize(buff_to_flush.len())?; + if buff_to_flush.len() > BUFFER_FLUSH_BYTES { object_store_writer .write_all(buff_to_flush.as_slice()) .await?; buff_to_flush.clear(); + file_reservation.try_resize(buff_to_flush.len())?; // will set to zero } } rg_out.close()?; @@ -1034,6 +1072,7 @@ async fn concatenate_parallel_row_groups( object_store_writer.write_all(final_buff.as_slice()).await?; object_store_writer.shutdown().await?; + file_reservation.free(); Ok(file_metadata) } @@ -1048,6 +1087,7 @@ async fn output_single_parquet_file_parallelized( output_schema: Arc, parquet_props: &WriterProperties, parallel_options: ParallelParquetWriterOptions, + pool: Arc, ) -> Result { let max_rowgroups = parallel_options.max_parallel_row_groups; // Buffer size of this channel limits maximum number of RowGroups being worked on in parallel @@ -1061,12 +1101,14 @@ async fn output_single_parquet_file_parallelized( output_schema.clone(), arc_props.clone(), parallel_options, + Arc::clone(&pool), ); let file_metadata = concatenate_parallel_row_groups( serialize_rx, output_schema.clone(), arc_props.clone(), object_store_writer, + pool, ) .await?; @@ -1158,8 +1200,10 @@ mod tests { use super::super::test_util::scan_format; use crate::datasource::listing::{ListingTableUrl, PartitionedFile}; use crate::physical_plan::collect; + use crate::test_util::bounded_stream; use std::fmt::{Display, Formatter}; use std::sync::atomic::{AtomicUsize, Ordering}; + use std::time::Duration; use super::*; @@ -2177,4 +2221,105 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn parquet_sink_write_memory_reservation() -> Result<()> { + async fn test_memory_reservation(global: ParquetOptions) -> Result<()> { + let field_a = Field::new("a", DataType::Utf8, false); + let field_b = Field::new("b", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let object_store_url = ObjectStoreUrl::local_filesystem(); + + let file_sink_config = FileSinkConfig { + object_store_url: object_store_url.clone(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![], + overwrite: true, + keep_partition_by_columns: false, + }; + let parquet_sink = Arc::new(ParquetSink::new( + file_sink_config, + TableParquetOptions { + key_value_metadata: std::collections::HashMap::from([ + ("my-data".to_string(), Some("stuff".to_string())), + ("my-data-bool-key".to_string(), None), + ]), + global, + ..Default::default() + }, + )); + + // create data + let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); + let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"])); + let batch = + RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); + + // create task context + let task_context = build_ctx(object_store_url.as_ref()); + assert_eq!( + task_context.memory_pool().reserved(), + 0, + "no bytes are reserved yet" + ); + + let mut write_task = parquet_sink.write_all( + Box::pin(RecordBatchStreamAdapter::new( + schema, + bounded_stream(batch, 1000), + )), + &task_context, + ); + + // incrementally poll and check for memory reservation + let mut reserved_bytes = 0; + while futures::poll!(&mut write_task).is_pending() { + reserved_bytes += task_context.memory_pool().reserved(); + tokio::time::sleep(Duration::from_micros(1)).await; + } + assert!( + reserved_bytes > 0, + "should have bytes reserved during write" + ); + assert_eq!( + task_context.memory_pool().reserved(), + 0, + "no leaking byte reservation" + ); + + Ok(()) + } + + let write_opts = ParquetOptions { + allow_single_file_parallelism: false, + ..Default::default() + }; + test_memory_reservation(write_opts) + .await + .expect("should track for non-parallel writes"); + + let row_parallel_write_opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 10, + maximum_buffered_record_batches_per_stream: 1, + ..Default::default() + }; + test_memory_reservation(row_parallel_write_opts) + .await + .expect("should track for row-parallel writes"); + + let col_parallel_write_opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 1, + maximum_buffered_record_batches_per_stream: 2, + ..Default::default() + }; + test_memory_reservation(col_parallel_write_opts) + .await + .expect("should track for column-parallel writes"); + + Ok(()) + } } diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index 059fa8fc6da77..ba0509f3f51ac 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -366,3 +366,39 @@ pub fn register_unbounded_file_with_ordering( ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; Ok(()) } + +struct BoundedStream { + limit: usize, + count: usize, + batch: RecordBatch, +} + +impl Stream for BoundedStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + if self.count >= self.limit { + return Poll::Ready(None); + } + self.count += 1; + Poll::Ready(Some(Ok(self.batch.clone()))) + } +} + +impl RecordBatchStream for BoundedStream { + fn schema(&self) -> SchemaRef { + self.batch.schema() + } +} + +/// Creates an bounded stream for testing purposes. +pub fn bounded_stream(batch: RecordBatch, limit: usize) -> SendableRecordBatchStream { + Box::pin(BoundedStream { + count: 0, + limit, + batch, + }) +} diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index f61ee5d9ab984..f7402357d1c76 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -31,6 +31,7 @@ use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; use futures::StreamExt; use std::any::Any; use std::sync::{Arc, OnceLock}; +use tokio::fs::File; use datafusion::datasource::streaming::StreamingTable; use datafusion::datasource::{MemTable, TableProvider}; @@ -323,6 +324,30 @@ async fn oom_recursive_cte() { .await } +#[tokio::test] +async fn oom_parquet_sink() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.into_path().join("test.parquet"); + let _ = File::create(path.clone()).await.unwrap(); + + TestCase::new() + .with_query(format!( + " + COPY (select * from t) + TO '{}' + STORED AS PARQUET OPTIONS (compression 'uncompressed'); + ", + path.to_string_lossy() + )) + .with_expected_errors(vec![ + // TODO: update error handling in ParquetSink + "Unable to send array to writer!", + ]) + .with_memory_limit(200_000) + .run() + .await +} + /// Run the query with the specified memory limit, /// and verifies the expected errors are returned #[derive(Clone, Debug)] From 32cb3c5a54bd0297d473792c8a3b0e7fd51c2e3b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 10 Jul 2024 14:21:44 -0400 Subject: [PATCH 04/59] Minor: remove clones and unnecessary Arcs in `from_substrait_rex` (#11337) --- .../substrait/src/logical_plan/consumer.rs | 146 +++++++----------- 1 file changed, 59 insertions(+), 87 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 89a6dde51e42c..a4f7242024754 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -411,11 +411,11 @@ pub async fn from_substrait_rel( from_substrait_rex(ctx, e, input.clone().schema(), extensions) .await?; // if the expression is WindowFunction, wrap in a Window relation - if let Expr::WindowFunction(_) = x.as_ref() { + if let Expr::WindowFunction(_) = &x { // Adding the same expression here and in the project below // works because the project's builder uses columnize_expr(..) // to transform it into a column reference - input = input.window(vec![x.as_ref().clone()])? + input = input.window(vec![x.clone()])? } // Ensure the expression has a unique display name, so that project's // validate_unique_names doesn't fail @@ -426,12 +426,12 @@ pub async fn from_substrait_rel( new_name = format!("{}__temp__{}", name, i); i += 1; } - names.insert(new_name.clone()); if new_name != name { - exprs.push(x.as_ref().clone().alias(new_name.clone())); + exprs.push(x.alias(new_name.clone())); } else { - exprs.push(x.as_ref().clone()); + exprs.push(x); } + names.insert(new_name); } input.project(exprs)?.build() } else { @@ -447,7 +447,7 @@ pub async fn from_substrait_rel( let expr = from_substrait_rex(ctx, condition, input.schema(), extensions) .await?; - input.filter(expr.as_ref().clone())?.build() + input.filter(expr)?.build() } else { not_impl_err!("Filter without an condition is not valid") } @@ -499,7 +499,7 @@ pub async fn from_substrait_rel( let x = from_substrait_rex(ctx, e, input.schema(), extensions) .await?; - group_expr.push(x.as_ref().clone()); + group_expr.push(x); } } _ => { @@ -514,7 +514,7 @@ pub async fn from_substrait_rel( extensions, ) .await?; - grouping_set.push(x.as_ref().clone()); + grouping_set.push(x); } grouping_sets.push(grouping_set); } @@ -532,9 +532,7 @@ pub async fn from_substrait_rel( let filter = match &m.filter { Some(fil) => Some(Box::new( from_substrait_rex(ctx, fil, input.schema(), extensions) - .await? - .as_ref() - .clone(), + .await?, )), None => None, }; @@ -931,7 +929,7 @@ pub async fn from_substrait_sorts( }; let (asc, nulls_first) = asc_nullfirst.unwrap(); sorts.push(Expr::Sort(Sort { - expr: Box::new(expr.as_ref().clone()), + expr: Box::new(expr), asc, nulls_first, })); @@ -949,7 +947,7 @@ pub async fn from_substrait_rex_vec( let mut expressions: Vec = vec![]; for expr in exprs { let expression = from_substrait_rex(ctx, expr, input_schema, extensions).await?; - expressions.push(expression.as_ref().clone()); + expressions.push(expression); } Ok(expressions) } @@ -969,7 +967,7 @@ pub async fn from_substrait_func_args( } _ => not_impl_err!("Function argument non-Value type not supported"), }; - args.push(arg_expr?.as_ref().clone()); + args.push(arg_expr?); } Ok(args) } @@ -1028,17 +1026,15 @@ pub async fn from_substrait_rex( e: &Expression, input_schema: &DFSchema, extensions: &HashMap, -) -> Result> { +) -> Result { match &e.rex_type { Some(RexType::SingularOrList(s)) => { let substrait_expr = s.value.as_ref().unwrap(); let substrait_list = s.options.as_ref(); - Ok(Arc::new(Expr::InList(InList { + Ok(Expr::InList(InList { expr: Box::new( from_substrait_rex(ctx, substrait_expr, input_schema, extensions) - .await? - .as_ref() - .clone(), + .await?, ), list: from_substrait_rex_vec( ctx, @@ -1048,11 +1044,11 @@ pub async fn from_substrait_rex( ) .await?, negated: false, - }))) + })) + } + Some(RexType::Selection(field_ref)) => { + Ok(from_substrait_field_reference(field_ref, input_schema)?) } - Some(RexType::Selection(field_ref)) => Ok(Arc::new( - from_substrait_field_reference(field_ref, input_schema)?, - )), Some(RexType::IfThen(if_then)) => { // Parse `ifs` // If the first element does not have a `then` part, then we can assume it's a base expression @@ -1069,9 +1065,7 @@ pub async fn from_substrait_rex( input_schema, extensions, ) - .await? - .as_ref() - .clone(), + .await?, )); continue; } @@ -1084,9 +1078,7 @@ pub async fn from_substrait_rex( input_schema, extensions, ) - .await? - .as_ref() - .clone(), + .await?, ), Box::new( from_substrait_rex( @@ -1095,27 +1087,22 @@ pub async fn from_substrait_rex( input_schema, extensions, ) - .await? - .as_ref() - .clone(), + .await?, ), )); } // Parse `else` let else_expr = match &if_then.r#else { Some(e) => Some(Box::new( - from_substrait_rex(ctx, e, input_schema, extensions) - .await? - .as_ref() - .clone(), + from_substrait_rex(ctx, e, input_schema, extensions).await?, )), None => None, }; - Ok(Arc::new(Expr::Case(Case { + Ok(Expr::Case(Case { expr, when_then_expr, else_expr, - }))) + })) } Some(RexType::ScalarFunction(f)) => { let Some(fn_name) = extensions.get(&f.function_reference) else { @@ -1133,8 +1120,9 @@ pub async fn from_substrait_rex( // try to first match the requested function into registered udfs, then built-in ops // and finally built-in expressions if let Some(func) = ctx.state().scalar_functions().get(fn_name) { - Ok(Arc::new(Expr::ScalarFunction( - expr::ScalarFunction::new_udf(func.to_owned(), args), + Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( + func.to_owned(), + args, ))) } else if let Some(op) = name_to_op(fn_name) { if f.arguments.len() < 2 { @@ -1147,17 +1135,14 @@ pub async fn from_substrait_rex( // In those cases we iterate through all the arguments, applying the binary expression against them all let combined_expr = args .into_iter() - .fold(None, |combined_expr: Option>, arg: Expr| { + .fold(None, |combined_expr: Option, arg: Expr| { Some(match combined_expr { - Some(expr) => Arc::new(Expr::BinaryExpr(BinaryExpr { - left: Box::new( - Arc::try_unwrap(expr) - .unwrap_or_else(|arc: Arc| (*arc).clone()), - ), // Avoid cloning if possible + Some(expr) => Expr::BinaryExpr(BinaryExpr { + left: Box::new(expr), op, right: Box::new(arg), - })), - None => Arc::new(arg), + }), + None => arg, }) }) .unwrap(); @@ -1171,10 +1156,10 @@ pub async fn from_substrait_rex( } Some(RexType::Literal(lit)) => { let scalar_value = from_substrait_literal_without_names(lit)?; - Ok(Arc::new(Expr::Literal(scalar_value))) + Ok(Expr::Literal(scalar_value)) } Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() { - Some(output_type) => Ok(Arc::new(Expr::Cast(Cast::new( + Some(output_type) => Ok(Expr::Cast(Cast::new( Box::new( from_substrait_rex( ctx, @@ -1182,12 +1167,10 @@ pub async fn from_substrait_rex( input_schema, extensions, ) - .await? - .as_ref() - .clone(), + .await?, ), from_substrait_type_without_names(output_type)?, - )))), + ))), None => substrait_err!("Cast expression without output type is not allowed"), }, Some(RexType::WindowFunction(window)) => { @@ -1232,7 +1215,7 @@ pub async fn from_substrait_rex( } } }; - Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction { + Ok(Expr::WindowFunction(expr::WindowFunction { fun, args: from_substrait_func_args( ctx, @@ -1255,7 +1238,7 @@ pub async fn from_substrait_rex( from_substrait_bound(&window.upper_bound, false)?, ), null_treatment: None, - }))) + })) } Some(RexType::Subquery(subquery)) => match &subquery.as_ref().subquery_type { Some(subquery_type) => match subquery_type { @@ -1270,7 +1253,7 @@ pub async fn from_substrait_rex( from_substrait_rel(ctx, haystack_expr, extensions) .await?; let outer_refs = haystack_expr.all_out_ref_exprs(); - Ok(Arc::new(Expr::InSubquery(InSubquery { + Ok(Expr::InSubquery(InSubquery { expr: Box::new( from_substrait_rex( ctx, @@ -1278,16 +1261,14 @@ pub async fn from_substrait_rex( input_schema, extensions, ) - .await? - .as_ref() - .clone(), + .await?, ), subquery: Subquery { subquery: Arc::new(haystack_expr), outer_ref_columns: outer_refs, }, negated: false, - }))) + })) } else { substrait_err!("InPredicate Subquery type must have a Haystack expression") } @@ -1301,10 +1282,10 @@ pub async fn from_substrait_rex( ) .await?; let outer_ref_columns = plan.all_out_ref_exprs(); - Ok(Arc::new(Expr::ScalarSubquery(Subquery { + Ok(Expr::ScalarSubquery(Subquery { subquery: Arc::new(plan), outer_ref_columns, - }))) + })) } SubqueryType::SetPredicate(predicate) => { match predicate.predicate_op() { @@ -1318,13 +1299,13 @@ pub async fn from_substrait_rex( ) .await?; let outer_ref_columns = plan.all_out_ref_exprs(); - Ok(Arc::new(Expr::Exists(Exists::new( + Ok(Expr::Exists(Exists::new( Subquery { subquery: Arc::new(plan), outer_ref_columns, }, false, - )))) + ))) } other_type => substrait_err!( "unimplemented type {:?} for set predicate", @@ -1337,7 +1318,7 @@ pub async fn from_substrait_rex( } }, None => { - substrait_err!("Subquery experssion without SubqueryType is not allowed") + substrait_err!("Subquery expression without SubqueryType is not allowed") } }, _ => not_impl_err!("unsupported rex_type"), @@ -2001,7 +1982,7 @@ impl BuiltinExprBuilder { f: &ScalarFunction, input_schema: &DFSchema, extensions: &HashMap, - ) -> Result> { + ) -> Result { match self.expr_name.as_str() { "like" => { Self::build_like_expr(ctx, false, f, input_schema, extensions).await @@ -2026,17 +2007,15 @@ impl BuiltinExprBuilder { f: &ScalarFunction, input_schema: &DFSchema, extensions: &HashMap, - ) -> Result> { + ) -> Result { if f.arguments.len() != 1 { return substrait_err!("Expect one argument for {fn_name} expr"); } let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { return substrait_err!("Invalid arguments type for {fn_name} expr"); }; - let arg = from_substrait_rex(ctx, expr_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); + let arg = + from_substrait_rex(ctx, expr_substrait, input_schema, extensions).await?; let arg = Box::new(arg); let expr = match fn_name { @@ -2053,7 +2032,7 @@ impl BuiltinExprBuilder { _ => return not_impl_err!("Unsupported builtin expression: {}", fn_name), }; - Ok(Arc::new(expr)) + Ok(expr) } async fn build_like_expr( @@ -2062,7 +2041,7 @@ impl BuiltinExprBuilder { f: &ScalarFunction, input_schema: &DFSchema, extensions: &HashMap, - ) -> Result> { + ) -> Result { let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; if f.arguments.len() != 2 && f.arguments.len() != 3 { return substrait_err!("Expect two or three arguments for `{fn_name}` expr"); @@ -2071,18 +2050,13 @@ impl BuiltinExprBuilder { let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let expr = from_substrait_rex(ctx, expr_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); + let expr = + from_substrait_rex(ctx, expr_substrait, input_schema, extensions).await?; let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let pattern = - from_substrait_rex(ctx, pattern_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); + from_substrait_rex(ctx, pattern_substrait, input_schema, extensions).await?; // Default case: escape character is Literal(Utf8(None)) let escape_char = if f.arguments.len() == 3 { @@ -2093,9 +2067,7 @@ impl BuiltinExprBuilder { let escape_char_expr = from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); + .await?; match escape_char_expr { Expr::Literal(ScalarValue::Utf8(escape_char_string)) => { @@ -2112,12 +2084,12 @@ impl BuiltinExprBuilder { None }; - Ok(Arc::new(Expr::Like(Like { + Ok(Expr::Like(Like { negated: false, expr: Box::new(expr), pattern: Box::new(pattern), escape_char, case_insensitive, - }))) + })) } } From cc7484e0b73fe0b36e5f76741399c95e5e7ff1c7 Mon Sep 17 00:00:00 2001 From: June <61218022+itsjunetime@users.noreply.github.com> Date: Wed, 10 Jul 2024 13:11:48 -0600 Subject: [PATCH 05/59] Minor: Change no-statement error message to be clearer (#11394) * Change no-statement error message to be clearer and add tests for said change * Run fmt to pass CI --- .../core/src/execution/session_state.rs | 2 +- datafusion/core/tests/sql/sql_api.rs | 34 +++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index c123ebb22ecb2..60745076c2427 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -555,7 +555,7 @@ impl SessionState { } let statement = statements.pop_front().ok_or_else(|| { DataFusionError::NotImplemented( - "The context requires a statement!".to_string(), + "No SQL statements were provided in the query string".to_string(), ) })?; Ok(statement) diff --git a/datafusion/core/tests/sql/sql_api.rs b/datafusion/core/tests/sql/sql_api.rs index 4a6424fc24b62..e7c40d2c8aa88 100644 --- a/datafusion/core/tests/sql/sql_api.rs +++ b/datafusion/core/tests/sql/sql_api.rs @@ -113,6 +113,40 @@ async fn unsupported_statement_returns_error() { ctx.sql_with_options(sql, options).await.unwrap(); } +#[tokio::test] +async fn empty_statement_returns_error() { + let ctx = SessionContext::new(); + ctx.sql("CREATE TABLE test (x int)").await.unwrap(); + + let state = ctx.state(); + + // Give it an empty string which contains no statements + let plan_res = state.create_logical_plan("").await; + assert_eq!( + plan_res.unwrap_err().strip_backtrace(), + "This feature is not implemented: No SQL statements were provided in the query string" + ); +} + +#[tokio::test] +async fn multiple_statements_returns_error() { + let ctx = SessionContext::new(); + ctx.sql("CREATE TABLE test (x int)").await.unwrap(); + + let state = ctx.state(); + + // Give it a string that contains multiple statements + let plan_res = state + .create_logical_plan( + "INSERT INTO test (x) VALUES (1); INSERT INTO test (x) VALUES (2)", + ) + .await; + assert_eq!( + plan_res.unwrap_err().strip_backtrace(), + "This feature is not implemented: The context currently only supports a single SQL statement" + ); +} + #[tokio::test] async fn ddl_can_not_be_planned_by_session_state() { let ctx = SessionContext::new(); From d3f63728d222cc5cf30cf03a12ec9a0b41399b18 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 11 Jul 2024 07:32:03 +0800 Subject: [PATCH 06/59] Change `array_agg` to return `null` on no input rather than empty list (#11299) * change array agg semantic for empty result Signed-off-by: jayzhan211 * return null Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 * fix order sensitive Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 * add more test Signed-off-by: jayzhan211 * fix null Signed-off-by: jayzhan211 * fix multi-phase case Signed-off-by: jayzhan211 * add comment Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * fix clone Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/common/src/scalar/mod.rs | 10 ++ datafusion/core/tests/dataframe/mod.rs | 2 +- datafusion/core/tests/sql/aggregates.rs | 2 +- datafusion/expr/src/aggregate_function.rs | 2 +- .../physical-expr/src/aggregate/array_agg.rs | 17 +- .../src/aggregate/array_agg_distinct.rs | 11 +- .../src/aggregate/array_agg_ordered.rs | 12 +- .../physical-expr/src/aggregate/build_in.rs | 4 +- .../sqllogictest/test_files/aggregate.slt | 155 +++++++++++++----- 9 files changed, 161 insertions(+), 54 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index c8f21788cbbdf..6c03e8698e80b 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1984,6 +1984,16 @@ impl ScalarValue { Self::new_list(values, data_type, true) } + /// Create ListArray with Null with specific data type + /// + /// - new_null_list(i32, nullable, 1): `ListArray[NULL]` + pub fn new_null_list(data_type: DataType, nullable: bool, null_len: usize) -> Self { + let data_type = DataType::List(Field::new_list_field(data_type, nullable).into()); + Self::List(Arc::new(ListArray::from(ArrayData::new_null( + &data_type, null_len, + )))) + } + /// Converts `IntoIterator` where each element has type corresponding to /// `data_type`, to a [`ListArray`]. /// diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 2d1904d9e1667..f1d57c44293be 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -1388,7 +1388,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { let expected = vec![ "Projection: shapes.shape_id [shape_id:UInt32]", " Unnest: lists[shape_id2] structs[] [shape_id:UInt32, shape_id2:UInt32;N]", - " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} })]", + " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", " TableScan: shapes projection=[shape_id] [shape_id:UInt32]", ]; diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index e503b74992c3f..86032dc9bc963 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -37,7 +37,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { Schema::new(vec![Field::new_list( "ARRAY_AGG(DISTINCT aggregate_test_100.c2)", Field::new("item", DataType::UInt32, false), - false + true ),]) ); diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 23e98714dfa4c..3cae78eaed9b6 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -118,7 +118,7 @@ impl AggregateFunction { pub fn nullable(&self) -> Result { match self { AggregateFunction::Max | AggregateFunction::Min => Ok(true), - AggregateFunction::ArrayAgg => Ok(false), + AggregateFunction::ArrayAgg => Ok(true), } } } diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 634a0a0179037..38a9738029335 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -71,7 +71,7 @@ impl AggregateExpr for ArrayAgg { &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), self.nullable), - false, + true, )) } @@ -86,7 +86,7 @@ impl AggregateExpr for ArrayAgg { Ok(vec![Field::new_list( format_state_name(&self.name, "array_agg"), Field::new("item", self.input_data_type.clone(), self.nullable), - false, + true, )]) } @@ -137,8 +137,11 @@ impl Accumulator for ArrayAggAccumulator { return Ok(()); } assert!(values.len() == 1, "array_agg can only take 1 param!"); + let val = Arc::clone(&values[0]); - self.values.push(val); + if val.len() > 0 { + self.values.push(val); + } Ok(()) } @@ -162,13 +165,15 @@ impl Accumulator for ArrayAggAccumulator { fn evaluate(&mut self) -> Result { // Transform Vec to ListArr - let element_arrays: Vec<&dyn Array> = self.values.iter().map(|a| a.as_ref()).collect(); if element_arrays.is_empty() { - let arr = ScalarValue::new_list(&[], &self.datatype, self.nullable); - return Ok(ScalarValue::List(arr)); + return Ok(ScalarValue::new_null_list( + self.datatype.clone(), + self.nullable, + 1, + )); } let concated_array = arrow::compute::concat(&element_arrays)?; diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index a59d85e84a203..368d11d7421ab 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -75,7 +75,7 @@ impl AggregateExpr for DistinctArrayAgg { &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), self.nullable), - false, + true, )) } @@ -90,7 +90,7 @@ impl AggregateExpr for DistinctArrayAgg { Ok(vec![Field::new_list( format_state_name(&self.name, "distinct_array_agg"), Field::new("item", self.input_data_type.clone(), self.nullable), - false, + true, )]) } @@ -165,6 +165,13 @@ impl Accumulator for DistinctArrayAggAccumulator { fn evaluate(&mut self) -> Result { let values: Vec = self.values.iter().cloned().collect(); + if values.is_empty() { + return Ok(ScalarValue::new_null_list( + self.datatype.clone(), + self.nullable, + 1, + )); + } let arr = ScalarValue::new_list(&values, &self.datatype, self.nullable); Ok(ScalarValue::List(arr)) } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index a64d97637c3bf..d44811192f667 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -92,7 +92,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg { &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), self.nullable), - false, + true, )) } @@ -111,7 +111,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg { let mut fields = vec![Field::new_list( format_state_name(&self.name, "array_agg"), Field::new("item", self.input_data_type.clone(), self.nullable), - false, // This should be the same as field() + true, // This should be the same as field() )]; let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); fields.push(Field::new_list( @@ -309,6 +309,14 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } fn evaluate(&mut self) -> Result { + if self.values.is_empty() { + return Ok(ScalarValue::new_null_list( + self.datatypes[0].clone(), + self.nullable, + 1, + )); + } + let values = self.values.clone(); let array = if self.reverse { ScalarValue::new_list_from_iter( diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index d4cd3d51d1744..68c9b4859f1f8 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -147,7 +147,7 @@ mod tests { Field::new_list( "c1", Field::new("item", data_type.clone(), true), - false, + true, ), result_agg_phy_exprs.field().unwrap() ); @@ -167,7 +167,7 @@ mod tests { Field::new_list( "c1", Field::new("item", data_type.clone(), true), - false, + true, ), result_agg_phy_exprs.field().unwrap() ); diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index e891093c81560..7dd1ea82b3275 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1694,7 +1694,7 @@ SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT query ? SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 LIMIT 0) test ---- -[] +NULL # csv_query_array_agg_one query ? @@ -1753,31 +1753,12 @@ NULL 4 29 1.260869565217 123 -117 23 NULL 5 -194 -13.857142857143 118 -101 14 NULL NULL 781 7.81 125 -117 100 -# TODO: array_agg_distinct output is non-deterministic -- rewrite with array_sort(list_sort) -# unnest is also not available, so manually unnesting via CROSS JOIN -# additional count(1) forces array_agg_distinct instead of array_agg over aggregated by c2 data -# +# select with count to forces array_agg_distinct function, since single distinct expression is converted to group by by optimizer # csv_query_array_agg_distinct -query III -WITH indices AS ( - SELECT 1 AS idx UNION ALL - SELECT 2 AS idx UNION ALL - SELECT 3 AS idx UNION ALL - SELECT 4 AS idx UNION ALL - SELECT 5 AS idx -) -SELECT data.arr[indices.idx] as element, array_length(data.arr) as array_len, dummy -FROM ( - SELECT array_agg(distinct c2) as arr, count(1) as dummy FROM aggregate_test_100 -) data - CROSS JOIN indices -ORDER BY 1 ----- -1 5 100 -2 5 100 -3 5 100 -4 5 100 -5 5 100 +query ?I +SELECT array_sort(array_agg(distinct c2)), count(1) FROM aggregate_test_100 +---- +[1, 2, 3, 4, 5] 100 # aggregate_time_min_and_max query TT @@ -2732,6 +2713,16 @@ SELECT COUNT(DISTINCT c1) FROM test # TODO: aggregate_with_alias +# test_approx_percentile_cont_decimal_support +query TI +SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 4 +b 5 +c 4 +d 4 +e 4 + # array_agg_zero query ? SELECT ARRAY_AGG([]) @@ -2744,28 +2735,114 @@ SELECT ARRAY_AGG([1]) ---- [[1]] -# test_approx_percentile_cont_decimal_support -query TI -SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +# test array_agg with no row qualified +statement ok +create table t(a int, b float, c bigint) as values (1, 1.2, 2); + +# returns NULL, follows DuckDB's behaviour +query ? +select array_agg(a) from t where a > 2; ---- -a 4 -b 5 -c 4 -d 4 -e 4 +NULL +query ? +select array_agg(b) from t where b > 3.1; +---- +NULL -# array_agg_zero query ? -SELECT ARRAY_AGG([]); +select array_agg(c) from t where c > 3; ---- -[[]] +NULL -# array_agg_one +query ?I +select array_agg(c), count(1) from t where c > 3; +---- +NULL 0 + +# returns 0 rows if group by is applied, follows DuckDB's behaviour query ? -SELECT ARRAY_AGG([1]); +select array_agg(a) from t where a > 3 group by a; ---- -[[1]] + +query ?I +select array_agg(a), count(1) from t where a > 3 group by a; +---- + +# returns NULL, follows DuckDB's behaviour +query ? +select array_agg(distinct a) from t where a > 3; +---- +NULL + +query ?I +select array_agg(distinct a), count(1) from t where a > 3; +---- +NULL 0 + +# returns 0 rows if group by is applied, follows DuckDB's behaviour +query ? +select array_agg(distinct a) from t where a > 3 group by a; +---- + +query ?I +select array_agg(distinct a), count(1) from t where a > 3 group by a; +---- + +# test order sensitive array agg +query ? +select array_agg(a order by a) from t where a > 3; +---- +NULL + +query ? +select array_agg(a order by a) from t where a > 3 group by a; +---- + +query ?I +select array_agg(a order by a), count(1) from t where a > 3 group by a; +---- + +statement ok +drop table t; + +# test with no values +statement ok +create table t(a int, b float, c bigint); + +query ? +select array_agg(a) from t; +---- +NULL + +query ? +select array_agg(b) from t; +---- +NULL + +query ? +select array_agg(c) from t; +---- +NULL + +query ?I +select array_agg(distinct a), count(1) from t; +---- +NULL 0 + +query ?I +select array_agg(distinct b), count(1) from t; +---- +NULL 0 + +query ?I +select array_agg(distinct b), count(1) from t; +---- +NULL 0 + +statement ok +drop table t; + # array_agg_i32 statement ok From 7a23ea9bce32dc8ae195caa8ca052673031c06c9 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Thu, 11 Jul 2024 09:38:15 +0800 Subject: [PATCH 07/59] Minor: return "not supported" for `COUNT DISTINCT` with multiple arguments (#11391) * Minor: return "not supported" for COUNT DISTINCT with multiple arguments * update condition --- datafusion/functions-aggregate/src/count.rs | 6 +++++- datafusion/sqllogictest/test_files/aggregate.slt | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index bd0155df0271b..0a667d35dce5e 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -37,7 +37,7 @@ use arrow::{ buffer::BooleanBuffer, }; use datafusion_common::{ - downcast_value, internal_err, DataFusionError, Result, ScalarValue, + downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ @@ -138,6 +138,10 @@ impl AggregateUDFImpl for Count { return Ok(Box::new(CountAccumulator::new())); } + if acc_args.input_exprs.len() > 1 { + return not_impl_err!("COUNT DISTINCT with multiple arguments"); + } + let data_type = acc_args.input_type; Ok(match data_type { // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 7dd1ea82b3275..6fafc0a74110c 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2019,6 +2019,10 @@ SELECT count(c1, c2) FROM test ---- 3 +# count(distinct) with multiple arguments +query error DataFusion error: This feature is not implemented: COUNT DISTINCT with multiple arguments +SELECT count(distinct c1, c2) FROM test + # count_null query III SELECT count(null), count(null, null), count(distinct null) FROM test From 2413155a3ed808285e31421a8b6aac23b8abdb91 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 11 Jul 2024 08:56:47 -0600 Subject: [PATCH 08/59] feat: Add `fail_on_overflow` option to `BinaryExpr` (#11400) * update tests * update tests * add rustdoc * update PartialEq impl * fix * address feedback about improving api --- datafusion/core/src/physical_planner.rs | 4 +- .../physical-expr/src/expressions/binary.rs | 126 +++++++++++++++++- 2 files changed, 121 insertions(+), 9 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 6aad4d5755320..d2bc334ec3248 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2312,7 +2312,7 @@ mod tests { // verify that the plan correctly casts u8 to i64 // the cast from u8 to i64 for literal will be simplified, and get lit(int64(5)) // the cast here is implicit so has CastOptions with safe=true - let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) } }"; + let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }"; assert!(format!("{exec_plan:?}").contains(expected)); Ok(()) } @@ -2551,7 +2551,7 @@ mod tests { let execution_plan = plan(&logical_plan).await?; // verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated. - let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") } }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") } } }"; + let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }"; let actual = format!("{execution_plan:?}"); assert!(actual.contains(expected), "{}", actual); diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index c153ead9639fe..c34dcdfb75988 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -53,6 +53,8 @@ pub struct BinaryExpr { left: Arc, op: Operator, right: Arc, + /// Specifies whether an error is returned on overflow or not + fail_on_overflow: bool, } impl BinaryExpr { @@ -62,7 +64,22 @@ impl BinaryExpr { op: Operator, right: Arc, ) -> Self { - Self { left, op, right } + Self { + left, + op, + right, + fail_on_overflow: false, + } + } + + /// Create new binary expression with explicit fail_on_overflow value + pub fn with_fail_on_overflow(self, fail_on_overflow: bool) -> Self { + Self { + left: self.left, + op: self.op, + right: self.right, + fail_on_overflow, + } } /// Get the left side of the binary expression @@ -273,8 +290,11 @@ impl PhysicalExpr for BinaryExpr { } match self.op { + Operator::Plus if self.fail_on_overflow => return apply(&lhs, &rhs, add), Operator::Plus => return apply(&lhs, &rhs, add_wrapping), + Operator::Minus if self.fail_on_overflow => return apply(&lhs, &rhs, sub), Operator::Minus => return apply(&lhs, &rhs, sub_wrapping), + Operator::Multiply if self.fail_on_overflow => return apply(&lhs, &rhs, mul), Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping), Operator::Divide => return apply(&lhs, &rhs, div), Operator::Modulo => return apply(&lhs, &rhs, rem), @@ -327,11 +347,10 @@ impl PhysicalExpr for BinaryExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(BinaryExpr::new( - Arc::clone(&children[0]), - self.op, - Arc::clone(&children[1]), - ))) + Ok(Arc::new( + BinaryExpr::new(Arc::clone(&children[0]), self.op, Arc::clone(&children[1])) + .with_fail_on_overflow(self.fail_on_overflow), + )) } fn evaluate_bounds(&self, children: &[&Interval]) -> Result { @@ -496,7 +515,12 @@ impl PartialEq for BinaryExpr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() - .map(|x| self.left.eq(&x.left) && self.op == x.op && self.right.eq(&x.right)) + .map(|x| { + self.left.eq(&x.left) + && self.op == x.op + && self.right.eq(&x.right) + && self.fail_on_overflow.eq(&x.fail_on_overflow) + }) .unwrap_or(false) } } @@ -661,6 +685,7 @@ mod tests { use datafusion_common::plan_datafusion_err; use datafusion_expr::type_coercion::binary::get_input_types; + use datafusion_physical_expr_common::expressions::column::Column; /// Performs a binary operation, applying any type coercion necessary fn binary_op( @@ -4008,4 +4033,91 @@ mod tests { .unwrap(); assert_eq!(&casted, &dictionary); } + + #[test] + fn test_add_with_overflow() -> Result<()> { + // create test data + let l = Arc::new(Int32Array::from(vec![1, i32::MAX])); + let r = Arc::new(Int32Array::from(vec![2, 1])); + let schema = Arc::new(Schema::new(vec![ + Field::new("l", DataType::Int32, false), + Field::new("r", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new(schema, vec![l, r])?; + + // create expression + let expr = BinaryExpr::new( + Arc::new(Column::new("l", 0)), + Operator::Plus, + Arc::new(Column::new("r", 1)), + ) + .with_fail_on_overflow(true); + + // evaluate expression + let result = expr.evaluate(&batch); + assert!(result + .err() + .unwrap() + .to_string() + .contains("Overflow happened on: 2147483647 + 1")); + Ok(()) + } + + #[test] + fn test_subtract_with_overflow() -> Result<()> { + // create test data + let l = Arc::new(Int32Array::from(vec![1, i32::MIN])); + let r = Arc::new(Int32Array::from(vec![2, 1])); + let schema = Arc::new(Schema::new(vec![ + Field::new("l", DataType::Int32, false), + Field::new("r", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new(schema, vec![l, r])?; + + // create expression + let expr = BinaryExpr::new( + Arc::new(Column::new("l", 0)), + Operator::Minus, + Arc::new(Column::new("r", 1)), + ) + .with_fail_on_overflow(true); + + // evaluate expression + let result = expr.evaluate(&batch); + assert!(result + .err() + .unwrap() + .to_string() + .contains("Overflow happened on: -2147483648 - 1")); + Ok(()) + } + + #[test] + fn test_mul_with_overflow() -> Result<()> { + // create test data + let l = Arc::new(Int32Array::from(vec![1, i32::MAX])); + let r = Arc::new(Int32Array::from(vec![2, 2])); + let schema = Arc::new(Schema::new(vec![ + Field::new("l", DataType::Int32, false), + Field::new("r", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new(schema, vec![l, r])?; + + // create expression + let expr = BinaryExpr::new( + Arc::new(Column::new("l", 0)), + Operator::Multiply, + Arc::new(Column::new("r", 1)), + ) + .with_fail_on_overflow(true); + + // evaluate expression + let result = expr.evaluate(&batch); + assert!(result + .err() + .unwrap() + .to_string() + .contains("Overflow happened on: 2147483647 * 2")); + Ok(()) + } } From ed65c11065f74d72995619450d5325234aba0b5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Thu, 11 Jul 2024 22:58:20 +0800 Subject: [PATCH 09/59] Enable clone_on_ref_ptr clippy lint on sql (#11380) --- datafusion/sql/examples/sql.rs | 2 +- datafusion/sql/src/cte.rs | 2 +- datafusion/sql/src/expr/mod.rs | 2 +- datafusion/sql/src/lib.rs | 2 ++ datafusion/sql/src/statement.rs | 4 ++-- datafusion/sql/tests/common/mod.rs | 2 +- 6 files changed, 8 insertions(+), 6 deletions(-) diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index aee4cf5a38ed3..1b92a7e116b16 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -119,7 +119,7 @@ fn create_table_source(fields: Vec) -> Arc { impl ContextProvider for MyContextProvider { fn get_table_source(&self, name: TableReference) -> Result> { match self.tables.get(name.table()) { - Some(table) => Ok(table.clone()), + Some(table) => Ok(Arc::clone(table)), _ => plan_err!("Table not found: {}", name.table()), } } diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index 0035dcda6ed7d..3dfe00e3c5e0b 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -144,7 +144,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // as the input to the recursive term let work_table_plan = LogicalPlanBuilder::scan( cte_name.to_string(), - work_table_source.clone(), + Arc::clone(&work_table_source), None, )? .build()?; diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 0546a101fcb25..859842e212be7 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -981,7 +981,7 @@ mod tests { impl ContextProvider for TestContextProvider { fn get_table_source(&self, name: TableReference) -> Result> { match self.tables.get(name.table()) { - Some(table) => Ok(table.clone()), + Some(table) => Ok(Arc::clone(table)), _ => plan_err!("Table not found: {}", name.table()), } } diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index 1040cc61c702b..eb5fec7a3c8bb 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] //! This module provides: //! diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 6cdb2f959cd88..1acfac79acc0b 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -870,12 +870,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.context_provider.get_table_source(table_ref.clone())?; let plan = LogicalPlanBuilder::scan(table_name, table_source, None)?.build()?; - let input_schema = plan.schema().clone(); + let input_schema = Arc::clone(plan.schema()); (plan, input_schema, Some(table_ref)) } CopyToSource::Query(query) => { let plan = self.query_to_plan(query, &mut PlannerContext::new())?; - let input_schema = plan.schema().clone(); + let input_schema = Arc::clone(plan.schema()); (plan, input_schema, None) } }; diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index f5caaefb3ea08..b8d8bd12d28bb 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -258,6 +258,6 @@ impl TableSource for EmptyTable { } fn schema(&self) -> SchemaRef { - self.table_schema.clone() + Arc::clone(&self.table_schema) } } From 0b2eb50c0f980562a6c009f541c4dbd5831b5fe1 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 11 Jul 2024 10:58:53 -0400 Subject: [PATCH 10/59] Move configuration information out of example usage page (#11300) --- datafusion/core/src/lib.rs | 6 + docs/source/index.rst | 8 +- docs/source/library-user-guide/index.md | 21 ++- docs/source/user-guide/crate-configuration.md | 146 ++++++++++++++++++ docs/source/user-guide/example-usage.md | 129 ---------------- 5 files changed, 177 insertions(+), 133 deletions(-) create mode 100644 docs/source/user-guide/crate-configuration.md diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index f5805bc069825..63dbe824c2314 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -620,6 +620,12 @@ doc_comment::doctest!( user_guide_example_usage ); +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/crate-configuration.md", + user_guide_crate_configuration +); + #[cfg(doctest)] doc_comment::doctest!( "../../../docs/source/user-guide/configs.md", diff --git a/docs/source/index.rst b/docs/source/index.rst index d491df04f7fe7..8fbff208f5617 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -41,13 +41,16 @@ DataFusion offers SQL and Dataframe APIs, excellent CSV, Parquet, JSON, and Avro, extensive customization, and a great community. -To get started with examples, see the `example usage`_ section of the user guide and the `datafusion-examples`_ directory. +To get started, see -See the `developer’s guide`_ for contributing and `communication`_ for getting in touch with us. +* The `example usage`_ section of the user guide and the `datafusion-examples`_ directory. +* The `library user guide`_ for examples of using DataFusion's extension APIs +* The `developer’s guide`_ for contributing and `communication`_ for getting in touch with us. .. _example usage: user-guide/example-usage.html .. _datafusion-examples: https://github.com/apache/datafusion/tree/main/datafusion-examples .. _developer’s guide: contributor-guide/index.html#developer-s-guide +.. _library user guide: library-user-guide/index.html .. _communication: contributor-guide/communication.html .. _toc.asf-links: @@ -80,6 +83,7 @@ See the `developer’s guide`_ for contributing and `communication`_ for getting user-guide/introduction user-guide/example-usage + user-guide/crate-configuration user-guide/cli/index user-guide/dataframe user-guide/expressions diff --git a/docs/source/library-user-guide/index.md b/docs/source/library-user-guide/index.md index 47257e0c926e7..fd126a1120edf 100644 --- a/docs/source/library-user-guide/index.md +++ b/docs/source/library-user-guide/index.md @@ -19,8 +19,25 @@ # Introduction -The library user guide explains how to use the DataFusion library as a dependency in your Rust project. Please check out the user-guide for more details on how to use DataFusion's SQL and DataFrame APIs, or the contributor guide for details on how to contribute to DataFusion. +The library user guide explains how to use the DataFusion library as a +dependency in your Rust project and customize its behavior using its extension APIs. -If you haven't reviewed the [architecture section in the docs][docs], it's a useful place to get the lay of the land before starting down a specific path. +Please check out the [user guide] for getting started using +DataFusion's SQL and DataFrame APIs, or the [contributor guide] +for details on how to contribute to DataFusion. +If you haven't reviewed the [architecture section in the docs][docs], it's a +useful place to get the lay of the land before starting down a specific path. + +DataFusion is designed to be extensible at all points, including + +- [x] User Defined Functions (UDFs) +- [x] User Defined Aggregate Functions (UDAFs) +- [x] User Defined Table Source (`TableProvider`) for tables +- [x] User Defined `Optimizer` passes (plan rewrites) +- [x] User Defined `LogicalPlan` nodes +- [x] User Defined `ExecutionPlan` nodes + +[user guide]: ../user-guide/example-usage.md +[contributor guide]: ../contributor-guide/index.md [docs]: https://docs.rs/datafusion/latest/datafusion/#architecture diff --git a/docs/source/user-guide/crate-configuration.md b/docs/source/user-guide/crate-configuration.md new file mode 100644 index 0000000000000..0587d06a39191 --- /dev/null +++ b/docs/source/user-guide/crate-configuration.md @@ -0,0 +1,146 @@ + + +# Crate Configuration + +This section contains information on how to configure DataFusion in your Rust +project. See the [Configuration Settings] section for a list of options that +control DataFusion's behavior. + +[configuration settings]: configs.md + +## Add latest non published DataFusion dependency + +DataFusion changes are published to `crates.io` according to the [release schedule](https://github.com/apache/datafusion/blob/main/dev/release/README.md#release-process) + +If you would like to test out DataFusion changes which are merged but not yet +published, Cargo supports adding dependency directly to GitHub branch: + +```toml +datafusion = { git = "https://github.com/apache/datafusion", branch = "main"} +``` + +Also it works on the package level + +```toml +datafusion-common = { git = "https://github.com/apache/datafusion", branch = "main", package = "datafusion-common"} +``` + +And with features + +```toml +datafusion = { git = "https://github.com/apache/datafusion", branch = "main", default-features = false, features = ["unicode_expressions"] } +``` + +More on [Cargo dependencies](https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#specifying-dependencies) + +## Optimized Configuration + +For an optimized build several steps are required. First, use the below in your `Cargo.toml`. It is +worth noting that using the settings in the `[profile.release]` section will significantly increase the build time. + +```toml +[dependencies] +datafusion = { version = "22.0" } +tokio = { version = "^1.0", features = ["rt-multi-thread"] } +snmalloc-rs = "0.3" + +[profile.release] +lto = true +codegen-units = 1 +``` + +Then, in `main.rs.` update the memory allocator with the below after your imports: + +```rust ,ignore +use datafusion::prelude::*; + +#[global_allocator] +static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; + +#[tokio::main] +async fn main() -> datafusion::error::Result<()> { + Ok(()) +} +``` + +Based on the instruction set architecture you are building on you will want to configure the `target-cpu` as well, ideally +with `native` or at least `avx2`. + +```shell +RUSTFLAGS='-C target-cpu=native' cargo run --release +``` + +## Enable backtraces + +By default Datafusion returns errors as a plain message. There is option to enable more verbose details about the error, +like error backtrace. To enable a backtrace you need to add Datafusion `backtrace` feature to your `Cargo.toml` file: + +```toml +datafusion = { version = "31.0.0", features = ["backtrace"]} +``` + +Set environment [variables](https://doc.rust-lang.org/std/backtrace/index.html#environment-variables) + +```bash +RUST_BACKTRACE=1 ./target/debug/datafusion-cli +DataFusion CLI v31.0.0 +> select row_numer() over (partition by a order by a) from (select 1 a); +Error during planning: Invalid function 'row_numer'. +Did you mean 'ROW_NUMBER'? + +backtrace: 0: std::backtrace_rs::backtrace::libunwind::trace + at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/../../backtrace/src/backtrace/libunwind.rs:93:5 + 1: std::backtrace_rs::backtrace::trace_unsynchronized + at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/../../backtrace/src/backtrace/mod.rs:66:5 + 2: std::backtrace::Backtrace::create + at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/backtrace.rs:332:13 + 3: std::backtrace::Backtrace::capture + at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/backtrace.rs:298:9 + 4: datafusion_common::error::DataFusionError::get_back_trace + at /datafusion/datafusion/common/src/error.rs:436:30 + 5: datafusion_sql::expr::function::>::sql_function_to_expr + ............ +``` + +The backtraces are useful when debugging code. If there is a test in `datafusion/core/src/physical_planner.rs` + +``` +#[tokio::test] +async fn test_get_backtrace_for_failed_code() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = " + select row_numer() over (partition by a order by a) from (select 1 a); + "; + + let _ = ctx.sql(sql).await?.collect().await?; + + Ok(()) +} +``` + +To obtain a backtrace: + +```bash +cargo build --features=backtrace +RUST_BACKTRACE=1 cargo test --features=backtrace --package datafusion --lib -- physical_planner::tests::test_get_backtrace_for_failed_code --exact --nocapture +``` + +Note: The backtrace wrapped into systems calls, so some steps on top of the backtrace can be ignored diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index 7dbd4045e75bd..813dbb1bc02ae 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -33,29 +33,6 @@ datafusion = "latest_version" tokio = { version = "1.0", features = ["rt-multi-thread"] } ``` -## Add latest non published DataFusion dependency - -DataFusion changes are published to `crates.io` according to [release schedule](https://github.com/apache/datafusion/blob/main/dev/release/README.md#release-process) -In case if it is required to test out DataFusion changes which are merged but yet to be published, Cargo supports adding dependency directly to GitHub branch - -```toml -datafusion = { git = "https://github.com/apache/datafusion", branch = "main"} -``` - -Also it works on the package level - -```toml -datafusion-common = { git = "https://github.com/apache/datafusion", branch = "main", package = "datafusion-common"} -``` - -And with features - -```toml -datafusion = { git = "https://github.com/apache/datafusion", branch = "main", default-features = false, features = ["unicode_expressions"] } -``` - -More on [Cargo dependencies](https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#specifying-dependencies) - ## Run a SQL query against data stored in a CSV ```rust @@ -201,109 +178,3 @@ async fn main() -> datafusion::error::Result<()> { | 1 | 2 | +---+--------+ ``` - -## Extensibility - -DataFusion is designed to be extensible at all points. To that end, you can provide your own custom: - -- [x] User Defined Functions (UDFs) -- [x] User Defined Aggregate Functions (UDAFs) -- [x] User Defined Table Source (`TableProvider`) for tables -- [x] User Defined `Optimizer` passes (plan rewrites) -- [x] User Defined `LogicalPlan` nodes -- [x] User Defined `ExecutionPlan` nodes - -## Optimized Configuration - -For an optimized build several steps are required. First, use the below in your `Cargo.toml`. It is -worth noting that using the settings in the `[profile.release]` section will significantly increase the build time. - -```toml -[dependencies] -datafusion = { version = "22.0" } -tokio = { version = "^1.0", features = ["rt-multi-thread"] } -snmalloc-rs = "0.3" - -[profile.release] -lto = true -codegen-units = 1 -``` - -Then, in `main.rs.` update the memory allocator with the below after your imports: - -```rust ,ignore -use datafusion::prelude::*; - -#[global_allocator] -static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; - -#[tokio::main] -async fn main() -> datafusion::error::Result<()> { - Ok(()) -} -``` - -Based on the instruction set architecture you are building on you will want to configure the `target-cpu` as well, ideally -with `native` or at least `avx2`. - -```shell -RUSTFLAGS='-C target-cpu=native' cargo run --release -``` - -## Enable backtraces - -By default Datafusion returns errors as a plain message. There is option to enable more verbose details about the error, -like error backtrace. To enable a backtrace you need to add Datafusion `backtrace` feature to your `Cargo.toml` file: - -```toml -datafusion = { version = "31.0.0", features = ["backtrace"]} -``` - -Set environment [variables](https://doc.rust-lang.org/std/backtrace/index.html#environment-variables) - -```bash -RUST_BACKTRACE=1 ./target/debug/datafusion-cli -DataFusion CLI v31.0.0 -> select row_number() over (partition by a order by a) from (select 1 a); -Error during planning: Invalid function 'row_number'. -Did you mean 'ROW_NUMBER'? - -backtrace: 0: std::backtrace_rs::backtrace::libunwind::trace - at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/../../backtrace/src/backtrace/libunwind.rs:93:5 - 1: std::backtrace_rs::backtrace::trace_unsynchronized - at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/../../backtrace/src/backtrace/mod.rs:66:5 - 2: std::backtrace::Backtrace::create - at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/backtrace.rs:332:13 - 3: std::backtrace::Backtrace::capture - at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/backtrace.rs:298:9 - 4: datafusion_common::error::DataFusionError::get_back_trace - at /datafusion/datafusion/common/src/error.rs:436:30 - 5: datafusion_sql::expr::function::>::sql_function_to_expr - ............ -``` - -The backtraces are useful when debugging code. If there is a test in `datafusion/core/src/physical_planner.rs` - -``` -#[tokio::test] -async fn test_get_backtrace_for_failed_code() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = " - select row_number() over (partition by a order by a) from (select 1 a); - "; - - let _ = ctx.sql(sql).await?.collect().await?; - - Ok(()) -} -``` - -To obtain a backtrace: - -```bash -cargo build --features=backtrace -RUST_BACKTRACE=1 cargo test --features=backtrace --package datafusion --lib -- physical_planner::tests::test_get_backtrace_for_failed_code --exact --nocapture -``` - -Note: The backtrace wrapped into systems calls, so some steps on top of the backtrace can be ignored From faa1e98fc4bec6040c8de07d6c19973e572ad62d Mon Sep 17 00:00:00 2001 From: Arttu Date: Thu, 11 Jul 2024 18:07:53 +0200 Subject: [PATCH 11/59] reuse a single function to create the tpch test contexts (#11396) --- .../tests/cases/consumer_integration.rs | 207 ++++++------------ 1 file changed, 62 insertions(+), 145 deletions(-) diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index 6133c239873b2..10c1319b903b5 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -32,151 +32,22 @@ mod tests { use std::io::BufReader; use substrait::proto::Plan; - async fn register_csv( - ctx: &SessionContext, - table_name: &str, - file_path: &str, - ) -> Result<()> { - ctx.register_csv(table_name, file_path, CsvReadOptions::default()) - .await - } - - async fn create_context_tpch1() -> Result { - let ctx = SessionContext::new(); - register_csv( - &ctx, - "FILENAME_PLACEHOLDER_0", - "tests/testdata/tpch/lineitem.csv", - ) - .await?; - Ok(ctx) - } - - async fn create_context_tpch2() -> Result { - let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/part.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"), - ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/partsupp.csv"), - ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"), - ("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/region.csv"), - ("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/partsupp.csv"), - ("FILENAME_PLACEHOLDER_6", "tests/testdata/tpch/supplier.csv"), - ("FILENAME_PLACEHOLDER_7", "tests/testdata/tpch/nation.csv"), - ("FILENAME_PLACEHOLDER_8", "tests/testdata/tpch/region.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - - Ok(ctx) - } - - async fn create_context_tpch3() -> Result { - let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), - ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - - Ok(ctx) - } - - async fn create_context_tpch4() -> Result { - let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/orders.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/lineitem.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - - Ok(ctx) - } - - async fn create_context_tpch5() -> Result { - let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), - ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), - ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/supplier.csv"), - ("NATION", "tests/testdata/tpch/nation.csv"), - ("REGION", "tests/testdata/tpch/region.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - - Ok(ctx) - } - - async fn create_context_tpch6() -> Result { - let ctx = SessionContext::new(); - - let registrations = - vec![("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/lineitem.csv")]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - - Ok(ctx) - } - // missing context for query 7,8,9 - - async fn create_context_tpch10() -> Result { + async fn create_context(files: Vec<(&str, &str)>) -> Result { let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), - ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), - ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; + for (table_name, file_path) in files { + ctx.register_csv(table_name, file_path, CsvReadOptions::default()) + .await?; } - - Ok(ctx) - } - - async fn create_context_tpch11() -> Result { - let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/partsupp.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"), - ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/nation.csv"), - ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/partsupp.csv"), - ("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/supplier.csv"), - ("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/nation.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - Ok(ctx) } #[tokio::test] async fn tpch_test_1() -> Result<()> { - let ctx = create_context_tpch1().await?; + let ctx = create_context(vec![( + "FILENAME_PLACEHOLDER_0", + "tests/testdata/tpch/lineitem.csv", + )]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_1.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -200,7 +71,18 @@ mod tests { #[tokio::test] async fn tpch_test_2() -> Result<()> { - let ctx = create_context_tpch2().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/part.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"), + ("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/region.csv"), + ("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_6", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_7", "tests/testdata/tpch/nation.csv"), + ("FILENAME_PLACEHOLDER_8", "tests/testdata/tpch/region.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_2.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -242,7 +124,12 @@ mod tests { #[tokio::test] async fn tpch_test_3() -> Result<()> { - let ctx = create_context_tpch3().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_3.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -267,7 +154,11 @@ mod tests { #[tokio::test] async fn tpch_test_4() -> Result<()> { - let ctx = create_context_tpch4().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/lineitem.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_4.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -289,7 +180,15 @@ mod tests { #[tokio::test] async fn tpch_test_5() -> Result<()> { - let ctx = create_context_tpch5().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/supplier.csv"), + ("NATION", "tests/testdata/tpch/nation.csv"), + ("REGION", "tests/testdata/tpch/region.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_5.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -319,7 +218,11 @@ mod tests { #[tokio::test] async fn tpch_test_6() -> Result<()> { - let ctx = create_context_tpch6().await?; + let ctx = create_context(vec![( + "FILENAME_PLACEHOLDER_0", + "tests/testdata/tpch/lineitem.csv", + )]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_6.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -338,7 +241,13 @@ mod tests { // TODO: missing plan 7, 8, 9 #[tokio::test] async fn tpch_test_10() -> Result<()> { - let ctx = create_context_tpch10().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_10.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -365,7 +274,15 @@ mod tests { #[tokio::test] async fn tpch_test_11() -> Result<()> { - let ctx = create_context_tpch11().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/nation.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/nation.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_11.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), From 6692382f22f04542534bba0183cf0682fd932da1 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Thu, 11 Jul 2024 18:17:03 +0200 Subject: [PATCH 12/59] refactor: change error type for "no statement" (#11411) Amends #11394 (sorry, I should have reviewed that). While reporting "not implemented" for "multiple statements" seems reasonable, I think the user should get a plan error (which roughly translates to "invalid argument") if they don't provide any statement. I don't see any reasonable way to support "no statement" ever, hence "not implemented" seems like a wrong promise. --- datafusion/core/src/execution/session_state.rs | 4 +--- datafusion/core/tests/sql/sql_api.rs | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 60745076c2427..dbfba9ea93521 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -554,9 +554,7 @@ impl SessionState { ); } let statement = statements.pop_front().ok_or_else(|| { - DataFusionError::NotImplemented( - "No SQL statements were provided in the query string".to_string(), - ) + plan_datafusion_err!("No SQL statements were provided in the query string") })?; Ok(statement) } diff --git a/datafusion/core/tests/sql/sql_api.rs b/datafusion/core/tests/sql/sql_api.rs index e7c40d2c8aa88..48f4a66b65dcf 100644 --- a/datafusion/core/tests/sql/sql_api.rs +++ b/datafusion/core/tests/sql/sql_api.rs @@ -124,7 +124,7 @@ async fn empty_statement_returns_error() { let plan_res = state.create_logical_plan("").await; assert_eq!( plan_res.unwrap_err().strip_backtrace(), - "This feature is not implemented: No SQL statements were provided in the query string" + "Error during planning: No SQL statements were provided in the query string" ); } From f284e3bb73e089abc0c06b3314014522411bf1da Mon Sep 17 00:00:00 2001 From: Chunchun Ye <14298407+appletreeisyellow@users.noreply.github.com> Date: Thu, 11 Jul 2024 11:17:09 -0500 Subject: [PATCH 13/59] feat: add UDF to_local_time() (#11347) * feat: add UDF `to_local_time()` * chore: support column value in array * chore: lint * chore: fix conversion for us, ms, and s * chore: add more tests for daylight savings time * chore: add function description * refactor: update tests and add examples in description * chore: add description and example * chore: doc chore: doc chore: doc chore: doc chore: doc * chore: stop copying * chore: fix typo * chore: mention that the offset varies based on daylight savings time * refactor: parse timezone once and update examples in description * refactor: replace map..concat with flat_map * chore: add hard code timestamp value in test chore: doc chore: doc * chore: handle errors and remove panics * chore: move some test to slt * chore: clone time_value * chore: typo --------- Co-authored-by: Andrew Lamb --- datafusion/functions/src/datetime/mod.rs | 11 +- .../functions/src/datetime/to_local_time.rs | 564 ++++++++++++++++++ .../sqllogictest/test_files/timestamps.slt | 177 ++++++ 3 files changed, 751 insertions(+), 1 deletion(-) create mode 100644 datafusion/functions/src/datetime/to_local_time.rs diff --git a/datafusion/functions/src/datetime/mod.rs b/datafusion/functions/src/datetime/mod.rs index 9c2f80856bf86..a7e9827d6ca69 100644 --- a/datafusion/functions/src/datetime/mod.rs +++ b/datafusion/functions/src/datetime/mod.rs @@ -32,6 +32,7 @@ pub mod make_date; pub mod now; pub mod to_char; pub mod to_date; +pub mod to_local_time; pub mod to_timestamp; pub mod to_unixtime; @@ -50,6 +51,7 @@ make_udf_function!( make_udf_function!(now::NowFunc, NOW, now); make_udf_function!(to_char::ToCharFunc, TO_CHAR, to_char); make_udf_function!(to_date::ToDateFunc, TO_DATE, to_date); +make_udf_function!(to_local_time::ToLocalTimeFunc, TO_LOCAL_TIME, to_local_time); make_udf_function!(to_unixtime::ToUnixtimeFunc, TO_UNIXTIME, to_unixtime); make_udf_function!(to_timestamp::ToTimestampFunc, TO_TIMESTAMP, to_timestamp); make_udf_function!( @@ -108,7 +110,13 @@ pub mod expr_fn { ),( now, "returns the current timestamp in nanoseconds, using the same value for all instances of now() in same statement", - ),( + ), + ( + to_local_time, + "converts a timezone-aware timestamp to local time (with no offset or timezone information), i.e. strips off the timezone from the timestamp", + args, + ), + ( to_unixtime, "converts a string and optional formats to a Unixtime", args, @@ -277,6 +285,7 @@ pub fn functions() -> Vec> { now(), to_char(), to_date(), + to_local_time(), to_unixtime(), to_timestamp(), to_timestamp_seconds(), diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs new file mode 100644 index 0000000000000..c84d1015bd7ee --- /dev/null +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -0,0 +1,564 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::ops::Add; +use std::sync::Arc; + +use arrow::array::timezone::Tz; +use arrow::array::{Array, ArrayRef, PrimitiveBuilder}; +use arrow::datatypes::DataType::Timestamp; +use arrow::datatypes::{ + ArrowTimestampType, DataType, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, +}; +use arrow::datatypes::{ + TimeUnit, + TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}, +}; + +use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc}; +use datafusion_common::cast::as_primitive_array; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ + ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, +}; + +/// A UDF function that converts a timezone-aware timestamp to local time (with no offset or +/// timezone information). In other words, this function strips off the timezone from the timestamp, +/// while keep the display value of the timestamp the same. +#[derive(Debug)] +pub struct ToLocalTimeFunc { + signature: Signature, +} + +impl Default for ToLocalTimeFunc { + fn default() -> Self { + Self::new() + } +} + +impl ToLocalTimeFunc { + pub fn new() -> Self { + let base_sig = |array_type: TimeUnit| { + [ + Exact(vec![Timestamp(array_type, None)]), + Exact(vec![Timestamp(array_type, Some(TIMEZONE_WILDCARD.into()))]), + ] + }; + + let full_sig = [Nanosecond, Microsecond, Millisecond, Second] + .into_iter() + .flat_map(base_sig) + .collect::>(); + + Self { + signature: Signature::one_of(full_sig, Volatility::Immutable), + } + } + + fn to_local_time(&self, args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!( + "to_local_time function requires 1 argument, got {}", + args.len() + ); + } + + let time_value = &args[0]; + let arg_type = time_value.data_type(); + match arg_type { + DataType::Timestamp(_, None) => { + // if no timezone specificed, just return the input + Ok(time_value.clone()) + } + // If has timezone, adjust the underlying time value. The current time value + // is stored as i64 in UTC, even though the timezone may not be in UTC. Therefore, + // we need to adjust the time value to the local time. See [`adjust_to_local_time`] + // for more details. + // + // Then remove the timezone in return type, i.e. return None + DataType::Timestamp(_, Some(timezone)) => { + let tz: Tz = timezone.parse()?; + + match time_value { + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(ts), + Some(_), + )) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( + Some(ts), + Some(_), + )) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( + Some(ts), + Some(_), + )) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Scalar(ScalarValue::TimestampSecond( + Some(ts), + Some(_), + )) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampSecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Array(array) => { + fn transform_array( + array: &ArrayRef, + tz: Tz, + ) -> Result { + let mut builder = PrimitiveBuilder::::new(); + + let primitive_array = as_primitive_array::(array)?; + for ts_opt in primitive_array.iter() { + match ts_opt { + None => builder.append_null(), + Some(ts) => { + let adjusted_ts: i64 = + adjust_to_local_time::(ts, tz)?; + builder.append_value(adjusted_ts) + } + } + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } + + match array.data_type() { + Timestamp(_, None) => { + // if no timezone specificed, just return the input + Ok(time_value.clone()) + } + Timestamp(Nanosecond, Some(_)) => { + transform_array::(array, tz) + } + Timestamp(Microsecond, Some(_)) => { + transform_array::(array, tz) + } + Timestamp(Millisecond, Some(_)) => { + transform_array::(array, tz) + } + Timestamp(Second, Some(_)) => { + transform_array::(array, tz) + } + _ => { + exec_err!("to_local_time function requires timestamp argument in array, got {:?}", array.data_type()) + } + } + } + _ => { + exec_err!( + "to_local_time function requires timestamp argument, got {:?}", + time_value.data_type() + ) + } + } + } + _ => { + exec_err!( + "to_local_time function requires timestamp argument, got {:?}", + arg_type + ) + } + } + } +} + +/// This function converts a timestamp with a timezone to a timestamp without a timezone. +/// The display value of the adjusted timestamp remain the same, but the underlying timestamp +/// representation is adjusted according to the relative timezone offset to UTC. +/// +/// This function uses chrono to handle daylight saving time changes. +/// +/// For example, +/// +/// ```text +/// '2019-03-31T01:00:00Z'::timestamp at time zone 'Europe/Brussels' +/// ``` +/// +/// is displayed as follows in datafusion-cli: +/// +/// ```text +/// 2019-03-31T01:00:00+01:00 +/// ``` +/// +/// and is represented in DataFusion as: +/// +/// ```text +/// TimestampNanosecond(Some(1_553_990_400_000_000_000), Some("Europe/Brussels")) +/// ``` +/// +/// To strip off the timezone while keeping the display value the same, we need to +/// adjust the underlying timestamp with the timezone offset value using `adjust_to_local_time()` +/// +/// ```text +/// adjust_to_local_time(1_553_990_400_000_000_000, "Europe/Brussels") --> 1_553_994_000_000_000_000 +/// ``` +/// +/// The difference between `1_553_990_400_000_000_000` and `1_553_994_000_000_000_000` is +/// `3600_000_000_000` ns, which corresponds to 1 hour. This matches with the timezone +/// offset for "Europe/Brussels" for this date. +/// +/// Note that the offset varies with daylight savings time (DST), which makes this tricky! For +/// example, timezone "Europe/Brussels" has a 2-hour offset during DST and a 1-hour offset +/// when DST ends. +/// +/// Consequently, DataFusion can represent the timestamp in local time (with no offset or +/// timezone information) as +/// +/// ```text +/// TimestampNanosecond(Some(1_553_994_000_000_000_000), None) +/// ``` +/// +/// which is displayed as follows in datafusion-cli: +/// +/// ```text +/// 2019-03-31T01:00:00 +/// ``` +/// +/// See `test_adjust_to_local_time()` for example +fn adjust_to_local_time(ts: i64, tz: Tz) -> Result { + fn convert_timestamp(ts: i64, converter: F) -> Result> + where + F: Fn(i64) -> MappedLocalTime>, + { + match converter(ts) { + MappedLocalTime::Ambiguous(earliest, latest) => exec_err!( + "Ambiguous timestamp. Do you mean {:?} or {:?}", + earliest, + latest + ), + MappedLocalTime::None => exec_err!( + "The local time does not exist because there is a gap in the local time." + ), + MappedLocalTime::Single(date_time) => Ok(date_time), + } + } + + let date_time = match T::UNIT { + Nanosecond => Utc.timestamp_nanos(ts), + Microsecond => convert_timestamp(ts, |ts| Utc.timestamp_micros(ts))?, + Millisecond => convert_timestamp(ts, |ts| Utc.timestamp_millis_opt(ts))?, + Second => convert_timestamp(ts, |ts| Utc.timestamp_opt(ts, 0))?, + }; + + let offset_seconds: i64 = tz + .offset_from_utc_datetime(&date_time.naive_utc()) + .fix() + .local_minus_utc() as i64; + + let adjusted_date_time = date_time.add( + // This should not fail under normal circumstances as the + // maximum possible offset is 26 hours (93,600 seconds) + TimeDelta::try_seconds(offset_seconds) + .ok_or(DataFusionError::Internal("Offset seconds should be less than i64::MAX / 1_000 or greater than -i64::MAX / 1_000".to_string()))?, + ); + + // convert the naive datetime back to i64 + match T::UNIT { + Nanosecond => adjusted_date_time.timestamp_nanos_opt().ok_or( + DataFusionError::Internal( + "Failed to convert DateTime to timestamp in nanosecond. This error may occur if the date is out of range. The supported date ranges are between 1677-09-21T00:12:43.145224192 and 2262-04-11T23:47:16.854775807".to_string(), + ), + ), + Microsecond => Ok(adjusted_date_time.timestamp_micros()), + Millisecond => Ok(adjusted_date_time.timestamp_millis()), + Second => Ok(adjusted_date_time.timestamp()), + } +} + +impl ScalarUDFImpl for ToLocalTimeFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "to_local_time" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return exec_err!( + "to_local_time function requires 1 argument, got {:?}", + arg_types.len() + ); + } + + match &arg_types[0] { + Timestamp(Nanosecond, _) => Ok(Timestamp(Nanosecond, None)), + Timestamp(Microsecond, _) => Ok(Timestamp(Microsecond, None)), + Timestamp(Millisecond, _) => Ok(Timestamp(Millisecond, None)), + Timestamp(Second, _) => Ok(Timestamp(Second, None)), + _ => exec_err!( + "The to_local_time function can only accept timestamp as the arg, got {:?}", arg_types[0] + ), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!( + "to_local_time function requires 1 argument, got {:?}", + args.len() + ); + } + + self.to_local_time(args) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::{types::TimestampNanosecondType, TimestampNanosecondArray}; + use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; + use arrow::datatypes::{DataType, TimeUnit}; + use chrono::NaiveDateTime; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use super::{adjust_to_local_time, ToLocalTimeFunc}; + + #[test] + fn test_adjust_to_local_time() { + let timestamp_str = "2020-03-31T13:40:00"; + let tz: arrow::array::timezone::Tz = + "America/New_York".parse().expect("Invalid timezone"); + + let timestamp = timestamp_str + .parse::() + .unwrap() + .and_local_timezone(tz) // this is in a local timezone + .unwrap() + .timestamp_nanos_opt() + .unwrap(); + + let expected_timestamp = timestamp_str + .parse::() + .unwrap() + .and_utc() // this is in UTC + .timestamp_nanos_opt() + .unwrap(); + + let res = adjust_to_local_time::(timestamp, tz).unwrap(); + assert_eq!(res, expected_timestamp); + } + + #[test] + fn test_to_local_time_scalar() { + let timezone = Some("Europe/Brussels".into()); + let timestamps_with_timezone = vec![ + ( + ScalarValue::TimestampNanosecond( + Some(1_123_123_000_000_000_000), + timezone.clone(), + ), + ScalarValue::TimestampNanosecond(Some(1_123_130_200_000_000_000), None), + ), + ( + ScalarValue::TimestampMicrosecond( + Some(1_123_123_000_000_000), + timezone.clone(), + ), + ScalarValue::TimestampMicrosecond(Some(1_123_130_200_000_000), None), + ), + ( + ScalarValue::TimestampMillisecond( + Some(1_123_123_000_000), + timezone.clone(), + ), + ScalarValue::TimestampMillisecond(Some(1_123_130_200_000), None), + ), + ( + ScalarValue::TimestampSecond(Some(1_123_123_000), timezone), + ScalarValue::TimestampSecond(Some(1_123_130_200), None), + ), + ]; + + for (input, expected) in timestamps_with_timezone { + test_to_local_time_helper(input, expected); + } + } + + #[test] + fn test_timezone_with_daylight_savings() { + let timezone_str = "America/New_York"; + let tz: arrow::array::timezone::Tz = + timezone_str.parse().expect("Invalid timezone"); + + // Test data: + // ( + // the string display of the input timestamp, + // the i64 representation of the timestamp before adjustment in nanosecond, + // the i64 representation of the timestamp after adjustment in nanosecond, + // ) + let test_cases = vec![ + ( + // DST time + "2020-03-31T13:40:00", + 1_585_676_400_000_000_000, + 1_585_662_000_000_000_000, + ), + ( + // End of DST + "2020-11-04T14:06:40", + 1_604_516_800_000_000_000, + 1_604_498_800_000_000_000, + ), + ]; + + for ( + input_timestamp_str, + expected_input_timestamp, + expected_adjusted_timestamp, + ) in test_cases + { + let input_timestamp = input_timestamp_str + .parse::() + .unwrap() + .and_local_timezone(tz) // this is in a local timezone + .unwrap() + .timestamp_nanos_opt() + .unwrap(); + assert_eq!(input_timestamp, expected_input_timestamp); + + let expected_timestamp = input_timestamp_str + .parse::() + .unwrap() + .and_utc() // this is in UTC + .timestamp_nanos_opt() + .unwrap(); + assert_eq!(expected_timestamp, expected_adjusted_timestamp); + + let input = ScalarValue::TimestampNanosecond( + Some(input_timestamp), + Some(timezone_str.into()), + ); + let expected = + ScalarValue::TimestampNanosecond(Some(expected_timestamp), None); + test_to_local_time_helper(input, expected) + } + } + + fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { + let res = ToLocalTimeFunc::new() + .invoke(&[ColumnarValue::Scalar(input)]) + .unwrap(); + match res { + ColumnarValue::Scalar(res) => { + assert_eq!(res, expected); + } + _ => panic!("unexpected return type"), + } + } + + #[test] + fn test_to_local_time_timezones_array() { + let cases = [ + ( + vec![ + "2020-09-08T00:00:00", + "2020-09-08T01:00:00", + "2020-09-08T02:00:00", + "2020-09-08T03:00:00", + "2020-09-08T04:00:00", + ], + None::>, + vec![ + "2020-09-08T00:00:00", + "2020-09-08T01:00:00", + "2020-09-08T02:00:00", + "2020-09-08T03:00:00", + "2020-09-08T04:00:00", + ], + ), + ( + vec![ + "2020-09-08T00:00:00", + "2020-09-08T01:00:00", + "2020-09-08T02:00:00", + "2020-09-08T03:00:00", + "2020-09-08T04:00:00", + ], + Some("+01:00".into()), + vec![ + "2020-09-08T00:00:00", + "2020-09-08T01:00:00", + "2020-09-08T02:00:00", + "2020-09-08T03:00:00", + "2020-09-08T04:00:00", + ], + ), + ]; + + cases.iter().for_each(|(source, _tz_opt, expected)| { + let input = source + .iter() + .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) + .collect::(); + let right = expected + .iter() + .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) + .collect::(); + let result = ToLocalTimeFunc::new() + .invoke(&[ColumnarValue::Array(Arc::new(input))]) + .unwrap(); + if let ColumnarValue::Array(result) = result { + assert_eq!( + result.data_type(), + &DataType::Timestamp(TimeUnit::Nanosecond, None) + ); + let left = arrow::array::cast::as_primitive_array::< + TimestampNanosecondType, + >(&result); + assert_eq!(left, &right); + } else { + panic!("unexpected column type"); + } + }); + } +} diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index 2216dbfa5fd58..f4e492649b9f8 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -2844,3 +2844,180 @@ select arrow_cast('2024-06-17T13:00:00', 'Timestamp(Nanosecond, Some("UTC"))') - query error select arrow_cast('2024-06-17T13:00:00', 'Timestamp(Nanosecond, Some("+00:00"))') - arrow_cast('2024-06-17T12:00:00', 'Timestamp(Microsecond, Some("+01:00"))'); + +########## +## Test to_local_time function +########## + +# invalid number of arguments -- no argument +statement error +select to_local_time(); + +# invalid number of arguments -- more than 1 argument +statement error +select to_local_time('2024-04-01T00:00:20Z'::timestamp, 'some string'); + +# invalid argument data type +statement error DataFusion error: Execution error: The to_local_time function can only accept timestamp as the arg, got Utf8 +select to_local_time('2024-04-01T00:00:20Z'); + +# invalid timezone +statement error DataFusion error: Arrow error: Parser error: Invalid timezone "Europe/timezone": failed to parse timezone +select to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/timezone'); + +# valid query +query P +select to_local_time('2024-04-01T00:00:20Z'::timestamp); +---- +2024-04-01T00:00:20 + +query P +select to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE '+05:00'); +---- +2024-04-01T00:00:20 + +query P +select to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'); +---- +2024-04-01T00:00:20 + +query PTPT +select + time, + arrow_typeof(time) as type, + to_local_time(time) as to_local_time, + arrow_typeof(to_local_time(time)) as to_local_time_type +from ( + select '2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels' as time +); +---- +2024-04-01T00:00:20+02:00 Timestamp(Nanosecond, Some("Europe/Brussels")) 2024-04-01T00:00:20 Timestamp(Nanosecond, None) + +# use to_local_time() in date_bin() +query P +select date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')); +---- +2024-04-01T00:00:00 + +query P +select date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AT TIME ZONE 'Europe/Brussels'; +---- +2024-04-01T00:00:00+02:00 + +# test using to_local_time() on array values +statement ok +create table t AS +VALUES + ('2024-01-01T00:00:01Z'), + ('2024-02-01T00:00:01Z'), + ('2024-03-01T00:00:01Z'), + ('2024-04-01T00:00:01Z'), + ('2024-05-01T00:00:01Z'), + ('2024-06-01T00:00:01Z'), + ('2024-07-01T00:00:01Z'), + ('2024-08-01T00:00:01Z'), + ('2024-09-01T00:00:01Z'), + ('2024-10-01T00:00:01Z'), + ('2024-11-01T00:00:01Z'), + ('2024-12-01T00:00:01Z') +; + +statement ok +create view t_utc as +select column1::timestamp AT TIME ZONE 'UTC' as "column1" +from t; + +statement ok +create view t_timezone as +select column1::timestamp AT TIME ZONE 'Europe/Brussels' as "column1" +from t; + +query PPT +select column1, to_local_time(column1::timestamp), arrow_typeof(to_local_time(column1::timestamp)) from t_utc; +---- +2024-01-01T00:00:01Z 2024-01-01T00:00:01 Timestamp(Nanosecond, None) +2024-02-01T00:00:01Z 2024-02-01T00:00:01 Timestamp(Nanosecond, None) +2024-03-01T00:00:01Z 2024-03-01T00:00:01 Timestamp(Nanosecond, None) +2024-04-01T00:00:01Z 2024-04-01T00:00:01 Timestamp(Nanosecond, None) +2024-05-01T00:00:01Z 2024-05-01T00:00:01 Timestamp(Nanosecond, None) +2024-06-01T00:00:01Z 2024-06-01T00:00:01 Timestamp(Nanosecond, None) +2024-07-01T00:00:01Z 2024-07-01T00:00:01 Timestamp(Nanosecond, None) +2024-08-01T00:00:01Z 2024-08-01T00:00:01 Timestamp(Nanosecond, None) +2024-09-01T00:00:01Z 2024-09-01T00:00:01 Timestamp(Nanosecond, None) +2024-10-01T00:00:01Z 2024-10-01T00:00:01 Timestamp(Nanosecond, None) +2024-11-01T00:00:01Z 2024-11-01T00:00:01 Timestamp(Nanosecond, None) +2024-12-01T00:00:01Z 2024-12-01T00:00:01 Timestamp(Nanosecond, None) + +query PPT +select column1, to_local_time(column1), arrow_typeof(to_local_time(column1)) from t_utc; +---- +2024-01-01T00:00:01Z 2024-01-01T00:00:01 Timestamp(Nanosecond, None) +2024-02-01T00:00:01Z 2024-02-01T00:00:01 Timestamp(Nanosecond, None) +2024-03-01T00:00:01Z 2024-03-01T00:00:01 Timestamp(Nanosecond, None) +2024-04-01T00:00:01Z 2024-04-01T00:00:01 Timestamp(Nanosecond, None) +2024-05-01T00:00:01Z 2024-05-01T00:00:01 Timestamp(Nanosecond, None) +2024-06-01T00:00:01Z 2024-06-01T00:00:01 Timestamp(Nanosecond, None) +2024-07-01T00:00:01Z 2024-07-01T00:00:01 Timestamp(Nanosecond, None) +2024-08-01T00:00:01Z 2024-08-01T00:00:01 Timestamp(Nanosecond, None) +2024-09-01T00:00:01Z 2024-09-01T00:00:01 Timestamp(Nanosecond, None) +2024-10-01T00:00:01Z 2024-10-01T00:00:01 Timestamp(Nanosecond, None) +2024-11-01T00:00:01Z 2024-11-01T00:00:01 Timestamp(Nanosecond, None) +2024-12-01T00:00:01Z 2024-12-01T00:00:01 Timestamp(Nanosecond, None) + +query PPT +select column1, to_local_time(column1), arrow_typeof(to_local_time(column1)) from t_timezone; +---- +2024-01-01T00:00:01+01:00 2024-01-01T00:00:01 Timestamp(Nanosecond, None) +2024-02-01T00:00:01+01:00 2024-02-01T00:00:01 Timestamp(Nanosecond, None) +2024-03-01T00:00:01+01:00 2024-03-01T00:00:01 Timestamp(Nanosecond, None) +2024-04-01T00:00:01+02:00 2024-04-01T00:00:01 Timestamp(Nanosecond, None) +2024-05-01T00:00:01+02:00 2024-05-01T00:00:01 Timestamp(Nanosecond, None) +2024-06-01T00:00:01+02:00 2024-06-01T00:00:01 Timestamp(Nanosecond, None) +2024-07-01T00:00:01+02:00 2024-07-01T00:00:01 Timestamp(Nanosecond, None) +2024-08-01T00:00:01+02:00 2024-08-01T00:00:01 Timestamp(Nanosecond, None) +2024-09-01T00:00:01+02:00 2024-09-01T00:00:01 Timestamp(Nanosecond, None) +2024-10-01T00:00:01+02:00 2024-10-01T00:00:01 Timestamp(Nanosecond, None) +2024-11-01T00:00:01+01:00 2024-11-01T00:00:01 Timestamp(Nanosecond, None) +2024-12-01T00:00:01+01:00 2024-12-01T00:00:01 Timestamp(Nanosecond, None) + +# combine to_local_time() with date_bin() +query P +select date_bin(interval '1 day', to_local_time(column1)) AT TIME ZONE 'Europe/Brussels' as date_bin from t_utc; +---- +2024-01-01T00:00:00+01:00 +2024-02-01T00:00:00+01:00 +2024-03-01T00:00:00+01:00 +2024-04-01T00:00:00+02:00 +2024-05-01T00:00:00+02:00 +2024-06-01T00:00:00+02:00 +2024-07-01T00:00:00+02:00 +2024-08-01T00:00:00+02:00 +2024-09-01T00:00:00+02:00 +2024-10-01T00:00:00+02:00 +2024-11-01T00:00:00+01:00 +2024-12-01T00:00:00+01:00 + +query P +select date_bin(interval '1 day', to_local_time(column1)) AT TIME ZONE 'Europe/Brussels' as date_bin from t_timezone; +---- +2024-01-01T00:00:00+01:00 +2024-02-01T00:00:00+01:00 +2024-03-01T00:00:00+01:00 +2024-04-01T00:00:00+02:00 +2024-05-01T00:00:00+02:00 +2024-06-01T00:00:00+02:00 +2024-07-01T00:00:00+02:00 +2024-08-01T00:00:00+02:00 +2024-09-01T00:00:00+02:00 +2024-10-01T00:00:00+02:00 +2024-11-01T00:00:00+01:00 +2024-12-01T00:00:00+01:00 + +statement ok +drop table t; + +statement ok +drop view t_utc; + +statement ok +drop view t_timezone; From 1e9f0e1d650f0549e6a8f7d6971b7373fae5199c Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen <83442793+MohamedAbdeen21@users.noreply.github.com> Date: Thu, 11 Jul 2024 19:20:10 +0300 Subject: [PATCH 14/59] Implement prettier SQL unparsing (more human readable) (#11186) * initial prettier unparse * bug fix * handling minus and divide * cleaning references and comments * moved tests * Update precedence of BETWEEN * rerun CI * Change precedence to match PGSQLs * more pretty unparser tests * Update operator precedence to match latest PGSQL * directly prettify expr_to_sql * handle IS operator * correct IS precedence * update unparser tests * update unparser example * update more unparser examples * add with_pretty builder to unparser --- .../examples/parse_sql_expr.rs | 9 + datafusion-examples/examples/plan_to_sql.rs | 18 +- datafusion/expr/src/operator.rs | 24 +- datafusion/sql/src/unparser/expr.rs | 230 ++++++++++++++---- datafusion/sql/src/unparser/mod.rs | 15 +- datafusion/sql/tests/cases/plan_to_sql.rs | 99 +++++++- 6 files changed, 319 insertions(+), 76 deletions(-) diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs index a1fc5d269a043..e23e5accae397 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -153,5 +153,14 @@ async fn round_trip_parse_sql_expr_demo() -> Result<()> { assert_eq!(sql, round_trip_sql); + // enable pretty-unparsing. This make the output more human-readable + // but can be problematic when passed to other SQL engines due to + // difference in precedence rules between DataFusion and target engines. + let unparser = Unparser::default().with_pretty(true); + + let pretty = "int_col < 5 OR double_col = 8"; + let pretty_round_trip_sql = unparser.expr_to_sql(&parsed_expr)?.to_string(); + assert_eq!(pretty, pretty_round_trip_sql); + Ok(()) } diff --git a/datafusion-examples/examples/plan_to_sql.rs b/datafusion-examples/examples/plan_to_sql.rs index bd708fe52bc15..f719a33fb6249 100644 --- a/datafusion-examples/examples/plan_to_sql.rs +++ b/datafusion-examples/examples/plan_to_sql.rs @@ -31,9 +31,9 @@ use datafusion_sql::unparser::{plan_to_sql, Unparser}; /// 1. [`simple_expr_to_sql_demo`]: Create a simple expression [`Exprs`] with /// fluent API and convert to sql suitable for passing to another database /// -/// 2. [`simple_expr_to_sql_demo_no_escape`] Create a simple expression -/// [`Exprs`] with fluent API and convert to sql without escaping column names -/// more suitable for displaying to humans. +/// 2. [`simple_expr_to_pretty_sql_demo`] Create a simple expression +/// [`Exprs`] with fluent API and convert to sql without extra parentheses, +/// suitable for displaying to humans /// /// 3. [`simple_expr_to_sql_demo_escape_mysql_style`]" Create a simple /// expression [`Exprs`] with fluent API and convert to sql escaping column @@ -49,6 +49,7 @@ use datafusion_sql::unparser::{plan_to_sql, Unparser}; async fn main() -> Result<()> { // See how to evaluate expressions simple_expr_to_sql_demo()?; + simple_expr_to_pretty_sql_demo()?; simple_expr_to_sql_demo_escape_mysql_style()?; simple_plan_to_sql_demo().await?; round_trip_plan_to_sql_demo().await?; @@ -64,6 +65,17 @@ fn simple_expr_to_sql_demo() -> Result<()> { Ok(()) } +/// DataFusioon can remove parentheses when converting an expression to SQL. +/// Note that output is intended for humans, not for other SQL engines, +/// as difference in precedence rules can cause expressions to be parsed differently. +fn simple_expr_to_pretty_sql_demo() -> Result<()> { + let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); + let unparser = Unparser::default().with_pretty(true); + let sql = unparser.expr_to_sql(&expr)?.to_string(); + assert_eq!(sql, r#"a < 5 OR a = 8"#); + Ok(()) +} + /// DataFusion can convert expressions to SQL without escaping column names using /// using a custom dialect and an explicit unparser fn simple_expr_to_sql_demo_escape_mysql_style() -> Result<()> { diff --git a/datafusion/expr/src/operator.rs b/datafusion/expr/src/operator.rs index a10312e234460..9bb8c48d6c71f 100644 --- a/datafusion/expr/src/operator.rs +++ b/datafusion/expr/src/operator.rs @@ -218,29 +218,23 @@ impl Operator { } /// Get the operator precedence - /// use as a reference + /// use as a reference pub fn precedence(&self) -> u8 { match self { Operator::Or => 5, Operator::And => 10, - Operator::NotEq - | Operator::Eq - | Operator::Lt - | Operator::LtEq - | Operator::Gt - | Operator::GtEq => 20, - Operator::Plus | Operator::Minus => 30, - Operator::Multiply | Operator::Divide | Operator::Modulo => 40, + Operator::Eq | Operator::NotEq | Operator::LtEq | Operator::GtEq => 15, + Operator::Lt | Operator::Gt => 20, + Operator::LikeMatch + | Operator::NotLikeMatch + | Operator::ILikeMatch + | Operator::NotILikeMatch => 25, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom | Operator::RegexMatch | Operator::RegexNotMatch | Operator::RegexIMatch | Operator::RegexNotIMatch - | Operator::LikeMatch - | Operator::ILikeMatch - | Operator::NotLikeMatch - | Operator::NotILikeMatch | Operator::BitwiseAnd | Operator::BitwiseOr | Operator::BitwiseShiftLeft @@ -248,7 +242,9 @@ impl Operator { | Operator::BitwiseXor | Operator::StringConcat | Operator::AtArrow - | Operator::ArrowAt => 0, + | Operator::ArrowAt => 30, + Operator::Plus | Operator::Minus => 40, + Operator::Multiply | Operator::Divide | Operator::Modulo => 45, } } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 198186934c84b..e0d05c400cb09 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -30,8 +30,8 @@ use arrow_array::{Date32Array, Date64Array, PrimitiveArray}; use arrow_schema::DataType; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, Expr as AstExpr, Function, FunctionArg, Ident, Interval, TimezoneInfo, - UnaryOperator, + self, BinaryOperator, Expr as AstExpr, Function, FunctionArg, Ident, Interval, + TimezoneInfo, UnaryOperator, }; use datafusion_common::{ @@ -101,8 +101,21 @@ pub fn expr_to_unparsed(expr: &Expr) -> Result { unparser.expr_to_unparsed(expr) } +const LOWEST: &BinaryOperator = &BinaryOperator::Or; +// closest precedence we have to IS operator is BitwiseAnd (any other) in PG docs +// (https://www.postgresql.org/docs/7.2/sql-precedence.html) +const IS: &BinaryOperator = &BinaryOperator::BitwiseAnd; + impl Unparser<'_> { pub fn expr_to_sql(&self, expr: &Expr) -> Result { + let mut root_expr = self.expr_to_sql_inner(expr)?; + if self.pretty { + root_expr = self.remove_unnecessary_nesting(root_expr, LOWEST, LOWEST); + } + Ok(root_expr) + } + + fn expr_to_sql_inner(&self, expr: &Expr) -> Result { match expr { Expr::InList(InList { expr, @@ -111,10 +124,10 @@ impl Unparser<'_> { }) => { let list_expr = list .iter() - .map(|e| self.expr_to_sql(e)) + .map(|e| self.expr_to_sql_inner(e)) .collect::>>()?; Ok(ast::Expr::InList { - expr: Box::new(self.expr_to_sql(expr)?), + expr: Box::new(self.expr_to_sql_inner(expr)?), list: list_expr, negated: *negated, }) @@ -128,7 +141,7 @@ impl Unparser<'_> { if matches!(e, Expr::Wildcard { qualifier: None }) { Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard)) } else { - self.expr_to_sql(e).map(|e| { + self.expr_to_sql_inner(e).map(|e| { FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) }) } @@ -157,9 +170,9 @@ impl Unparser<'_> { low, high, }) => { - let sql_parser_expr = self.expr_to_sql(expr)?; - let sql_low = self.expr_to_sql(low)?; - let sql_high = self.expr_to_sql(high)?; + let sql_parser_expr = self.expr_to_sql_inner(expr)?; + let sql_low = self.expr_to_sql_inner(low)?; + let sql_high = self.expr_to_sql_inner(high)?; Ok(ast::Expr::Nested(Box::new(self.between_op_to_sql( sql_parser_expr, *negated, @@ -169,8 +182,8 @@ impl Unparser<'_> { } Expr::Column(col) => self.col_to_sql(col), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = self.expr_to_sql(left.as_ref())?; - let r = self.expr_to_sql(right.as_ref())?; + let l = self.expr_to_sql_inner(left.as_ref())?; + let r = self.expr_to_sql_inner(right.as_ref())?; let op = self.op_to_sql(op)?; Ok(ast::Expr::Nested(Box::new(self.binary_op_to_sql(l, r, op)))) @@ -182,21 +195,21 @@ impl Unparser<'_> { }) => { let conditions = when_then_expr .iter() - .map(|(w, _)| self.expr_to_sql(w)) + .map(|(w, _)| self.expr_to_sql_inner(w)) .collect::>>()?; let results = when_then_expr .iter() - .map(|(_, t)| self.expr_to_sql(t)) + .map(|(_, t)| self.expr_to_sql_inner(t)) .collect::>>()?; let operand = match expr.as_ref() { - Some(e) => match self.expr_to_sql(e) { + Some(e) => match self.expr_to_sql_inner(e) { Ok(sql_expr) => Some(Box::new(sql_expr)), Err(_) => None, }, None => None, }; let else_result = match else_expr.as_ref() { - Some(e) => match self.expr_to_sql(e) { + Some(e) => match self.expr_to_sql_inner(e) { Ok(sql_expr) => Some(Box::new(sql_expr)), Err(_) => None, }, @@ -211,7 +224,7 @@ impl Unparser<'_> { }) } Expr::Cast(Cast { expr, data_type }) => { - let inner_expr = self.expr_to_sql(expr)?; + let inner_expr = self.expr_to_sql_inner(expr)?; Ok(ast::Expr::Cast { kind: ast::CastKind::Cast, expr: Box::new(inner_expr), @@ -220,7 +233,7 @@ impl Unparser<'_> { }) } Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), - Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql(expr), + Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), Expr::WindowFunction(WindowFunction { fun, args, @@ -255,7 +268,7 @@ impl Unparser<'_> { window_name: None, partition_by: partition_by .iter() - .map(|e| self.expr_to_sql(e)) + .map(|e| self.expr_to_sql_inner(e)) .collect::>>()?, order_by, window_frame: Some(ast::WindowFrame { @@ -296,8 +309,8 @@ impl Unparser<'_> { case_insensitive: _, }) => Ok(ast::Expr::Like { negated: *negated, - expr: Box::new(self.expr_to_sql(expr)?), - pattern: Box::new(self.expr_to_sql(pattern)?), + expr: Box::new(self.expr_to_sql_inner(expr)?), + pattern: Box::new(self.expr_to_sql_inner(pattern)?), escape_char: escape_char.map(|c| c.to_string()), }), Expr::AggregateFunction(agg) => { @@ -305,7 +318,7 @@ impl Unparser<'_> { let args = self.function_args_to_sql(&agg.args)?; let filter = match &agg.filter { - Some(filter) => Some(Box::new(self.expr_to_sql(filter)?)), + Some(filter) => Some(Box::new(self.expr_to_sql_inner(filter)?)), None => None, }; Ok(ast::Expr::Function(Function { @@ -339,7 +352,7 @@ impl Unparser<'_> { Ok(ast::Expr::Subquery(sub_query)) } Expr::InSubquery(insubq) => { - let inexpr = Box::new(self.expr_to_sql(insubq.expr.as_ref())?); + let inexpr = Box::new(self.expr_to_sql_inner(insubq.expr.as_ref())?); let sub_statement = self.plan_to_sql(insubq.subquery.subquery.as_ref())?; let sub_query = if let ast::Statement::Query(inner_query) = sub_statement @@ -377,38 +390,38 @@ impl Unparser<'_> { nulls_first: _, }) => plan_err!("Sort expression should be handled by expr_to_unparsed"), Expr::IsNull(expr) => { - Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotNull(expr) => { - Ok(ast::Expr::IsNotNull(Box::new(self.expr_to_sql(expr)?))) + Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql_inner(expr)?))) } + Expr::IsNotNull(expr) => Ok(ast::Expr::IsNotNull(Box::new( + self.expr_to_sql_inner(expr)?, + ))), Expr::IsTrue(expr) => { - Ok(ast::Expr::IsTrue(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotTrue(expr) => { - Ok(ast::Expr::IsNotTrue(Box::new(self.expr_to_sql(expr)?))) + Ok(ast::Expr::IsTrue(Box::new(self.expr_to_sql_inner(expr)?))) } + Expr::IsNotTrue(expr) => Ok(ast::Expr::IsNotTrue(Box::new( + self.expr_to_sql_inner(expr)?, + ))), Expr::IsFalse(expr) => { - Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotFalse(expr) => { - Ok(ast::Expr::IsNotFalse(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsUnknown(expr) => { - Ok(ast::Expr::IsUnknown(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotUnknown(expr) => { - Ok(ast::Expr::IsNotUnknown(Box::new(self.expr_to_sql(expr)?))) - } + Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql_inner(expr)?))) + } + Expr::IsNotFalse(expr) => Ok(ast::Expr::IsNotFalse(Box::new( + self.expr_to_sql_inner(expr)?, + ))), + Expr::IsUnknown(expr) => Ok(ast::Expr::IsUnknown(Box::new( + self.expr_to_sql_inner(expr)?, + ))), + Expr::IsNotUnknown(expr) => Ok(ast::Expr::IsNotUnknown(Box::new( + self.expr_to_sql_inner(expr)?, + ))), Expr::Not(expr) => { - let sql_parser_expr = self.expr_to_sql(expr)?; + let sql_parser_expr = self.expr_to_sql_inner(expr)?; Ok(AstExpr::UnaryOp { op: UnaryOperator::Not, expr: Box::new(sql_parser_expr), }) } Expr::Negative(expr) => { - let sql_parser_expr = self.expr_to_sql(expr)?; + let sql_parser_expr = self.expr_to_sql_inner(expr)?; Ok(AstExpr::UnaryOp { op: UnaryOperator::Minus, expr: Box::new(sql_parser_expr), @@ -432,7 +445,7 @@ impl Unparser<'_> { }) } Expr::TryCast(TryCast { expr, data_type }) => { - let inner_expr = self.expr_to_sql(expr)?; + let inner_expr = self.expr_to_sql_inner(expr)?; Ok(ast::Expr::Cast { kind: ast::CastKind::TryCast, expr: Box::new(inner_expr), @@ -449,7 +462,7 @@ impl Unparser<'_> { .iter() .map(|set| { set.iter() - .map(|e| self.expr_to_sql(e)) + .map(|e| self.expr_to_sql_inner(e)) .collect::>>() }) .collect::>>()?; @@ -460,7 +473,7 @@ impl Unparser<'_> { let expr_ast_sets = cube .iter() .map(|e| { - let sql = self.expr_to_sql(e)?; + let sql = self.expr_to_sql_inner(e)?; Ok(vec![sql]) }) .collect::>>()?; @@ -470,7 +483,7 @@ impl Unparser<'_> { let expr_ast_sets: Vec> = rollup .iter() .map(|e| { - let sql = self.expr_to_sql(e)?; + let sql = self.expr_to_sql_inner(e)?; Ok(vec![sql]) }) .collect::>>()?; @@ -603,6 +616,88 @@ impl Unparser<'_> { } } + /// Given an expression of the form `((a + b) * (c * d))`, + /// the parenthesing is redundant if the precedence of the nested expression is already higher + /// than the surrounding operators' precedence. The above expression would become + /// `(a + b) * c * d`. + /// + /// Also note that when fetching the precedence of a nested expression, we ignore other nested + /// expressions, so precedence of expr `(a * (b + c))` equals `*` and not `+`. + fn remove_unnecessary_nesting( + &self, + expr: ast::Expr, + left_op: &BinaryOperator, + right_op: &BinaryOperator, + ) -> ast::Expr { + match expr { + ast::Expr::Nested(nested) => { + let surrounding_precedence = self + .sql_op_precedence(left_op) + .max(self.sql_op_precedence(right_op)); + + let inner_precedence = self.inner_precedence(&nested); + + let not_associative = + matches!(left_op, BinaryOperator::Minus | BinaryOperator::Divide); + + if inner_precedence == surrounding_precedence && not_associative { + ast::Expr::Nested(Box::new( + self.remove_unnecessary_nesting(*nested, LOWEST, LOWEST), + )) + } else if inner_precedence >= surrounding_precedence { + self.remove_unnecessary_nesting(*nested, left_op, right_op) + } else { + ast::Expr::Nested(Box::new( + self.remove_unnecessary_nesting(*nested, LOWEST, LOWEST), + )) + } + } + ast::Expr::BinaryOp { left, op, right } => ast::Expr::BinaryOp { + left: Box::new(self.remove_unnecessary_nesting(*left, left_op, &op)), + right: Box::new(self.remove_unnecessary_nesting(*right, &op, right_op)), + op, + }, + ast::Expr::IsTrue(expr) => ast::Expr::IsTrue(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNotTrue(expr) => ast::Expr::IsNotTrue(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsFalse(expr) => ast::Expr::IsFalse(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNotFalse(expr) => ast::Expr::IsNotFalse(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNull(expr) => ast::Expr::IsNull(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNotNull(expr) => ast::Expr::IsNotNull(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsUnknown(expr) => ast::Expr::IsUnknown(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNotUnknown(expr) => ast::Expr::IsNotUnknown(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + _ => expr, + } + } + + fn inner_precedence(&self, expr: &ast::Expr) -> u8 { + match expr { + ast::Expr::Nested(_) | ast::Expr::Identifier(_) | ast::Expr::Value(_) => 100, + ast::Expr::BinaryOp { op, .. } => self.sql_op_precedence(op), + // closest precedence we currently have to Between is PGLikeMatch + // (https://www.postgresql.org/docs/7.2/sql-precedence.html) + ast::Expr::Between { .. } => { + self.sql_op_precedence(&ast::BinaryOperator::PGLikeMatch) + } + _ => 0, + } + } + pub(super) fn between_op_to_sql( &self, expr: ast::Expr, @@ -618,6 +713,48 @@ impl Unparser<'_> { } } + fn sql_op_precedence(&self, op: &BinaryOperator) -> u8 { + match self.sql_to_op(op) { + Ok(op) => op.precedence(), + Err(_) => 0, + } + } + + fn sql_to_op(&self, op: &BinaryOperator) -> Result { + match op { + ast::BinaryOperator::Eq => Ok(Operator::Eq), + ast::BinaryOperator::NotEq => Ok(Operator::NotEq), + ast::BinaryOperator::Lt => Ok(Operator::Lt), + ast::BinaryOperator::LtEq => Ok(Operator::LtEq), + ast::BinaryOperator::Gt => Ok(Operator::Gt), + ast::BinaryOperator::GtEq => Ok(Operator::GtEq), + ast::BinaryOperator::Plus => Ok(Operator::Plus), + ast::BinaryOperator::Minus => Ok(Operator::Minus), + ast::BinaryOperator::Multiply => Ok(Operator::Multiply), + ast::BinaryOperator::Divide => Ok(Operator::Divide), + ast::BinaryOperator::Modulo => Ok(Operator::Modulo), + ast::BinaryOperator::And => Ok(Operator::And), + ast::BinaryOperator::Or => Ok(Operator::Or), + ast::BinaryOperator::PGRegexMatch => Ok(Operator::RegexMatch), + ast::BinaryOperator::PGRegexIMatch => Ok(Operator::RegexIMatch), + ast::BinaryOperator::PGRegexNotMatch => Ok(Operator::RegexNotMatch), + ast::BinaryOperator::PGRegexNotIMatch => Ok(Operator::RegexNotIMatch), + ast::BinaryOperator::PGILikeMatch => Ok(Operator::ILikeMatch), + ast::BinaryOperator::PGNotLikeMatch => Ok(Operator::NotLikeMatch), + ast::BinaryOperator::PGLikeMatch => Ok(Operator::LikeMatch), + ast::BinaryOperator::PGNotILikeMatch => Ok(Operator::NotILikeMatch), + ast::BinaryOperator::BitwiseAnd => Ok(Operator::BitwiseAnd), + ast::BinaryOperator::BitwiseOr => Ok(Operator::BitwiseOr), + ast::BinaryOperator::BitwiseXor => Ok(Operator::BitwiseXor), + ast::BinaryOperator::PGBitwiseShiftRight => Ok(Operator::BitwiseShiftRight), + ast::BinaryOperator::PGBitwiseShiftLeft => Ok(Operator::BitwiseShiftLeft), + ast::BinaryOperator::StringConcat => Ok(Operator::StringConcat), + ast::BinaryOperator::AtArrow => Ok(Operator::AtArrow), + ast::BinaryOperator::ArrowAt => Ok(Operator::ArrowAt), + _ => not_impl_err!("unsupported operation: {op:?}"), + } + } + fn op_to_sql(&self, op: &Operator) -> Result { match op { Operator::Eq => Ok(ast::BinaryOperator::Eq), @@ -1538,6 +1675,7 @@ mod tests { Ok(()) } + #[test] fn custom_dialect() -> Result<()> { let dialect = CustomDialect::new(Some('\'')); diff --git a/datafusion/sql/src/unparser/mod.rs b/datafusion/sql/src/unparser/mod.rs index fbbed4972b173..e5ffbc8a212ab 100644 --- a/datafusion/sql/src/unparser/mod.rs +++ b/datafusion/sql/src/unparser/mod.rs @@ -29,11 +29,23 @@ pub mod dialect; pub struct Unparser<'a> { dialect: &'a dyn Dialect, + pretty: bool, } impl<'a> Unparser<'a> { pub fn new(dialect: &'a dyn Dialect) -> Self { - Self { dialect } + Self { + dialect, + pretty: false, + } + } + + /// Allow unparser to remove parenthesis according to the precedence rules of DataFusion. + /// This might make it invalid SQL for other SQL query engines with different precedence + /// rules, even if its valid for DataFusion. + pub fn with_pretty(mut self, pretty: bool) -> Self { + self.pretty = pretty; + self } } @@ -41,6 +53,7 @@ impl<'a> Default for Unparser<'a> { fn default() -> Self { Self { dialect: &DefaultDialect {}, + pretty: false, } } } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 374403d853f92..91295b2e8aae9 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -104,26 +104,26 @@ fn roundtrip_statement() -> Result<()> { "select id, count(*) as cnt from (select p1.id as id from person p1 inner join person p2 on p1.id=p2.id) group by id", "select id, count(*), first_name from person group by first_name, id", "select id, sum(age), first_name from person group by first_name, id", - "select id, count(*), first_name - from person + "select id, count(*), first_name + from person where id!=3 and first_name=='test' - group by first_name, id + group by first_name, id having count(*)>5 and count(*)<10 order by count(*)", - r#"select id, count("First Name") as count_first_name, "Last Name" + r#"select id, count("First Name") as count_first_name, "Last Name" from person_quoted_cols where id!=3 and "First Name"=='test' - group by "Last Name", id + group by "Last Name", id having count_first_name>5 and count_first_name<10 order by count_first_name, "Last Name""#, r#"select p.id, count("First Name") as count_first_name, - "Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person) + "Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person) from (select id, "First Name", "Last Name" from person_quoted_cols) qp inner join (select * from person) p on p.id = qp.id - where p.id!=3 and "First Name"=='test' and qp.id in + where p.id!=3 and "First Name"=='test' and qp.id in (select id from (select id, count(*) from person group by id having count(*) > 0)) - group by "Last Name", p.id + group by "Last Name", p.id having count_first_name>5 and count_first_name<10 order by count_first_name, "Last Name""#, r#"SELECT j1_string as string FROM j1 @@ -134,12 +134,12 @@ fn roundtrip_statement() -> Result<()> { SELECT j2_string as string FROM j2 ORDER BY string DESC LIMIT 10"#, - "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), - last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), first_name from person", - r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) from person"#, - "SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person", + "SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person", ]; // For each test sql string, we transform as follows: @@ -314,3 +314,78 @@ fn test_table_references_in_plan_to_sql() { "SELECT \"table\".id, \"table\".\"value\" FROM \"table\"", ); } + +#[test] +fn test_pretty_roundtrip() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("age", DataType::Utf8, false), + ]); + + let df_schema = DFSchema::try_from(schema)?; + + let context = MockContextProvider::default(); + let sql_to_rel = SqlToRel::new(&context); + + let unparser = Unparser::default().with_pretty(true); + + let sql_to_pretty_unparse = vec![ + ("((id < 5) OR (age = 8))", "id < 5 OR age = 8"), + ("((id + 5) * (age * 8))", "(id + 5) * age * 8"), + ("(3 + (5 * 6) * 3)", "3 + 5 * 6 * 3"), + ("((3 * (5 + 6)) * 3)", "3 * (5 + 6) * 3"), + ("((3 AND (5 OR 6)) * 3)", "(3 AND (5 OR 6)) * 3"), + ("((3 + (5 + 6)) * 3)", "(3 + 5 + 6) * 3"), + ("((3 + (5 + 6)) + 3)", "3 + 5 + 6 + 3"), + ("3 + 5 + 6 + 3", "3 + 5 + 6 + 3"), + ("3 + (5 + (6 + 3))", "3 + 5 + 6 + 3"), + ("3 + ((5 + 6) + 3)", "3 + 5 + 6 + 3"), + ("(3 + 5) + (6 + 3)", "3 + 5 + 6 + 3"), + ("((3 + 5) + (6 + 3))", "3 + 5 + 6 + 3"), + ( + "((id > 10) OR (age BETWEEN 10 AND 20))", + "id > 10 OR age BETWEEN 10 AND 20", + ), + ( + "((id > 10) * (age BETWEEN 10 AND 20))", + "(id > 10) * (age BETWEEN 10 AND 20)", + ), + ("id - (age - 8)", "id - (age - 8)"), + ("((id - age) - 8)", "id - age - 8"), + ("(id OR (age - 8))", "id OR age - 8"), + ("(id / (age - 8))", "id / (age - 8)"), + ("((id / age) * 8)", "id / age * 8"), + ("((age + 10) < 20) IS TRUE", "(age + 10 < 20) IS TRUE"), + ( + "(20 > (age + 5)) IS NOT FALSE", + "(20 > age + 5) IS NOT FALSE", + ), + ("(true AND false) IS FALSE", "(true AND false) IS FALSE"), + ("true AND (false IS FALSE)", "true AND false IS FALSE"), + ]; + + for (sql, pretty) in sql_to_pretty_unparse.iter() { + let sql_expr = Parser::new(&GenericDialect {}) + .try_with_sql(sql)? + .parse_expr()?; + let expr = + sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())?; + let round_trip_sql = unparser.expr_to_sql(&expr)?.to_string(); + assert_eq!(pretty.to_string(), round_trip_sql); + + // verify that the pretty string parses to the same underlying Expr + let pretty_sql_expr = Parser::new(&GenericDialect {}) + .try_with_sql(pretty)? + .parse_expr()?; + + let pretty_expr = sql_to_rel.sql_to_expr( + pretty_sql_expr, + &df_schema, + &mut PlannerContext::new(), + )?; + + assert_eq!(expr.to_string(), pretty_expr.to_string()); + } + + Ok(()) +} From e19dd2d0b91f30b97fd68da894137987c1318b18 Mon Sep 17 00:00:00 2001 From: Chunchun Ye <14298407+appletreeisyellow@users.noreply.github.com> Date: Thu, 11 Jul 2024 11:21:51 -0500 Subject: [PATCH 15/59] Add `to_local_time()` in function reference docs (#11401) * chore: add document for `to_local_time()` * chore: feedback Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb --- .../source/user-guide/sql/scalar_functions.md | 65 ++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index d636726b45fe1..d2e012cf4093d 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1480,6 +1480,7 @@ contains(string, search_string) - [make_date](#make_date) - [to_char](#to_char) - [to_date](#to_date) +- [to_local_time](#to_local_time) - [to_timestamp](#to_timestamp) - [to_timestamp_millis](#to_timestamp_millis) - [to_timestamp_micros](#to_timestamp_micros) @@ -1710,7 +1711,7 @@ to_char(expression, format) #### Example ``` -> > select to_char('2023-03-01'::date, '%d-%m-%Y'); +> select to_char('2023-03-01'::date, '%d-%m-%Y'); +----------------------------------------------+ | to_char(Utf8("2023-03-01"),Utf8("%d-%m-%Y")) | +----------------------------------------------+ @@ -1771,6 +1772,68 @@ to_date(expression[, ..., format_n]) Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_date.rs) +### `to_local_time` + +Converts a timestamp with a timezone to a timestamp without a timezone (with no offset or +timezone information). This function handles daylight saving time changes. + +``` +to_local_time(expression) +``` + +#### Arguments + +- **expression**: Time expression to operate on. Can be a constant, column, or function. + +#### Example + +``` +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT + time, + arrow_typeof(time) as type, + to_local_time(time) as to_local_time, + arrow_typeof(to_local_time(time)) as to_local_time_type +FROM ( + SELECT '2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels' AS time +); ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| time | type | to_local_time | to_local_time_type | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| 2024-04-01T00:00:20+02:00 | Timestamp(Nanosecond, Some("Europe/Brussels")) | 2024-04-01T00:00:20 | Timestamp(Nanosecond, None) | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ + +# combine `to_local_time()` with `date_bin()` to bin on boundaries in the timezone rather +# than UTC boundaries + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AS date_bin; ++---------------------+ +| date_bin | ++---------------------+ +| 2024-04-01T00:00:00 | ++---------------------+ + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AT TIME ZONE 'Europe/Brussels' AS date_bin_with_timezone; ++---------------------------+ +| date_bin_with_timezone | ++---------------------------+ +| 2024-04-01T00:00:00+02:00 | ++---------------------------+ +``` + ### `to_timestamp` Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). From 4402a1a9dd8ebec1640b2fa807781a2701407672 Mon Sep 17 00:00:00 2001 From: Dharan Aditya Date: Thu, 11 Jul 2024 21:52:06 +0530 Subject: [PATCH 16/59] Move `overlay` planning to`ExprPlanner` (#11398) * move overlay to expr planner * typo --- datafusion/expr/src/planner.rs | 7 ++++++ datafusion/functions/src/core/planner.rs | 6 +++++ datafusion/functions/src/string/mod.rs | 1 - datafusion/sql/src/expr/mod.rs | 28 ++++++++++++------------ 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index aeb8ed8372b76..2f13923b1f10a 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -161,6 +161,13 @@ pub trait ExprPlanner: Send + Sync { ) -> Result>> { Ok(PlannerResult::Original(args)) } + + /// Plans an overlay expression eg `overlay(str PLACING substr FROM pos [FOR count])` + /// + /// Returns origin expression arguments if not possible + fn plan_overlay(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Original(args)) + } } /// An operator with two arguments to plan diff --git a/datafusion/functions/src/core/planner.rs b/datafusion/functions/src/core/planner.rs index 748b598d292fe..63eaa9874c2b9 100644 --- a/datafusion/functions/src/core/planner.rs +++ b/datafusion/functions/src/core/planner.rs @@ -56,4 +56,10 @@ impl ExprPlanner for CoreFunctionPlanner { ), ))) } + + fn plan_overlay(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::string::overlay(), args), + ))) + } } diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 5bf372c29f2d5..9a19151a85e26 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -182,7 +182,6 @@ pub fn functions() -> Vec> { lower(), ltrim(), octet_length(), - overlay(), repeat(), replace(), rtrim(), diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 859842e212be7..062ef805fd9f8 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -193,7 +193,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - not_impl_err!("Extract not supported by UserDefinedExtensionPlanners: {extract_args:?}") + not_impl_err!("Extract not supported by ExprPlanner: {extract_args:?}") } SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema), @@ -292,7 +292,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - not_impl_err!("GetFieldAccess not supported by UserDefinedExtensionPlanners: {field_access_expr:?}") + not_impl_err!( + "GetFieldAccess not supported by ExprPlanner: {field_access_expr:?}" + ) } SQLExpr::CompoundIdentifier(ids) => { @@ -657,7 +659,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { PlannerResult::Original(args) => create_struct_args = args, } } - not_impl_err!("Struct not supported by UserDefinedExtensionPlanners: {create_struct_args:?}") + not_impl_err!("Struct not supported by ExprPlanner: {create_struct_args:?}") } fn sql_position_to_expr( @@ -680,9 +682,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - not_impl_err!( - "Position not supported by UserDefinedExtensionPlanners: {position_args:?}" - ) + not_impl_err!("Position not supported by ExprPlanner: {position_args:?}") } fn try_plan_dictionary_literal( @@ -914,18 +914,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let fun = self - .context_provider - .get_function_meta("overlay") - .ok_or_else(|| { - internal_datafusion_err!("Unable to find expected 'overlay' function") - })?; let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; let what_arg = self.sql_expr_to_logical_expr(overlay_what, schema, planner_context)?; let from_arg = self.sql_expr_to_logical_expr(overlay_from, schema, planner_context)?; - let args = match overlay_for { + let mut overlay_args = match overlay_for { Some(for_expr) => { let for_expr = self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?; @@ -933,7 +927,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } None => vec![arg, what_arg, from_arg], }; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) + for planner in self.planners.iter() { + match planner.plan_overlay(overlay_args)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(args) => overlay_args = args, + } + } + not_impl_err!("Overlay not supported by ExprPlanner: {overlay_args:?}") } } From d314ced8090cb599fd7808d7df41699e46ac956e Mon Sep 17 00:00:00 2001 From: Marko Grujic Date: Thu, 11 Jul 2024 18:22:20 +0200 Subject: [PATCH 17/59] Coerce types for all union children plans when eliminating nesting (#11386) --- .../optimizer/src/eliminate_nested_union.rs | 13 +++++++------ datafusion/sqllogictest/test_files/union.slt | 15 +++++++++++++++ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index c8ae937e128a6..cc8cf1f56c184 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -60,7 +60,8 @@ impl OptimizerRule for EliminateNestedUnion { let inputs = inputs .into_iter() .flat_map(extract_plans_from_union) - .collect::>(); + .map(|plan| coerce_plan_expr_for_schema(&plan, &schema)) + .collect::>>()?; Ok(Transformed::yes(LogicalPlan::Union(Union { inputs: inputs.into_iter().map(Arc::new).collect_vec(), @@ -74,7 +75,8 @@ impl OptimizerRule for EliminateNestedUnion { .into_iter() .map(extract_plan_from_distinct) .flat_map(extract_plans_from_union) - .collect::>(); + .map(|plan| coerce_plan_expr_for_schema(&plan, &schema)) + .collect::>>()?; Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::All( Arc::new(LogicalPlan::Union(Union { @@ -95,10 +97,9 @@ impl OptimizerRule for EliminateNestedUnion { fn extract_plans_from_union(plan: Arc) -> Vec { match unwrap_arc(plan) { - LogicalPlan::Union(Union { inputs, schema }) => inputs - .into_iter() - .map(|plan| coerce_plan_expr_for_schema(&plan, &schema).unwrap()) - .collect::>(), + LogicalPlan::Union(Union { inputs, .. }) => { + inputs.into_iter().map(unwrap_arc).collect::>() + } plan => vec![plan], } } diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 7b91e97e4a3e2..5ede68a42aae6 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -135,6 +135,21 @@ SELECT SUM(d) FROM ( ---- 5 +# three way union with aggregate and type coercion +query II rowsort +SELECT c1, SUM(c2) FROM ( + SELECT 1 as c1, 1::int as c2 + UNION + SELECT 2 as c1, 2::int as c2 + UNION + SELECT 3 as c1, COALESCE(3::int, 0) as c2 +) as a +GROUP BY c1 +---- +1 1 +2 2 +3 3 + # union_all_with_count statement ok CREATE table t as SELECT 1 as a From 4bed04e4e312a0b125306944aee94a93c2ff6c4f Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Thu, 11 Jul 2024 19:26:46 +0300 Subject: [PATCH 18/59] Add customizable equality and hash functions to UDFs (#11392) * Add customizable equality and hash functions to UDFs * Improve equals and hash_value documentation * Add tests for parameterized UDFs --- .../user_defined/user_defined_aggregates.rs | 79 ++++++++++- .../user_defined_scalar_functions.rs | 128 +++++++++++++++++- datafusion/expr/src/udaf.rs | 73 ++++++++-- datafusion/expr/src/udf.rs | 62 +++++++-- datafusion/expr/src/udwf.rs | 69 ++++++++-- 5 files changed, 367 insertions(+), 44 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index d591c662d8774..96de865b6554a 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -18,14 +18,19 @@ //! This module contains end to end demonstrations of creating //! user defined aggregate functions -use arrow::{array::AsArray, datatypes::Fields}; -use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray}; -use arrow_schema::Schema; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; +use arrow::{array::AsArray, datatypes::Fields}; +use arrow_array::{ + types::UInt64Type, Int32Array, PrimitiveArray, StringArray, StructArray, +}; +use arrow_schema::Schema; + +use datafusion::dataframe::DataFrame; use datafusion::datasource::MemTable; use datafusion::test_util::plan_and_collect; use datafusion::{ @@ -45,8 +50,8 @@ use datafusion::{ }; use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err}; use datafusion_expr::{ - create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, - SimpleAggregateUDF, + col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, + LogicalPlanBuilder, SimpleAggregateUDF, }; use datafusion_functions_aggregate::average::AvgAccumulator; @@ -377,6 +382,55 @@ async fn test_groups_accumulator() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_parameterized_aggregate_udf() -> Result<()> { + let batch = RecordBatch::try_from_iter([( + "text", + Arc::new(StringArray::from(vec!["foo"])) as ArrayRef, + )])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable); + let udf1 = AggregateUDF::from(TestGroupsAccumulator { + signature: signature.clone(), + result: 1, + }); + let udf2 = AggregateUDF::from(TestGroupsAccumulator { + signature: signature.clone(), + result: 2, + }); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .aggregate( + [col("text")], + [ + udf1.call(vec![col("text")]).alias("a"), + udf2.call(vec![col("text")]).alias("b"), + ], + )? + .build()?; + + assert_eq!( + format!("{plan:?}"), + "Aggregate: groupBy=[[t.text]], aggr=[[geo_mean(t.text) AS a, geo_mean(t.text) AS b]]\n TableScan: t projection=[text]" + ); + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + let expected = [ + "+------+---+---+", + "| text | a | b |", + "+------+---+---+", + "| foo | 1 | 2 |", + "+------+---+---+", + ]; + assert_batches_eq!(expected, &actual); + + ctx.deregister_table("t")?; + Ok(()) +} + /// Returns an context with a table "t" and the "first" and "time_sum" /// aggregate functions registered. /// @@ -735,6 +789,21 @@ impl AggregateUDFImpl for TestGroupsAccumulator { ) -> Result> { Ok(Box::new(self.clone())) } + + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.result == other.result && self.signature == other.signature + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.signature.hash(hasher); + self.result.hash(hasher); + hasher.finish() + } } impl Accumulator for TestGroupsAccumulator { diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 1733068debb96..5847952ae6a61 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -16,11 +16,20 @@ // under the License. use std::any::Any; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; use arrow::compute::kernels::numeric::add; -use arrow_array::{ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch}; +use arrow_array::builder::BooleanBuilder; +use arrow_array::cast::AsArray; +use arrow_array::{ + Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, +}; use arrow_schema::{DataType, Field, Schema}; +use parking_lot::Mutex; +use regex::Regex; +use sqlparser::ast::Ident; + use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; @@ -37,8 +46,6 @@ use datafusion_expr::{ Volatility, }; use datafusion_functions_array::range::range_udf; -use parking_lot::Mutex; -use sqlparser::ast::Ident; /// test that casting happens on udfs. /// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and @@ -1021,6 +1028,121 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<( Ok(()) } +#[derive(Debug)] +struct MyRegexUdf { + signature: Signature, + regex: Regex, +} + +impl MyRegexUdf { + fn new(pattern: &str) -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + regex: Regex::new(pattern).expect("regex"), + } + } + + fn matches(&self, value: Option<&str>) -> Option { + Some(self.regex.is_match(value?)) + } +} + +impl ScalarUDFImpl for MyRegexUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "regex_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, args: &[DataType]) -> Result { + if matches!(args, [DataType::Utf8]) { + Ok(DataType::Boolean) + } else { + plan_err!("regex_udf only accepts a Utf8 argument") + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args { + [ColumnarValue::Scalar(ScalarValue::Utf8(value))] => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean( + self.matches(value.as_deref()), + ))) + } + [ColumnarValue::Array(values)] => { + let mut builder = BooleanBuilder::with_capacity(values.len()); + for value in values.as_string::() { + builder.append_option(self.matches(value)) + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } + _ => exec_err!("regex_udf only accepts a Utf8 arguments"), + } + } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.regex.as_str() == other.regex.as_str() + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.regex.as_str().hash(hasher); + hasher.finish() + } +} + +#[tokio::test] +async fn test_parameterized_scalar_udf() -> Result<()> { + let batch = RecordBatch::try_from_iter([( + "text", + Arc::new(StringArray::from(vec!["foo", "bar", "foobar", "barfoo"])) as ArrayRef, + )])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let foo_udf = ScalarUDF::from(MyRegexUdf::new("fo{2}")); + let bar_udf = ScalarUDF::from(MyRegexUdf::new("[Bb]ar")); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .filter( + foo_udf + .call(vec![col("text")]) + .and(bar_udf.call(vec![col("text")])), + )? + .filter(col("text").is_not_null())? + .build()?; + + assert_eq!( + format!("{plan:?}"), + "Filter: t.text IS NOT NULL\n Filter: regex_udf(t.text) AND regex_udf(t.text)\n TableScan: t projection=[text]" + ); + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + let expected = [ + "+--------+", + "| text |", + "+--------+", + "| foobar |", + "| barfoo |", + "+--------+", + ]; + assert_batches_eq!(expected, &actual); + + ctx.deregister_table("t")?; + Ok(()) +} + fn create_udf_context() -> SessionContext { let ctx = SessionContext::new(); // register a custom UDF diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 7a054abea75b3..1657e034fbe2b 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,6 +17,17 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions +use std::any::Any; +use std::fmt::{self, Debug, Formatter}; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::sync::Arc; +use std::vec; + +use arrow::datatypes::{DataType, Field}; +use sqlparser::ast::NullTreatment; + +use datafusion_common::{exec_err, not_impl_err, plan_err, Result}; + use crate::expr::AggregateFunction; use crate::function::{ AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, @@ -26,13 +37,6 @@ use crate::utils::format_state_name; use crate::utils::AggregateOrderSensitivity; use crate::{Accumulator, Expr}; use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature}; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::{exec_err, not_impl_err, plan_err, Result}; -use sqlparser::ast::NullTreatment; -use std::any::Any; -use std::fmt::{self, Debug, Formatter}; -use std::sync::Arc; -use std::vec; /// Logical representation of a user-defined [aggregate function] (UDAF). /// @@ -72,20 +76,19 @@ pub struct AggregateUDF { impl PartialEq for AggregateUDF { fn eq(&self, other: &Self) -> bool { - self.name() == other.name() && self.signature() == other.signature() + self.inner.equals(other.inner.as_ref()) } } impl Eq for AggregateUDF {} -impl std::hash::Hash for AggregateUDF { - fn hash(&self, state: &mut H) { - self.name().hash(state); - self.signature().hash(state); +impl Hash for AggregateUDF { + fn hash(&self, state: &mut H) { + self.inner.hash_value().hash(state) } } -impl std::fmt::Display for AggregateUDF { +impl fmt::Display for AggregateUDF { fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(f, "{}", self.name()) } @@ -280,7 +283,7 @@ where /// #[derive(Debug, Clone)] /// struct GeoMeanUdf { /// signature: Signature -/// }; +/// } /// /// impl GeoMeanUdf { /// fn new() -> Self { @@ -507,6 +510,33 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { not_impl_err!("Function {} does not implement coerce_types", self.name()) } + + /// Return true if this aggregate UDF is equal to the other. + /// + /// Allows customizing the equality of aggregate UDFs. + /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: + /// + /// - reflexive: `a.equals(a)`; + /// - symmetric: `a.equals(b)` implies `b.equals(a)`; + /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. + /// + /// By default, compares [`Self::name`] and [`Self::signature`]. + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Returns a hash value for this aggregate UDF. + /// + /// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`], + /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. + /// + /// By default, hashes [`Self::name`] and [`Self::signature`]. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } } pub enum ReversedUDAF { @@ -562,6 +592,21 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { fn aliases(&self) -> &[String] { &self.aliases } + + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.inner.hash_value().hash(hasher); + self.aliases.hash(hasher); + hasher.finish() + } } /// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 68d3af6ace3c0..1fbb3cc584b34 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -19,8 +19,13 @@ use std::any::Any; use std::fmt::{self, Debug, Formatter}; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; +use arrow::datatypes::DataType; + +use datafusion_common::{not_impl_err, ExprSchema, Result}; + use crate::expr::create_name; use crate::interval_arithmetic::Interval; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; @@ -29,9 +34,6 @@ use crate::{ ColumnarValue, Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature, }; -use arrow::datatypes::DataType; -use datafusion_common::{not_impl_err, ExprSchema, Result}; - /// Logical representation of a Scalar User Defined Function. /// /// A scalar function produces a single row output for each row of input. This @@ -59,16 +61,15 @@ pub struct ScalarUDF { impl PartialEq for ScalarUDF { fn eq(&self, other: &Self) -> bool { - self.name() == other.name() && self.signature() == other.signature() + self.inner.equals(other.inner.as_ref()) } } impl Eq for ScalarUDF {} -impl std::hash::Hash for ScalarUDF { - fn hash(&self, state: &mut H) { - self.name().hash(state); - self.signature().hash(state); +impl Hash for ScalarUDF { + fn hash(&self, state: &mut H) { + self.inner.hash_value().hash(state) } } @@ -294,7 +295,7 @@ where /// #[derive(Debug)] /// struct AddOne { /// signature: Signature -/// }; +/// } /// /// impl AddOne { /// fn new() -> Self { @@ -540,6 +541,33 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { not_impl_err!("Function {} does not implement coerce_types", self.name()) } + + /// Return true if this scalar UDF is equal to the other. + /// + /// Allows customizing the equality of scalar UDFs. + /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: + /// + /// - reflexive: `a.equals(a)`; + /// - symmetric: `a.equals(b)` implies `b.equals(a)`; + /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. + /// + /// By default, compares [`Self::name`] and [`Self::signature`]. + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Returns a hash value for this scalar UDF. + /// + /// Allows customizing the hash code of scalar UDFs. Similarly to [`Hash`] and [`Eq`], + /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. + /// + /// By default, hashes [`Self::name`] and [`Self::signature`]. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -557,7 +585,6 @@ impl AliasedScalarUDFImpl { ) -> Self { let mut aliases = inner.aliases().to_vec(); aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); - Self { inner, aliases } } } @@ -586,6 +613,21 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn aliases(&self) -> &[String] { &self.aliases } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.inner.hash_value().hash(hasher); + self.aliases.hash(hasher); + hasher.finish() + } } /// Implementation of [`ScalarUDFImpl`] that wraps the function style pointers diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 70b44e5e307a4..1a6b21e3dd294 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -17,18 +17,22 @@ //! [`WindowUDF`]: User Defined Window Functions -use crate::{ - function::WindowFunctionSimplification, Expr, PartitionEvaluator, - PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, -}; -use arrow::datatypes::DataType; -use datafusion_common::Result; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::{ any::Any, fmt::{self, Debug, Display, Formatter}, sync::Arc, }; +use arrow::datatypes::DataType; + +use datafusion_common::Result; + +use crate::{ + function::WindowFunctionSimplification, Expr, PartitionEvaluator, + PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, +}; + /// Logical representation of a user-defined window function (UDWF) /// A UDWF is different from a UDF in that it is stateful across batches. /// @@ -62,16 +66,15 @@ impl Display for WindowUDF { impl PartialEq for WindowUDF { fn eq(&self, other: &Self) -> bool { - self.name() == other.name() && self.signature() == other.signature() + self.inner.equals(other.inner.as_ref()) } } impl Eq for WindowUDF {} -impl std::hash::Hash for WindowUDF { - fn hash(&self, state: &mut H) { - self.name().hash(state); - self.signature().hash(state); +impl Hash for WindowUDF { + fn hash(&self, state: &mut H) { + self.inner.hash_value().hash(state) } } @@ -212,7 +215,7 @@ where /// #[derive(Debug, Clone)] /// struct SmoothIt { /// signature: Signature -/// }; +/// } /// /// impl SmoothIt { /// fn new() -> Self { @@ -296,6 +299,33 @@ pub trait WindowUDFImpl: Debug + Send + Sync { fn simplify(&self) -> Option { None } + + /// Return true if this window UDF is equal to the other. + /// + /// Allows customizing the equality of window UDFs. + /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: + /// + /// - reflexive: `a.equals(a)`; + /// - symmetric: `a.equals(b)` implies `b.equals(a)`; + /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. + /// + /// By default, compares [`Self::name`] and [`Self::signature`]. + fn equals(&self, other: &dyn WindowUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Returns a hash value for this window UDF. + /// + /// Allows customizing the hash code of window UDFs. Similarly to [`Hash`] and [`Eq`], + /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. + /// + /// By default, hashes [`Self::name`] and [`Self::signature`]. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } } /// WindowUDF that adds an alias to the underlying function. It is better to @@ -342,6 +372,21 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { fn aliases(&self) -> &[String] { &self.aliases } + + fn equals(&self, other: &dyn WindowUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.inner.hash_value().hash(hasher); + self.aliases.hash(hasher); + hasher.finish() + } } /// Implementation of [`WindowUDFImpl`] that wraps the function style pointers From 5ba634aa4f6d3d4ed5eefbc15dba5448f4f30923 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 12 Jul 2024 14:43:49 +0800 Subject: [PATCH 19/59] Implement ScalarFunction `MAKE_MAP` and `MAP` (#11361) * tmp * opt * modify test * add another version * implement make_map function * implement make_map function * implement map function * format and modify the doc * add benchmark for map function * add empty end-line * fix cargo check * update lock * upate lock * fix clippy * fmt and clippy * support FixedSizeList and LargeList * check type and handle null array in coerce_types * make array value throw todo error * fix clippy * simpify the error tests --- datafusion-cli/Cargo.lock | 1 + datafusion/functions/Cargo.toml | 7 +- datafusion/functions/benches/map.rs | 101 +++++++ datafusion/functions/src/core/map.rs | 312 +++++++++++++++++++++ datafusion/functions/src/core/mod.rs | 13 + datafusion/sqllogictest/test_files/map.slt | 112 ++++++++ 6 files changed, 545 insertions(+), 1 deletion(-) create mode 100644 datafusion/functions/benches/map.rs create mode 100644 datafusion/functions/src/core/map.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 8af42cb43932e..7da9cc427c37d 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1278,6 +1278,7 @@ name = "datafusion-functions" version = "40.0.0" dependencies = [ "arrow", + "arrow-buffer", "base64 0.22.1", "blake2", "blake3", diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 884a66724c91e..b143080b19626 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -66,6 +66,7 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } +arrow-buffer = { workspace = true } base64 = { version = "0.22", optional = true } blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.0", optional = true } @@ -86,7 +87,6 @@ uuid = { version = "1.7", features = ["v4"], optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } -arrow-buffer = { workspace = true } criterion = "0.5" rand = { workspace = true } rstest = { workspace = true } @@ -141,3 +141,8 @@ required-features = ["string_expressions"] harness = false name = "upper" required-features = ["string_expressions"] + +[[bench]] +harness = false +name = "map" +required-features = ["core_expressions"] diff --git a/datafusion/functions/benches/map.rs b/datafusion/functions/benches/map.rs new file mode 100644 index 0000000000000..cd863d0e33114 --- /dev/null +++ b/datafusion/functions/benches/map.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{Int32Array, ListArray, StringArray}; +use arrow::datatypes::{DataType, Field}; +use arrow_buffer::{OffsetBuffer, ScalarBuffer}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::ScalarValue; +use datafusion_expr::ColumnarValue; +use datafusion_functions::core::{make_map, map}; +use rand::prelude::ThreadRng; +use rand::Rng; +use std::sync::Arc; + +fn keys(rng: &mut ThreadRng) -> Vec { + let mut keys = vec![]; + for _ in 0..1000 { + keys.push(rng.gen_range(0..9999).to_string()); + } + keys +} + +fn values(rng: &mut ThreadRng) -> Vec { + let mut values = vec![]; + for _ in 0..1000 { + values.push(rng.gen_range(0..9999)); + } + values +} + +fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("make_map_1000", |b| { + let mut rng = rand::thread_rng(); + let keys = keys(&mut rng); + let values = values(&mut rng); + let mut buffer = Vec::new(); + for i in 0..1000 { + buffer.push(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + keys[i].clone(), + )))); + buffer.push(ColumnarValue::Scalar(ScalarValue::Int32(Some(values[i])))); + } + + b.iter(|| { + black_box( + make_map() + .invoke(&buffer) + .expect("map should work on valid values"), + ); + }); + }); + + c.bench_function("map_1000", |b| { + let mut rng = rand::thread_rng(); + let field = Arc::new(Field::new("item", DataType::Utf8, true)); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 1000])); + let key_list = ListArray::new( + field, + offsets, + Arc::new(StringArray::from(keys(&mut rng))), + None, + ); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 1000])); + let value_list = ListArray::new( + field, + offsets, + Arc::new(Int32Array::from(values(&mut rng))), + None, + ); + let keys = ColumnarValue::Scalar(ScalarValue::List(Arc::new(key_list))); + let values = ColumnarValue::Scalar(ScalarValue::List(Arc::new(value_list))); + + b.iter(|| { + black_box( + map() + .invoke(&[keys.clone(), values.clone()]) + .expect("map should work on valid values"), + ); + }); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/core/map.rs b/datafusion/functions/src/core/map.rs new file mode 100644 index 0000000000000..8a8a19d7af52b --- /dev/null +++ b/datafusion/functions/src/core/map.rs @@ -0,0 +1,312 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::collections::VecDeque; +use std::sync::Arc; + +use arrow::array::{Array, ArrayData, ArrayRef, MapArray, StructArray}; +use arrow::compute::concat; +use arrow::datatypes::{DataType, Field, SchemaBuilder}; +use arrow_buffer::{Buffer, ToByteSlice}; + +use datafusion_common::{exec_err, internal_err, ScalarValue}; +use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +fn make_map(args: &[ColumnarValue]) -> Result { + let (key, value): (Vec<_>, Vec<_>) = args + .chunks_exact(2) + .map(|chunk| { + if let ColumnarValue::Array(_) = chunk[0] { + return not_impl_err!("make_map does not support array keys"); + } + if let ColumnarValue::Array(_) = chunk[1] { + return not_impl_err!("make_map does not support array values"); + } + Ok((chunk[0].clone(), chunk[1].clone())) + }) + .collect::>>()? + .into_iter() + .unzip(); + + let keys = ColumnarValue::values_to_arrays(&key)?; + let values = ColumnarValue::values_to_arrays(&value)?; + + let keys: Vec<_> = keys.iter().map(|k| k.as_ref()).collect(); + let values: Vec<_> = values.iter().map(|v| v.as_ref()).collect(); + + let key = match concat(&keys) { + Ok(key) => key, + Err(e) => return internal_err!("Error concatenating keys: {}", e), + }; + let value = match concat(&values) { + Ok(value) => value, + Err(e) => return internal_err!("Error concatenating values: {}", e), + }; + make_map_batch_internal(key, value) +} + +fn make_map_batch(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return exec_err!( + "make_map requires exactly 2 arguments, got {} instead", + args.len() + ); + } + let key = get_first_array_ref(&args[0])?; + let value = get_first_array_ref(&args[1])?; + make_map_batch_internal(key, value) +} + +fn get_first_array_ref(columnar_value: &ColumnarValue) -> Result { + match columnar_value { + ColumnarValue::Scalar(value) => match value { + ScalarValue::List(array) => Ok(array.value(0).clone()), + ScalarValue::LargeList(array) => Ok(array.value(0).clone()), + ScalarValue::FixedSizeList(array) => Ok(array.value(0).clone()), + _ => exec_err!("Expected array, got {:?}", value), + }, + ColumnarValue::Array(array) => exec_err!("Expected scalar, got {:?}", array), + } +} + +fn make_map_batch_internal(keys: ArrayRef, values: ArrayRef) -> Result { + if keys.null_count() > 0 { + return exec_err!("map key cannot be null"); + } + + if keys.len() != values.len() { + return exec_err!("map requires key and value lists to have the same length"); + } + + let key_field = Arc::new(Field::new("key", keys.data_type().clone(), false)); + let value_field = Arc::new(Field::new("value", values.data_type().clone(), true)); + let mut entry_struct_buffer: VecDeque<(Arc, ArrayRef)> = VecDeque::new(); + let mut entry_offsets_buffer = VecDeque::new(); + entry_offsets_buffer.push_back(0); + + entry_struct_buffer.push_back((Arc::clone(&key_field), Arc::clone(&keys))); + entry_struct_buffer.push_back((Arc::clone(&value_field), Arc::clone(&values))); + entry_offsets_buffer.push_back(keys.len() as u32); + + let entry_struct: Vec<(Arc, ArrayRef)> = entry_struct_buffer.into(); + let entry_struct = StructArray::from(entry_struct); + + let map_data_type = DataType::Map( + Arc::new(Field::new( + "entries", + entry_struct.data_type().clone(), + false, + )), + false, + ); + + let entry_offsets: Vec = entry_offsets_buffer.into(); + let entry_offsets_buffer = Buffer::from(entry_offsets.to_byte_slice()); + + let map_data = ArrayData::builder(map_data_type) + .len(entry_offsets.len() - 1) + .add_buffer(entry_offsets_buffer) + .add_child_data(entry_struct.to_data()) + .build()?; + + Ok(ColumnarValue::Array(Arc::new(MapArray::from(map_data)))) +} + +#[derive(Debug)] +pub struct MakeMap { + signature: Signature, +} + +impl Default for MakeMap { + fn default() -> Self { + Self::new() + } +} + +impl MakeMap { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for MakeMap { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "make_map" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.is_empty() { + return exec_err!( + "make_map requires at least one pair of arguments, got 0 instead" + ); + } + if arg_types.len() % 2 != 0 { + return exec_err!( + "make_map requires an even number of arguments, got {} instead", + arg_types.len() + ); + } + + let key_type = &arg_types[0]; + let mut value_type = &arg_types[1]; + + for (i, chunk) in arg_types.chunks_exact(2).enumerate() { + if chunk[0].is_null() { + return exec_err!("make_map key cannot be null at position {}", i); + } + if &chunk[0] != key_type { + return exec_err!( + "make_map requires all keys to have the same type {}, got {} instead at position {}", + key_type, + chunk[0], + i + ); + } + + if !chunk[1].is_null() { + if value_type.is_null() { + value_type = &chunk[1]; + } else if &chunk[1] != value_type { + return exec_err!( + "map requires all values to have the same type {}, got {} instead at position {}", + value_type, + &chunk[1], + i + ); + } + } + } + + let mut result = Vec::new(); + for _ in 0..arg_types.len() / 2 { + result.push(key_type.clone()); + result.push(value_type.clone()); + } + + Ok(result) + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let key_type = &arg_types[0]; + let mut value_type = &arg_types[1]; + + for chunk in arg_types.chunks_exact(2) { + if !chunk[1].is_null() && value_type.is_null() { + value_type = &chunk[1]; + } + } + + let mut builder = SchemaBuilder::new(); + builder.push(Field::new("key", key_type.clone(), false)); + builder.push(Field::new("value", value_type.clone(), true)); + let fields = builder.finish().fields; + Ok(DataType::Map( + Arc::new(Field::new("entries", DataType::Struct(fields), false)), + false, + )) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_map(args) + } +} + +#[derive(Debug)] +pub struct MapFunc { + signature: Signature, +} + +impl Default for MapFunc { + fn default() -> Self { + Self::new() + } +} + +impl MapFunc { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for MapFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "map" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() % 2 != 0 { + return exec_err!( + "map requires an even number of arguments, got {} instead", + arg_types.len() + ); + } + let mut builder = SchemaBuilder::new(); + builder.push(Field::new( + "key", + get_element_type(&arg_types[0])?.clone(), + false, + )); + builder.push(Field::new( + "value", + get_element_type(&arg_types[1])?.clone(), + true, + )); + let fields = builder.finish().fields; + Ok(DataType::Map( + Arc::new(Field::new("entries", DataType::Struct(fields), false)), + false, + )) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_map_batch(args) + } +} + +fn get_element_type(data_type: &DataType) -> Result<&DataType> { + match data_type { + DataType::List(element) => Ok(element.data_type()), + DataType::LargeList(element) => Ok(element.data_type()), + DataType::FixedSizeList(element, _) => Ok(element.data_type()), + _ => exec_err!( + "Expected list, large_list or fixed_size_list, got {:?}", + data_type + ), + } +} diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 062a4a104d54a..31bce04beec1b 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -25,6 +25,7 @@ pub mod arrowtypeof; pub mod coalesce; pub mod expr_ext; pub mod getfield; +pub mod map; pub mod named_struct; pub mod nullif; pub mod nvl; @@ -42,6 +43,8 @@ make_udf_function!(r#struct::StructFunc, STRUCT, r#struct); make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct); make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce); +make_udf_function!(map::MakeMap, MAKE_MAP, make_map); +make_udf_function!(map::MapFunc, MAP, map); pub mod expr_fn { use datafusion_expr::{Expr, Literal}; @@ -78,6 +81,14 @@ pub mod expr_fn { coalesce, "Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL", args, + ),( + make_map, + "Returns a map created from the given keys and values pairs. This function isn't efficient for large maps. Use the `map` function instead.", + args, + ),( + map, + "Returns a map created from a key list and a value list", + args, )); #[doc = "Returns the value of the field with the given name from the struct"] @@ -96,5 +107,7 @@ pub fn functions() -> Vec> { named_struct(), get_field(), coalesce(), + make_map(), + map(), ] } diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 417947dc6c89b..abf5b2ebbf98e 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -100,3 +100,115 @@ physical_plan statement ok drop table table_with_map; + +query ? +SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', 30, 'OPTION', 29, 'GET', 27, 'PUT', 25, 'DELETE', 24) AS method_count; +---- +{POST: 41, HEAD: 33, PATCH: 30, OPTION: 29, GET: 27, PUT: 25, DELETE: 24} + +query I +SELECT MAKE_MAP('POST', 41, 'HEAD', 33)['POST']; +---- +41 + +query ? +SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', null); +---- +{POST: 41, HEAD: 33, PATCH: } + +query ? +SELECT MAKE_MAP('POST', null, 'HEAD', 33, 'PATCH', null); +---- +{POST: , HEAD: 33, PATCH: } + +query ? +SELECT MAKE_MAP(1, null, 2, 33, 3, null); +---- +{1: , 2: 33, 3: } + +query ? +SELECT MAKE_MAP([1,2], ['a', 'b'], [3,4], ['b']); +---- +{[1, 2]: [a, b], [3, 4]: [b]} + +query error +SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30); + +query error +SELECT MAKE_MAP('POST', 41, 'HEAD', 33, null, 30); + +query error +SELECT MAKE_MAP('POST', 41, 123, 33,'PATCH', 30); + +query error +SELECT MAKE_MAP() + +query error +SELECT MAKE_MAP('POST', 41, 'HEAD'); + +query ? +SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, 30]); +---- +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]); +---- +{POST: 41, HEAD: 33, PATCH: } + +query ? +SELECT MAP([[1,2], [3,4]], ['a', 'b']); +---- +{[1, 2]: a, [3, 4]: b} + +query error +SELECT MAP() + +query error DataFusion error: Execution error: map requires an even number of arguments, got 1 instead +SELECT MAP(['POST', 'HEAD']) + +query error DataFusion error: Execution error: Expected list, large_list or fixed_size_list, got Null +SELECT MAP(null, [41, 33, 30]); + +query error DataFusion error: Execution error: map requires key and value lists to have the same length +SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33]); + +query error DataFusion error: Execution error: map key cannot be null +SELECT MAP(['POST', 'HEAD', null], [41, 33, 30]); + +query ? +SELECT MAP(make_array('POST', 'HEAD', 'PATCH'), make_array(41, 33, 30)); +---- +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'FixedSizeList(3, Utf8)'), arrow_cast(make_array(41, 33, 30), 'FixedSizeList(3, Int64)')); +---- +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'LargeList(Utf8)'), arrow_cast(make_array(41, 33, 30), 'LargeList(Int64)')); +---- +{POST: 41, HEAD: 33, PATCH: 30} + +statement ok +create table t as values +('a', 1, 'k1', 10, ['k1', 'k2'], [1, 2]), +('b', 2, 'k3', 30, ['k3'], [3]), +('d', 4, 'k5', 50, ['k5'], [5]); + +query error +SELECT make_map(column1, column2, column3, column4) FROM t; +# TODO: support array value +# ---- +# {a: 1, k1: 10} +# {b: 2, k3: 30} +# {d: 4, k5: 50} + +query error +SELECT map(column5, column6) FROM t; +# TODO: support array value +# ---- +# {k1:1, k2:2} +# {k3: 3} +# {k5: 5} From d542cbda8f17ba004de18bb107ecf1c8ec3266f6 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 12 Jul 2024 12:53:05 +0200 Subject: [PATCH 20/59] Improve `CommonSubexprEliminate` rule with surely and conditionally evaluated stats (#11357) * Improve `CommonSubexprEliminate` rule with surely and conditionally evaluated stats * remove expression tree hashing as no longer needed * address review comments * add negative tests --- datafusion/expr/src/expr.rs | 39 ++- .../optimizer/src/common_subexpr_eliminate.rs | 256 +++++++++++------- .../optimizer/src/optimize_projections/mod.rs | 10 +- datafusion/sqllogictest/test_files/cse.slt | 88 +++++- datafusion/sqllogictest/test_files/select.slt | 20 +- .../sqllogictest/test_files/tpch/q14.slt.part | 33 +-- 6 files changed, 298 insertions(+), 148 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index ecece6dbfce7f..a344e621ddb12 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -17,7 +17,7 @@ //! Logical Expressions: [`Expr`] -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::fmt::{self, Display, Formatter, Write}; use std::hash::{Hash, Hasher}; use std::mem; @@ -1380,7 +1380,7 @@ impl Expr { /// // refs contains "a" and "b" /// assert_eq!(refs.len(), 2); /// assert!(refs.contains(&Column::new_unqualified("a"))); - /// assert!(refs.contains(&Column::new_unqualified("b"))); + /// assert!(refs.contains(&Column::new_unqualified("b"))); /// ``` pub fn column_refs(&self) -> HashSet<&Column> { let mut using_columns = HashSet::new(); @@ -1401,6 +1401,41 @@ impl Expr { .expect("traversal is infallable"); } + /// Return all references to columns and their occurrence counts in the expression. + /// + /// # Example + /// ``` + /// # use std::collections::HashMap; + /// # use datafusion_common::Column; + /// # use datafusion_expr::col; + /// // For an expression `a + (b * a)` + /// let expr = col("a") + (col("b") * col("a")); + /// let mut refs = expr.column_refs_counts(); + /// // refs contains "a" and "b" + /// assert_eq!(refs.len(), 2); + /// assert_eq!(*refs.get(&Column::new_unqualified("a")).unwrap(), 2); + /// assert_eq!(*refs.get(&Column::new_unqualified("b")).unwrap(), 1); + /// ``` + pub fn column_refs_counts(&self) -> HashMap<&Column, usize> { + let mut map = HashMap::new(); + self.add_column_ref_counts(&mut map); + map + } + + /// Adds references to all columns and their occurrence counts in the expression to + /// the map. + /// + /// See [`Self::column_refs_counts`] for details + pub fn add_column_ref_counts<'a>(&'a self, map: &mut HashMap<&'a Column, usize>) { + self.apply(|expr| { + if let Expr::Column(col) = expr { + *map.entry(col).or_default() += 1; + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("traversal is infallable"); + } + /// Returns true if there are any column references in this Expr pub fn any_column_refs(&self) -> bool { self.exists(|expr| Ok(matches!(expr, Expr::Column(_)))) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 721987b917d4c..e4b36652974d7 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -33,12 +33,12 @@ use datafusion_common::tree_node::{ use datafusion_common::{ internal_datafusion_err, qualified_name, Column, DFSchema, DFSchemaRef, Result, }; -use datafusion_expr::expr::Alias; +use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; -use datafusion_expr::{col, Expr, ExprSchemable}; +use datafusion_expr::{col, BinaryExpr, Case, Expr, ExprSchemable, Operator}; use indexmap::IndexMap; const CSE_PREFIX: &str = "__common_expr"; @@ -56,13 +56,9 @@ struct Identifier<'n> { } impl<'n> Identifier<'n> { - fn new(expr: &'n Expr, is_tree: bool, random_state: &RandomState) -> Self { + fn new(expr: &'n Expr, random_state: &RandomState) -> Self { let mut hasher = random_state.build_hasher(); - if is_tree { - expr.hash(&mut hasher); - } else { - expr.hash_node(&mut hasher); - } + expr.hash_node(&mut hasher); let hash = hasher.finish(); Self { hash, expr } } @@ -110,8 +106,9 @@ impl Hash for Identifier<'_> { /// ``` type IdArray<'n> = Vec<(usize, Option>)>; -/// A map that contains the number of occurrences of expressions by their identifiers. -type ExprStats<'n> = HashMap, usize>; +/// A map that contains the number of normal and conditional occurrences of expressions by +/// their identifiers. +type ExprStats<'n> = HashMap, (usize, usize)>; /// A map that contains the common expressions and their alias extracted during the /// second, rewriting traversal. @@ -200,6 +197,7 @@ impl CommonSubexprEliminate { expr_mask, random_state: &self.random_state, found_common: false, + conditional: false, }; expr.visit(&mut visitor)?; @@ -901,15 +899,17 @@ struct ExprIdentifierVisitor<'a, 'n> { random_state: &'a RandomState, // a flag to indicate that common expression found found_common: bool, + // if we are in a conditional branch. A conditional branch means that the expression + // might not be executed depending on the runtime values of other expressions, and + // thus can not be extracted as a common expression. + conditional: bool, } /// Record item that used when traversing an expression tree. enum VisitRecord<'n> { /// Marks the beginning of expression. It contains: /// - The post-order index assigned during the first, visiting traversal. - /// - A boolean flag if the record marks an expression subtree (not just a single - /// node). - EnterMark(usize, bool), + EnterMark(usize), /// Marks an accumulated subexpression tree. It contains: /// - The accumulated identifier of a subexpression. @@ -924,10 +924,6 @@ impl<'n> ExprIdentifierVisitor<'_, 'n> { /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` before /// it. Returns a tuple that contains: /// - The pre-order index of the expression we marked. - /// - A boolean flag if we marked an expression subtree (not just a single node). - /// If true we didn't recurse into the node's children, so we need to calculate the - /// hash of the marked expression tree (not just the node) and we need to validate - /// the expression tree (not just the node). /// - The accumulated identifier of the children of the marked expression. /// - An accumulated boolean flag from the children of the marked expression if all /// children are valid for subexpression elimination (i.e. it is safe to extract the @@ -937,14 +933,14 @@ impl<'n> ExprIdentifierVisitor<'_, 'n> { /// information up from children to parents via `visit_stack` during the first, /// visiting traversal and no need to test the expression's validity beforehand with /// an extra traversal). - fn pop_enter_mark(&mut self) -> (usize, bool, Option>, bool) { + fn pop_enter_mark(&mut self) -> (usize, Option>, bool) { let mut expr_id = None; let mut is_valid = true; while let Some(item) = self.visit_stack.pop() { match item { - VisitRecord::EnterMark(down_index, is_tree) => { - return (down_index, is_tree, expr_id, is_valid); + VisitRecord::EnterMark(down_index) => { + return (down_index, expr_id, is_valid); } VisitRecord::ExprItem(sub_expr_id, sub_expr_is_valid) => { expr_id = Some(sub_expr_id.combine(expr_id)); @@ -954,53 +950,112 @@ impl<'n> ExprIdentifierVisitor<'_, 'n> { } unreachable!("Enter mark should paired with node number"); } + + /// Save the current `conditional` status and run `f` with `conditional` set to true. + fn conditionally Result<()>>( + &mut self, + mut f: F, + ) -> Result<()> { + let conditional = self.conditional; + self.conditional = true; + f(self)?; + self.conditional = conditional; + + Ok(()) + } } impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { type Node = Expr; fn f_down(&mut self, expr: &'n Expr) -> Result { - // If an expression can short circuit its children then don't consider its - // children for CSE (https://github.com/apache/arrow-datafusion/issues/8814). - // This means that we don't recurse into its children, but handle the expression - // as a subtree when we calculate its identifier. - // TODO: consider surely executed children of "short circuited"s for CSE - let is_tree = expr.short_circuits(); - let tnr = if is_tree { - TreeNodeRecursion::Jump - } else { - TreeNodeRecursion::Continue - }; - self.id_array.push((0, None)); self.visit_stack - .push(VisitRecord::EnterMark(self.down_index, is_tree)); + .push(VisitRecord::EnterMark(self.down_index)); self.down_index += 1; - Ok(tnr) + // If an expression can short-circuit then some of its children might not be + // executed so count the occurrence of subexpressions as conditional in all + // children. + Ok(match expr { + // If we are already in a conditionally evaluated subtree then continue + // traversal. + _ if self.conditional => TreeNodeRecursion::Continue, + + // In case of `ScalarFunction`s we don't know which children are surely + // executed so start visiting all children conditionally and stop the + // recursion with `TreeNodeRecursion::Jump`. + Expr::ScalarFunction(ScalarFunction { func, args }) + if func.short_circuits() => + { + self.conditionally(|visitor| { + args.iter().try_for_each(|e| e.visit(visitor).map(|_| ())) + })?; + + TreeNodeRecursion::Jump + } + + // In case of `And` and `Or` the first child is surely executed, but we + // account subexpressions as conditional in the second. + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And | Operator::Or, + right, + }) => { + left.visit(self)?; + self.conditionally(|visitor| right.visit(visitor).map(|_| ()))?; + + TreeNodeRecursion::Jump + } + + // In case of `Case` the optional base expression and the first when + // expressions are surely executed, but we account subexpressions as + // conditional in the others. + Expr::Case(Case { + expr, + when_then_expr, + else_expr, + }) => { + expr.iter().try_for_each(|e| e.visit(self).map(|_| ()))?; + when_then_expr.iter().take(1).try_for_each(|(when, then)| { + when.visit(self)?; + self.conditionally(|visitor| then.visit(visitor).map(|_| ())) + })?; + self.conditionally(|visitor| { + when_then_expr.iter().skip(1).try_for_each(|(when, then)| { + when.visit(visitor)?; + then.visit(visitor).map(|_| ()) + })?; + else_expr + .iter() + .try_for_each(|e| e.visit(visitor).map(|_| ())) + })?; + + TreeNodeRecursion::Jump + } + + // In case of non-short-circuit expressions continue the traversal. + _ => TreeNodeRecursion::Continue, + }) } fn f_up(&mut self, expr: &'n Expr) -> Result { - let (down_index, is_tree, sub_expr_id, sub_expr_is_valid) = self.pop_enter_mark(); + let (down_index, sub_expr_id, sub_expr_is_valid) = self.pop_enter_mark(); - let (expr_id, is_valid) = if is_tree { - ( - Identifier::new(expr, true, self.random_state), - !expr.is_volatile()?, - ) - } else { - ( - Identifier::new(expr, false, self.random_state).combine(sub_expr_id), - !expr.is_volatile_node() && sub_expr_is_valid, - ) - }; + let expr_id = Identifier::new(expr, self.random_state).combine(sub_expr_id); + let is_valid = !expr.is_volatile_node() && sub_expr_is_valid; self.id_array[down_index].0 = self.up_index; if is_valid && !self.expr_mask.ignores(expr) { self.id_array[down_index].1 = Some(expr_id); - let count = self.expr_stats.entry(expr_id).or_insert(0); - *count += 1; - if *count > 1 { + let (count, conditional_count) = + self.expr_stats.entry(expr_id).or_insert((0, 0)); + if self.conditional { + *conditional_count += 1; + } else { + *count += 1; + } + if *count > 1 || (*count == 1 && *conditional_count > 0) { self.found_common = true; } } @@ -1039,51 +1094,40 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> { self.alias_counter += 1; } - // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate the - // `id_array`, which records the expr's identifier used to rewrite expr. So if we - // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. - let is_tree = expr.short_circuits(); - let tnr = if is_tree { - TreeNodeRecursion::Jump - } else { - TreeNodeRecursion::Continue - }; - let (up_index, expr_id) = self.id_array[self.down_index]; self.down_index += 1; - // skip `Expr`s without identifier (empty identifier). - let Some(expr_id) = expr_id else { - return Ok(Transformed::new(expr, false, tnr)); - }; - - let count = self.expr_stats.get(&expr_id).unwrap(); - if *count > 1 { - // step index to skip all sub-node (which has smaller series number). - while self.down_index < self.id_array.len() - && self.id_array[self.down_index].0 < up_index - { - self.down_index += 1; - } + // Handle `Expr`s with identifiers only + if let Some(expr_id) = expr_id { + let (count, conditional_count) = self.expr_stats.get(&expr_id).unwrap(); + if *count > 1 || *count == 1 && *conditional_count > 0 { + // step index to skip all sub-node (which has smaller series number). + while self.down_index < self.id_array.len() + && self.id_array[self.down_index].0 < up_index + { + self.down_index += 1; + } - let expr_name = expr.display_name()?; - let (_, expr_alias) = self.common_exprs.entry(expr_id).or_insert_with(|| { - let expr_alias = self.alias_generator.next(CSE_PREFIX); - (expr, expr_alias) - }); + let expr_name = expr.display_name()?; + let (_, expr_alias) = + self.common_exprs.entry(expr_id).or_insert_with(|| { + let expr_alias = self.alias_generator.next(CSE_PREFIX); + (expr, expr_alias) + }); - // alias the expressions without an `Alias` ancestor node - let rewritten = if self.alias_counter > 0 { - col(expr_alias.clone()) - } else { - self.alias_counter += 1; - col(expr_alias.clone()).alias(expr_name) - }; + // alias the expressions without an `Alias` ancestor node + let rewritten = if self.alias_counter > 0 { + col(expr_alias.clone()) + } else { + self.alias_counter += 1; + col(expr_alias.clone()).alias(expr_name) + }; - Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)) - } else { - Ok(Transformed::new(expr, false, tnr)) + return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)); + } } + + Ok(Transformed::no(expr)) } fn f_up(&mut self, expr: Expr) -> Result> { @@ -1685,7 +1729,7 @@ mod test { .unwrap(); let rule = CommonSubexprEliminate::new(); let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); - assert!(!optimized_plan.transformed); + assert!(optimized_plan.transformed); let optimized_plan = optimized_plan.data; let schema = optimized_plan.schema(); @@ -1837,22 +1881,29 @@ mod test { let table_scan = test_table_scan()?; let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0))); - let not_extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0)); + let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0)); let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0)); + let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0)); let plan = LogicalPlanBuilder::from(table_scan.clone()) .project(vec![ extracted_short_circuit.clone().alias("c1"), extracted_short_circuit.alias("c2"), - not_extracted_short_circuit_leg_1.clone().alias("c3"), - not_extracted_short_circuit_leg_2.clone().alias("c4"), - not_extracted_short_circuit_leg_1 - .or(not_extracted_short_circuit_leg_2) + extracted_short_circuit_leg_1 + .clone() + .or(not_extracted_short_circuit_leg_2.clone()) + .alias("c3"), + extracted_short_circuit_leg_1 + .and(not_extracted_short_circuit_leg_2) + .alias("c4"), + extracted_short_circuit_leg_3 + .clone() + .or(extracted_short_circuit_leg_3.clone()) .alias("c5"), ])? .build()?; - let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, test.a + test.b = Int32(0) AS c3, test.a - test.b = Int32(0) AS c4, test.a + test.b = Int32(0) OR test.a - test.b = Int32(0) AS c5\ - \n Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a, test.b, test.c\ + let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5\ + \n Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c\ \n TableScan: test"; assert_optimized_plan_eq(expected, plan, None); @@ -1888,10 +1939,12 @@ mod test { let table_scan = test_table_scan()?; let rand = rand_func().call(vec![]); - let not_extracted_volatile_short_circuit_2 = - rand.clone().eq(lit(0)).or(col("b").eq(lit(0))); + let extracted_short_circuit_leg_1 = col("a").eq(lit(0)); let not_extracted_volatile_short_circuit_1 = - col("a").eq(lit(0)).or(rand.eq(lit(0))); + extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0))); + let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0)); + let not_extracted_volatile_short_circuit_2 = + rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2); let plan = LogicalPlanBuilder::from(table_scan.clone()) .project(vec![ not_extracted_volatile_short_circuit_1.clone().alias("c1"), @@ -1901,10 +1954,11 @@ mod test { ])? .build()?; - let expected = "Projection: test.a = Int32(0) OR random() = Int32(0) AS c1, test.a = Int32(0) OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4\ - \n TableScan: test"; + let expected = "Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4\ + \n Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; - assert_non_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_eq(expected, plan, None); Ok(()) } diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index cae2a7b2cad2f..58c1ae297b02e 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -19,7 +19,7 @@ mod required_indices; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use crate::optimizer::ApplyOrder; @@ -42,7 +42,6 @@ use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, }; use datafusion_expr::logical_plan::tree_node::unwrap_arc; -use hashbrown::HashMap; /// Optimizer rule to prune unnecessary columns from intermediate schemas /// inside the [`LogicalPlan`]. This rule: @@ -472,11 +471,8 @@ fn merge_consecutive_projections(proj: Projection) -> Result::new(); - for columns in expr.iter().map(|expr| expr.column_refs()) { - for col in columns.into_iter() { - *column_referral_map.entry(col).or_default() += 1; - } - } + expr.iter() + .for_each(|expr| expr.add_column_ref_counts(&mut column_referral_map)); // If an expression is non-trivial and appears more than once, do not merge // them as consecutive projections will benefit from a compute-once approach. diff --git a/datafusion/sqllogictest/test_files/cse.slt b/datafusion/sqllogictest/test_files/cse.slt index 3579c1c1635cb..19b47fa50e410 100644 --- a/datafusion/sqllogictest/test_files/cse.slt +++ b/datafusion/sqllogictest/test_files/cse.slt @@ -93,15 +93,16 @@ FROM t1 ---- logical_plan 01)Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4, __common_expr_3 AS c5, __common_expr_3 AS c6 -02)--Projection: t1.a = Float64(0) AND t1.b = Float64(0) AS __common_expr_1, t1.a = Float64(0) OR t1.b = Float64(0) AS __common_expr_2, CASE WHEN t1.a = Float64(0) THEN Int64(0) ELSE Int64(1) END AS __common_expr_3 -03)----TableScan: t1 projection=[a, b] +02)--Projection: __common_expr_4 AND t1.b = Float64(0) AS __common_expr_1, __common_expr_4 OR t1.b = Float64(0) AS __common_expr_2, CASE WHEN __common_expr_4 THEN Int64(0) ELSE Int64(1) END AS __common_expr_3 +03)----Projection: t1.a = Float64(0) AS __common_expr_4, t1.b +04)------TableScan: t1 projection=[a, b] physical_plan 01)ProjectionExec: expr=[__common_expr_1@0 as c1, __common_expr_1@0 as c2, __common_expr_2@1 as c3, __common_expr_2@1 as c4, __common_expr_3@2 as c5, __common_expr_3@2 as c6] -02)--ProjectionExec: expr=[a@0 = 0 AND b@1 = 0 as __common_expr_1, a@0 = 0 OR b@1 = 0 as __common_expr_2, CASE WHEN a@0 = 0 THEN 0 ELSE 1 END as __common_expr_3] -03)----MemoryExec: partitions=1, partition_sizes=[0] +02)--ProjectionExec: expr=[__common_expr_4@0 AND b@1 = 0 as __common_expr_1, __common_expr_4@0 OR b@1 = 0 as __common_expr_2, CASE WHEN __common_expr_4@0 THEN 0 ELSE 1 END as __common_expr_3] +03)----ProjectionExec: expr=[a@0 = 0 as __common_expr_4, b@1 as b] +04)------MemoryExec: partitions=1, partition_sizes=[0] # Common children of short-circuit expression -# TODO: consider surely executed children of "short circuited"s for CSE. i.e. `a = 0`, `a = 2`, `a = 4` should be extracted query TT EXPLAIN SELECT a = 0 AND b = 0 AS c1, @@ -121,14 +122,15 @@ EXPLAIN SELECT FROM t1 ---- logical_plan -01)Projection: t1.a = Float64(0) AND t1.b = Float64(0) AS c1, t1.a = Float64(0) AND t1.b = Float64(1) AS c2, t1.b = Float64(2) AND t1.a = Float64(1) AS c3, t1.b = Float64(3) AND t1.a = Float64(1) AS c4, t1.a = Float64(2) OR t1.b = Float64(4) AS c5, t1.a = Float64(2) OR t1.b = Float64(5) AS c6, t1.b = Float64(6) OR t1.a = Float64(3) AS c7, t1.b = Float64(7) OR t1.a = Float64(3) AS c8, CASE WHEN t1.a = Float64(4) THEN Int64(0) ELSE Int64(1) END AS c9, CASE WHEN t1.a = Float64(4) THEN Int64(0) ELSE Int64(2) END AS c10, CASE WHEN t1.b = Float64(8) THEN t1.a + Float64(1) ELSE Float64(0) END AS c11, CASE WHEN t1.b = Float64(9) THEN t1.a + Float64(1) ELSE Float64(0) END AS c12, CASE WHEN t1.b = Float64(10) THEN Float64(0) ELSE t1.a + Float64(2) END AS c13, CASE WHEN t1.b = Float64(11) THEN Float64(0) ELSE t1.a + Float64(2) END AS c14 -02)--TableScan: t1 projection=[a, b] +01)Projection: __common_expr_1 AND t1.b = Float64(0) AS c1, __common_expr_1 AND t1.b = Float64(1) AS c2, t1.b = Float64(2) AND t1.a = Float64(1) AS c3, t1.b = Float64(3) AND t1.a = Float64(1) AS c4, __common_expr_2 OR t1.b = Float64(4) AS c5, __common_expr_2 OR t1.b = Float64(5) AS c6, t1.b = Float64(6) OR t1.a = Float64(3) AS c7, t1.b = Float64(7) OR t1.a = Float64(3) AS c8, CASE WHEN __common_expr_3 THEN Int64(0) ELSE Int64(1) END AS c9, CASE WHEN __common_expr_3 THEN Int64(0) ELSE Int64(2) END AS c10, CASE WHEN t1.b = Float64(8) THEN t1.a + Float64(1) ELSE Float64(0) END AS c11, CASE WHEN t1.b = Float64(9) THEN t1.a + Float64(1) ELSE Float64(0) END AS c12, CASE WHEN t1.b = Float64(10) THEN Float64(0) ELSE t1.a + Float64(2) END AS c13, CASE WHEN t1.b = Float64(11) THEN Float64(0) ELSE t1.a + Float64(2) END AS c14 +02)--Projection: t1.a = Float64(0) AS __common_expr_1, t1.a = Float64(2) AS __common_expr_2, t1.a = Float64(4) AS __common_expr_3, t1.a, t1.b +03)----TableScan: t1 projection=[a, b] physical_plan -01)ProjectionExec: expr=[a@0 = 0 AND b@1 = 0 as c1, a@0 = 0 AND b@1 = 1 as c2, b@1 = 2 AND a@0 = 1 as c3, b@1 = 3 AND a@0 = 1 as c4, a@0 = 2 OR b@1 = 4 as c5, a@0 = 2 OR b@1 = 5 as c6, b@1 = 6 OR a@0 = 3 as c7, b@1 = 7 OR a@0 = 3 as c8, CASE WHEN a@0 = 4 THEN 0 ELSE 1 END as c9, CASE WHEN a@0 = 4 THEN 0 ELSE 2 END as c10, CASE WHEN b@1 = 8 THEN a@0 + 1 ELSE 0 END as c11, CASE WHEN b@1 = 9 THEN a@0 + 1 ELSE 0 END as c12, CASE WHEN b@1 = 10 THEN 0 ELSE a@0 + 2 END as c13, CASE WHEN b@1 = 11 THEN 0 ELSE a@0 + 2 END as c14] -02)--MemoryExec: partitions=1, partition_sizes=[0] +01)ProjectionExec: expr=[__common_expr_1@0 AND b@4 = 0 as c1, __common_expr_1@0 AND b@4 = 1 as c2, b@4 = 2 AND a@3 = 1 as c3, b@4 = 3 AND a@3 = 1 as c4, __common_expr_2@1 OR b@4 = 4 as c5, __common_expr_2@1 OR b@4 = 5 as c6, b@4 = 6 OR a@3 = 3 as c7, b@4 = 7 OR a@3 = 3 as c8, CASE WHEN __common_expr_3@2 THEN 0 ELSE 1 END as c9, CASE WHEN __common_expr_3@2 THEN 0 ELSE 2 END as c10, CASE WHEN b@4 = 8 THEN a@3 + 1 ELSE 0 END as c11, CASE WHEN b@4 = 9 THEN a@3 + 1 ELSE 0 END as c12, CASE WHEN b@4 = 10 THEN 0 ELSE a@3 + 2 END as c13, CASE WHEN b@4 = 11 THEN 0 ELSE a@3 + 2 END as c14] +02)--ProjectionExec: expr=[a@0 = 0 as __common_expr_1, a@0 = 2 as __common_expr_2, a@0 = 4 as __common_expr_3, a@0 as a, b@1 as b] +03)----MemoryExec: partitions=1, partition_sizes=[0] # Common children of volatile, short-circuit expression -# TODO: consider surely executed children of "short circuited"s for CSE. i.e. `a = 0`, `a = 2`, `a = 4` should be extracted query TT EXPLAIN SELECT a = 0 AND b = random() AS c1, @@ -148,11 +150,13 @@ EXPLAIN SELECT FROM t1 ---- logical_plan -01)Projection: t1.a = Float64(0) AND t1.b = random() AS c1, t1.a = Float64(0) AND t1.b = Float64(1) + random() AS c2, t1.b = Float64(2) + random() AND t1.a = Float64(1) AS c3, t1.b = Float64(3) + random() AND t1.a = Float64(1) AS c4, t1.a = Float64(2) OR t1.b = Float64(4) + random() AS c5, t1.a = Float64(2) OR t1.b = Float64(5) + random() AS c6, t1.b = Float64(6) + random() OR t1.a = Float64(3) AS c7, t1.b = Float64(7) + random() OR t1.a = Float64(3) AS c8, CASE WHEN t1.a = Float64(4) THEN random() ELSE Float64(1) END AS c9, CASE WHEN t1.a = Float64(4) THEN random() ELSE Float64(2) END AS c10, CASE WHEN t1.b = Float64(8) + random() THEN t1.a + Float64(1) ELSE Float64(0) END AS c11, CASE WHEN t1.b = Float64(9) + random() THEN t1.a + Float64(1) ELSE Float64(0) END AS c12, CASE WHEN t1.b = Float64(10) + random() THEN Float64(0) ELSE t1.a + Float64(2) END AS c13, CASE WHEN t1.b = Float64(11) + random() THEN Float64(0) ELSE t1.a + Float64(2) END AS c14 -02)--TableScan: t1 projection=[a, b] +01)Projection: __common_expr_1 AND t1.b = random() AS c1, __common_expr_1 AND t1.b = Float64(1) + random() AS c2, t1.b = Float64(2) + random() AND t1.a = Float64(1) AS c3, t1.b = Float64(3) + random() AND t1.a = Float64(1) AS c4, __common_expr_2 OR t1.b = Float64(4) + random() AS c5, __common_expr_2 OR t1.b = Float64(5) + random() AS c6, t1.b = Float64(6) + random() OR t1.a = Float64(3) AS c7, t1.b = Float64(7) + random() OR t1.a = Float64(3) AS c8, CASE WHEN __common_expr_3 THEN random() ELSE Float64(1) END AS c9, CASE WHEN __common_expr_3 THEN random() ELSE Float64(2) END AS c10, CASE WHEN t1.b = Float64(8) + random() THEN t1.a + Float64(1) ELSE Float64(0) END AS c11, CASE WHEN t1.b = Float64(9) + random() THEN t1.a + Float64(1) ELSE Float64(0) END AS c12, CASE WHEN t1.b = Float64(10) + random() THEN Float64(0) ELSE t1.a + Float64(2) END AS c13, CASE WHEN t1.b = Float64(11) + random() THEN Float64(0) ELSE t1.a + Float64(2) END AS c14 +02)--Projection: t1.a = Float64(0) AS __common_expr_1, t1.a = Float64(2) AS __common_expr_2, t1.a = Float64(4) AS __common_expr_3, t1.a, t1.b +03)----TableScan: t1 projection=[a, b] physical_plan -01)ProjectionExec: expr=[a@0 = 0 AND b@1 = random() as c1, a@0 = 0 AND b@1 = 1 + random() as c2, b@1 = 2 + random() AND a@0 = 1 as c3, b@1 = 3 + random() AND a@0 = 1 as c4, a@0 = 2 OR b@1 = 4 + random() as c5, a@0 = 2 OR b@1 = 5 + random() as c6, b@1 = 6 + random() OR a@0 = 3 as c7, b@1 = 7 + random() OR a@0 = 3 as c8, CASE WHEN a@0 = 4 THEN random() ELSE 1 END as c9, CASE WHEN a@0 = 4 THEN random() ELSE 2 END as c10, CASE WHEN b@1 = 8 + random() THEN a@0 + 1 ELSE 0 END as c11, CASE WHEN b@1 = 9 + random() THEN a@0 + 1 ELSE 0 END as c12, CASE WHEN b@1 = 10 + random() THEN 0 ELSE a@0 + 2 END as c13, CASE WHEN b@1 = 11 + random() THEN 0 ELSE a@0 + 2 END as c14] -02)--MemoryExec: partitions=1, partition_sizes=[0] +01)ProjectionExec: expr=[__common_expr_1@0 AND b@4 = random() as c1, __common_expr_1@0 AND b@4 = 1 + random() as c2, b@4 = 2 + random() AND a@3 = 1 as c3, b@4 = 3 + random() AND a@3 = 1 as c4, __common_expr_2@1 OR b@4 = 4 + random() as c5, __common_expr_2@1 OR b@4 = 5 + random() as c6, b@4 = 6 + random() OR a@3 = 3 as c7, b@4 = 7 + random() OR a@3 = 3 as c8, CASE WHEN __common_expr_3@2 THEN random() ELSE 1 END as c9, CASE WHEN __common_expr_3@2 THEN random() ELSE 2 END as c10, CASE WHEN b@4 = 8 + random() THEN a@3 + 1 ELSE 0 END as c11, CASE WHEN b@4 = 9 + random() THEN a@3 + 1 ELSE 0 END as c12, CASE WHEN b@4 = 10 + random() THEN 0 ELSE a@3 + 2 END as c13, CASE WHEN b@4 = 11 + random() THEN 0 ELSE a@3 + 2 END as c14] +02)--ProjectionExec: expr=[a@0 = 0 as __common_expr_1, a@0 = 2 as __common_expr_2, a@0 = 4 as __common_expr_3, a@0 as a, b@1 as b] +03)----MemoryExec: partitions=1, partition_sizes=[0] # Common volatile children of short-circuit expression query TT @@ -171,3 +175,59 @@ logical_plan physical_plan 01)ProjectionExec: expr=[a@0 = random() AND b@1 = 0 as c1, a@0 = random() AND b@1 = 1 as c2, a@0 = 2 + random() OR b@1 = 4 as c3, a@0 = 2 + random() OR b@1 = 5 as c4, CASE WHEN a@0 = 4 + random() THEN 0 ELSE 1 END as c5, CASE WHEN a@0 = 4 + random() THEN 0 ELSE 2 END as c6] 02)--MemoryExec: partitions=1, partition_sizes=[0] + +# Surely only once but also conditionally evaluated expressions +query TT +EXPLAIN SELECT + (a = 1 OR random() = 0) AND a = 1 AS c1, + (a = 2 AND random() = 0) OR a = 2 AS c2, + CASE WHEN a + 3 = 0 THEN a + 3 ELSE 0 END AS c3, + CASE WHEN a + 4 = 0 THEN 0 WHEN a + 4 THEN 0 ELSE 0 END AS c4, + CASE WHEN a + 5 = 0 THEN 0 WHEN random() = 0 THEN a + 5 ELSE 0 END AS c5, + CASE WHEN a + 6 = 0 THEN 0 ELSE a + 6 END AS c6 +FROM t1 +---- +logical_plan +01)Projection: (__common_expr_1 OR random() = Float64(0)) AND __common_expr_1 AS c1, __common_expr_2 AND random() = Float64(0) OR __common_expr_2 AS c2, CASE WHEN __common_expr_3 = Float64(0) THEN __common_expr_3 ELSE Float64(0) END AS c3, CASE WHEN __common_expr_4 = Float64(0) THEN Int64(0) WHEN CAST(__common_expr_4 AS Boolean) THEN Int64(0) ELSE Int64(0) END AS c4, CASE WHEN __common_expr_5 = Float64(0) THEN Float64(0) WHEN random() = Float64(0) THEN __common_expr_5 ELSE Float64(0) END AS c5, CASE WHEN __common_expr_6 = Float64(0) THEN Float64(0) ELSE __common_expr_6 END AS c6 +02)--Projection: t1.a = Float64(1) AS __common_expr_1, t1.a = Float64(2) AS __common_expr_2, t1.a + Float64(3) AS __common_expr_3, t1.a + Float64(4) AS __common_expr_4, t1.a + Float64(5) AS __common_expr_5, t1.a + Float64(6) AS __common_expr_6 +03)----TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[(__common_expr_1@0 OR random() = 0) AND __common_expr_1@0 as c1, __common_expr_2@1 AND random() = 0 OR __common_expr_2@1 as c2, CASE WHEN __common_expr_3@2 = 0 THEN __common_expr_3@2 ELSE 0 END as c3, CASE WHEN __common_expr_4@3 = 0 THEN 0 WHEN CAST(__common_expr_4@3 AS Boolean) THEN 0 ELSE 0 END as c4, CASE WHEN __common_expr_5@4 = 0 THEN 0 WHEN random() = 0 THEN __common_expr_5@4 ELSE 0 END as c5, CASE WHEN __common_expr_6@5 = 0 THEN 0 ELSE __common_expr_6@5 END as c6] +02)--ProjectionExec: expr=[a@0 = 1 as __common_expr_1, a@0 = 2 as __common_expr_2, a@0 + 3 as __common_expr_3, a@0 + 4 as __common_expr_4, a@0 + 5 as __common_expr_5, a@0 + 6 as __common_expr_6] +03)----MemoryExec: partitions=1, partition_sizes=[0] + +# Surely only once but also conditionally evaluated subexpressions +query TT +EXPLAIN SELECT + (a = 1 OR random() = 0) AND (a = 1 OR random() = 1) AS c1, + (a = 2 AND random() = 0) OR (a = 2 AND random() = 1) AS c2, + CASE WHEN a + 3 = 0 THEN a + 3 + random() ELSE 0 END AS c3, + CASE WHEN a + 4 = 0 THEN 0 ELSE a + 4 + random() END AS c4 +FROM t1 +---- +logical_plan +01)Projection: (__common_expr_1 OR random() = Float64(0)) AND (__common_expr_1 OR random() = Float64(1)) AS c1, __common_expr_2 AND random() = Float64(0) OR __common_expr_2 AND random() = Float64(1) AS c2, CASE WHEN __common_expr_3 = Float64(0) THEN __common_expr_3 + random() ELSE Float64(0) END AS c3, CASE WHEN __common_expr_4 = Float64(0) THEN Float64(0) ELSE __common_expr_4 + random() END AS c4 +02)--Projection: t1.a = Float64(1) AS __common_expr_1, t1.a = Float64(2) AS __common_expr_2, t1.a + Float64(3) AS __common_expr_3, t1.a + Float64(4) AS __common_expr_4 +03)----TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[(__common_expr_1@0 OR random() = 0) AND (__common_expr_1@0 OR random() = 1) as c1, __common_expr_2@1 AND random() = 0 OR __common_expr_2@1 AND random() = 1 as c2, CASE WHEN __common_expr_3@2 = 0 THEN __common_expr_3@2 + random() ELSE 0 END as c3, CASE WHEN __common_expr_4@3 = 0 THEN 0 ELSE __common_expr_4@3 + random() END as c4] +02)--ProjectionExec: expr=[a@0 = 1 as __common_expr_1, a@0 = 2 as __common_expr_2, a@0 + 3 as __common_expr_3, a@0 + 4 as __common_expr_4] +03)----MemoryExec: partitions=1, partition_sizes=[0] + +# Only conditionally evaluated expressions +query TT +EXPLAIN SELECT + (random() = 0 OR a = 1) AND a = 1 AS c1, + (random() = 0 AND a = 2) OR a = 2 AS c2, + CASE WHEN random() = 0 THEN a + 3 ELSE a + 3 END AS c3, + CASE WHEN random() = 0 THEN 0 WHEN a + 4 = 0 THEN a + 4 ELSE 0 END AS c4, + CASE WHEN random() = 0 THEN 0 WHEN a + 5 = 0 THEN 0 ELSE a + 5 END AS c5, + CASE WHEN random() = 0 THEN 0 WHEN random() = 0 THEN a + 6 ELSE a + 6 END AS c6 +FROM t1 +---- +logical_plan +01)Projection: (random() = Float64(0) OR t1.a = Float64(1)) AND t1.a = Float64(1) AS c1, random() = Float64(0) AND t1.a = Float64(2) OR t1.a = Float64(2) AS c2, CASE WHEN random() = Float64(0) THEN t1.a + Float64(3) ELSE t1.a + Float64(3) END AS c3, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN t1.a + Float64(4) = Float64(0) THEN t1.a + Float64(4) ELSE Float64(0) END AS c4, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN t1.a + Float64(5) = Float64(0) THEN Float64(0) ELSE t1.a + Float64(5) END AS c5, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN random() = Float64(0) THEN t1.a + Float64(6) ELSE t1.a + Float64(6) END AS c6 +02)--TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[(random() = 0 OR a@0 = 1) AND a@0 = 1 as c1, random() = 0 AND a@0 = 2 OR a@0 = 2 as c2, CASE WHEN random() = 0 THEN a@0 + 3 ELSE a@0 + 3 END as c3, CASE WHEN random() = 0 THEN 0 WHEN a@0 + 4 = 0 THEN a@0 + 4 ELSE 0 END as c4, CASE WHEN random() = 0 THEN 0 WHEN a@0 + 5 = 0 THEN 0 ELSE a@0 + 5 END as c5, CASE WHEN random() = 0 THEN 0 WHEN random() = 0 THEN a@0 + 6 ELSE a@0 + 6 END as c6] +02)--MemoryExec: partitions=1, partition_sizes=[0] diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index f9baf8db69d5b..95f67245a981e 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1504,21 +1504,25 @@ query TT EXPLAIN SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t; ---- logical_plan -01)Projection: t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y > Int64(0) AND Int64(1) / t.y < Int64(1), t.x > Int32(0) AND t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x -02)--TableScan: t projection=[x, y] +01)Projection: __common_expr_1 AND Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y > Int64(0) AND Int64(1) / t.y < Int64(1), t.x > Int32(0) AND __common_expr_1 AND Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x +02)--Projection: t.y > Int32(0) AS __common_expr_1, t.x, t.y +03)----TableScan: t projection=[x, y] physical_plan -01)ProjectionExec: expr=[y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 as t.y > Int64(0) AND Int64(1) / t.y < Int64(1), x@0 > 0 AND y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 / CAST(x@0 AS Int64) as t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x] -02)--MemoryExec: partitions=1, partition_sizes=[1] +01)ProjectionExec: expr=[__common_expr_1@0 AND 1 / CAST(y@2 AS Int64) < 1 as t.y > Int64(0) AND Int64(1) / t.y < Int64(1), x@1 > 0 AND __common_expr_1@0 AND 1 / CAST(y@2 AS Int64) < 1 / CAST(x@1 AS Int64) as t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x] +02)--ProjectionExec: expr=[y@1 > 0 as __common_expr_1, x@0 as x, y@1 as y] +03)----MemoryExec: partitions=1, partition_sizes=[1] query TT EXPLAIN SELECT y = 0 or 1 / y < 1, x = 0 or y = 0 or 1 / y < 1 / x from t; ---- logical_plan -01)Projection: t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y = Int64(0) OR Int64(1) / t.y < Int64(1), t.x = Int32(0) OR t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x -02)--TableScan: t projection=[x, y] +01)Projection: __common_expr_1 OR Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y = Int64(0) OR Int64(1) / t.y < Int64(1), t.x = Int32(0) OR __common_expr_1 OR Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x +02)--Projection: t.y = Int32(0) AS __common_expr_1, t.x, t.y +03)----TableScan: t projection=[x, y] physical_plan -01)ProjectionExec: expr=[y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 as t.y = Int64(0) OR Int64(1) / t.y < Int64(1), x@0 = 0 OR y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 / CAST(x@0 AS Int64) as t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x] -02)--MemoryExec: partitions=1, partition_sizes=[1] +01)ProjectionExec: expr=[__common_expr_1@0 OR 1 / CAST(y@2 AS Int64) < 1 as t.y = Int64(0) OR Int64(1) / t.y < Int64(1), x@1 = 0 OR __common_expr_1@0 OR 1 / CAST(y@2 AS Int64) < 1 / CAST(x@1 AS Int64) as t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x] +02)--ProjectionExec: expr=[y@1 = 0 as __common_expr_1, x@0 as x, y@1 as y] +03)----MemoryExec: partitions=1, partition_sizes=[1] # due to the reason describe in https://github.com/apache/datafusion/issues/8927, # the following queries will fail diff --git a/datafusion/sqllogictest/test_files/tpch/q14.slt.part b/datafusion/sqllogictest/test_files/tpch/q14.slt.part index e56e463a617d7..3743c201ff2e5 100644 --- a/datafusion/sqllogictest/test_files/tpch/q14.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q14.slt.part @@ -32,9 +32,9 @@ where and l_shipdate < date '1995-10-01'; ---- logical_plan -01)Projection: Float64(100) * CAST(sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Float64) / CAST(sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Float64) AS promo_revenue -02)--Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) ELSE Decimal128(Some(0),38,4) END) AS sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] -03)----Projection: lineitem.l_extendedprice, lineitem.l_discount, part.p_type +01)Projection: Float64(100) * CAST(sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Float64) / CAST(sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Float64) AS promo_revenue +02)--Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN __common_expr_1 ELSE Decimal128(Some(0),38,4) END) AS sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), sum(__common_expr_1) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +03)----Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS __common_expr_1, part.p_type 04)------Inner Join: lineitem.l_partkey = part.p_partkey 05)--------Projection: lineitem.l_partkey, lineitem.l_extendedprice, lineitem.l_discount 06)----------Filter: lineitem.l_shipdate >= Date32("1995-09-01") AND lineitem.l_shipdate < Date32("1995-10-01") @@ -44,19 +44,20 @@ physical_plan 01)ProjectionExec: expr=[100 * CAST(sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END)@0 AS Float64) / CAST(sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 AS Float64) as promo_revenue] 02)--AggregateExec: mode=Final, gby=[], aggr=[sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -05)--------CoalesceBatchesExec: target_batch_size=8192 -06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], projection=[l_extendedprice@1, l_discount@2, p_type@4] -07)------------CoalesceBatchesExec: target_batch_size=8192 -08)--------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 -09)----------------ProjectionExec: expr=[l_partkey@0 as l_partkey, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] -10)------------------CoalesceBatchesExec: target_batch_size=8192 -11)--------------------FilterExec: l_shipdate@3 >= 1995-09-01 AND l_shipdate@3 < 1995-10-01 -12)----------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_extendedprice, l_discount, l_shipdate], has_header=false -13)------------CoalesceBatchesExec: target_batch_size=8192 -14)--------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 -15)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -16)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_type], has_header=false +04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +05)--------ProjectionExec: expr=[l_extendedprice@0 * (Some(1),20,0 - l_discount@1) as __common_expr_1, p_type@2 as p_type] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], projection=[l_extendedprice@1, l_discount@2, p_type@4] +08)--------------CoalesceBatchesExec: target_batch_size=8192 +09)----------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 +10)------------------ProjectionExec: expr=[l_partkey@0 as l_partkey, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] +11)--------------------CoalesceBatchesExec: target_batch_size=8192 +12)----------------------FilterExec: l_shipdate@3 >= 1995-09-01 AND l_shipdate@3 < 1995-10-01 +13)------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_extendedprice, l_discount, l_shipdate], has_header=false +14)--------------CoalesceBatchesExec: target_batch_size=8192 +15)----------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 +16)------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +17)--------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_type], has_header=false From 1dfac86a89750193491cf3e04917e37b92c64ffa Mon Sep 17 00:00:00 2001 From: wiedld Date: Fri, 12 Jul 2024 04:04:42 -0700 Subject: [PATCH 21/59] fix(11397): surface proper errors in ParquetSink (#11399) * fix(11397): do not surface errors for closed channels, and instead let the task join errors be surfaced * fix(11397): terminate early on channel send failure --- .../src/datasource/file_format/parquet.rs | 32 +++++++++---------- datafusion/core/tests/memory_limit/mod.rs | 4 +-- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 694c949285374..6271d8af37862 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -893,12 +893,12 @@ async fn send_arrays_to_col_writers( let mut next_channel = 0; for (array, field) in rb.columns().iter().zip(schema.fields()) { for c in compute_leaves(field, array)? { - col_array_channels[next_channel] - .send(c) - .await - .map_err(|_| { - DataFusionError::Internal("Unable to send array to writer!".into()) - })?; + // Do not surface error from closed channel (means something + // else hit an error, and the plan is shutting down). + if col_array_channels[next_channel].send(c).await.is_err() { + return Ok(()); + } + next_channel += 1; } } @@ -984,11 +984,11 @@ fn spawn_parquet_parallel_serialization_task( &pool, ); - serialize_tx.send(finalize_rg_task).await.map_err(|_| { - DataFusionError::Internal( - "Unable to send closed RG to concat task!".into(), - ) - })?; + // Do not surface error from closed channel (means something + // else hit an error, and the plan is shutting down). + if serialize_tx.send(finalize_rg_task).await.is_err() { + return Ok(()); + } current_rg_rows = 0; rb = rb.slice(rows_left, rb.num_rows() - rows_left); @@ -1013,11 +1013,11 @@ fn spawn_parquet_parallel_serialization_task( &pool, ); - serialize_tx.send(finalize_rg_task).await.map_err(|_| { - DataFusionError::Internal( - "Unable to send closed RG to concat task!".into(), - ) - })?; + // Do not surface error from closed channel (means something + // else hit an error, and the plan is shutting down). + if serialize_tx.send(finalize_rg_task).await.is_err() { + return Ok(()); + } } Ok(()) diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index f7402357d1c76..7ef24609e238d 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -340,8 +340,8 @@ async fn oom_parquet_sink() { path.to_string_lossy() )) .with_expected_errors(vec![ - // TODO: update error handling in ParquetSink - "Unable to send array to writer!", + "Failed to allocate additional", + "for ParquetSink(ArrowColumnWriter)", ]) .with_memory_limit(200_000) .run() From 13ddbaf2f7220c26f443d097697d1380e63f6206 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 12 Jul 2024 10:53:58 -0400 Subject: [PATCH 22/59] Minor: Add note about SQLLancer fuzz testing to docs (#11430) * Minor: Add note about SQLLancer fuzz testing to docs * prettier --- docs/source/contributor-guide/testing.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/docs/source/contributor-guide/testing.md b/docs/source/contributor-guide/testing.md index 018cc6233c461..0f4461ab2c2c6 100644 --- a/docs/source/contributor-guide/testing.md +++ b/docs/source/contributor-guide/testing.md @@ -39,7 +39,7 @@ DataFusion's SQL implementation is tested using [sqllogictest](https://github.co Like similar systems such as [DuckDB](https://duckdb.org/dev/testing), DataFusion has chosen to trade off a slightly higher barrier to contribution for longer term maintainability. -### Rust Integration Tests +## Rust Integration Tests There are several tests of the public interface of the DataFusion library in the [tests](https://github.com/apache/datafusion/tree/main/datafusion/core/tests) directory. @@ -49,6 +49,18 @@ You can run these tests individually using `cargo` as normal command such as cargo test -p datafusion --test parquet_exec ``` +## SQL "Fuzz" testing + +DataFusion uses the [SQLancer] for "fuzz" testing: it generates random SQL +queries and execute them against DataFusion to find bugs. + +The code is in the [datafusion-sqllancer] repository, and we welcome further +contributions. Kudos to [@2010YOUY01] for the initial implementation. + +[sqlancer]: https://github.com/sqlancer/sqlancer +[datafusion-sqllancer]: https://github.com/datafusion-contrib/datafusion-sqllancer +[@2010youy01]: https://github.com/2010YOUY01 + ## Documentation Examples We use Rust [doctest] to verify examples from the documentation are correct and From c769a70dc1c746460b4c1369d4e42c4a78da9571 Mon Sep 17 00:00:00 2001 From: tmi Date: Fri, 12 Jul 2024 17:52:24 +0200 Subject: [PATCH 23/59] Trivial: use arrow csv writer's timestamp_tz_format (#11407) --- datafusion/common/src/file_options/csv_writer.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datafusion/common/src/file_options/csv_writer.rs b/datafusion/common/src/file_options/csv_writer.rs index 5792cfdba9e0c..ae069079a68f8 100644 --- a/datafusion/common/src/file_options/csv_writer.rs +++ b/datafusion/common/src/file_options/csv_writer.rs @@ -63,6 +63,9 @@ impl TryFrom<&CsvOptions> for CsvWriterOptions { if let Some(v) = &value.timestamp_format { builder = builder.with_timestamp_format(v.into()) } + if let Some(v) = &value.timestamp_tz_format { + builder = builder.with_timestamp_tz_format(v.into()) + } if let Some(v) = &value.time_format { builder = builder.with_time_format(v.into()) } From a2a6458e420209c7125b08966c5726b5fd104195 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 12 Jul 2024 11:53:03 -0400 Subject: [PATCH 24/59] Minor: improve documentation for sql unparsing (#11395) --- datafusion/sql/src/lib.rs | 6 ++- datafusion/sql/src/unparser/expr.rs | 29 +++++++++---- datafusion/sql/src/unparser/mod.rs | 64 +++++++++++++++++++++++++++-- datafusion/sql/src/unparser/plan.rs | 24 ++++++++--- 4 files changed, 105 insertions(+), 18 deletions(-) diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index eb5fec7a3c8bb..f53cab5df8482 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -17,7 +17,7 @@ // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -//! This module provides: +//! This crate provides: //! //! 1. A SQL parser, [`DFParser`], that translates SQL query text into //! an abstract syntax tree (AST), [`Statement`]. @@ -25,10 +25,14 @@ //! 2. A SQL query planner [`SqlToRel`] that creates [`LogicalPlan`]s //! from [`Statement`]s. //! +//! 3. A SQL [`unparser`] that converts [`Expr`]s and [`LogicalPlan`]s +//! into SQL query text. +//! //! [`DFParser`]: parser::DFParser //! [`Statement`]: parser::Statement //! [`SqlToRel`]: planner::SqlToRel //! [`LogicalPlan`]: datafusion_expr::logical_plan::LogicalPlan +//! [`Expr`]: datafusion_expr::expr::Expr mod cte; mod expr; diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index e0d05c400cb09..eb149c819c8b0 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -72,21 +72,34 @@ impl Display for Unparsed { } } -/// Convert a DataFusion [`Expr`] to `sqlparser::ast::Expr` +/// Convert a DataFusion [`Expr`] to [`ast::Expr`] /// -/// This function is the opposite of `SqlToRel::sql_to_expr` and can -/// be used to, among other things, convert [`Expr`]s to strings. -/// Throws an error if [`Expr`] can not be represented by an `sqlparser::ast::Expr` +/// This function is the opposite of [`SqlToRel::sql_to_expr`] and can be used +/// to, among other things, convert [`Expr`]s to SQL strings. Such strings could +/// be used to pass filters or other expressions to another SQL engine. +/// +/// # Errors +/// +/// Throws an error if [`Expr`] can not be represented by an [`ast::Expr`] +/// +/// # See Also +/// +/// * [`Unparser`] for more control over the conversion to SQL +/// * [`plan_to_sql`] for converting a [`LogicalPlan`] to SQL /// /// # Example /// ``` /// use datafusion_expr::{col, lit}; /// use datafusion_sql::unparser::expr_to_sql; -/// let expr = col("a").gt(lit(4)); -/// let sql = expr_to_sql(&expr).unwrap(); -/// -/// assert_eq!(format!("{}", sql), "(a > 4)") +/// let expr = col("a").gt(lit(4)); // form an expression `a > 4` +/// let sql = expr_to_sql(&expr).unwrap(); // convert to ast::Expr +/// // use the Display impl to convert to SQL text +/// assert_eq!(sql.to_string(), "(a > 4)") /// ``` +/// +/// [`SqlToRel::sql_to_expr`]: crate::planner::SqlToRel::sql_to_expr +/// [`plan_to_sql`]: crate::unparser::plan_to_sql +/// [`LogicalPlan`]: datafusion_expr::logical_plan::LogicalPlan pub fn expr_to_sql(expr: &Expr) -> Result { let unparser = Unparser::default(); unparser.expr_to_sql(expr) diff --git a/datafusion/sql/src/unparser/mod.rs b/datafusion/sql/src/unparser/mod.rs index e5ffbc8a212ab..83ae64ba238b0 100644 --- a/datafusion/sql/src/unparser/mod.rs +++ b/datafusion/sql/src/unparser/mod.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! [`Unparser`] for converting `Expr` to SQL text + mod ast; mod expr; mod plan; @@ -27,6 +29,29 @@ pub use plan::plan_to_sql; use self::dialect::{DefaultDialect, Dialect}; pub mod dialect; +/// Convert a DataFusion [`Expr`] to [`sqlparser::ast::Expr`] +/// +/// See [`expr_to_sql`] for background. `Unparser` allows greater control of +/// the conversion, but with a more complicated API. +/// +/// To get more human-readable output, see [`Self::with_pretty`] +/// +/// # Example +/// ``` +/// use datafusion_expr::{col, lit}; +/// use datafusion_sql::unparser::Unparser; +/// let expr = col("a").gt(lit(4)); // form an expression `a > 4` +/// let unparser = Unparser::default(); +/// let sql = unparser.expr_to_sql(&expr).unwrap();// convert to AST +/// // use the Display impl to convert to SQL text +/// assert_eq!(sql.to_string(), "(a > 4)"); +/// // now convert to pretty sql +/// let unparser = unparser.with_pretty(true); +/// let sql = unparser.expr_to_sql(&expr).unwrap(); +/// assert_eq!(sql.to_string(), "a > 4"); // note lack of parenthesis +/// ``` +/// +/// [`Expr`]: datafusion_expr::Expr pub struct Unparser<'a> { dialect: &'a dyn Dialect, pretty: bool, @@ -40,9 +65,42 @@ impl<'a> Unparser<'a> { } } - /// Allow unparser to remove parenthesis according to the precedence rules of DataFusion. - /// This might make it invalid SQL for other SQL query engines with different precedence - /// rules, even if its valid for DataFusion. + /// Create pretty SQL output, better suited for human consumption + /// + /// See example on the struct level documentation + /// + /// # Pretty Output + /// + /// By default, `Unparser` generates SQL text that will parse back to the + /// same parsed [`Expr`], which is useful for creating machine readable + /// expressions to send to other systems. However, the resulting expressions are + /// not always nice to read for humans. + /// + /// For example + /// + /// ```sql + /// ((a + 4) > 5) + /// ``` + /// + /// This method removes parenthesis using to the precedence rules of + /// DataFusion. If the output is reparsed, the resulting [`Expr`] produces + /// same value as the original in DataFusion, but with a potentially + /// different order of operations. + /// + /// Note that this setting may create invalid SQL for other SQL query + /// engines with different precedence rules + /// + /// # Example + /// ``` + /// use datafusion_expr::{col, lit}; + /// use datafusion_sql::unparser::Unparser; + /// let expr = col("a").gt(lit(4)).and(col("b").lt(lit(5))); // form an expression `a > 4 AND b < 5` + /// let unparser = Unparser::default().with_pretty(true); + /// let sql = unparser.expr_to_sql(&expr).unwrap(); + /// assert_eq!(sql.to_string(), "a > 4 AND b < 5"); // note lack of parenthesis + /// ``` + /// + /// [`Expr`]: datafusion_expr::Expr pub fn with_pretty(mut self, pretty: bool) -> Self { self.pretty = pretty; self diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 15137403c582d..41a8c968841b3 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -33,10 +33,18 @@ use super::{ Unparser, }; -/// Convert a DataFusion [`LogicalPlan`] to `sqlparser::ast::Statement` +/// Convert a DataFusion [`LogicalPlan`] to [`ast::Statement`] /// -/// This function is the opposite of `SqlToRel::sql_statement_to_plan` and can -/// be used to, among other things, convert `LogicalPlan`s to strings. +/// This function is the opposite of [`SqlToRel::sql_statement_to_plan`] and can +/// be used to, among other things, to convert `LogicalPlan`s to SQL strings. +/// +/// # Errors +/// +/// This function returns an error if the plan cannot be converted to SQL. +/// +/// # See Also +/// +/// * [`expr_to_sql`] for converting [`Expr`], a single expression to SQL /// /// # Example /// ``` @@ -47,16 +55,20 @@ use super::{ /// Field::new("id", DataType::Utf8, false), /// Field::new("value", DataType::Utf8, false), /// ]); +/// // Scan 'table' and select columns 'id' and 'value' /// let plan = table_scan(Some("table"), &schema, None) /// .unwrap() /// .project(vec![col("id"), col("value")]) /// .unwrap() /// .build() /// .unwrap(); -/// let sql = plan_to_sql(&plan).unwrap(); -/// -/// assert_eq!(format!("{}", sql), "SELECT \"table\".id, \"table\".\"value\" FROM \"table\"") +/// let sql = plan_to_sql(&plan).unwrap(); // convert to AST +/// // use the Display impl to convert to SQL text +/// assert_eq!(sql.to_string(), "SELECT \"table\".id, \"table\".\"value\" FROM \"table\"") /// ``` +/// +/// [`SqlToRel::sql_statement_to_plan`]: crate::planner::SqlToRel::sql_statement_to_plan +/// [`expr_to_sql`]: crate::unparser::expr_to_sql pub fn plan_to_sql(plan: &LogicalPlan) -> Result { let unparser = Unparser::default(); unparser.plan_to_sql(plan) From dc21a6c25893e7906da588debf18a8e5918b3b32 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 12 Jul 2024 11:53:44 -0400 Subject: [PATCH 25/59] Minor: Consolidate specificataion doc sections (#11427) --- docs/source/contributor-guide/index.md | 16 ---------------- .../contributor-guide/specification/index.rst | 10 ++++++++++ 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index 891277f647570..ad49b614c3341 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -134,19 +134,3 @@ The good thing about open code and open development is that any issues in one ch Pull requests will be marked with a `stale` label after 60 days of inactivity and then closed 7 days after that. Commenting on the PR will remove the `stale` label. - -## Specifications - -We formalize some DataFusion semantics and behaviors through specification -documents. These specifications are useful to be used as references to help -resolve ambiguities during development or code reviews. - -You are also welcome to propose changes to existing specifications or create -new specifications as you see fit. - -Here is the list current active specifications: - -- [Output field name semantic](https://datafusion.apache.org/contributor-guide/specification/output-field-name-semantic.html) -- [Invariants](https://datafusion.apache.org/contributor-guide/specification/invariants.html) - -All specifications are stored in the `docs/source/specification` folder. diff --git a/docs/source/contributor-guide/specification/index.rst b/docs/source/contributor-guide/specification/index.rst index bcd5a895c4d24..a34f0b19e4dea 100644 --- a/docs/source/contributor-guide/specification/index.rst +++ b/docs/source/contributor-guide/specification/index.rst @@ -18,6 +18,16 @@ Specifications ============== +We formalize some DataFusion semantics and behaviors through specification +documents. These specifications are useful to be used as references to help +resolve ambiguities during development or code reviews. + +You are also welcome to propose changes to existing specifications or create +new specifications as you see fit. All specifications are stored in the +`docs/source/specification` folder. Here is the list current active +specifications: + + .. toctree:: :maxdepth: 1 From b075ac471e6d27dfe40b6586a72070a9ec4751a9 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 12 Jul 2024 15:27:16 -0400 Subject: [PATCH 26/59] Minor: consolidate doc roadmap pages (#11426) --- .../contributor-guide/quarterly_roadmap.md | 96 ------------------- docs/source/contributor-guide/roadmap.md | 81 ++++++++++++++++ docs/source/index.rst | 1 - 3 files changed, 81 insertions(+), 97 deletions(-) delete mode 100644 docs/source/contributor-guide/quarterly_roadmap.md diff --git a/docs/source/contributor-guide/quarterly_roadmap.md b/docs/source/contributor-guide/quarterly_roadmap.md deleted file mode 100644 index ee82617225aa6..0000000000000 --- a/docs/source/contributor-guide/quarterly_roadmap.md +++ /dev/null @@ -1,96 +0,0 @@ - - -# Quarterly Roadmap - -A quarterly roadmap will be published to give the DataFusion community visibility into the priorities of the projects contributors. This roadmap is not binding. - -## 2023 Q4 - -- Improve data output (`COPY`, `INSERT` and DataFrame) output capability [#6569](https://github.com/apache/datafusion/issues/6569) -- Implementation of `ARRAY` types and related functions [#6980](https://github.com/apache/datafusion/issues/6980) -- Write an industrial paper about DataFusion for SIGMOD [#6782](https://github.com/apache/datafusion/issues/6782) - -## 2022 Q2 - -### DataFusion Core - -- IO Improvements - - Reading, registering, and writing more file formats from both DataFrame API and SQL - - Additional options for IO including partitioning and metadata support -- Work Scheduling - - Improve predictability, observability and performance of IO and CPU-bound work - - Develop a more explicit story for managing parallelism during plan execution -- Memory Management - - Add more operators for memory limited execution -- Performance - - Incorporate row-format into operators such as aggregate - - Add row-format benchmarks - - Explore JIT-compiling complex expressions - - Explore LLVM for JIT, with inline Rust functions as the primary goal - - Improve performance of Sort and Merge using Row Format / JIT expressions -- Documentation - - General improvements to DataFusion website - - Publish design documents -- Streaming - - Create `StreamProvider` trait - -### Ballista - -- Make production ready - - Shuffle file cleanup - - Fill functional gaps between DataFusion and Ballista - - Improve task scheduling and data exchange efficiency - - Better error handling - - Task failure - - Executor lost - - Schedule restart - - Improve monitoring and logging - - Auto scaling support -- Support for multi-scheduler deployments. Initially for resiliency and fault tolerance but ultimately to support sharding for scalability and more efficient caching. -- Executor deployment grouping based on resource allocation - -### Extensions ([datafusion-contrib](https://github.com/datafusion-contrib)) - -#### [DataFusion-Python](https://github.com/datafusion-contrib/datafusion-python) - -- Add missing functionality to DataFrame and SessionContext -- Improve documentation - -#### [DataFusion-S3](https://github.com/datafusion-contrib/datafusion-objectstore-s3) - -- Create Python bindings to use with datafusion-python - -#### [DataFusion-Tui](https://github.com/datafusion-contrib/datafusion-tui) - -- Create multiple SQL editors -- Expose more Context and query metadata -- Support new data sources - - BigTable, HDFS, HTTP APIs - -#### [DataFusion-BigTable](https://github.com/datafusion-contrib/datafusion-bigtable) - -- Python binding to use with datafusion-python -- Timestamp range predicate pushdown -- Multi-threaded partition aware execution -- Production ready Rust SDK - -#### [DataFusion-Streams](https://github.com/datafusion-contrib/datafusion-streams) - -- Create experimental implementation of `StreamProvider` trait diff --git a/docs/source/contributor-guide/roadmap.md b/docs/source/contributor-guide/roadmap.md index a6d78d9311aa4..3d9c1ee371fe6 100644 --- a/docs/source/contributor-guide/roadmap.md +++ b/docs/source/contributor-guide/roadmap.md @@ -43,3 +43,84 @@ start a conversation using a github issue or the make review efficient and avoid surprises. [The current list of `EPIC`s can be found here](https://github.com/apache/datafusion/issues?q=is%3Aissue+is%3Aopen+epic). + +# Quarterly Roadmap + +A quarterly roadmap will be published to give the DataFusion community +visibility into the priorities of the projects contributors. This roadmap is not +binding and we would welcome any/all contributions to help keep this list up to +date. + +## 2023 Q4 + +- Improve data output (`COPY`, `INSERT` and DataFrame) output capability [#6569](https://github.com/apache/datafusion/issues/6569) +- Implementation of `ARRAY` types and related functions [#6980](https://github.com/apache/datafusion/issues/6980) +- Write an industrial paper about DataFusion for SIGMOD [#6782](https://github.com/apache/datafusion/issues/6782) + +## 2022 Q2 + +### DataFusion Core + +- IO Improvements + - Reading, registering, and writing more file formats from both DataFrame API and SQL + - Additional options for IO including partitioning and metadata support +- Work Scheduling + - Improve predictability, observability and performance of IO and CPU-bound work + - Develop a more explicit story for managing parallelism during plan execution +- Memory Management + - Add more operators for memory limited execution +- Performance + - Incorporate row-format into operators such as aggregate + - Add row-format benchmarks + - Explore JIT-compiling complex expressions + - Explore LLVM for JIT, with inline Rust functions as the primary goal + - Improve performance of Sort and Merge using Row Format / JIT expressions +- Documentation + - General improvements to DataFusion website + - Publish design documents +- Streaming + - Create `StreamProvider` trait + +### Ballista + +- Make production ready + - Shuffle file cleanup + - Fill functional gaps between DataFusion and Ballista + - Improve task scheduling and data exchange efficiency + - Better error handling + - Task failure + - Executor lost + - Schedule restart + - Improve monitoring and logging + - Auto scaling support +- Support for multi-scheduler deployments. Initially for resiliency and fault tolerance but ultimately to support sharding for scalability and more efficient caching. +- Executor deployment grouping based on resource allocation + +### Extensions ([datafusion-contrib](https://github.com/datafusion-contrib)) + +### [DataFusion-Python](https://github.com/datafusion-contrib/datafusion-python) + +- Add missing functionality to DataFrame and SessionContext +- Improve documentation + +### [DataFusion-S3](https://github.com/datafusion-contrib/datafusion-objectstore-s3) + +- Create Python bindings to use with datafusion-python + +### [DataFusion-Tui](https://github.com/datafusion-contrib/datafusion-tui) + +- Create multiple SQL editors +- Expose more Context and query metadata +- Support new data sources + - BigTable, HDFS, HTTP APIs + +### [DataFusion-BigTable](https://github.com/datafusion-contrib/datafusion-bigtable) + +- Python binding to use with datafusion-python +- Timestamp range predicate pushdown +- Multi-threaded partition aware execution +- Production ready Rust SDK + +### [DataFusion-Streams](https://github.com/datafusion-contrib/datafusion-streams) + +- Create experimental implementation of `StreamProvider` trait diff --git a/docs/source/index.rst b/docs/source/index.rst index 8fbff208f5617..ca6905c434f35 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -121,7 +121,6 @@ To get started, see contributor-guide/testing contributor-guide/howtos contributor-guide/roadmap - contributor-guide/quarterly_roadmap contributor-guide/governance contributor-guide/inviting contributor-guide/specification/index From d5367f3ff5ed506e824a04c68120194deb68a908 Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Fri, 12 Jul 2024 22:34:35 +0300 Subject: [PATCH 27/59] Avoid calling shutdown after failed write of AsyncWrite (#249) (#250) (#11415) in `serialize_rb_stream_to_object_store` --- .../file_format/write/orchestration.rs | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index a62b5715aeb3b..8bd0dae9f5a48 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -42,15 +42,20 @@ use tokio::task::JoinSet; type WriterType = Box; type SerializerType = Arc; -/// Serializes a single data stream in parallel and writes to an ObjectStore -/// concurrently. Data order is preserved. In the event of an error, -/// the ObjectStore writer is returned to the caller in addition to an error, -/// so that the caller may handle aborting failed writes. +/// Serializes a single data stream in parallel and writes to an ObjectStore concurrently. +/// Data order is preserved. +/// +/// In the event of a non-IO error which does not involve the ObjectStore writer, +/// the writer returned to the caller in addition to the error, +/// so that failed writes may be aborted. +/// +/// In the event of an IO error involving the ObjectStore writer, +/// the writer is dropped to avoid calling further methods on it which might panic. pub(crate) async fn serialize_rb_stream_to_object_store( mut data_rx: Receiver, serializer: Arc, mut writer: WriterType, -) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> { +) -> std::result::Result<(WriterType, u64), (Option, DataFusionError)> { let (tx, mut rx) = mpsc::channel::>>(100); let serialize_task = SpawnedTask::spawn(async move { @@ -82,7 +87,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store( Ok(_) => (), Err(e) => { return Err(( - writer, + None, DataFusionError::Execution(format!( "Error writing to object store: {e}" )), @@ -93,12 +98,12 @@ pub(crate) async fn serialize_rb_stream_to_object_store( } Ok(Err(e)) => { // Return the writer along with the error - return Err((writer, e)); + return Err((Some(writer), e)); } Err(e) => { // Handle task panic or cancellation return Err(( - writer, + Some(writer), DataFusionError::Execution(format!( "Serialization task panicked or was cancelled: {e}" )), @@ -109,10 +114,10 @@ pub(crate) async fn serialize_rb_stream_to_object_store( match serialize_task.join().await { Ok(Ok(_)) => (), - Ok(Err(e)) => return Err((writer, e)), + Ok(Err(e)) => return Err((Some(writer), e)), Err(_) => { return Err(( - writer, + Some(writer), internal_datafusion_err!("Unknown error writing to object store"), )) } @@ -153,7 +158,7 @@ pub(crate) async fn stateless_serialize_and_write_files( row_count += cnt; } Err((writer, e)) => { - finished_writers.push(writer); + finished_writers.extend(writer); any_errors = true; triggering_error = Some(e); } From 02335ebe2dd36081e22ed2d8ab46287c6d950a5c Mon Sep 17 00:00:00 2001 From: kamille Date: Sat, 13 Jul 2024 03:50:22 +0800 Subject: [PATCH 28/59] Short term way to make `AggregateStatistics` still work when min/max is converted to udaf (#11261) * impl the short term solution. * add todos. --- .../aggregate_statistics.rs | 136 +++++++++++------- 1 file changed, 85 insertions(+), 51 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 7e9aec9e5e4c4..66067d8cb5c42 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -140,31 +140,29 @@ fn take_optimizable_column_and_table_count( stats: &Statistics, ) -> Option<(ScalarValue, String)> { let col_stats = &stats.column_statistics; - if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { - if agg_expr.fun().name() == "count" && !agg_expr.is_distinct() { - if let Precision::Exact(num_rows) = stats.num_rows { - let exprs = agg_expr.expressions(); - if exprs.len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = - exprs[0].as_any().downcast_ref::() - { - let current_val = &col_stats[col_expr.index()].null_count; - if let &Precision::Exact(val) = current_val { - return Some(( - ScalarValue::Int64(Some((num_rows - val) as i64)), - agg_expr.name().to_string(), - )); - } - } else if let Some(lit_expr) = - exprs[0].as_any().downcast_ref::() - { - if lit_expr.value() == &COUNT_STAR_EXPANSION { - return Some(( - ScalarValue::Int64(Some(num_rows as i64)), - agg_expr.name().to_string(), - )); - } + if is_non_distinct_count(agg_expr) { + if let Precision::Exact(num_rows) = stats.num_rows { + let exprs = agg_expr.expressions(); + if exprs.len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = + exprs[0].as_any().downcast_ref::() + { + let current_val = &col_stats[col_expr.index()].null_count; + if let &Precision::Exact(val) = current_val { + return Some(( + ScalarValue::Int64(Some((num_rows - val) as i64)), + agg_expr.name().to_string(), + )); + } + } else if let Some(lit_expr) = + exprs[0].as_any().downcast_ref::() + { + if lit_expr.value() == &COUNT_STAR_EXPANSION { + return Some(( + ScalarValue::Int64(Some(num_rows as i64)), + agg_expr.name().to_string(), + )); } } } @@ -182,26 +180,22 @@ fn take_optimizable_min( match *num_rows { 0 => { // MIN/MAX with 0 rows is always null - if let Some(casted_expr) = - agg_expr.as_any().downcast_ref::() - { + if is_min(agg_expr) { if let Ok(min_data_type) = - ScalarValue::try_from(casted_expr.field().unwrap().data_type()) + ScalarValue::try_from(agg_expr.field().unwrap().data_type()) { - return Some((min_data_type, casted_expr.name().to_string())); + return Some((min_data_type, agg_expr.name().to_string())); } } } value if value > 0 => { let col_stats = &stats.column_statistics; - if let Some(casted_expr) = - agg_expr.as_any().downcast_ref::() - { - if casted_expr.expressions().len() == 1 { + if is_min(agg_expr) { + let exprs = agg_expr.expressions(); + if exprs.len() == 1 { // TODO optimize with exprs other than Column - if let Some(col_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() + if let Some(col_expr) = + exprs[0].as_any().downcast_ref::() { if let Precision::Exact(val) = &col_stats[col_expr.index()].min_value @@ -209,7 +203,7 @@ fn take_optimizable_min( if !val.is_null() { return Some(( val.clone(), - casted_expr.name().to_string(), + agg_expr.name().to_string(), )); } } @@ -232,26 +226,22 @@ fn take_optimizable_max( match *num_rows { 0 => { // MIN/MAX with 0 rows is always null - if let Some(casted_expr) = - agg_expr.as_any().downcast_ref::() - { + if is_max(agg_expr) { if let Ok(max_data_type) = - ScalarValue::try_from(casted_expr.field().unwrap().data_type()) + ScalarValue::try_from(agg_expr.field().unwrap().data_type()) { - return Some((max_data_type, casted_expr.name().to_string())); + return Some((max_data_type, agg_expr.name().to_string())); } } } value if value > 0 => { let col_stats = &stats.column_statistics; - if let Some(casted_expr) = - agg_expr.as_any().downcast_ref::() - { - if casted_expr.expressions().len() == 1 { + if is_max(agg_expr) { + let exprs = agg_expr.expressions(); + if exprs.len() == 1 { // TODO optimize with exprs other than Column - if let Some(col_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() + if let Some(col_expr) = + exprs[0].as_any().downcast_ref::() { if let Precision::Exact(val) = &col_stats[col_expr.index()].max_value @@ -259,7 +249,7 @@ fn take_optimizable_max( if !val.is_null() { return Some(( val.clone(), - casted_expr.name().to_string(), + agg_expr.name().to_string(), )); } } @@ -273,6 +263,50 @@ fn take_optimizable_max( None } +// TODO: Move this check into AggregateUDFImpl +// https://github.com/apache/datafusion/issues/11153 +fn is_non_distinct_count(agg_expr: &dyn AggregateExpr) -> bool { + if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { + if agg_expr.fun().name() == "count" && !agg_expr.is_distinct() { + return true; + } + } + + false +} + +// TODO: Move this check into AggregateUDFImpl +// https://github.com/apache/datafusion/issues/11153 +fn is_min(agg_expr: &dyn AggregateExpr) -> bool { + if agg_expr.as_any().is::() { + return true; + } + + if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { + if agg_expr.fun().name() == "min" { + return true; + } + } + + false +} + +// TODO: Move this check into AggregateUDFImpl +// https://github.com/apache/datafusion/issues/11153 +fn is_max(agg_expr: &dyn AggregateExpr) -> bool { + if agg_expr.as_any().is::() { + return true; + } + + if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { + if agg_expr.fun().name() == "max" { + return true; + } + } + + false +} + #[cfg(test)] pub(crate) mod tests { use super::*; From bd25e26747a271752b7f46aa0970022525eff05b Mon Sep 17 00:00:00 2001 From: Lordworms <48054792+Lordworms@users.noreply.github.com> Date: Fri, 12 Jul 2024 12:51:01 -0700 Subject: [PATCH 29/59] Implement TPCH substrait integration test, support tpch_13, tpch_14, tpch_16 (#11405) optimize code --- .../tests/cases/consumer_integration.rs | 86 +- .../tpch_substrait_plans/query_13.json | 624 +++++++++ .../tpch_substrait_plans/query_14.json | 924 +++++++++++++ .../tpch_substrait_plans/query_16.json | 1175 +++++++++++++++++ 4 files changed, 2808 insertions(+), 1 deletion(-) create mode 100644 datafusion/substrait/tests/testdata/tpch_substrait_plans/query_13.json create mode 100644 datafusion/substrait/tests/testdata/tpch_substrait_plans/query_14.json create mode 100644 datafusion/substrait/tests/testdata/tpch_substrait_plans/query_16.json diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index 10c1319b903b5..c8130220ef4ae 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -40,7 +40,6 @@ mod tests { } Ok(ctx) } - #[tokio::test] async fn tpch_test_1() -> Result<()> { let ctx = create_context(vec![( @@ -314,4 +313,89 @@ mod tests { \n TableScan: FILENAME_PLACEHOLDER_2 projection=[n_nationkey, n_name, n_regionkey, n_comment]"); Ok(()) } + + // missing query 12 + #[tokio::test] + async fn tpch_test_13() -> Result<()> { + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), + ]) + .await?; + let path = "tests/testdata/tpch_substrait_plans/query_13.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{:?}", plan); + assert_eq!(plan_str, "Projection: count(FILENAME_PLACEHOLDER_1.o_orderkey) AS C_COUNT, count(Int64(1)) AS CUSTDIST\ + \n Sort: count(Int64(1)) DESC NULLS FIRST, count(FILENAME_PLACEHOLDER_1.o_orderkey) DESC NULLS FIRST\ + \n Projection: count(FILENAME_PLACEHOLDER_1.o_orderkey), count(Int64(1))\ + \n Aggregate: groupBy=[[count(FILENAME_PLACEHOLDER_1.o_orderkey)]], aggr=[[count(Int64(1))]]\ + \n Projection: count(FILENAME_PLACEHOLDER_1.o_orderkey)\ + \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.c_custkey]], aggr=[[count(FILENAME_PLACEHOLDER_1.o_orderkey)]]\ + \n Projection: FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_1.o_orderkey\ + \n Left Join: FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey Filter: NOT FILENAME_PLACEHOLDER_1.o_comment LIKE CAST(Utf8(\"%special%requests%\") AS Utf8)\ + \n TableScan: FILENAME_PLACEHOLDER_0 projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_mktsegment, c_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_1 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]"); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_14() -> Result<()> { + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/lineitem.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/part.csv"), + ]) + .await?; + let path = "tests/testdata/tpch_substrait_plans/query_14.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{:?}", plan); + assert_eq!(plan_str, "Projection: Decimal128(Some(10000),5,2) * sum(CASE WHEN FILENAME_PLACEHOLDER_1.p_type LIKE CAST(Utf8(\"PROMO%\") AS Utf8) THEN FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount ELSE Decimal128(Some(0),19,0) END) / sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount) AS PROMO_REVENUE\ + \n Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN FILENAME_PLACEHOLDER_1.p_type LIKE CAST(Utf8(\"PROMO%\") AS Utf8) THEN FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount ELSE Decimal128(Some(0),19,0) END), sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount)]]\ + \n Projection: CASE WHEN FILENAME_PLACEHOLDER_1.p_type LIKE CAST(Utf8(\"PROMO%\") AS Utf8) THEN FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount) ELSE Decimal128(Some(0),19,0) END, FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount)\ + \n Filter: FILENAME_PLACEHOLDER_0.l_partkey = FILENAME_PLACEHOLDER_1.p_partkey AND FILENAME_PLACEHOLDER_0.l_shipdate >= Date32(\"1995-09-01\") AND FILENAME_PLACEHOLDER_0.l_shipdate < CAST(Utf8(\"1995-10-01\") AS Date32)\ + \n Inner Join: Filter: Boolean(true)\ + \n TableScan: FILENAME_PLACEHOLDER_0 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_1 projection=[p_partkey, p_name, p_mfgr, p_brand, p_type, p_size, p_container, p_retailprice, p_comment]"); + Ok(()) + } + // query 15 is missing + #[tokio::test] + async fn tpch_test_16() -> Result<()> { + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/part.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/supplier.csv"), + ]) + .await?; + let path = "tests/testdata/tpch_substrait_plans/query_16.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{:?}", plan); + assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_1.p_brand AS P_BRAND, FILENAME_PLACEHOLDER_1.p_type AS P_TYPE, FILENAME_PLACEHOLDER_1.p_size AS P_SIZE, count(DISTINCT FILENAME_PLACEHOLDER_0.ps_suppkey) AS SUPPLIER_CNT\ + \n Sort: count(DISTINCT FILENAME_PLACEHOLDER_0.ps_suppkey) DESC NULLS FIRST, FILENAME_PLACEHOLDER_1.p_brand ASC NULLS LAST, FILENAME_PLACEHOLDER_1.p_type ASC NULLS LAST, FILENAME_PLACEHOLDER_1.p_size ASC NULLS LAST\ + \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_1.p_brand, FILENAME_PLACEHOLDER_1.p_type, FILENAME_PLACEHOLDER_1.p_size]], aggr=[[count(DISTINCT FILENAME_PLACEHOLDER_0.ps_suppkey)]]\ + \n Projection: FILENAME_PLACEHOLDER_1.p_brand, FILENAME_PLACEHOLDER_1.p_type, FILENAME_PLACEHOLDER_1.p_size, FILENAME_PLACEHOLDER_0.ps_suppkey\ + \n Filter: FILENAME_PLACEHOLDER_1.p_partkey = FILENAME_PLACEHOLDER_0.ps_partkey AND FILENAME_PLACEHOLDER_1.p_brand != CAST(Utf8(\"Brand#45\") AS Utf8) AND NOT FILENAME_PLACEHOLDER_1.p_type LIKE CAST(Utf8(\"MEDIUM POLISHED%\") AS Utf8) AND (FILENAME_PLACEHOLDER_1.p_size = Int32(49) OR FILENAME_PLACEHOLDER_1.p_size = Int32(14) OR FILENAME_PLACEHOLDER_1.p_size = Int32(23) OR FILENAME_PLACEHOLDER_1.p_size = Int32(45) OR FILENAME_PLACEHOLDER_1.p_size = Int32(19) OR FILENAME_PLACEHOLDER_1.p_size = Int32(3) OR FILENAME_PLACEHOLDER_1.p_size = Int32(36) OR FILENAME_PLACEHOLDER_1.p_size = Int32(9)) AND NOT CAST(FILENAME_PLACEHOLDER_0.ps_suppkey IN () AS Boolean)\ + \n Subquery:\ + \n Projection: FILENAME_PLACEHOLDER_2.s_suppkey\ + \n Filter: FILENAME_PLACEHOLDER_2.s_comment LIKE CAST(Utf8(\"%Customer%Complaints%\") AS Utf8)\ + \n TableScan: FILENAME_PLACEHOLDER_2 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ + \n Inner Join: Filter: Boolean(true)\ + \n TableScan: FILENAME_PLACEHOLDER_0 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_1 projection=[p_partkey, p_name, p_mfgr, p_brand, p_type, p_size, p_container, p_retailprice, p_comment]"); + Ok(()) + } } diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_13.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_13.json new file mode 100644 index 0000000000000..c88e61e78304e --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_13.json @@ -0,0 +1,624 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 4, + "uri": "/functions_aggregate_generic.yaml" + }, + { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, + { + "extensionUriAnchor": 3, + "uri": "/functions_string.yaml" + }, + { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "and:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any1_any1" + } + }, + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 2, + "name": "not:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "like:vchar_vchar" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 4, + "name": "count:opt_any" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 5, + "name": "count:opt" + } + } + ], + "relations": [ + { + "root": { + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 2, + 3 + ] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 2 + ] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 17, + 18 + ] + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "C_CUSTKEY", + "C_NAME", + "C_ADDRESS", + "C_NATIONKEY", + "C_PHONE", + "C_ACCTBAL", + "C_MKTSEGMENT", + "C_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 40, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 117, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_0", + "parquet": {} + } + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "O_ORDERKEY", + "O_CUSTKEY", + "O_ORDERSTATUS", + "O_TOTALPRICE", + "O_ORDERDATE", + "O_ORDERPRIORITY", + "O_CLERK", + "O_SHIPPRIORITY", + "O_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 79, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_1", + "parquet": {} + } + ] + } + } + }, + "expression": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 79, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "%special%requests%", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + ] + } + } + } + ] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + } + ] + } + }, + "groupings": [ + { + "groupingExpressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + ] + } + ], + "measures": [ + { + "measure": { + "functionReference": 4, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + ] + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + ] + } + }, + "groupings": [ + { + "groupingExpressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + ] + } + ], + "measures": [ + { + "measure": { + "functionReference": 5, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [] + } + } + ] + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + ] + } + }, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }, + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + } + ] + } + }, + "names": [ + "C_COUNT", + "CUSTDIST" + ] + } + } + ], + "expectedTypeUrls": [] +} diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_14.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_14.json new file mode 100644 index 0000000000000..380b71df8aacc --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_14.json @@ -0,0 +1,924 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, + { + "extensionUriAnchor": 4, + "uri": "/functions_string.yaml" + }, + { + "extensionUriAnchor": 5, + "uri": "/functions_arithmetic_decimal.yaml" + }, + { + "extensionUriAnchor": 3, + "uri": "/functions_datetime.yaml" + }, + { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "and:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any1_any1" + } + }, + { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "gte:date_date" + } + }, + { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "lt:date_date" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 4, + "name": "like:vchar_vchar" + } + }, + { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 5, + "name": "multiply:opt_decimal_decimal" + } + }, + { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 6, + "name": "subtract:opt_decimal_decimal" + } + }, + { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 7, + "name": "sum:opt_decimal" + } + }, + { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 8, + "name": "divide:opt_decimal_decimal" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 2 + ] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 25, + 26 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "L_ORDERKEY", + "L_PARTKEY", + "L_SUPPKEY", + "L_LINENUMBER", + "L_QUANTITY", + "L_EXTENDEDPRICE", + "L_DISCOUNT", + "L_TAX", + "L_RETURNFLAG", + "L_LINESTATUS", + "L_SHIPDATE", + "L_COMMITDATE", + "L_RECEIPTDATE", + "L_SHIPINSTRUCT", + "L_SHIPMODE", + "L_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 44, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_0", + "parquet": {} + } + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "P_PARTKEY", + "P_NAME", + "P_MFGR", + "P_BRAND", + "P_TYPE", + "P_SIZE", + "P_CONTAINER", + "P_RETAILPRICE", + "P_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 55, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 23, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_1", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "date": 9374, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1995-10-01", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "ifThen": { + "ifs": [ + { + "if": { + "scalarFunction": { + "functionReference": 4, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 20 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "PROMO%", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + }, + "then": { + "scalarFunction": { + "functionReference": 5, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 6, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + } + ] + } + } + } + ], + "else": { + "literal": { + "decimal": { + "value": "AAAAAAAAAAAAAAAAAAAAAA==", + "precision": 19, + "scale": 0 + }, + "nullable": false, + "typeVariationReference": 0 + } + } + } + }, + { + "scalarFunction": { + "functionReference": 5, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 6, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + } + ] + } + } + ] + } + }, + "groupings": [ + { + "groupingExpressions": [] + } + ], + "measures": [ + { + "measure": { + "functionReference": 7, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + } + ] + } + }, + { + "measure": { + "functionReference": 7, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + ] + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 8, + "args": [], + "outputType": { + "decimal": { + "scale": 2, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 5, + "args": [], + "outputType": { + "decimal": { + "scale": 2, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "literal": { + "decimal": { + "value": "ECcAAAAAAAAAAAAAAAAAAA==", + "precision": 5, + "scale": 2 + }, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + ] + } + }, + "names": [ + "PROMO_REVENUE" + ] + } + } + ], + "expectedTypeUrls": [] +} diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_16.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_16.json new file mode 100644 index 0000000000000..f988aa7a76a26 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_16.json @@ -0,0 +1,1175 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 4, + "uri": "/functions_aggregate_generic.yaml" + }, + { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, + { + "extensionUriAnchor": 3, + "uri": "/functions_string.yaml" + }, + { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "and:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any1_any1" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 2, + "name": "not_equal:any1_any1" + } + }, + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 3, + "name": "not:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 4, + "name": "like:vchar_vchar" + } + }, + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 5, + "name": "or:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 6, + "name": "count:opt_any" + } + } + ], + "relations": [ + { + "root": { + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 14, + 15, + 16, + 17 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "PS_PARTKEY", + "PS_SUPPKEY", + "PS_AVAILQTY", + "PS_SUPPLYCOST", + "PS_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 199, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_0", + "parquet": {} + } + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "P_PARTKEY", + "P_NAME", + "P_MFGR", + "P_BRAND", + "P_TYPE", + "P_SIZE", + "P_CONTAINER", + "P_RETAILPRICE", + "P_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 55, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 23, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_1", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "Brand#45", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 4, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "MEDIUM POLISHED%", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 5, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 49, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 14, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 23, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 45, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 19, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 3, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 36, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 9, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "cast": { + "type": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "subquery": { + "inPredicate": { + "needles": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + ], + "haystack": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 7 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "S_SUPPKEY", + "S_NAME", + "S_ADDRESS", + "S_NATIONKEY", + "S_PHONE", + "S_ACCTBAL", + "S_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 40, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 101, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_2", + "parquet": {} + } + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 4, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 101, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "%Customer%Complaints%", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + ] + } + } + } + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + ] + } + }, + "groupings": [ + { + "groupingExpressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + ] + } + ], + "measures": [ + { + "measure": { + "functionReference": 6, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_DISTINCT", + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + ] + } + }, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }, + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }, + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }, + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + } + ] + } + }, + "names": [ + "P_BRAND", + "P_TYPE", + "P_SIZE", + "SUPPLIER_CNT" + ] + } + } + ], + "expectedTypeUrls": [] +} From 4d04a6ebbb0458495d2282df34e8b22001f3971d Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 12 Jul 2024 15:51:44 -0400 Subject: [PATCH 30/59] Minor: fix labeler rules (#11428) --- .github/workflows/dev_pr/labeler.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/dev_pr/labeler.yml b/.github/workflows/dev_pr/labeler.yml index 34a37948785b5..308abd1688a6d 100644 --- a/.github/workflows/dev_pr/labeler.yml +++ b/.github/workflows/dev_pr/labeler.yml @@ -17,11 +17,11 @@ development-process: - changed-files: - - any-glob-to-any-file: ['dev/**.*', '.github/**.*', 'ci/**.*', '.asf.yaml'] + - any-glob-to-any-file: ['dev/**/*', '.github/**/*', 'ci/**/*', '.asf.yaml'] documentation: - changed-files: - - any-glob-to-any-file: ['docs/**.*', 'README.md', './**/README.md', 'DEVELOPERS.md', 'datafusion/docs/**.*'] + - any-glob-to-any-file: ['docs/**/*', 'README.md', './**/README.md', 'DEVELOPERS.md', 'datafusion/docs/**/*'] sql: - changed-files: From 8f8df07c80aa66bb94d57c9619be93f9c3be92a9 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 12 Jul 2024 23:14:17 -0400 Subject: [PATCH 31/59] Minor: change internal error to not supported error for nested field access (#11446) --- datafusion/sql/src/expr/identifier.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index d297b2e4df5b3..39736b1fbba59 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -18,8 +18,8 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::Field; use datafusion_common::{ - internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result, - ScalarValue, TableReference, + internal_err, not_impl_err, plan_datafusion_err, Column, DFSchema, DataFusionError, + Result, ScalarValue, TableReference, }; use datafusion_expr::{expr::ScalarFunction, lit, Case, Expr}; use sqlparser::ast::{Expr as SQLExpr, Ident}; @@ -118,7 +118,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Though ideally once that support is in place, this code should work with it // TODO: remove when can support multiple nested identifiers if ids.len() > 5 { - return internal_err!("Unsupported compound identifier: {ids:?}"); + return not_impl_err!("Compound identifier: {ids:?}"); } let search_result = search_dfschema(&ids, schema); @@ -127,7 +127,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Some((field, qualifier, nested_names)) if !nested_names.is_empty() => { // TODO: remove when can support multiple nested identifiers if nested_names.len() > 1 { - return internal_err!( + return not_impl_err!( "Nested identifiers not yet supported for column {}", Column::from((qualifier, field)).quoted_flat_name() ); @@ -154,7 +154,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // return default where use all identifiers to not have a nested field // this len check is because at 5 identifiers will have to have a nested field if ids.len() == 5 { - internal_err!("Unsupported compound identifier: {ids:?}") + not_impl_err!("compound identifier: {ids:?}") } else { // check the outer_query_schema and try to find a match if let Some(outer) = planner_context.outer_query_schema() { @@ -165,7 +165,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if !nested_names.is_empty() => { // TODO: remove when can support nested identifiers for OuterReferenceColumn - internal_err!( + not_impl_err!( "Nested identifiers are not yet supported for OuterReferenceColumn {}", Column::from((qualifier, field)).quoted_flat_name() ) From 9e4a4a1599b9def33f27a6f82dd32045038de296 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 13 Jul 2024 05:33:51 -0400 Subject: [PATCH 32/59] Minor: change Datafusion --> DataFusion in docs (#11439) * Minor: change Datafusion --> DataFusion in docs * update expected --- datafusion-examples/README.md | 4 ++-- datafusion-examples/examples/expr_api.rs | 2 +- datafusion/common/src/config.rs | 2 +- datafusion/core/src/dataframe/mod.rs | 2 +- datafusion/expr/src/signature.rs | 2 +- datafusion/optimizer/src/unwrap_cast_in_comparison.rs | 2 +- datafusion/physical-expr/src/intervals/cp_solver.rs | 2 +- datafusion/physical-plan/src/aggregates/mod.rs | 2 +- datafusion/sql/src/parser.rs | 2 +- datafusion/sqllogictest/README.md | 2 +- datafusion/sqllogictest/test_files/information_schema.slt | 2 +- datafusion/sqllogictest/test_files/window.slt | 6 +++--- docs/source/contributor-guide/inviting.md | 2 +- docs/source/user-guide/configs.md | 2 +- 14 files changed, 17 insertions(+), 17 deletions(-) diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 2696f74775cf3..da01f60b527d9 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -71,8 +71,8 @@ cargo run --example dataframe - [`parquet_index.rs`](examples/parquet_index.rs): Create an secondary index over several parquet files and use it to speed up queries - [`parquet_sql_multiple_files.rs`](examples/parquet_sql_multiple_files.rs): Build and run a query plan from a SQL statement against multiple local Parquet files - [`parquet_exec_visitor.rs`](examples/parquet_exec_visitor.rs): Extract statistics by visiting an ExecutionPlan after execution -- [`parse_sql_expr.rs`](examples/parse_sql_expr.rs): Parse SQL text into Datafusion `Expr`. -- [`plan_to_sql.rs`](examples/plan_to_sql.rs): Generate SQL from Datafusion `Expr` and `LogicalPlan` +- [`parse_sql_expr.rs`](examples/parse_sql_expr.rs): Parse SQL text into DataFusion `Expr`. +- [`plan_to_sql.rs`](examples/plan_to_sql.rs): Generate SQL from DataFusion `Expr` and `LogicalPlan` - [`pruning.rs`](examples/pruning.rs): Use pruning to rule out files based on statistics - [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 - [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 43729a913e5d8..a5cf7011f8113 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -83,7 +83,7 @@ async fn main() -> Result<()> { Ok(()) } -/// Datafusion's `expr_fn` API makes it easy to create [`Expr`]s for the +/// DataFusion's `expr_fn` API makes it easy to create [`Expr`]s for the /// full range of expression types such as aggregates and window functions. fn expr_fn_demo() -> Result<()> { // Let's say you want to call the "first_value" aggregate function diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 1d2a9589adfc6..880f0119ce0da 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -309,7 +309,7 @@ config_namespace! { /// Currently experimental pub split_file_groups_by_statistics: bool, default = false - /// Should Datafusion keep the columns used for partition_by in the output RecordBatches + /// Should DataFusion keep the columns used for partition_by in the output RecordBatches pub keep_partition_by_columns: bool, default = false } } diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index d0f2852a6e53a..05a08a6378930 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1472,7 +1472,7 @@ impl DataFrame { /// /// The method supports case sensitive rename with wrapping column name into one of following symbols ( " or ' or ` ) /// - /// Alternatively setting Datafusion param `datafusion.sql_parser.enable_ident_normalization` to `false` will enable + /// Alternatively setting DataFusion param `datafusion.sql_parser.enable_ident_normalization` to `false` will enable /// case sensitive rename without need to wrap column name into special symbols /// /// # Example diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 33f643eb2dc2a..fba793dd229d3 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -93,7 +93,7 @@ pub enum TypeSignature { Variadic(Vec), /// The acceptable signature and coercions rules to coerce arguments to this /// signature are special for this function. If this signature is specified, - /// Datafusion will call [`ScalarUDFImpl::coerce_types`] to prepare argument types. + /// DataFusion will call [`ScalarUDFImpl::coerce_types`] to prepare argument types. /// /// [`ScalarUDFImpl::coerce_types`]: crate::udf::ScalarUDFImpl::coerce_types UserDefined, diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 3447082525597..9941da9dd65e0 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -1080,7 +1080,7 @@ mod tests { ), }; - // Datafusion ignores timezones for comparisons of ScalarValue + // DataFusion ignores timezones for comparisons of ScalarValue // so double check it here assert_eq!(lit_tz_none, lit_tz_utc); diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index fc4950ae4e7ca..f05ac3624b8e2 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -176,7 +176,7 @@ impl ExprIntervalGraphNode { &self.interval } - /// This function creates a DAEG node from Datafusion's [`ExprTreeNode`] + /// This function creates a DAEG node from DataFusion's [`ExprTreeNode`] /// object. Literals are created with definite, singleton intervals while /// any other expression starts with an indefinite interval ([-∞, ∞]). pub fn make_node(node: &ExprTreeNode, schema: &Schema) -> Result { diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 8caf10acf09b8..8bf808af3b5b8 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -324,7 +324,7 @@ impl AggregateExec { /// Create a new hash aggregate execution plan with the given schema. /// This constructor isn't part of the public API, it is used internally - /// by Datafusion to enforce schema consistency during when re-creating + /// by DataFusion to enforce schema consistency during when re-creating /// `AggregateExec`s inside optimization rules. Schema field names of an /// `AggregateExec` depends on the names of aggregate expressions. Since /// a rule may re-write aggregate expressions (e.g. reverse them) during diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 5da7f71765096..8147092c34aba 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -253,7 +253,7 @@ fn ensure_not_set(field: &Option, name: &str) -> Result<(), ParserError> { Ok(()) } -/// Datafusion SQL Parser based on [`sqlparser`] +/// DataFusion SQL Parser based on [`sqlparser`] /// /// Parses DataFusion's SQL dialect, often delegating to [`sqlparser`]'s [`Parser`]. /// diff --git a/datafusion/sqllogictest/README.md b/datafusion/sqllogictest/README.md index 930df47967762..c7f04c0d762c1 100644 --- a/datafusion/sqllogictest/README.md +++ b/datafusion/sqllogictest/README.md @@ -225,7 +225,7 @@ query ``` -- `test_name`: Uniquely identify the test name (Datafusion only) +- `test_name`: Uniquely identify the test name (DataFusion only) - `type_string`: A short string that specifies the number of result columns and the expected datatype of each result column. There is one character in the for each result column. The characters codes are: - 'B' - **B**oolean, diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index acd465a0c021f..95bea1223a9ce 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -257,7 +257,7 @@ datafusion.execution.batch_size 8192 Default batch size while creating new batch datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting datafusion.execution.collect_statistics false Should DataFusion collect statistics after listing files datafusion.execution.enable_recursive_ctes true Should DataFusion support recursive CTEs -datafusion.execution.keep_partition_by_columns false Should Datafusion keep the columns used for partition_by in the output RecordBatches +datafusion.execution.keep_partition_by_columns false Should DataFusion keep the columns used for partition_by in the output RecordBatches datafusion.execution.listing_table_ignore_subdirectory true Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 7f2e766aab915..a865a7ccbd8fb 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -2236,7 +2236,7 @@ SELECT SUM(c12) OVER(ORDER BY c1, c2 GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING) 7.728066219895 NULL # test_c9_rn_ordering_alias -# These tests check whether Datafusion is aware of the ordering generated by the ROW_NUMBER() window function. +# These tests check whether DataFusion is aware of the ordering generated by the ROW_NUMBER() window function. # Physical plan shouldn't have a SortExec after the BoundedWindowAggExec since the table after BoundedWindowAggExec is already ordered by rn1 ASC and c9 DESC. query TT EXPLAIN SELECT c9, rn1 FROM (SELECT c9, @@ -2275,7 +2275,7 @@ SELECT c9, rn1 FROM (SELECT c9, 145294611 5 # test_c9_rn_ordering_alias_opposite_direction -# These tests check whether Datafusion is aware of the ordering generated by the ROW_NUMBER() window function. +# These tests check whether DataFusion is aware of the ordering generated by the ROW_NUMBER() window function. # Physical plan shouldn't have a SortExec after the BoundedWindowAggExec since the table after BoundedWindowAggExec is already ordered by rn1 ASC and c9 DESC. query TT EXPLAIN SELECT c9, rn1 FROM (SELECT c9, @@ -2314,7 +2314,7 @@ SELECT c9, rn1 FROM (SELECT c9, 4076864659 5 # test_c9_rn_ordering_alias_opposite_direction2 -# These tests check whether Datafusion is aware of the ordering generated by the ROW_NUMBER() window function. +# These tests check whether DataFusion is aware of the ordering generated by the ROW_NUMBER() window function. # Physical plan _should_ have a SortExec after BoundedWindowAggExec since the table after BoundedWindowAggExec is ordered by rn1 ASC and c9 DESC, which is conflicting with the requirement rn1 DESC. query TT EXPLAIN SELECT c9, rn1 FROM (SELECT c9, diff --git a/docs/source/contributor-guide/inviting.md b/docs/source/contributor-guide/inviting.md index 967f417e6e9aa..4066dd9699eeb 100644 --- a/docs/source/contributor-guide/inviting.md +++ b/docs/source/contributor-guide/inviting.md @@ -59,7 +59,7 @@ the person. Here is an example: To: private@datafusion.apache.org Subject: [DISCUSS] $PERSONS_NAME for Committer -$PERSONS_NAME has been an active contributor to the Datafusion community for the +$PERSONS_NAME has been an active contributor to the DataFusion community for the last 6 months[1][2], helping others, answering questions, and improving the project's code. diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 579088f991ef2..5130b0a56d0e9 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -86,7 +86,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.listing_table_ignore_subdirectory | true | Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). | | datafusion.execution.enable_recursive_ctes | true | Should DataFusion support recursive CTEs | | datafusion.execution.split_file_groups_by_statistics | false | Attempt to eliminate sorts by packing & sorting files with non-overlapping statistics into the same file groups. Currently experimental | -| datafusion.execution.keep_partition_by_columns | false | Should Datafusion keep the columns used for partition_by in the output RecordBatches | +| datafusion.execution.keep_partition_by_columns | false | Should DataFusion keep the columns used for partition_by in the output RecordBatches | | datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | | datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | | datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | From 08fa444aaa8513a60ede5c57d92f29e6156b91a8 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Sat, 13 Jul 2024 17:34:45 +0800 Subject: [PATCH 33/59] fix: make sure JOIN ON expression is boolean type (#11423) * fix: make sure JOIN ON expression is boolean type * Applied to DataFrame * Update datafusion/optimizer/src/analyzer/type_coercion.rs Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb --- datafusion/core/src/dataframe/mod.rs | 31 +++++++++++++++++-- .../optimizer/src/analyzer/type_coercion.rs | 17 +++++++++- datafusion/sqllogictest/test_files/join.slt | 12 ++++++- 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 05a08a6378930..c55b7c752765d 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -896,9 +896,8 @@ impl DataFrame { join_type: JoinType, on_exprs: impl IntoIterator, ) -> Result { - let expr = on_exprs.into_iter().reduce(Expr::and); let plan = LogicalPlanBuilder::from(self.plan) - .join_on(right.plan, join_type, expr)? + .join_on(right.plan, join_type, on_exprs)? .build()?; Ok(DataFrame { session_state: self.session_state, @@ -1694,7 +1693,7 @@ mod tests { use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; use arrow::array::{self, Int32Array}; - use datafusion_common::{Constraint, Constraints}; + use datafusion_common::{Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, @@ -2555,6 +2554,32 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_on_filter_datatype() -> Result<()> { + let left = test_table_with_name("a").await?.select_columns(&["c1"])?; + let right = test_table_with_name("b").await?.select_columns(&["c1"])?; + + // JOIN ON untyped NULL + let join = left.clone().join_on( + right.clone(), + JoinType::Inner, + Some(Expr::Literal(ScalarValue::Null)), + )?; + let expected_plan = "CrossJoin:\ + \n TableScan: a projection=[c1], full_filters=[Boolean(NULL)]\ + \n TableScan: b projection=[c1]"; + assert_eq!(expected_plan, format!("{:?}", join.into_optimized_plan()?)); + + // JOIN ON expression must be boolean type + let join = left.join_on(right, JoinType::Inner, Some(lit("TRUE")))?; + let expected = join.into_optimized_plan().unwrap_err(); + assert_eq!( + expected.strip_backtrace(), + "type_coercion\ncaused by\nError during planning: Join condition must be boolean type, but got Utf8" + ); + Ok(()) + } + #[tokio::test] async fn join_ambiguous_filter() -> Result<()> { let left = test_table_with_name("a") diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 3cab474df84e0..80a8c864e4311 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -127,7 +127,7 @@ impl<'a> TypeCoercionRewriter<'a> { Self { schema } } - /// Coerce join equality expressions + /// Coerce join equality expressions and join filter /// /// Joins must be treated specially as their equality expressions are stored /// as a parallel list of left and right expressions, rather than a single @@ -151,9 +151,24 @@ impl<'a> TypeCoercionRewriter<'a> { }) .collect::>>()?; + // Join filter must be boolean + join.filter = join + .filter + .map(|expr| self.coerce_join_filter(expr)) + .transpose()?; + Ok(LogicalPlan::Join(join)) } + fn coerce_join_filter(&self, expr: Expr) -> Result { + let expr_type = expr.get_type(self.schema)?; + match expr_type { + DataType::Boolean => Ok(expr), + DataType::Null => expr.cast_to(&DataType::Boolean, self.schema), + other => plan_err!("Join condition must be boolean type, but got {other:?}"), + } + } + fn coerce_binary_op( &self, left: Expr, diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 12cb8b3985c76..efebba1779cf7 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -988,7 +988,6 @@ statement ok DROP TABLE department -# Test issue: https://github.com/apache/datafusion/issues/11269 statement ok CREATE TABLE t1 (v0 BIGINT) AS VALUES (-503661263); @@ -998,11 +997,22 @@ CREATE TABLE t2 (v0 DOUBLE) AS VALUES (-1.663563947387); statement ok CREATE TABLE t3 (v0 DOUBLE) AS VALUES (0.05112015193508901); +# Test issue: https://github.com/apache/datafusion/issues/11269 query RR SELECT t3.v0, t2.v0 FROM t1,t2,t3 WHERE t3.v0 >= t1.v0; ---- 0.051120151935 -1.663563947387 +# Test issue: https://github.com/apache/datafusion/issues/11414 +query IRR +SELECT * FROM t1 INNER JOIN t2 ON NULL RIGHT JOIN t3 ON TRUE; +---- +NULL NULL 0.051120151935 + +# ON expression must be boolean type +query error DataFusion error: type_coercion\ncaused by\nError during planning: Join condition must be boolean type, but got Utf8 +SELECT * FROM t1 INNER JOIN t2 ON 'TRUE' + statement ok DROP TABLE t1; From f5d88d1790eea85910ae5590a353ae17318f8401 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Sun, 14 Jul 2024 05:44:32 +0800 Subject: [PATCH 34/59] Support serialization/deserialization for custom physical exprs in proto (#11387) * Add PhysicalExtensionExprNode * regen proto * Add ser/de extension expr logic * Add test and fix clippy lint --- datafusion/proto/proto/datafusion.proto | 7 + datafusion/proto/src/generated/pbjson.rs | 124 +++++++++++++++ datafusion/proto/src/generated/prost.rs | 12 +- .../proto/src/physical_plan/from_proto.rs | 8 + datafusion/proto/src/physical_plan/mod.rs | 16 ++ .../proto/src/physical_plan/to_proto.rs | 19 ++- .../tests/cases/roundtrip_physical_plan.rs | 147 +++++++++++++++++- 7 files changed, 330 insertions(+), 3 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 345765b08be3c..9ef884531e320 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -836,6 +836,8 @@ message PhysicalExprNode { // was PhysicalDateTimeIntervalExprNode date_time_interval_expr = 17; PhysicalLikeExprNode like_expr = 18; + + PhysicalExtensionExprNode extension = 19; } } @@ -942,6 +944,11 @@ message PhysicalNegativeNode { PhysicalExprNode expr = 1; } +message PhysicalExtensionExprNode { + bytes expr = 1; + repeated PhysicalExprNode inputs = 2; +} + message FilterExecNode { PhysicalPlanNode input = 1; PhysicalExprNode expr = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 905f0d9849556..fa989480fad90 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -13543,6 +13543,9 @@ impl serde::Serialize for PhysicalExprNode { physical_expr_node::ExprType::LikeExpr(v) => { struct_ser.serialize_field("likeExpr", v)?; } + physical_expr_node::ExprType::Extension(v) => { + struct_ser.serialize_field("extension", v)?; + } } } struct_ser.end() @@ -13582,6 +13585,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "scalarUdf", "like_expr", "likeExpr", + "extension", ]; #[allow(clippy::enum_variant_names)] @@ -13602,6 +13606,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { WindowExpr, ScalarUdf, LikeExpr, + Extension, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -13639,6 +13644,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "windowExpr" | "window_expr" => Ok(GeneratedField::WindowExpr), "scalarUdf" | "scalar_udf" => Ok(GeneratedField::ScalarUdf), "likeExpr" | "like_expr" => Ok(GeneratedField::LikeExpr), + "extension" => Ok(GeneratedField::Extension), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -13771,6 +13777,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { return Err(serde::de::Error::duplicate_field("likeExpr")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::LikeExpr) +; + } + GeneratedField::Extension => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("extension")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Extension) ; } } @@ -13783,6 +13796,117 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { deserializer.deserialize_struct("datafusion.PhysicalExprNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PhysicalExtensionExprNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.expr.is_empty() { + len += 1; + } + if !self.inputs.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalExtensionExprNode", len)?; + if !self.expr.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("expr", pbjson::private::base64::encode(&self.expr).as_str())?; + } + if !self.inputs.is_empty() { + struct_ser.serialize_field("inputs", &self.inputs)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PhysicalExtensionExprNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "expr", + "inputs", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Expr, + Inputs, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "expr" => Ok(GeneratedField::Expr), + "inputs" => Ok(GeneratedField::Inputs), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PhysicalExtensionExprNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PhysicalExtensionExprNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut expr__ = None; + let mut inputs__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::Inputs => { + if inputs__.is_some() { + return Err(serde::de::Error::duplicate_field("inputs")); + } + inputs__ = Some(map_.next_value()?); + } + } + } + Ok(PhysicalExtensionExprNode { + expr: expr__.unwrap_or_default(), + inputs: inputs__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.PhysicalExtensionExprNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PhysicalExtensionNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index b16d26ee6e1e0..8407e545fe650 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1218,7 +1218,7 @@ pub struct PhysicalExtensionNode { pub struct PhysicalExprNode { #[prost( oneof = "physical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19" )] pub expr_type: ::core::option::Option, } @@ -1266,6 +1266,8 @@ pub mod physical_expr_node { ScalarUdf(super::PhysicalScalarUdfNode), #[prost(message, tag = "18")] LikeExpr(::prost::alloc::boxed::Box), + #[prost(message, tag = "19")] + Extension(super::PhysicalExtensionExprNode), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1456,6 +1458,14 @@ pub struct PhysicalNegativeNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct PhysicalExtensionExprNode { + #[prost(bytes = "vec", tag = "1")] + pub expr: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "2")] + pub inputs: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct FilterExecNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index e94bb3b8efcb4..52fbd5cbdcf64 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -394,6 +394,14 @@ pub fn parse_physical_expr( codec, )?, )), + ExprType::Extension(extension) => { + let inputs: Vec> = extension + .inputs + .iter() + .map(|e| parse_physical_expr(e, registry, input_schema, codec)) + .collect::>()?; + (codec.try_decode_expr(extension.expr.as_slice(), &inputs)?) as _ + } }; Ok(pexpr) diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 56e702704798f..e5429945e97ef 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -2018,6 +2018,22 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync { fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec) -> Result<()> { Ok(()) } + + fn try_decode_expr( + &self, + _buf: &[u8], + _inputs: &[Arc], + ) -> Result> { + not_impl_err!("PhysicalExtensionCodec is not provided") + } + + fn try_encode_expr( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + not_impl_err!("PhysicalExtensionCodec is not provided") + } } #[derive(Debug)] diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 5e982ad2afde8..9c95acc1dcf47 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -495,7 +495,24 @@ pub fn serialize_physical_expr( ))), }) } else { - internal_err!("physical_plan::to_proto() unsupported expression {value:?}") + let mut buf: Vec = vec![]; + match codec.try_encode_expr(Arc::clone(&value), &mut buf) { + Ok(_) => { + let inputs: Vec = value + .children() + .into_iter() + .map(|e| serialize_physical_expr(Arc::clone(e), codec)) + .collect::>()?; + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Extension( + protobuf::PhysicalExtensionExprNode { expr: buf, inputs }, + )), + }) + } + Err(e) => internal_err!( + "Unsupported physical expr and extension codec failed with [{e}]. Expr: {value:?}" + ), + } } } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index d8d85ace1a29e..2fcc65008fd8f 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::RecordBatch; use std::any::Any; +use std::fmt::Display; +use std::hash::Hasher; use std::ops::Deref; use std::sync::Arc; use std::vec; @@ -38,6 +41,7 @@ use datafusion::datasource::physical_plan::{ }; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; +use datafusion::physical_expr::aggregate::utils::down_cast_any_ref; use datafusion::physical_expr::expressions::Max; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; @@ -75,7 +79,7 @@ use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; -use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; +use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, @@ -658,6 +662,147 @@ async fn roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> { roundtrip_test(ParquetExec::builder(scan_config).build_arc()) } +#[test] +fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { + let scan_config = FileScanConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_schema: Arc::new(Schema::new(vec![Field::new( + "col", + DataType::Utf8, + false, + )])), + file_groups: vec![vec![PartitionedFile::new( + "/path/to/file.parquet".to_string(), + 1024, + )]], + statistics: Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(1024), + column_statistics: Statistics::unknown_column(&Arc::new(Schema::new(vec![ + Field::new("col", DataType::Utf8, false), + ]))), + }, + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + }; + + #[derive(Debug, Hash, Clone)] + struct CustomPredicateExpr { + inner: Arc, + } + impl Display for CustomPredicateExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "CustomPredicateExpr") + } + } + impl PartialEq for CustomPredicateExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.inner.eq(&x.inner)) + .unwrap_or(false) + } + } + impl PhysicalExpr for CustomPredicateExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + unreachable!() + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + unreachable!() + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + unreachable!() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.inner] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + todo!() + } + + fn dyn_hash(&self, _state: &mut dyn Hasher) { + unreachable!() + } + } + + #[derive(Debug)] + struct CustomPhysicalExtensionCodec; + impl PhysicalExtensionCodec for CustomPhysicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> Result> { + unreachable!() + } + + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + unreachable!() + } + + fn try_decode_expr( + &self, + buf: &[u8], + inputs: &[Arc], + ) -> Result> { + if buf == "CustomPredicateExpr".as_bytes() { + Ok(Arc::new(CustomPredicateExpr { + inner: inputs[0].clone(), + })) + } else { + internal_err!("Not supported") + } + } + + fn try_encode_expr( + &self, + node: Arc, + buf: &mut Vec, + ) -> Result<()> { + if node + .as_ref() + .as_any() + .downcast_ref::() + .is_some() + { + buf.extend_from_slice("CustomPredicateExpr".as_bytes()); + Ok(()) + } else { + internal_err!("Not supported") + } + } + } + + let custom_predicate_expr = Arc::new(CustomPredicateExpr { + inner: Arc::new(Column::new("col", 1)), + }); + let exec_plan = ParquetExec::builder(scan_config) + .with_predicate(custom_predicate_expr) + .build_arc(); + + let ctx = SessionContext::new(); + roundtrip_test_and_return(exec_plan, &ctx, &CustomPhysicalExtensionCodec {})?; + Ok(()) +} + #[test] fn roundtrip_scalar_udf() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); From a43cf79bf0b133379ee6f2a236c025e59a5ef822 Mon Sep 17 00:00:00 2001 From: kf zheng <100595273+Kev1n8@users.noreply.github.com> Date: Sun, 14 Jul 2024 05:45:03 +0800 Subject: [PATCH 35/59] remove termtree dependency (#11416) * remove termtree dependency * impl Display for TopKHeap, replace uses of tree_print in tests * use to_string instead of format! --- datafusion/physical-plan/Cargo.toml | 1 - .../physical-plan/src/aggregates/topk/heap.rs | 86 ++++++++++++------- 2 files changed, 55 insertions(+), 32 deletions(-) diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index f5f756417ebf8..00fc81ebde978 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -66,7 +66,6 @@ tokio = { workspace = true } [dev-dependencies] rstest = { workspace = true } rstest_reuse = "0.7.0" -termtree = "0.5.0" tokio = { workspace = true, features = [ "rt-multi-thread", "fs", diff --git a/datafusion/physical-plan/src/aggregates/topk/heap.rs b/datafusion/physical-plan/src/aggregates/topk/heap.rs index 51593f5c28cef..81eadbc018b34 100644 --- a/datafusion/physical-plan/src/aggregates/topk/heap.rs +++ b/datafusion/physical-plan/src/aggregates/topk/heap.rs @@ -27,7 +27,7 @@ use datafusion_common::Result; use datafusion_physical_expr::aggregate::utils::adjust_output_array; use half::f16; use std::cmp::Ordering; -use std::fmt::{Debug, Formatter}; +use std::fmt::{Debug, Display, Formatter}; use std::sync::Arc; /// A custom version of `Ord` that only exists to we can implement it for the Values in our heap @@ -323,29 +323,53 @@ impl TopKHeap { } } - #[cfg(test)] - fn _tree_print(&self, idx: usize) -> Option> { - let hi = self.heap.get(idx)?; - match hi { - None => None, - Some(hi) => { - let label = - format!("val={:?} idx={}, bucket={}", hi.val, idx, hi.map_idx); - let left = self._tree_print(idx * 2 + 1); - let right = self._tree_print(idx * 2 + 2); - let children = left.into_iter().chain(right); - let me = termtree::Tree::new(label).with_leaves(children); - Some(me) + fn _tree_print( + &self, + idx: usize, + prefix: String, + is_tail: bool, + output: &mut String, + ) { + if let Some(Some(hi)) = self.heap.get(idx) { + let connector = if idx != 0 { + if is_tail { + "└── " + } else { + "├── " + } + } else { + "" + }; + output.push_str(&format!( + "{}{}val={:?} idx={}, bucket={}\n", + prefix, connector, hi.val, idx, hi.map_idx + )); + let new_prefix = if is_tail { "" } else { "│ " }; + let child_prefix = format!("{}{}", prefix, new_prefix); + + let left_idx = idx * 2 + 1; + let right_idx = idx * 2 + 2; + + let left_exists = left_idx < self.len; + let right_exists = right_idx < self.len; + + if left_exists { + self._tree_print(left_idx, child_prefix.clone(), !right_exists, output); + } + if right_exists { + self._tree_print(right_idx, child_prefix, true, output); } } } +} - #[cfg(test)] - fn tree_print(&self) -> String { - match self._tree_print(0) { - None => "".to_string(), - Some(root) => format!("{}", root), +impl Display for TopKHeap { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut output = String::new(); + if self.heap.first().is_some() { + self._tree_print(0, String::new(), true, &mut output); } + write!(f, "{}", output) } } @@ -361,9 +385,9 @@ impl HeapItem { impl Debug for HeapItem { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_str("bucket=")?; - self.map_idx.fmt(f)?; + Debug::fmt(&self.map_idx, f)?; f.write_str(" val=")?; - self.val.fmt(f)?; + Debug::fmt(&self.val, f)?; f.write_str("\n")?; Ok(()) } @@ -462,7 +486,7 @@ mod tests { let mut heap = TopKHeap::new(10, false); heap.append_or_replace(1, 1, &mut map); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=1 idx=0, bucket=1 "#; @@ -482,7 +506,7 @@ val=1 idx=0, bucket=1 heap.append_or_replace(2, 2, &mut map); assert_eq!(map, vec![(2, 0), (1, 1)]); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=2 idx=0, bucket=2 └── val=1 idx=1, bucket=1 @@ -500,7 +524,7 @@ val=2 idx=0, bucket=2 heap.append_or_replace(1, 1, &mut map); heap.append_or_replace(2, 2, &mut map); heap.append_or_replace(3, 3, &mut map); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=3 idx=0, bucket=3 ├── val=1 idx=1, bucket=1 @@ -510,7 +534,7 @@ val=3 idx=0, bucket=3 let mut map = vec![]; heap.append_or_replace(0, 0, &mut map); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=2 idx=0, bucket=2 ├── val=1 idx=1, bucket=1 @@ -531,7 +555,7 @@ val=2 idx=0, bucket=2 heap.append_or_replace(2, 2, &mut map); heap.append_or_replace(3, 3, &mut map); heap.append_or_replace(4, 4, &mut map); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=4 idx=0, bucket=4 ├── val=3 idx=1, bucket=3 @@ -542,7 +566,7 @@ val=4 idx=0, bucket=4 let mut map = vec![]; heap.replace_if_better(1, 0, &mut map); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=4 idx=0, bucket=4 ├── val=1 idx=1, bucket=1 @@ -563,7 +587,7 @@ val=4 idx=0, bucket=4 heap.append_or_replace(1, 1, &mut map); heap.append_or_replace(2, 2, &mut map); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=2 idx=0, bucket=2 └── val=1 idx=1, bucket=1 @@ -584,7 +608,7 @@ val=2 idx=0, bucket=2 heap.append_or_replace(1, 1, &mut map); heap.append_or_replace(2, 2, &mut map); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=2 idx=0, bucket=2 └── val=1 idx=1, bucket=1 @@ -607,7 +631,7 @@ val=2 idx=0, bucket=2 heap.append_or_replace(1, 1, &mut map); heap.append_or_replace(2, 2, &mut map); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=2 idx=0, bucket=2 └── val=1 idx=1, bucket=1 @@ -616,7 +640,7 @@ val=2 idx=0, bucket=2 let numbers = vec![(0, 1), (1, 2)]; heap.renumber(numbers.as_slice()); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=2 idx=0, bucket=1 └── val=1 idx=1, bucket=2 From a7041feff32c2af09854c144a760d945e30fb38a Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sun, 14 Jul 2024 05:47:47 +0800 Subject: [PATCH 36/59] Minor: Add an example for backtrace pretty print (#11450) * add the example for printing backtrace pretty * add empty end line * fix prettier * sync the usage example * Update docs/source/user-guide/crate-configuration.md Co-authored-by: Oleks V --------- Co-authored-by: Oleks V --- docs/source/user-guide/crate-configuration.md | 44 ++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/docs/source/user-guide/crate-configuration.md b/docs/source/user-guide/crate-configuration.md index 0587d06a39191..9d22e3403097f 100644 --- a/docs/source/user-guide/crate-configuration.md +++ b/docs/source/user-guide/crate-configuration.md @@ -121,7 +121,7 @@ backtrace: 0: std::backtrace_rs::backtrace::libunwind::trace The backtraces are useful when debugging code. If there is a test in `datafusion/core/src/physical_planner.rs` -``` +```rust #[tokio::test] async fn test_get_backtrace_for_failed_code() -> Result<()> { let ctx = SessionContext::new(); @@ -141,6 +141,48 @@ To obtain a backtrace: ```bash cargo build --features=backtrace RUST_BACKTRACE=1 cargo test --features=backtrace --package datafusion --lib -- physical_planner::tests::test_get_backtrace_for_failed_code --exact --nocapture + +running 1 test +Error: Plan("Invalid function 'row_numer'.\nDid you mean 'ROW_NUMBER'?\n\nbacktrace: 0: std::backtrace_rs::backtrace::libunwind::trace\n at /rustc/129f3b9964af4d4a709d1383930ade12dfe7c081/library/std/src/../../backtrace/src/backtrace/libunwind.rs:105:5\n 1: std::backtrace_rs::backtrace::trace_unsynchronized\n... ``` Note: The backtrace wrapped into systems calls, so some steps on top of the backtrace can be ignored + +To show the backtrace in a pretty-printed format use `eprintln!("{e}");`. + +```rust +#[tokio::test] +async fn test_get_backtrace_for_failed_code() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "select row_numer() over (partition by a order by a) from (select 1 a);"; + + let _ = match ctx.sql(sql).await { + Ok(result) => result.show().await?, + Err(e) => { + eprintln!("{e}"); + } + }; + + Ok(()) +} +``` + +Then run the test: + +```bash +$ RUST_BACKTRACE=1 cargo test --features=backtrace --package datafusion --lib -- physical_planner::tests::test_get_backtrace_for_failed_code --exact --nocapture + +running 1 test +Error during planning: Invalid function 'row_numer'. +Did you mean 'ROW_NUMBER'? + +backtrace: 0: std::backtrace_rs::backtrace::libunwind::trace + at /rustc/129f3b9964af4d4a709d1383930ade12dfe7c081/library/std/src/../../backtrace/src/backtrace/libunwind.rs:105:5 + 1: std::backtrace_rs::backtrace::trace_unsynchronized + at /rustc/129f3b9964af4d4a709d1383930ade12dfe7c081/library/std/src/../../backtrace/src/backtrace/mod.rs:66:5 + 2: std::backtrace::Backtrace::create + at /rustc/129f3b9964af4d4a709d1383930ade12dfe7c081/library/std/src/backtrace.rs:331:13 + 3: std::backtrace::Backtrace::capture + ... +``` From 84758062f808f97ba3b7e9d8a9d3839df4c39d98 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Sun, 14 Jul 2024 15:00:31 -0400 Subject: [PATCH 37/59] Add SessionStateBuilder and extract out the registration of defaults (#11403) * Create a SessionStateBuilder and use it for creating anything but a basic SessionState. * Updated new_from_existing to take a reference to the existing SessionState and clone it. * Minor documentation update. * SessionStateDefaults improvements. * Reworked how SessionStateBuilder works from PR feedback. * Bug fix for missing array_expressions cfg feature. * Review feedback updates + doc fixes for SessionStateDefaults * Cargo fmt update. --- datafusion-cli/src/catalog.rs | 11 +- .../examples/custom_file_format.rs | 9 +- .../core/src/datasource/file_format/csv.rs | 7 +- datafusion/core/src/execution/context/mod.rs | 25 +- .../core/src/execution/session_state.rs | 965 ++++++++++++++---- datafusion/core/src/physical_planner.rs | 7 +- datafusion/core/src/test/object_store.rs | 8 +- datafusion/core/tests/dataframe/mod.rs | 19 +- datafusion/core/tests/memory_limit/mod.rs | 14 +- .../core/tests/parquet/file_statistics.rs | 6 +- datafusion/core/tests/sql/create_drop.rs | 13 +- .../tests/user_defined/user_defined_plan.rs | 11 +- .../tests/cases/roundtrip_logical_plan.rs | 8 +- .../tests/cases/roundtrip_logical_plan.rs | 13 +- 14 files changed, 884 insertions(+), 232 deletions(-) diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index c11eb3280c20f..b83f659756105 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -29,6 +29,7 @@ use datafusion::datasource::listing::{ use datafusion::datasource::TableProvider; use datafusion::error::Result; use datafusion::execution::context::SessionState; +use datafusion::execution::session_state::SessionStateBuilder; use async_trait::async_trait; use dirs::home_dir; @@ -162,6 +163,7 @@ impl SchemaProvider for DynamicFileSchemaProvider { .ok_or_else(|| plan_datafusion_err!("locking error"))? .read() .clone(); + let mut builder = SessionStateBuilder::from(state.clone()); let optimized_name = substitute_tilde(name.to_owned()); let table_url = ListingTableUrl::parse(optimized_name.as_str())?; let scheme = table_url.scheme(); @@ -178,13 +180,18 @@ impl SchemaProvider for DynamicFileSchemaProvider { // to any command options so the only choice is to use an empty collection match scheme { "s3" | "oss" | "cos" => { - state = state.add_table_options_extension(AwsOptions::default()); + if let Some(table_options) = builder.table_options() { + table_options.extensions.insert(AwsOptions::default()) + } } "gs" | "gcs" => { - state = state.add_table_options_extension(GcpOptions::default()) + if let Some(table_options) = builder.table_options() { + table_options.extensions.insert(GcpOptions::default()) + } } _ => {} }; + state = builder.build(); let store = get_object_store( &state, table_url.scheme(), diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs index fe936418bce4a..bdb702375c945 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_file_format.rs @@ -22,6 +22,7 @@ use arrow::{ datatypes::UInt64Type, }; use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::{ datasource::{ file_format::{ @@ -32,9 +33,9 @@ use datafusion::{ MemTable, }, error::Result, - execution::{context::SessionState, runtime_env::RuntimeEnv}, + execution::context::SessionState, physical_plan::ExecutionPlan, - prelude::{SessionConfig, SessionContext}, + prelude::SessionContext, }; use datafusion_common::{GetExt, Statistics}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; @@ -176,9 +177,7 @@ impl GetExt for TSVFileFactory { #[tokio::main] async fn main() -> Result<()> { // Create a new context with the default configuration - let config = SessionConfig::new(); - let runtime = RuntimeEnv::default(); - let mut state = SessionState::new_with_config_rt(config, Arc::new(runtime)); + let mut state = SessionStateBuilder::new().with_default_features().build(); // Register the custom file format let file_format = Arc::new(TSVFileFactory::new()); diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 92cb11e2b47a4..baeaf51fb56d1 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -632,6 +632,7 @@ mod tests { use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::{col, lit}; + use crate::execution::session_state::SessionStateBuilder; use chrono::DateTime; use object_store::local::LocalFileSystem; use object_store::path::Path; @@ -814,7 +815,11 @@ mod tests { let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::new()).unwrap()); let mut cfg = SessionConfig::new(); cfg.options_mut().catalog.has_header = true; - let session_state = SessionState::new_with_config_rt(cfg, runtime); + let session_state = SessionStateBuilder::new() + .with_config(cfg) + .with_runtime_env(runtime) + .with_default_features() + .build(); let integration = LocalFileSystem::new_with_prefix(arrow_test_data()).unwrap(); let path = Path::from("csv/aggregate_test_100.csv"); let csv = CsvFormat::default().with_has_header(true); diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 4b9e3e843341a..640a9b14a65f1 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -73,6 +73,7 @@ use object_store::ObjectStore; use parking_lot::RwLock; use url::Url; +use crate::execution::session_state::SessionStateBuilder; pub use datafusion_execution::config::SessionConfig; pub use datafusion_execution::TaskContext; pub use datafusion_expr::execution_props::ExecutionProps; @@ -294,7 +295,11 @@ impl SessionContext { /// all `SessionContext`'s should be configured with the /// same `RuntimeEnv`. pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - let state = SessionState::new_with_config_rt(config, runtime); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build(); Self::new_with_state(state) } @@ -315,7 +320,7 @@ impl SessionContext { } /// Creates a new `SessionContext` using the provided [`SessionState`] - #[deprecated(since = "32.0.0", note = "Use SessionState::new_with_state")] + #[deprecated(since = "32.0.0", note = "Use SessionContext::new_with_state")] pub fn with_state(state: SessionState) -> Self { Self::new_with_state(state) } @@ -1574,6 +1579,7 @@ mod tests { use datafusion_common_runtime::SpawnedTask; use crate::catalog::schema::SchemaProvider; + use crate::execution::session_state::SessionStateBuilder; use crate::physical_planner::PhysicalPlanner; use async_trait::async_trait; use tempfile::TempDir; @@ -1707,7 +1713,11 @@ mod tests { .set_str("datafusion.catalog.location", url.as_str()) .set_str("datafusion.catalog.format", "CSV") .set_str("datafusion.catalog.has_header", "true"); - let session_state = SessionState::new_with_config_rt(cfg, runtime); + let session_state = SessionStateBuilder::new() + .with_config(cfg) + .with_runtime_env(runtime) + .with_default_features() + .build(); let ctx = SessionContext::new_with_state(session_state); ctx.refresh_catalogs().await?; @@ -1733,9 +1743,12 @@ mod tests { #[tokio::test] async fn custom_query_planner() -> Result<()> { let runtime = Arc::new(RuntimeEnv::default()); - let session_state = - SessionState::new_with_config_rt(SessionConfig::new(), runtime) - .with_query_planner(Arc::new(MyQueryPlanner {})); + let session_state = SessionStateBuilder::new() + .with_config(SessionConfig::new()) + .with_runtime_env(runtime) + .with_default_features() + .with_query_planner(Arc::new(MyQueryPlanner {})) + .build(); let ctx = SessionContext::new_with_state(session_state); let df = ctx.sql("SELECT 1").await?; diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index dbfba9ea93521..75eef43454873 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -77,6 +77,8 @@ use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; use datafusion_sql::parser::{DFParser, Statement}; use datafusion_sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}; +use itertools::Itertools; +use log::{debug, info}; use sqlparser::ast::Expr as SQLExpr; use sqlparser::dialect::dialect_from_str; use std::collections::hash_map::Entry; @@ -89,9 +91,29 @@ use uuid::Uuid; /// Execution context for registering data sources and executing queries. /// See [`SessionContext`] for a higher level API. /// +/// Use the [`SessionStateBuilder`] to build a SessionState object. +/// +/// ``` +/// use datafusion::prelude::*; +/// # use datafusion::{error::Result, assert_batches_eq}; +/// # use datafusion::execution::session_state::SessionStateBuilder; +/// # use datafusion_execution::runtime_env::RuntimeEnv; +/// # use std::sync::Arc; +/// # #[tokio::main] +/// # async fn main() -> Result<()> { +/// let state = SessionStateBuilder::new() +/// .with_config(SessionConfig::new()) +/// .with_runtime_env(Arc::new(RuntimeEnv::default())) +/// .with_default_features() +/// .build(); +/// Ok(()) +/// # } +/// ``` +/// /// Note that there is no `Default` or `new()` for SessionState, /// to avoid accidentally running queries or other operations without passing through -/// the [`SessionConfig`] or [`RuntimeEnv`]. See [`SessionContext`]. +/// the [`SessionConfig`] or [`RuntimeEnv`]. See [`SessionStateBuilder`] and +/// [`SessionContext`]. /// /// [`SessionContext`]: crate::execution::context::SessionContext #[derive(Clone)] @@ -140,7 +162,6 @@ pub struct SessionState { table_factories: HashMap>, /// Runtime environment runtime_env: Arc, - /// [FunctionFactory] to support pluggable user defined function handler. /// /// It will be invoked on `CREATE FUNCTION` statements. @@ -153,6 +174,7 @@ impl Debug for SessionState { f.debug_struct("SessionState") .field("session_id", &self.session_id) .field("analyzer", &"...") + .field("expr_planners", &"...") .field("optimizer", &"...") .field("physical_optimizers", &"...") .field("query_planner", &"...") @@ -175,193 +197,56 @@ impl Debug for SessionState { impl SessionState { /// Returns new [`SessionState`] using the provided /// [`SessionConfig`] and [`RuntimeEnv`]. + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - let catalog_list = - Arc::new(MemoryCatalogProviderList::new()) as Arc; - Self::new_with_config_rt_and_catalog_list(config, runtime, catalog_list) + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build() } /// Returns new [`SessionState`] using the provided /// [`SessionConfig`] and [`RuntimeEnv`]. - #[deprecated(since = "32.0.0", note = "Use SessionState::new_with_config_rt")] + #[deprecated(since = "32.0.0", note = "Use SessionStateBuilder")] pub fn with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - Self::new_with_config_rt(config, runtime) + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build() } /// Returns new [`SessionState`] using the provided /// [`SessionConfig`], [`RuntimeEnv`], and [`CatalogProviderList`] + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] pub fn new_with_config_rt_and_catalog_list( config: SessionConfig, runtime: Arc, catalog_list: Arc, ) -> Self { - let session_id = Uuid::new_v4().to_string(); - - // Create table_factories for all default formats - let mut table_factories: HashMap> = - HashMap::new(); - #[cfg(feature = "parquet")] - table_factories.insert("PARQUET".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("CSV".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("JSON".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); - - if config.create_default_catalog_and_schema() { - let default_catalog = MemoryCatalogProvider::new(); - - default_catalog - .register_schema( - &config.options().catalog.default_schema, - Arc::new(MemorySchemaProvider::new()), - ) - .expect("memory catalog provider can register schema"); - - Self::register_default_schema( - &config, - &table_factories, - &runtime, - &default_catalog, - ); - - catalog_list.register_catalog( - config.options().catalog.default_catalog.clone(), - Arc::new(default_catalog), - ); - } - - let expr_planners: Vec> = vec![ - Arc::new(functions::core::planner::CoreFunctionPlanner::default()), - // register crate of array expressions (if enabled) - #[cfg(feature = "array_expressions")] - Arc::new(functions_array::planner::ArrayFunctionPlanner), - #[cfg(feature = "array_expressions")] - Arc::new(functions_array::planner::FieldAccessPlanner), - #[cfg(any( - feature = "datetime_expressions", - feature = "unicode_expressions" - ))] - Arc::new(functions::planner::UserDefinedFunctionPlanner), - ]; - - let mut new_self = SessionState { - session_id, - analyzer: Analyzer::new(), - expr_planners, - optimizer: Optimizer::new(), - physical_optimizers: PhysicalOptimizer::new(), - query_planner: Arc::new(DefaultQueryPlanner {}), - catalog_list, - table_functions: HashMap::new(), - scalar_functions: HashMap::new(), - aggregate_functions: HashMap::new(), - window_functions: HashMap::new(), - serializer_registry: Arc::new(EmptySerializerRegistry), - file_formats: HashMap::new(), - table_options: TableOptions::default_from_session_config(config.options()), - config, - execution_props: ExecutionProps::new(), - runtime_env: runtime, - table_factories, - function_factory: None, - }; - - #[cfg(feature = "parquet")] - if let Err(e) = - new_self.register_file_format(Arc::new(ParquetFormatFactory::new()), false) - { - log::info!("Unable to register default ParquetFormat: {e}") - }; - - if let Err(e) = - new_self.register_file_format(Arc::new(JsonFormatFactory::new()), false) - { - log::info!("Unable to register default JsonFormat: {e}") - }; - - if let Err(e) = - new_self.register_file_format(Arc::new(CsvFormatFactory::new()), false) - { - log::info!("Unable to register default CsvFormat: {e}") - }; - - if let Err(e) = - new_self.register_file_format(Arc::new(ArrowFormatFactory::new()), false) - { - log::info!("Unable to register default ArrowFormat: {e}") - }; - - if let Err(e) = - new_self.register_file_format(Arc::new(AvroFormatFactory::new()), false) - { - log::info!("Unable to register default AvroFormat: {e}") - }; - - // register built in functions - functions::register_all(&mut new_self) - .expect("can not register built in functions"); - - // register crate of array expressions (if enabled) - #[cfg(feature = "array_expressions")] - functions_array::register_all(&mut new_self) - .expect("can not register array expressions"); - - functions_aggregate::register_all(&mut new_self) - .expect("can not register aggregate functions"); - - new_self + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_catalog_list(catalog_list) + .with_default_features() + .build() } + /// Returns new [`SessionState`] using the provided /// [`SessionConfig`] and [`RuntimeEnv`]. - #[deprecated( - since = "32.0.0", - note = "Use SessionState::new_with_config_rt_and_catalog_list" - )] + #[deprecated(since = "32.0.0", note = "Use SessionStateBuilder")] pub fn with_config_rt_and_catalog_list( config: SessionConfig, runtime: Arc, catalog_list: Arc, ) -> Self { - Self::new_with_config_rt_and_catalog_list(config, runtime, catalog_list) - } - fn register_default_schema( - config: &SessionConfig, - table_factories: &HashMap>, - runtime: &Arc, - default_catalog: &MemoryCatalogProvider, - ) { - let url = config.options().catalog.location.as_ref(); - let format = config.options().catalog.format.as_ref(); - let (url, format) = match (url, format) { - (Some(url), Some(format)) => (url, format), - _ => return, - }; - let url = url.to_string(); - let format = format.to_string(); - - let url = Url::parse(url.as_str()).expect("Invalid default catalog location!"); - let authority = match url.host_str() { - Some(host) => format!("{}://{}", url.scheme(), host), - None => format!("{}://", url.scheme()), - }; - let path = &url.as_str()[authority.len()..]; - let path = object_store::path::Path::parse(path).expect("Can't parse path"); - let store = ObjectStoreUrl::parse(authority.as_str()) - .expect("Invalid default catalog url"); - let store = match runtime.object_store(store) { - Ok(store) => store, - _ => return, - }; - let factory = match table_factories.get(format.as_str()) { - Some(factory) => factory, - _ => return, - }; - let schema = - ListingSchemaProvider::new(authority, path, factory.clone(), store, format); - let _ = default_catalog - .register_schema("default", Arc::new(schema)) - .expect("Failed to register default schema"); + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_catalog_list(catalog_list) + .with_default_features() + .build() } pub(crate) fn resolve_table_ref( @@ -400,12 +285,14 @@ impl SessionState { }) } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Replace the random session id. pub fn with_session_id(mut self, session_id: String) -> Self { self.session_id = session_id; self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// override default query planner with `query_planner` pub fn with_query_planner( mut self, @@ -415,6 +302,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Override the [`AnalyzerRule`]s optimizer plan rules. pub fn with_analyzer_rules( mut self, @@ -424,6 +312,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Replace the entire list of [`OptimizerRule`]s used to optimize plans pub fn with_optimizer_rules( mut self, @@ -433,6 +322,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Replace the entire list of [`PhysicalOptimizerRule`]s used to optimize plans pub fn with_physical_optimizer_rules( mut self, @@ -452,6 +342,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Add `optimizer_rule` to the end of the list of /// [`OptimizerRule`]s used to rewrite queries. pub fn add_optimizer_rule( @@ -472,6 +363,7 @@ impl SessionState { self.optimizer.rules.push(optimizer_rule); } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Add `physical_optimizer_rule` to the end of the list of /// [`PhysicalOptimizerRule`]s used to rewrite queries. pub fn add_physical_optimizer_rule( @@ -482,6 +374,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Adds a new [`ConfigExtension`] to TableOptions pub fn add_table_options_extension( mut self, @@ -491,6 +384,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements pub fn with_function_factory( mut self, @@ -505,6 +399,7 @@ impl SessionState { self.function_factory = Some(function_factory); } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Replace the extension [`SerializerRegistry`] pub fn with_serializer_registry( mut self, @@ -858,19 +753,20 @@ impl SessionState { &self.table_options } - /// Return mutable table opptions + /// Return mutable table options pub fn table_options_mut(&mut self) -> &mut TableOptions { &mut self.table_options } - /// Registers a [`ConfigExtension`] as a table option extention that can be + /// Registers a [`ConfigExtension`] as a table option extension that can be /// referenced from SQL statements executed against this context. pub fn register_table_options_extension(&mut self, extension: T) { self.table_options.extensions.insert(extension) } - /// Adds or updates a [FileFormatFactory] which can be used with COPY TO or CREATE EXTERNAL TABLE statements for reading - /// and writing files of custom formats. + /// Adds or updates a [FileFormatFactory] which can be used with COPY TO or + /// CREATE EXTERNAL TABLE statements for reading and writing files of custom + /// formats. pub fn register_file_format( &mut self, file_format: Arc, @@ -950,7 +846,7 @@ impl SessionState { ); } - /// Deregsiter a user defined table function + /// Deregister a user defined table function pub fn deregister_udtf( &mut self, name: &str, @@ -974,6 +870,733 @@ impl SessionState { } } +/// A builder to be used for building [`SessionState`]'s. Defaults will +/// be used for all values unless explicitly provided. +/// +/// See example on [`SessionState`] +pub struct SessionStateBuilder { + session_id: Option, + analyzer: Option, + expr_planners: Option>>, + optimizer: Option, + physical_optimizers: Option, + query_planner: Option>, + catalog_list: Option>, + table_functions: Option>>, + scalar_functions: Option>>, + aggregate_functions: Option>>, + window_functions: Option>>, + serializer_registry: Option>, + file_formats: Option>>, + config: Option, + table_options: Option, + execution_props: Option, + table_factories: Option>>, + runtime_env: Option>, + function_factory: Option>, + // fields to support convenience functions + analyzer_rules: Option>>, + optimizer_rules: Option>>, + physical_optimizer_rules: Option>>, +} + +impl SessionStateBuilder { + /// Returns a new [`SessionStateBuilder`] with no options set. + pub fn new() -> Self { + Self { + session_id: None, + analyzer: None, + expr_planners: None, + optimizer: None, + physical_optimizers: None, + query_planner: None, + catalog_list: None, + table_functions: None, + scalar_functions: None, + aggregate_functions: None, + window_functions: None, + serializer_registry: None, + file_formats: None, + table_options: None, + config: None, + execution_props: None, + table_factories: None, + runtime_env: None, + function_factory: None, + // fields to support convenience functions + analyzer_rules: None, + optimizer_rules: None, + physical_optimizer_rules: None, + } + } + + /// Returns a new [SessionStateBuilder] based on an existing [SessionState] + /// The session id for the new builder will be unset; all other fields will + /// be cloned from what is set in the provided session state + pub fn new_from_existing(existing: SessionState) -> Self { + Self { + session_id: None, + analyzer: Some(existing.analyzer), + expr_planners: Some(existing.expr_planners), + optimizer: Some(existing.optimizer), + physical_optimizers: Some(existing.physical_optimizers), + query_planner: Some(existing.query_planner), + catalog_list: Some(existing.catalog_list), + table_functions: Some(existing.table_functions), + scalar_functions: Some(existing.scalar_functions.into_values().collect_vec()), + aggregate_functions: Some( + existing.aggregate_functions.into_values().collect_vec(), + ), + window_functions: Some(existing.window_functions.into_values().collect_vec()), + serializer_registry: Some(existing.serializer_registry), + file_formats: Some(existing.file_formats.into_values().collect_vec()), + config: Some(existing.config), + table_options: Some(existing.table_options), + execution_props: Some(existing.execution_props), + table_factories: Some(existing.table_factories), + runtime_env: Some(existing.runtime_env), + function_factory: existing.function_factory, + + // fields to support convenience functions + analyzer_rules: None, + optimizer_rules: None, + physical_optimizer_rules: None, + } + } + + /// Set defaults for table_factories, file formats, expr_planners and builtin + /// scalar and aggregate functions. + pub fn with_default_features(mut self) -> Self { + self.table_factories = Some(SessionStateDefaults::default_table_factories()); + self.file_formats = Some(SessionStateDefaults::default_file_formats()); + self.expr_planners = Some(SessionStateDefaults::default_expr_planners()); + self.scalar_functions = Some(SessionStateDefaults::default_scalar_functions()); + self.aggregate_functions = + Some(SessionStateDefaults::default_aggregate_functions()); + self + } + + /// Set the session id. + pub fn with_session_id(mut self, session_id: String) -> Self { + self.session_id = Some(session_id); + self + } + + /// Set the [`AnalyzerRule`]s optimizer plan rules. + pub fn with_analyzer_rules( + mut self, + rules: Vec>, + ) -> Self { + self.analyzer = Some(Analyzer::with_rules(rules)); + self + } + + /// Add `analyzer_rule` to the end of the list of + /// [`AnalyzerRule`]s used to rewrite queries. + pub fn with_analyzer_rule( + mut self, + analyzer_rule: Arc, + ) -> Self { + let mut rules = self.analyzer_rules.unwrap_or_default(); + rules.push(analyzer_rule); + self.analyzer_rules = Some(rules); + self + } + + /// Set the [`OptimizerRule`]s used to optimize plans. + pub fn with_optimizer_rules( + mut self, + rules: Vec>, + ) -> Self { + self.optimizer = Some(Optimizer::with_rules(rules)); + self + } + + /// Add `optimizer_rule` to the end of the list of + /// [`OptimizerRule`]s used to rewrite queries. + pub fn with_optimizer_rule( + mut self, + optimizer_rule: Arc, + ) -> Self { + let mut rules = self.optimizer_rules.unwrap_or_default(); + rules.push(optimizer_rule); + self.optimizer_rules = Some(rules); + self + } + + /// Set the [`ExprPlanner`]s used to customize the behavior of the SQL planner. + pub fn with_expr_planners( + mut self, + expr_planners: Vec>, + ) -> Self { + self.expr_planners = Some(expr_planners); + self + } + + /// Set tje [`PhysicalOptimizerRule`]s used to optimize plans. + pub fn with_physical_optimizer_rules( + mut self, + physical_optimizers: Vec>, + ) -> Self { + self.physical_optimizers = + Some(PhysicalOptimizer::with_rules(physical_optimizers)); + self + } + + /// Add `physical_optimizer_rule` to the end of the list of + /// [`PhysicalOptimizerRule`]s used to rewrite queries. + pub fn with_physical_optimizer_rule( + mut self, + physical_optimizer_rule: Arc, + ) -> Self { + let mut rules = self.physical_optimizer_rules.unwrap_or_default(); + rules.push(physical_optimizer_rule); + self.physical_optimizer_rules = Some(rules); + self + } + + /// Set the [`QueryPlanner`] + pub fn with_query_planner( + mut self, + query_planner: Arc, + ) -> Self { + self.query_planner = Some(query_planner); + self + } + + /// Set the [`CatalogProviderList`] + pub fn with_catalog_list( + mut self, + catalog_list: Arc, + ) -> Self { + self.catalog_list = Some(catalog_list); + self + } + + /// Set the map of [`TableFunction`]s + pub fn with_table_functions( + mut self, + table_functions: HashMap>, + ) -> Self { + self.table_functions = Some(table_functions); + self + } + + /// Set the map of [`ScalarUDF`]s + pub fn with_scalar_functions( + mut self, + scalar_functions: Vec>, + ) -> Self { + self.scalar_functions = Some(scalar_functions); + self + } + + /// Set the map of [`AggregateUDF`]s + pub fn with_aggregate_functions( + mut self, + aggregate_functions: Vec>, + ) -> Self { + self.aggregate_functions = Some(aggregate_functions); + self + } + + /// Set the map of [`WindowUDF`]s + pub fn with_window_functions( + mut self, + window_functions: Vec>, + ) -> Self { + self.window_functions = Some(window_functions); + self + } + + /// Set the [`SerializerRegistry`] + pub fn with_serializer_registry( + mut self, + serializer_registry: Arc, + ) -> Self { + self.serializer_registry = Some(serializer_registry); + self + } + + /// Set the map of [`FileFormatFactory`]s + pub fn with_file_formats( + mut self, + file_formats: Vec>, + ) -> Self { + self.file_formats = Some(file_formats); + self + } + + /// Set the [`SessionConfig`] + pub fn with_config(mut self, config: SessionConfig) -> Self { + self.config = Some(config); + self + } + + /// Set the [`TableOptions`] + pub fn with_table_options(mut self, table_options: TableOptions) -> Self { + self.table_options = Some(table_options); + self + } + + /// Set the [`ExecutionProps`] + pub fn with_execution_props(mut self, execution_props: ExecutionProps) -> Self { + self.execution_props = Some(execution_props); + self + } + + /// Set the map of [`TableProviderFactory`]s + pub fn with_table_factories( + mut self, + table_factories: HashMap>, + ) -> Self { + self.table_factories = Some(table_factories); + self + } + + /// Set the [`RuntimeEnv`] + pub fn with_runtime_env(mut self, runtime_env: Arc) -> Self { + self.runtime_env = Some(runtime_env); + self + } + + /// Set a [`FunctionFactory`] to handle `CREATE FUNCTION` statements + pub fn with_function_factory( + mut self, + function_factory: Option>, + ) -> Self { + self.function_factory = function_factory; + self + } + + /// Builds a [`SessionState`] with the current configuration. + /// + /// Note that there is an explicit option for enabling catalog and schema defaults + /// in [SessionConfig::create_default_catalog_and_schema] which if enabled + /// will be built here. + pub fn build(self) -> SessionState { + let Self { + session_id, + analyzer, + expr_planners, + optimizer, + physical_optimizers, + query_planner, + catalog_list, + table_functions, + scalar_functions, + aggregate_functions, + window_functions, + serializer_registry, + file_formats, + table_options, + config, + execution_props, + table_factories, + runtime_env, + function_factory, + analyzer_rules, + optimizer_rules, + physical_optimizer_rules, + } = self; + + let config = config.unwrap_or_default(); + let runtime_env = runtime_env.unwrap_or(Arc::new(RuntimeEnv::default())); + + let mut state = SessionState { + session_id: session_id.unwrap_or(Uuid::new_v4().to_string()), + analyzer: analyzer.unwrap_or_default(), + expr_planners: expr_planners.unwrap_or_default(), + optimizer: optimizer.unwrap_or_default(), + physical_optimizers: physical_optimizers.unwrap_or_default(), + query_planner: query_planner.unwrap_or(Arc::new(DefaultQueryPlanner {})), + catalog_list: catalog_list + .unwrap_or(Arc::new(MemoryCatalogProviderList::new()) + as Arc), + table_functions: table_functions.unwrap_or_default(), + scalar_functions: HashMap::new(), + aggregate_functions: HashMap::new(), + window_functions: HashMap::new(), + serializer_registry: serializer_registry + .unwrap_or(Arc::new(EmptySerializerRegistry)), + file_formats: HashMap::new(), + table_options: table_options + .unwrap_or(TableOptions::default_from_session_config(config.options())), + config, + execution_props: execution_props.unwrap_or_default(), + table_factories: table_factories.unwrap_or_default(), + runtime_env, + function_factory, + }; + + if let Some(file_formats) = file_formats { + for file_format in file_formats { + if let Err(e) = state.register_file_format(file_format, false) { + info!("Unable to register file format: {e}") + }; + } + } + + if let Some(scalar_functions) = scalar_functions { + scalar_functions.into_iter().for_each(|udf| { + let existing_udf = state.register_udf(udf); + if let Ok(Some(existing_udf)) = existing_udf { + debug!("Overwrote an existing UDF: {}", existing_udf.name()); + } + }); + } + + if let Some(aggregate_functions) = aggregate_functions { + aggregate_functions.into_iter().for_each(|udaf| { + let existing_udf = state.register_udaf(udaf); + if let Ok(Some(existing_udf)) = existing_udf { + debug!("Overwrote an existing UDF: {}", existing_udf.name()); + } + }); + } + + if let Some(window_functions) = window_functions { + window_functions.into_iter().for_each(|udwf| { + let existing_udf = state.register_udwf(udwf); + if let Ok(Some(existing_udf)) = existing_udf { + debug!("Overwrote an existing UDF: {}", existing_udf.name()); + } + }); + } + + if state.config.create_default_catalog_and_schema() { + let default_catalog = SessionStateDefaults::default_catalog( + &state.config, + &state.table_factories, + &state.runtime_env, + ); + + state.catalog_list.register_catalog( + state.config.options().catalog.default_catalog.clone(), + Arc::new(default_catalog), + ); + } + + if let Some(analyzer_rules) = analyzer_rules { + for analyzer_rule in analyzer_rules { + state.analyzer.rules.push(analyzer_rule); + } + } + + if let Some(optimizer_rules) = optimizer_rules { + for optimizer_rule in optimizer_rules { + state.optimizer.rules.push(optimizer_rule); + } + } + + if let Some(physical_optimizer_rules) = physical_optimizer_rules { + for physical_optimizer_rule in physical_optimizer_rules { + state + .physical_optimizers + .rules + .push(physical_optimizer_rule); + } + } + + state + } + + /// Returns the current session_id value + pub fn session_id(&self) -> &Option { + &self.session_id + } + + /// Returns the current analyzer value + pub fn analyzer(&mut self) -> &mut Option { + &mut self.analyzer + } + + /// Returns the current expr_planners value + pub fn expr_planners(&mut self) -> &mut Option>> { + &mut self.expr_planners + } + + /// Returns the current optimizer value + pub fn optimizer(&mut self) -> &mut Option { + &mut self.optimizer + } + + /// Returns the current physical_optimizers value + pub fn physical_optimizers(&mut self) -> &mut Option { + &mut self.physical_optimizers + } + + /// Returns the current query_planner value + pub fn query_planner(&mut self) -> &mut Option> { + &mut self.query_planner + } + + /// Returns the current catalog_list value + pub fn catalog_list(&mut self) -> &mut Option> { + &mut self.catalog_list + } + + /// Returns the current table_functions value + pub fn table_functions( + &mut self, + ) -> &mut Option>> { + &mut self.table_functions + } + + /// Returns the current scalar_functions value + pub fn scalar_functions(&mut self) -> &mut Option>> { + &mut self.scalar_functions + } + + /// Returns the current aggregate_functions value + pub fn aggregate_functions(&mut self) -> &mut Option>> { + &mut self.aggregate_functions + } + + /// Returns the current window_functions value + pub fn window_functions(&mut self) -> &mut Option>> { + &mut self.window_functions + } + + /// Returns the current serializer_registry value + pub fn serializer_registry(&mut self) -> &mut Option> { + &mut self.serializer_registry + } + + /// Returns the current file_formats value + pub fn file_formats(&mut self) -> &mut Option>> { + &mut self.file_formats + } + + /// Returns the current session_config value + pub fn config(&mut self) -> &mut Option { + &mut self.config + } + + /// Returns the current table_options value + pub fn table_options(&mut self) -> &mut Option { + &mut self.table_options + } + + /// Returns the current execution_props value + pub fn execution_props(&mut self) -> &mut Option { + &mut self.execution_props + } + + /// Returns the current table_factories value + pub fn table_factories( + &mut self, + ) -> &mut Option>> { + &mut self.table_factories + } + + /// Returns the current runtime_env value + pub fn runtime_env(&mut self) -> &mut Option> { + &mut self.runtime_env + } + + /// Returns the current function_factory value + pub fn function_factory(&mut self) -> &mut Option> { + &mut self.function_factory + } + + /// Returns the current analyzer_rules value + pub fn analyzer_rules( + &mut self, + ) -> &mut Option>> { + &mut self.analyzer_rules + } + + /// Returns the current optimizer_rules value + pub fn optimizer_rules( + &mut self, + ) -> &mut Option>> { + &mut self.optimizer_rules + } + + /// Returns the current physical_optimizer_rules value + pub fn physical_optimizer_rules( + &mut self, + ) -> &mut Option>> { + &mut self.physical_optimizer_rules + } +} + +impl Default for SessionStateBuilder { + fn default() -> Self { + Self::new() + } +} + +impl From for SessionStateBuilder { + fn from(state: SessionState) -> Self { + SessionStateBuilder::new_from_existing(state) + } +} + +/// Defaults that are used as part of creating a SessionState such as table providers, +/// file formats, registering of builtin functions, etc. +pub struct SessionStateDefaults {} + +impl SessionStateDefaults { + /// returns a map of the default [`TableProviderFactory`]s + pub fn default_table_factories() -> HashMap> { + let mut table_factories: HashMap> = + HashMap::new(); + #[cfg(feature = "parquet")] + table_factories.insert("PARQUET".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("CSV".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("JSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); + + table_factories + } + + /// returns the default MemoryCatalogProvider + pub fn default_catalog( + config: &SessionConfig, + table_factories: &HashMap>, + runtime: &Arc, + ) -> MemoryCatalogProvider { + let default_catalog = MemoryCatalogProvider::new(); + + default_catalog + .register_schema( + &config.options().catalog.default_schema, + Arc::new(MemorySchemaProvider::new()), + ) + .expect("memory catalog provider can register schema"); + + Self::register_default_schema(config, table_factories, runtime, &default_catalog); + + default_catalog + } + + /// returns the list of default [`ExprPlanner`]s + pub fn default_expr_planners() -> Vec> { + let expr_planners: Vec> = vec![ + Arc::new(functions::core::planner::CoreFunctionPlanner::default()), + // register crate of array expressions (if enabled) + #[cfg(feature = "array_expressions")] + Arc::new(functions_array::planner::ArrayFunctionPlanner), + #[cfg(feature = "array_expressions")] + Arc::new(functions_array::planner::FieldAccessPlanner), + #[cfg(any( + feature = "datetime_expressions", + feature = "unicode_expressions" + ))] + Arc::new(functions::planner::UserDefinedFunctionPlanner), + ]; + + expr_planners + } + + /// returns the list of default [`ScalarUDF']'s + pub fn default_scalar_functions() -> Vec> { + let mut functions: Vec> = functions::all_default_functions(); + #[cfg(feature = "array_expressions")] + functions.append(&mut functions_array::all_default_array_functions()); + + functions + } + + /// returns the list of default [`AggregateUDF']'s + pub fn default_aggregate_functions() -> Vec> { + functions_aggregate::all_default_aggregate_functions() + } + + /// returns the list of default [`FileFormatFactory']'s + pub fn default_file_formats() -> Vec> { + let file_formats: Vec> = vec![ + #[cfg(feature = "parquet")] + Arc::new(ParquetFormatFactory::new()), + Arc::new(JsonFormatFactory::new()), + Arc::new(CsvFormatFactory::new()), + Arc::new(ArrowFormatFactory::new()), + Arc::new(AvroFormatFactory::new()), + ]; + + file_formats + } + + /// registers all builtin functions - scalar, array and aggregate + pub fn register_builtin_functions(state: &mut SessionState) { + Self::register_scalar_functions(state); + Self::register_array_functions(state); + Self::register_aggregate_functions(state); + } + + /// registers all the builtin scalar functions + pub fn register_scalar_functions(state: &mut SessionState) { + functions::register_all(state).expect("can not register built in functions"); + } + + /// registers all the builtin array functions + pub fn register_array_functions(state: &mut SessionState) { + // register crate of array expressions (if enabled) + #[cfg(feature = "array_expressions")] + functions_array::register_all(state).expect("can not register array expressions"); + } + + /// registers all the builtin aggregate functions + pub fn register_aggregate_functions(state: &mut SessionState) { + functions_aggregate::register_all(state) + .expect("can not register aggregate functions"); + } + + /// registers the default schema + pub fn register_default_schema( + config: &SessionConfig, + table_factories: &HashMap>, + runtime: &Arc, + default_catalog: &MemoryCatalogProvider, + ) { + let url = config.options().catalog.location.as_ref(); + let format = config.options().catalog.format.as_ref(); + let (url, format) = match (url, format) { + (Some(url), Some(format)) => (url, format), + _ => return, + }; + let url = url.to_string(); + let format = format.to_string(); + + let url = Url::parse(url.as_str()).expect("Invalid default catalog location!"); + let authority = match url.host_str() { + Some(host) => format!("{}://{}", url.scheme(), host), + None => format!("{}://", url.scheme()), + }; + let path = &url.as_str()[authority.len()..]; + let path = object_store::path::Path::parse(path).expect("Can't parse path"); + let store = ObjectStoreUrl::parse(authority.as_str()) + .expect("Invalid default catalog url"); + let store = match runtime.object_store(store) { + Ok(store) => store, + _ => return, + }; + let factory = match table_factories.get(format.as_str()) { + Some(factory) => factory, + _ => return, + }; + let schema = + ListingSchemaProvider::new(authority, path, factory.clone(), store, format); + let _ = default_catalog + .register_schema("default", Arc::new(schema)) + .expect("Failed to register default schema"); + } + + /// registers the default [`FileFormatFactory`]s + pub fn register_default_file_formats(state: &mut SessionState) { + let formats = SessionStateDefaults::default_file_formats(); + for format in formats { + if let Err(e) = state.register_file_format(format, false) { + log::info!("Unable to register default file format: {e}") + }; + } + } +} + struct SessionContextProvider<'a> { state: &'a SessionState, tables: HashMap>, diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index d2bc334ec3248..efc83d8f6b5c2 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2269,6 +2269,7 @@ mod tests { use crate::prelude::{SessionConfig, SessionContext}; use crate::test_util::{scan_empty, scan_empty_with_partitions}; + use crate::execution::session_state::SessionStateBuilder; use arrow::array::{ArrayRef, DictionaryArray, Int32Array}; use arrow::datatypes::{DataType, Field, Int32Type}; use datafusion_common::{assert_contains, DFSchemaRef, TableReference}; @@ -2282,7 +2283,11 @@ mod tests { let runtime = Arc::new(RuntimeEnv::default()); let config = SessionConfig::new().with_target_partitions(4); let config = config.set_bool("datafusion.optimizer.skip_failed_rules", false); - SessionState::new_with_config_rt(config, runtime) + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build() } async fn plan(logical_plan: &LogicalPlan) -> Result> { diff --git a/datafusion/core/src/test/object_store.rs b/datafusion/core/src/test/object_store.rs index bea6f7b9ceb7b..6c0a2fc7bec47 100644 --- a/datafusion/core/src/test/object_store.rs +++ b/datafusion/core/src/test/object_store.rs @@ -16,9 +16,8 @@ // under the License. //! Object store implementation used for testing use crate::execution::context::SessionState; +use crate::execution::session_state::SessionStateBuilder; use crate::prelude::SessionContext; -use datafusion_execution::config::SessionConfig; -use datafusion_execution::runtime_env::RuntimeEnv; use futures::FutureExt; use object_store::{memory::InMemory, path::Path, ObjectMeta, ObjectStore}; use std::sync::Arc; @@ -44,10 +43,7 @@ pub fn make_test_store_and_state(files: &[(&str, u64)]) -> (Arc, Sessi ( Arc::new(memory), - SessionState::new_with_config_rt( - SessionConfig::default(), - Arc::new(RuntimeEnv::default()), - ), + SessionStateBuilder::new().with_default_features().build(), ) } diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index f1d57c44293be..1b2a6770cf013 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -42,7 +42,8 @@ use url::Url; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::datasource::MemTable; use datafusion::error::Result; -use datafusion::execution::context::{SessionContext, SessionState}; +use datafusion::execution::context::SessionContext; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::JoinType; use datafusion::prelude::{CsvReadOptions, ParquetReadOptions}; use datafusion::test_util::{parquet_test_data, populate_csv_partitions}; @@ -1544,7 +1545,11 @@ async fn unnest_non_nullable_list() -> Result<()> { async fn test_read_batches() -> Result<()> { let config = SessionConfig::new(); let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::new_with_config_rt(config, runtime); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build(); let ctx = SessionContext::new_with_state(state); let schema = Arc::new(Schema::new(vec![ @@ -1594,7 +1599,11 @@ async fn test_read_batches() -> Result<()> { async fn test_read_batches_empty() -> Result<()> { let config = SessionConfig::new(); let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::new_with_config_rt(config, runtime); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build(); let ctx = SessionContext::new_with_state(state); let batches = vec![]; @@ -1608,9 +1617,7 @@ async fn test_read_batches_empty() -> Result<()> { #[tokio::test] async fn consecutive_projection_same_schema() -> Result<()> { - let config = SessionConfig::new(); - let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::new_with_config_rt(config, runtime); + let state = SessionStateBuilder::new().with_default_features().build(); let ctx = SessionContext::new_with_state(state); let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 7ef24609e238d..1d151f9fd3683 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -38,6 +38,7 @@ use datafusion::datasource::{MemTable, TableProvider}; use datafusion::execution::context::SessionState; use datafusion::execution::disk_manager::DiskManagerConfig; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::physical_optimizer::join_selection::JoinSelection; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; @@ -459,13 +460,16 @@ impl TestCase { let runtime = RuntimeEnv::new(rt_config).unwrap(); // Configure execution - let state = SessionState::new_with_config_rt(config, Arc::new(runtime)); - let state = match scenario.rules() { - Some(rules) => state.with_physical_optimizer_rules(rules), - None => state, + let builder = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(Arc::new(runtime)) + .with_default_features(); + let builder = match scenario.rules() { + Some(rules) => builder.with_physical_optimizer_rules(rules), + None => builder, }; - let ctx = SessionContext::new_with_state(state); + let ctx = SessionContext::new_with_state(builder.build()); ctx.register_table("t", table).expect("registering table"); let query = query.expect("Test error: query not specified"); diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 9f94a59a3e598..bf25b36f48e8b 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -35,6 +35,7 @@ use datafusion_execution::cache::cache_unit::{ use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::session_state::SessionStateBuilder; use tempfile::tempdir; #[tokio::test] @@ -167,10 +168,7 @@ async fn get_listing_table( ) -> ListingTable { let schema = opt .infer_schema( - &SessionState::new_with_config_rt( - SessionConfig::default(), - Arc::new(RuntimeEnv::default()), - ), + &SessionStateBuilder::new().with_default_features().build(), table_path, ) .await diff --git a/datafusion/core/tests/sql/create_drop.rs b/datafusion/core/tests/sql/create_drop.rs index 2174009b85573..83712053b9542 100644 --- a/datafusion/core/tests/sql/create_drop.rs +++ b/datafusion/core/tests/sql/create_drop.rs @@ -15,18 +15,14 @@ // specific language governing permissions and limitations // under the License. -use datafusion::execution::context::SessionState; -use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::test_util::TestTableFactory; use super::*; #[tokio::test] async fn create_custom_table() -> Result<()> { - let cfg = RuntimeConfig::new(); - let env = RuntimeEnv::new(cfg).unwrap(); - let ses = SessionConfig::new(); - let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); + let mut state = SessionStateBuilder::new().with_default_features().build(); state .table_factories_mut() .insert("DELTATABLE".to_string(), Arc::new(TestTableFactory {})); @@ -45,10 +41,7 @@ async fn create_custom_table() -> Result<()> { #[tokio::test] async fn create_external_table_with_ddl() -> Result<()> { - let cfg = RuntimeConfig::new(); - let env = RuntimeEnv::new(cfg).unwrap(); - let ses = SessionConfig::new(); - let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); + let mut state = SessionStateBuilder::new().with_default_features().build(); state .table_factories_mut() .insert("MOCKTABLE".to_string(), Arc::new(TestTableFactory {})); diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 38ed142cf922f..a44f522ba95ac 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -92,6 +92,7 @@ use datafusion::{ }; use async_trait::async_trait; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; @@ -290,10 +291,14 @@ async fn topk_plan() -> Result<()> { fn make_topk_context() -> SessionContext { let config = SessionConfig::new().with_target_partitions(48); let runtime = Arc::new(RuntimeEnv::default()); - let mut state = SessionState::new_with_config_rt(config, runtime) + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() .with_query_planner(Arc::new(TopKQueryPlanner {})) - .add_optimizer_rule(Arc::new(TopKOptimizerRule {})); - state.add_analyzer_rule(Arc::new(MyAnalyzerRule {})); + .with_optimizer_rule(Arc::new(TopKOptimizerRule {})) + .with_analyzer_rule(Arc::new(MyAnalyzerRule {})) + .build(); SessionContext::new_with_state(state) } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index f764a050a6cdd..d0209d811b7ce 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -39,8 +39,7 @@ use prost::Message; use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; -use datafusion::execution::context::SessionState; -use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::expr_fn::{ @@ -202,10 +201,7 @@ async fn roundtrip_custom_tables() -> Result<()> { let mut table_factories: HashMap> = HashMap::new(); table_factories.insert("TESTTABLE".to_string(), Arc::new(TestTableFactory {})); - let cfg = RuntimeConfig::new(); - let env = RuntimeEnv::new(cfg).unwrap(); - let ses = SessionConfig::new(); - let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); + let mut state = SessionStateBuilder::new().with_default_features().build(); // replace factories *state.table_factories_mut() = table_factories; let ctx = SessionContext::new_with_state(state); diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 2893b1a31a26c..5b2d0fbacaef0 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -28,7 +28,6 @@ use std::sync::Arc; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef}; use datafusion::error::Result; -use datafusion::execution::context::SessionState; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_expr::{ @@ -37,6 +36,7 @@ use datafusion::logical_expr::{ use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; +use datafusion::execution::session_state::SessionStateBuilder; use substrait::proto::extensions::simple_extension_declaration::MappingType; use substrait::proto::rel::RelType; use substrait::proto::{plan_rel, Plan, Rel}; @@ -1121,11 +1121,12 @@ async fn function_extension_info(sql: &str) -> Result<(Vec, Vec)> { } async fn create_context() -> Result { - let mut state = SessionState::new_with_config_rt( - SessionConfig::default(), - Arc::new(RuntimeEnv::default()), - ) - .with_serializer_registry(Arc::new(MockSerializerRegistry)); + let mut state = SessionStateBuilder::new() + .with_config(SessionConfig::default()) + .with_runtime_env(Arc::new(RuntimeEnv::default())) + .with_default_features() + .with_serializer_registry(Arc::new(MockSerializerRegistry)) + .build(); // register udaf for test, e.g. `sum()` datafusion_functions_aggregate::register_all(&mut state) From bfd815622f1fe2c84d6fab32596b83ffbe52a84a Mon Sep 17 00:00:00 2001 From: Lordworms <48054792+Lordworms@users.noreply.github.com> Date: Sun, 14 Jul 2024 12:06:14 -0700 Subject: [PATCH 38/59] integrate consumer tests, implement tpch query 18 to 22 (#11462) --- .../tests/cases/consumer_integration.rs | 191 ++ .../tpch_substrait_plans/query_18.json | 1128 ++++++++ .../tpch_substrait_plans/query_19.json | 2386 +++++++++++++++++ .../tpch_substrait_plans/query_20.json | 1273 +++++++++ .../tpch_substrait_plans/query_21.json | 1493 +++++++++++ .../tpch_substrait_plans/query_22.json | 2034 ++++++++++++++ 6 files changed, 8505 insertions(+) create mode 100644 datafusion/substrait/tests/testdata/tpch_substrait_plans/query_18.json create mode 100644 datafusion/substrait/tests/testdata/tpch_substrait_plans/query_19.json create mode 100644 datafusion/substrait/tests/testdata/tpch_substrait_plans/query_20.json create mode 100644 datafusion/substrait/tests/testdata/tpch_substrait_plans/query_21.json create mode 100644 datafusion/substrait/tests/testdata/tpch_substrait_plans/query_22.json diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index c8130220ef4ae..8fbcd721166e3 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -398,4 +398,195 @@ mod tests { \n TableScan: FILENAME_PLACEHOLDER_1 projection=[p_partkey, p_name, p_mfgr, p_brand, p_type, p_size, p_container, p_retailprice, p_comment]"); Ok(()) } + /// this test has some problem in json file internally, gonna fix it + #[ignore] + #[tokio::test] + async fn tpch_test_17() -> Result<()> { + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/lineitem.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/part.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), + ]) + .await?; + let path = "tests/testdata/tpch_substrait_plans/query_17.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let _plan = from_substrait_plan(&ctx, &proto).await?; + Ok(()) + } + + #[tokio::test] + async fn tpch_test_18() -> Result<()> { + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/lineitem.csv"), + ]) + .await?; + let path = "tests/testdata/tpch_substrait_plans/query_18.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{:?}", plan); + assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_0.c_name AS C_NAME, FILENAME_PLACEHOLDER_0.c_custkey AS C_CUSTKEY, FILENAME_PLACEHOLDER_1.o_orderkey AS O_ORDERKEY, FILENAME_PLACEHOLDER_1.o_orderdate AS O_ORDERDATE, FILENAME_PLACEHOLDER_1.o_totalprice AS O_TOTALPRICE, sum(FILENAME_PLACEHOLDER_2.l_quantity) AS EXPR$5\ + \n Limit: skip=0, fetch=100\ + \n Sort: FILENAME_PLACEHOLDER_1.o_totalprice DESC NULLS FIRST, FILENAME_PLACEHOLDER_1.o_orderdate ASC NULLS LAST\ + \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.c_name, FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_1.o_orderkey, FILENAME_PLACEHOLDER_1.o_orderdate, FILENAME_PLACEHOLDER_1.o_totalprice]], aggr=[[sum(FILENAME_PLACEHOLDER_2.l_quantity)]]\ + \n Projection: FILENAME_PLACEHOLDER_0.c_name, FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_1.o_orderkey, FILENAME_PLACEHOLDER_1.o_orderdate, FILENAME_PLACEHOLDER_1.o_totalprice, FILENAME_PLACEHOLDER_2.l_quantity\ + \n Filter: CAST(FILENAME_PLACEHOLDER_1.o_orderkey IN () AS Boolean) AND FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey AND FILENAME_PLACEHOLDER_1.o_orderkey = FILENAME_PLACEHOLDER_2.l_orderkey\ + \n Subquery:\ + \n Projection: FILENAME_PLACEHOLDER_3.l_orderkey\ + \n Filter: sum(FILENAME_PLACEHOLDER_3.l_quantity) > CAST(Int32(300) AS Decimal128(19, 0))\ + \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_3.l_orderkey]], aggr=[[sum(FILENAME_PLACEHOLDER_3.l_quantity)]]\ + \n Projection: FILENAME_PLACEHOLDER_3.l_orderkey, FILENAME_PLACEHOLDER_3.l_quantity\ + \n TableScan: FILENAME_PLACEHOLDER_3 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n TableScan: FILENAME_PLACEHOLDER_0 projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_mktsegment, c_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_1 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_2 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]"); + Ok(()) + } + #[tokio::test] + async fn tpch_test_19() -> Result<()> { + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/lineitem.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/part.csv"), + ]) + .await?; + let path = "tests/testdata/tpch_substrait_plans/query_19.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{:?}", plan); + assert_eq!(plan_str, "Aggregate: groupBy=[[]], aggr=[[sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount) AS REVENUE]]\n Projection: FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount)\ + \n Filter: FILENAME_PLACEHOLDER_1.p_partkey = FILENAME_PLACEHOLDER_0.l_partkey AND FILENAME_PLACEHOLDER_1.p_brand = CAST(Utf8(\"Brand#12\") AS Utf8) AND (FILENAME_PLACEHOLDER_1.p_container = Utf8(\"SM CASE\") OR FILENAME_PLACEHOLDER_1.p_container = Utf8(\"SM BOX\") OR FILENAME_PLACEHOLDER_1.p_container = Utf8(\"SM PACK\") OR FILENAME_PLACEHOLDER_1.p_container = Utf8(\"SM PKG\")) AND FILENAME_PLACEHOLDER_0.l_quantity >= CAST(Int32(1) AS Decimal128(19, 0)) AND FILENAME_PLACEHOLDER_0.l_quantity <= CAST(Int32(1) + Int32(10) AS Decimal128(19, 0)) AND FILENAME_PLACEHOLDER_1.p_size >= Int32(1) AND FILENAME_PLACEHOLDER_1.p_size <= Int32(5) AND (FILENAME_PLACEHOLDER_0.l_shipmode = Utf8(\"AIR\") OR FILENAME_PLACEHOLDER_0.l_shipmode = Utf8(\"AIR REG\")) AND FILENAME_PLACEHOLDER_0.l_shipinstruct = CAST(Utf8(\"DELIVER IN PERSON\") AS Utf8) OR FILENAME_PLACEHOLDER_1.p_partkey = FILENAME_PLACEHOLDER_0.l_partkey AND FILENAME_PLACEHOLDER_1.p_brand = CAST(Utf8(\"Brand#23\") AS Utf8) AND (FILENAME_PLACEHOLDER_1.p_container = Utf8(\"MED BAG\") OR FILENAME_PLACEHOLDER_1.p_container = Utf8(\"MED BOX\") OR FILENAME_PLACEHOLDER_1.p_container = Utf8(\"MED PKG\") OR FILENAME_PLACEHOLDER_1.p_container = Utf8(\"MED PACK\")) AND FILENAME_PLACEHOLDER_0.l_quantity >= CAST(Int32(10) AS Decimal128(19, 0)) AND FILENAME_PLACEHOLDER_0.l_quantity <= CAST(Int32(10) + Int32(10) AS Decimal128(19, 0)) AND FILENAME_PLACEHOLDER_1.p_size >= Int32(1) AND FILENAME_PLACEHOLDER_1.p_size <= Int32(10) AND (FILENAME_PLACEHOLDER_0.l_shipmode = Utf8(\"AIR\") OR FILENAME_PLACEHOLDER_0.l_shipmode = Utf8(\"AIR REG\")) AND FILENAME_PLACEHOLDER_0.l_shipinstruct = CAST(Utf8(\"DELIVER IN PERSON\") AS Utf8) OR FILENAME_PLACEHOLDER_1.p_partkey = FILENAME_PLACEHOLDER_0.l_partkey AND FILENAME_PLACEHOLDER_1.p_brand = CAST(Utf8(\"Brand#34\") AS Utf8) AND (FILENAME_PLACEHOLDER_1.p_container = Utf8(\"LG CASE\") OR FILENAME_PLACEHOLDER_1.p_container = Utf8(\"LG BOX\") OR FILENAME_PLACEHOLDER_1.p_container = Utf8(\"LG PACK\") OR FILENAME_PLACEHOLDER_1.p_container = Utf8(\"LG PKG\")) AND FILENAME_PLACEHOLDER_0.l_quantity >= CAST(Int32(20) AS Decimal128(19, 0)) AND FILENAME_PLACEHOLDER_0.l_quantity <= CAST(Int32(20) + Int32(10) AS Decimal128(19, 0)) AND FILENAME_PLACEHOLDER_1.p_size >= Int32(1) AND FILENAME_PLACEHOLDER_1.p_size <= Int32(15) AND (FILENAME_PLACEHOLDER_0.l_shipmode = Utf8(\"AIR\") OR FILENAME_PLACEHOLDER_0.l_shipmode = Utf8(\"AIR REG\")) AND FILENAME_PLACEHOLDER_0.l_shipinstruct = CAST(Utf8(\"DELIVER IN PERSON\") AS Utf8)\ + \n Inner Join: Filter: Boolean(true)\ + \n TableScan: FILENAME_PLACEHOLDER_0 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_1 projection=[p_partkey, p_name, p_mfgr, p_brand, p_type, p_size, p_container, p_retailprice, p_comment]"); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_20() -> Result<()> { + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/nation.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/part.csv"), + ("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/lineitem.csv"), + ]) + .await?; + let path = "tests/testdata/tpch_substrait_plans/query_20.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{:?}", plan); + assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_0.s_name AS S_NAME, FILENAME_PLACEHOLDER_0.s_address AS S_ADDRESS\ + \n Sort: FILENAME_PLACEHOLDER_0.s_name ASC NULLS LAST\ + \n Projection: FILENAME_PLACEHOLDER_0.s_name, FILENAME_PLACEHOLDER_0.s_address\ + \n Filter: CAST(FILENAME_PLACEHOLDER_0.s_suppkey IN () AS Boolean) AND FILENAME_PLACEHOLDER_0.s_nationkey = FILENAME_PLACEHOLDER_1.n_nationkey AND FILENAME_PLACEHOLDER_1.n_name = CAST(Utf8(\"CANADA\") AS Utf8)\ + \n Subquery:\ + \n Projection: FILENAME_PLACEHOLDER_2.ps_suppkey\ + \n Filter: CAST(FILENAME_PLACEHOLDER_2.ps_partkey IN () AS Boolean) AND CAST(FILENAME_PLACEHOLDER_2.ps_availqty AS Decimal128(19, 1)) > ()\ + \n Subquery:\ + \n Projection: FILENAME_PLACEHOLDER_3.p_partkey\ + \n Filter: FILENAME_PLACEHOLDER_3.p_name LIKE CAST(Utf8(\"forest%\") AS Utf8)\ + \n TableScan: FILENAME_PLACEHOLDER_3 projection=[p_partkey, p_name, p_mfgr, p_brand, p_type, p_size, p_container, p_retailprice, p_comment]\ + \n Subquery:\ + \n Projection: Decimal128(Some(5),2,1) * sum(FILENAME_PLACEHOLDER_4.l_quantity)\ + \n Aggregate: groupBy=[[]], aggr=[[sum(FILENAME_PLACEHOLDER_4.l_quantity)]]\ + \n Projection: FILENAME_PLACEHOLDER_4.l_quantity\ + \n Filter: FILENAME_PLACEHOLDER_4.l_partkey = FILENAME_PLACEHOLDER_4.l_orderkey AND FILENAME_PLACEHOLDER_4.l_suppkey = FILENAME_PLACEHOLDER_4.l_partkey AND FILENAME_PLACEHOLDER_4.l_shipdate >= CAST(Utf8(\"1994-01-01\") AS Date32) AND FILENAME_PLACEHOLDER_4.l_shipdate < CAST(Utf8(\"1995-01-01\") AS Date32)\ + \n TableScan: FILENAME_PLACEHOLDER_4 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_2 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ + \n Inner Join: Filter: Boolean(true)\ + \n TableScan: FILENAME_PLACEHOLDER_0 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_1 projection=[n_nationkey, n_name, n_regionkey, n_comment]"); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_21() -> Result<()> { + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/lineitem.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"), + ("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/lineitem.csv"), + ("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/lineitem.csv"), + ]) + .await?; + let path = "tests/testdata/tpch_substrait_plans/query_21.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{:?}", plan); + assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_0.s_name AS S_NAME, count(Int64(1)) AS NUMWAIT\ + \n Limit: skip=0, fetch=100\ + \n Sort: count(Int64(1)) DESC NULLS FIRST, FILENAME_PLACEHOLDER_0.s_name ASC NULLS LAST\ + \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.s_name]], aggr=[[count(Int64(1))]]\ + \n Projection: FILENAME_PLACEHOLDER_0.s_name\ + \n Filter: FILENAME_PLACEHOLDER_0.s_suppkey = FILENAME_PLACEHOLDER_1.l_suppkey AND FILENAME_PLACEHOLDER_2.o_orderkey = FILENAME_PLACEHOLDER_1.l_orderkey AND FILENAME_PLACEHOLDER_2.o_orderstatus = Utf8(\"F\") AND FILENAME_PLACEHOLDER_1.l_receiptdate > FILENAME_PLACEHOLDER_1.l_commitdate AND EXISTS () AND NOT EXISTS () AND FILENAME_PLACEHOLDER_0.s_nationkey = FILENAME_PLACEHOLDER_3.n_nationkey AND FILENAME_PLACEHOLDER_3.n_name = CAST(Utf8(\"SAUDI ARABIA\") AS Utf8)\ + \n Subquery:\ + \n Filter: FILENAME_PLACEHOLDER_4.l_orderkey = FILENAME_PLACEHOLDER_4.l_tax AND FILENAME_PLACEHOLDER_4.l_suppkey != FILENAME_PLACEHOLDER_4.l_linestatus\ + \n TableScan: FILENAME_PLACEHOLDER_4 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\ + \n Subquery:\ + \n Filter: FILENAME_PLACEHOLDER_5.l_orderkey = FILENAME_PLACEHOLDER_5.l_tax AND FILENAME_PLACEHOLDER_5.l_suppkey != FILENAME_PLACEHOLDER_5.l_linestatus AND FILENAME_PLACEHOLDER_5.l_receiptdate > FILENAME_PLACEHOLDER_5.l_commitdate\ + \n TableScan: FILENAME_PLACEHOLDER_5 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n TableScan: FILENAME_PLACEHOLDER_0 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_1 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\n TableScan: FILENAME_PLACEHOLDER_2 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_3 projection=[n_nationkey, n_name, n_regionkey, n_comment]"); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_22() -> Result<()> { + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/customer.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/orders.csv"), + ]) + .await?; + let path = "tests/testdata/tpch_substrait_plans/query_22.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{:?}", plan); + assert_eq!(plan_str, "Projection: substr(FILENAME_PLACEHOLDER_0.c_phone,Int32(1),Int32(2)) AS CNTRYCODE, count(Int64(1)) AS NUMCUST, sum(FILENAME_PLACEHOLDER_0.c_acctbal) AS TOTACCTBAL\n Sort: substr(FILENAME_PLACEHOLDER_0.c_phone,Int32(1),Int32(2)) ASC NULLS LAST\ + \n Aggregate: groupBy=[[substr(FILENAME_PLACEHOLDER_0.c_phone,Int32(1),Int32(2))]], aggr=[[count(Int64(1)), sum(FILENAME_PLACEHOLDER_0.c_acctbal)]]\ + \n Projection: substr(FILENAME_PLACEHOLDER_0.c_phone, Int32(1), Int32(2)), FILENAME_PLACEHOLDER_0.c_acctbal\ + \n Filter: (substr(FILENAME_PLACEHOLDER_0.c_phone, Int32(1), Int32(2)) = CAST(Utf8(\"13\") AS Utf8) OR substr(FILENAME_PLACEHOLDER_0.c_phone, Int32(1), Int32(2)) = CAST(Utf8(\"31\") AS Utf8) OR substr(FILENAME_PLACEHOLDER_0.c_phone, Int32(1), Int32(2)) = CAST(Utf8(\"23\") AS Utf8) OR substr(FILENAME_PLACEHOLDER_0.c_phone, Int32(1), Int32(2)) = CAST(Utf8(\"29\") AS Utf8) OR substr(FILENAME_PLACEHOLDER_0.c_phone, Int32(1), Int32(2)) = CAST(Utf8(\"30\") AS Utf8) OR substr(FILENAME_PLACEHOLDER_0.c_phone, Int32(1), Int32(2)) = CAST(Utf8(\"18\") AS Utf8) OR substr(FILENAME_PLACEHOLDER_0.c_phone, Int32(1), Int32(2)) = CAST(Utf8(\"17\") AS Utf8)) AND FILENAME_PLACEHOLDER_0.c_acctbal > () AND NOT EXISTS ()\ + \n Subquery:\ + \n Aggregate: groupBy=[[]], aggr=[[avg(FILENAME_PLACEHOLDER_1.c_acctbal)]]\ + \n Projection: FILENAME_PLACEHOLDER_1.c_acctbal\ + \n Filter: FILENAME_PLACEHOLDER_1.c_acctbal > Decimal128(Some(0),3,2) AND (substr(FILENAME_PLACEHOLDER_1.c_phone, Int32(1), Int32(2)) = CAST(Utf8(\"13\") AS Utf8) OR substr(FILENAME_PLACEHOLDER_1.c_phone, Int32(1), Int32(2)) = CAST(Utf8(\"31\") AS Utf8) OR substr(FILENAME_PLACEHOLDER_1.c_phone, Int32(1), Int32(2)) = CAST(Utf8(\"23\") AS Utf8) OR substr(FILENAME_PLACEHOLDER_1.c_phone, Int32(1), Int32(2)) = CAST(Utf8(\"29\") AS Utf8) OR substr(FILENAME_PLACEHOLDER_1.c_phone, Int32(1), Int32(2)) = CAST(Utf8(\"30\") AS Utf8) OR substr(FILENAME_PLACEHOLDER_1.c_phone, Int32(1), Int32(2)) = CAST(Utf8(\"18\") AS Utf8) OR substr(FILENAME_PLACEHOLDER_1.c_phone, Int32(1), Int32(2)) = CAST(Utf8(\"17\") AS Utf8))\ + \n TableScan: FILENAME_PLACEHOLDER_1 projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_mktsegment, c_comment]\n Subquery:\ + \n Filter: FILENAME_PLACEHOLDER_2.o_custkey = FILENAME_PLACEHOLDER_2.o_orderkey\ + \n TableScan: FILENAME_PLACEHOLDER_2 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_0 projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_mktsegment, c_comment]"); + Ok(()) + } } diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_18.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_18.json new file mode 100644 index 0000000000000..a4f0b25db9562 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_18.json @@ -0,0 +1,1128 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, + { + "extensionUriAnchor": 2, + "uri": "/functions_arithmetic_decimal.yaml" + }, + { + "extensionUriAnchor": 3, + "uri": "/functions_comparison.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "and:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "sum:opt_decimal" + } + }, + { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "gt:any1_any1" + } + }, + { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "equal:any1_any1" + } + } + ], + "relations": [ + { + "root": { + "input": { + "fetch": { + "common": { + "direct": { + } + }, + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 33, + 34, + 35, + 36, + 37, + 38 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "C_CUSTKEY", + "C_NAME", + "C_ADDRESS", + "C_NATIONKEY", + "C_PHONE", + "C_ACCTBAL", + "C_MKTSEGMENT", + "C_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 40, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 117, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_0", + "parquet": {} + } + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "O_ORDERKEY", + "O_CUSTKEY", + "O_ORDERSTATUS", + "O_TOTALPRICE", + "O_ORDERDATE", + "O_ORDERPRIORITY", + "O_CLERK", + "O_SHIPPRIORITY", + "O_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 79, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_1", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "L_ORDERKEY", + "L_PARTKEY", + "L_SUPPKEY", + "L_LINENUMBER", + "L_QUANTITY", + "L_EXTENDEDPRICE", + "L_DISCOUNT", + "L_TAX", + "L_RETURNFLAG", + "L_LINESTATUS", + "L_SHIPDATE", + "L_COMMITDATE", + "L_RECEIPTDATE", + "L_SHIPINSTRUCT", + "L_SHIPMODE", + "L_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 44, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_2", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "cast": { + "type": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "subquery": { + "inPredicate": { + "needles": [ + { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + } + ], + "haystack": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 2 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 16, + 17 + ] + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "L_ORDERKEY", + "L_PARTKEY", + "L_SUPPKEY", + "L_LINENUMBER", + "L_QUANTITY", + "L_EXTENDEDPRICE", + "L_DISCOUNT", + "L_TAX", + "L_RETURNFLAG", + "L_LINESTATUS", + "L_SHIPDATE", + "L_COMMITDATE", + "L_RECEIPTDATE", + "L_SHIPINSTRUCT", + "L_SHIPMODE", + "L_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 44, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_3", + "parquet": {} + } + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + ] + } + }, + "groupings": [ + { + "groupingExpressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + ] + } + ], + "measures": [ + { + "measure": { + "functionReference": 1, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + ] + } + }, + "condition": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "i32": 300, + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + ] + } + } + } + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 11 + } + }, + "rootReference": { + } + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + ] + } + }, + "groupings": [ + { + "groupingExpressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + ] + } + ], + "measures": [ + { + "measure": { + "functionReference": 1, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + ] + } + }, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }, + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + } + ] + } + }, + "offset": "0", + "count": "100" + } + }, + "names": [ + "C_NAME", + "C_CUSTKEY", + "O_ORDERKEY", + "O_ORDERDATE", + "O_TOTALPRICE", + "EXPR$5" + ] + } + } + ], + "expectedTypeUrls": [] +} diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_19.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_19.json new file mode 100644 index 0000000000000..356111a480f3b --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_19.json @@ -0,0 +1,2386 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 3, + "uri": "/functions_arithmetic.yaml" + }, + { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, + { + "extensionUriAnchor": 4, + "uri": "/functions_arithmetic_decimal.yaml" + }, + { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "or:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "and:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 2, + "name": "equal:any1_any1" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 3, + "name": "gte:any1_any1" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 4, + "name": "lte:any1_any1" + } + }, + { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 5, + "name": "add:opt_i32_i32" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 6, + "name": "multiply:opt_decimal_decimal" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 7, + "name": "subtract:opt_decimal_decimal" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 8, + "name": "sum:opt_decimal" + } + } + ], + "relations": [ + { + "root": { + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 25 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "L_ORDERKEY", + "L_PARTKEY", + "L_SUPPKEY", + "L_LINENUMBER", + "L_QUANTITY", + "L_EXTENDEDPRICE", + "L_DISCOUNT", + "L_TAX", + "L_RETURNFLAG", + "L_LINESTATUS", + "L_SHIPDATE", + "L_COMMITDATE", + "L_RECEIPTDATE", + "L_SHIPINSTRUCT", + "L_SHIPMODE", + "L_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 44, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_0", + "parquet": {} + } + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "P_PARTKEY", + "P_NAME", + "P_MFGR", + "P_BRAND", + "P_TYPE", + "P_SIZE", + "P_CONTAINER", + "P_RETAILPRICE", + "P_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 55, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 23, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_1", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 19 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "Brand#12", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "SM CASE", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "SM BOX", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "SM PACK", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "SM PKG", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 4, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "scalarFunction": { + "functionReference": 5, + "args": [], + "outputType": { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 10, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 4, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 5, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "AIR", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "AIR REG", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 13 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "DELIVER IN PERSON", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 19 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "Brand#23", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "MED BAG", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "MED BOX", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "MED PKG", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "MED PACK", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "i32": 10, + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 4, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "scalarFunction": { + "functionReference": 5, + "args": [], + "outputType": { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 10, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 10, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 4, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 10, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "AIR", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "AIR REG", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 13 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "DELIVER IN PERSON", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 19 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "Brand#34", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "LG CASE", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "LG BOX", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "LG PACK", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "LG PKG", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "i32": 20, + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 4, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "scalarFunction": { + "functionReference": 5, + "args": [], + "outputType": { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "literal": { + "i32": 20, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 10, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 4, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 15, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "AIR", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "AIR REG", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 13 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "DELIVER IN PERSON", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + ] + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 6, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 7, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + } + ] + } + } + ] + } + }, + "groupings": [ + { + "groupingExpressions": [] + } + ], + "measures": [ + { + "measure": { + "functionReference": 8, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + ] + } + }, + "names": [ + "REVENUE" + ] + } + } + ], + "expectedTypeUrls": [] +} diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_20.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_20.json new file mode 100644 index 0000000000000..54a71fa553f89 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_20.json @@ -0,0 +1,1273 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, + { + "extensionUriAnchor": 2, + "uri": "/functions_string.yaml" + }, + { + "extensionUriAnchor": 5, + "uri": "/functions_arithmetic_decimal.yaml" + }, + { + "extensionUriAnchor": 4, + "uri": "/functions_datetime.yaml" + }, + { + "extensionUriAnchor": 3, + "uri": "/functions_comparison.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "and:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "like:vchar_vchar" + } + }, + { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "gt:any1_any1" + } + }, + { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "equal:any1_any1" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 4, + "name": "gte:date_date" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 5, + "name": "lt:date_date" + } + }, + { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 6, + "name": "sum:opt_decimal" + } + }, + { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 7, + "name": "multiply:opt_decimal_decimal" + } + } + ], + "relations": [ + { + "root": { + "input": { + "sort": { + "common": { + "direct": {} + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 11, + 12 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "join": { + "common": { + "direct": {} + }, + "left": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "S_SUPPKEY", + "S_NAME", + "S_ADDRESS", + "S_NATIONKEY", + "S_PHONE", + "S_ACCTBAL", + "S_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 40, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 101, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_0", + "parquet": {} + } + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "N_NATIONKEY", + "N_NAME", + "N_REGIONKEY", + "N_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 152, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_1", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "cast": { + "type": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "subquery": { + "inPredicate": { + "needles": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + ], + "haystack": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 5 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "PS_PARTKEY", + "PS_SUPPKEY", + "PS_AVAILQTY", + "PS_SUPPLYCOST", + "PS_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 199, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_2", + "parquet": {} + } + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "cast": { + "type": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "subquery": { + "inPredicate": { + "needles": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + ], + "haystack": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 9 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "P_PARTKEY", + "P_NAME", + "P_MFGR", + "P_BRAND", + "P_TYPE", + "P_SIZE", + "P_CONTAINER", + "P_RETAILPRICE", + "P_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 55, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 23, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_3", + "parquet": {} + } + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 55, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "forest%", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + ] + } + } + } + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 1, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }, + { + "value": { + "subquery": { + "scalar": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "aggregate": { + "common": { + "direct": {} + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 16 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "L_ORDERKEY", + "L_PARTKEY", + "L_SUPPKEY", + "L_LINENUMBER", + "L_QUANTITY", + "L_EXTENDEDPRICE", + "L_DISCOUNT", + "L_TAX", + "L_RETURNFLAG", + "L_LINESTATUS", + "L_SHIPDATE", + "L_COMMITDATE", + "L_RECEIPTDATE", + "L_SHIPINSTRUCT", + "L_SHIPMODE", + "L_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 44, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_4", + "parquet": {} + } + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 4, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "cast": { + "type": { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1994-01-01", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 5, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "cast": { + "type": { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1995-01-01", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": {} + } + } + ] + } + }, + "groupings": [ + { + "groupingExpressions": [] + } + ], + "measures": [ + { + "measure": { + "functionReference": 6, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + } + ] + } + } + ] + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 7, + "args": [], + "outputType": { + "decimal": { + "scale": 1, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "literal": { + "decimal": { + "value": "BQAAAAAAAAAAAAAAAAAAAA==", + "precision": 2, + "scale": 1 + }, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + } + ] + } + } + ] + } + } + } + } + } + } + ] + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + ] + } + } + } + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "cast": { + "type": { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "CANADA", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + } + ] + } + }, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + } + ] + } + }, + "names": [ + "S_NAME", + "S_ADDRESS" + ] + } + } + ], + "expectedTypeUrls": [] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_21.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_21.json new file mode 100644 index 0000000000000..d35c1517228bc --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_21.json @@ -0,0 +1,1493 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 4, + "uri": "/functions_aggregate_generic.yaml" + }, + { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, + { + "extensionUriAnchor": 3, + "uri": "/functions_datetime.yaml" + }, + { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "and:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any1_any1" + } + }, + { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "gt:date_date" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 3, + "name": "not_equal:any1_any1" + } + }, + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 4, + "name": "not:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 5, + "name": "count:opt" + } + } + ], + "relations": [ + { + "root": { + "input": { + "fetch": { + "common": { + "direct": { + } + }, + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 36 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "S_SUPPKEY", + "S_NAME", + "S_ADDRESS", + "S_NATIONKEY", + "S_PHONE", + "S_ACCTBAL", + "S_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 40, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 101, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_0", + "parquet": {} + } + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "L_ORDERKEY", + "L_PARTKEY", + "L_SUPPKEY", + "L_LINENUMBER", + "L_QUANTITY", + "L_EXTENDEDPRICE", + "L_DISCOUNT", + "L_TAX", + "L_RETURNFLAG", + "L_LINESTATUS", + "L_SHIPDATE", + "L_COMMITDATE", + "L_RECEIPTDATE", + "L_SHIPINSTRUCT", + "L_SHIPMODE", + "L_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 44, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_1", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "O_ORDERKEY", + "O_CUSTKEY", + "O_ORDERSTATUS", + "O_TOTALPRICE", + "O_ORDERDATE", + "O_ORDERPRIORITY", + "O_CLERK", + "O_SHIPPRIORITY", + "O_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 79, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_2", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "N_NATIONKEY", + "N_NAME", + "N_REGIONKEY", + "N_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 152, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_3", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 23 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 25 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "fixedChar": "F", + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 19 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 18 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + }, + { + "value": { + "subquery": { + "setPredicate": { + "predicateOp": "PREDICATE_OP_EXISTS", + "tuples": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "L_ORDERKEY", + "L_PARTKEY", + "L_SUPPKEY", + "L_LINENUMBER", + "L_QUANTITY", + "L_EXTENDEDPRICE", + "L_DISCOUNT", + "L_TAX", + "L_RETURNFLAG", + "L_LINESTATUS", + "L_SHIPDATE", + "L_COMMITDATE", + "L_RECEIPTDATE", + "L_SHIPINSTRUCT", + "L_SHIPMODE", + "L_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 44, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_4", + "parquet": {} + } + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + } + ] + } + } + } + ] + } + } + } + } + } + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 4, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "subquery": { + "setPredicate": { + "predicateOp": "PREDICATE_OP_EXISTS", + "tuples": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "L_ORDERKEY", + "L_PARTKEY", + "L_SUPPKEY", + "L_LINENUMBER", + "L_QUANTITY", + "L_EXTENDEDPRICE", + "L_DISCOUNT", + "L_TAX", + "L_RETURNFLAG", + "L_LINESTATUS", + "L_SHIPDATE", + "L_COMMITDATE", + "L_RECEIPTDATE", + "L_SHIPINSTRUCT", + "L_SHIPMODE", + "L_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 44, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_5", + "parquet": {} + } + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 11 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + } + ] + } + } + } + } + } + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 32 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 33 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "cast": { + "type": { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "SAUDI ARABIA", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + ] + } + }, + "groupings": [ + { + "groupingExpressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + ] + } + ], + "measures": [ + { + "measure": { + "functionReference": 5, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [] + } + } + ] + } + }, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }, + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + } + ] + } + }, + "offset": "0", + "count": "100" + } + }, + "names": [ + "S_NAME", + "NUMWAIT" + ] + } + } + ], + "expectedTypeUrls": [] +} diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_22.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_22.json new file mode 100644 index 0000000000000..9eb37da8e18e8 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_22.json @@ -0,0 +1,2034 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 5, + "uri": "/functions_aggregate_generic.yaml" + }, + { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, + { + "extensionUriAnchor": 3, + "uri": "/functions_string.yaml" + }, + { + "extensionUriAnchor": 4, + "uri": "/functions_arithmetic_decimal.yaml" + }, + { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "and:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "or:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 2, + "name": "equal:any1_any1" + } + }, + { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "substring:fchar_i32_i32" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 4, + "name": "gt:any1_any1" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 5, + "name": "avg:opt_decimal" + } + }, + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 6, + "name": "not:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 7, + "name": "count:opt" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 8, + "name": "sum:opt_decimal" + } + } + ], + "relations": [ + { + "root": { + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 8, + 9 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "C_CUSTKEY", + "C_NAME", + "C_ADDRESS", + "C_NATIONKEY", + "C_PHONE", + "C_ACCTBAL", + "C_MKTSEGMENT", + "C_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 40, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 117, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_0", + "parquet": {} + } + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "13", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "31", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "23", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "29", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "30", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "18", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "17", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 4, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "subquery": { + "scalar": { + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 8 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "C_CUSTKEY", + "C_NAME", + "C_ADDRESS", + "C_NATIONKEY", + "C_PHONE", + "C_ACCTBAL", + "C_MKTSEGMENT", + "C_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 40, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 117, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_1", + "parquet": {} + } + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 4, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "decimal": { + "value": "AAAAAAAAAAAAAAAAAAAAAA==", + "precision": 3, + "scale": 2 + }, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "13", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "31", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "23", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "29", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "30", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "18", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "17", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + ] + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + ] + } + }, + "groupings": [ + { + "groupingExpressions": [] + } + ], + "measures": [ + { + "measure": { + "functionReference": 5, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + ] + } + } + } + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 6, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "subquery": { + "setPredicate": { + "predicateOp": "PREDICATE_OP_EXISTS", + "tuples": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "O_ORDERKEY", + "O_CUSTKEY", + "O_ORDERSTATUS", + "O_TOTALPRICE", + "O_ORDERDATE", + "O_ORDERPRIORITY", + "O_CLERK", + "O_SHIPPRIORITY", + "O_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 79, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_2", + "parquet": {} + } + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + } + ] + } + } + } + } + } + } + } + } + ] + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "varchar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, + { + "value": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }, + { + "value": { + "literal": { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + ] + } + }, + "groupings": [ + { + "groupingExpressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + ] + } + ], + "measures": [ + { + "measure": { + "functionReference": 7, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [] + } + }, + { + "measure": { + "functionReference": 8, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + } + ] + } + } + ] + } + }, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + } + ] + } + }, + "names": [ + "CNTRYCODE", + "NUMCUST", + "TOTACCTBAL" + ] + } + } + ], + "expectedTypeUrls": [] +} From d01301d2ee9ea6d8e22e002bdfb5cf7b6ff6bd75 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 16 Jul 2024 01:55:49 +0800 Subject: [PATCH 39/59] Docs: Explain the usage of logical expressions for `create_aggregate_expr` (#11458) * doc: comment Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/physical-expr-common/src/aggregate/mod.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 35666f199ace9..db4581a622acc 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -43,6 +43,14 @@ use datafusion_expr::utils::AggregateOrderSensitivity; /// Creates a physical expression of the UDAF, that includes all necessary type coercion. /// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. +/// +/// `input_exprs` and `sort_exprs` are used for customizing Accumulator +/// whose behavior depends on arguments such as the `ORDER BY`. +/// +/// For example to call `ARRAY_AGG(x ORDER BY y)` would pass `y` to `sort_exprs`, `x` to `input_exprs` +/// +/// `input_exprs` and `sort_exprs` are used for customizing Accumulator as the arguments in `AccumulatorArgs`, +/// if you don't need them it is fine to pass empty slice `&[]`. #[allow(clippy::too_many_arguments)] pub fn create_aggregate_expr( fun: &AggregateUDF, From 0965455486b7dcbd8c9a5efa8d2370ca5460bb9f Mon Sep 17 00:00:00 2001 From: kamille Date: Tue, 16 Jul 2024 02:06:38 +0800 Subject: [PATCH 40/59] Return scalar result when all inputs are constants in `map` and `make_map` (#11461) * return scalar result when all inputs are constants. * support convert map array to scalar. * disable the const evaluate for Map type before impl its hash calculation. * add tests in map.slt. * improve error return. * fix error. * fix remove unused import. * remove duplicated testcase. * remove inline. --- datafusion/common/src/scalar/mod.rs | 5 +- datafusion/functions/src/core/map.rs | 34 +++++++- .../simplify_expressions/expr_simplifier.rs | 27 +++++- datafusion/sqllogictest/test_files/map.slt | 84 +++++++++++++++++++ 4 files changed, 143 insertions(+), 7 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 6c03e8698e80b..c891e85aa59bb 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -2678,7 +2678,10 @@ impl ScalarValue { DataType::Duration(TimeUnit::Nanosecond) => { typed_cast!(array, index, DurationNanosecondArray, DurationNanosecond)? } - + DataType::Map(_, _) => { + let a = array.slice(index, 1); + Self::Map(Arc::new(a.as_map().to_owned())) + } other => { return _not_impl_err!( "Can't create a scalar from array of type \"{other:?}\"" diff --git a/datafusion/functions/src/core/map.rs b/datafusion/functions/src/core/map.rs index 8a8a19d7af52b..6626831c8034f 100644 --- a/datafusion/functions/src/core/map.rs +++ b/datafusion/functions/src/core/map.rs @@ -28,7 +28,21 @@ use datafusion_common::{exec_err, internal_err, ScalarValue}; use datafusion_common::{not_impl_err, Result}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +/// Check if we can evaluate the expr to constant directly. +/// +/// # Example +/// ```sql +/// SELECT make_map('type', 'test') from test +/// ``` +/// We can evaluate the result of `make_map` directly. +fn can_evaluate_to_const(args: &[ColumnarValue]) -> bool { + args.iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))) +} + fn make_map(args: &[ColumnarValue]) -> Result { + let can_evaluate_to_const = can_evaluate_to_const(args); + let (key, value): (Vec<_>, Vec<_>) = args .chunks_exact(2) .map(|chunk| { @@ -58,7 +72,7 @@ fn make_map(args: &[ColumnarValue]) -> Result { Ok(value) => value, Err(e) => return internal_err!("Error concatenating values: {}", e), }; - make_map_batch_internal(key, value) + make_map_batch_internal(key, value, can_evaluate_to_const) } fn make_map_batch(args: &[ColumnarValue]) -> Result { @@ -68,9 +82,12 @@ fn make_map_batch(args: &[ColumnarValue]) -> Result { args.len() ); } + + let can_evaluate_to_const = can_evaluate_to_const(args); + let key = get_first_array_ref(&args[0])?; let value = get_first_array_ref(&args[1])?; - make_map_batch_internal(key, value) + make_map_batch_internal(key, value, can_evaluate_to_const) } fn get_first_array_ref(columnar_value: &ColumnarValue) -> Result { @@ -85,7 +102,11 @@ fn get_first_array_ref(columnar_value: &ColumnarValue) -> Result { } } -fn make_map_batch_internal(keys: ArrayRef, values: ArrayRef) -> Result { +fn make_map_batch_internal( + keys: ArrayRef, + values: ArrayRef, + can_evaluate_to_const: bool, +) -> Result { if keys.null_count() > 0 { return exec_err!("map key cannot be null"); } @@ -124,8 +145,13 @@ fn make_map_batch_internal(keys: ArrayRef, values: ArrayRef) -> Result ConstEvaluator<'a> { } else { // Non-ListArray match ScalarValue::try_from_array(&a, 0) { - Ok(s) => ConstSimplifyResult::Simplified(s), + Ok(s) => { + // TODO: support the optimization for `Map` type after support impl hash for it + if matches!(&s, ScalarValue::Map(_)) { + ConstSimplifyResult::SimplifyRuntimeError( + DataFusionError::NotImplemented("Const evaluate for Map type is still not supported".to_string()), + expr, + ) + } else { + ConstSimplifyResult::Simplified(s) + } + } Err(err) => ConstSimplifyResult::SimplifyRuntimeError(err, expr), } } } - ColumnarValue::Scalar(s) => ConstSimplifyResult::Simplified(s), + ColumnarValue::Scalar(s) => { + // TODO: support the optimization for `Map` type after support impl hash for it + if matches!(&s, ScalarValue::Map(_)) { + ConstSimplifyResult::SimplifyRuntimeError( + DataFusionError::NotImplemented( + "Const evaluate for Map type is still not supported" + .to_string(), + ), + expr, + ) + } else { + ConstSimplifyResult::Simplified(s) + } + } } } } diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index abf5b2ebbf98e..fb8917a5f4fee 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -212,3 +212,87 @@ SELECT map(column5, column6) FROM t; # {k1:1, k2:2} # {k3: 3} # {k5: 5} + +query ? +SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', 30, 'OPTION', 29, 'GET', 27, 'PUT', 25, 'DELETE', 24) AS method_count from t; +---- +{POST: 41, HEAD: 33, PATCH: 30, OPTION: 29, GET: 27, PUT: 25, DELETE: 24} +{POST: 41, HEAD: 33, PATCH: 30, OPTION: 29, GET: 27, PUT: 25, DELETE: 24} +{POST: 41, HEAD: 33, PATCH: 30, OPTION: 29, GET: 27, PUT: 25, DELETE: 24} + +query I +SELECT MAKE_MAP('POST', 41, 'HEAD', 33)['POST'] from t; +---- +41 +41 +41 + +query ? +SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', null) from t; +---- +{POST: 41, HEAD: 33, PATCH: } +{POST: 41, HEAD: 33, PATCH: } +{POST: 41, HEAD: 33, PATCH: } + +query ? +SELECT MAKE_MAP('POST', null, 'HEAD', 33, 'PATCH', null) from t; +---- +{POST: , HEAD: 33, PATCH: } +{POST: , HEAD: 33, PATCH: } +{POST: , HEAD: 33, PATCH: } + +query ? +SELECT MAKE_MAP(1, null, 2, 33, 3, null) from t; +---- +{1: , 2: 33, 3: } +{1: , 2: 33, 3: } +{1: , 2: 33, 3: } + +query ? +SELECT MAKE_MAP([1,2], ['a', 'b'], [3,4], ['b']) from t; +---- +{[1, 2]: [a, b], [3, 4]: [b]} +{[1, 2]: [a, b], [3, 4]: [b]} +{[1, 2]: [a, b], [3, 4]: [b]} + +query ? +SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, 30]) from t; +---- +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]) from t; +---- +{POST: 41, HEAD: 33, PATCH: } +{POST: 41, HEAD: 33, PATCH: } +{POST: 41, HEAD: 33, PATCH: } + +query ? +SELECT MAP([[1,2], [3,4]], ['a', 'b']) from t; +---- +{[1, 2]: a, [3, 4]: b} +{[1, 2]: a, [3, 4]: b} +{[1, 2]: a, [3, 4]: b} + +query ? +SELECT MAP(make_array('POST', 'HEAD', 'PATCH'), make_array(41, 33, 30)) from t; +---- +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'FixedSizeList(3, Utf8)'), arrow_cast(make_array(41, 33, 30), 'FixedSizeList(3, Int64)')) from t; +---- +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'LargeList(Utf8)'), arrow_cast(make_array(41, 33, 30), 'LargeList(Int64)')) from t; +---- +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} From 7bd0e74aaa7aad3e436f01000fd4f973d5724f50 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Tue, 16 Jul 2024 02:53:37 +0800 Subject: [PATCH 41/59] fix: `regexp_replace` fails when pattern or replacement is a scalar `NULL` (#11459) * fix: gexp_replace fails when pattern or replacement is a scalar NULL * chore --- .../functions/src/regex/regexpreplace.rs | 31 +++++++++++++------ datafusion/sqllogictest/test_files/regexp.slt | 10 ++++++ 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index 201eebde22bb9..378b6ced076c3 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -282,22 +282,23 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result fn _regexp_replace_early_abort( input_array: &GenericStringArray, + sz: usize, ) -> Result { // Mimicking the existing behavior of regexp_replace, if any of the scalar arguments - // are actually null, then the result will be an array of the same size but with nulls. + // are actually null, then the result will be an array of the same size as the first argument with all nulls. // // Also acts like an early abort mechanism when the input array is empty. - Ok(new_null_array(input_array.data_type(), input_array.len())) + Ok(new_null_array(input_array.data_type(), sz)) } /// Get the first argument from the given string array. /// /// Note: If the array is empty or the first argument is null, /// then calls the given early abort function. macro_rules! fetch_string_arg { - ($ARG:expr, $NAME:expr, $T:ident, $EARLY_ABORT:ident) => {{ + ($ARG:expr, $NAME:expr, $T:ident, $EARLY_ABORT:ident, $ARRAY_SIZE:expr) => {{ let array = as_generic_string_array::($ARG)?; if array.len() == 0 || array.is_null(0) { - return $EARLY_ABORT(array); + return $EARLY_ABORT(array, $ARRAY_SIZE); } else { array.value(0) } @@ -313,12 +314,24 @@ fn _regexp_replace_static_pattern_replace( args: &[ArrayRef], ) -> Result { let string_array = as_generic_string_array::(&args[0])?; - let pattern = fetch_string_arg!(&args[1], "pattern", T, _regexp_replace_early_abort); - let replacement = - fetch_string_arg!(&args[2], "replacement", T, _regexp_replace_early_abort); + let array_size = string_array.len(); + let pattern = fetch_string_arg!( + &args[1], + "pattern", + T, + _regexp_replace_early_abort, + array_size + ); + let replacement = fetch_string_arg!( + &args[2], + "replacement", + T, + _regexp_replace_early_abort, + array_size + ); let flags = match args.len() { 3 => None, - 4 => Some(fetch_string_arg!(&args[3], "flags", T, _regexp_replace_early_abort)), + 4 => Some(fetch_string_arg!(&args[3], "flags", T, _regexp_replace_early_abort, array_size)), other => { return exec_err!( "regexp_replace was called with {other} arguments. It requires at least 3 and at most 4." @@ -351,7 +364,7 @@ fn _regexp_replace_static_pattern_replace( let offsets = string_array.value_offsets(); (offsets[string_array.len()] - offsets[0]) .to_usize() - .unwrap() + .expect("Failed to convert usize") }); let mut new_offsets = BufferBuilder::::new(string_array.len() + 1); new_offsets.append(T::zero()); diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt index fed7ac31712ce..f5349fc659f6a 100644 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ b/datafusion/sqllogictest/test_files/regexp.slt @@ -309,6 +309,16 @@ SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'x ---- fooxx +query TTT +select + regexp_replace(col, NULL, 'c'), + regexp_replace(col, 'a', NULL), + regexp_replace(col, 'a', 'c', NULL) +from (values ('a'), ('b')) as tbl(col); +---- +NULL NULL NULL +NULL NULL NULL + # multiline string query B SELECT 'foo\nbar\nbaz' ~ 'bar'; From f204869ff55bb3e39cf23fc0a34ebd5021e6773f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Tue, 16 Jul 2024 02:54:10 +0800 Subject: [PATCH 42/59] Enable `clone_on_ref_ptr` clippy lint on functions* (#11468) * Enable clone_on_ref_ptr clippy lint on functions * Remove unnecessary Arc::clone --- .../functions-aggregate/src/correlation.rs | 21 +++++++++++----- .../functions-aggregate/src/first_last.rs | 16 ++++++------- datafusion/functions-aggregate/src/lib.rs | 2 ++ datafusion/functions-array/src/array_has.rs | 4 ++-- datafusion/functions-array/src/concat.rs | 2 +- datafusion/functions-array/src/flatten.rs | 4 ++-- datafusion/functions-array/src/lib.rs | 2 ++ datafusion/functions-array/src/resize.rs | 8 +++---- datafusion/functions-array/src/reverse.rs | 4 ++-- datafusion/functions-array/src/set_ops.rs | 12 +++++----- datafusion/functions-array/src/sort.rs | 2 +- datafusion/functions-array/src/string.rs | 2 +- datafusion/functions-array/src/utils.rs | 14 ++++++----- datafusion/functions/benches/concat.rs | 3 ++- datafusion/functions/benches/regx.rs | 24 ++++++++++++------- datafusion/functions/src/core/getfield.rs | 5 ++-- datafusion/functions/src/core/map.rs | 6 ++--- datafusion/functions/src/core/nvl.rs | 7 +++--- datafusion/functions/src/core/nvl2.rs | 3 ++- datafusion/functions/src/core/struct.rs | 17 ++++--------- .../functions/src/datetime/date_part.rs | 2 +- .../functions/src/datetime/to_timestamp.rs | 21 ++++++++-------- datafusion/functions/src/lib.rs | 2 ++ datafusion/functions/src/math/abs.rs | 2 +- datafusion/functions/src/math/log.rs | 2 +- datafusion/functions/src/math/round.rs | 2 +- datafusion/functions/src/math/trunc.rs | 2 +- 27 files changed, 106 insertions(+), 85 deletions(-) diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index 10d5563086154..c2d7a89081d66 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::fmt::Debug; +use std::sync::Arc; use arrow::compute::{and, filter, is_not_null}; use arrow::{ @@ -192,13 +193,21 @@ impl Accumulator for CorrelationAccumulator { fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { let states_c = [ - states[0].clone(), - states[1].clone(), - states[3].clone(), - states[5].clone(), + Arc::clone(&states[0]), + Arc::clone(&states[1]), + Arc::clone(&states[3]), + Arc::clone(&states[5]), + ]; + let states_s1 = [ + Arc::clone(&states[0]), + Arc::clone(&states[1]), + Arc::clone(&states[2]), + ]; + let states_s2 = [ + Arc::clone(&states[0]), + Arc::clone(&states[3]), + Arc::clone(&states[4]), ]; - let states_s1 = [states[0].clone(), states[1].clone(), states[2].clone()]; - let states_s2 = [states[0].clone(), states[3].clone(), states[4].clone()]; self.covar.merge_batch(&states_c)?; self.stddev1.merge_batch(&states_s1)?; diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index dd38e34872643..0e619bacef824 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -247,7 +247,7 @@ impl FirstValueAccumulator { .iter() .zip(self.ordering_req.iter()) .map(|(values, req)| SortColumn { - values: values.clone(), + values: Arc::clone(values), options: Some(req.options), }) .collect::>(); @@ -547,7 +547,7 @@ impl LastValueAccumulator { // Take the reverse ordering requirement. This enables us to // use "fetch = 1" to get the last value. SortColumn { - values: values.clone(), + values: Arc::clone(values), options: Some(!req.options), } }) @@ -676,7 +676,7 @@ fn convert_to_sort_cols( arrs.iter() .zip(sort_exprs.iter()) .map(|(item, sort_expr)| SortColumn { - values: item.clone(), + values: Arc::clone(item), options: Some(sort_expr.options), }) .collect::>() @@ -707,7 +707,7 @@ mod tests { for arr in arrs { // Once first_value is set, accumulator should remember it. // It shouldn't update first_value for each new batch - first_accumulator.update_batch(&[arr.clone()])?; + first_accumulator.update_batch(&[Arc::clone(&arr)])?; // last_value should be updated for each new batch. last_accumulator.update_batch(&[arr])?; } @@ -733,12 +733,12 @@ mod tests { let mut first_accumulator = FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; - first_accumulator.update_batch(&[arrs[0].clone()])?; + first_accumulator.update_batch(&[Arc::clone(&arrs[0])])?; let state1 = first_accumulator.state()?; let mut first_accumulator = FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; - first_accumulator.update_batch(&[arrs[1].clone()])?; + first_accumulator.update_batch(&[Arc::clone(&arrs[1])])?; let state2 = first_accumulator.state()?; assert_eq!(state1.len(), state2.len()); @@ -763,12 +763,12 @@ mod tests { let mut last_accumulator = LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; - last_accumulator.update_batch(&[arrs[0].clone()])?; + last_accumulator.update_batch(&[Arc::clone(&arrs[0])])?; let state1 = last_accumulator.state()?; let mut last_accumulator = LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; - last_accumulator.update_batch(&[arrs[1].clone()])?; + last_accumulator.update_batch(&[Arc::clone(&arrs[1])])?; let state2 = last_accumulator.state()?; assert_eq!(state1.len(), state2.len()); diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 6ae2dfb3697ce..a3808a08b0074 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] //! Aggregate Function packages for [DataFusion]. //! diff --git a/datafusion/functions-array/src/array_has.rs b/datafusion/functions-array/src/array_has.rs index 136c6e7691207..bdda5a565947e 100644 --- a/datafusion/functions-array/src/array_has.rs +++ b/datafusion/functions-array/src/array_has.rs @@ -279,7 +279,7 @@ fn general_array_has_dispatch( let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - let element = sub_array.clone(); + let element = Arc::clone(sub_array); let sub_array = if comparison_type != ComparisonType::Single { as_generic_list_array::(sub_array)? } else { @@ -292,7 +292,7 @@ fn general_array_has_dispatch( let sub_arr_values = if comparison_type != ComparisonType::Single { converter.convert_columns(&[sub_arr])? } else { - converter.convert_columns(&[element.clone()])? + converter.convert_columns(&[Arc::clone(&element)])? }; let mut res = match comparison_type { diff --git a/datafusion/functions-array/src/concat.rs b/datafusion/functions-array/src/concat.rs index 330c50f5b055d..c52118d0a5e2b 100644 --- a/datafusion/functions-array/src/concat.rs +++ b/datafusion/functions-array/src/concat.rs @@ -249,7 +249,7 @@ pub(crate) fn array_concat_inner(args: &[ArrayRef]) -> Result { return not_impl_err!("Array is not type '{base_type:?}'."); } if !base_type.eq(&DataType::Null) { - new_args.push(arg.clone()); + new_args.push(Arc::clone(arg)); } } diff --git a/datafusion/functions-array/src/flatten.rs b/datafusion/functions-array/src/flatten.rs index a495c3ade96f3..2b383af3d456f 100644 --- a/datafusion/functions-array/src/flatten.rs +++ b/datafusion/functions-array/src/flatten.rs @@ -77,7 +77,7 @@ impl ScalarUDFImpl for Flatten { get_base_type(field.data_type()) } Null | List(_) | LargeList(_) => Ok(data_type.to_owned()), - FixedSizeList(field, _) => Ok(List(field.clone())), + FixedSizeList(field, _) => Ok(List(Arc::clone(field))), _ => exec_err!( "Not reachable, data_type should be List, LargeList or FixedSizeList" ), @@ -115,7 +115,7 @@ pub fn flatten_inner(args: &[ArrayRef]) -> Result { let flattened_array = flatten_internal::(list_arr.clone(), None)?; Ok(Arc::new(flattened_array) as ArrayRef) } - Null => Ok(args[0].clone()), + Null => Ok(Arc::clone(&args[0])), _ => { exec_err!("flatten does not support type '{array_type:?}'") } diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index 814127be806b1..9717d29883fd5 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] //! Array Functions for [DataFusion]. //! diff --git a/datafusion/functions-array/src/resize.rs b/datafusion/functions-array/src/resize.rs index 078ec7766aac8..83c545a26eb24 100644 --- a/datafusion/functions-array/src/resize.rs +++ b/datafusion/functions-array/src/resize.rs @@ -67,8 +67,8 @@ impl ScalarUDFImpl for ArrayResize { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(field.clone())), - LargeList(field) => Ok(LargeList(field.clone())), + List(field) | FixedSizeList(field, _) => Ok(List(Arc::clone(field))), + LargeList(field) => Ok(LargeList(Arc::clone(field))), _ => exec_err!( "Not reachable, data_type should be List, LargeList or FixedSizeList" ), @@ -92,7 +92,7 @@ pub(crate) fn array_resize_inner(arg: &[ArrayRef]) -> Result { let new_len = as_int64_array(&arg[1])?; let new_element = if arg.len() == 3 { - Some(arg[2].clone()) + Some(Arc::clone(&arg[2])) } else { None }; @@ -168,7 +168,7 @@ fn general_list_resize>( let data = mutable.freeze(); Ok(Arc::new(GenericListArray::::try_new( - field.clone(), + Arc::clone(field), OffsetBuffer::::new(offsets.into()), arrow_array::make_array(data), None, diff --git a/datafusion/functions-array/src/reverse.rs b/datafusion/functions-array/src/reverse.rs index b462be40209bc..581caf5daf2b8 100644 --- a/datafusion/functions-array/src/reverse.rs +++ b/datafusion/functions-array/src/reverse.rs @@ -93,7 +93,7 @@ pub fn array_reverse_inner(arg: &[ArrayRef]) -> Result { let array = as_large_list_array(&arg[0])?; general_array_reverse::(array, field) } - Null => Ok(arg[0].clone()), + Null => Ok(Arc::clone(&arg[0])), array_type => exec_err!("array_reverse does not support type '{array_type:?}'."), } } @@ -137,7 +137,7 @@ fn general_array_reverse>( let data = mutable.freeze(); Ok(Arc::new(GenericListArray::::try_new( - field.clone(), + Arc::clone(field), OffsetBuffer::::new(offsets.into()), arrow_array::make_array(data), Some(nulls.into()), diff --git a/datafusion/functions-array/src/set_ops.rs b/datafusion/functions-array/src/set_ops.rs index a843a175f3a08..1de9c264ddc2c 100644 --- a/datafusion/functions-array/src/set_ops.rs +++ b/datafusion/functions-array/src/set_ops.rs @@ -213,7 +213,7 @@ fn array_distinct_inner(args: &[ArrayRef]) -> Result { // handle null if args[0].data_type() == &Null { - return Ok(args[0].clone()); + return Ok(Arc::clone(&args[0])); } // handle for list & largelist @@ -314,7 +314,7 @@ fn generic_set_lists( offsets.push(last_offset + OffsetSize::usize_as(rows.len())); let arrays = converter.convert_rows(rows)?; let array = match arrays.first() { - Some(array) => array.clone(), + Some(array) => Arc::clone(array), None => { return internal_err!("{set_op}: failed to get array from rows"); } @@ -370,12 +370,12 @@ fn general_set_op( (List(field), List(_)) => { let array1 = as_list_array(&array1)?; let array2 = as_list_array(&array2)?; - generic_set_lists::(array1, array2, field.clone(), set_op) + generic_set_lists::(array1, array2, Arc::clone(field), set_op) } (LargeList(field), LargeList(_)) => { let array1 = as_large_list_array(&array1)?; let array2 = as_large_list_array(&array2)?; - generic_set_lists::(array1, array2, field.clone(), set_op) + generic_set_lists::(array1, array2, Arc::clone(field), set_op) } (data_type1, data_type2) => { internal_err!( @@ -426,7 +426,7 @@ fn general_array_distinct( offsets.push(last_offset + OffsetSize::usize_as(rows.len())); let arrays = converter.convert_rows(rows)?; let array = match arrays.first() { - Some(array) => array.clone(), + Some(array) => Arc::clone(array), None => { return internal_err!("array_distinct: failed to get array from rows") } @@ -437,7 +437,7 @@ fn general_array_distinct( let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); let values = compute::concat(&new_arrays_ref)?; Ok(Arc::new(GenericListArray::::try_new( - field.clone(), + Arc::clone(field), offsets, values, None, diff --git a/datafusion/functions-array/src/sort.rs b/datafusion/functions-array/src/sort.rs index c82dbd37be04d..9c1ae507636c9 100644 --- a/datafusion/functions-array/src/sort.rs +++ b/datafusion/functions-array/src/sort.rs @@ -121,7 +121,7 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; let row_count = list_array.len(); if row_count == 0 { - return Ok(args[0].clone()); + return Ok(Arc::clone(&args[0])); } let mut array_lengths = vec![]; diff --git a/datafusion/functions-array/src/string.rs b/datafusion/functions-array/src/string.rs index d02c863db8b7e..2dc0a55e69519 100644 --- a/datafusion/functions-array/src/string.rs +++ b/datafusion/functions-array/src/string.rs @@ -381,7 +381,7 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { let delimiter = delimiters[0].unwrap(); let s = compute_array_to_string( &mut arg, - arr.clone(), + Arc::clone(arr), delimiter.to_string(), null_string, with_null_string, diff --git a/datafusion/functions-array/src/utils.rs b/datafusion/functions-array/src/utils.rs index 3ecccf3c87137..f396c3b22581c 100644 --- a/datafusion/functions-array/src/utils.rs +++ b/datafusion/functions-array/src/utils.rs @@ -105,7 +105,7 @@ pub(crate) fn align_array_dimensions( .zip(args_ndim.iter()) .map(|(array, ndim)| { if ndim < max_ndim { - let mut aligned_array = array.clone(); + let mut aligned_array = Arc::clone(&array); for _ in 0..(max_ndim - ndim) { let data_type = aligned_array.data_type().to_owned(); let array_lengths = vec![1; aligned_array.len()]; @@ -120,7 +120,7 @@ pub(crate) fn align_array_dimensions( } Ok(aligned_array) } else { - Ok(array.clone()) + Ok(Arc::clone(&array)) } }) .collect(); @@ -277,10 +277,12 @@ mod tests { Some(vec![Some(6), Some(7), Some(8)]), ])); - let array2d_1 = - Arc::new(array_into_list_array_nullable(array1d_1.clone())) as ArrayRef; - let array2d_2 = - Arc::new(array_into_list_array_nullable(array1d_2.clone())) as ArrayRef; + let array2d_1 = Arc::new(array_into_list_array_nullable( + Arc::clone(&array1d_1) as ArrayRef + )) as ArrayRef; + let array2d_2 = Arc::new(array_into_list_array_nullable( + Arc::clone(&array1d_2) as ArrayRef + )) as ArrayRef; let res = align_array_dimensions::(vec![ array1d_1.to_owned(), diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index e7b00a6d540ad..91c46ac775a8b 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::ArrayRef; use arrow::util::bench_util::create_string_array_with_len; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_common::ScalarValue; @@ -26,7 +27,7 @@ fn create_args(size: usize, str_len: usize) -> Vec { let array = Arc::new(create_string_array_with_len::(size, 0.2, str_len)); let scalar = ScalarValue::Utf8(Some(", ".to_string())); vec![ - ColumnarValue::Array(array.clone()), + ColumnarValue::Array(Arc::clone(&array) as ArrayRef), ColumnarValue::Scalar(scalar), ColumnarValue::Array(array), ] diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index da4882381e76f..23d57f38efae2 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -83,8 +83,12 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box( - regexp_like::(&[data.clone(), regex.clone(), flags.clone()]) - .expect("regexp_like should work on valid values"), + regexp_like::(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&flags), + ]) + .expect("regexp_like should work on valid values"), ) }) }); @@ -97,8 +101,12 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box( - regexp_match::(&[data.clone(), regex.clone(), flags.clone()]) - .expect("regexp_match should work on valid values"), + regexp_match::(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&flags), + ]) + .expect("regexp_match should work on valid values"), ) }) }); @@ -115,10 +123,10 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box( regexp_replace::(&[ - data.clone(), - regex.clone(), - replacement.clone(), - flags.clone(), + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&replacement), + Arc::clone(&flags), ]) .expect("regexp_replace should work on valid values"), ) diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index b76da15c52ca1..2c2e36b91b13a 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -26,6 +26,7 @@ use datafusion_common::{ use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; +use std::sync::Arc; #[derive(Debug)] pub struct GetFieldFunc { @@ -151,7 +152,7 @@ impl ScalarUDFImpl for GetFieldFunc { } let arrays = ColumnarValue::values_to_arrays(args)?; - let array = arrays[0].clone(); + let array = Arc::clone(&arrays[0]); let name = match &args[1] { ColumnarValue::Scalar(name) => name, @@ -199,7 +200,7 @@ impl ScalarUDFImpl for GetFieldFunc { let as_struct_array = as_struct_array(&array)?; match as_struct_array.column_by_name(k) { None => exec_err!("get indexed field {k} not found in struct"), - Some(col) => Ok(ColumnarValue::Array(col.clone())), + Some(col) => Ok(ColumnarValue::Array(Arc::clone(col))), } } (DataType::Struct(_), name) => exec_err!( diff --git a/datafusion/functions/src/core/map.rs b/datafusion/functions/src/core/map.rs index 6626831c8034f..1834c7ac6060f 100644 --- a/datafusion/functions/src/core/map.rs +++ b/datafusion/functions/src/core/map.rs @@ -93,9 +93,9 @@ fn make_map_batch(args: &[ColumnarValue]) -> Result { fn get_first_array_ref(columnar_value: &ColumnarValue) -> Result { match columnar_value { ColumnarValue::Scalar(value) => match value { - ScalarValue::List(array) => Ok(array.value(0).clone()), - ScalarValue::LargeList(array) => Ok(array.value(0).clone()), - ScalarValue::FixedSizeList(array) => Ok(array.value(0).clone()), + ScalarValue::List(array) => Ok(array.value(0)), + ScalarValue::LargeList(array) => Ok(array.value(0)), + ScalarValue::FixedSizeList(array) => Ok(array.value(0)), _ => exec_err!("Expected array, got {:?}", value), }, ColumnarValue::Array(array) => exec_err!("Expected scalar, got {:?}", array), diff --git a/datafusion/functions/src/core/nvl.rs b/datafusion/functions/src/core/nvl.rs index 05515c6e925c8..a09224acefcdf 100644 --- a/datafusion/functions/src/core/nvl.rs +++ b/datafusion/functions/src/core/nvl.rs @@ -21,6 +21,7 @@ use arrow::compute::kernels::zip::zip; use arrow::datatypes::DataType; use datafusion_common::{internal_err, Result}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use std::sync::Arc; #[derive(Debug)] pub struct NVLFunc { @@ -101,13 +102,13 @@ fn nvl_func(args: &[ColumnarValue]) -> Result { } let (lhs_array, rhs_array) = match (&args[0], &args[1]) { (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => { - (lhs.clone(), rhs.to_array_of_size(lhs.len())?) + (Arc::clone(lhs), rhs.to_array_of_size(lhs.len())?) } (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => { - (lhs.clone(), rhs.clone()) + (Arc::clone(lhs), Arc::clone(rhs)) } (ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => { - (lhs.to_array_of_size(rhs.len())?, rhs.clone()) + (lhs.to_array_of_size(rhs.len())?, Arc::clone(rhs)) } (ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)) => { let mut current_value = lhs; diff --git a/datafusion/functions/src/core/nvl2.rs b/datafusion/functions/src/core/nvl2.rs index 573ac72425fb4..1144dc0fb7c56 100644 --- a/datafusion/functions/src/core/nvl2.rs +++ b/datafusion/functions/src/core/nvl2.rs @@ -24,6 +24,7 @@ use datafusion_expr::{ type_coercion::binary::comparison_coercion, ColumnarValue, ScalarUDFImpl, Signature, Volatility, }; +use std::sync::Arc; #[derive(Debug)] pub struct NVL2Func { @@ -112,7 +113,7 @@ fn nvl2_func(args: &[ColumnarValue]) -> Result { .iter() .map(|arg| match arg { ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(len), - ColumnarValue::Array(array) => Ok(array.clone()), + ColumnarValue::Array(array) => Ok(Arc::clone(array)), }) .collect::>>()?; let to_apply = is_not_null(&args[0])?; diff --git a/datafusion/functions/src/core/struct.rs b/datafusion/functions/src/core/struct.rs index 9d4b2e4a0b8b6..c3dee8b1ccb40 100644 --- a/datafusion/functions/src/core/struct.rs +++ b/datafusion/functions/src/core/struct.rs @@ -40,7 +40,7 @@ fn array_struct(args: &[ArrayRef]) -> Result { arg.data_type().clone(), true, )), - arg.clone(), + Arc::clone(arg), )) }) .collect::>>()?; @@ -121,30 +121,21 @@ mod tests { as_struct_array(&struc).expect("failed to initialize function struct"); assert_eq!( &Int64Array::from(vec![1]), - result - .column_by_name("c0") - .unwrap() - .clone() + Arc::clone(result.column_by_name("c0").unwrap()) .as_any() .downcast_ref::() .unwrap() ); assert_eq!( &Int64Array::from(vec![2]), - result - .column_by_name("c1") - .unwrap() - .clone() + Arc::clone(result.column_by_name("c1").unwrap()) .as_any() .downcast_ref::() .unwrap() ); assert_eq!( &Int64Array::from(vec![3]), - result - .column_by_name("c2") - .unwrap() - .clone() + Arc::clone(result.column_by_name("c2").unwrap()) .as_any() .downcast_ref::() .unwrap() diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 4906cdc9601d3..e1efb4811ec0d 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -123,7 +123,7 @@ impl ScalarUDFImpl for DatePartFunc { let is_scalar = matches!(array, ColumnarValue::Scalar(_)); let array = match array { - ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Array(array) => Arc::clone(array), ColumnarValue::Scalar(scalar) => scalar.to_array()?, }; diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 4cb91447f3867..cbb6f37603d27 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -16,6 +16,7 @@ // under the License. use std::any::Any; +use std::sync::Arc; use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; @@ -387,7 +388,7 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { /// the timezone if it exists. fn return_type_for(arg: &DataType, unit: TimeUnit) -> DataType { match arg { - Timestamp(_, Some(tz)) => Timestamp(unit, Some(tz.clone())), + Timestamp(_, Some(tz)) => Timestamp(unit, Some(Arc::clone(tz))), _ => Timestamp(unit, None), } } @@ -794,10 +795,10 @@ mod tests { Arc::new(sec_builder.finish().with_timezone("UTC")) as ArrayRef; let arrays = &[ - ColumnarValue::Array(nanos_timestamps.clone()), - ColumnarValue::Array(millis_timestamps.clone()), - ColumnarValue::Array(micros_timestamps.clone()), - ColumnarValue::Array(sec_timestamps.clone()), + ColumnarValue::Array(Arc::clone(&nanos_timestamps)), + ColumnarValue::Array(Arc::clone(&millis_timestamps)), + ColumnarValue::Array(Arc::clone(µs_timestamps)), + ColumnarValue::Array(Arc::clone(&sec_timestamps)), ]; for udf in &udfs { @@ -836,11 +837,11 @@ mod tests { let i64_timestamps = Arc::new(i64_builder.finish()) as ArrayRef; let arrays = &[ - ColumnarValue::Array(nanos_timestamps.clone()), - ColumnarValue::Array(millis_timestamps.clone()), - ColumnarValue::Array(micros_timestamps.clone()), - ColumnarValue::Array(sec_timestamps.clone()), - ColumnarValue::Array(i64_timestamps.clone()), + ColumnarValue::Array(Arc::clone(&nanos_timestamps)), + ColumnarValue::Array(Arc::clone(&millis_timestamps)), + ColumnarValue::Array(Arc::clone(µs_timestamps)), + ColumnarValue::Array(Arc::clone(&sec_timestamps)), + ColumnarValue::Array(Arc::clone(&i64_timestamps)), ]; for udf in &udfs { diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index 433a4f90d95b7..b1c55c843f71d 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] //! Function packages for [DataFusion]. //! diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index 6d07b14f866e3..f7a17f0caf947 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -91,7 +91,7 @@ fn create_abs_function(input_data_type: &DataType) -> Result | DataType::UInt8 | DataType::UInt16 | DataType::UInt32 - | DataType::UInt64 => Ok(|args: &Vec| Ok(args[0].clone())), + | DataType::UInt64 => Ok(|args: &Vec| Ok(Arc::clone(&args[0]))), // Decimal types DataType::Decimal128(_, _) => Ok(make_decimal_abs_function!(Decimal128Array)), diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 0791561539e1e..ea424c14749e8 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -109,7 +109,7 @@ impl ScalarUDFImpl for LogFunc { let mut x = &args[0]; if args.len() == 2 { x = &args[1]; - base = ColumnarValue::Array(args[0].clone()); + base = ColumnarValue::Array(Arc::clone(&args[0])); } // note in f64::log params order is different than in sql. e.g in sql log(base, x) == f64::log(x, base) let arr: ArrayRef = match args[0].data_type() { diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 71ab7c1b43502..89554a76febba 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -111,7 +111,7 @@ pub fn round(args: &[ArrayRef]) -> Result { let mut decimal_places = ColumnarValue::Scalar(ScalarValue::Int64(Some(0))); if args.len() == 2 { - decimal_places = ColumnarValue::Array(args[1].clone()); + decimal_places = ColumnarValue::Array(Arc::clone(&args[1])); } match args[0].data_type() { diff --git a/datafusion/functions/src/math/trunc.rs b/datafusion/functions/src/math/trunc.rs index f980e583365f7..3344438454c4b 100644 --- a/datafusion/functions/src/math/trunc.rs +++ b/datafusion/functions/src/math/trunc.rs @@ -117,7 +117,7 @@ fn trunc(args: &[ArrayRef]) -> Result { let precision = if args.len() == 1 { ColumnarValue::Scalar(Int64(Some(0))) } else { - ColumnarValue::Array(args[1].clone()) + ColumnarValue::Array(Arc::clone(&args[1])) }; match args[0].data_type() { From 2837e02b7ec7dfbca576451e63db25b84ed2c97d Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Tue, 16 Jul 2024 13:23:48 +0300 Subject: [PATCH 43/59] minor: split repartition time and send time metrics (#11440) --- datafusion/physical-plan/src/repartition/mod.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 3d4d3058393e6..e5c506403ff66 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -261,6 +261,7 @@ impl BatchPartitioner { num_partitions: partitions, hash_buffer, } => { + // Tracking time required for distributing indexes across output partitions let timer = self.timer.timer(); let arrays = exprs @@ -282,6 +283,11 @@ impl BatchPartitioner { .append_value(index as u64); } + // Finished building index-arrays for output partitions + timer.done(); + + // Borrowing partitioner timer to prevent moving `self` to closure + let partitioner_timer = &self.timer; let it = indices .into_iter() .enumerate() @@ -290,6 +296,9 @@ impl BatchPartitioner { (!indices.is_empty()).then_some((partition, indices)) }) .map(move |(partition, indices)| { + // Tracking time required for repartitioned batches construction + let _timer = partitioner_timer.timer(); + // Produce batches based on indices let columns = batch .columns() @@ -303,9 +312,6 @@ impl BatchPartitioner { let batch = RecordBatch::try_new(batch.schema(), columns).unwrap(); - // bind timer so it drops w/ this iterator - let _ = &timer; - Ok((partition, batch)) }); From 133128840ca3dbea200dcfe84050cb7b82bf94a8 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 16 Jul 2024 07:19:25 -0400 Subject: [PATCH 44/59] Docs: Document creating new extension APIs (#11425) * Docs: Document creating new extension APIs * fix * Add clarification about extension APIs. Thanks @ozankabak * Apply suggestions from code review Co-authored-by: Mehmet Ozan Kabak * Add a paragraph on datafusion-contrib * prettier --------- Co-authored-by: Mehmet Ozan Kabak --- datafusion/core/src/lib.rs | 2 +- docs/source/contributor-guide/architecture.md | 74 +++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 63dbe824c2314..81c1c4629a3ad 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -174,7 +174,7 @@ //! //! DataFusion is designed to be highly extensible, so you can //! start with a working, full featured engine, and then -//! specialize any behavior for their usecase. For example, +//! specialize any behavior for your usecase. For example, //! some projects may add custom [`ExecutionPlan`] operators, or create their own //! query language that directly creates [`LogicalPlan`] rather than using the //! built in SQL planner, [`SqlToRel`]. diff --git a/docs/source/contributor-guide/architecture.md b/docs/source/contributor-guide/architecture.md index 68541f8777689..55c8a1d980df5 100644 --- a/docs/source/contributor-guide/architecture.md +++ b/docs/source/contributor-guide/architecture.md @@ -25,3 +25,77 @@ possible. You can find the most up to date version in the [source code]. [crates.io documentation]: https://docs.rs/datafusion/latest/datafusion/index.html#architecture [source code]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/lib.rs + +## Forks vs Extension APIs + +DataFusion is a fast moving project, which results in frequent internal changes. +This benefits DataFusion by allowing it to evolve and respond quickly to +requests, but also means that maintaining a fork with major modifications +sometimes requires non trivial work. + +The public API (what is accessible if you use the DataFusion releases from +crates.io) is typically much more stable (though it does change from release to +release as well). + +Thus, rather than forks, we recommend using one of the many extension APIs (such +as `TableProvider`, `OptimizerRule`, or `ExecutionPlan`) to customize +DataFusion. If you can not do what you want with the existing APIs, we would +welcome you working with us to add new APIs to enable your use case, as +described in the next section. + +## `datafusion-contrib` + +While DataFusions comes with enough features "out of the box" to quickly start +with a working system, it can't include everything useful feature (e.g. +`TableProvider`s for all data formats). The [`datafusion-contrib`] project +contains a collection of community maintained extensions that are not part of +the core DataFusion project, and not under Apache Software Foundation governance +but may be useful to others in the community. If you are interested adding a +feature to DataFusion, a new extension in `datafusion-contrib` is likely a good +place to start. Please [contact] us via github issue, slack, or Discord and +we'll gladly set up a new repository for your extension. + +[`datafusion-contrib`]: https://github.com/datafusion-contrib +[contact]: ../contributor-guide/communication.md + +## Creating new Extension APIs + +DataFusion aims to be a general-purpose query engine, and thus the core crates +contain features that are useful for a wide range of use cases. Use case specific +functionality (such as very specific time series or stream processing features) +are typically implemented using the extension APIs. + +If have a use case that is not covered by the existing APIs, we would love to +work with you to design a new general purpose API. There are often others who are +interested in similar extensions and the act of defining the API often improves +the code overall for everyone. + +Extension APIs that provide "safe" default behaviors are more likely to be +suitable for inclusion in DataFusion, while APIs that require major changes to +built-in operators are less likely. For example, it might make less sense +to add an API to support a stream processing feature if that would result in +slower performance for built-in operators. It may still make sense to add +extension APIs for such features, but leave implementation of such operators in +downstream projects. + +The process to create a new extension API is typically: + +- Look for an existing issue describing what you want to do, and file one if it + doesn't yet exist. +- Discuss what the API would look like. Feel free to ask contributors (via `@` + mentions) for feedback (you can find such people by looking at the most + recently changed PRs and issues) +- Prototype the new API, typically by adding an example (in + `datafusion-examples` or refactoring existing code) to show how it would work +- Create a PR with the new API, and work with the community to get it merged + +Some benefits of using an example based approach are + +- Any future API changes will also keep your example going ensuring no + regression in functionality +- There will be a blue print of any needed changes to your code if the APIs do change + (just look at what changed in your example) + +An example of this process was [creating a SQL Extension Planning API]. + +[creating a sql extension planning api]: https://github.com/apache/datafusion/issues/11207 From 5f0993cf58a1c004c88120eea974554666332213 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 16 Jul 2024 10:15:36 -0400 Subject: [PATCH 45/59] Minor: rename `row_groups.rs` to `row_group_filter.rs` (#11481) --- datafusion/core/src/datasource/physical_plan/parquet/mod.rs | 2 +- datafusion/core/src/datasource/physical_plan/parquet/opener.rs | 2 +- .../parquet/{row_groups.rs => row_group_filter.rs} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename datafusion/core/src/datasource/physical_plan/parquet/{row_groups.rs => row_group_filter.rs} (100%) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 9d5c64719e759..ed0fc5f0169ee 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -53,7 +53,7 @@ mod opener; mod page_filter; mod reader; mod row_filter; -mod row_groups; +mod row_group_filter; mod statistics; mod writer; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs index 36335863032c1..c97b0282626a7 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs @@ -18,7 +18,7 @@ //! [`ParquetOpener`] for opening Parquet files use crate::datasource::physical_plan::parquet::page_filter::PagePruningPredicate; -use crate::datasource::physical_plan::parquet::row_groups::RowGroupAccessPlanFilter; +use crate::datasource::physical_plan::parquet::row_group_filter::RowGroupAccessPlanFilter; use crate::datasource::physical_plan::parquet::{ row_filter, should_enable_page_index, ParquetAccessPlan, }; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs similarity index 100% rename from datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs rename to datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs From 55b792a608be881439fd20dafdd803478346186a Mon Sep 17 00:00:00 2001 From: yfu Date: Wed, 17 Jul 2024 00:17:32 +1000 Subject: [PATCH 46/59] Support alternate formats for unparsing `datetime` to `timestamp` and `interval` (#11466) * Unparser rule for datatime cast (#10) * use timestamp as the identifier for date64 * rename * implement CustomDialectBuilder * fix * dialect with interval style (#11) --------- Co-authored-by: Phillip LeBlanc * fmt * clippy * doc * Update datafusion/sql/src/unparser/expr.rs Co-authored-by: Andrew Lamb * update the doc for CustomDialectBuilder * fix doc test --------- Co-authored-by: Phillip LeBlanc Co-authored-by: Andrew Lamb --- datafusion-examples/examples/plan_to_sql.rs | 6 +- datafusion/sql/src/unparser/dialect.rs | 140 ++++++++ datafusion/sql/src/unparser/expr.rs | 339 ++++++++++++++++---- 3 files changed, 420 insertions(+), 65 deletions(-) diff --git a/datafusion-examples/examples/plan_to_sql.rs b/datafusion-examples/examples/plan_to_sql.rs index f719a33fb6249..8ea7c2951223d 100644 --- a/datafusion-examples/examples/plan_to_sql.rs +++ b/datafusion-examples/examples/plan_to_sql.rs @@ -19,7 +19,7 @@ use datafusion::error::Result; use datafusion::prelude::*; use datafusion::sql::unparser::expr_to_sql; -use datafusion_sql::unparser::dialect::CustomDialect; +use datafusion_sql::unparser::dialect::CustomDialectBuilder; use datafusion_sql::unparser::{plan_to_sql, Unparser}; /// This example demonstrates the programmatic construction of SQL strings using @@ -80,7 +80,9 @@ fn simple_expr_to_pretty_sql_demo() -> Result<()> { /// using a custom dialect and an explicit unparser fn simple_expr_to_sql_demo_escape_mysql_style() -> Result<()> { let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); - let dialect = CustomDialect::new(Some('`')); + let dialect = CustomDialectBuilder::new() + .with_identifier_quote_style('`') + .build(); let unparser = Unparser::new(&dialect); let sql = unparser.expr_to_sql(&expr)?.to_string(); assert_eq!(sql, r#"((`a` < 5) OR (`a` = 8))"#); diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index e8cbde0585666..eca2eb4fd0ec7 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -35,7 +35,33 @@ pub trait Dialect { fn supports_nulls_first_in_sort(&self) -> bool { true } + + // Does the dialect use TIMESTAMP to represent Date64 rather than DATETIME? + // E.g. Trino, Athena and Dremio does not have DATETIME data type + fn use_timestamp_for_date64(&self) -> bool { + false + } + + fn interval_style(&self) -> IntervalStyle { + IntervalStyle::PostgresVerbose + } } + +/// `IntervalStyle` to use for unparsing +/// +/// +/// different DBMS follows different standards, popular ones are: +/// postgres_verbose: '2 years 15 months 100 weeks 99 hours 123456789 milliseconds' which is +/// compatible with arrow display format, as well as duckdb +/// sql standard format is '1-2' for year-month, or '1 10:10:10.123456' for day-time +/// +#[derive(Clone, Copy)] +pub enum IntervalStyle { + PostgresVerbose, + SQLStandard, + MySQL, +} + pub struct DefaultDialect {} impl Dialect for DefaultDialect { @@ -57,6 +83,10 @@ impl Dialect for PostgreSqlDialect { fn identifier_quote_style(&self, _: &str) -> Option { Some('"') } + + fn interval_style(&self) -> IntervalStyle { + IntervalStyle::PostgresVerbose + } } pub struct MySqlDialect {} @@ -69,6 +99,10 @@ impl Dialect for MySqlDialect { fn supports_nulls_first_in_sort(&self) -> bool { false } + + fn interval_style(&self) -> IntervalStyle { + IntervalStyle::MySQL + } } pub struct SqliteDialect {} @@ -81,12 +115,29 @@ impl Dialect for SqliteDialect { pub struct CustomDialect { identifier_quote_style: Option, + supports_nulls_first_in_sort: bool, + use_timestamp_for_date64: bool, + interval_style: IntervalStyle, +} + +impl Default for CustomDialect { + fn default() -> Self { + Self { + identifier_quote_style: None, + supports_nulls_first_in_sort: true, + use_timestamp_for_date64: false, + interval_style: IntervalStyle::SQLStandard, + } + } } impl CustomDialect { + // create a CustomDialect + #[deprecated(note = "please use `CustomDialectBuilder` instead")] pub fn new(identifier_quote_style: Option) -> Self { Self { identifier_quote_style, + ..Default::default() } } } @@ -95,4 +146,93 @@ impl Dialect for CustomDialect { fn identifier_quote_style(&self, _: &str) -> Option { self.identifier_quote_style } + + fn supports_nulls_first_in_sort(&self) -> bool { + self.supports_nulls_first_in_sort + } + + fn use_timestamp_for_date64(&self) -> bool { + self.use_timestamp_for_date64 + } + + fn interval_style(&self) -> IntervalStyle { + self.interval_style + } +} + +/// `CustomDialectBuilder` to build `CustomDialect` using builder pattern +/// +/// +/// # Examples +/// +/// Building a custom dialect with all default options set in CustomDialectBuilder::new() +/// but with `use_timestamp_for_date64` overridden to `true` +/// +/// ``` +/// use datafusion_sql::unparser::dialect::CustomDialectBuilder; +/// let dialect = CustomDialectBuilder::new() +/// .with_use_timestamp_for_date64(true) +/// .build(); +/// ``` +pub struct CustomDialectBuilder { + identifier_quote_style: Option, + supports_nulls_first_in_sort: bool, + use_timestamp_for_date64: bool, + interval_style: IntervalStyle, +} + +impl Default for CustomDialectBuilder { + fn default() -> Self { + Self::new() + } +} + +impl CustomDialectBuilder { + pub fn new() -> Self { + Self { + identifier_quote_style: None, + supports_nulls_first_in_sort: true, + use_timestamp_for_date64: false, + interval_style: IntervalStyle::PostgresVerbose, + } + } + + pub fn build(self) -> CustomDialect { + CustomDialect { + identifier_quote_style: self.identifier_quote_style, + supports_nulls_first_in_sort: self.supports_nulls_first_in_sort, + use_timestamp_for_date64: self.use_timestamp_for_date64, + interval_style: self.interval_style, + } + } + + /// Customize the dialect with a specific identifier quote style, e.g. '`', '"' + pub fn with_identifier_quote_style(mut self, identifier_quote_style: char) -> Self { + self.identifier_quote_style = Some(identifier_quote_style); + self + } + + /// Customize the dialect to supports `NULLS FIRST` in `ORDER BY` clauses + pub fn with_supports_nulls_first_in_sort( + mut self, + supports_nulls_first_in_sort: bool, + ) -> Self { + self.supports_nulls_first_in_sort = supports_nulls_first_in_sort; + self + } + + /// Customize the dialect to uses TIMESTAMP when casting Date64 rather than DATETIME + pub fn with_use_timestamp_for_date64( + mut self, + use_timestamp_for_date64: bool, + ) -> Self { + self.use_timestamp_for_date64 = use_timestamp_for_date64; + self + } + + /// Customize the dialect with a specific interval style listed in `IntervalStyle` + pub fn with_interval_style(mut self, interval_style: IntervalStyle) -> Self { + self.interval_style = interval_style; + self + } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index eb149c819c8b0..6b7775ee3d4db 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -43,6 +43,7 @@ use datafusion_expr::{ Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast, }; +use super::dialect::IntervalStyle; use super::Unparser; /// DataFusion's Exprs can represent either an `Expr` or an `OrderByExpr` @@ -541,6 +542,14 @@ impl Unparser<'_> { } } + fn ast_type_for_date64_in_cast(&self) -> ast::DataType { + if self.dialect.use_timestamp_for_date64() { + ast::DataType::Timestamp(None, ast::TimezoneInfo::None) + } else { + ast::DataType::Datetime(None) + } + } + fn col_to_sql(&self, col: &Column) -> Result { if let Some(table_ref) = &col.relation { let mut id = table_ref.to_vec(); @@ -1003,7 +1012,7 @@ impl Unparser<'_> { expr: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString( datetime.to_string(), ))), - data_type: ast::DataType::Datetime(None), + data_type: self.ast_type_for_date64_in_cast(), format: None, }) } @@ -1055,22 +1064,7 @@ impl Unparser<'_> { ScalarValue::IntervalYearMonth(Some(_)) | ScalarValue::IntervalDayTime(Some(_)) | ScalarValue::IntervalMonthDayNano(Some(_)) => { - let wrap_array = v.to_array()?; - let Some(result) = array_value_to_string(&wrap_array, 0).ok() else { - return internal_err!( - "Unable to convert interval scalar value to string" - ); - }; - let interval = Interval { - value: Box::new(ast::Expr::Value(SingleQuotedString( - result.to_uppercase(), - ))), - leading_field: None, - leading_precision: None, - last_field: None, - fractional_seconds_precision: None, - }; - Ok(ast::Expr::Interval(interval)) + self.interval_scalar_to_sql(v) } ScalarValue::IntervalYearMonth(None) => { Ok(ast::Expr::Value(ast::Value::Null)) @@ -1108,6 +1102,123 @@ impl Unparser<'_> { } } + fn interval_scalar_to_sql(&self, v: &ScalarValue) -> Result { + match self.dialect.interval_style() { + IntervalStyle::PostgresVerbose => { + let wrap_array = v.to_array()?; + let Some(result) = array_value_to_string(&wrap_array, 0).ok() else { + return internal_err!( + "Unable to convert interval scalar value to string" + ); + }; + let interval = Interval { + value: Box::new(ast::Expr::Value(SingleQuotedString( + result.to_uppercase(), + ))), + leading_field: None, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + Ok(ast::Expr::Interval(interval)) + } + // If the interval standard is SQLStandard, implement a simple unparse logic + IntervalStyle::SQLStandard => match v { + ScalarValue::IntervalYearMonth(v) => { + let Some(v) = v else { + return Ok(ast::Expr::Value(ast::Value::Null)); + }; + let interval = Interval { + value: Box::new(ast::Expr::Value( + ast::Value::SingleQuotedString(v.to_string()), + )), + leading_field: Some(ast::DateTimeField::Month), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + Ok(ast::Expr::Interval(interval)) + } + ScalarValue::IntervalDayTime(v) => { + let Some(v) = v else { + return Ok(ast::Expr::Value(ast::Value::Null)); + }; + let days = v.days; + let secs = v.milliseconds / 1_000; + let mins = secs / 60; + let hours = mins / 60; + + let secs = secs - (mins * 60); + let mins = mins - (hours * 60); + + let millis = v.milliseconds % 1_000; + let interval = Interval { + value: Box::new(ast::Expr::Value( + ast::Value::SingleQuotedString(format!( + "{days} {hours}:{mins}:{secs}.{millis:3}" + )), + )), + leading_field: Some(ast::DateTimeField::Day), + leading_precision: None, + last_field: Some(ast::DateTimeField::Second), + fractional_seconds_precision: None, + }; + Ok(ast::Expr::Interval(interval)) + } + ScalarValue::IntervalMonthDayNano(v) => { + let Some(v) = v else { + return Ok(ast::Expr::Value(ast::Value::Null)); + }; + + if v.months >= 0 && v.days == 0 && v.nanoseconds == 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value( + ast::Value::SingleQuotedString(v.months.to_string()), + )), + leading_field: Some(ast::DateTimeField::Month), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + Ok(ast::Expr::Interval(interval)) + } else if v.months == 0 + && v.days >= 0 + && v.nanoseconds % 1_000_000 == 0 + { + let days = v.days; + let secs = v.nanoseconds / 1_000_000_000; + let mins = secs / 60; + let hours = mins / 60; + + let secs = secs - (mins * 60); + let mins = mins - (hours * 60); + + let millis = (v.nanoseconds % 1_000_000_000) / 1_000_000; + + let interval = Interval { + value: Box::new(ast::Expr::Value( + ast::Value::SingleQuotedString(format!( + "{days} {hours}:{mins}:{secs}.{millis:03}" + )), + )), + leading_field: Some(ast::DateTimeField::Day), + leading_precision: None, + last_field: Some(ast::DateTimeField::Second), + fractional_seconds_precision: None, + }; + Ok(ast::Expr::Interval(interval)) + } else { + not_impl_err!("Unsupported IntervalMonthDayNano scalar with both Month and DayTime for IntervalStyle::SQLStandard") + } + } + _ => Ok(ast::Expr::Value(ast::Value::Null)), + }, + IntervalStyle::MySQL => { + not_impl_err!("Unsupported interval scalar for IntervalStyle::MySQL") + } + } + } + fn arrow_dtype_to_ast_dtype(&self, data_type: &DataType) -> Result { match data_type { DataType::Null => { @@ -1136,7 +1247,7 @@ impl Unparser<'_> { Ok(ast::DataType::Timestamp(None, tz_info)) } DataType::Date32 => Ok(ast::DataType::Date), - DataType::Date64 => Ok(ast::DataType::Datetime(None)), + DataType::Date64 => Ok(self.ast_type_for_date64_in_cast()), DataType::Time32(_) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } @@ -1232,7 +1343,7 @@ mod tests { use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::sum; - use crate::unparser::dialect::CustomDialect; + use crate::unparser::dialect::{CustomDialect, CustomDialectBuilder}; use super::*; @@ -1595,46 +1706,7 @@ mod tests { ), (col("need-quoted").eq(lit(1)), r#"("need-quoted" = 1)"#), (col("need quoted").eq(lit(1)), r#"("need quoted" = 1)"#), - ( - interval_month_day_nano_lit( - "1 YEAR 1 MONTH 1 DAY 3 HOUR 10 MINUTE 20 SECOND", - ), - r#"INTERVAL '0 YEARS 13 MONS 1 DAYS 3 HOURS 10 MINS 20.000000000 SECS'"#, - ), - ( - interval_month_day_nano_lit("1.5 MONTH"), - r#"INTERVAL '0 YEARS 1 MONS 15 DAYS 0 HOURS 0 MINS 0.000000000 SECS'"#, - ), - ( - interval_month_day_nano_lit("-3 MONTH"), - r#"INTERVAL '0 YEARS -3 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS'"#, - ), - ( - interval_month_day_nano_lit("1 MONTH") - .add(interval_month_day_nano_lit("1 DAY")), - r#"(INTERVAL '0 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS' + INTERVAL '0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS')"#, - ), - ( - interval_month_day_nano_lit("1 MONTH") - .sub(interval_month_day_nano_lit("1 DAY")), - r#"(INTERVAL '0 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS' - INTERVAL '0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS')"#, - ), - ( - interval_datetime_lit("10 DAY 1 HOUR 10 MINUTE 20 SECOND"), - r#"INTERVAL '0 YEARS 0 MONS 10 DAYS 1 HOURS 10 MINS 20.000 SECS'"#, - ), - ( - interval_datetime_lit("10 DAY 1.5 HOUR 10 MINUTE 20 SECOND"), - r#"INTERVAL '0 YEARS 0 MONS 10 DAYS 1 HOURS 40 MINS 20.000 SECS'"#, - ), - ( - interval_year_month_lit("1 YEAR 1 MONTH"), - r#"INTERVAL '1 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.00 SECS'"#, - ), - ( - interval_year_month_lit("1.5 YEAR 1 MONTH"), - r#"INTERVAL '1 YEARS 7 MONS 0 DAYS 0 HOURS 0 MINS 0.00 SECS'"#, - ), + // See test_interval_scalar_to_expr for interval literals ( (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal128( Some(100123), @@ -1690,8 +1762,10 @@ mod tests { } #[test] - fn custom_dialect() -> Result<()> { - let dialect = CustomDialect::new(Some('\'')); + fn custom_dialect_with_identifier_quote_style() -> Result<()> { + let dialect = CustomDialectBuilder::new() + .with_identifier_quote_style('\'') + .build(); let unparser = Unparser::new(&dialect); let expr = col("a").gt(lit(4)); @@ -1706,8 +1780,8 @@ mod tests { } #[test] - fn custom_dialect_none() -> Result<()> { - let dialect = CustomDialect::new(None); + fn custom_dialect_without_identifier_quote_style() -> Result<()> { + let dialect = CustomDialect::default(); let unparser = Unparser::new(&dialect); let expr = col("a").gt(lit(4)); @@ -1720,4 +1794,143 @@ mod tests { Ok(()) } + + #[test] + fn custom_dialect_use_timestamp_for_date64() -> Result<()> { + for (use_timestamp_for_date64, identifier) in + [(false, "DATETIME"), (true, "TIMESTAMP")] + { + let dialect = CustomDialectBuilder::new() + .with_use_timestamp_for_date64(use_timestamp_for_date64) + .build(); + let unparser = Unparser::new(&dialect); + + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Date64, + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + + let expected = format!(r#"CAST(a AS {identifier})"#); + assert_eq!(actual, expected); + } + Ok(()) + } + + #[test] + fn customer_dialect_support_nulls_first_in_ort() -> Result<()> { + let tests: Vec<(Expr, &str, bool)> = vec![ + (col("a").sort(true, true), r#"a ASC NULLS FIRST"#, true), + (col("a").sort(true, true), r#"a ASC"#, false), + ]; + + for (expr, expected, supports_nulls_first_in_sort) in tests { + let dialect = CustomDialectBuilder::new() + .with_supports_nulls_first_in_sort(supports_nulls_first_in_sort) + .build(); + let unparser = Unparser::new(&dialect); + let ast = unparser.expr_to_unparsed(&expr)?; + + let actual = format!("{}", ast); + + assert_eq!(actual, expected); + } + + Ok(()) + } + + #[test] + fn test_interval_scalar_to_expr() { + let tests = [ + ( + interval_month_day_nano_lit("1 MONTH"), + IntervalStyle::SQLStandard, + "INTERVAL '1' MONTH", + ), + ( + interval_month_day_nano_lit("1.5 DAY"), + IntervalStyle::SQLStandard, + "INTERVAL '1 12:0:0.000' DAY TO SECOND", + ), + ( + interval_month_day_nano_lit("1.51234 DAY"), + IntervalStyle::SQLStandard, + "INTERVAL '1 12:17:46.176' DAY TO SECOND", + ), + ( + interval_datetime_lit("1.51234 DAY"), + IntervalStyle::SQLStandard, + "INTERVAL '1 12:17:46.176' DAY TO SECOND", + ), + ( + interval_year_month_lit("1 YEAR"), + IntervalStyle::SQLStandard, + "INTERVAL '12' MONTH", + ), + ( + interval_month_day_nano_lit( + "1 YEAR 1 MONTH 1 DAY 3 HOUR 10 MINUTE 20 SECOND", + ), + IntervalStyle::PostgresVerbose, + r#"INTERVAL '0 YEARS 13 MONS 1 DAYS 3 HOURS 10 MINS 20.000000000 SECS'"#, + ), + ( + interval_month_day_nano_lit("1.5 MONTH"), + IntervalStyle::PostgresVerbose, + r#"INTERVAL '0 YEARS 1 MONS 15 DAYS 0 HOURS 0 MINS 0.000000000 SECS'"#, + ), + ( + interval_month_day_nano_lit("-3 MONTH"), + IntervalStyle::PostgresVerbose, + r#"INTERVAL '0 YEARS -3 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS'"#, + ), + ( + interval_month_day_nano_lit("1 MONTH") + .add(interval_month_day_nano_lit("1 DAY")), + IntervalStyle::PostgresVerbose, + r#"(INTERVAL '0 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS' + INTERVAL '0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS')"#, + ), + ( + interval_month_day_nano_lit("1 MONTH") + .sub(interval_month_day_nano_lit("1 DAY")), + IntervalStyle::PostgresVerbose, + r#"(INTERVAL '0 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS' - INTERVAL '0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS')"#, + ), + ( + interval_datetime_lit("10 DAY 1 HOUR 10 MINUTE 20 SECOND"), + IntervalStyle::PostgresVerbose, + r#"INTERVAL '0 YEARS 0 MONS 10 DAYS 1 HOURS 10 MINS 20.000 SECS'"#, + ), + ( + interval_datetime_lit("10 DAY 1.5 HOUR 10 MINUTE 20 SECOND"), + IntervalStyle::PostgresVerbose, + r#"INTERVAL '0 YEARS 0 MONS 10 DAYS 1 HOURS 40 MINS 20.000 SECS'"#, + ), + ( + interval_year_month_lit("1 YEAR 1 MONTH"), + IntervalStyle::PostgresVerbose, + r#"INTERVAL '1 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.00 SECS'"#, + ), + ( + interval_year_month_lit("1.5 YEAR 1 MONTH"), + IntervalStyle::PostgresVerbose, + r#"INTERVAL '1 YEARS 7 MONS 0 DAYS 0 HOURS 0 MINS 0.00 SECS'"#, + ), + ]; + + for (value, style, expected) in tests { + let dialect = CustomDialectBuilder::new() + .with_interval_style(style) + .build(); + let unparser = Unparser::new(&dialect); + + let ast = unparser.expr_to_sql(&value).expect("to be unparsed"); + + let actual = format!("{ast}"); + + assert_eq!(actual, expected); + } + } } From f11bdf08b2fea5465d3b120dce4e49c7d0ff45ae Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 16 Jul 2024 08:50:09 -0600 Subject: [PATCH 47/59] add criterion benchmark for CaseExpr (#11482) --- datafusion/physical-expr/Cargo.toml | 4 + datafusion/physical-expr/benches/case_when.rs | 94 +++++++++++++++++++ 2 files changed, 98 insertions(+) create mode 100644 datafusion/physical-expr/benches/case_when.rs diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index d8dbe636d90cf..067617a697a98 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -77,3 +77,7 @@ tokio = { workspace = true, features = ["rt-multi-thread"] } [[bench]] harness = false name = "in_list" + +[[bench]] +harness = false +name = "case_when" diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs new file mode 100644 index 0000000000000..9cc7bdc465fb5 --- /dev/null +++ b/datafusion/physical-expr/benches/case_when.rs @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{Field, Schema}; +use arrow::record_batch::RecordBatch; +use arrow_array::builder::{Int32Builder, StringBuilder}; +use arrow_schema::DataType; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::ScalarValue; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::{BinaryExpr, CaseExpr}; +use datafusion_physical_expr_common::expressions::column::Column; +use datafusion_physical_expr_common::expressions::Literal; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +fn make_col(name: &str, index: usize) -> Arc { + Arc::new(Column::new(name, index)) +} + +fn make_lit_i32(n: i32) -> Arc { + Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) +} + +fn criterion_benchmark(c: &mut Criterion) { + // create input data + let mut c1 = Int32Builder::new(); + let mut c2 = StringBuilder::new(); + for i in 0..1000 { + c1.append_value(i); + if i % 7 == 0 { + c2.append_null(); + } else { + c2.append_value(&format!("string {i}")); + } + } + let c1 = Arc::new(c1.finish()); + let c2 = Arc::new(c2.finish()); + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Utf8, true), + ]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap(); + + // use same predicate for all benchmarks + let predicate = Arc::new(BinaryExpr::new( + make_col("c1", 0), + Operator::LtEq, + make_lit_i32(500), + )); + + // CASE WHEN expr THEN 1 ELSE 0 END + c.bench_function("case_when: scalar or scalar", |b| { + let expr = Arc::new( + CaseExpr::try_new( + None, + vec![(predicate.clone(), make_lit_i32(1))], + Some(make_lit_i32(0)), + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + // CASE WHEN expr THEN col ELSE null END + c.bench_function("case_when: column or null", |b| { + let expr = Arc::new( + CaseExpr::try_new( + None, + vec![(predicate.clone(), make_col("c2", 1))], + Some(Arc::new(Literal::new(ScalarValue::Utf8(None)))), + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); From ccb4baf0fc6b4dee983bb29f2282b9c19510a481 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 16 Jul 2024 15:52:35 -0400 Subject: [PATCH 48/59] Initial support for `StringView`, merge changes from `string-view` development branch (#11402) * Update `string-view` branch to arrow-rs main (#10966) * Pin to arrow main * Fix clippy with latest arrow * Uncomment test that needs new arrow-rs to work * Update datafusion-cli Cargo.lock * Update Cargo.lock * tapelo * feat: Implement equality = and inequality <> support for StringView (#10985) * feat: Implement equality = and inequality <> support for StringView * chore: Add tests for the StringView * chore * chore: Update tests for NULL * fix: Used build_array_string! * chore: Update string_coercion function to handle Utf8View type in binary.rs * chore: add tests * chore: ci * Add more StringView comparison test coverage (#10997) * Add more StringView comparison test coverage * add reference * Add another test showing casting on columns works correctly * feat: Implement equality = and inequality <> support for BinaryView (#11004) * feat: Implement equality = and inequality <> support for BinaryView Signed-off-by: Chojan Shang * chore: make fmt happy Signed-off-by: Chojan Shang --------- Signed-off-by: Chojan Shang * Implement support for LargeString and LargeBinary for StringView and BinaryView (#11034) * implement large binary * add tests for large string * better comments for string coercion * Improve filter predicates with `Utf8View` literals (#11043) * refactor: Improve type coercion logic in TypeCoercionRewriter * refactor: Improve type coercion logic in TypeCoercionRewriter * chore * chore: Update test * refactor: Improve type coercion logic in TypeCoercionRewriter * refactor: Remove unused import and update code formatting in unwrap_cast_in_comparison.rs * Remove arrow-patch --------- Signed-off-by: Chojan Shang Co-authored-by: Alex Huang Co-authored-by: Chojan Shang Co-authored-by: Xiangpeng Hao --- datafusion/common/src/scalar/mod.rs | 8 +- datafusion/expr/src/type_coercion/binary.rs | 36 +- .../src/unwrap_cast_in_comparison.rs | 26 +- .../sqllogictest/test_files/binary_view.slt | 202 +++++++++++ .../sqllogictest/test_files/string_view.slt | 326 ++++++++++++++++++ 5 files changed, 566 insertions(+), 32 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/binary_view.slt create mode 100644 datafusion/sqllogictest/test_files/string_view.slt diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index c891e85aa59bb..38f70e4c1466c 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1682,8 +1682,10 @@ impl ScalarValue { DataType::UInt16 => build_array_primitive!(UInt16Array, UInt16), DataType::UInt32 => build_array_primitive!(UInt32Array, UInt32), DataType::UInt64 => build_array_primitive!(UInt64Array, UInt64), + DataType::Utf8View => build_array_string!(StringViewArray, Utf8View), DataType::Utf8 => build_array_string!(StringArray, Utf8), DataType::LargeUtf8 => build_array_string!(LargeStringArray, LargeUtf8), + DataType::BinaryView => build_array_string!(BinaryViewArray, BinaryView), DataType::Binary => build_array_string!(BinaryArray, Binary), DataType::LargeBinary => build_array_string!(LargeBinaryArray, LargeBinary), DataType::Date32 => build_array_primitive!(Date32Array, Date32), @@ -1841,8 +1843,6 @@ impl ScalarValue { | DataType::Time64(TimeUnit::Millisecond) | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) - | DataType::Utf8View - | DataType::BinaryView | DataType::ListView(_) | DataType::LargeListView(_) => { return _internal_err!( @@ -5695,16 +5695,12 @@ mod tests { DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), ); - // needs https://github.com/apache/arrow-rs/issues/5893 - /* check_scalar_cast(ScalarValue::Utf8(None), DataType::Utf8View); check_scalar_cast(ScalarValue::from("foo"), DataType::Utf8View); check_scalar_cast( ScalarValue::from("larger than 12 bytes string"), DataType::Utf8View, ); - - */ } // mimics how casting work on scalar values by `casting` `scalar` to `desired_type` diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 4f79f3fa2b220..70139aaa4a0cc 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -919,16 +919,21 @@ fn string_concat_internal_coercion( } } -/// Coercion rules for string types (Utf8/LargeUtf8): If at least one argument is -/// a string type and both arguments can be coerced into a string type, coerce -/// to string type. +/// Coercion rules for string view types (Utf8/LargeUtf8/Utf8View): +/// If at least one argument is a string view, we coerce to string view +/// based on the observation that StringArray to StringViewArray is cheap but not vice versa. +/// +/// Between Utf8 and LargeUtf8, we coerce to LargeUtf8. fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { + // If Utf8View is in any side, we coerce to Utf8View. + (Utf8View, Utf8View | Utf8 | LargeUtf8) | (Utf8 | LargeUtf8, Utf8View) => { + Some(Utf8View) + } + // Then, if LargeUtf8 is in any side, we coerce to LargeUtf8. + (LargeUtf8, Utf8 | LargeUtf8) | (Utf8, LargeUtf8) => Some(LargeUtf8), (Utf8, Utf8) => Some(Utf8), - (LargeUtf8, Utf8) => Some(LargeUtf8), - (Utf8, LargeUtf8) => Some(LargeUtf8), - (LargeUtf8, LargeUtf8) => Some(LargeUtf8), _ => None, } } @@ -975,15 +980,26 @@ fn binary_to_string_coercion( } } -/// Coercion rules for binary types (Binary/LargeBinary): If at least one argument is +/// Coercion rules for binary types (Binary/LargeBinary/BinaryView): If at least one argument is /// a binary type and both arguments can be coerced into a binary type, coerce /// to binary type. fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (Binary | Utf8, Binary) | (Binary, Utf8) => Some(Binary), - (LargeBinary | Binary | Utf8 | LargeUtf8, LargeBinary) - | (LargeBinary, Binary | Utf8 | LargeUtf8) => Some(LargeBinary), + // If BinaryView is in any side, we coerce to BinaryView. + (BinaryView, BinaryView | Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View) + | (LargeBinary | Binary | Utf8 | LargeUtf8 | Utf8View, BinaryView) => { + Some(BinaryView) + } + // Prefer LargeBinary over Binary + (LargeBinary | Binary | Utf8 | LargeUtf8 | Utf8View, LargeBinary) + | (LargeBinary, Binary | Utf8 | LargeUtf8 | Utf8View) => Some(LargeBinary), + + // If Utf8View/LargeUtf8 presents need to be large Binary + (Utf8View | LargeUtf8, Binary) | (Binary, Utf8View | LargeUtf8) => { + Some(LargeBinary) + } + (Binary, Utf8) | (Utf8, Binary) => Some(Binary), _ => None, } } diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 9941da9dd65e0..7238dd5bbd97e 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -33,7 +33,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; use datafusion_expr::utils::merge_schema; -use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan, Operator}; +use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan}; /// [`UnwrapCastInComparison`] attempts to remove casts from /// comparisons to literals ([`ScalarValue`]s) by applying the casts @@ -146,7 +146,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { }; is_supported_type(&left_type) && is_supported_type(&right_type) - && is_comparison_op(op) + && op.is_comparison_operator() } => { match (left.as_mut(), right.as_mut()) { @@ -262,18 +262,6 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { } } -fn is_comparison_op(op: &Operator) -> bool { - matches!( - op, - Operator::Eq - | Operator::NotEq - | Operator::Gt - | Operator::GtEq - | Operator::Lt - | Operator::LtEq - ) -} - /// Returns true if [UnwrapCastExprRewriter] supports this data type fn is_supported_type(data_type: &DataType) -> bool { is_supported_numeric_type(data_type) @@ -300,7 +288,10 @@ fn is_supported_numeric_type(data_type: &DataType) -> bool { /// Returns true if [UnwrapCastExprRewriter] supports casting this value as a string fn is_supported_string_type(data_type: &DataType) -> bool { - matches!(data_type, DataType::Utf8 | DataType::LargeUtf8) + matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) } /// Returns true if [UnwrapCastExprRewriter] supports casting this value as a dictionary @@ -473,12 +464,15 @@ fn try_cast_string_literal( target_type: &DataType, ) -> Option { let string_value = match lit_value { - ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) => s.clone(), + ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) | ScalarValue::Utf8View(s) => { + s.clone() + } _ => return None, }; let scalar_value = match target_type { DataType::Utf8 => ScalarValue::Utf8(string_value), DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value), + DataType::Utf8View => ScalarValue::Utf8View(string_value), _ => return None, }; Some(scalar_value) diff --git a/datafusion/sqllogictest/test_files/binary_view.slt b/datafusion/sqllogictest/test_files/binary_view.slt new file mode 100644 index 0000000000000..de0f0bea7ffb5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/binary_view.slt @@ -0,0 +1,202 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +######## +## Test setup +######## + +statement ok +create table test_source as values + ('Andrew', 'X'), + ('Xiangpeng', 'Xiangpeng'), + ('Raphael', 'R'), + (NULL, 'R') +; + +# Table with the different combination of column types +statement ok +CREATE TABLE test AS +SELECT + arrow_cast(column1, 'Utf8') as column1_utf8, + arrow_cast(column2, 'Utf8') as column2_utf8, + arrow_cast(column1, 'Binary') AS column1_binary, + arrow_cast(column2, 'Binary') AS column2_binary, + arrow_cast(column1, 'LargeBinary') AS column1_large_binary, + arrow_cast(column2, 'LargeBinary') AS column2_large_binary, + arrow_cast(arrow_cast(column1, 'Binary'), 'BinaryView') AS column1_binaryview, + arrow_cast(arrow_cast(column2, 'Binary'), 'BinaryView') AS column2_binaryview, + arrow_cast(column1, 'Dictionary(Int32, Binary)') AS column1_dict, + arrow_cast(column2, 'Dictionary(Int32, Binary)') AS column2_dict +FROM test_source; + +statement ok +drop table test_source + +######## +## BinaryView to BinaryView +######## + +# BinaryView scalar to BinaryView scalar + +query BBBB +SELECT + arrow_cast(arrow_cast('NULL', 'Binary'), 'BinaryView') = arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') AS comparison1, + arrow_cast(arrow_cast('NULL', 'Binary'), 'BinaryView') <> arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') AS comparison2, + arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') = arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') AS comparison3, + arrow_cast(arrow_cast('Xiangpeng', 'Binary'), 'BinaryView') <> arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') AS comparison4; +---- +false true true true + + +# BinaryView column to BinaryView column comparison as filters + +query TT +select column1_utf8, column2_utf8 from test where column1_binaryview = column2_binaryview; +---- +Xiangpeng Xiangpeng + +query TT +select column1_utf8, column2_utf8 from test where column1_binaryview <> column2_binaryview; +---- +Andrew X +Raphael R + +# BinaryView column to BinaryView column +query TTBB +select + column1_utf8, column2_utf8, + column1_binaryview = column2_binaryview, + column1_binaryview <> column2_binaryview +from test; +---- +Andrew X false true +Xiangpeng Xiangpeng true false +Raphael R false true +NULL R NULL NULL + +# BinaryView column to BinaryView scalar comparison +query TTBBBB +select + column1_utf8, column2_utf8, + column1_binaryview = arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView'), + arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') = column1_binaryview, + column1_binaryview <> arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView'), + arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') <> column1_binaryview +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +NULL R NULL NULL NULL NULL + +######## +## BinaryView to Binary +######## + +# test BinaryViewArray with Binary columns +query TTBBBB +select + column1_utf8, column2_utf8, + column1_binaryview = column2_binary, + column2_binary = column1_binaryview, + column1_binaryview <> column2_binary, + column2_binary <> column1_binaryview +from test; +---- +Andrew X false false true true +Xiangpeng Xiangpeng true true false false +Raphael R false false true true +NULL R NULL NULL NULL NULL + +# test BinaryViewArray with LargeBinary columns +query TTBBBB +select + column1_utf8, column2_utf8, + column1_binaryview = column2_large_binary, + column2_large_binary = column1_binaryview, + column1_binaryview <> column2_large_binary, + column2_large_binary <> column1_binaryview +from test; +---- +Andrew X false false true true +Xiangpeng Xiangpeng true true false false +Raphael R false false true true +NULL R NULL NULL NULL NULL + +# BinaryView column to Binary scalar +query TTBBBB +select + column1_utf8, column2_utf8, + column1_binaryview = arrow_cast('Andrew', 'Binary'), + arrow_cast('Andrew', 'Binary') = column1_binaryview, + column1_binaryview <> arrow_cast('Andrew', 'Binary'), + arrow_cast('Andrew', 'Binary') <> column1_binaryview +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +NULL R NULL NULL NULL NULL + +# BinaryView column to LargeBinary scalar +query TTBBBB +select + column1_utf8, column2_utf8, + column1_binaryview = arrow_cast('Andrew', 'LargeBinary'), + arrow_cast('Andrew', 'LargeBinary') = column1_binaryview, + column1_binaryview <> arrow_cast('Andrew', 'LargeBinary'), + arrow_cast('Andrew', 'LargeBinary') <> column1_binaryview +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +NULL R NULL NULL NULL NULL + +# Binary column to BinaryView scalar +query TTBBBB +select + column1_utf8, column2_utf8, + column1_binary = arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView'), + arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') = column1_binary, + column1_binary <> arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView'), + arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') <> column1_binary +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +NULL R NULL NULL NULL NULL + + +# LargeBinary column to BinaryView scalar +query TTBBBB +select + column1_utf8, column2_utf8, + column1_large_binary = arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView'), + arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') = column1_large_binary, + column1_large_binary <> arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView'), + arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') <> column1_large_binary +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +NULL R NULL NULL NULL NULL + +statement ok +drop table test; \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt new file mode 100644 index 0000000000000..3ba4e271c2f64 --- /dev/null +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -0,0 +1,326 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +######## +## Test setup +######## + +statement ok +create table test_source as values + ('Andrew', 'X'), + ('Xiangpeng', 'Xiangpeng'), + ('Raphael', 'R'), + (NULL, 'R') +; + +# Table with the different combination of column types +statement ok +create table test as +SELECT + arrow_cast(column1, 'Utf8') as column1_utf8, + arrow_cast(column2, 'Utf8') as column2_utf8, + arrow_cast(column1, 'LargeUtf8') as column1_large_utf8, + arrow_cast(column2, 'LargeUtf8') as column2_large_utf8, + arrow_cast(column1, 'Utf8View') as column1_utf8view, + arrow_cast(column2, 'Utf8View') as column2_utf8view, + arrow_cast(column1, 'Dictionary(Int32, Utf8)') as column1_dict, + arrow_cast(column2, 'Dictionary(Int32, Utf8)') as column2_dict +FROM test_source; + +statement ok +drop table test_source + +######## +## StringView to StringView +######## + +# StringView scalar to StringView scalar + +query BBBB +select + arrow_cast('NULL', 'Utf8View') = arrow_cast('Andrew', 'Utf8View'), + arrow_cast('NULL', 'Utf8View') <> arrow_cast('Andrew', 'Utf8View'), + arrow_cast('Andrew', 'Utf8View') = arrow_cast('Andrew', 'Utf8View'), + arrow_cast('Xiangpeng', 'Utf8View') <> arrow_cast('Andrew', 'Utf8View'); +---- +false true true true + + +# StringView column to StringView column comparison as filters + +query TT +select column1_utf8, column2_utf8 from test where column1_utf8view = column2_utf8view; +---- +Xiangpeng Xiangpeng + +query TT +select column1_utf8, column2_utf8 from test where column1_utf8view <> column2_utf8view; +---- +Andrew X +Raphael R + +# StringView column to StringView column +query TTBB +select + column1_utf8, column2_utf8, + column1_utf8view = column2_utf8view, + column1_utf8view <> column2_utf8view +from test; +---- +Andrew X false true +Xiangpeng Xiangpeng true false +Raphael R false true +NULL R NULL NULL + +# StringView column to StringView scalar comparison +query TTBBBB +select + column1_utf8, column2_utf8, + column1_utf8view = arrow_cast('Andrew', 'Utf8View'), + arrow_cast('Andrew', 'Utf8View') = column1_utf8view, + column1_utf8view <> arrow_cast('Andrew', 'Utf8View'), + arrow_cast('Andrew', 'Utf8View') <> column1_utf8view +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +NULL R NULL NULL NULL NULL + +######## +## StringView to String +######## + +# test StringViewArray with Utf8 columns +query TTBBBB +select + column1_utf8, column2_utf8, + column1_utf8view = column2_utf8, + column2_utf8 = column1_utf8view, + column1_utf8view <> column2_utf8, + column2_utf8 <> column1_utf8view +from test; +---- +Andrew X false false true true +Xiangpeng Xiangpeng true true false false +Raphael R false false true true +NULL R NULL NULL NULL NULL + +# test StringViewArray with LargeUtf8 columns +query TTBBBB +select + column1_utf8, column2_utf8, + column1_utf8view = column2_large_utf8, + column2_large_utf8 = column1_utf8view, + column1_utf8view <> column2_large_utf8, + column2_large_utf8 <> column1_utf8view +from test; +---- +Andrew X false false true true +Xiangpeng Xiangpeng true true false false +Raphael R false false true true +NULL R NULL NULL NULL NULL + + +# StringView column to String scalar +query TTBBBB +select + column1_utf8, column2_utf8, + column1_utf8view = arrow_cast('Andrew', 'Utf8'), + arrow_cast('Andrew', 'Utf8') = column1_utf8view, + column1_utf8view <> arrow_cast('Andrew', 'Utf8'), + arrow_cast('Andrew', 'Utf8') <> column1_utf8view +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +NULL R NULL NULL NULL NULL + +# StringView column to LargeString scalar +query TTBBBB +select + column1_utf8, column2_utf8, + column1_utf8view = arrow_cast('Andrew', 'LargeUtf8'), + arrow_cast('Andrew', 'LargeUtf8') = column1_utf8view, + column1_utf8view <> arrow_cast('Andrew', 'LargeUtf8'), + arrow_cast('Andrew', 'LargeUtf8') <> column1_utf8view +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +NULL R NULL NULL NULL NULL + +# String column to StringView scalar +query TTBBBB +select + column1_utf8, column2_utf8, + column1_utf8 = arrow_cast('Andrew', 'Utf8View'), + arrow_cast('Andrew', 'Utf8View') = column1_utf8, + column1_utf8 <> arrow_cast('Andrew', 'Utf8View'), + arrow_cast('Andrew', 'Utf8View') <> column1_utf8 +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +NULL R NULL NULL NULL NULL + +# LargeString column to StringView scalar +query TTBBBB +select + column1_utf8, column2_utf8, + column1_large_utf8 = arrow_cast('Andrew', 'Utf8View'), + arrow_cast('Andrew', 'Utf8View') = column1_large_utf8, + column1_large_utf8 <> arrow_cast('Andrew', 'Utf8View'), + arrow_cast('Andrew', 'Utf8View') <> column1_large_utf8 +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +NULL R NULL NULL NULL NULL + +######## +## StringView to Dictionary +######## + +# test StringViewArray with Dictionary columns +query TTBBBB +select + column1_utf8, column2_utf8, + column1_utf8view = column2_dict, + column2_dict = column1_utf8view, + column1_utf8view <> column2_dict, + column2_dict <> column1_utf8view +from test; +---- +Andrew X false false true true +Xiangpeng Xiangpeng true true false false +Raphael R false false true true +NULL R NULL NULL NULL NULL + +# StringView column to Dict scalar +query TTBBBB +select + column1_utf8, column2_utf8, + column1_utf8view = arrow_cast('Andrew', 'Dictionary(Int32, Utf8)'), + arrow_cast('Andrew', 'Dictionary(Int32, Utf8)') = column1_utf8view, + column1_utf8view <> arrow_cast('Andrew', 'Dictionary(Int32, Utf8)'), + arrow_cast('Andrew', 'Dictionary(Int32, Utf8)') <> column1_utf8view +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +NULL R NULL NULL NULL NULL + +# Dict column to StringView scalar +query TTBBBB +select + column1_utf8, column2_utf8, + column1_dict = arrow_cast('Andrew', 'Utf8View'), + arrow_cast('Andrew', 'Utf8View') = column1_dict, + column1_dict <> arrow_cast('Andrew', 'Utf8View'), + arrow_cast('Andrew', 'Utf8View') <> column1_dict +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +NULL R NULL NULL NULL NULL + + +######## +## Coercion Rules +######## + + +statement ok +set datafusion.explain.logical_plan_only = true; + + +# Filter should have a StringView literal and no column cast +query TT +explain SELECT column1_utf8 from test where column1_utf8view = 'Andrew'; +---- +logical_plan +01)Projection: test.column1_utf8 +02)--Filter: test.column1_utf8view = Utf8View("Andrew") +03)----TableScan: test projection=[column1_utf8, column1_utf8view] + +# reverse order should be the same +query TT +explain SELECT column1_utf8 from test where 'Andrew' = column1_utf8view; +---- +logical_plan +01)Projection: test.column1_utf8 +02)--Filter: test.column1_utf8view = Utf8View("Andrew") +03)----TableScan: test projection=[column1_utf8, column1_utf8view] + +query TT +explain SELECT column1_utf8 from test where column1_utf8 = arrow_cast('Andrew', 'Utf8View'); +---- +logical_plan +01)Filter: test.column1_utf8 = Utf8("Andrew") +02)--TableScan: test projection=[column1_utf8] + +query TT +explain SELECT column1_utf8 from test where arrow_cast('Andrew', 'Utf8View') = column1_utf8; +---- +logical_plan +01)Filter: test.column1_utf8 = Utf8("Andrew") +02)--TableScan: test projection=[column1_utf8] + +query TT +explain SELECT column1_utf8 from test where column1_utf8view = arrow_cast('Andrew', 'Dictionary(Int32, Utf8)'); +---- +logical_plan +01)Projection: test.column1_utf8 +02)--Filter: test.column1_utf8view = Utf8View("Andrew") +03)----TableScan: test projection=[column1_utf8, column1_utf8view] + +query TT +explain SELECT column1_utf8 from test where arrow_cast('Andrew', 'Dictionary(Int32, Utf8)') = column1_utf8view; +---- +logical_plan +01)Projection: test.column1_utf8 +02)--Filter: test.column1_utf8view = Utf8View("Andrew") +03)----TableScan: test projection=[column1_utf8, column1_utf8view] + +# compare string / stringview +# Should cast string -> stringview (which is cheap), not stringview -> string (which is not) +query TT +explain SELECT column1_utf8 from test where column1_utf8view = column2_utf8; +---- +logical_plan +01)Projection: test.column1_utf8 +02)--Filter: test.column1_utf8view = CAST(test.column2_utf8 AS Utf8View) +03)----TableScan: test projection=[column1_utf8, column2_utf8, column1_utf8view] + +query TT +explain SELECT column1_utf8 from test where column2_utf8 = column1_utf8view; +---- +logical_plan +01)Projection: test.column1_utf8 +02)--Filter: CAST(test.column2_utf8 AS Utf8View) = test.column1_utf8view +03)----TableScan: test projection=[column1_utf8, column2_utf8, column1_utf8view] + + +statement ok +drop table test; From 0c39b4d2ffefcd1e0e77389e493291aaa315d628 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Wed, 17 Jul 2024 04:33:12 +0800 Subject: [PATCH 49/59] Replace to_lowercase with to_string in sql exmaple (#11486) --- datafusion/sql/examples/sql.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 1b92a7e116b16..b724afabaf097 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -69,8 +69,7 @@ struct MyContextProvider { impl MyContextProvider { fn with_udaf(mut self, udaf: Arc) -> Self { - // TODO: change to to_string() if all the function name is converted to lowercase - self.udafs.insert(udaf.name().to_lowercase(), udaf); + self.udafs.insert(udaf.name().to_string(), udaf); self } From 169a0d338cd1b3247da199a03add2119b5289d61 Mon Sep 17 00:00:00 2001 From: Arttu Date: Tue, 16 Jul 2024 22:33:28 +0200 Subject: [PATCH 50/59] chore: switch to using proper Substrait types for IntervalYearMonth and IntervalDayTime (#11471) also clean up IntervalMonthDayNano type - the type itself needs no parameters --- .../substrait/src/logical_plan/consumer.rs | 46 ++++-- .../substrait/src/logical_plan/producer.rs | 140 ++++++------------ 2 files changed, 79 insertions(+), 107 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index a4f7242024754..991aa61fbf159 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -65,7 +65,7 @@ use std::str::FromStr; use std::sync::Arc; use substrait::proto::exchange_rel::ExchangeKind; use substrait::proto::expression::literal::user_defined::Val; -use substrait::proto::expression::literal::IntervalDayToSecond; +use substrait::proto::expression::literal::{IntervalDayToSecond, IntervalYearToMonth}; use substrait::proto::expression::subquery::SubqueryType; use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction}; use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; @@ -1414,7 +1414,7 @@ fn from_substrait_type( })?; let field = Arc::new(Field::new_list_field( from_substrait_type(inner_type, dfs_names, name_idx)?, - // We ignore Substrait's nullability here to match to_substrait_literal + // We ignore Substrait's nullability here to match to_substrait_literal // which always creates nullable lists true, )); @@ -1445,12 +1445,15 @@ fn from_substrait_type( )); match map.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => { - Ok(DataType::Map(Arc::new(Field::new_struct( - "entries", - [key_field, value_field], - false, // The inner map field is always non-nullable (Arrow #1697), - )), false)) - }, + Ok(DataType::Map( + Arc::new(Field::new_struct( + "entries", + [key_field, value_field], + false, // The inner map field is always non-nullable (Arrow #1697), + )), + false, + )) + } v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" )?, @@ -1467,14 +1470,33 @@ fn from_substrait_type( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, + r#type::Kind::IntervalYear(i) => match i.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => { + Ok(DataType::Interval(IntervalUnit::YearMonth)) + } + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::IntervalDay(i) => match i.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => { + Ok(DataType::Interval(IntervalUnit::DayTime)) + } + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, r#type::Kind::UserDefined(u) => { match u.type_reference { + // Kept for backwards compatibility, use IntervalYear instead INTERVAL_YEAR_MONTH_TYPE_REF => { Ok(DataType::Interval(IntervalUnit::YearMonth)) } + // Kept for backwards compatibility, use IntervalDay instead INTERVAL_DAY_TIME_TYPE_REF => { Ok(DataType::Interval(IntervalUnit::DayTime)) } + // Not supported yet by Substrait INTERVAL_MONTH_DAY_NANO_TYPE_REF => { Ok(DataType::Interval(IntervalUnit::MonthDayNano)) } @@ -1484,7 +1506,7 @@ fn from_substrait_type( u.type_variation_reference ), } - }, + } r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type( s, dfs_names, name_idx, )?)), @@ -1753,11 +1775,16 @@ fn from_substrait_literal( seconds, microseconds, })) => { + // DF only supports millisecond precision, so we lose the micros here ScalarValue::new_interval_dt(*days, (seconds * 1000) + (microseconds / 1000)) } + Some(LiteralType::IntervalYearToMonth(IntervalYearToMonth { years, months })) => { + ScalarValue::new_interval_ym(*years, *months) + } Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), Some(LiteralType::UserDefined(user_defined)) => { match user_defined.type_reference { + // Kept for backwards compatibility, use IntervalYearToMonth instead INTERVAL_YEAR_MONTH_TYPE_REF => { let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { return substrait_err!("Interval year month value is empty"); @@ -1770,6 +1797,7 @@ fn from_substrait_literal( })?; ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes(value_slice))) } + // Kept for backwards compatibility, use IntervalDayToSecond instead INTERVAL_DAY_TIME_TYPE_REF => { let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { return substrait_err!("Interval day time value is empty"); diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 8d039a0502494..7849d0bd431e6 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -48,12 +48,11 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Opera use datafusion::prelude::Expr; use pbjson_types::Any as ProtoAny; use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; -use substrait::proto::expression::literal::user_defined::Val; -use substrait::proto::expression::literal::UserDefined; -use substrait::proto::expression::literal::{List, Struct}; +use substrait::proto::expression::literal::{ + user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Struct, UserDefined, +}; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; -use substrait::proto::r#type::{parameter, Parameter}; use substrait::proto::read_rel::VirtualTable; use substrait::proto::{CrossRel, ExchangeRel}; use substrait::{ @@ -95,9 +94,7 @@ use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_DAY_TIME_TYPE_URL, INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_URL, - INTERVAL_YEAR_MONTH_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_URL, LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, @@ -1534,47 +1531,31 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { - // define two type parameters for convenience - let i32_param = Parameter { - parameter: Some(parameter::Parameter::DataType(substrait::proto::Type { - kind: Some(r#type::Kind::I32(r#type::I32 { + match interval_unit { + IntervalUnit::YearMonth => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalYear(r#type::IntervalYear { type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability: r#type::Nullability::Unspecified as i32, + nullability, })), - })), - }; - let i64_param = Parameter { - parameter: Some(parameter::Parameter::DataType(substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { + }), + IntervalUnit::DayTime => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalDay(r#type::IntervalDay { type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability: r#type::Nullability::Unspecified as i32, + nullability, })), - })), - }; - - let (type_parameters, type_reference) = match interval_unit { - IntervalUnit::YearMonth => { - let type_parameters = vec![i32_param]; - (type_parameters, INTERVAL_YEAR_MONTH_TYPE_REF) - } - IntervalUnit::DayTime => { - let type_parameters = vec![i64_param]; - (type_parameters, INTERVAL_DAY_TIME_TYPE_REF) - } + }), IntervalUnit::MonthDayNano => { - // use 2 `i64` as `i128` - let type_parameters = vec![i64_param.clone(), i64_param]; - (type_parameters, INTERVAL_MONTH_DAY_NANO_TYPE_REF) + // Substrait doesn't currently support this type, so we represent it as a UDT + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::UserDefined(r#type::UserDefined { + type_reference: INTERVAL_MONTH_DAY_NANO_TYPE_REF, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + type_parameters: vec![], + })), + }) } - }; - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::UserDefined(r#type::UserDefined { - type_reference, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - type_parameters, - })), - }) + } } DataType::Binary => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Binary(r#type::Binary { @@ -1954,45 +1935,23 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { (LiteralType::Date(*d), DATE_32_TYPE_VARIATION_REF) } // Date64 literal is not supported in Substrait - ScalarValue::IntervalYearMonth(Some(i)) => { - let bytes = i.to_le_bytes(); - ( - LiteralType::UserDefined(UserDefined { - type_reference: INTERVAL_YEAR_MONTH_TYPE_REF, - type_parameters: vec![Parameter { - parameter: Some(parameter::Parameter::DataType( - substrait::proto::Type { - kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability: r#type::Nullability::Required as i32, - })), - }, - )), - }], - val: Some(Val::Value(ProtoAny { - type_url: INTERVAL_YEAR_MONTH_TYPE_URL.to_string(), - value: bytes.to_vec().into(), - })), - }), - INTERVAL_YEAR_MONTH_TYPE_REF, - ) - } + ScalarValue::IntervalYearMonth(Some(i)) => ( + LiteralType::IntervalYearToMonth(IntervalYearToMonth { + // DF only tracks total months, but there should always be 12 months in a year + years: *i / 12, + months: *i % 12, + }), + DEFAULT_TYPE_VARIATION_REF, + ), ScalarValue::IntervalMonthDayNano(Some(i)) => { - // treat `i128` as two contiguous `i64` + // IntervalMonthDayNano is internally represented as a 128-bit integer, containing + // months (32bit), days (32bit), and nanoseconds (64bit) let bytes = i.to_byte_slice(); - let i64_param = Parameter { - parameter: Some(parameter::Parameter::DataType(substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability: r#type::Nullability::Required as i32, - })), - })), - }; ( LiteralType::UserDefined(UserDefined { type_reference: INTERVAL_MONTH_DAY_NANO_TYPE_REF, - type_parameters: vec![i64_param.clone(), i64_param], - val: Some(Val::Value(ProtoAny { + type_parameters: vec![], + val: Some(user_defined::Val::Value(ProtoAny { type_url: INTERVAL_MONTH_DAY_NANO_TYPE_URL.to_string(), value: bytes.to_vec().into(), })), @@ -2000,29 +1959,14 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { INTERVAL_MONTH_DAY_NANO_TYPE_REF, ) } - ScalarValue::IntervalDayTime(Some(i)) => { - let bytes = i.to_byte_slice(); - ( - LiteralType::UserDefined(UserDefined { - type_reference: INTERVAL_DAY_TIME_TYPE_REF, - type_parameters: vec![Parameter { - parameter: Some(parameter::Parameter::DataType( - substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability: r#type::Nullability::Required as i32, - })), - }, - )), - }], - val: Some(Val::Value(ProtoAny { - type_url: INTERVAL_DAY_TIME_TYPE_URL.to_string(), - value: bytes.to_vec().into(), - })), - }), - INTERVAL_DAY_TIME_TYPE_REF, - ) - } + ScalarValue::IntervalDayTime(Some(i)) => ( + LiteralType::IntervalDayToSecond(IntervalDayToSecond { + days: i.days, + seconds: i.milliseconds / 1000, + microseconds: (i.milliseconds % 1000) * 1000, + }), + DEFAULT_TYPE_VARIATION_REF, + ), ScalarValue::Binary(Some(b)) => ( LiteralType::Binary(b.clone()), DEFAULT_CONTAINER_TYPE_VARIATION_REF, From 82fd6a7de310fef4e365c333b0f7fc2a3e4ed12e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Tue, 16 Jul 2024 23:34:01 +0300 Subject: [PATCH 51/59] Move execute_input_stream (#11449) --- datafusion/physical-plan/src/insert.rs | 77 +++----------------- datafusion/physical-plan/src/lib.rs | 97 +++++++++++++++++++++++++- 2 files changed, 103 insertions(+), 71 deletions(-) diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index 1c21991d93c55..5cd864125e29c 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -23,8 +23,8 @@ use std::fmt::Debug; use std::sync::Arc; use super::{ - DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning, - PlanProperties, SendableRecordBatchStream, + execute_input_stream, DisplayAs, DisplayFormatType, ExecutionPlan, + ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream, }; use crate::metrics::MetricsSet; use crate::stream::RecordBatchStreamAdapter; @@ -33,7 +33,7 @@ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow_array::{ArrayRef, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; -use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{ Distribution, EquivalenceProperties, PhysicalSortRequirement, @@ -120,46 +120,6 @@ impl DataSinkExec { } } - fn execute_input_stream( - &self, - partition: usize, - context: Arc, - ) -> Result { - let input_stream = self.input.execute(partition, context)?; - - debug_assert_eq!( - self.sink_schema.fields().len(), - self.input.schema().fields().len() - ); - - // Find input columns that may violate the not null constraint. - let risky_columns: Vec<_> = self - .sink_schema - .fields() - .iter() - .zip(self.input.schema().fields().iter()) - .enumerate() - .filter_map(|(i, (sink_field, input_field))| { - if !sink_field.is_nullable() && input_field.is_nullable() { - Some(i) - } else { - None - } - }) - .collect(); - - if risky_columns.is_empty() { - Ok(input_stream) - } else { - // Check not null constraint on the input stream - Ok(Box::pin(RecordBatchStreamAdapter::new( - Arc::clone(&self.sink_schema), - input_stream - .map(move |batch| check_not_null_contraits(batch?, &risky_columns)), - ))) - } - } - /// Input execution plan pub fn input(&self) -> &Arc { &self.input @@ -269,7 +229,12 @@ impl ExecutionPlan for DataSinkExec { if partition != 0 { return internal_err!("DataSinkExec can only be called on partition 0!"); } - let data = self.execute_input_stream(0, Arc::clone(&context))?; + let data = execute_input_stream( + Arc::clone(&self.input), + Arc::clone(&self.sink_schema), + 0, + Arc::clone(&context), + )?; let count_schema = Arc::clone(&self.count_schema); let sink = Arc::clone(&self.sink); @@ -314,27 +279,3 @@ fn make_count_schema() -> SchemaRef { false, )])) } - -fn check_not_null_contraits( - batch: RecordBatch, - column_indices: &Vec, -) -> Result { - for &index in column_indices { - if batch.num_columns() <= index { - return exec_err!( - "Invalid batch column count {} expected > {}", - batch.num_columns(), - index - ); - } - - if batch.column(index).null_count() > 0 { - return exec_err!( - "Invalid batch column at '{}' has null but schema specifies non-nullable", - index - ); - } - } - - Ok(batch) -} diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index f3a709ff76703..dc736993a4533 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -36,13 +36,13 @@ use arrow::datatypes::SchemaRef; use arrow::ipc::reader::FileReader; use arrow::record_batch::RecordBatch; use datafusion_common::config::ConfigOptions; -use datafusion_common::{exec_datafusion_err, Result}; +use datafusion_common::{exec_datafusion_err, exec_err, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{ EquivalenceProperties, LexOrdering, PhysicalSortExpr, PhysicalSortRequirement, }; -use futures::stream::TryStreamExt; +use futures::stream::{StreamExt, TryStreamExt}; use log::debug; use tokio::sync::mpsc::Sender; use tokio::task::JoinSet; @@ -97,7 +97,7 @@ pub use datafusion_physical_expr::{ // Backwards compatibility use crate::common::IPCWriter; pub use crate::stream::EmptyRecordBatchStream; -use crate::stream::RecordBatchReceiverStream; +use crate::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::human_readable_size; pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; @@ -805,6 +805,97 @@ pub fn execute_stream_partitioned( Ok(streams) } +/// Executes an input stream and ensures that the resulting stream adheres to +/// the `not null` constraints specified in the `sink_schema`. +/// +/// # Arguments +/// +/// * `input` - An execution plan +/// * `sink_schema` - The schema to be applied to the output stream +/// * `partition` - The partition index to be executed +/// * `context` - The task context +/// +/// # Returns +/// +/// * `Result` - A stream of `RecordBatch`es if successful +/// +/// This function first executes the given input plan for the specified partition +/// and context. It then checks if there are any columns in the input that might +/// violate the `not null` constraints specified in the `sink_schema`. If there are +/// such columns, it wraps the resulting stream to enforce the `not null` constraints +/// by invoking the `check_not_null_contraits` function on each batch of the stream. +pub fn execute_input_stream( + input: Arc, + sink_schema: SchemaRef, + partition: usize, + context: Arc, +) -> Result { + let input_stream = input.execute(partition, context)?; + + debug_assert_eq!(sink_schema.fields().len(), input.schema().fields().len()); + + // Find input columns that may violate the not null constraint. + let risky_columns: Vec<_> = sink_schema + .fields() + .iter() + .zip(input.schema().fields().iter()) + .enumerate() + .filter_map(|(idx, (sink_field, input_field))| { + (!sink_field.is_nullable() && input_field.is_nullable()).then_some(idx) + }) + .collect(); + + if risky_columns.is_empty() { + Ok(input_stream) + } else { + // Check not null constraint on the input stream + Ok(Box::pin(RecordBatchStreamAdapter::new( + sink_schema, + input_stream + .map(move |batch| check_not_null_contraits(batch?, &risky_columns)), + ))) + } +} + +/// Checks a `RecordBatch` for `not null` constraints on specified columns. +/// +/// # Arguments +/// +/// * `batch` - The `RecordBatch` to be checked +/// * `column_indices` - A vector of column indices that should be checked for +/// `not null` constraints. +/// +/// # Returns +/// +/// * `Result` - The original `RecordBatch` if all constraints are met +/// +/// This function iterates over the specified column indices and ensures that none +/// of the columns contain null values. If any column contains null values, an error +/// is returned. +pub fn check_not_null_contraits( + batch: RecordBatch, + column_indices: &Vec, +) -> Result { + for &index in column_indices { + if batch.num_columns() <= index { + return exec_err!( + "Invalid batch column count {} expected > {}", + batch.num_columns(), + index + ); + } + + if batch.column(index).null_count() > 0 { + return exec_err!( + "Invalid batch column at '{}' has null but schema specifies non-nullable", + index + ); + } + } + + Ok(batch) +} + /// Utility function yielding a string representation of the given [`ExecutionPlan`]. pub fn get_plan_string(plan: &Arc) -> Vec { let formatted = displayable(plan.as_ref()).indent(true).to_string(); From c54a638585715410fefbe07fd23552e3871bd4f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Wed, 17 Jul 2024 04:35:12 +0800 Subject: [PATCH 52/59] Enable clone_on_ref_ptr clippy lints on proto (#11465) --- datafusion/proto-common/src/from_proto/mod.rs | 2 +- datafusion/proto-common/src/lib.rs | 2 ++ datafusion/proto/src/lib.rs | 2 ++ .../proto/src/physical_plan/from_proto.rs | 2 +- datafusion/proto/src/physical_plan/mod.rs | 18 +++++++++--------- datafusion/proto/src/physical_plan/to_proto.rs | 12 ++++++------ 6 files changed, 21 insertions(+), 17 deletions(-) diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index df673de4e1191..52ca5781dc963 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -448,7 +448,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { None, &message.version(), )?; - Ok(record_batch.column(0).clone()) + Ok(Arc::clone(record_batch.column(0))) } _ => Err(Error::General("dictionary id not found in schema while deserializing ScalarValue::List".to_string())), }?; diff --git a/datafusion/proto-common/src/lib.rs b/datafusion/proto-common/src/lib.rs index 474db652df992..91e3939154424 100644 --- a/datafusion/proto-common/src/lib.rs +++ b/datafusion/proto-common/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] //! Serialize / Deserialize DataFusion Primitive Types to bytes //! diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index 57a1236ba8f4f..bac31850c875b 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] //! Serialize / Deserialize DataFusion Plans to bytes //! diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 52fbd5cbdcf64..b7311c694d4c9 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -365,7 +365,7 @@ pub fn parse_physical_expr( Some(buf) => codec.try_decode_udf(&e.name, buf)?, None => registry.udf(e.name.as_str())?, }; - let scalar_fun_def = udf.clone(); + let scalar_fun_def = Arc::clone(&udf); let args = parse_physical_exprs(&e.args, registry, input_schema, codec)?; diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index e5429945e97ef..948a39bfe0be7 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -1101,7 +1101,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { where Self: Sized, { - let plan_clone = plan.clone(); + let plan_clone = Arc::clone(&plan); let plan = plan.as_any(); if let Some(exec) = plan.downcast_ref::() { @@ -1128,7 +1128,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { let expr = exec .expr() .iter() - .map(|expr| serialize_physical_expr(expr.0.clone(), extension_codec)) + .map(|expr| serialize_physical_expr(Arc::clone(&expr.0), extension_codec)) .collect::>>()?; let expr_name = exec.expr().iter().map(|expr| expr.1.clone()).collect(); return Ok(protobuf::PhysicalPlanNode { @@ -1169,7 +1169,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { protobuf::FilterExecNode { input: Some(Box::new(input)), expr: Some(serialize_physical_expr( - exec.predicate().clone(), + Arc::clone(exec.predicate()), extension_codec, )?), default_filter_selectivity: exec.default_selectivity() as u32, @@ -1585,7 +1585,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { if let Some(exec) = plan.downcast_ref::() { let predicate = exec .predicate() - .map(|pred| serialize_physical_expr(pred.clone(), extension_codec)) + .map(|pred| serialize_physical_expr(Arc::clone(pred), extension_codec)) .transpose()?; return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetScan( @@ -1810,13 +1810,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { let window_expr = exec .window_expr() .iter() - .map(|e| serialize_physical_window_expr(e.clone(), extension_codec)) + .map(|e| serialize_physical_window_expr(Arc::clone(e), extension_codec)) .collect::>>()?; let partition_keys = exec .partition_keys .iter() - .map(|e| serialize_physical_expr(e.clone(), extension_codec)) + .map(|e| serialize_physical_expr(Arc::clone(e), extension_codec)) .collect::>>()?; return Ok(protobuf::PhysicalPlanNode { @@ -1840,13 +1840,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { let window_expr = exec .window_expr() .iter() - .map(|e| serialize_physical_window_expr(e.clone(), extension_codec)) + .map(|e| serialize_physical_window_expr(Arc::clone(e), extension_codec)) .collect::>>()?; let partition_keys = exec .partition_keys .iter() - .map(|e| serialize_physical_expr(e.clone(), extension_codec)) + .map(|e| serialize_physical_expr(Arc::clone(e), extension_codec)) .collect::>>()?; let input_order_mode = match &exec.input_order_mode { @@ -1949,7 +1949,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { } let mut buf: Vec = vec![]; - match extension_codec.try_encode(plan_clone.clone(), &mut buf) { + match extension_codec.try_encode(Arc::clone(&plan_clone), &mut buf) { Ok(_) => { let inputs: Vec = plan_clone .children() diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 9c95acc1dcf47..d8d0291e1ca52 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -323,11 +323,11 @@ pub fn serialize_physical_expr( } else if let Some(expr) = expr.downcast_ref::() { let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { l: Some(Box::new(serialize_physical_expr( - expr.left().clone(), + Arc::clone(expr.left()), codec, )?)), r: Some(Box::new(serialize_physical_expr( - expr.right().clone(), + Arc::clone(expr.right()), codec, )?)), op: format!("{:?}", expr.op()), @@ -347,7 +347,7 @@ pub fn serialize_physical_expr( expr: expr .expr() .map(|exp| { - serialize_physical_expr(exp.clone(), codec) + serialize_physical_expr(Arc::clone(exp), codec) .map(Box::new) }) .transpose()?, @@ -364,7 +364,7 @@ pub fn serialize_physical_expr( else_expr: expr .else_expr() .map(|a| { - serialize_physical_expr(a.clone(), codec) + serialize_physical_expr(Arc::clone(a), codec) .map(Box::new) }) .transpose()?, @@ -552,8 +552,8 @@ fn serialize_when_then_expr( codec: &dyn PhysicalExtensionCodec, ) -> Result { Ok(protobuf::PhysicalWhenThen { - when_expr: Some(serialize_physical_expr(when_expr.clone(), codec)?), - then_expr: Some(serialize_physical_expr(then_expr.clone(), codec)?), + when_expr: Some(serialize_physical_expr(Arc::clone(when_expr), codec)?), + then_expr: Some(serialize_physical_expr(Arc::clone(then_expr), codec)?), }) } From 382bf4f3c7a730828684b9e4ce01369b89717e19 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen <83442793+MohamedAbdeen21@users.noreply.github.com> Date: Tue, 16 Jul 2024 23:41:20 +0300 Subject: [PATCH 53/59] upgrade sqlparser 0.47 -> 0.48 (#11453) * upgrade sqlparser 0.47 -> 0.48 * clean imports and qualified imports * update df-cli cargo lock * fix trailing commas in slt tests * update slt tests results * restore rowsort in slt tests * fix slt tests * rerun CI * reset unchanged slt files * Revert "clean imports and qualified imports" This reverts commit 7be2263793be7730615c52fec79ca3397eefb40f. * update non-windows systems stack size * update windows stack size * remove windows-only unused import * use same test main for all systems * Reapply "clean imports and qualified imports" This reverts commit 4fc036a9112528ec96926df93b1301465829bbcc. --- Cargo.toml | 2 +- datafusion-cli/Cargo.lock | 4 +- datafusion/sql/src/expr/function.rs | 1 + datafusion/sql/src/parser.rs | 18 +++--- datafusion/sql/src/planner.rs | 21 +++++++ datafusion/sql/src/select.rs | 2 +- datafusion/sql/src/statement.rs | 60 ++++++++++--------- datafusion/sql/src/unparser/ast.rs | 5 +- datafusion/sql/src/unparser/expr.rs | 3 + datafusion/sql/src/unparser/plan.rs | 1 + datafusion/sql/tests/sql_integration.rs | 6 +- datafusion/sqllogictest/bin/sqllogictests.rs | 30 +++------- .../sqllogictest/test_files/aggregate.slt | 2 +- .../sqllogictest/test_files/arrow_typeof.slt | 2 +- .../sqllogictest/test_files/coalesce.slt | 2 +- datafusion/sqllogictest/test_files/copy.slt | 4 +- .../test_files/create_external_table.slt | 14 ++--- .../sqllogictest/test_files/csv_files.slt | 2 +- .../sqllogictest/test_files/encoding.slt | 2 +- datafusion/sqllogictest/test_files/expr.slt | 2 +- .../sqllogictest/test_files/group_by.slt | 2 +- datafusion/sqllogictest/test_files/joins.slt | 5 +- datafusion/sqllogictest/test_files/math.slt | 6 +- datafusion/sqllogictest/test_files/misc.slt | 2 +- .../sqllogictest/test_files/predicates.slt | 4 +- datafusion/sqllogictest/test_files/scalar.slt | 2 +- datafusion/sqllogictest/test_files/select.slt | 8 +-- .../sqllogictest/test_files/strings.slt | 2 +- datafusion/sqllogictest/test_files/struct.slt | 2 +- datafusion/sqllogictest/test_files/union.slt | 4 +- datafusion/sqllogictest/test_files/unnest.slt | 6 +- datafusion/sqllogictest/test_files/window.slt | 2 +- 32 files changed, 123 insertions(+), 105 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6dd434abc87c9..f61ed7e58fe37 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -123,7 +123,7 @@ rand = "0.8" regex = "1.8" rstest = "0.21.0" serde_json = "1" -sqlparser = { version = "0.47", features = ["visitor"] } +sqlparser = { version = "0.48", features = ["visitor"] } tempfile = "3" thiserror = "1.0.44" tokio = { version = "1.36", features = ["macros", "rt", "sync"] } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 7da9cc427c37d..e48c6b081e1a5 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -3438,9 +3438,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" -version = "0.47.0" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "295e9930cd7a97e58ca2a070541a3ca502b17f5d1fa7157376d0fabd85324f25" +checksum = "749780d15ad1ee15fd74f5f84b0665560b6abb913de744c2b69155770f9601da" dependencies = [ "log", "sqlparser_derive", diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index d9ddf57eb192c..dab328cc49080 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -109,6 +109,7 @@ impl FunctionArgs { filter, mut null_treatment, within_group, + .. } = function; // Handle no argument form (aka `current_time` as opposed to `current_time()`) diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 8147092c34aba..bc13484235c39 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -1006,14 +1006,15 @@ mod tests { expect_parse_ok(sql, expected)?; // positive case: it is ok for sql stmt with `COMPRESSION TYPE GZIP` tokens - let sqls = vec![ - ("CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS + let sqls = + vec![ + ("CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS ('format.compression' 'GZIP')", "GZIP"), - ("CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS + ("CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS ('format.compression' 'BZIP2')", "BZIP2"), - ("CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS + ("CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS ('format.compression' 'XZ')", "XZ"), - ("CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS + ("CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS ('format.compression' 'ZSTD')", "ZSTD"), ]; for (sql, compression) in sqls { @@ -1123,7 +1124,10 @@ mod tests { // negative case: mixed column defs and column names in `PARTITIONED BY` clause let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (p1 int, c1) LOCATION 'foo.csv'"; - expect_parse_error(sql, "sql parser error: Expected a data type name, found: )"); + expect_parse_error( + sql, + "sql parser error: Expected: a data type name, found: )", + ); // negative case: mixed column defs and column names in `PARTITIONED BY` clause let sql = @@ -1291,7 +1295,7 @@ mod tests { LOCATION 'foo.parquet' OPTIONS ('format.compression' 'zstd', 'format.delimiter' '*', - 'ROW_GROUP_SIZE' '1024', + 'ROW_GROUP_SIZE' '1024', 'TRUNCATE' 'NO', 'format.has_header' 'true')"; let expected = Statement::CreateExternalTable(CreateExternalTable { diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index a77f0003f7380..be04f51f4f2c9 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -468,6 +468,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | SQLDataType::Float64 | SQLDataType::JSONB | SQLDataType::Unspecified + // Clickhouse datatypes + | SQLDataType::Int16 + | SQLDataType::Int32 + | SQLDataType::Int128 + | SQLDataType::Int256 + | SQLDataType::UInt8 + | SQLDataType::UInt16 + | SQLDataType::UInt32 + | SQLDataType::UInt64 + | SQLDataType::UInt128 + | SQLDataType::UInt256 + | SQLDataType::Float32 + | SQLDataType::Date32 + | SQLDataType::Datetime64(_, _) + | SQLDataType::FixedString(_) + | SQLDataType::Map(_, _) + | SQLDataType::Tuple(_) + | SQLDataType::Nested(_) + | SQLDataType::Union(_) + | SQLDataType::Nullable(_) + | SQLDataType::LowCardinality(_) => not_impl_err!( "Unsupported SQL type {sql_type:?}" ), diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 236403e83d74e..a5891e655a052 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -149,7 +149,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let aggr_exprs = find_aggregate_exprs(&aggr_expr_haystack); // All of the group by expressions - let group_by_exprs = if let GroupByExpr::Expressions(exprs) = select.group_by { + let group_by_exprs = if let GroupByExpr::Expressions(exprs, _) = select.group_by { exprs .into_iter() .map(|e| { diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 1acfac79acc0b..6df25086305d2 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -46,18 +46,19 @@ use datafusion_expr::{ cast, col, Analyze, CreateCatalog, CreateCatalogSchema, CreateExternalTable as PlanCreateExternalTable, CreateFunction, CreateFunctionBody, CreateMemoryTable, CreateView, DescribeTable, DmlStatement, DropCatalogSchema, - DropFunction, DropTable, DropView, EmptyRelation, Explain, ExprSchemable, Filter, - LogicalPlan, LogicalPlanBuilder, OperateFunctionArg, PlanType, Prepare, SetVariable, - Statement as PlanStatement, ToStringifiedPlan, TransactionAccessMode, + DropFunction, DropTable, DropView, EmptyRelation, Explain, Expr, ExprSchemable, + Filter, LogicalPlan, LogicalPlanBuilder, OperateFunctionArg, PlanType, Prepare, + SetVariable, Statement as PlanStatement, ToStringifiedPlan, TransactionAccessMode, TransactionConclusion, TransactionEnd, TransactionIsolationLevel, TransactionStart, Volatility, WriteOp, }; use sqlparser::ast; use sqlparser::ast::{ - Assignment, ColumnDef, CreateTableOptions, Delete, DescribeAlias, Expr as SQLExpr, - Expr, FromTable, Ident, Insert, ObjectName, ObjectType, OneOrManyWithParens, Query, - SchemaName, SetExpr, ShowCreateObject, ShowStatementFilter, Statement, - TableConstraint, TableFactor, TableWithJoins, TransactionMode, UnaryOperator, Value, + Assignment, AssignmentTarget, ColumnDef, CreateTable, CreateTableOptions, Delete, + DescribeAlias, Expr as SQLExpr, FromTable, Ident, Insert, ObjectName, ObjectType, + OneOrManyWithParens, Query, SchemaName, SetExpr, ShowCreateObject, + ShowStatementFilter, Statement, TableConstraint, TableFactor, TableWithJoins, + TransactionMode, UnaryOperator, Value, }; use sqlparser::parser::ParserError::ParserError; @@ -240,7 +241,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { value, } => self.set_variable_to_plan(local, hivevar, &variables, value), - Statement::CreateTable { + Statement::CreateTable(CreateTable { query, name, columns, @@ -250,7 +251,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if_not_exists, or_replace, .. - } if table_properties.is_empty() && with_options.is_empty() => { + }) if table_properties.is_empty() && with_options.is_empty() => { // Merge inline constraints and existing constraints let mut all_constraints = constraints; let inline_constraints = calc_inline_constraints_from_columns(&columns); @@ -954,7 +955,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { order_exprs: Vec, schema: &DFSchemaRef, planner_context: &mut PlannerContext, - ) -> Result>> { + ) -> Result>> { // Ask user to provide a schema if schema is empty. if !order_exprs.is_empty() && schema.fields().is_empty() { return plan_err!( @@ -1159,7 +1160,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { local: bool, hivevar: bool, variables: &OneOrManyWithParens, - value: Vec, + value: Vec, ) -> Result { if local { return not_impl_err!("LOCAL is not supported"); @@ -1218,7 +1219,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn delete_to_plan( &self, table_name: ObjectName, - predicate_expr: Option, + predicate_expr: Option, ) -> Result { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(table_name.clone())?; @@ -1264,7 +1265,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { table: TableWithJoins, assignments: Vec, from: Option, - predicate_expr: Option, + predicate_expr: Option, ) -> Result { let (table_name, table_alias) = match &table.relation { TableFactor::Table { name, alias, .. } => (name.clone(), alias.clone()), @@ -1284,8 +1285,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut assign_map = assignments .iter() .map(|assign| { - let col_name: &Ident = assign - .id + let cols = match &assign.target { + AssignmentTarget::ColumnName(cols) => cols, + _ => plan_err!("Tuples are not supported")?, + }; + let col_name: &Ident = cols + .0 .iter() .last() .ok_or_else(|| plan_datafusion_err!("Empty column id"))?; @@ -1293,7 +1298,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { table_schema.field_with_unqualified_name(&col_name.value)?; Ok((col_name.value.clone(), assign.value.clone())) }) - .collect::>>()?; + .collect::>>()?; // Build scan, join with from table if it exists. let mut input_tables = vec![table]; @@ -1332,8 +1337,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &mut planner_context, )?; // Update placeholder's datatype to the type of the target column - if let datafusion_expr::Expr::Placeholder(placeholder) = &mut expr - { + if let Expr::Placeholder(placeholder) = &mut expr { placeholder.data_type = placeholder .data_type .take() @@ -1345,14 +1349,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { None => { // If the target table has an alias, use it to qualify the column name if let Some(alias) = &table_alias { - datafusion_expr::Expr::Column(Column::new( + Expr::Column(Column::new( Some(self.normalizer.normalize(alias.name.clone())), field.name(), )) } else { - datafusion_expr::Expr::Column(Column::from(( - qualifier, field, - ))) + Expr::Column(Column::from((qualifier, field))) } } }; @@ -1427,7 +1429,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if let SetExpr::Values(ast::Values { rows, .. }) = (*source.body).clone() { for row in rows.iter() { for (idx, val) in row.iter().enumerate() { - if let ast::Expr::Value(Value::Placeholder(name)) = val { + if let SQLExpr::Value(Value::Placeholder(name)) = val { let name = name.replace('$', "").parse::().map_err(|_| { plan_datafusion_err!("Can't parse placeholder: {name}") @@ -1460,23 +1462,23 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|(i, value_index)| { let target_field = table_schema.field(i); let expr = match value_index { - Some(v) => datafusion_expr::Expr::Column(Column::from( - source.schema().qualified_field(v), - )) - .cast_to(target_field.data_type(), source.schema())?, + Some(v) => { + Expr::Column(Column::from(source.schema().qualified_field(v))) + .cast_to(target_field.data_type(), source.schema())? + } // The value is not specified. Fill in the default value for the column. None => table_source .get_column_default(target_field.name()) .cloned() .unwrap_or_else(|| { // If there is no default for the column, then the default is NULL - datafusion_expr::Expr::Literal(ScalarValue::Null) + Expr::Literal(ScalarValue::Null) }) .cast_to(target_field.data_type(), &DFSchema::empty())?, }; Ok(expr.alias(target_field.name())) }) - .collect::>>()?; + .collect::>>()?; let source = project(source, exprs)?; let op = if overwrite { diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index 7cbe34825c503..06b4d4a710a31 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -93,6 +93,8 @@ impl QueryBuilder { fetch: self.fetch.clone(), locks: self.locks.clone(), for_clause: self.for_clause.clone(), + settings: None, + format_clause: None, }) } fn create_empty() -> Self { @@ -234,6 +236,7 @@ impl SelectBuilder { value_table_mode: self.value_table_mode, connect_by: None, window_before_qualify: false, + prewhere: None, }) } fn create_empty() -> Self { @@ -245,7 +248,7 @@ impl SelectBuilder { from: Default::default(), lateral_views: Default::default(), selection: Default::default(), - group_by: Some(ast::GroupByExpr::Expressions(Vec::new())), + group_by: Some(ast::GroupByExpr::Expressions(Vec::new(), Vec::new())), cluster_by: Default::default(), distribute_by: Default::default(), sort_by: Default::default(), diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 6b7775ee3d4db..e6b67b5d9fb2d 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -176,6 +176,7 @@ impl Unparser<'_> { null_treatment: None, over: None, within_group: vec![], + parameters: ast::FunctionArguments::None, })) } Expr::Between(Between { @@ -306,6 +307,7 @@ impl Unparser<'_> { null_treatment: None, over, within_group: vec![], + parameters: ast::FunctionArguments::None, })) } Expr::SimilarTo(Like { @@ -351,6 +353,7 @@ impl Unparser<'_> { null_treatment: None, over: None, within_group: vec![], + parameters: ast::FunctionArguments::None, })) } Expr::ScalarSubquery(subq) => { diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 41a8c968841b3..7a653f80be08b 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -172,6 +172,7 @@ impl Unparser<'_> { .iter() .map(|expr| self.expr_to_sql(expr)) .collect::>>()?, + vec![], )); } Some(AggVariant::Window(window)) => { diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index aca0d040bb8da..e34e7e20a0f32 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -3627,7 +3627,7 @@ fn test_prepare_statement_to_plan_panic_prepare_wrong_syntax() { let sql = "PREPARE AS SELECT id, age FROM person WHERE age = $foo"; assert_eq!( logical_plan(sql).unwrap_err().strip_backtrace(), - "SQL error: ParserError(\"Expected AS, found: SELECT\")" + "SQL error: ParserError(\"Expected: AS, found: SELECT\")" ) } @@ -3668,7 +3668,7 @@ fn test_non_prepare_statement_should_infer_types() { #[test] #[should_panic( - expected = "value: SQL(ParserError(\"Expected [NOT] NULL or TRUE|FALSE or [NOT] DISTINCT FROM after IS, found: $1\"" + expected = "value: SQL(ParserError(\"Expected: [NOT] NULL or TRUE|FALSE or [NOT] DISTINCT FROM after IS, found: $1\"" )] fn test_prepare_statement_to_plan_panic_is_param() { let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age is $1"; @@ -4347,7 +4347,7 @@ fn test_parse_escaped_string_literal_value() { let sql = r"SELECT character_length(E'\000') AS len"; assert_eq!( logical_plan(sql).unwrap_err().strip_backtrace(), - "SQL error: TokenizerError(\"Unterminated encoded string literal at Line: 1, Column 25\")" + "SQL error: TokenizerError(\"Unterminated encoded string literal at Line: 1, Column: 25\")" ) } diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index 560328ee8619a..8c8ed2e587439 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -18,8 +18,6 @@ use std::ffi::OsStr; use std::fs; use std::path::{Path, PathBuf}; -#[cfg(target_family = "windows")] -use std::thread; use clap::Parser; use datafusion_sqllogictest::{DataFusion, TestContext}; @@ -32,29 +30,15 @@ use datafusion_common_runtime::SpawnedTask; const TEST_DIRECTORY: &str = "test_files/"; const PG_COMPAT_FILE_PREFIX: &str = "pg_compat_"; +const STACK_SIZE: usize = 2 * 1024 * 1024 + 512 * 1024; // 2.5 MBs, the default 2 MBs is currently too small -#[cfg(target_family = "windows")] -pub fn main() { - // Tests from `tpch/tpch.slt` fail with stackoverflow with the default stack size. - thread::Builder::new() - .stack_size(2 * 1024 * 1024) // 2 MB - .spawn(move || { - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap() - .block_on(async { run_tests().await }) - .unwrap() - }) +pub fn main() -> Result<()> { + tokio::runtime::Builder::new_multi_thread() + .thread_stack_size(STACK_SIZE) + .enable_all() + .build() .unwrap() - .join() - .unwrap(); -} - -#[tokio::main] -#[cfg(not(target_family = "windows"))] -pub async fn main() -> Result<()> { - run_tests().await + .block_on(run_tests()) } /// Sets up an empty directory at test_files/scratch/ diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 6fafc0a74110c..a0140b1c5292a 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3643,7 +3643,7 @@ create table bool_aggregate_functions ( c5 boolean, c6 boolean, c7 boolean, - c8 boolean, + c8 boolean ) as values (true, true, false, false, true, true, null, null), diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index ab4ff9e2ce926..448706744305a 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -430,5 +430,5 @@ select arrow_cast('MyAwesomeString', 'Utf8View'), arrow_typeof(arrow_cast('MyAwe MyAwesomeString Utf8View # Fails until we update arrow-rs with support for https://github.com/apache/arrow-rs/pull/5894 -query error DataFusion error: SQL error: ParserError\("Expected an SQL statement, found: arrow_cast"\) +query error DataFusion error: SQL error: ParserError\("Expected: an SQL statement, found: arrow_cast"\) arrow_cast('MyAwesomeString', 'BinaryView'), arrow_typeof(arrow_cast('MyAwesomeString', 'BinaryView')) diff --git a/datafusion/sqllogictest/test_files/coalesce.slt b/datafusion/sqllogictest/test_files/coalesce.slt index 17b0e774d9cb7..d16b79734c62c 100644 --- a/datafusion/sqllogictest/test_files/coalesce.slt +++ b/datafusion/sqllogictest/test_files/coalesce.slt @@ -361,7 +361,7 @@ drop table test statement ok CREATE TABLE test( c1 BIGINT, - c2 BIGINT, + c2 BIGINT ) as VALUES (1, 2), (NULL, 2), diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index 21c34bc25cee0..6a6ab15a065d3 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -600,7 +600,7 @@ query error DataFusion error: Invalid or Unsupported Configuration: Config value COPY source_table to 'test_files/scratch/copy/table.json' STORED AS JSON OPTIONS ('format.row_group_size' 55); # Incomplete statement -query error DataFusion error: SQL error: ParserError\("Expected \), found: EOF"\) +query error DataFusion error: SQL error: ParserError\("Expected: \), found: EOF"\) COPY (select col2, sum(col1) from source_table # Copy from table with non literal @@ -609,4 +609,4 @@ COPY source_table to '/tmp/table.parquet' (row_group_size 55 + 102); # Copy using execution.keep_partition_by_columns with an invalid value query error DataFusion error: Invalid or Unsupported Configuration: provided value for 'execution.keep_partition_by_columns' was not recognized: "invalid_value" -COPY source_table to '/tmp/table.parquet' OPTIONS (execution.keep_partition_by_columns invalid_value); \ No newline at end of file +COPY source_table to '/tmp/table.parquet' OPTIONS (execution.keep_partition_by_columns invalid_value); diff --git a/datafusion/sqllogictest/test_files/create_external_table.slt b/datafusion/sqllogictest/test_files/create_external_table.slt index 607c909fd63d5..e42d14e101f17 100644 --- a/datafusion/sqllogictest/test_files/create_external_table.slt +++ b/datafusion/sqllogictest/test_files/create_external_table.slt @@ -33,23 +33,23 @@ statement error DataFusion error: SQL error: ParserError\("Missing LOCATION clau CREATE EXTERNAL TABLE t STORED AS CSV # Option value is missing -statement error DataFusion error: SQL error: ParserError\("Expected string or numeric value, found: \)"\) +statement error DataFusion error: SQL error: ParserError\("Expected: string or numeric value, found: \)"\) CREATE EXTERNAL TABLE t STORED AS x OPTIONS ('k1' 'v1', k2 v2, k3) LOCATION 'blahblah' # Missing `(` in WITH ORDER clause -statement error DataFusion error: SQL error: ParserError\("Expected \(, found: c1"\) +statement error DataFusion error: SQL error: ParserError\("Expected: \(, found: c1"\) CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER c1 LOCATION 'foo.csv' # Missing `)` in WITH ORDER clause -statement error DataFusion error: SQL error: ParserError\("Expected \), found: LOCATION"\) +statement error DataFusion error: SQL error: ParserError\("Expected: \), found: LOCATION"\) CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER (c1 LOCATION 'foo.csv' # Missing `ROW` in WITH HEADER clause -statement error DataFusion error: SQL error: ParserError\("Expected ROW, found: LOCATION"\) +statement error DataFusion error: SQL error: ParserError\("Expected: ROW, found: LOCATION"\) CREATE EXTERNAL TABLE t STORED AS CSV WITH HEADER LOCATION 'abc' # Missing `BY` in PARTITIONED clause -statement error DataFusion error: SQL error: ParserError\("Expected BY, found: LOCATION"\) +statement error DataFusion error: SQL error: ParserError\("Expected: BY, found: LOCATION"\) CREATE EXTERNAL TABLE t STORED AS CSV PARTITIONED LOCATION 'abc' # Duplicate `STORED AS` clause @@ -69,11 +69,11 @@ statement error DataFusion error: SQL error: ParserError\("OPTIONS specified mor CREATE EXTERNAL TABLE t STORED AS CSV OPTIONS ('k1' 'v1', 'k2' 'v2') OPTIONS ('k3' 'v3') LOCATION 'foo.csv' # With typo error -statement error DataFusion error: SQL error: ParserError\("Expected HEADER, found: HEAD"\) +statement error DataFusion error: SQL error: ParserError\("Expected: HEADER, found: HEAD"\) CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH HEAD ROW LOCATION 'foo.csv'; # Missing `anything` in WITH clause -statement error DataFusion error: SQL error: ParserError\("Expected HEADER, found: LOCATION"\) +statement error DataFusion error: SQL error: ParserError\("Expected: HEADER, found: LOCATION"\) CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH LOCATION 'foo.csv'; # Unrecognized random clause diff --git a/datafusion/sqllogictest/test_files/csv_files.slt b/datafusion/sqllogictest/test_files/csv_files.slt index a8a689cbb8b5e..ca3bebe79f279 100644 --- a/datafusion/sqllogictest/test_files/csv_files.slt +++ b/datafusion/sqllogictest/test_files/csv_files.slt @@ -167,7 +167,7 @@ physical_plan statement ok CREATE TABLE table_with_necessary_quoting ( int_col INT, - string_col TEXT, + string_col TEXT ) AS VALUES (1, 'e|e|e'), (2, 'f|f|f'), diff --git a/datafusion/sqllogictest/test_files/encoding.slt b/datafusion/sqllogictest/test_files/encoding.slt index 626af88aa9b8c..7a6ac5ca7121a 100644 --- a/datafusion/sqllogictest/test_files/encoding.slt +++ b/datafusion/sqllogictest/test_files/encoding.slt @@ -20,7 +20,7 @@ CREATE TABLE test( num INT, bin_field BYTEA, base64_field TEXT, - hex_field TEXT, + hex_field TEXT ) as VALUES (0, 'abc', encode('abc', 'base64'), encode('abc', 'hex')), (1, 'qweqwe', encode('qweqwe', 'base64'), encode('qweqwe', 'hex')), diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 4e8f3b59a650a..b08d329d4a863 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -2356,7 +2356,7 @@ CREATE TABLE t_source( column1 String, column2 String, column3 String, - column4 String, + column4 String ) AS VALUES ('one', 'one', 'one', 'one'), ('two', 'two', '', 'two'), diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 04a1fcc78fe7a..b2be65a609e37 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -4489,7 +4489,7 @@ LIMIT 5 statement ok CREATE TABLE src_table ( t1 TIMESTAMP, - c2 INT, + c2 INT ) AS VALUES ('2020-12-10T00:00:00.00Z', 0), ('2020-12-11T00:00:00.00Z', 1), diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index df66bffab8e82..b9897f81a107a 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -3844,7 +3844,7 @@ EXPLAIN SELECT * FROM ( ---- logical_plan EmptyRelation -# Left ANTI join with empty right table +# Left ANTI join with empty right table query TT EXPLAIN SELECT * FROM ( SELECT 1 as a @@ -3855,7 +3855,7 @@ logical_plan 02)--Projection: Int64(1) AS a 03)----EmptyRelation -# Right ANTI join with empty left table +# Right ANTI join with empty left table query TT EXPLAIN SELECT * FROM ( SELECT 1 as a WHERE 1=0 @@ -4043,4 +4043,3 @@ physical_plan 03)----MemoryExec: partitions=1, partition_sizes=[1] 04)----SortExec: expr=[b@1 ASC NULLS LAST], preserve_partitioning=[false] 05)------MemoryExec: partitions=1, partition_sizes=[1] - diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 573441ab44013..6ff804c3065d9 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -280,7 +280,7 @@ CREATE TABLE test_non_nullable_integer( c5 TINYINT UNSIGNED NOT NULL, c6 SMALLINT UNSIGNED NOT NULL, c7 INT UNSIGNED NOT NULL, - c8 BIGINT UNSIGNED NOT NULL, + c8 BIGINT UNSIGNED NOT NULL ); query I @@ -348,7 +348,7 @@ drop table test_non_nullable_integer statement ok CREATE TABLE test_nullable_float( c1 float, - c2 double, + c2 double ) AS VALUES (-1.0, -1.0), (1.0, 1.0), @@ -415,7 +415,7 @@ drop table test_nullable_float statement ok CREATE TABLE test_non_nullable_float( c1 float NOT NULL, - c2 double NOT NULL, + c2 double NOT NULL ); query I diff --git a/datafusion/sqllogictest/test_files/misc.slt b/datafusion/sqllogictest/test_files/misc.slt index 66606df834808..9f4710eb9bcc0 100644 --- a/datafusion/sqllogictest/test_files/misc.slt +++ b/datafusion/sqllogictest/test_files/misc.slt @@ -37,4 +37,4 @@ select 1 where NULL and 1 = 1 query I select 1 where NULL or 1 = 1 ---- -1 \ No newline at end of file +1 diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index ffaae7204ecaf..4695e37aa560f 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -584,7 +584,7 @@ DROP TABLE data_index_bloom_encoding_stats; # String coercion ######## -statement error DataFusion error: SQL error: ParserError\("Expected a data type name, found: ,"\) +statement error DataFusion error: SQL error: ParserError\("Expected: a data type name, found: ,"\) CREATE TABLE t(vendor_id_utf8, vendor_id_dict) AS VALUES (arrow_cast('124', 'Utf8'), arrow_cast('124', 'Dictionary(Int16, Utf8)')), @@ -692,7 +692,7 @@ CREATE TABLE IF NOT EXISTS partsupp ( ps_suppkey BIGINT, ps_availqty INTEGER, ps_supplycost DECIMAL(15, 2), - ps_comment VARCHAR, + ps_comment VARCHAR ) AS VALUES (63700, 7311, 100, 993.49, 'ven ideas. quickly even packages print. pending multipliers must have to are fluff'); diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 85ac5b0c242db..5daa9333fb36f 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1578,7 +1578,7 @@ false statement ok CREATE TABLE t1( a boolean, - b boolean, + b boolean ) as VALUES (true, true), (true, null), diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 95f67245a981e..03426dec874f3 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -101,7 +101,7 @@ statement ok CREATE TABLE test ( c1 BIGINT NOT NULL, c2 BIGINT NOT NULL, - c3 BOOLEAN NOT NULL, + c3 BOOLEAN NOT NULL ) AS VALUES (0, 1, false), (0, 10, true), (0, 2, true), @@ -336,13 +336,13 @@ three 1 NULL 1 # select_values_list -statement error DataFusion error: SQL error: ParserError\("Expected \(, found: EOF"\) +statement error DataFusion error: SQL error: ParserError\("Expected: \(, found: EOF"\) VALUES -statement error DataFusion error: SQL error: ParserError\("Expected an expression:, found: \)"\) +statement error DataFusion error: SQL error: ParserError\("Expected: an expression:, found: \)"\) VALUES () -statement error DataFusion error: SQL error: ParserError\("Expected an expression:, found: \)"\) +statement error DataFusion error: SQL error: ParserError\("Expected: an expression:, found: \)"\) VALUES (1),() statement error DataFusion error: Error during planning: Inconsistent data length across values list: got 2 values in row 1 but expected 1 diff --git a/datafusion/sqllogictest/test_files/strings.slt b/datafusion/sqllogictest/test_files/strings.slt index 3cd6c339b44fb..30fb2d750d95e 100644 --- a/datafusion/sqllogictest/test_files/strings.slt +++ b/datafusion/sqllogictest/test_files/strings.slt @@ -17,7 +17,7 @@ statement ok CREATE TABLE test( - s TEXT, + s TEXT ) as VALUES ('p1'), ('p1e1'), diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index fd6e25ea749df..a7384fd4d8ad6 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -24,7 +24,7 @@ CREATE TABLE values( a INT, b FLOAT, c VARCHAR, - n VARCHAR, + n VARCHAR ) AS VALUES (1, 1.1, 'a', NULL), (2, 2.2, 'b', NULL), diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 5ede68a42aae6..31b16f975e9ea 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -22,7 +22,7 @@ statement ok CREATE TABLE t1( id INT, - name TEXT, + name TEXT ) as VALUES (1, 'Alex'), (2, 'Bob'), @@ -32,7 +32,7 @@ CREATE TABLE t1( statement ok CREATE TABLE t2( id TINYINT, - name TEXT, + name TEXT ) as VALUES (1, 'Alex'), (2, 'Bob'), diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index 06733f7b1e40e..698faf87c9b20 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -267,7 +267,7 @@ query error DataFusion error: Error during planning: unnest\(\) requires exactly select unnest(); ## Unnest empty expression in from clause -query error DataFusion error: SQL error: ParserError\("Expected an expression:, found: \)"\) +query error DataFusion error: SQL error: ParserError\("Expected: an expression:, found: \)"\) select * from unnest(); @@ -496,7 +496,7 @@ select unnest(column1) from (select * from (values([1,2,3]), ([4,5,6])) limit 1 5 6 -## FIXME: https://github.com/apache/datafusion/issues/11198 +## FIXME: https://github.com/apache/datafusion/issues/11198 query error DataFusion error: Error during planning: Projections require unique expression names but the expression "UNNEST\(Column\(Column \{ relation: Some\(Bare \{ table: "unnest_table" \}\), name: "column1" \}\)\)" at position 0 and "UNNEST\(Column\(Column \{ relation: Some\(Bare \{ table: "unnest_table" \}\), name: "column1" \}\)\)" at position 1 have the same name. Consider aliasing \("AS"\) one of them. select unnest(column1), unnest(column1) from unnest_table; @@ -556,4 +556,4 @@ physical_plan 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 06)----------UnnestExec 07)------------ProjectionExec: expr=[column3@0 as unnest(recursive_unnest_table.column3), column3@0 as column3] -08)--------------MemoryExec: partitions=1, partition_sizes=[1] \ No newline at end of file +08)--------------MemoryExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index a865a7ccbd8fb..5296f13de08a5 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3415,7 +3415,7 @@ SELECT # window1 spec is defined multiple times statement error DataFusion error: Error during planning: The window window1 is defined multiple times! SELECT - MAX(c12) OVER window1 as min1, + MAX(c12) OVER window1 as min1 FROM aggregate_test_100 WINDOW window1 AS (ORDER BY C12), window1 AS (ORDER BY C3) From 81aff944bd76b674a22371f7deaa12560d2f629d Mon Sep 17 00:00:00 2001 From: Arttu Date: Tue, 16 Jul 2024 22:54:50 +0200 Subject: [PATCH 54/59] feat: support UDWFs in Substrait (#11489) * feat: support UDWFs in Substrait Previously Substrait consumer would, for window functions, look at: 1. UDAFs 2. built-in window functions 3. built-in aggregate functions That makes it tough to override the built-in window function behavior, as it could only be overridden with a UDAF but some window functions don't fit nicely into aggregates. This change adds UDWFs at the top, so the consumer will look at: 1. UDWFs 2. UDAFs 3. built-in window functions 4. built-in aggregate functions This also paves the way for moving DF's built-in window funcs into UDWFs. * check udwf first, then udaf --- .../substrait/src/logical_plan/consumer.rs | 27 +++++++------- .../tests/cases/roundtrip_logical_plan.rs | 36 ++++++++++++++++++- 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 991aa61fbf159..1365630d5079a 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -23,8 +23,8 @@ use datafusion::arrow::datatypes::{ }; use datafusion::common::plan_err; use datafusion::common::{ - not_impl_datafusion_err, not_impl_err, plan_datafusion_err, substrait_datafusion_err, - substrait_err, DFSchema, DFSchemaRef, + not_impl_err, plan_datafusion_err, substrait_datafusion_err, substrait_err, DFSchema, + DFSchemaRef, }; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; @@ -1182,16 +1182,19 @@ pub async fn from_substrait_rex( }; let fn_name = substrait_fun_name(fn_name); - // check udaf first, then built-in functions - let fun = match ctx.udaf(fn_name) { - Ok(udaf) => Ok(WindowFunctionDefinition::AggregateUDF(udaf)), - Err(_) => find_df_window_func(fn_name).ok_or_else(|| { - not_impl_datafusion_err!( - "Window function {} is not supported: function anchor = {:?}", - fn_name, - window.function_reference - ) - }), + // check udwf first, then udaf, then built-in window and aggregate functions + let fun = if let Ok(udwf) = ctx.udwf(fn_name) { + Ok(WindowFunctionDefinition::WindowUDF(udwf)) + } else if let Ok(udaf) = ctx.udaf(fn_name) { + Ok(WindowFunctionDefinition::AggregateUDF(udaf)) + } else if let Some(fun) = find_df_window_func(fn_name) { + Ok(fun) + } else { + not_impl_err!( + "Window function {} is not supported: function anchor = {:?}", + fn_name, + window.function_reference + ) }?; let order_by = diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 5b2d0fbacaef0..a7653e11d598f 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -31,7 +31,8 @@ use datafusion::error::Result; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_expr::{ - Extension, LogicalPlan, Repartition, UserDefinedLogicalNode, Volatility, + Extension, LogicalPlan, PartitionEvaluator, Repartition, UserDefinedLogicalNode, + Volatility, }; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; @@ -860,6 +861,39 @@ async fn roundtrip_aggregate_udf() -> Result<()> { roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await } +#[tokio::test] +async fn roundtrip_window_udf() -> Result<()> { + #[derive(Debug)] + struct Dummy {} + + impl PartitionEvaluator for Dummy { + fn evaluate_all( + &mut self, + values: &[ArrayRef], + _num_rows: usize, + ) -> Result { + Ok(values[0].to_owned()) + } + } + + fn make_partition_evaluator() -> Result> { + Ok(Box::new(Dummy {})) + } + + let dummy_agg = create_udwf( + "dummy_window", // name + DataType::Int64, // input type + Arc::new(DataType::Int64), // return type + Volatility::Immutable, + Arc::new(make_partition_evaluator), + ); + + let ctx = create_context().await?; + ctx.register_udwf(dummy_agg); + + roundtrip_with_ctx("select dummy_window(a) OVER () from data", ctx).await +} + #[tokio::test] async fn roundtrip_repartition_roundrobin() -> Result<()> { let ctx = create_context().await?; From 02326998f07a13fda0c93988bf13853413c4a2b2 Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Wed, 17 Jul 2024 00:52:20 +0300 Subject: [PATCH 55/59] Add extension hooks for encoding and decoding UDAFs and UDWFs (#11417) * Add extension hooks for encoding and decoding UDAFs and UDWFs * Add tests for encoding and decoding UDAF --- .../examples/composed_extension_codec.rs | 80 +++--- .../physical-expr-common/src/aggregate/mod.rs | 5 + datafusion/proto/proto/datafusion.proto | 35 +-- datafusion/proto/src/generated/pbjson.rs | 102 +++++++ datafusion/proto/src/generated/prost.rs | 10 + .../proto/src/logical_plan/file_formats.rs | 80 ------ .../proto/src/logical_plan/from_proto.rs | 42 +-- datafusion/proto/src/logical_plan/mod.rs | 22 +- datafusion/proto/src/logical_plan/to_proto.rs | 84 +++--- .../proto/src/physical_plan/from_proto.rs | 6 +- datafusion/proto/src/physical_plan/mod.rs | 23 +- .../proto/src/physical_plan/to_proto.rs | 122 +++++---- datafusion/proto/tests/cases/mod.rs | 99 +++++++ .../tests/cases/roundtrip_logical_plan.rs | 171 +++++------- .../tests/cases/roundtrip_physical_plan.rs | 251 +++++++++++------- 15 files changed, 686 insertions(+), 446 deletions(-) diff --git a/datafusion-examples/examples/composed_extension_codec.rs b/datafusion-examples/examples/composed_extension_codec.rs index 43c6daba211ac..5c34eccf26e11 100644 --- a/datafusion-examples/examples/composed_extension_codec.rs +++ b/datafusion-examples/examples/composed_extension_codec.rs @@ -30,18 +30,19 @@ //! DeltaScan //! ``` +use std::any::Any; +use std::fmt::Debug; +use std::ops::Deref; +use std::sync::Arc; + use datafusion::common::Result; use datafusion::physical_plan::{DisplayAs, ExecutionPlan}; use datafusion::prelude::SessionContext; -use datafusion_common::internal_err; +use datafusion_common::{internal_err, DataFusionError}; use datafusion_expr::registry::FunctionRegistry; -use datafusion_expr::ScalarUDF; +use datafusion_expr::{AggregateUDF, ScalarUDF}; use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec}; use datafusion_proto::protobuf; -use std::any::Any; -use std::fmt::Debug; -use std::ops::Deref; -use std::sync::Arc; #[tokio::main] async fn main() { @@ -239,6 +240,25 @@ struct ComposedPhysicalExtensionCodec { codecs: Vec>, } +impl ComposedPhysicalExtensionCodec { + fn try_any( + &self, + mut f: impl FnMut(&dyn PhysicalExtensionCodec) -> Result, + ) -> Result { + let mut last_err = None; + for codec in &self.codecs { + match f(codec.as_ref()) { + Ok(node) => return Ok(node), + Err(err) => last_err = Some(err), + } + } + + Err(last_err.unwrap_or_else(|| { + DataFusionError::NotImplemented("Empty list of composed codecs".to_owned()) + })) + } +} + impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { fn try_decode( &self, @@ -246,46 +266,26 @@ impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { inputs: &[Arc], registry: &dyn FunctionRegistry, ) -> Result> { - let mut last_err = None; - for codec in &self.codecs { - match codec.try_decode(buf, inputs, registry) { - Ok(plan) => return Ok(plan), - Err(e) => last_err = Some(e), - } - } - Err(last_err.unwrap()) + self.try_any(|codec| codec.try_decode(buf, inputs, registry)) } fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { - let mut last_err = None; - for codec in &self.codecs { - match codec.try_encode(node.clone(), buf) { - Ok(_) => return Ok(()), - Err(e) => last_err = Some(e), - } - } - Err(last_err.unwrap()) + self.try_any(|codec| codec.try_encode(node.clone(), buf)) } - fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { - let mut last_err = None; - for codec in &self.codecs { - match codec.try_decode_udf(name, _buf) { - Ok(plan) => return Ok(plan), - Err(e) => last_err = Some(e), - } - } - Err(last_err.unwrap()) + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + self.try_any(|codec| codec.try_decode_udf(name, buf)) } - fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec) -> Result<()> { - let mut last_err = None; - for codec in &self.codecs { - match codec.try_encode_udf(_node, _buf) { - Ok(_) => return Ok(()), - Err(e) => last_err = Some(e), - } - } - Err(last_err.unwrap()) + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + self.try_any(|codec| codec.try_encode_udf(node, buf)) + } + + fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + self.try_any(|codec| codec.try_decode_udaf(name, buf)) + } + + fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + self.try_any(|codec| codec.try_encode_udaf(node, buf)) } } diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index db4581a622acc..0e245fd0a66aa 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -283,6 +283,11 @@ impl AggregateFunctionExpr { pub fn is_distinct(&self) -> bool { self.is_distinct } + + /// Return if the aggregation ignores nulls + pub fn ignore_nulls(&self) -> bool { + self.ignore_nulls + } } impl AggregateExpr for AggregateFunctionExpr { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9ef884531e320..dc551778c5fb2 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -164,7 +164,7 @@ message CreateExternalTableNode { map options = 8; datafusion_common.Constraints constraints = 12; map column_defaults = 13; - } +} message PrepareNode { string name = 1; @@ -249,24 +249,24 @@ message DistinctOnNode { } message CopyToNode { - LogicalPlanNode input = 1; - string output_url = 2; - bytes file_type = 3; - repeated string partition_by = 7; + LogicalPlanNode input = 1; + string output_url = 2; + bytes file_type = 3; + repeated string partition_by = 7; } message UnnestNode { - LogicalPlanNode input = 1; - repeated datafusion_common.Column exec_columns = 2; - repeated uint64 list_type_columns = 3; - repeated uint64 struct_type_columns = 4; - repeated uint64 dependency_indices = 5; - datafusion_common.DfSchema schema = 6; - UnnestOptions options = 7; + LogicalPlanNode input = 1; + repeated datafusion_common.Column exec_columns = 2; + repeated uint64 list_type_columns = 3; + repeated uint64 struct_type_columns = 4; + repeated uint64 dependency_indices = 5; + datafusion_common.DfSchema schema = 6; + UnnestOptions options = 7; } message UnnestOptions { - bool preserve_nulls = 1; + bool preserve_nulls = 1; } message UnionNode { @@ -488,8 +488,8 @@ enum AggregateFunction { // BIT_AND = 19; // BIT_OR = 20; // BIT_XOR = 21; -// BOOL_AND = 22; -// BOOL_OR = 23; + // BOOL_AND = 22; + // BOOL_OR = 23; // REGR_SLOPE = 26; // REGR_INTERCEPT = 27; // REGR_COUNT = 28; @@ -517,6 +517,7 @@ message AggregateUDFExprNode { bool distinct = 5; LogicalExprNode filter = 3; repeated LogicalExprNode order_by = 4; + optional bytes fun_definition = 6; } message ScalarUDFExprNode { @@ -551,6 +552,7 @@ message WindowExprNode { repeated LogicalExprNode order_by = 6; // repeated LogicalExprNode filter = 7; WindowFrame window_frame = 8; + optional bytes fun_definition = 10; } message BetweenNode { @@ -856,6 +858,8 @@ message PhysicalAggregateExprNode { repeated PhysicalExprNode expr = 2; repeated PhysicalSortExprNode ordering_req = 5; bool distinct = 3; + bool ignore_nulls = 6; + optional bytes fun_definition = 7; } message PhysicalWindowExprNode { @@ -869,6 +873,7 @@ message PhysicalWindowExprNode { repeated PhysicalSortExprNode order_by = 6; WindowFrame window_frame = 7; string name = 8; + optional bytes fun_definition = 9; } message PhysicalIsNull { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index fa989480fad90..8f77c24bd9117 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -829,6 +829,9 @@ impl serde::Serialize for AggregateUdfExprNode { if !self.order_by.is_empty() { len += 1; } + if self.fun_definition.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AggregateUDFExprNode", len)?; if !self.fun_name.is_empty() { struct_ser.serialize_field("funName", &self.fun_name)?; @@ -845,6 +848,10 @@ impl serde::Serialize for AggregateUdfExprNode { if !self.order_by.is_empty() { struct_ser.serialize_field("orderBy", &self.order_by)?; } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; + } struct_ser.end() } } @@ -862,6 +869,8 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { "filter", "order_by", "orderBy", + "fun_definition", + "funDefinition", ]; #[allow(clippy::enum_variant_names)] @@ -871,6 +880,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { Distinct, Filter, OrderBy, + FunDefinition, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -897,6 +907,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { "distinct" => Ok(GeneratedField::Distinct), "filter" => Ok(GeneratedField::Filter), "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -921,6 +932,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { let mut distinct__ = None; let mut filter__ = None; let mut order_by__ = None; + let mut fun_definition__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::FunName => { @@ -953,6 +965,14 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { } order_by__ = Some(map_.next_value()?); } + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); + } + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; + } } } Ok(AggregateUdfExprNode { @@ -961,6 +981,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { distinct: distinct__.unwrap_or_default(), filter: filter__, order_by: order_by__.unwrap_or_default(), + fun_definition: fun_definition__, }) } } @@ -12631,6 +12652,12 @@ impl serde::Serialize for PhysicalAggregateExprNode { if self.distinct { len += 1; } + if self.ignore_nulls { + len += 1; + } + if self.fun_definition.is_some() { + len += 1; + } if self.aggregate_function.is_some() { len += 1; } @@ -12644,6 +12671,13 @@ impl serde::Serialize for PhysicalAggregateExprNode { if self.distinct { struct_ser.serialize_field("distinct", &self.distinct)?; } + if self.ignore_nulls { + struct_ser.serialize_field("ignoreNulls", &self.ignore_nulls)?; + } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; + } if let Some(v) = self.aggregate_function.as_ref() { match v { physical_aggregate_expr_node::AggregateFunction::AggrFunction(v) => { @@ -12670,6 +12704,10 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { "ordering_req", "orderingReq", "distinct", + "ignore_nulls", + "ignoreNulls", + "fun_definition", + "funDefinition", "aggr_function", "aggrFunction", "user_defined_aggr_function", @@ -12681,6 +12719,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { Expr, OrderingReq, Distinct, + IgnoreNulls, + FunDefinition, AggrFunction, UserDefinedAggrFunction, } @@ -12707,6 +12747,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { "expr" => Ok(GeneratedField::Expr), "orderingReq" | "ordering_req" => Ok(GeneratedField::OrderingReq), "distinct" => Ok(GeneratedField::Distinct), + "ignoreNulls" | "ignore_nulls" => Ok(GeneratedField::IgnoreNulls), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -12731,6 +12773,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { let mut expr__ = None; let mut ordering_req__ = None; let mut distinct__ = None; + let mut ignore_nulls__ = None; + let mut fun_definition__ = None; let mut aggregate_function__ = None; while let Some(k) = map_.next_key()? { match k { @@ -12752,6 +12796,20 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { } distinct__ = Some(map_.next_value()?); } + GeneratedField::IgnoreNulls => { + if ignore_nulls__.is_some() { + return Err(serde::de::Error::duplicate_field("ignoreNulls")); + } + ignore_nulls__ = Some(map_.next_value()?); + } + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); + } + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; + } GeneratedField::AggrFunction => { if aggregate_function__.is_some() { return Err(serde::de::Error::duplicate_field("aggrFunction")); @@ -12770,6 +12828,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { expr: expr__.unwrap_or_default(), ordering_req: ordering_req__.unwrap_or_default(), distinct: distinct__.unwrap_or_default(), + ignore_nulls: ignore_nulls__.unwrap_or_default(), + fun_definition: fun_definition__, aggregate_function: aggregate_function__, }) } @@ -15832,6 +15892,9 @@ impl serde::Serialize for PhysicalWindowExprNode { if !self.name.is_empty() { len += 1; } + if self.fun_definition.is_some() { + len += 1; + } if self.window_function.is_some() { len += 1; } @@ -15851,6 +15914,10 @@ impl serde::Serialize for PhysicalWindowExprNode { if !self.name.is_empty() { struct_ser.serialize_field("name", &self.name)?; } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; + } if let Some(v) = self.window_function.as_ref() { match v { physical_window_expr_node::WindowFunction::AggrFunction(v) => { @@ -15886,6 +15953,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { "window_frame", "windowFrame", "name", + "fun_definition", + "funDefinition", "aggr_function", "aggrFunction", "built_in_function", @@ -15901,6 +15970,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { OrderBy, WindowFrame, Name, + FunDefinition, AggrFunction, BuiltInFunction, UserDefinedAggrFunction, @@ -15930,6 +16000,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), "name" => Ok(GeneratedField::Name), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), "builtInFunction" | "built_in_function" => Ok(GeneratedField::BuiltInFunction), "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), @@ -15957,6 +16028,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { let mut order_by__ = None; let mut window_frame__ = None; let mut name__ = None; + let mut fun_definition__ = None; let mut window_function__ = None; while let Some(k) = map_.next_key()? { match k { @@ -15990,6 +16062,14 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { } name__ = Some(map_.next_value()?); } + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); + } + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; + } GeneratedField::AggrFunction => { if window_function__.is_some() { return Err(serde::de::Error::duplicate_field("aggrFunction")); @@ -16016,6 +16096,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { order_by: order_by__.unwrap_or_default(), window_frame: window_frame__, name: name__.unwrap_or_default(), + fun_definition: fun_definition__, window_function: window_function__, }) } @@ -20349,6 +20430,9 @@ impl serde::Serialize for WindowExprNode { if self.window_frame.is_some() { len += 1; } + if self.fun_definition.is_some() { + len += 1; + } if self.window_function.is_some() { len += 1; } @@ -20365,6 +20449,10 @@ impl serde::Serialize for WindowExprNode { if let Some(v) = self.window_frame.as_ref() { struct_ser.serialize_field("windowFrame", v)?; } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; + } if let Some(v) = self.window_function.as_ref() { match v { window_expr_node::WindowFunction::AggrFunction(v) => { @@ -20402,6 +20490,8 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { "orderBy", "window_frame", "windowFrame", + "fun_definition", + "funDefinition", "aggr_function", "aggrFunction", "built_in_function", @@ -20416,6 +20506,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { PartitionBy, OrderBy, WindowFrame, + FunDefinition, AggrFunction, BuiltInFunction, Udaf, @@ -20445,6 +20536,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { "partitionBy" | "partition_by" => Ok(GeneratedField::PartitionBy), "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), "builtInFunction" | "built_in_function" => Ok(GeneratedField::BuiltInFunction), "udaf" => Ok(GeneratedField::Udaf), @@ -20472,6 +20564,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { let mut partition_by__ = None; let mut order_by__ = None; let mut window_frame__ = None; + let mut fun_definition__ = None; let mut window_function__ = None; while let Some(k) = map_.next_key()? { match k { @@ -20499,6 +20592,14 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { } window_frame__ = map_.next_value()?; } + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); + } + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; + } GeneratedField::AggrFunction => { if window_function__.is_some() { return Err(serde::de::Error::duplicate_field("aggrFunction")); @@ -20530,6 +20631,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { partition_by: partition_by__.unwrap_or_default(), order_by: order_by__.unwrap_or_default(), window_frame: window_frame__, + fun_definition: fun_definition__, window_function: window_function__, }) } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 8407e545fe650..605c56fa946a3 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -756,6 +756,8 @@ pub struct AggregateUdfExprNode { pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "4")] pub order_by: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", optional, tag = "6")] + pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -779,6 +781,8 @@ pub struct WindowExprNode { /// repeated LogicalExprNode filter = 7; #[prost(message, optional, tag = "8")] pub window_frame: ::core::option::Option, + #[prost(bytes = "vec", optional, tag = "10")] + pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, #[prost(oneof = "window_expr_node::WindowFunction", tags = "1, 2, 3, 9")] pub window_function: ::core::option::Option, } @@ -1291,6 +1295,10 @@ pub struct PhysicalAggregateExprNode { pub ordering_req: ::prost::alloc::vec::Vec, #[prost(bool, tag = "3")] pub distinct: bool, + #[prost(bool, tag = "6")] + pub ignore_nulls: bool, + #[prost(bytes = "vec", optional, tag = "7")] + pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, #[prost(oneof = "physical_aggregate_expr_node::AggregateFunction", tags = "1, 4")] pub aggregate_function: ::core::option::Option< physical_aggregate_expr_node::AggregateFunction, @@ -1320,6 +1328,8 @@ pub struct PhysicalWindowExprNode { pub window_frame: ::core::option::Option, #[prost(string, tag = "8")] pub name: ::prost::alloc::string::String, + #[prost(bytes = "vec", optional, tag = "9")] + pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "1, 2, 3")] pub window_function: ::core::option::Option< physical_window_expr_node::WindowFunction, diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs index 106d5639489e7..09e36a650b9fa 100644 --- a/datafusion/proto/src/logical_plan/file_formats.rs +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -86,22 +86,6 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { ) -> datafusion_common::Result<()> { Ok(()) } - - fn try_decode_udf( - &self, - name: &str, - __buf: &[u8], - ) -> datafusion_common::Result> { - not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") - } - - fn try_encode_udf( - &self, - __node: &datafusion_expr::ScalarUDF, - __buf: &mut Vec, - ) -> datafusion_common::Result<()> { - Ok(()) - } } #[derive(Debug)] @@ -162,22 +146,6 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { ) -> datafusion_common::Result<()> { Ok(()) } - - fn try_decode_udf( - &self, - name: &str, - __buf: &[u8], - ) -> datafusion_common::Result> { - not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") - } - - fn try_encode_udf( - &self, - __node: &datafusion_expr::ScalarUDF, - __buf: &mut Vec, - ) -> datafusion_common::Result<()> { - Ok(()) - } } #[derive(Debug)] @@ -238,22 +206,6 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { ) -> datafusion_common::Result<()> { Ok(()) } - - fn try_decode_udf( - &self, - name: &str, - __buf: &[u8], - ) -> datafusion_common::Result> { - not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") - } - - fn try_encode_udf( - &self, - __node: &datafusion_expr::ScalarUDF, - __buf: &mut Vec, - ) -> datafusion_common::Result<()> { - Ok(()) - } } #[derive(Debug)] @@ -314,22 +266,6 @@ impl LogicalExtensionCodec for ArrowLogicalExtensionCodec { ) -> datafusion_common::Result<()> { Ok(()) } - - fn try_decode_udf( - &self, - name: &str, - __buf: &[u8], - ) -> datafusion_common::Result> { - not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") - } - - fn try_encode_udf( - &self, - __node: &datafusion_expr::ScalarUDF, - __buf: &mut Vec, - ) -> datafusion_common::Result<()> { - Ok(()) - } } #[derive(Debug)] @@ -390,20 +326,4 @@ impl LogicalExtensionCodec for AvroLogicalExtensionCodec { ) -> datafusion_common::Result<()> { Ok(()) } - - fn try_decode_udf( - &self, - name: &str, - __buf: &[u8], - ) -> datafusion_common::Result> { - not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") - } - - fn try_encode_udf( - &self, - __node: &datafusion_expr::ScalarUDF, - __buf: &mut Vec, - ) -> datafusion_common::Result<()> { - Ok(()) - } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 095c6a50973a1..b6b556a8ed6b2 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -308,14 +308,17 @@ pub fn parse_expr( let aggr_function = parse_i32_to_aggregate_function(i)?; Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::expr::WindowFunctionDefinition::AggregateFunction( - aggr_function, - ), - vec![parse_required_expr(expr.expr.as_deref(), registry, "expr", codec)?], + expr::WindowFunctionDefinition::AggregateFunction(aggr_function), + vec![parse_required_expr( + expr.expr.as_deref(), + registry, + "expr", + codec, + )?], partition_by, order_by, window_frame, - None + None, ))) } window_expr_node::WindowFunction::BuiltInFunction(i) => { @@ -329,26 +332,28 @@ pub fn parse_expr( .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::expr::WindowFunctionDefinition::BuiltInWindowFunction( + expr::WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, ), args, partition_by, order_by, window_frame, - null_treatment + null_treatment, ))) } window_expr_node::WindowFunction::Udaf(udaf_name) => { - let udaf_function = registry.udaf(udaf_name)?; + let udaf_function = match &expr.fun_definition { + Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, + None => registry.udaf(udaf_name)?, + }; + let args = parse_optional_expr(expr.expr.as_deref(), registry, codec)? .map(|e| vec![e]) .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::expr::WindowFunctionDefinition::AggregateUDF( - udaf_function, - ), + expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, partition_by, order_by, @@ -357,15 +362,17 @@ pub fn parse_expr( ))) } window_expr_node::WindowFunction::Udwf(udwf_name) => { - let udwf_function = registry.udwf(udwf_name)?; + let udwf_function = match &expr.fun_definition { + Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, + None => registry.udwf(udwf_name)?, + }; + let args = parse_optional_expr(expr.expr.as_deref(), registry, codec)? .map(|e| vec![e]) .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::expr::WindowFunctionDefinition::WindowUDF( - udwf_function, - ), + expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, partition_by, order_by, @@ -613,7 +620,10 @@ pub fn parse_expr( ))) } ExprType::AggregateUdfExpr(pb) => { - let agg_fn = registry.udaf(pb.fun_name.as_str())?; + let agg_fn = match &pb.fun_definition { + Some(buf) => codec.try_decode_udaf(&pb.fun_name, buf)?, + None => registry.udaf(&pb.fun_name)?, + }; Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 664cd7e115557..2a963fb13ccf0 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -51,7 +51,6 @@ use datafusion_common::{ context, internal_datafusion_err, internal_err, not_impl_err, DataFusionError, Result, TableReference, }; -use datafusion_expr::Unnest; use datafusion_expr::{ dml, logical_plan::{ @@ -60,8 +59,9 @@ use datafusion_expr::{ EmptyRelation, Extension, Join, JoinConstraint, Limit, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Values, Window, }, - DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, + DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, WindowUDF, }; +use datafusion_expr::{AggregateUDF, Unnest}; use prost::bytes::BufMut; use prost::Message; @@ -144,6 +144,24 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync { fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec) -> Result<()> { Ok(()) } + + fn try_decode_udaf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!( + "LogicalExtensionCodec is not provided for aggregate function {name}" + ) + } + + fn try_encode_udaf(&self, _node: &AggregateUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } + + fn try_decode_udwf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!("LogicalExtensionCodec is not provided for window function {name}") + } + + fn try_encode_udwf(&self, _node: &WindowUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } } #[derive(Debug, Clone)] diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index d8f8ea002b2dd..9607b918eb895 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -319,25 +319,37 @@ pub fn serialize_expr( // TODO: support null treatment in proto null_treatment: _, }) => { - let window_function = match fun { - WindowFunctionDefinition::AggregateFunction(fun) => { + let (window_function, fun_definition) = match fun { + WindowFunctionDefinition::AggregateFunction(fun) => ( protobuf::window_expr_node::WindowFunction::AggrFunction( protobuf::AggregateFunction::from(fun).into(), - ) - } - WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + ), + None, + ), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => ( protobuf::window_expr_node::WindowFunction::BuiltInFunction( protobuf::BuiltInWindowFunction::from(fun).into(), - ) - } + ), + None, + ), WindowFunctionDefinition::AggregateUDF(aggr_udf) => { - protobuf::window_expr_node::WindowFunction::Udaf( - aggr_udf.name().to_string(), + let mut buf = Vec::new(); + let _ = codec.try_encode_udaf(aggr_udf, &mut buf); + ( + protobuf::window_expr_node::WindowFunction::Udaf( + aggr_udf.name().to_string(), + ), + (!buf.is_empty()).then_some(buf), ) } WindowFunctionDefinition::WindowUDF(window_udf) => { - protobuf::window_expr_node::WindowFunction::Udwf( - window_udf.name().to_string(), + let mut buf = Vec::new(); + let _ = codec.try_encode_udwf(window_udf, &mut buf); + ( + protobuf::window_expr_node::WindowFunction::Udwf( + window_udf.name().to_string(), + ), + (!buf.is_empty()).then_some(buf), ) } }; @@ -358,6 +370,7 @@ pub fn serialize_expr( partition_by, order_by, window_frame, + fun_definition, }); protobuf::LogicalExprNode { expr_type: Some(ExprType::WindowExpr(window_expr)), @@ -395,23 +408,30 @@ pub fn serialize_expr( expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))), } } - AggregateFunctionDefinition::UDF(fun) => protobuf::LogicalExprNode { - expr_type: Some(ExprType::AggregateUdfExpr(Box::new( - protobuf::AggregateUdfExprNode { - fun_name: fun.name().to_string(), - args: serialize_exprs(args, codec)?, - distinct: *distinct, - filter: match filter { - Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), - None => None, - }, - order_by: match order_by { - Some(e) => serialize_exprs(e, codec)?, - None => vec![], + AggregateFunctionDefinition::UDF(fun) => { + let mut buf = Vec::new(); + let _ = codec.try_encode_udaf(fun, &mut buf); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::AggregateUdfExpr(Box::new( + protobuf::AggregateUdfExprNode { + fun_name: fun.name().to_string(), + args: serialize_exprs(args, codec)?, + distinct: *distinct, + filter: match filter { + Some(e) => { + Some(Box::new(serialize_expr(e.as_ref(), codec)?)) + } + None => None, + }, + order_by: match order_by { + Some(e) => serialize_exprs(e, codec)?, + None => vec![], + }, + fun_definition: (!buf.is_empty()).then_some(buf), }, - }, - ))), - }, + ))), + } + } }, Expr::ScalarVariable(_, _) => { @@ -420,17 +440,13 @@ pub fn serialize_expr( )) } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let args = serialize_exprs(args, codec)?; let mut buf = Vec::new(); - let _ = codec.try_encode_udf(func.as_ref(), &mut buf); - - let fun_definition = if buf.is_empty() { None } else { Some(buf) }; - + let _ = codec.try_encode_udf(func, &mut buf); protobuf::LogicalExprNode { expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name: func.name().to_string(), - fun_definition, - args, + fun_definition: (!buf.is_empty()).then_some(buf), + args: serialize_exprs(args, codec)?, })), } } diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index b7311c694d4c9..5ecca51478053 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -164,8 +164,10 @@ pub fn parse_physical_window_expr( WindowFunctionDefinition::BuiltInWindowFunction(f.into()) } protobuf::physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(udaf_name) => { - let agg_udf = registry.udaf(udaf_name)?; - WindowFunctionDefinition::AggregateUDF(agg_udf) + WindowFunctionDefinition::AggregateUDF(match &proto.fun_definition { + Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, + None => registry.udaf(udaf_name)? + }) } } } else { diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 948a39bfe0be7..1220f42ded836 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -61,7 +61,7 @@ use datafusion::physical_plan::{ udaf, AggregateExpr, ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr, }; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; -use datafusion_expr::ScalarUDF; +use datafusion_expr::{AggregateUDF, ScalarUDF}; use crate::common::{byte_to_string, str_to_byte}; use crate::convert_required; @@ -491,19 +491,22 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { &ordering_req, &physical_schema, name.to_string(), - false, + agg_node.ignore_nulls, ) } AggregateFunction::UserDefinedAggrFunction(udaf_name) => { - let agg_udf = registry.udaf(udaf_name)?; + let agg_udf = match &agg_node.fun_definition { + Some(buf) => extension_codec.try_decode_udaf(udaf_name, buf)?, + None => registry.udaf(udaf_name)? + }; + // TODO: 'logical_exprs' is not supported for UDAF yet. // approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet. let logical_exprs = &[]; // TODO: `order by` is not supported for UDAF yet let sort_exprs = &[]; let ordering_req = &[]; - let ignore_nulls = false; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls, false) + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, sort_exprs, ordering_req, &physical_schema, name, agg_node.ignore_nulls, agg_node.distinct) } } }).transpose()?.ok_or_else(|| { @@ -2034,6 +2037,16 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync { ) -> Result<()> { not_impl_err!("PhysicalExtensionCodec is not provided") } + + fn try_decode_udaf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!( + "PhysicalExtensionCodec is not provided for aggregate function {name}" + ) + } + + fn try_encode_udaf(&self, _node: &AggregateUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } } #[derive(Debug)] diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index d8d0291e1ca52..7ea2902cf3c09 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -40,6 +40,7 @@ use datafusion::{ physical_plan::expressions::LikeExpr, }; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_expr::WindowFrame; use crate::protobuf::{ self, physical_aggregate_expr_node, physical_window_expr_node, PhysicalSortExprNode, @@ -58,13 +59,17 @@ pub fn serialize_physical_aggr_expr( if let Some(a) = aggr_expr.as_any().downcast_ref::() { let name = a.fun().name().to_string(); + let mut buf = Vec::new(); + codec.try_encode_udaf(a.fun(), &mut buf)?; return Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), expr: expressions, ordering_req, - distinct: false, + distinct: a.is_distinct(), + ignore_nulls: a.ignore_nulls(), + fun_definition: (!buf.is_empty()).then_some(buf) }, )), }); @@ -86,11 +91,55 @@ pub fn serialize_physical_aggr_expr( expr: expressions, ordering_req, distinct, + ignore_nulls: false, + fun_definition: None, }, )), }) } +fn serialize_physical_window_aggr_expr( + aggr_expr: &dyn AggregateExpr, + window_frame: &WindowFrame, + codec: &dyn PhysicalExtensionCodec, +) -> Result<(physical_window_expr_node::WindowFunction, Option>)> { + if let Some(a) = aggr_expr.as_any().downcast_ref::() { + if a.is_distinct() || a.ignore_nulls() { + // TODO + return not_impl_err!( + "Distinct aggregate functions not supported in window expressions" + ); + } + + let mut buf = Vec::new(); + codec.try_encode_udaf(a.fun(), &mut buf)?; + Ok(( + physical_window_expr_node::WindowFunction::UserDefinedAggrFunction( + a.fun().name().to_string(), + ), + (!buf.is_empty()).then_some(buf), + )) + } else { + let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(aggr_expr)?; + if distinct { + return not_impl_err!( + "Distinct aggregate functions not supported in window expressions" + ); + } + + if !window_frame.start_bound.is_unbounded() { + return Err(DataFusionError::Internal(format!( + "Unbounded start bound in WindowFrame = {window_frame}" + ))); + } + + Ok(( + physical_window_expr_node::WindowFunction::AggrFunction(inner as i32), + None, + )) + } +} + pub fn serialize_physical_window_expr( window_expr: Arc, codec: &dyn PhysicalExtensionCodec, @@ -99,7 +148,7 @@ pub fn serialize_physical_window_expr( let mut args = window_expr.expressions().to_vec(); let window_frame = window_expr.get_window_frame(); - let window_function = if let Some(built_in_window_expr) = + let (window_function, fun_definition) = if let Some(built_in_window_expr) = expr.downcast_ref::() { let expr = built_in_window_expr.get_built_in_func_expr(); @@ -160,58 +209,26 @@ pub fn serialize_physical_window_expr( return not_impl_err!("BuiltIn function not supported: {expr:?}"); }; - physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn as i32) + ( + physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn as i32), + None, + ) } else if let Some(plain_aggr_window_expr) = expr.downcast_ref::() { - let aggr_expr = plain_aggr_window_expr.get_aggregate_expr(); - if let Some(a) = aggr_expr.as_any().downcast_ref::() { - physical_window_expr_node::WindowFunction::UserDefinedAggrFunction( - a.fun().name().to_string(), - ) - } else { - let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn( - plain_aggr_window_expr.get_aggregate_expr().as_ref(), - )?; - - if distinct { - return not_impl_err!( - "Distinct aggregate functions not supported in window expressions" - ); - } - - if !window_frame.start_bound.is_unbounded() { - return Err(DataFusionError::Internal(format!("Invalid PlainAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); - } - - physical_window_expr_node::WindowFunction::AggrFunction(inner as i32) - } + serialize_physical_window_aggr_expr( + plain_aggr_window_expr.get_aggregate_expr().as_ref(), + window_frame, + codec, + )? } else if let Some(sliding_aggr_window_expr) = expr.downcast_ref::() { - let aggr_expr = sliding_aggr_window_expr.get_aggregate_expr(); - if let Some(a) = aggr_expr.as_any().downcast_ref::() { - physical_window_expr_node::WindowFunction::UserDefinedAggrFunction( - a.fun().name().to_string(), - ) - } else { - let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn( - sliding_aggr_window_expr.get_aggregate_expr().as_ref(), - )?; - - if distinct { - // TODO - return not_impl_err!( - "Distinct aggregate functions not supported in window expressions" - ); - } - - if window_frame.start_bound.is_unbounded() { - return Err(DataFusionError::Internal(format!("Invalid SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); - } - - physical_window_expr_node::WindowFunction::AggrFunction(inner as i32) - } + serialize_physical_window_aggr_expr( + sliding_aggr_window_expr.get_aggregate_expr().as_ref(), + window_frame, + codec, + )? } else { return not_impl_err!("WindowExpr not supported: {window_expr:?}"); }; @@ -232,6 +249,7 @@ pub fn serialize_physical_window_expr( window_frame: Some(window_frame), window_function: Some(window_function), name: window_expr.name().to_string(), + fun_definition, }) } @@ -461,18 +479,14 @@ pub fn serialize_physical_expr( ))), }) } else if let Some(expr) = expr.downcast_ref::() { - let args = serialize_physical_exprs(expr.args().to_vec(), codec)?; - let mut buf = Vec::new(); codec.try_encode_udf(expr.fun(), &mut buf)?; - - let fun_definition = if buf.is_empty() { None } else { Some(buf) }; Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( protobuf::PhysicalScalarUdfNode { name: expr.name().to_string(), - args, - fun_definition, + args: serialize_physical_exprs(expr.args().to_vec(), codec)?, + fun_definition: (!buf.is_empty()).then_some(buf), return_type: Some(expr.return_type().try_into()?), }, )), diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index b17289205f3de..1f837b7f42e86 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -15,6 +15,105 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; + +use arrow::datatypes::DataType; + +use datafusion_common::plan_err; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, Signature, Volatility, +}; + mod roundtrip_logical_plan; mod roundtrip_physical_plan; mod serialize; + +#[derive(Debug, PartialEq, Eq, Hash)] +struct MyRegexUdf { + signature: Signature, + // regex as original string + pattern: String, +} + +impl MyRegexUdf { + fn new(pattern: String) -> Self { + let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable); + Self { signature, pattern } + } +} + +/// Implement the ScalarUDFImpl trait for MyRegexUdf +impl ScalarUDFImpl for MyRegexUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "regex_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, args: &[DataType]) -> datafusion_common::Result { + if matches!(args, [DataType::Utf8]) { + Ok(DataType::Int64) + } else { + plan_err!("regex_udf only accepts Utf8 arguments") + } + } + fn invoke( + &self, + _args: &[ColumnarValue], + ) -> datafusion_common::Result { + unimplemented!() + } +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct MyRegexUdfNode { + #[prost(string, tag = "1")] + pub pattern: String, +} + +#[derive(Debug, PartialEq, Eq, Hash)] +struct MyAggregateUDF { + signature: Signature, + result: String, +} + +impl MyAggregateUDF { + fn new(result: String) -> Self { + let signature = Signature::exact(vec![DataType::Int64], Volatility::Immutable); + Self { signature, result } + } +} + +impl AggregateUDFImpl for MyAggregateUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "aggregate_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type( + &self, + _arg_types: &[DataType], + ) -> datafusion_common::Result { + Ok(DataType::Utf8) + } + fn accumulator( + &self, + _acc_args: AccumulatorArgs, + ) -> datafusion_common::Result> { + unimplemented!() + } +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct MyAggregateUdfNode { + #[prost(string, tag = "1")] + pub result: String, +} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index d0209d811b7ce..0117502f400d2 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -28,15 +28,12 @@ use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; +use prost::Message; + use datafusion::datasource::file_format::arrow::ArrowFormatFactory; use datafusion::datasource::file_format::csv::CsvFormatFactory; use datafusion::datasource::file_format::format_as_file_type; use datafusion::datasource::file_format::parquet::ParquetFormatFactory; -use datafusion_proto::logical_plan::file_formats::{ - ArrowLogicalExtensionCodec, CsvLogicalExtensionCodec, ParquetLogicalExtensionCodec, -}; -use prost::Message; - use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::session_state::SessionStateBuilder; @@ -62,9 +59,9 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateExt, AggregateFunction, ColumnarValue, ExprSchemable, - LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, - TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, + Accumulator, AggregateExt, AggregateFunction, AggregateUDF, ColumnarValue, + ExprSchemable, Literal, LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, + Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_functions_aggregate::average::avg_udaf; @@ -76,12 +73,17 @@ use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, }; +use datafusion_proto::logical_plan::file_formats::{ + ArrowLogicalExtensionCodec, CsvLogicalExtensionCodec, ParquetLogicalExtensionCodec, +}; use datafusion_proto::logical_plan::to_proto::serialize_expr; use datafusion_proto::logical_plan::{ from_proto, DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; use datafusion_proto::protobuf; +use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, MyRegexUdfNode}; + #[cfg(feature = "json")] fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) { let string = serde_json::to_string(proto).unwrap(); @@ -744,7 +746,7 @@ pub mod proto { pub k: u64, #[prost(message, optional, tag = "2")] - pub expr: ::core::option::Option, + pub expr: Option, } #[derive(Clone, PartialEq, Eq, ::prost::Message)] @@ -752,12 +754,6 @@ pub mod proto { #[prost(uint64, tag = "1")] pub k: u64, } - - #[derive(Clone, PartialEq, ::prost::Message)] - pub struct MyRegexUdfNode { - #[prost(string, tag = "1")] - pub pattern: String, - } } #[derive(PartialEq, Eq, Hash)] @@ -890,51 +886,9 @@ impl LogicalExtensionCodec for TopKExtensionCodec { } #[derive(Debug)] -struct MyRegexUdf { - signature: Signature, - // regex as original string - pattern: String, -} - -impl MyRegexUdf { - fn new(pattern: String) -> Self { - Self { - signature: Signature::uniform( - 1, - vec![DataType::Int32], - Volatility::Immutable, - ), - pattern, - } - } -} - -/// Implement the ScalarUDFImpl trait for MyRegexUdf -impl ScalarUDFImpl for MyRegexUdf { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "regex_udf" - } - fn signature(&self) -> &Signature { - &self.signature - } - fn return_type(&self, args: &[DataType]) -> Result { - if !matches!(args.first(), Some(&DataType::Utf8)) { - return plan_err!("regex_udf only accepts Utf8 arguments"); - } - Ok(DataType::Int32) - } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { - unimplemented!() - } -} - -#[derive(Debug)] -pub struct ScalarUDFExtensionCodec {} +pub struct UDFExtensionCodec; -impl LogicalExtensionCodec for ScalarUDFExtensionCodec { +impl LogicalExtensionCodec for UDFExtensionCodec { fn try_decode( &self, _buf: &[u8], @@ -969,13 +923,11 @@ impl LogicalExtensionCodec for ScalarUDFExtensionCodec { fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { if name == "regex_udf" { - let proto = proto::MyRegexUdfNode::decode(buf).map_err(|err| { - DataFusionError::Internal(format!("failed to decode regex_udf: {}", err)) + let proto = MyRegexUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to decode regex_udf: {err}")) })?; - Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new( - proto.pattern, - )))) + Ok(Arc::new(ScalarUDF::from(MyRegexUdf::new(proto.pattern)))) } else { not_impl_err!("unrecognized scalar UDF implementation, cannot decode") } @@ -984,11 +936,39 @@ impl LogicalExtensionCodec for ScalarUDFExtensionCodec { fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { let binding = node.inner(); let udf = binding.as_any().downcast_ref::().unwrap(); - let proto = proto::MyRegexUdfNode { + let proto = MyRegexUdfNode { pattern: udf.pattern.clone(), }; - proto.encode(buf).map_err(|e| { - DataFusionError::Internal(format!("failed to encode udf: {e:?}")) + proto.encode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to encode udf: {err}")) + })?; + Ok(()) + } + + fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "aggregate_udf" { + let proto = MyAggregateUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!( + "failed to decode aggregate_udf: {err}" + )) + })?; + + Ok(Arc::new(AggregateUDF::from(MyAggregateUDF::new( + proto.result, + )))) + } else { + not_impl_err!("unrecognized aggregate UDF implementation, cannot decode") + } + } + + fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + let udf = binding.as_any().downcast_ref::().unwrap(); + let proto = MyAggregateUdfNode { + result: udf.result.clone(), + }; + proto.encode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to encode udf: {err}")) })?; Ok(()) } @@ -1563,8 +1543,7 @@ fn roundtrip_null_scalar_values() { for test_case in test_types.into_iter() { let proto_scalar: protobuf::ScalarValue = (&test_case).try_into().unwrap(); - let returned_scalar: datafusion::scalar::ScalarValue = - (&proto_scalar).try_into().unwrap(); + let returned_scalar: ScalarValue = (&proto_scalar).try_into().unwrap(); assert_eq!(format!("{:?}", &test_case), format!("{returned_scalar:?}")); } } @@ -1893,22 +1872,19 @@ fn roundtrip_aggregate_udf() { struct Dummy {} impl Accumulator for Dummy { - fn state(&mut self) -> datafusion::error::Result> { + fn state(&mut self) -> Result> { Ok(vec![]) } - fn update_batch( - &mut self, - _values: &[ArrayRef], - ) -> datafusion::error::Result<()> { + fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> { Ok(()) } - fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> { + fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> { Ok(()) } - fn evaluate(&mut self) -> datafusion::error::Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::Float64(None)) } @@ -1976,25 +1952,27 @@ fn roundtrip_scalar_udf() { #[test] fn roundtrip_scalar_udf_extension_codec() { - let pattern = ".*"; - let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string())); - let test_expr = - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf.clone()), vec![])); - + let udf = ScalarUDF::from(MyRegexUdf::new(".*".to_owned())); + let test_expr = udf.call(vec!["foo".lit()]); let ctx = SessionContext::new(); - ctx.register_udf(udf); - - let extension_codec = ScalarUDFExtensionCodec {}; - let proto: protobuf::LogicalExprNode = - match serialize_expr(&test_expr, &extension_codec) { - Ok(p) => p, - Err(e) => panic!("Error serializing expression: {:?}", e), - }; - let round_trip: Expr = - from_proto::parse_expr(&proto, &ctx, &extension_codec).unwrap(); + let proto = serialize_expr(&test_expr, &UDFExtensionCodec).expect("serialize expr"); + let round_trip = + from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse expr"); assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); + roundtrip_json_test(&proto); +} + +#[test] +fn roundtrip_aggregate_udf_extension_codec() { + let udf = AggregateUDF::from(MyAggregateUDF::new("DataFusion".to_owned())); + let test_expr = udf.call(vec![42.lit()]); + let ctx = SessionContext::new(); + let proto = serialize_expr(&test_expr, &UDFExtensionCodec).expect("serialize expr"); + let round_trip = + from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse expr"); + assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); roundtrip_json_test(&proto); } @@ -2120,22 +2098,19 @@ fn roundtrip_window() { struct DummyAggr {} impl Accumulator for DummyAggr { - fn state(&mut self) -> datafusion::error::Result> { + fn state(&mut self) -> Result> { Ok(vec![]) } - fn update_batch( - &mut self, - _values: &[ArrayRef], - ) -> datafusion::error::Result<()> { + fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> { Ok(()) } - fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> { + fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> { Ok(()) } - fn evaluate(&mut self) -> datafusion::error::Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::Float64(None)) } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 2fcc65008fd8f..fba6dfe425996 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::RecordBatch; use std::any::Any; use std::fmt::Display; use std::hash::Hasher; @@ -23,8 +22,8 @@ use std::ops::Deref; use std::sync::Arc; use std::vec; +use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; -use datafusion::functions_aggregate::sum::sum_udaf; use prost::Message; use datafusion::arrow::array::ArrayRef; @@ -40,9 +39,10 @@ use datafusion::datasource::physical_plan::{ FileSinkConfig, ParquetExec, }; use datafusion::execution::FunctionRegistry; +use datafusion::functions_aggregate::sum::sum_udaf; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; use datafusion::physical_expr::aggregate::utils::down_cast_any_ref; -use datafusion::physical_expr::expressions::Max; +use datafusion::physical_expr::expressions::{Literal, Max}; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ @@ -70,7 +70,7 @@ use datafusion::physical_plan::windows::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec, }; use datafusion::physical_plan::{ - udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, Statistics, + AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, Statistics, }; use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; @@ -79,10 +79,10 @@ use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; -use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError, Result}; +use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, - ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, + Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::nth_value::nth_value_udaf; @@ -92,6 +92,8 @@ use datafusion_proto::physical_plan::{ }; use datafusion_proto::protobuf; +use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, MyRegexUdfNode}; + /// Perform a serde roundtrip and assert that the string representation of the before and after plans /// are identical. Note that this often isn't sufficient to guarantee that no information is /// lost during serde because the string representation of a plan often only shows a subset of state. @@ -312,7 +314,7 @@ fn roundtrip_window() -> Result<()> { ); let args = vec![cast(col("a", &schema)?, &schema, DataType::Float64)?]; - let sum_expr = udaf::create_aggregate_expr( + let sum_expr = create_aggregate_expr( &sum_udaf(), &args, &[], @@ -367,7 +369,7 @@ fn rountrip_aggregate() -> Result<()> { false, )?], // NTH_VALUE - vec![udaf::create_aggregate_expr( + vec![create_aggregate_expr( &nth_value_udaf(), &[col("b", &schema)?, lit(1u64)], &[], @@ -379,7 +381,7 @@ fn rountrip_aggregate() -> Result<()> { false, )?], // STRING_AGG - vec![udaf::create_aggregate_expr( + vec![create_aggregate_expr( &AggregateUDF::new_from_impl(StringAgg::new()), &[ cast(col("b", &schema)?, &schema, DataType::Utf8)?, @@ -490,7 +492,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec> = vec![udaf::create_aggregate_expr( + let aggregates: Vec> = vec![create_aggregate_expr( &udaf, &[col("b", &schema)?], &[], @@ -845,123 +847,161 @@ fn roundtrip_scalar_udf() -> Result<()> { roundtrip_test_with_context(Arc::new(project), &ctx) } -#[test] -fn roundtrip_scalar_udf_extension_codec() -> Result<()> { - #[derive(Debug)] - struct MyRegexUdf { - signature: Signature, - // regex as original string - pattern: String, +#[derive(Debug)] +struct UDFExtensionCodec; + +impl PhysicalExtensionCodec for UDFExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> Result> { + not_impl_err!("No extension codec provided") } - impl MyRegexUdf { - fn new(pattern: String) -> Self { - Self { - signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), - pattern, - } - } + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + not_impl_err!("No extension codec provided") } - /// Implement the ScalarUDFImpl trait for MyRegexUdf - impl ScalarUDFImpl for MyRegexUdf { - fn as_any(&self) -> &dyn Any { - self - } + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "regex_udf" { + let proto = MyRegexUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to decode regex_udf: {err}")) + })?; - fn name(&self) -> &str { - "regex_udf" + Ok(Arc::new(ScalarUDF::from(MyRegexUdf::new(proto.pattern)))) + } else { + not_impl_err!("unrecognized scalar UDF implementation, cannot decode") } + } - fn signature(&self) -> &Signature { - &self.signature + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + if let Some(udf) = binding.as_any().downcast_ref::() { + let proto = MyRegexUdfNode { + pattern: udf.pattern.clone(), + }; + proto.encode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to encode udf: {err}")) + })?; } + Ok(()) + } - fn return_type(&self, args: &[DataType]) -> Result { - if !matches!(args.first(), Some(&DataType::Utf8)) { - return plan_err!("regex_udf only accepts Utf8 arguments"); - } - Ok(DataType::Int64) + fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "aggregate_udf" { + let proto = MyAggregateUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!( + "failed to decode aggregate_udf: {err}" + )) + })?; + + Ok(Arc::new(AggregateUDF::from(MyAggregateUDF::new( + proto.result, + )))) + } else { + not_impl_err!("unrecognized scalar UDF implementation, cannot decode") } + } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { - unimplemented!() + fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + if let Some(udf) = binding.as_any().downcast_ref::() { + let proto = MyAggregateUdfNode { + result: udf.result.clone(), + }; + proto.encode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to encode udf: {err:?}")) + })?; } + Ok(()) } +} - #[derive(Clone, PartialEq, ::prost::Message)] - pub struct MyRegexUdfNode { - #[prost(string, tag = "1")] - pub pattern: String, - } +#[test] +fn roundtrip_scalar_udf_extension_codec() -> Result<()> { + let field_text = Field::new("text", DataType::Utf8, true); + let field_published = Field::new("published", DataType::Boolean, false); + let field_author = Field::new("author", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_text, field_published, field_author])); + let input = Arc::new(EmptyExec::new(schema.clone())); - #[derive(Debug)] - pub struct ScalarUDFExtensionCodec {} + let udf_expr = Arc::new(ScalarFunctionExpr::new( + "regex_udf", + Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), + vec![col("text", &schema)?], + DataType::Int64, + )); - impl PhysicalExtensionCodec for ScalarUDFExtensionCodec { - fn try_decode( - &self, - _buf: &[u8], - _inputs: &[Arc], - _registry: &dyn FunctionRegistry, - ) -> Result> { - not_impl_err!("No extension codec provided") - } + let filter = Arc::new(FilterExec::try_new( + Arc::new(BinaryExpr::new( + col("published", &schema)?, + Operator::And, + Arc::new(BinaryExpr::new(udf_expr.clone(), Operator::Gt, lit(0))), + )), + input, + )?); - fn try_encode( - &self, - _node: Arc, - _buf: &mut Vec, - ) -> Result<()> { - not_impl_err!("No extension codec provided") - } + let window = Arc::new(WindowAggExec::try_new( + vec![Arc::new(PlainAggregateWindowExpr::new( + Arc::new(Max::new(udf_expr.clone(), "max", DataType::Int64)), + &[col("author", &schema)?], + &[], + Arc::new(WindowFrame::new(None)), + ))], + filter, + vec![col("author", &schema)?], + )?); - fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { - if name == "regex_udf" { - let proto = MyRegexUdfNode::decode(buf).map_err(|err| { - DataFusionError::Internal(format!( - "failed to decode regex_udf: {}", - err - )) - })?; - - Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new( - proto.pattern, - )))) - } else { - not_impl_err!("unrecognized scalar UDF implementation, cannot decode") - } - } + let aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new(vec![], vec![], vec![]), + vec![Arc::new(Max::new(udf_expr, "max", DataType::Int64))], + vec![None], + window, + schema.clone(), + )?); - fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { - let binding = node.inner(); - if let Some(udf) = binding.as_any().downcast_ref::() { - let proto = MyRegexUdfNode { - pattern: udf.pattern.clone(), - }; - proto.encode(buf).map_err(|e| { - DataFusionError::Internal(format!("failed to encode udf: {e:?}")) - })?; - } - Ok(()) - } - } + let ctx = SessionContext::new(); + roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?; + Ok(()) +} +#[test] +fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { let field_text = Field::new("text", DataType::Utf8, true); let field_published = Field::new("published", DataType::Boolean, false); let field_author = Field::new("author", DataType::Utf8, false); let schema = Arc::new(Schema::new(vec![field_text, field_published, field_author])); let input = Arc::new(EmptyExec::new(schema.clone())); - let pattern = ".*"; - let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string())); let udf_expr = Arc::new(ScalarFunctionExpr::new( - udf.name(), - Arc::new(udf.clone()), + "regex_udf", + Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), vec![col("text", &schema)?], DataType::Int64, )); + let udaf = AggregateUDF::from(MyAggregateUDF::new("result".to_string())); + let aggr_args: [Arc; 1] = + [Arc::new(Literal::new(ScalarValue::from(42)))]; + let aggr_expr = create_aggregate_expr( + &udaf, + &aggr_args, + &[], + &[], + &[], + &schema, + "aggregate_udf", + false, + false, + )?; + let filter = Arc::new(FilterExec::try_new( Arc::new(BinaryExpr::new( col("published", &schema)?, @@ -973,7 +1013,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { let window = Arc::new(WindowAggExec::try_new( vec![Arc::new(PlainAggregateWindowExpr::new( - Arc::new(Max::new(udf_expr.clone(), "max", DataType::Int64)), + aggr_expr, &[col("author", &schema)?], &[], Arc::new(WindowFrame::new(None)), @@ -982,18 +1022,29 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { vec![col("author", &schema)?], )?); + let aggr_expr = create_aggregate_expr( + &udaf, + &aggr_args, + &[], + &[], + &[], + &schema, + "aggregate_udf", + true, + true, + )?; + let aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::new(vec![], vec![], vec![]), - vec![Arc::new(Max::new(udf_expr, "max", DataType::Int64))], + vec![aggr_expr], vec![None], window, schema.clone(), )?); let ctx = SessionContext::new(); - let codec = ScalarUDFExtensionCodec {}; - roundtrip_test_and_return(aggregate, &ctx, &codec)?; + roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?; Ok(()) } From a979f3e5d4745edf31a489185e6dda5008e6e628 Mon Sep 17 00:00:00 2001 From: JasonLi Date: Wed, 17 Jul 2024 09:32:36 +0800 Subject: [PATCH 56/59] feat: support `unnest` in GROUP BY clause (#11469) * feat: support group by unnest * pass slt * refactor: mv process_group_by_unnest into try_process_unnest * chore: add some documentation comments and tests * Avoid cloning input * use consistent field names --------- Co-authored-by: Andrew Lamb --- datafusion/sql/src/select.rs | 118 ++++++++++++++- datafusion/sqllogictest/test_files/unnest.slt | 134 +++++++++++++++++- 2 files changed, 249 insertions(+), 3 deletions(-) diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index a5891e655a052..84b80c311245c 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -26,18 +26,20 @@ use crate::utils::{ resolve_columns, resolve_positions_to_exprs, transform_bottom_unnest, }; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_common::{Column, UnnestOptions}; use datafusion_expr::expr::Alias; use datafusion_expr::expr_rewriter::{ normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_cols, }; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::utils::{ expand_qualified_wildcard, expand_wildcard, expr_as_column_expr, expr_to_columns, find_aggregate_exprs, find_window_exprs, }; use datafusion_expr::{ - Expr, Filter, GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning, + Aggregate, Expr, Filter, GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning, }; use sqlparser::ast::{ Distinct, Expr as SQLExpr, GroupByExpr, NamedWindowExpr, OrderByExpr, @@ -297,6 +299,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: LogicalPlan, select_exprs: Vec, ) -> Result { + // Try process group by unnest + let input = self.try_process_aggregate_unnest(input)?; + let mut intermediate_plan = input; let mut intermediate_select_exprs = select_exprs; // Each expr in select_exprs can contains multiple unnest stage @@ -354,6 +359,117 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .build() } + fn try_process_aggregate_unnest(&self, input: LogicalPlan) -> Result { + match input { + LogicalPlan::Aggregate(agg) => { + let agg_expr = agg.aggr_expr.clone(); + let (new_input, new_group_by_exprs) = + self.try_process_group_by_unnest(agg)?; + LogicalPlanBuilder::from(new_input) + .aggregate(new_group_by_exprs, agg_expr)? + .build() + } + LogicalPlan::Filter(mut filter) => { + filter.input = Arc::new( + self.try_process_aggregate_unnest(unwrap_arc(filter.input))?, + ); + Ok(LogicalPlan::Filter(filter)) + } + _ => Ok(input), + } + } + + /// Try converting Unnest(Expr) of group by to Unnest/Projection + /// Return the new input and group_by_exprs of Aggregate. + fn try_process_group_by_unnest( + &self, + agg: Aggregate, + ) -> Result<(LogicalPlan, Vec)> { + let mut aggr_expr_using_columns: Option> = None; + + let Aggregate { + input, + group_expr, + aggr_expr, + .. + } = agg; + + // process unnest of group_by_exprs, and input of agg will be rewritten + // for example: + // + // ``` + // Aggregate: groupBy=[[UNNEST(Column(Column { relation: Some(Bare { table: "tab" }), name: "array_col" }))]], aggr=[[]] + // TableScan: tab + // ``` + // + // will be transformed into + // + // ``` + // Aggregate: groupBy=[[unnest(tab.array_col)]], aggr=[[]] + // Unnest: lists[unnest(tab.array_col)] structs[] + // Projection: tab.array_col AS unnest(tab.array_col) + // TableScan: tab + // ``` + let mut intermediate_plan = unwrap_arc(input); + let mut intermediate_select_exprs = group_expr; + + loop { + let mut unnest_columns = vec![]; + let mut inner_projection_exprs = vec![]; + + let outer_projection_exprs: Vec = intermediate_select_exprs + .iter() + .map(|expr| { + transform_bottom_unnest( + &intermediate_plan, + &mut unnest_columns, + &mut inner_projection_exprs, + expr, + ) + }) + .collect::>>()? + .into_iter() + .flatten() + .collect(); + + if unnest_columns.is_empty() { + break; + } else { + let columns = unnest_columns.into_iter().map(|col| col.into()).collect(); + let unnest_options = UnnestOptions::new().with_preserve_nulls(false); + + let mut projection_exprs = match &aggr_expr_using_columns { + Some(exprs) => (*exprs).clone(), + None => { + let mut columns = HashSet::new(); + for expr in &aggr_expr { + expr.apply(|expr| { + if let Expr::Column(c) = expr { + columns.insert(Expr::Column(c.clone())); + } + Ok(TreeNodeRecursion::Continue) + }) + // As the closure always returns Ok, this "can't" error + .expect("Unexpected error"); + } + aggr_expr_using_columns = Some(columns.clone()); + columns + } + }; + projection_exprs.extend(inner_projection_exprs); + + intermediate_plan = LogicalPlanBuilder::from(intermediate_plan) + .project(projection_exprs)? + .unnest_columns_with_options(columns, unnest_options)? + .build()?; + + intermediate_select_exprs = outer_projection_exprs; + } + } + + Ok((intermediate_plan, intermediate_select_exprs)) + } + fn plan_selection( &self, selection: Option, diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index 698faf87c9b20..93146541e107b 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -500,8 +500,6 @@ select unnest(column1) from (select * from (values([1,2,3]), ([4,5,6])) limit 1 query error DataFusion error: Error during planning: Projections require unique expression names but the expression "UNNEST\(Column\(Column \{ relation: Some\(Bare \{ table: "unnest_table" \}\), name: "column1" \}\)\)" at position 0 and "UNNEST\(Column\(Column \{ relation: Some\(Bare \{ table: "unnest_table" \}\), name: "column1" \}\)\)" at position 1 have the same name. Consider aliasing \("AS"\) one of them. select unnest(column1), unnest(column1) from unnest_table; -statement ok -drop table unnest_table; ## unnest list followed by unnest struct query ??? @@ -557,3 +555,135 @@ physical_plan 06)----------UnnestExec 07)------------ProjectionExec: expr=[column3@0 as unnest(recursive_unnest_table.column3), column3@0 as column3] 08)--------------MemoryExec: partitions=1, partition_sizes=[1] + +## group by unnest + +### without agg exprs +query I +select unnest(column1) c1 from unnest_table group by c1 order by c1; +---- +1 +2 +3 +4 +5 +6 +12 + +query II +select unnest(column1) c1, unnest(column2) c2 from unnest_table group by c1, c2 order by c1, c2; +---- +1 7 +2 NULL +3 NULL +4 8 +5 9 +6 11 +12 NULL +NULL 10 +NULL 12 +NULL 42 +NULL NULL + +query III +select unnest(column1) c1, unnest(column2) c2, column3 c3 from unnest_table group by c1, c2, c3 order by c1, c2, c3; +---- +1 7 1 +2 NULL 1 +3 NULL 1 +4 8 2 +5 9 2 +6 11 3 +12 NULL NULL +NULL 10 2 +NULL 12 3 +NULL 42 NULL +NULL NULL NULL + +### with agg exprs + +query IIII +select unnest(column1) c1, unnest(column2) c2, column3 c3, count(1) from unnest_table group by c1, c2, c3 order by c1, c2, c3; +---- +1 7 1 1 +2 NULL 1 1 +3 NULL 1 1 +4 8 2 1 +5 9 2 1 +6 11 3 1 +12 NULL NULL 1 +NULL 10 2 1 +NULL 12 3 1 +NULL 42 NULL 1 +NULL NULL NULL 1 + +query IIII +select unnest(column1) c1, unnest(column2) c2, column3 c3, count(column4) from unnest_table group by c1, c2, c3 order by c1, c2, c3; +---- +1 7 1 1 +2 NULL 1 1 +3 NULL 1 1 +4 8 2 1 +5 9 2 1 +6 11 3 0 +12 NULL NULL 0 +NULL 10 2 1 +NULL 12 3 0 +NULL 42 NULL 0 +NULL NULL NULL 0 + +query IIIII +select unnest(column1) c1, unnest(column2) c2, column3 c3, count(column4), sum(column3) from unnest_table group by c1, c2, c3 order by c1, c2, c3; +---- +1 7 1 1 1 +2 NULL 1 1 1 +3 NULL 1 1 1 +4 8 2 1 2 +5 9 2 1 2 +6 11 3 0 3 +12 NULL NULL 0 NULL +NULL 10 2 1 2 +NULL 12 3 0 3 +NULL 42 NULL 0 NULL +NULL NULL NULL 0 NULL + +query II +select unnest(column1), count(*) from unnest_table group by unnest(column1) order by unnest(column1) desc; +---- +12 1 +6 1 +5 1 +4 1 +3 1 +2 1 +1 1 + +### group by recursive unnest list + +query ? +select unnest(unnest(column2)) c2 from recursive_unnest_table group by c2 order by c2; +---- +[1] +[1, 1] +[2] +[3, 4] +[5] +[7, 8] +[, 6] +NULL + +query ?I +select unnest(unnest(column2)) c2, count(column3) from recursive_unnest_table group by c2 order by c2; +---- +[1] 1 +[1, 1] 1 +[2] 1 +[3, 4] 1 +[5] 1 +[7, 8] 1 +[, 6] 1 +NULL 1 + +### TODO: group by unnest struct +query error DataFusion error: Error during planning: Projection references non\-aggregate values +select unnest(column1) c1 from nested_unnest_table group by c1.c0; From d67b0fbf52a2c428399811fabac3eec6cf15da41 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Wed, 17 Jul 2024 13:34:07 +0800 Subject: [PATCH 57/59] Remove element's nullability of array_agg function (#11447) * rm null Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/core/tests/sql/aggregates.rs | 2 +- .../physical-expr/src/aggregate/array_agg.rs | 23 +++--------- .../src/aggregate/array_agg_distinct.rs | 23 +++--------- .../src/aggregate/array_agg_ordered.rs | 37 +++++-------------- .../physical-expr/src/aggregate/build_in.rs | 12 +----- .../physical-plan/src/aggregates/mod.rs | 1 - 6 files changed, 23 insertions(+), 75 deletions(-) diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 86032dc9bc963..1f4f9e77d5dc5 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -36,7 +36,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { *actual[0].schema(), Schema::new(vec![Field::new_list( "ARRAY_AGG(DISTINCT aggregate_test_100.c2)", - Field::new("item", DataType::UInt32, false), + Field::new("item", DataType::UInt32, true), true ),]) ); diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 38a9738029335..0d5ed730e2834 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -24,7 +24,7 @@ use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; use arrow_array::Array; use datafusion_common::cast::as_list_array; -use datafusion_common::utils::array_into_list_array; +use datafusion_common::utils::array_into_list_array_nullable; use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::Accumulator; @@ -40,8 +40,6 @@ pub struct ArrayAgg { input_data_type: DataType, /// The input expression expr: Arc, - /// If the input expression can have NULLs - nullable: bool, } impl ArrayAgg { @@ -50,13 +48,11 @@ impl ArrayAgg { expr: Arc, name: impl Into, data_type: DataType, - nullable: bool, ) -> Self { Self { name: name.into(), input_data_type: data_type, expr, - nullable, } } } @@ -70,7 +66,7 @@ impl AggregateExpr for ArrayAgg { Ok(Field::new_list( &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), self.nullable), + Field::new("item", self.input_data_type.clone(), true), true, )) } @@ -78,14 +74,13 @@ impl AggregateExpr for ArrayAgg { fn create_accumulator(&self) -> Result> { Ok(Box::new(ArrayAggAccumulator::try_new( &self.input_data_type, - self.nullable, )?)) } fn state_fields(&self) -> Result> { Ok(vec![Field::new_list( format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), self.nullable), + Field::new("item", self.input_data_type.clone(), true), true, )]) } @@ -116,16 +111,14 @@ impl PartialEq for ArrayAgg { pub(crate) struct ArrayAggAccumulator { values: Vec, datatype: DataType, - nullable: bool, } impl ArrayAggAccumulator { /// new array_agg accumulator based on given item data type - pub fn try_new(datatype: &DataType, nullable: bool) -> Result { + pub fn try_new(datatype: &DataType) -> Result { Ok(Self { values: vec![], datatype: datatype.clone(), - nullable, }) } } @@ -169,15 +162,11 @@ impl Accumulator for ArrayAggAccumulator { self.values.iter().map(|a| a.as_ref()).collect(); if element_arrays.is_empty() { - return Ok(ScalarValue::new_null_list( - self.datatype.clone(), - self.nullable, - 1, - )); + return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); } let concated_array = arrow::compute::concat(&element_arrays)?; - let list_array = array_into_list_array(concated_array, self.nullable); + let list_array = array_into_list_array_nullable(concated_array); Ok(ScalarValue::List(Arc::new(list_array))) } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 368d11d7421ab..eca6e4ce4f656 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -42,8 +42,6 @@ pub struct DistinctArrayAgg { input_data_type: DataType, /// The input expression expr: Arc, - /// If the input expression can have NULLs - nullable: bool, } impl DistinctArrayAgg { @@ -52,14 +50,12 @@ impl DistinctArrayAgg { expr: Arc, name: impl Into, input_data_type: DataType, - nullable: bool, ) -> Self { let name = name.into(); Self { name, input_data_type, expr, - nullable, } } } @@ -74,7 +70,7 @@ impl AggregateExpr for DistinctArrayAgg { Ok(Field::new_list( &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), self.nullable), + Field::new("item", self.input_data_type.clone(), true), true, )) } @@ -82,14 +78,13 @@ impl AggregateExpr for DistinctArrayAgg { fn create_accumulator(&self) -> Result> { Ok(Box::new(DistinctArrayAggAccumulator::try_new( &self.input_data_type, - self.nullable, )?)) } fn state_fields(&self) -> Result> { Ok(vec![Field::new_list( format_state_name(&self.name, "distinct_array_agg"), - Field::new("item", self.input_data_type.clone(), self.nullable), + Field::new("item", self.input_data_type.clone(), true), true, )]) } @@ -120,15 +115,13 @@ impl PartialEq for DistinctArrayAgg { struct DistinctArrayAggAccumulator { values: HashSet, datatype: DataType, - nullable: bool, } impl DistinctArrayAggAccumulator { - pub fn try_new(datatype: &DataType, nullable: bool) -> Result { + pub fn try_new(datatype: &DataType) -> Result { Ok(Self { values: HashSet::new(), datatype: datatype.clone(), - nullable, }) } } @@ -166,13 +159,9 @@ impl Accumulator for DistinctArrayAggAccumulator { fn evaluate(&mut self) -> Result { let values: Vec = self.values.iter().cloned().collect(); if values.is_empty() { - return Ok(ScalarValue::new_null_list( - self.datatype.clone(), - self.nullable, - 1, - )); + return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); } - let arr = ScalarValue::new_list(&values, &self.datatype, self.nullable); + let arr = ScalarValue::new_list(&values, &self.datatype, true); Ok(ScalarValue::List(arr)) } @@ -255,7 +244,6 @@ mod tests { col("a", &schema)?, "bla".to_string(), datatype, - true, )); let actual = aggregate(&batch, agg)?; compare_list_contents(expected, actual) @@ -272,7 +260,6 @@ mod tests { col("a", &schema)?, "bla".to_string(), datatype, - true, )); let mut accum1 = agg.create_accumulator()?; diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index d44811192f667..992c06f5bf628 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -33,7 +33,7 @@ use arrow::datatypes::{DataType, Field}; use arrow_array::cast::AsArray; use arrow_array::{new_empty_array, Array, ArrayRef, StructArray}; use arrow_schema::Fields; -use datafusion_common::utils::{array_into_list_array, get_row_at_idx}; +use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::utils::AggregateOrderSensitivity; use datafusion_expr::Accumulator; @@ -50,8 +50,6 @@ pub struct OrderSensitiveArrayAgg { input_data_type: DataType, /// The input expression expr: Arc, - /// If the input expression can have `NULL`s - nullable: bool, /// Ordering data types order_by_data_types: Vec, /// Ordering requirement @@ -66,7 +64,6 @@ impl OrderSensitiveArrayAgg { expr: Arc, name: impl Into, input_data_type: DataType, - nullable: bool, order_by_data_types: Vec, ordering_req: LexOrdering, ) -> Self { @@ -74,7 +71,6 @@ impl OrderSensitiveArrayAgg { name: name.into(), input_data_type, expr, - nullable, order_by_data_types, ordering_req, reverse: false, @@ -90,8 +86,8 @@ impl AggregateExpr for OrderSensitiveArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, - // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), self.nullable), + // This should be the same as return type of AggregateFunction::OrderSensitiveArrayAgg + Field::new("item", self.input_data_type.clone(), true), true, )) } @@ -102,7 +98,6 @@ impl AggregateExpr for OrderSensitiveArrayAgg { &self.order_by_data_types, self.ordering_req.clone(), self.reverse, - self.nullable, ) .map(|acc| Box::new(acc) as _) } @@ -110,17 +105,13 @@ impl AggregateExpr for OrderSensitiveArrayAgg { fn state_fields(&self) -> Result> { let mut fields = vec![Field::new_list( format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), self.nullable), + Field::new("item", self.input_data_type.clone(), true), true, // This should be the same as field() )]; let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); fields.push(Field::new_list( format_state_name(&self.name, "array_agg_orderings"), - Field::new( - "item", - DataType::Struct(Fields::from(orderings)), - self.nullable, - ), + Field::new("item", DataType::Struct(Fields::from(orderings)), true), false, )); Ok(fields) @@ -147,7 +138,6 @@ impl AggregateExpr for OrderSensitiveArrayAgg { name: self.name.to_string(), input_data_type: self.input_data_type.clone(), expr: Arc::clone(&self.expr), - nullable: self.nullable, order_by_data_types: self.order_by_data_types.clone(), // Reverse requirement: ordering_req: reverse_order_bys(&self.ordering_req), @@ -186,8 +176,6 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator { ordering_req: LexOrdering, /// Whether the aggregation is running in reverse. reverse: bool, - /// Whether the input expr is nullable - nullable: bool, } impl OrderSensitiveArrayAggAccumulator { @@ -198,7 +186,6 @@ impl OrderSensitiveArrayAggAccumulator { ordering_dtypes: &[DataType], ordering_req: LexOrdering, reverse: bool, - nullable: bool, ) -> Result { let mut datatypes = vec![datatype.clone()]; datatypes.extend(ordering_dtypes.iter().cloned()); @@ -208,7 +195,6 @@ impl OrderSensitiveArrayAggAccumulator { datatypes, ordering_req, reverse, - nullable, }) } } @@ -312,7 +298,7 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { if self.values.is_empty() { return Ok(ScalarValue::new_null_list( self.datatypes[0].clone(), - self.nullable, + true, 1, )); } @@ -322,14 +308,10 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { ScalarValue::new_list_from_iter( values.into_iter().rev(), &self.datatypes[0], - self.nullable, + true, ) } else { - ScalarValue::new_list_from_iter( - values.into_iter(), - &self.datatypes[0], - self.nullable, - ) + ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true) }; Ok(ScalarValue::List(array)) } @@ -385,9 +367,8 @@ impl OrderSensitiveArrayAggAccumulator { column_wise_ordering_values, None, )?; - Ok(ScalarValue::List(Arc::new(array_into_list_array( + Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable( Arc::new(ordering_array), - self.nullable, )))) } } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 68c9b4859f1f8..ef21b3d0f7883 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -62,16 +62,14 @@ pub fn create_aggregate_expr( Ok(match (fun, distinct) { (AggregateFunction::ArrayAgg, false) => { let expr = Arc::clone(&input_phy_exprs[0]); - let nullable = expr.nullable(input_schema)?; if ordering_req.is_empty() { - Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable)) + Arc::new(expressions::ArrayAgg::new(expr, name, data_type)) } else { Arc::new(expressions::OrderSensitiveArrayAgg::new( expr, name, data_type, - nullable, ordering_types, ordering_req.to_vec(), )) @@ -84,13 +82,7 @@ pub fn create_aggregate_expr( ); } let expr = Arc::clone(&input_phy_exprs[0]); - let is_expr_nullable = expr.nullable(input_schema)?; - Arc::new(expressions::DistinctArrayAgg::new( - expr, - name, - data_type, - is_expr_nullable, - )) + Arc::new(expressions::DistinctArrayAgg::new(expr, name, data_type)) } (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( Arc::clone(&input_phy_exprs[0]), diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 8bf808af3b5b8..5f780f1ff8019 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -2231,7 +2231,6 @@ mod tests { Arc::clone(col_a), "array_agg", DataType::Int32, - false, vec![], order_by_expr.unwrap_or_default(), )) as _ From de0765a97b4c348c2a9667cccd5b652591c8e532 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Wed, 17 Jul 2024 18:29:58 +0800 Subject: [PATCH 58/59] Get expr planners when creating new planner (#11485) * get expr planners when creating new planner Signed-off-by: jayzhan211 * get expr planner when creating planner Signed-off-by: jayzhan211 * no planners in sqltorel Signed-off-by: jayzhan211 * Add docs about SessionContextProvider * Use Slice rather than Vec to access expr planners * add test Signed-off-by: jayzhan211 * clippy Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 Co-authored-by: Andrew Lamb --- .../core/src/execution/session_state.rs | 70 ++++++++++++++----- datafusion/expr/src/planner.rs | 5 ++ datafusion/sql/src/expr/mod.rs | 14 ++-- datafusion/sql/src/expr/substring.rs | 2 +- datafusion/sql/src/expr/value.rs | 2 +- datafusion/sql/src/planner.rs | 10 --- 6 files changed, 68 insertions(+), 35 deletions(-) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 75eef43454873..03ce8d3b5892a 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -516,7 +516,7 @@ impl SessionState { } } - let query = self.build_sql_query_planner(&provider); + let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); query.statement_to_plan(statement) } @@ -569,7 +569,7 @@ impl SessionState { tables: HashMap::new(), }; - let query = self.build_sql_query_planner(&provider); + let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new()) } @@ -854,20 +854,6 @@ impl SessionState { let udtf = self.table_functions.remove(name); Ok(udtf.map(|x| x.function().clone())) } - - fn build_sql_query_planner<'a, S>(&self, provider: &'a S) -> SqlToRel<'a, S> - where - S: ContextProvider, - { - let mut query = SqlToRel::new_with_options(provider, self.get_parser_options()); - - // custom planners are registered first, so they're run first and take precedence over built-in planners - for planner in self.expr_planners.iter() { - query = query.with_user_defined_planner(planner.clone()); - } - - query - } } /// A builder to be used for building [`SessionState`]'s. Defaults will @@ -1597,12 +1583,20 @@ impl SessionStateDefaults { } } +/// Adapter that implements the [`ContextProvider`] trait for a [`SessionState`] +/// +/// This is used so the SQL planner can access the state of the session without +/// having a direct dependency on the [`SessionState`] struct (and core crate) struct SessionContextProvider<'a> { state: &'a SessionState, tables: HashMap>, } impl<'a> ContextProvider for SessionContextProvider<'a> { + fn get_expr_planners(&self) -> &[Arc] { + &self.state.expr_planners + } + fn get_table_source( &self, name: TableReference, @@ -1898,3 +1892,47 @@ impl<'a> SimplifyInfo for SessionSimplifyProvider<'a> { expr.get_type(self.df_schema) } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::DFSchema; + use datafusion_common::Result; + use datafusion_expr::Expr; + use datafusion_sql::planner::{PlannerContext, SqlToRel}; + + use crate::execution::context::SessionState; + + use super::{SessionContextProvider, SessionStateBuilder}; + + #[test] + fn test_session_state_with_default_features() { + // test array planners with and without builtin planners + fn sql_to_expr(state: &SessionState) -> Result { + let provider = SessionContextProvider { + state, + tables: HashMap::new(), + }; + + let sql = "[1,2,3]"; + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let df_schema = DFSchema::try_from(schema)?; + let dialect = state.config.options().sql_parser.dialect.as_str(); + let sql_expr = state.sql_to_expr(sql, dialect)?; + + let query = SqlToRel::new_with_options(&provider, state.get_parser_options()); + query.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new()) + } + + let state = SessionStateBuilder::new().with_default_features().build(); + + assert!(sql_to_expr(&state).is_ok()); + + // if no builtin planners exist, you should register your own, otherwise returns error + let state = SessionStateBuilder::new().build(); + + assert!(sql_to_expr(&state).is_err()) + } +} diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 2f13923b1f10a..009f3512c588e 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -60,6 +60,11 @@ pub trait ContextProvider { not_impl_err!("Recursive CTE is not implemented") } + /// Getter for expr planners + fn get_expr_planners(&self) -> &[Arc] { + &[] + } + /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 062ef805fd9f8..71ff7c03bea2f 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -111,7 +111,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { // try extension planers let mut binary_expr = datafusion_expr::planner::RawBinaryExpr { op, left, right }; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_binary_op(binary_expr, schema)? { PlannerResult::Planned(expr) => { return Ok(expr); @@ -184,7 +184,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ]; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_extract(extract_args)? { PlannerResult::Planned(expr) => return Ok(expr), PlannerResult::Original(args) => { @@ -283,7 +283,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; let mut field_access_expr = RawFieldAccessExpr { expr, field_access }; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_field_access(field_access_expr, schema)? { PlannerResult::Planned(expr) => return Ok(expr), PlannerResult::Original(expr) => { @@ -653,7 +653,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.create_struct_expr(values, schema, planner_context)? }; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_struct_literal(create_struct_args, is_named_struct)? { PlannerResult::Planned(expr) => return Ok(expr), PlannerResult::Original(args) => create_struct_args = args, @@ -673,7 +673,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.sql_expr_to_logical_expr(substr_expr, schema, planner_context)?; let fullstr = self.sql_expr_to_logical_expr(str_expr, schema, planner_context)?; let mut position_args = vec![fullstr, substr]; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_position(position_args)? { PlannerResult::Planned(expr) => return Ok(expr), PlannerResult::Original(args) => { @@ -703,7 +703,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut raw_expr = RawDictionaryExpr { keys, values }; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_dictionary_literal(raw_expr, schema)? { PlannerResult::Planned(expr) => { return Ok(expr); @@ -927,7 +927,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } None => vec![arg, what_arg, from_arg], }; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_overlay(overlay_args)? { PlannerResult::Planned(expr) => return Ok(expr), PlannerResult::Original(args) => overlay_args = args, diff --git a/datafusion/sql/src/expr/substring.rs b/datafusion/sql/src/expr/substring.rs index a0dfee1b9d907..f58ab5ff3612c 100644 --- a/datafusion/sql/src/expr/substring.rs +++ b/datafusion/sql/src/expr/substring.rs @@ -68,7 +68,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_substring(substring_args)? { PlannerResult::Planned(expr) => return Ok(expr), PlannerResult::Original(args) => { diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 5cd6ffc687888..1564f06fe4b9a 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -154,7 +154,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, ) -> Result { let mut exprs = values; - for planner in self.planners.iter() { + for planner in self.context_provider.get_expr_planners() { match planner.plan_array_literal(exprs, schema)? { PlannerResult::Planned(expr) => { return Ok(expr); diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index be04f51f4f2c9..901a2ad38d8cc 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -24,7 +24,6 @@ use arrow_schema::*; use datafusion_common::{ field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, SchemaError, }; -use datafusion_expr::planner::ExprPlanner; use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; @@ -186,8 +185,6 @@ pub struct SqlToRel<'a, S: ContextProvider> { pub(crate) context_provider: &'a S, pub(crate) options: ParserOptions, pub(crate) normalizer: IdentNormalizer, - /// user defined planner extensions - pub(crate) planners: Vec>, } impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -196,12 +193,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Self::new_with_options(context_provider, ParserOptions::default()) } - /// add an user defined planner - pub fn with_user_defined_planner(mut self, planner: Arc) -> Self { - self.planners.push(planner); - self - } - /// Create a new query planner pub fn new_with_options(context_provider: &'a S, options: ParserOptions) -> Self { let normalize = options.enable_ident_normalization; @@ -210,7 +201,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { context_provider, options, normalizer: IdentNormalizer::new(normalize), - planners: vec![], } } From b0925c801e1c07bd78c78b045ab58fbd0630b638 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Wed, 17 Jul 2024 04:21:34 -0700 Subject: [PATCH 59/59] Support alternate format for Utf8 unparsing (CHAR) (#11494) * Add dialect param to use CHAR instead of TEXT for Utf8 unparsing for MySQL (#12) * Configurable data type instead of flag for Utf8 unparsing * Fix type in comment --- datafusion/sql/src/unparser/dialect.rs | 52 +++++++++++++++++++++++++- datafusion/sql/src/unparser/expr.rs | 34 ++++++++++++++++- 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index eca2eb4fd0ec7..87453f81ee3d8 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -16,7 +16,7 @@ // under the License. use regex::Regex; -use sqlparser::keywords::ALL_KEYWORDS; +use sqlparser::{ast, keywords::ALL_KEYWORDS}; /// `Dialect` to use for Unparsing /// @@ -45,6 +45,17 @@ pub trait Dialect { fn interval_style(&self) -> IntervalStyle { IntervalStyle::PostgresVerbose } + + // The SQL type to use for Arrow Utf8 unparsing + // Most dialects use VARCHAR, but some, like MySQL, require CHAR + fn utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Varchar(None) + } + // The SQL type to use for Arrow LargeUtf8 unparsing + // Most dialects use TEXT, but some, like MySQL, require CHAR + fn large_utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Text + } } /// `IntervalStyle` to use for unparsing @@ -103,6 +114,14 @@ impl Dialect for MySqlDialect { fn interval_style(&self) -> IntervalStyle { IntervalStyle::MySQL } + + fn utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Char(None) + } + + fn large_utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Char(None) + } } pub struct SqliteDialect {} @@ -118,6 +137,8 @@ pub struct CustomDialect { supports_nulls_first_in_sort: bool, use_timestamp_for_date64: bool, interval_style: IntervalStyle, + utf8_cast_dtype: ast::DataType, + large_utf8_cast_dtype: ast::DataType, } impl Default for CustomDialect { @@ -127,6 +148,8 @@ impl Default for CustomDialect { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::SQLStandard, + utf8_cast_dtype: ast::DataType::Varchar(None), + large_utf8_cast_dtype: ast::DataType::Text, } } } @@ -158,6 +181,14 @@ impl Dialect for CustomDialect { fn interval_style(&self) -> IntervalStyle { self.interval_style } + + fn utf8_cast_dtype(&self) -> ast::DataType { + self.utf8_cast_dtype.clone() + } + + fn large_utf8_cast_dtype(&self) -> ast::DataType { + self.large_utf8_cast_dtype.clone() + } } /// `CustomDialectBuilder` to build `CustomDialect` using builder pattern @@ -179,6 +210,8 @@ pub struct CustomDialectBuilder { supports_nulls_first_in_sort: bool, use_timestamp_for_date64: bool, interval_style: IntervalStyle, + utf8_cast_dtype: ast::DataType, + large_utf8_cast_dtype: ast::DataType, } impl Default for CustomDialectBuilder { @@ -194,6 +227,8 @@ impl CustomDialectBuilder { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::PostgresVerbose, + utf8_cast_dtype: ast::DataType::Varchar(None), + large_utf8_cast_dtype: ast::DataType::Text, } } @@ -203,6 +238,8 @@ impl CustomDialectBuilder { supports_nulls_first_in_sort: self.supports_nulls_first_in_sort, use_timestamp_for_date64: self.use_timestamp_for_date64, interval_style: self.interval_style, + utf8_cast_dtype: self.utf8_cast_dtype, + large_utf8_cast_dtype: self.large_utf8_cast_dtype, } } @@ -235,4 +272,17 @@ impl CustomDialectBuilder { self.interval_style = interval_style; self } + + pub fn with_utf8_cast_dtype(mut self, utf8_cast_dtype: ast::DataType) -> Self { + self.utf8_cast_dtype = utf8_cast_dtype; + self + } + + pub fn with_large_utf8_cast_dtype( + mut self, + large_utf8_cast_dtype: ast::DataType, + ) -> Self { + self.large_utf8_cast_dtype = large_utf8_cast_dtype; + self + } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index e6b67b5d9fb2d..950e7e11288a7 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1275,8 +1275,8 @@ impl Unparser<'_> { DataType::BinaryView => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::Utf8 => Ok(ast::DataType::Varchar(None)), - DataType::LargeUtf8 => Ok(ast::DataType::Text), + DataType::Utf8 => Ok(self.dialect.utf8_cast_dtype()), + DataType::LargeUtf8 => Ok(self.dialect.large_utf8_cast_dtype()), DataType::Utf8View => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } @@ -1936,4 +1936,34 @@ mod tests { assert_eq!(actual, expected); } } + + #[test] + fn custom_dialect_use_char_for_utf8_cast() -> Result<()> { + let default_dialect = CustomDialectBuilder::default().build(); + let mysql_custom_dialect = CustomDialectBuilder::new() + .with_utf8_cast_dtype(ast::DataType::Char(None)) + .with_large_utf8_cast_dtype(ast::DataType::Char(None)) + .build(); + + for (dialect, data_type, identifier) in [ + (&default_dialect, DataType::Utf8, "VARCHAR"), + (&default_dialect, DataType::LargeUtf8, "TEXT"), + (&mysql_custom_dialect, DataType::Utf8, "CHAR"), + (&mysql_custom_dialect, DataType::LargeUtf8, "CHAR"), + ] { + let unparser = Unparser::new(dialect); + + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type, + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"CAST(a AS {identifier})"#); + + assert_eq!(actual, expected); + } + Ok(()) + } }