From d91a03f495b0856c1c2bfe5fc4db2ed1fe5352bf Mon Sep 17 00:00:00 2001 From: Xin Li <33629085+xinlifoobar@users.noreply.github.com> Date: Thu, 18 Jul 2024 02:44:15 -0700 Subject: [PATCH 01/37] Move handlign of NULL literals in where clause to type coercion pass (#11491) * Revert "Support `NULL` literals in where clause (#11266)" This reverts commit fa0191772e87e04da2598aedb7fe11dd49f88f88. * Followup Support NULL literals in where clause * misc err change * adopt comparison_coercion * Fix comments * Fix comments --- .../optimizer/src/analyzer/type_coercion.rs | 12 +++++- datafusion/physical-plan/src/filter.rs | 39 +++++-------------- 2 files changed, 20 insertions(+), 31 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 80a8c864e431..337492d1a55b 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -84,7 +84,7 @@ impl AnalyzerRule for TypeCoercion { /// Assumes that children have already been optimized fn analyze_internal( external_schema: &DFSchema, - plan: LogicalPlan, + mut plan: LogicalPlan, ) -> Result> { // get schema representing all available input fields. This is used for data type // resolution only, so order does not matter here @@ -103,6 +103,16 @@ fn analyze_internal( // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3) schema.merge(external_schema); + if let LogicalPlan::Filter(filter) = &mut plan { + if let Ok(new_predicate) = filter + .predicate + .clone() + .cast_to(&DataType::Boolean, filter.input.schema()) + { + filter.predicate = new_predicate; + } + } + let mut expr_rewrite = TypeCoercionRewriter::new(&schema); let name_preserver = NamePreserver::new(&plan); diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index c5ba3992d3b4..a9d78d059f5c 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -28,11 +28,11 @@ use crate::{ metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, DisplayFormatType, ExecutionPlan, }; + use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow_array::{Array, BooleanArray}; -use datafusion_common::cast::{as_boolean_array, as_null_array}; +use datafusion_common::cast::as_boolean_array; use datafusion_common::stats::Precision; use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; @@ -81,19 +81,6 @@ impl FilterExec { cache, }) } - DataType::Null => { - let default_selectivity = 0; - let cache = - Self::compute_properties(&input, &predicate, default_selectivity)?; - - Ok(Self { - predicate, - input: Arc::clone(&input), - metrics: ExecutionPlanMetricsSet::new(), - default_selectivity, - cache, - }) - } other => { plan_err!("Filter predicate must return BOOLEAN values, got {other:?}") } @@ -367,23 +354,15 @@ pub(crate) fn batch_filter( .evaluate(batch) .and_then(|v| v.into_array(batch.num_rows())) .and_then(|array| { - let filter_array = match as_boolean_array(&array) { - Ok(boolean_array) => Ok(boolean_array.to_owned()), + Ok(match as_boolean_array(&array) { + // apply filter array to record batch + Ok(filter_array) => filter_record_batch(batch, filter_array)?, Err(_) => { - let Ok(null_array) = as_null_array(&array) else { - return internal_err!( - "Cannot create filter_array from non-boolean predicates" - ); - }; - - // if the predicate is null, then the result is also null - Ok::(BooleanArray::new_null( - null_array.len(), - )) + return internal_err!( + "Cannot create filter_array from non-boolean predicates" + ); } - }?; - - Ok(filter_record_batch(batch, &filter_array)?) + }) }) } From dff2f3c3c637fd5c3b30ed0cf26fac75c22973ac Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 18 Jul 2024 06:03:39 -0400 Subject: [PATCH 02/37] Minor: Clarify which parquet options are used for reading/writing (#11511) --- datafusion/common/src/config.rs | 63 ++++++++++--------- .../common/src/file_options/parquet_writer.rs | 1 + .../test_files/information_schema.slt | 52 +++++++-------- docs/source/user-guide/configs.md | 52 +++++++-------- 4 files changed, 87 insertions(+), 81 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 880f0119ce0d..b46b002baac0 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -315,93 +315,96 @@ config_namespace! { } config_namespace! { - /// Options related to parquet files + /// Options for reading and writing parquet files /// /// See also: [`SessionConfig`] /// /// [`SessionConfig`]: https://docs.rs/datafusion/latest/datafusion/prelude/struct.SessionConfig.html pub struct ParquetOptions { - /// If true, reads the Parquet data page level metadata (the + // The following options affect reading parquet files + + /// (reading) If true, reads the Parquet data page level metadata (the /// Page Index), if present, to reduce the I/O and number of /// rows decoded. pub enable_page_index: bool, default = true - /// If true, the parquet reader attempts to skip entire row groups based + /// (reading) If true, the parquet reader attempts to skip entire row groups based /// on the predicate in the query and the metadata (min/max values) stored in /// the parquet file pub pruning: bool, default = true - /// If true, the parquet reader skip the optional embedded metadata that may be in + /// (reading) If true, the parquet reader skip the optional embedded metadata that may be in /// the file Schema. This setting can help avoid schema conflicts when querying /// multiple parquet files with schemas containing compatible types but different metadata pub skip_metadata: bool, default = true - /// If specified, the parquet reader will try and fetch the last `size_hint` + /// (reading) If specified, the parquet reader will try and fetch the last `size_hint` /// bytes of the parquet file optimistically. If not specified, two reads are required: /// One read to fetch the 8-byte parquet footer and /// another to fetch the metadata length encoded in the footer pub metadata_size_hint: Option, default = None - /// If true, filter expressions are be applied during the parquet decoding operation to + /// (reading) If true, filter expressions are be applied during the parquet decoding operation to /// reduce the number of rows decoded. This optimization is sometimes called "late materialization". pub pushdown_filters: bool, default = false - /// If true, filter expressions evaluated during the parquet decoding operation + /// (reading) If true, filter expressions evaluated during the parquet decoding operation /// will be reordered heuristically to minimize the cost of evaluation. If false, /// the filters are applied in the same order as written in the query pub reorder_filters: bool, default = false - // The following map to parquet::file::properties::WriterProperties + // The following options affect writing to parquet files + // and map to parquet::file::properties::WriterProperties - /// Sets best effort maximum size of data page in bytes + /// (writing) Sets best effort maximum size of data page in bytes pub data_pagesize_limit: usize, default = 1024 * 1024 - /// Sets write_batch_size in bytes + /// (writing) Sets write_batch_size in bytes pub write_batch_size: usize, default = 1024 - /// Sets parquet writer version + /// (writing) Sets parquet writer version /// valid values are "1.0" and "2.0" pub writer_version: String, default = "1.0".into() - /// Sets default parquet compression codec + /// (writing) Sets default parquet compression codec. /// Valid values are: uncompressed, snappy, gzip(level), /// lzo, brotli(level), lz4, zstd(level), and lz4_raw. /// These values are not case sensitive. If NULL, uses /// default parquet writer setting pub compression: Option, default = Some("zstd(3)".into()) - /// Sets if dictionary encoding is enabled. If NULL, uses + /// (writing) Sets if dictionary encoding is enabled. If NULL, uses /// default parquet writer setting pub dictionary_enabled: Option, default = None - /// Sets best effort maximum dictionary page size, in bytes + /// (writing) Sets best effort maximum dictionary page size, in bytes pub dictionary_page_size_limit: usize, default = 1024 * 1024 - /// Sets if statistics are enabled for any column + /// (writing) Sets if statistics are enabled for any column /// Valid values are: "none", "chunk", and "page" /// These values are not case sensitive. If NULL, uses /// default parquet writer setting pub statistics_enabled: Option, default = None - /// Sets max statistics size for any column. If NULL, uses + /// (writing) Sets max statistics size for any column. If NULL, uses /// default parquet writer setting pub max_statistics_size: Option, default = None - /// Target maximum number of rows in each row group (defaults to 1M + /// (writing) Target maximum number of rows in each row group (defaults to 1M /// rows). Writing larger row groups requires more memory to write, but /// can get better compression and be faster to read. pub max_row_group_size: usize, default = 1024 * 1024 - /// Sets "created by" property + /// (writing) Sets "created by" property pub created_by: String, default = concat!("datafusion version ", env!("CARGO_PKG_VERSION")).into() - /// Sets column index truncate length + /// (writing) Sets column index truncate length pub column_index_truncate_length: Option, default = None - /// Sets best effort maximum number of rows in data page + /// (writing) Sets best effort maximum number of rows in data page pub data_page_row_count_limit: usize, default = usize::MAX - /// Sets default encoding for any column + /// (writing) Sets default encoding for any column. /// Valid values are: plain, plain_dictionary, rle, /// bit_packed, delta_binary_packed, delta_length_byte_array, /// delta_byte_array, rle_dictionary, and byte_stream_split. @@ -409,27 +412,27 @@ config_namespace! { /// default parquet writer setting pub encoding: Option, default = None - /// Use any available bloom filters when reading parquet files + /// (writing) Use any available bloom filters when reading parquet files pub bloom_filter_on_read: bool, default = true - /// Write bloom filters for all columns when creating parquet files + /// (writing) Write bloom filters for all columns when creating parquet files pub bloom_filter_on_write: bool, default = false - /// Sets bloom filter false positive probability. If NULL, uses + /// (writing) Sets bloom filter false positive probability. If NULL, uses /// default parquet writer setting pub bloom_filter_fpp: Option, default = None - /// Sets bloom filter number of distinct values. If NULL, uses + /// (writing) Sets bloom filter number of distinct values. If NULL, uses /// default parquet writer setting pub bloom_filter_ndv: Option, default = None - /// Controls whether DataFusion will attempt to speed up writing + /// (writing) Controls whether DataFusion will attempt to speed up writing /// parquet files by serializing them in parallel. Each column /// in each row group in each output file are serialized in parallel /// leveraging a maximum possible core count of n_files*n_row_groups*n_columns. pub allow_single_file_parallelism: bool, default = true - /// By default parallel parquet writer is tuned for minimum + /// (writing) By default parallel parquet writer is tuned for minimum /// memory usage in a streaming execution plan. You may see /// a performance benefit when writing large parquet files /// by increasing maximum_parallel_row_group_writers and @@ -440,7 +443,7 @@ config_namespace! { /// data frame. pub maximum_parallel_row_group_writers: usize, default = 1 - /// By default parallel parquet writer is tuned for minimum + /// (writing) By default parallel parquet writer is tuned for minimum /// memory usage in a streaming execution plan. You may see /// a performance benefit when writing large parquet files /// by increasing maximum_parallel_row_group_writers and @@ -450,7 +453,6 @@ config_namespace! { /// writing out already in-memory data, such as from a cached /// data frame. pub maximum_buffered_record_batches_per_stream: usize, default = 2 - } } @@ -1534,6 +1536,9 @@ macro_rules! config_namespace_with_hashmap { } config_namespace_with_hashmap! { + /// Options controlling parquet format for individual columns. + /// + /// See [`ParquetOptions`] for more details pub struct ColumnOptions { /// Sets if bloom filter is enabled for the column path. pub bloom_filter_enabled: Option, default = None diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index dd4bb8ce505e..abe7db2009a2 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -35,6 +35,7 @@ use parquet::{ /// Options for writing parquet files #[derive(Clone, Debug)] pub struct ParquetWriterOptions { + /// parquet-rs writer properties pub writer_options: WriterProperties, } diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 95bea1223a9c..f7b755b01911 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -262,32 +262,32 @@ datafusion.execution.listing_table_ignore_subdirectory true Should sub directori datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics datafusion.execution.minimum_parallel_output_files 4 Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. -datafusion.execution.parquet.allow_single_file_parallelism true Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. -datafusion.execution.parquet.bloom_filter_fpp NULL Sets bloom filter false positive probability. If NULL, uses default parquet writer setting -datafusion.execution.parquet.bloom_filter_ndv NULL Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting -datafusion.execution.parquet.bloom_filter_on_read true Use any available bloom filters when reading parquet files -datafusion.execution.parquet.bloom_filter_on_write false Write bloom filters for all columns when creating parquet files -datafusion.execution.parquet.column_index_truncate_length NULL Sets column index truncate length -datafusion.execution.parquet.compression zstd(3) Sets default parquet compression codec Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting -datafusion.execution.parquet.created_by datafusion Sets "created by" property -datafusion.execution.parquet.data_page_row_count_limit 18446744073709551615 Sets best effort maximum number of rows in data page -datafusion.execution.parquet.data_pagesize_limit 1048576 Sets best effort maximum size of data page in bytes -datafusion.execution.parquet.dictionary_enabled NULL Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting -datafusion.execution.parquet.dictionary_page_size_limit 1048576 Sets best effort maximum dictionary page size, in bytes -datafusion.execution.parquet.enable_page_index true If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. -datafusion.execution.parquet.encoding NULL Sets default encoding for any column Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting -datafusion.execution.parquet.max_row_group_size 1048576 Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. -datafusion.execution.parquet.max_statistics_size NULL Sets max statistics size for any column. If NULL, uses default parquet writer setting -datafusion.execution.parquet.maximum_buffered_record_batches_per_stream 2 By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. -datafusion.execution.parquet.maximum_parallel_row_group_writers 1 By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. -datafusion.execution.parquet.metadata_size_hint NULL If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer -datafusion.execution.parquet.pruning true If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file -datafusion.execution.parquet.pushdown_filters false If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". -datafusion.execution.parquet.reorder_filters false If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query -datafusion.execution.parquet.skip_metadata true If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata -datafusion.execution.parquet.statistics_enabled NULL Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting -datafusion.execution.parquet.write_batch_size 1024 Sets write_batch_size in bytes -datafusion.execution.parquet.writer_version 1.0 Sets parquet writer version valid values are "1.0" and "2.0" +datafusion.execution.parquet.allow_single_file_parallelism true (writing) Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. +datafusion.execution.parquet.bloom_filter_fpp NULL (writing) Sets bloom filter false positive probability. If NULL, uses default parquet writer setting +datafusion.execution.parquet.bloom_filter_ndv NULL (writing) Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting +datafusion.execution.parquet.bloom_filter_on_read true (writing) Use any available bloom filters when reading parquet files +datafusion.execution.parquet.bloom_filter_on_write false (writing) Write bloom filters for all columns when creating parquet files +datafusion.execution.parquet.column_index_truncate_length NULL (writing) Sets column index truncate length +datafusion.execution.parquet.compression zstd(3) (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting +datafusion.execution.parquet.created_by datafusion (writing) Sets "created by" property +datafusion.execution.parquet.data_page_row_count_limit 18446744073709551615 (writing) Sets best effort maximum number of rows in data page +datafusion.execution.parquet.data_pagesize_limit 1048576 (writing) Sets best effort maximum size of data page in bytes +datafusion.execution.parquet.dictionary_enabled NULL (writing) Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting +datafusion.execution.parquet.dictionary_page_size_limit 1048576 (writing) Sets best effort maximum dictionary page size, in bytes +datafusion.execution.parquet.enable_page_index true (reading) If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. +datafusion.execution.parquet.encoding NULL (writing) Sets default encoding for any column. Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting +datafusion.execution.parquet.max_row_group_size 1048576 (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. +datafusion.execution.parquet.max_statistics_size NULL (writing) Sets max statistics size for any column. If NULL, uses default parquet writer setting +datafusion.execution.parquet.maximum_buffered_record_batches_per_stream 2 (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. +datafusion.execution.parquet.maximum_parallel_row_group_writers 1 (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. +datafusion.execution.parquet.metadata_size_hint NULL (reading) If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer +datafusion.execution.parquet.pruning true (reading) If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file +datafusion.execution.parquet.pushdown_filters false (reading) If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". +datafusion.execution.parquet.reorder_filters false (reading) If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query +datafusion.execution.parquet.skip_metadata true (reading) If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata +datafusion.execution.parquet.statistics_enabled NULL (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting +datafusion.execution.parquet.write_batch_size 1024 (writing) Sets write_batch_size in bytes +datafusion.execution.parquet.writer_version 1.0 (writing) Sets parquet writer version valid values are "1.0" and "2.0" datafusion.execution.planning_concurrency 13 Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system datafusion.execution.soft_max_rows_per_output_file 50000000 Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max datafusion.execution.sort_in_place_threshold_bytes 1048576 When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 5130b0a56d0e..8d3ecbc98544 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -49,32 +49,32 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.collect_statistics | false | Should DataFusion collect statistics after listing files | | datafusion.execution.target_partitions | 0 | Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system | | datafusion.execution.time_zone | +00:00 | The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour | -| datafusion.execution.parquet.enable_page_index | true | If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. | -| datafusion.execution.parquet.pruning | true | If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file | -| datafusion.execution.parquet.skip_metadata | true | If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata | -| datafusion.execution.parquet.metadata_size_hint | NULL | If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer | -| datafusion.execution.parquet.pushdown_filters | false | If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". | -| datafusion.execution.parquet.reorder_filters | false | If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | -| datafusion.execution.parquet.data_pagesize_limit | 1048576 | Sets best effort maximum size of data page in bytes | -| datafusion.execution.parquet.write_batch_size | 1024 | Sets write_batch_size in bytes | -| datafusion.execution.parquet.writer_version | 1.0 | Sets parquet writer version valid values are "1.0" and "2.0" | -| datafusion.execution.parquet.compression | zstd(3) | Sets default parquet compression codec Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.dictionary_enabled | NULL | Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | Sets best effort maximum dictionary page size, in bytes | -| datafusion.execution.parquet.statistics_enabled | NULL | Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.max_statistics_size | NULL | Sets max statistics size for any column. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.max_row_group_size | 1048576 | Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | -| datafusion.execution.parquet.created_by | datafusion version 40.0.0 | Sets "created by" property | -| datafusion.execution.parquet.column_index_truncate_length | NULL | Sets column index truncate length | -| datafusion.execution.parquet.data_page_row_count_limit | 18446744073709551615 | Sets best effort maximum number of rows in data page | -| datafusion.execution.parquet.encoding | NULL | Sets default encoding for any column Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.bloom_filter_on_read | true | Use any available bloom filters when reading parquet files | -| datafusion.execution.parquet.bloom_filter_on_write | false | Write bloom filters for all columns when creating parquet files | -| datafusion.execution.parquet.bloom_filter_fpp | NULL | Sets bloom filter false positive probability. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.bloom_filter_ndv | NULL | Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.allow_single_file_parallelism | true | Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | -| datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | -| datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | +| datafusion.execution.parquet.enable_page_index | true | (reading) If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. | +| datafusion.execution.parquet.pruning | true | (reading) If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file | +| datafusion.execution.parquet.skip_metadata | true | (reading) If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata | +| datafusion.execution.parquet.metadata_size_hint | NULL | (reading) If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer | +| datafusion.execution.parquet.pushdown_filters | false | (reading) If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". | +| datafusion.execution.parquet.reorder_filters | false | (reading) If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | +| datafusion.execution.parquet.data_pagesize_limit | 1048576 | (writing) Sets best effort maximum size of data page in bytes | +| datafusion.execution.parquet.write_batch_size | 1024 | (writing) Sets write_batch_size in bytes | +| datafusion.execution.parquet.writer_version | 1.0 | (writing) Sets parquet writer version valid values are "1.0" and "2.0" | +| datafusion.execution.parquet.compression | zstd(3) | (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.dictionary_enabled | NULL | (writing) Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | (writing) Sets best effort maximum dictionary page size, in bytes | +| datafusion.execution.parquet.statistics_enabled | NULL | (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.max_statistics_size | NULL | (writing) Sets max statistics size for any column. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.max_row_group_size | 1048576 | (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | +| datafusion.execution.parquet.created_by | datafusion version 40.0.0 | (writing) Sets "created by" property | +| datafusion.execution.parquet.column_index_truncate_length | NULL | (writing) Sets column index truncate length | +| datafusion.execution.parquet.data_page_row_count_limit | 18446744073709551615 | (writing) Sets best effort maximum number of rows in data page | +| datafusion.execution.parquet.encoding | NULL | (writing) Sets default encoding for any column. Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.bloom_filter_on_read | true | (writing) Use any available bloom filters when reading parquet files | +| datafusion.execution.parquet.bloom_filter_on_write | false | (writing) Write bloom filters for all columns when creating parquet files | +| datafusion.execution.parquet.bloom_filter_fpp | NULL | (writing) Sets bloom filter false positive probability. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.bloom_filter_ndv | NULL | (writing) Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.allow_single_file_parallelism | true | (writing) Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | +| datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | +| datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | | datafusion.execution.aggregate.scalar_update_factor | 10 | Specifies the threshold for using `ScalarValue`s to update accumulators during high-cardinality aggregations for each input batch. The aggregation is considered high-cardinality if the number of affected groups is greater than or equal to `batch_size / scalar_update_factor`. In such cases, `ScalarValue`s are utilized for updating accumulators, rather than the default batch-slice approach. This can lead to performance improvements. By adjusting the `scalar_update_factor`, you can balance the trade-off between more efficient accumulator updates and the number of groups affected. | | datafusion.execution.planning_concurrency | 0 | Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system | | datafusion.execution.sort_spill_reservation_bytes | 10485760 | Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). | From b19744968770c4ab426d065dec3cc5147534e87a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 18 Jul 2024 06:04:25 -0400 Subject: [PATCH 03/37] Update parquet page pruning code to use the `StatisticsExtractor` (#11483) * Update the parquet code prune_pages_in_one_row_group to use the `StatisticsExtractor` * fix doc * Increase evaluation error counter if error determining data page row counts * Optimize `single_column` --- .../datasource/physical_plan/parquet/mod.rs | 51 +- .../physical_plan/parquet/opener.rs | 4 +- .../physical_plan/parquet/page_filter.rs | 556 ++++++++---------- .../physical_plan/parquet/statistics.rs | 10 + .../core/src/physical_optimizer/pruning.rs | 27 +- 5 files changed, 279 insertions(+), 369 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index ed0fc5f0169e..1eea4eab8ba2 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use crate::datasource::listing::PartitionedFile; use crate::datasource::physical_plan::file_stream::FileStream; use crate::datasource::physical_plan::{ - parquet::page_filter::PagePruningPredicate, DisplayAs, FileGroupPartitioner, + parquet::page_filter::PagePruningAccessPlanFilter, DisplayAs, FileGroupPartitioner, FileScanConfig, }; use crate::{ @@ -39,13 +39,11 @@ use crate::{ }, }; -use arrow::datatypes::{DataType, SchemaRef}; +use arrow::datatypes::SchemaRef; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, PhysicalExpr}; use itertools::Itertools; use log::debug; -use parquet::basic::{ConvertedType, LogicalType}; -use parquet::schema::types::ColumnDescriptor; mod access_plan; mod metrics; @@ -225,7 +223,7 @@ pub struct ParquetExec { /// Optional predicate for pruning row groups (derived from `predicate`) pruning_predicate: Option>, /// Optional predicate for pruning pages (derived from `predicate`) - page_pruning_predicate: Option>, + page_pruning_predicate: Option>, /// Optional hint for the size of the parquet metadata metadata_size_hint: Option, /// Optional user defined parquet file reader factory @@ -381,19 +379,12 @@ impl ParquetExecBuilder { }) .filter(|p| !p.always_true()); - let page_pruning_predicate = predicate.as_ref().and_then(|predicate_expr| { - match PagePruningPredicate::try_new(predicate_expr, file_schema.clone()) { - Ok(pruning_predicate) => Some(Arc::new(pruning_predicate)), - Err(e) => { - debug!( - "Could not create page pruning predicate for '{:?}': {}", - pruning_predicate, e - ); - predicate_creation_errors.add(1); - None - } - } - }); + let page_pruning_predicate = predicate + .as_ref() + .map(|predicate_expr| { + PagePruningAccessPlanFilter::new(predicate_expr, file_schema.clone()) + }) + .map(Arc::new); let (projected_schema, projected_statistics, projected_output_ordering) = base_config.project(); @@ -739,7 +730,7 @@ impl ExecutionPlan for ParquetExec { fn should_enable_page_index( enable_page_index: bool, - page_pruning_predicate: &Option>, + page_pruning_predicate: &Option>, ) -> bool { enable_page_index && page_pruning_predicate.is_some() @@ -749,26 +740,6 @@ fn should_enable_page_index( .unwrap_or(false) } -// Convert parquet column schema to arrow data type, and just consider the -// decimal data type. -pub(crate) fn parquet_to_arrow_decimal_type( - parquet_column: &ColumnDescriptor, -) -> Option { - let type_ptr = parquet_column.self_type_ptr(); - match type_ptr.get_basic_info().logical_type() { - Some(LogicalType::Decimal { scale, precision }) => { - Some(DataType::Decimal128(precision as u8, scale as i8)) - } - _ => match type_ptr.get_basic_info().converted_type() { - ConvertedType::DECIMAL => Some(DataType::Decimal128( - type_ptr.get_precision() as u8, - type_ptr.get_scale() as i8, - )), - _ => None, - }, - } -} - #[cfg(test)] mod tests { // See also `parquet_exec` integration test @@ -798,7 +769,7 @@ mod tests { }; use arrow::datatypes::{Field, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; - use arrow_schema::Fields; + use arrow_schema::{DataType, Fields}; use datafusion_common::{assert_contains, ScalarValue}; use datafusion_expr::{col, lit, when, Expr}; use datafusion_physical_expr::planner::logical2physical; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs index c97b0282626a..ffe879eb8de0 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs @@ -17,7 +17,7 @@ //! [`ParquetOpener`] for opening Parquet files -use crate::datasource::physical_plan::parquet::page_filter::PagePruningPredicate; +use crate::datasource::physical_plan::parquet::page_filter::PagePruningAccessPlanFilter; use crate::datasource::physical_plan::parquet::row_group_filter::RowGroupAccessPlanFilter; use crate::datasource::physical_plan::parquet::{ row_filter, should_enable_page_index, ParquetAccessPlan, @@ -46,7 +46,7 @@ pub(super) struct ParquetOpener { pub limit: Option, pub predicate: Option>, pub pruning_predicate: Option>, - pub page_pruning_predicate: Option>, + pub page_pruning_predicate: Option>, pub table_schema: SchemaRef, pub metadata_size_hint: Option, pub metrics: ExecutionPlanMetricsSet, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index 7429ca593820..d658608ab4f1 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -17,40 +17,33 @@ //! Contains code to filter entire pages -use arrow::array::{ - BooleanArray, Decimal128Array, Float32Array, Float64Array, Int32Array, Int64Array, - StringArray, -}; -use arrow::datatypes::DataType; +use crate::datasource::physical_plan::parquet::ParquetAccessPlan; +use crate::datasource::physical_plan::parquet::StatisticsConverter; +use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use arrow::array::BooleanArray; use arrow::{array::ArrayRef, datatypes::SchemaRef}; use arrow_schema::Schema; -use datafusion_common::{Result, ScalarValue}; -use datafusion_physical_expr::expressions::Column; +use datafusion_common::ScalarValue; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use log::{debug, trace}; -use parquet::schema::types::{ColumnDescriptor, SchemaDescriptor}; +use parquet::file::metadata::{ParquetColumnIndex, ParquetOffsetIndex}; +use parquet::format::PageLocation; +use parquet::schema::types::SchemaDescriptor; use parquet::{ arrow::arrow_reader::{RowSelection, RowSelector}, - file::{ - metadata::{ParquetMetaData, RowGroupMetaData}, - page_index::index::Index, - }, - format::PageLocation, + file::metadata::{ParquetMetaData, RowGroupMetaData}, }; use std::collections::HashSet; use std::sync::Arc; -use crate::datasource::physical_plan::parquet::parquet_to_arrow_decimal_type; -use crate::datasource::physical_plan::parquet::statistics::{ - from_bytes_to_i128, parquet_column, -}; -use crate::datasource::physical_plan::parquet::ParquetAccessPlan; -use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; - use super::metrics::ParquetFileMetrics; -/// A [`PagePruningPredicate`] provides the ability to construct a [`RowSelection`] -/// based on parquet page level statistics, if any +/// Filters a [`ParquetAccessPlan`] based on the [Parquet PageIndex], if present +/// +/// It does so by evaluating statistics from the [`ParquetColumnIndex`] and +/// [`ParquetOffsetIndex`] and converting them to [`RowSelection`]. +/// +/// [Parquet PageIndex]: https://github.com/apache/parquet-format/blob/master/PageIndex.md /// /// For example, given a row group with two column (chunks) for `A` /// and `B` with the following with page level statistics: @@ -103,30 +96,52 @@ use super::metrics::ParquetFileMetrics; /// /// So we can entirely skip rows 0->199 and 250->299 as we know they /// can not contain rows that match the predicate. +/// +/// # Implementation notes +/// +/// Single column predicates are evaluated using the PageIndex information +/// for that column to determine which row ranges can be skipped based. +/// +/// The resulting [`RowSelection`]'s are combined into a final +/// row selection that is added to the [`ParquetAccessPlan`]. #[derive(Debug)] -pub struct PagePruningPredicate { +pub struct PagePruningAccessPlanFilter { + /// single column predicates (e.g. (`col = 5`) extracted from the overall + /// predicate. Must all be true for a row to be included in the result. predicates: Vec, } -impl PagePruningPredicate { - /// Create a new [`PagePruningPredicate`] - // TODO: this is infallaible -- it can not return an error - pub fn try_new(expr: &Arc, schema: SchemaRef) -> Result { +impl PagePruningAccessPlanFilter { + /// Create a new [`PagePruningAccessPlanFilter`] from a physical + /// expression. + pub fn new(expr: &Arc, schema: SchemaRef) -> Self { + // extract any single column predicates let predicates = split_conjunction(expr) .into_iter() .filter_map(|predicate| { - match PruningPredicate::try_new(predicate.clone(), schema.clone()) { - Ok(p) - if (!p.always_true()) - && (p.required_columns().n_columns() < 2) => - { - Some(Ok(p)) - } - _ => None, + let pp = + match PruningPredicate::try_new(predicate.clone(), schema.clone()) { + Ok(pp) => pp, + Err(e) => { + debug!("Ignoring error creating page pruning predicate: {e}"); + return None; + } + }; + + if pp.always_true() { + debug!("Ignoring always true page pruning predicate: {predicate}"); + return None; + } + + if pp.required_columns().single_column().is_none() { + debug!("Ignoring multi-column page pruning predicate: {predicate}"); + return None; } + + Some(pp) }) - .collect::>>()?; - Ok(Self { predicates }) + .collect::>(); + Self { predicates } } /// Returns an updated [`ParquetAccessPlan`] by applying predicates to the @@ -136,7 +151,7 @@ impl PagePruningPredicate { mut access_plan: ParquetAccessPlan, arrow_schema: &Schema, parquet_schema: &SchemaDescriptor, - file_metadata: &ParquetMetaData, + parquet_metadata: &ParquetMetaData, file_metrics: &ParquetFileMetrics, ) -> ParquetAccessPlan { // scoped timer updates on drop @@ -146,18 +161,18 @@ impl PagePruningPredicate { } let page_index_predicates = &self.predicates; - let groups = file_metadata.row_groups(); + let groups = parquet_metadata.row_groups(); if groups.is_empty() { return access_plan; } - let (Some(file_offset_indexes), Some(file_page_indexes)) = - (file_metadata.offset_index(), file_metadata.column_index()) - else { - trace!( - "skip page pruning due to lack of indexes. Have offset: {}, column index: {}", - file_metadata.offset_index().is_some(), file_metadata.column_index().is_some() + if parquet_metadata.offset_index().is_none() + || parquet_metadata.column_index().is_none() + { + debug!( + "Can not prune pages due to lack of indexes. Have offset: {}, column index: {}", + parquet_metadata.offset_index().is_some(), parquet_metadata.column_index().is_some() ); return access_plan; }; @@ -165,33 +180,39 @@ impl PagePruningPredicate { // track the total number of rows that should be skipped let mut total_skip = 0; + // for each row group specified in the access plan let row_group_indexes = access_plan.row_group_indexes(); - for r in row_group_indexes { + for row_group_index in row_group_indexes { // The selection for this particular row group let mut overall_selection = None; for predicate in page_index_predicates { - // find column index in the parquet schema - let col_idx = find_column_index(predicate, arrow_schema, parquet_schema); - let row_group_metadata = &groups[r]; - - let (Some(rg_page_indexes), Some(rg_offset_indexes), Some(col_idx)) = ( - file_page_indexes.get(r), - file_offset_indexes.get(r), - col_idx, - ) else { - trace!( - "Did not have enough metadata to prune with page indexes, \ - falling back to all rows", - ); - continue; + let column = predicate + .required_columns() + .single_column() + .expect("Page pruning requires single column predicates"); + + let converter = StatisticsConverter::try_new( + column.name(), + arrow_schema, + parquet_schema, + ); + + let converter = match converter { + Ok(converter) => converter, + Err(e) => { + debug!( + "Could not create statistics converter for column {}: {e}", + column.name() + ); + continue; + } }; let selection = prune_pages_in_one_row_group( - row_group_metadata, + row_group_index, predicate, - rg_offset_indexes.get(col_idx), - rg_page_indexes.get(col_idx), - groups[r].column(col_idx).column_descr(), + converter, + parquet_metadata, file_metrics, ); @@ -224,15 +245,15 @@ impl PagePruningPredicate { let rows_skipped = rows_skipped(&overall_selection); trace!("Overall selection from predicate skipped {rows_skipped}: {overall_selection:?}"); total_skip += rows_skipped; - access_plan.scan_selection(r, overall_selection) + access_plan.scan_selection(row_group_index, overall_selection) } else { // Selection skips all rows, so skip the entire row group - let rows_skipped = groups[r].num_rows() as usize; - access_plan.skip(r); + let rows_skipped = groups[row_group_index].num_rows() as usize; + access_plan.skip(row_group_index); total_skip += rows_skipped; trace!( "Overall selection from predicate is empty, \ - skipping all {rows_skipped} rows in row group {r}" + skipping all {rows_skipped} rows in row group {row_group_index}" ); } } @@ -242,7 +263,7 @@ impl PagePruningPredicate { access_plan } - /// Returns the number of filters in the [`PagePruningPredicate`] + /// Returns the number of filters in the [`PagePruningAccessPlanFilter`] pub fn filter_number(&self) -> usize { self.predicates.len() } @@ -266,97 +287,53 @@ fn update_selection( } } -/// Returns the column index in the row parquet schema for the single -/// column of a single column pruning predicate. -/// -/// For example, give the predicate `y > 5` +/// Returns a [`RowSelection`] for the rows in this row group to scan. /// -/// And columns in the RowGroupMetadata like `['x', 'y', 'z']` will -/// return 1. +/// This Row Selection is formed from the page index and the predicate skips row +/// ranges that can be ruled out based on the predicate. /// -/// Returns `None` if the column is not found, or if there are no -/// required columns, which is the case for predicate like `abs(i) = -/// 1` which are rewritten to `lit(true)` -/// -/// Panics: -/// -/// If the predicate contains more than one column reference (assumes -/// that `extract_page_index_push_down_predicates` only returns -/// predicate with one col) -fn find_column_index( - predicate: &PruningPredicate, - arrow_schema: &Schema, - parquet_schema: &SchemaDescriptor, -) -> Option { - let mut found_required_column: Option<&Column> = None; - - for required_column_details in predicate.required_columns().iter() { - let column = &required_column_details.0; - if let Some(found_required_column) = found_required_column.as_ref() { - // make sure it is the same name we have seen previously - assert_eq!( - column.name(), - found_required_column.name(), - "Unexpected multi column predicate" - ); - } else { - found_required_column = Some(column); - } - } - - let Some(column) = found_required_column.as_ref() else { - trace!("No column references in pruning predicate"); - return None; - }; - - parquet_column(parquet_schema, arrow_schema, column.name()).map(|x| x.0) -} - -/// Returns a `RowSelection` for the pages in this RowGroup if any -/// rows can be pruned based on the page index +/// Returns `None` if there is an error evaluating the predicate or the required +/// page information is not present. fn prune_pages_in_one_row_group( - group: &RowGroupMetaData, - predicate: &PruningPredicate, - col_offset_indexes: Option<&Vec>, - col_page_indexes: Option<&Index>, - col_desc: &ColumnDescriptor, + row_group_index: usize, + pruning_predicate: &PruningPredicate, + converter: StatisticsConverter<'_>, + parquet_metadata: &ParquetMetaData, metrics: &ParquetFileMetrics, ) -> Option { - let num_rows = group.num_rows() as usize; - let (Some(col_offset_indexes), Some(col_page_indexes)) = - (col_offset_indexes, col_page_indexes) - else { - return None; - }; - - let target_type = parquet_to_arrow_decimal_type(col_desc); - let pruning_stats = PagesPruningStatistics { - col_page_indexes, - col_offset_indexes, - target_type: &target_type, - num_rows_in_row_group: group.num_rows(), - }; + let pruning_stats = + PagesPruningStatistics::try_new(row_group_index, converter, parquet_metadata)?; - let values = match predicate.prune(&pruning_stats) { + // Each element in values is a boolean indicating whether the page may have + // values that match the predicate (true) or could not possibly have values + // that match the predicate (false). + let values = match pruning_predicate.prune(&pruning_stats) { Ok(values) => values, Err(e) => { - // stats filter array could not be built - // return a result which will not filter out any pages debug!("Error evaluating page index predicate values {e}"); metrics.predicate_evaluation_errors.add(1); return None; } }; + // Convert the information of which pages to skip into a RowSelection + // that describes the ranges of rows to skip. + let Some(page_row_counts) = pruning_stats.page_row_counts() else { + debug!( + "Can not determine page row counts for row group {row_group_index}, skipping" + ); + metrics.predicate_evaluation_errors.add(1); + return None; + }; + let mut vec = Vec::with_capacity(values.len()); - let row_vec = create_row_count_in_each_page(col_offset_indexes, num_rows); - assert_eq!(row_vec.len(), values.len()); - let mut sum_row = *row_vec.first().unwrap(); + assert_eq!(page_row_counts.len(), values.len()); + let mut sum_row = *page_row_counts.first().unwrap(); let mut selected = *values.first().unwrap(); trace!("Pruned to {:?} using {:?}", values, pruning_stats); for (i, &f) in values.iter().enumerate().skip(1) { if f == selected { - sum_row += *row_vec.get(i).unwrap(); + sum_row += *page_row_counts.get(i).unwrap(); } else { let selector = if selected { RowSelector::select(sum_row) @@ -364,7 +341,7 @@ fn prune_pages_in_one_row_group( RowSelector::skip(sum_row) }; vec.push(selector); - sum_row = *row_vec.get(i).unwrap(); + sum_row = *page_row_counts.get(i).unwrap(); selected = f; } } @@ -378,206 +355,143 @@ fn prune_pages_in_one_row_group( Some(RowSelection::from(vec)) } -fn create_row_count_in_each_page( - location: &[PageLocation], - num_rows: usize, -) -> Vec { - let mut vec = Vec::with_capacity(location.len()); - location.windows(2).for_each(|x| { - let start = x[0].first_row_index as usize; - let end = x[1].first_row_index as usize; - vec.push(end - start); - }); - vec.push(num_rows - location.last().unwrap().first_row_index as usize); - vec -} - -/// Wraps one col page_index in one rowGroup statistics in a way -/// that implements [`PruningStatistics`] +/// Implement [`PruningStatistics`] for one column's PageIndex (column_index + offset_index) #[derive(Debug)] struct PagesPruningStatistics<'a> { - col_page_indexes: &'a Index, - col_offset_indexes: &'a Vec, - // target_type means the logical type in schema: like 'DECIMAL' is the logical type, but the - // real physical type in parquet file may be `INT32, INT64, FIXED_LEN_BYTE_ARRAY` - target_type: &'a Option, - num_rows_in_row_group: i64, + row_group_index: usize, + row_group_metadatas: &'a [RowGroupMetaData], + converter: StatisticsConverter<'a>, + column_index: &'a ParquetColumnIndex, + offset_index: &'a ParquetOffsetIndex, + page_offsets: &'a Vec, } -// Extract the min or max value calling `func` from page idex -macro_rules! get_min_max_values_for_page_index { - ($self:expr, $func:ident) => {{ - match $self.col_page_indexes { - Index::NONE => None, - Index::INT32(index) => { - match $self.target_type { - // int32 to decimal with the precision and scale - Some(DataType::Decimal128(precision, scale)) => { - let vec = &index.indexes; - let vec: Vec> = vec - .iter() - .map(|x| x.$func().and_then(|x| Some(*x as i128))) - .collect(); - Decimal128Array::from(vec) - .with_precision_and_scale(*precision, *scale) - .ok() - .map(|arr| Arc::new(arr) as ArrayRef) - } - _ => { - let vec = &index.indexes; - Some(Arc::new(Int32Array::from_iter( - vec.iter().map(|x| x.$func().cloned()), - ))) - } - } - } - Index::INT64(index) => { - match $self.target_type { - // int64 to decimal with the precision and scale - Some(DataType::Decimal128(precision, scale)) => { - let vec = &index.indexes; - let vec: Vec> = vec - .iter() - .map(|x| x.$func().and_then(|x| Some(*x as i128))) - .collect(); - Decimal128Array::from(vec) - .with_precision_and_scale(*precision, *scale) - .ok() - .map(|arr| Arc::new(arr) as ArrayRef) - } - _ => { - let vec = &index.indexes; - Some(Arc::new(Int64Array::from_iter( - vec.iter().map(|x| x.$func().cloned()), - ))) - } - } - } - Index::FLOAT(index) => { - let vec = &index.indexes; - Some(Arc::new(Float32Array::from_iter( - vec.iter().map(|x| x.$func().cloned()), - ))) - } - Index::DOUBLE(index) => { - let vec = &index.indexes; - Some(Arc::new(Float64Array::from_iter( - vec.iter().map(|x| x.$func().cloned()), - ))) - } - Index::BOOLEAN(index) => { - let vec = &index.indexes; - Some(Arc::new(BooleanArray::from_iter( - vec.iter().map(|x| x.$func().cloned()), - ))) - } - Index::BYTE_ARRAY(index) => match $self.target_type { - Some(DataType::Decimal128(precision, scale)) => { - let vec = &index.indexes; - Decimal128Array::from( - vec.iter() - .map(|x| { - x.$func() - .and_then(|x| Some(from_bytes_to_i128(x.as_ref()))) - }) - .collect::>>(), - ) - .with_precision_and_scale(*precision, *scale) - .ok() - .map(|arr| Arc::new(arr) as ArrayRef) - } - _ => { - let vec = &index.indexes; - let array: StringArray = vec - .iter() - .map(|x| x.$func()) - .map(|x| x.and_then(|x| std::str::from_utf8(x.as_ref()).ok())) - .collect(); - Some(Arc::new(array)) - } - }, - Index::INT96(_) => { - //Todo support these type - None - } - Index::FIXED_LEN_BYTE_ARRAY(index) => match $self.target_type { - Some(DataType::Decimal128(precision, scale)) => { - let vec = &index.indexes; - Decimal128Array::from( - vec.iter() - .map(|x| { - x.$func() - .and_then(|x| Some(from_bytes_to_i128(x.as_ref()))) - }) - .collect::>>(), - ) - .with_precision_and_scale(*precision, *scale) - .ok() - .map(|arr| Arc::new(arr) as ArrayRef) - } - _ => None, - }, - } - }}; +impl<'a> PagesPruningStatistics<'a> { + /// Creates a new [`PagesPruningStatistics`] for a column in a row group, if + /// possible. + /// + /// Returns None if the `parquet_metadata` does not have sufficient + /// information to create the statistics. + fn try_new( + row_group_index: usize, + converter: StatisticsConverter<'a>, + parquet_metadata: &'a ParquetMetaData, + ) -> Option { + let Some(parquet_column_index) = converter.parquet_index() else { + trace!( + "Column {:?} not in parquet file, skipping", + converter.arrow_field() + ); + return None; + }; + + let column_index = parquet_metadata.column_index()?; + let offset_index = parquet_metadata.offset_index()?; + let row_group_metadatas = parquet_metadata.row_groups(); + + let Some(row_group_page_offsets) = offset_index.get(row_group_index) else { + trace!("No page offsets for row group {row_group_index}, skipping"); + return None; + }; + let Some(page_offsets) = row_group_page_offsets.get(parquet_column_index) else { + trace!( + "No page offsets for column {:?} in row group {row_group_index}, skipping", + converter.arrow_field() + ); + return None; + }; + + Some(Self { + row_group_index, + row_group_metadatas, + converter, + column_index, + offset_index, + page_offsets, + }) + } + + /// return the row counts in each data page, if possible. + fn page_row_counts(&self) -> Option> { + let row_group_metadata = self + .row_group_metadatas + .get(self.row_group_index) + // fail fast/panic if row_group_index is out of bounds + .unwrap(); + + let num_rows_in_row_group = row_group_metadata.num_rows() as usize; + + let page_offsets = self.page_offsets; + let mut vec = Vec::with_capacity(page_offsets.len()); + page_offsets.windows(2).for_each(|x| { + let start = x[0].first_row_index as usize; + let end = x[1].first_row_index as usize; + vec.push(end - start); + }); + vec.push(num_rows_in_row_group - page_offsets.last()?.first_row_index as usize); + Some(vec) + } } impl<'a> PruningStatistics for PagesPruningStatistics<'a> { fn min_values(&self, _column: &datafusion_common::Column) -> Option { - get_min_max_values_for_page_index!(self, min) + match self.converter.data_page_mins( + self.column_index, + self.offset_index, + [&self.row_group_index], + ) { + Ok(min_values) => Some(min_values), + Err(e) => { + debug!("Error evaluating data page min values {e}"); + None + } + } } fn max_values(&self, _column: &datafusion_common::Column) -> Option { - get_min_max_values_for_page_index!(self, max) + match self.converter.data_page_maxes( + self.column_index, + self.offset_index, + [&self.row_group_index], + ) { + Ok(min_values) => Some(min_values), + Err(e) => { + debug!("Error evaluating data page max values {e}"); + None + } + } } fn num_containers(&self) -> usize { - self.col_offset_indexes.len() + self.page_offsets.len() } fn null_counts(&self, _column: &datafusion_common::Column) -> Option { - match self.col_page_indexes { - Index::NONE => None, - Index::BOOLEAN(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::INT32(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::INT64(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::FLOAT(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::DOUBLE(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::INT96(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::BYTE_ARRAY(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::FIXED_LEN_BYTE_ARRAY(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), + match self.converter.data_page_null_counts( + self.column_index, + self.offset_index, + [&self.row_group_index], + ) { + Ok(null_counts) => Some(Arc::new(null_counts)), + Err(e) => { + debug!("Error evaluating data page null counts {e}"); + None + } } } fn row_counts(&self, _column: &datafusion_common::Column) -> Option { - // see https://github.com/apache/arrow-rs/blob/91f0b1771308609ca27db0fb1d2d49571b3980d8/parquet/src/file/metadata.rs#L979-L982 - - let row_count_per_page = self.col_offset_indexes.windows(2).map(|location| { - Some(location[1].first_row_index - location[0].first_row_index) - }); - - // append the last page row count - let row_count_per_page = row_count_per_page.chain(std::iter::once(Some( - self.num_rows_in_row_group - - self.col_offset_indexes.last().unwrap().first_row_index, - ))); - - Some(Arc::new(Int64Array::from_iter(row_count_per_page))) + match self.converter.data_page_row_counts( + self.offset_index, + self.row_group_metadatas, + [&self.row_group_index], + ) { + Ok(row_counts) => row_counts.map(|a| Arc::new(a) as ArrayRef), + Err(e) => { + debug!("Error evaluating data page row counts {e}"); + None + } + } } fn contained( diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index 44e22f778075..3d250718f736 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -1136,6 +1136,16 @@ pub struct StatisticsConverter<'a> { } impl<'a> StatisticsConverter<'a> { + /// Return the index of the column in the parquet file, if any + pub fn parquet_index(&self) -> Option { + self.parquet_index + } + + /// Return the arrow field of the column in the arrow schema + pub fn arrow_field(&self) -> &'a Field { + self.arrow_field + } + /// Returns a [`UInt64Array`] with row counts for each row group /// /// # Return Value diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index a1ace229985e..3c18e53497fd 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -609,6 +609,8 @@ impl PruningPredicate { /// /// This happens if the predicate is a literal `true` and /// literal_guarantees is empty. + /// + /// This can happen when a predicate is simplified to a constant `true` pub fn always_true(&self) -> bool { is_always_true(&self.predicate_expr) && self.literal_guarantees.is_empty() } @@ -736,12 +738,25 @@ impl RequiredColumns { Self::default() } - /// Returns number of unique columns - pub(crate) fn n_columns(&self) -> usize { - self.iter() - .map(|(c, _s, _f)| c) - .collect::>() - .len() + /// Returns Some(column) if this is a single column predicate. + /// + /// Returns None if this is a multi-column predicate. + /// + /// Examples: + /// * `a > 5 OR a < 10` returns `Some(a)` + /// * `a > 5 OR b < 10` returns `None` + /// * `true` returns None + pub(crate) fn single_column(&self) -> Option<&phys_expr::Column> { + if self.columns.windows(2).all(|w| { + // check if all columns are the same (ignoring statistics and field) + let c1 = &w[0].0; + let c2 = &w[1].0; + c1 == c2 + }) { + self.columns.first().map(|r| &r.0) + } else { + None + } } /// Returns an iterator over items in columns (see doc on From be130b46709e084f969b15e7686cddb289a198ff Mon Sep 17 00:00:00 2001 From: Oleks V Date: Thu, 18 Jul 2024 09:53:27 -0700 Subject: [PATCH 04/37] Enable SortMergeJoin LeftAnti filtered fuzz tests (#11535) * Enable LeftAnti filtered fuzz tests * Enable LeftAnti filtered fuzz tests. Add git reference --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 17dbf3a0ff28..604c1f93e55e 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -83,7 +83,7 @@ fn less_than_100_join_filter(schema1: Arc, _schema2: Arc) -> Joi } fn col_lt_col_filter(schema1: Arc, schema2: Arc) -> JoinFilter { - let less_than_100 = Arc::new(BinaryExpr::new( + let less_filter = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 1)), Operator::Lt, Arc::new(Column::new("x", 0)), @@ -99,11 +99,19 @@ fn col_lt_col_filter(schema1: Arc, schema2: Arc) -> JoinFilter { }, ]; let intermediate_schema = Schema::new(vec![ - schema1.field_with_name("x").unwrap().to_owned(), - schema2.field_with_name("x").unwrap().to_owned(), + schema1 + .field_with_name("x") + .unwrap() + .clone() + .with_nullable(true), + schema2 + .field_with_name("x") + .unwrap() + .clone() + .with_nullable(true), ]); - JoinFilter::new(less_than_100, column_indices, intermediate_schema) + JoinFilter::new(less_filter, column_indices, intermediate_schema) } #[tokio::test] @@ -217,6 +225,8 @@ async fn test_semi_join_1k() { #[tokio::test] async fn test_semi_join_1k_filtered() { + // NLJ vs HJ gives wrong result + // Tracked in https://github.com/apache/datafusion/issues/11537 JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), @@ -239,17 +249,17 @@ async fn test_anti_join_1k() { .await } -// Test failed for now. https://github.com/apache/datafusion/issues/10872 -#[ignore] #[tokio::test] async fn test_anti_join_1k_filtered() { + // NLJ vs HJ gives wrong result + // Tracked in https://github.com/apache/datafusion/issues/11537 JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftAnti, - Some(Box::new(less_than_100_join_filter)), + Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[JoinTestType::HjSmj], false) .await } @@ -422,12 +432,13 @@ impl JoinFuzzTestCase { let session_config = SessionConfig::new().with_batch_size(*batch_size); let ctx = SessionContext::new_with_config(session_config); let task_ctx = ctx.task_ctx(); - let smj = self.sort_merge_join(); - let smj_collected = collect(smj, task_ctx.clone()).await.unwrap(); let hj = self.hash_join(); let hj_collected = collect(hj, task_ctx.clone()).await.unwrap(); + let smj = self.sort_merge_join(); + let smj_collected = collect(smj, task_ctx.clone()).await.unwrap(); + let nlj = self.nested_loop_join(); let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap(); @@ -437,11 +448,12 @@ impl JoinFuzzTestCase { let nlj_rows = nlj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); if debug { - println!("The debug is ON. Input data will be saved"); let fuzz_debug = "fuzz_test_debug"; std::fs::remove_dir_all(fuzz_debug).unwrap_or(()); std::fs::create_dir_all(fuzz_debug).unwrap(); let out_dir_name = &format!("{fuzz_debug}/batch_size_{batch_size}"); + println!("The debug is ON. Input data will be saved to {out_dir_name}"); + Self::save_partitioned_batches_as_parquet( &self.input1, out_dir_name, @@ -562,8 +574,7 @@ impl JoinFuzzTestCase { /// Some(Box::new(col_lt_col_filter)), /// ) /// .run_test(&[JoinTestType::HjSmj], false) - /// .await - /// } + /// .await; /// /// let ctx: SessionContext = SessionContext::new(); /// let df = ctx @@ -592,6 +603,7 @@ impl JoinFuzzTestCase { /// ) /// .run_test() /// .await + /// } fn save_partitioned_batches_as_parquet( input: &[RecordBatch], output_dir: &str, From b685e2d4f1f245dd1dbe468b32b115ae99316689 Mon Sep 17 00:00:00 2001 From: JasonLi Date: Fri, 19 Jul 2024 03:22:39 +0800 Subject: [PATCH 05/37] chore: fix typos of expr, functions, optimizer, physical-expr-common, physical-expr, and physical-plan packages (#11538) --- datafusion/expr/src/aggregate_function.rs | 4 ++-- datafusion/expr/src/expr.rs | 6 +++--- datafusion/expr/src/expr_rewriter/mod.rs | 4 ++-- datafusion/expr/src/logical_plan/builder.rs | 6 +++--- datafusion/expr/src/logical_plan/display.rs | 4 ++-- datafusion/expr/src/logical_plan/plan.rs | 6 +++--- datafusion/expr/src/partition_evaluator.rs | 2 +- datafusion/expr/src/signature.rs | 2 +- datafusion/expr/src/type_coercion/binary.rs | 4 ++-- datafusion/expr/src/type_coercion/functions.rs | 2 +- datafusion/expr/src/type_coercion/mod.rs | 2 +- datafusion/expr/src/utils.rs | 2 +- .../src/approx_percentile_cont_with_weight.rs | 2 +- datafusion/functions-array/src/remove.rs | 2 +- datafusion/functions/src/core/arrow_cast.rs | 2 +- datafusion/functions/src/datetime/to_local_time.rs | 4 ++-- datafusion/functions/src/regex/regexpreplace.rs | 2 +- datafusion/functions/src/unicode/substrindex.rs | 8 ++++---- datafusion/optimizer/src/analyzer/subquery.rs | 4 ++-- datafusion/optimizer/src/common_subexpr_eliminate.rs | 4 ++-- .../optimizer/src/decorrelate_predicate_subquery.rs | 2 +- .../src/optimize_projections/required_indices.rs | 2 +- datafusion/optimizer/src/push_down_filter.rs | 2 +- .../src/simplify_expressions/expr_simplifier.rs | 4 ++-- .../optimizer/src/unwrap_cast_in_comparison.rs | 2 +- .../src/aggregate/groups_accumulator/accumulate.rs | 4 ++-- datafusion/physical-expr-common/src/aggregate/mod.rs | 4 ++-- datafusion/physical-expr-common/src/binary_map.rs | 4 ++-- .../physical-expr-common/src/expressions/column.rs | 2 +- .../src/aggregate/groups_accumulator/adapter.rs | 2 +- datafusion/physical-expr/src/aggregate/min_max.rs | 2 +- datafusion/physical-expr/src/equivalence/class.rs | 2 +- datafusion/physical-expr/src/expressions/case.rs | 2 +- datafusion/physical-expr/src/expressions/column.rs | 2 +- datafusion/physical-expr/src/expressions/try_cast.rs | 2 +- datafusion/physical-expr/src/utils/guarantee.rs | 2 +- .../physical-plan/src/aggregates/group_values/row.rs | 2 +- datafusion/physical-plan/src/aggregates/mod.rs | 6 +++--- datafusion/physical-plan/src/analyze.rs | 2 +- datafusion/physical-plan/src/display.rs | 2 +- datafusion/physical-plan/src/joins/cross_join.rs | 2 +- .../physical-plan/src/joins/nested_loop_join.rs | 2 +- .../physical-plan/src/joins/sort_merge_join.rs | 2 +- .../physical-plan/src/joins/symmetric_hash_join.rs | 2 +- datafusion/physical-plan/src/joins/utils.rs | 12 ++++++------ datafusion/physical-plan/src/limit.rs | 2 +- datafusion/physical-plan/src/repartition/mod.rs | 6 +++--- datafusion/physical-plan/src/sorts/sort.rs | 8 ++++---- datafusion/physical-plan/src/streaming.rs | 2 +- datafusion/physical-plan/src/test/exec.rs | 2 +- datafusion/physical-plan/src/topk/mod.rs | 8 ++++---- datafusion/physical-plan/src/windows/mod.rs | 2 +- .../physical-plan/src/windows/window_agg_exec.rs | 2 +- datafusion/physical-plan/src/work_table.rs | 2 +- 54 files changed, 89 insertions(+), 89 deletions(-) diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 3cae78eaed9b..39b3b4ed3b5a 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -152,8 +152,8 @@ mod tests { use strum::IntoEnumIterator; #[test] - // Test for AggregateFuncion's Display and from_str() implementations. - // For each variant in AggregateFuncion, it converts the variant to a string + // Test for AggregateFunction's Display and from_str() implementations. + // For each variant in AggregateFunction, it converts the variant to a string // and then back to a variant. The test asserts that the original variant and // the reconstructed variant are the same. This assertion is also necessary for // function suggestion. See https://github.com/apache/datafusion/issues/8082 diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index a344e621ddb1..e3620501d9a8 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -109,7 +109,7 @@ use sqlparser::ast::NullTreatment; /// ## Binary Expressions /// /// Exprs implement traits that allow easy to understand construction of more -/// complex expresions. For example, to create `c1 + c2` to add columns "c1" and +/// complex expressions. For example, to create `c1 + c2` to add columns "c1" and /// "c2" together /// /// ``` @@ -1398,7 +1398,7 @@ impl Expr { } Ok(TreeNodeRecursion::Continue) }) - .expect("traversal is infallable"); + .expect("traversal is infallible"); } /// Return all references to columns and their occurrence counts in the expression. @@ -1433,7 +1433,7 @@ impl Expr { } Ok(TreeNodeRecursion::Continue) }) - .expect("traversal is infallable"); + .expect("traversal is infallible"); } /// Returns true if there are any column references in this Expr diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 91bec501f4a0..8d460bdc8e7d 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -155,7 +155,7 @@ pub fn unnormalize_col(expr: Expr) -> Expr { }) }) .data() - .expect("Unnormalize is infallable") + .expect("Unnormalize is infallible") } /// Create a Column from the Scalar Expr @@ -201,7 +201,7 @@ pub fn strip_outer_reference(expr: Expr) -> Expr { }) }) .data() - .expect("strip_outer_reference is infallable") + .expect("strip_outer_reference is infallible") } /// Returns plan with expressions coerced to types compatible with diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 4ad3bd5018a4..98e262f0b187 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -412,14 +412,14 @@ impl LogicalPlanBuilder { /// Add missing sort columns to all downstream projection /// - /// Thus, if you have a LogialPlan that selects A and B and have + /// Thus, if you have a LogicalPlan that selects A and B and have /// not requested a sort by C, this code will add C recursively to /// all input projections. /// /// Adding a new column is not correct if there is a `Distinct` /// node, which produces only distinct values of its /// inputs. Adding a new column to its input will result in - /// potententially different results than with the original column. + /// potentially different results than with the original column. /// /// For example, if the input is like: /// @@ -1763,7 +1763,7 @@ mod tests { .unwrap(); assert_eq!(&expected, plan.schema().as_ref()); - // Note scan of "EMPLOYEE_CSV" is treated as a SQL identifer + // Note scan of "EMPLOYEE_CSV" is treated as a SQL identifier // (and thus normalized to "employee"csv") as well let projection = None; let plan = diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 81fd03555abb..343eda056ffe 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -338,9 +338,9 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { .collect::>() .join(", "); - let elipse = if values.len() > 5 { "..." } else { "" }; + let eclipse = if values.len() > 5 { "..." } else { "" }; - let values_str = format!("{}{}", str_values, elipse); + let values_str = format!("{}{}", str_values, eclipse); json!({ "Node Type": "Values", "Values": values_str diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index bde9655b8a39..48fa6270b202 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -263,7 +263,7 @@ pub enum LogicalPlan { /// Prepare a statement and find any bind parameters /// (e.g. `?`). This is used to implement SQL-prepared statements. Prepare(Prepare), - /// Data Manipulaton Language (DML): Insert / Update / Delete + /// Data Manipulation Language (DML): Insert / Update / Delete Dml(DmlStatement), /// Data Definition Language (DDL): CREATE / DROP TABLES / VIEWS / SCHEMAS Ddl(DdlStatement), @@ -1598,8 +1598,8 @@ impl LogicalPlan { }) .collect(); - let elipse = if values.len() > 5 { "..." } else { "" }; - write!(f, "Values: {}{}", str_values.join(", "), elipse) + let eclipse = if values.len() > 5 { "..." } else { "" }; + write!(f, "Values: {}{}", str_values.join(", "), eclipse) } LogicalPlan::TableScan(TableScan { diff --git a/datafusion/expr/src/partition_evaluator.rs b/datafusion/expr/src/partition_evaluator.rs index 04b6faf55ae1..a0f0988b4f4e 100644 --- a/datafusion/expr/src/partition_evaluator.rs +++ b/datafusion/expr/src/partition_evaluator.rs @@ -135,7 +135,7 @@ pub trait PartitionEvaluator: Debug + Send { /// must produce an output column with one output row for every /// input row. /// - /// `num_rows` is requied to correctly compute the output in case + /// `num_rows` is required to correctly compute the output in case /// `values.len() == 0` /// /// Implementing this function is an optimization: certain window diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index fba793dd229d..eadd7ac2f83f 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -65,7 +65,7 @@ pub enum Volatility { /// automatically coerces (add casts to) function arguments so they match the type signature. /// /// For example, a function like `cos` may only be implemented for `Float64` arguments. To support a query -/// that calles `cos` with a different argument type, such as `cos(int_column)`, type coercion automatically +/// that calls `cos` with a different argument type, such as `cos(int_column)`, type coercion automatically /// adds a cast such as `cos(CAST int_column AS DOUBLE)` during planning. /// /// # Data Types diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 70139aaa4a0c..e1765b5c3e6a 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -370,7 +370,7 @@ impl From<&DataType> for TypeCategory { /// The rules in the document provide a clue, but adhering strictly to them doesn't precisely /// align with the behavior of Postgres. Therefore, we've made slight adjustments to the rules /// to better match the behavior of both Postgres and DuckDB. For example, we expect adjusted -/// decimal percision and scale when coercing decimal types. +/// decimal precision and scale when coercing decimal types. pub fn type_union_resolution(data_types: &[DataType]) -> Option { if data_types.is_empty() { return None; @@ -718,7 +718,7 @@ pub fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { (Int16 | Int32 | Int64, Int8) | (Int32 | Int64, Int16) | (Int64, Int32) | // Left Float is larger than right Float. (Float32 | Float64, Float16) | (Float64, Float32) | - // Left String is larget than right String. + // Left String is larger than right String. (LargeUtf8, Utf8) | // Any left type is wider than a right hand side Null. (_, Null) => lhs.clone(), diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index b430b343e484..ef52a01e0598 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -646,7 +646,7 @@ mod tests { vec![DataType::UInt8, DataType::UInt16], Some(vec![DataType::UInt8, DataType::UInt16]), ), - // 2 entries, can coerse values + // 2 entries, can coerce values ( vec![DataType::UInt16, DataType::UInt16], vec![DataType::UInt8, DataType::UInt16], diff --git a/datafusion/expr/src/type_coercion/mod.rs b/datafusion/expr/src/type_coercion/mod.rs index 86005da3dafa..e0d1236aac2d 100644 --- a/datafusion/expr/src/type_coercion/mod.rs +++ b/datafusion/expr/src/type_coercion/mod.rs @@ -19,7 +19,7 @@ //! //! Coercion is performed automatically by DataFusion when the types //! of arguments passed to a function or needed by operators do not -//! exacty match the types required by that function / operator. In +//! exactly match the types required by that function / operator. In //! this case, DataFusion will attempt to *coerce* the arguments to //! types accepted by the function by inserting CAST operations. //! diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 45155cbd2c27..889aa0952e51 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1212,7 +1212,7 @@ pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { } } -/// Build state name. State is the intermidiate state of the aggregate function. +/// Build state name. State is the intermediate state of the aggregate function. pub fn format_state_name(name: &str, state_name: &str) -> String { format!("{name}[{state_name}]") } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index a64218c606c4..0dbea1fb1ff7 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -43,7 +43,7 @@ make_udaf_expr_and_func!( approx_percentile_cont_with_weight_udaf ); -/// APPROX_PERCENTILE_CONT_WITH_WEIGTH aggregate expression +/// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression pub struct ApproxPercentileContWithWeight { signature: Signature, approx_percentile_cont: ApproxPercentileCont, diff --git a/datafusion/functions-array/src/remove.rs b/datafusion/functions-array/src/remove.rs index 589dd4d0c41c..0b7cfc283c06 100644 --- a/datafusion/functions-array/src/remove.rs +++ b/datafusion/functions-array/src/remove.rs @@ -228,7 +228,7 @@ fn array_remove_internal( } } -/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurences +/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences /// of `element_array[i]`. /// /// The type of each **element** in `list_array` must be the same as the type of diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 9c410d4e18e8..9227f9e3a2a8 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -444,7 +444,7 @@ fn is_separator(c: char) -> bool { } #[derive(Debug)] -/// Splits a strings like Dictionary(Int32, Int64) into tokens sutable for parsing +/// Splits a strings like Dictionary(Int32, Int64) into tokens suitable for parsing /// /// For example the string "Timestamp(Nanosecond, None)" would be parsed into: /// diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index c84d1015bd7e..634e28e6f393 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -84,7 +84,7 @@ impl ToLocalTimeFunc { let arg_type = time_value.data_type(); match arg_type { DataType::Timestamp(_, None) => { - // if no timezone specificed, just return the input + // if no timezone specified, just return the input Ok(time_value.clone()) } // If has timezone, adjust the underlying time value. The current time value @@ -165,7 +165,7 @@ impl ToLocalTimeFunc { match array.data_type() { Timestamp(_, None) => { - // if no timezone specificed, just return the input + // if no timezone specified, just return the input Ok(time_value.clone()) } Timestamp(Nanosecond, Some(_)) => { diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index 378b6ced076c..d820f991be18 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -562,7 +562,7 @@ mod tests { #[test] fn test_static_pattern_regexp_replace_pattern_error() { let values = StringArray::from(vec!["abc"; 5]); - // Delibaretely using an invalid pattern to see how the single pattern + // Deliberately using an invalid pattern to see how the single pattern // error is propagated on regexp_replace. let patterns = StringArray::from(vec!["["; 5]); let replacements = StringArray::from(vec!["foo"; 5]); diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index a057e4298546..f8ecab9073c4 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -122,15 +122,15 @@ pub fn substr_index(args: &[ArrayRef]) -> Result { let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); let length = if n > 0 { - let splitted = string.split(delimiter); - splitted + let split = string.split(delimiter); + split .take(occurrences) .map(|s| s.len() + delimiter.len()) .sum::() - delimiter.len() } else { - let splitted = string.rsplit(delimiter); - splitted + let split = string.rsplit(delimiter); + split .take(occurrences) .map(|s| s.len() + delimiter.len()) .sum::() diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index db39f8f7737d..9856ea271ca5 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -159,11 +159,11 @@ fn check_inner_plan( let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate) .into_iter() .partition(|e| e.contains_outer()); - let maybe_unsupport = correlated + let maybe_unsupported = correlated .into_iter() .filter(|expr| !can_pullup_over_aggregation(expr)) .collect::>(); - if is_aggregate && is_scalar && !maybe_unsupport.is_empty() { + if is_aggregate && is_scalar && !maybe_unsupported.is_empty() { return plan_err!( "Correlated column is not allowed in predicate: {predicate}" ); diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index e4b36652974d..bbf2091c2217 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -248,7 +248,7 @@ impl CommonSubexprEliminate { } /// Rewrites the expression in `exprs_list` with common sub-expressions - /// replaced with a new colum and adds a ProjectionExec on top of `input` + /// replaced with a new column and adds a ProjectionExec on top of `input` /// which computes any replaced common sub-expressions. /// /// Returns a tuple of: @@ -636,7 +636,7 @@ impl CommonSubexprEliminate { /// Returns the window expressions, and the input to the deepest child /// LogicalPlan. /// -/// For example, if the input widnow looks like +/// For example, if the input window looks like /// /// ```text /// LogicalPlan::Window(exprs=[a, b, c]) diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 4e3ca7e33a2e..b6d49490d437 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -1232,7 +1232,7 @@ mod tests { } #[test] - fn in_subquery_muti_project_subquery_cols() -> Result<()> { + fn in_subquery_multi_project_subquery_cols() -> Result<()> { let table_scan = test_table_scan()?; let subquery_scan = test_table_scan_with_name("sq")?; diff --git a/datafusion/optimizer/src/optimize_projections/required_indices.rs b/datafusion/optimizer/src/optimize_projections/required_indices.rs index 3f32a0c36a9a..a9a18898c82e 100644 --- a/datafusion/optimizer/src/optimize_projections/required_indices.rs +++ b/datafusion/optimizer/src/optimize_projections/required_indices.rs @@ -160,7 +160,7 @@ impl RequiredIndicies { (l, r.map_indices(|idx| idx - n)) } - /// Partitions the indicies in this instance into two groups based on the + /// Partitions the indices in this instance into two groups based on the /// given predicate function `f`. fn partition(&self, f: F) -> (Self, Self) where diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 20e2ac07dffd..33b2883d6ed8 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1913,7 +1913,7 @@ mod tests { assert_optimized_plan_eq(plan, expected) } - /// post-join predicates with columns from both sides are converted to join filterss + /// post-join predicates with columns from both sides are converted to join filters #[test] fn filter_join_on_common_dependent() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 8414f39f3060..56556f387d1b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -478,7 +478,7 @@ struct ConstEvaluator<'a> { #[allow(dead_code)] /// The simplify result of ConstEvaluator enum ConstSimplifyResult { - // Expr was simplifed and contains the new expression + // Expr was simplified and contains the new expression Simplified(ScalarValue), // Expr was not simplified and original value is returned NotSimplified(ScalarValue), @@ -519,7 +519,7 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { fn f_up(&mut self, expr: Expr) -> Result> { match self.can_evaluate.pop() { // Certain expressions such as `CASE` and `COALESCE` are short circuiting - // and may not evalute all their sub expressions. Thus if + // and may not evaluate all their sub expressions. Thus if // if any error is countered during simplification, return the original // so that normal evaluation can occur Some(true) => { diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 7238dd5bbd97..e0f50a470d43 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -893,7 +893,7 @@ mod tests { DataType::Timestamp(TimeUnit::Nanosecond, utc) } - // a dictonary type for storing string tags + // a dictionary type for storing string tags fn dictionary_tag_type() -> DataType { DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) } diff --git a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs index f109079f6a26..3fcd570f514e 100644 --- a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs @@ -410,7 +410,7 @@ pub fn accumulate_indices( }, ); - // handle any remaining bits (after the intial 64) + // handle any remaining bits (after the initial 64) let remainder_bits = bit_chunks.remainder_bits(); group_indices_remainder .iter() @@ -835,7 +835,7 @@ mod test { } } - /// Parallel implementaiton of NullState to check expected values + /// Parallel implementation of NullState to check expected values #[derive(Debug, Default)] struct MockNullState { /// group indices that had values that passed the filter diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 0e245fd0a66a..7a4a3a6cac4b 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -346,7 +346,7 @@ impl AggregateExpr for AggregateFunctionExpr { let accumulator = self.fun.create_sliding_accumulator(args)?; // Accumulators that have window frame startings different - // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to + // than `UNBOUNDED PRECEDING`, such as `1 PRECEDING`, need to // implement retract_batch method in order to run correctly // currently in DataFusion. // @@ -377,7 +377,7 @@ impl AggregateExpr for AggregateFunctionExpr { // 3. Third sum we add to the state sum value between `[2, 3)` // (`[0, 2)` is already in the state sum). Also we need to // retract values between `[0, 1)` by this way we can obtain sum - // between [1, 3) which is indeed the apropriate range. + // between [1, 3) which is indeed the appropriate range. // // When we use `UNBOUNDED PRECEDING` in the query starting // index will always be 0 for the desired range, and hence the diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index bff571f5b5be..23280701013d 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -355,7 +355,7 @@ where assert_eq!(values.len(), batch_hashes.len()); for (value, &hash) in values.iter().zip(batch_hashes.iter()) { - // hande null value + // handle null value let Some(value) = value else { let payload = if let Some(&(payload, _offset)) = self.null.as_ref() { payload @@ -439,7 +439,7 @@ where // Put the small values into buffer and offsets so it // appears the output array, and store that offset // so the bytes can be compared if needed - let offset = self.buffer.len(); // offset of start fof data + let offset = self.buffer.len(); // offset of start for data self.buffer.append_slice(value); self.offsets.push(O::usize_as(self.buffer.len())); diff --git a/datafusion/physical-expr-common/src/expressions/column.rs b/datafusion/physical-expr-common/src/expressions/column.rs index 956c33d59b20..d972d35b9e4e 100644 --- a/datafusion/physical-expr-common/src/expressions/column.rs +++ b/datafusion/physical-expr-common/src/expressions/column.rs @@ -80,7 +80,7 @@ impl PhysicalExpr for Column { Ok(input_schema.field(self.index).data_type().clone()) } - /// Decide whehter this expression is nullable, given the schema of the input + /// Decide whether this expression is nullable, given the schema of the input fn nullable(&self, input_schema: &Schema) -> Result { self.bounds_check(input_schema)?; Ok(input_schema.field(self.index).is_nullable()) diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs index 9856e1c989b3..592c130b69d8 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs @@ -69,7 +69,7 @@ impl AccumulatorState { } } - /// Returns the amount of memory taken by this structre and its accumulator + /// Returns the amount of memory taken by this structure and its accumulator fn size(&self) -> usize { self.accumulator.size() + std::mem::size_of_val(self) diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index 65bb9e478c3d..9987e97b38d3 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -296,7 +296,7 @@ macro_rules! typed_min_max_batch_string { }}; } -// Statically-typed version of min/max(array) -> ScalarValue for binay types. +// Statically-typed version of min/max(array) -> ScalarValue for binary types. macro_rules! typed_min_max_batch_binary { ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ let array = downcast_value!($VALUES, $ARRAYTYPE); diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index e483f935b75c..ffa58e385322 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -67,7 +67,7 @@ impl ConstExpr { pub fn new(expr: Arc) -> Self { Self { expr, - // By default, assume constant expressions are not same accross partitions. + // By default, assume constant expressions are not same across partitions. across_partitions: false, } } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index cd73c5cb579c..7a434c940229 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -931,7 +931,7 @@ mod tests { } #[test] - fn case_tranform() -> Result<()> { + fn case_transform() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); let when1 = lit("foo"); diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index f6525c7c0462..38779c54607f 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -67,7 +67,7 @@ impl PhysicalExpr for UnKnownColumn { Ok(DataType::Null) } - /// Decide whehter this expression is nullable, given the schema of the input + /// Decide whether this expression is nullable, given the schema of the input fn nullable(&self, _input_schema: &Schema) -> Result { Ok(true) } diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index 3549a3df83bb..43b6c993d2b2 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -31,7 +31,7 @@ use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result, ScalarValue}; use datafusion_expr::ColumnarValue; -/// TRY_CAST expression casts an expression to a specific data type and retuns NULL on invalid cast +/// TRY_CAST expression casts an expression to a specific data type and returns NULL on invalid cast #[derive(Debug, Hash)] pub struct TryCastExpr { /// The expression to cast diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index 42e5e6fcf3ac..993ff5610063 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -283,7 +283,7 @@ impl<'a> GuaranteeBuilder<'a> { ) } - /// Aggregates a new single column, multi literal term to ths builder + /// Aggregates a new single column, multi literal term to this builder /// combining with previously known guarantees if possible. /// /// # Examples diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 96a12d7b62da..8c2a4ba5c497 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -190,7 +190,7 @@ impl GroupValues for GroupValuesRows { let groups_rows = group_values.iter().take(n); let output = self.row_converter.convert_rows(groups_rows)?; // Clear out first n group keys by copying them to a new Rows. - // TODO file some ticket in arrow-rs to make this more efficent? + // TODO file some ticket in arrow-rs to make this more efficient? let mut new_group_values = self.row_converter.empty_rows(0, 0); for row in group_values.iter().skip(n) { new_group_values.push(row); diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 5f780f1ff801..4146dda7641d 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -75,12 +75,12 @@ pub enum AggregateMode { /// Applies the entire logical aggregation operation in a single operator, /// as opposed to Partial / Final modes which apply the logical aggregation using /// two operators. - /// This mode requires tha the input is a single partition (like Final) + /// This mode requires that the input is a single partition (like Final) Single, /// Applies the entire logical aggregation operation in a single operator, /// as opposed to Partial / Final modes which apply the logical aggregation using /// two operators. - /// This mode requires tha the input is partitioned by group key (like FinalPartitioned) + /// This mode requires that the input is partitioned by group key (like FinalPartitioned) SinglePartitioned, } @@ -733,7 +733,7 @@ impl ExecutionPlan for AggregateExec { // - once expressions will be able to compute their own stats, use it here // - case where we group by on a column for which with have the `distinct` stat // TODO stats: aggr expression: - // - aggregations somtimes also preserve invariants such as min, max... + // - aggregations sometimes also preserve invariants such as min, max... let column_statistics = Statistics::unknown_column(&self.schema()); match self.mode { AggregateMode::Final | AggregateMode::FinalPartitioned diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index b4c1e25e6191..287446328f8d 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -206,7 +206,7 @@ impl ExecutionPlan for AnalyzeExec { } } -/// Creates the ouput of AnalyzeExec as a RecordBatch +/// Creates the output of AnalyzeExec as a RecordBatch fn create_output_batch( verbose: bool, show_statistics: bool, diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index 7f4ae5797d97..0d2653c5c775 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -236,7 +236,7 @@ enum ShowMetrics { /// Do not show any metrics None, - /// Show aggregrated metrics across partition + /// Show aggregated metrics across partition Aggregated, /// Show full per-partition metrics diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 33a9c061bf31..8304ddc7331a 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -578,7 +578,7 @@ mod tests { } #[tokio::test] - async fn test_stats_cartesian_product_with_unknwon_size() { + async fn test_stats_cartesian_product_with_unknown_size() { let left_row_count = 11; let left = Statistics { diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 754e55e49650..f8ca38980850 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -160,7 +160,7 @@ pub struct NestedLoopJoinExec { } impl NestedLoopJoinExec { - /// Try to create a nwe [`NestedLoopJoinExec`] + /// Try to create a new [`NestedLoopJoinExec`] pub fn try_new( left: Arc, right: Arc, diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index e9124a72970a..a03e4a83fd2d 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -634,7 +634,7 @@ struct SMJStream { pub buffered: SendableRecordBatchStream, /// Current processing record batch of streamed pub streamed_batch: StreamedBatch, - /// Currrent buffered data + /// Current buffered data pub buffered_data: BufferedData, /// (used in outer join) Is current streamed row joined at least once? pub streamed_joined: bool, diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index c23dc2032c4b..2299b7ff07f1 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -215,7 +215,7 @@ impl SymmetricHashJoinExec { let left_schema = left.schema(); let right_schema = right.schema(); - // Error out if no "on" contraints are given: + // Error out if no "on" constraints are given: if on.is_empty() { return plan_err!( "On constraints in SymmetricHashJoinExec should be non-empty" diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index e3ec242ce8de..51744730a5a1 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -145,7 +145,7 @@ impl JoinHashMap { pub(crate) type JoinHashMapOffset = (usize, Option); // Macro for traversing chained values with limit. -// Early returns in case of reacing output tuples limit. +// Early returns in case of reaching output tuples limit. macro_rules! chain_traverse { ( $input_indices:ident, $match_indices:ident, $hash_values:ident, $next_chain:ident, @@ -477,7 +477,7 @@ fn offset_ordering( offset: usize, ) -> Vec { match join_type { - // In the case below, right ordering should be offseted with the left + // In the case below, right ordering should be offsetted with the left // side length, since we append the right table to the left table. JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => ordering .iter() @@ -910,7 +910,7 @@ fn estimate_inner_join_cardinality( left_stats: Statistics, right_stats: Statistics, ) -> Option> { - // Immediatedly return if inputs considered as non-overlapping + // Immediately return if inputs considered as non-overlapping if let Some(estimation) = estimate_disjoint_inputs(&left_stats, &right_stats) { return Some(estimation); }; @@ -2419,7 +2419,7 @@ mod tests { ); assert!( absent_outer_estimation.is_none(), - "Expected \"None\" esimated SemiJoin cardinality for absent outer num_rows" + "Expected \"None\" estimated SemiJoin cardinality for absent outer num_rows" ); let absent_inner_estimation = estimate_join_cardinality( @@ -2437,7 +2437,7 @@ mod tests { &join_on, ).expect("Expected non-empty PartialJoinStatistics for SemiJoin with absent inner num_rows"); - assert_eq!(absent_inner_estimation.num_rows, 500, "Expected outer.num_rows esimated SemiJoin cardinality for absent inner num_rows"); + assert_eq!(absent_inner_estimation.num_rows, 500, "Expected outer.num_rows estimated SemiJoin cardinality for absent inner num_rows"); let absent_inner_estimation = estimate_join_cardinality( &JoinType::LeftSemi, @@ -2453,7 +2453,7 @@ mod tests { }, &join_on, ); - assert!(absent_inner_estimation.is_none(), "Expected \"None\" esimated SemiJoin cardinality for absent outer and inner num_rows"); + assert!(absent_inner_estimation.is_none(), "Expected \"None\" estimated SemiJoin cardinality for absent outer and inner num_rows"); Ok(()) } diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 9c77a3d05cc2..f3dad6afabde 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -393,7 +393,7 @@ impl ExecutionPlan for LocalLimitExec { .. } if nr <= self.fetch => input_stats, // if the input is greater than the limit, the num_row will be greater - // than the limit because the partitions will be limited separatly + // than the limit because the partitions will be limited separately // the statistic Statistics { num_rows: Precision::Exact(nr), diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index e5c506403ff6..4870e9e95eb5 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -1345,8 +1345,8 @@ mod tests { #[tokio::test] // As the hash results might be different on different platforms or - // wiht different compilers, we will compare the same execution with - // and without droping the output stream. + // with different compilers, we will compare the same execution with + // and without dropping the output stream. async fn hash_repartition_with_dropping_output_stream() { let task_ctx = Arc::new(TaskContext::default()); let partitioning = Partitioning::Hash( @@ -1357,7 +1357,7 @@ mod tests { 2, ); - // We first collect the results without droping the output stream. + // We first collect the results without dropping the output stream. let input = Arc::new(make_barrier_exec()); let exec = RepartitionExec::try_new( Arc::clone(&input) as Arc, diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 5b99f8bc7161..d576f77d9f74 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -602,7 +602,7 @@ pub fn sort_batch( .collect::>>()?; let indices = if is_multi_column_with_lists(&sort_columns) { - // lex_sort_to_indices doesn't support List with more than one colum + // lex_sort_to_indices doesn't support List with more than one column // https://github.com/apache/arrow-rs/issues/5454 lexsort_to_indices_multi_columns(sort_columns, fetch)? } else { @@ -802,12 +802,12 @@ impl DisplayAs for SortExec { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { let expr = PhysicalSortExpr::format_list(&self.expr); - let preserve_partioning = self.preserve_partitioning; + let preserve_partitioning = self.preserve_partitioning; match self.fetch { Some(fetch) => { - write!(f, "SortExec: TopK(fetch={fetch}), expr=[{expr}], preserve_partitioning=[{preserve_partioning}]",) + write!(f, "SortExec: TopK(fetch={fetch}), expr=[{expr}], preserve_partitioning=[{preserve_partitioning}]",) } - None => write!(f, "SortExec: expr=[{expr}], preserve_partitioning=[{preserve_partioning}]"), + None => write!(f, "SortExec: expr=[{expr}], preserve_partitioning=[{preserve_partitioning}]"), } } } diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index 5a9035c8dbfc..e10e5c9a6995 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -80,7 +80,7 @@ impl StreamingTableExec { if !schema.eq(partition_schema) { debug!( "Target schema does not match with partition schema. \ - Target_schema: {schema:?}. Partiton Schema: {partition_schema:?}" + Target_schema: {schema:?}. Partition Schema: {partition_schema:?}" ); return plan_err!("Mismatch between schema and batches"); } diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index ac4eb1ca9e58..cf1c0e313733 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -725,7 +725,7 @@ pub struct PanicExec { schema: SchemaRef, /// Number of output partitions. Each partition will produce this - /// many empty output record batches prior to panicing + /// many empty output record batches prior to panicking batches_until_panics: Vec, cache: PlanProperties, } diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 5366a5707696..d3f1a4fd96ca 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -94,7 +94,7 @@ pub struct TopK { impl TopK { /// Create a new [`TopK`] that stores the top `k` values, as /// defined by the sort expressions in `expr`. - // TOOD: make a builder or some other nicer API to avoid the + // TODO: make a builder or some other nicer API to avoid the // clippy warning #[allow(clippy::too_many_arguments)] pub fn try_new( @@ -258,7 +258,7 @@ impl TopKMetrics { /// Using the `Row` format handles things such as ascending vs /// descending and nulls first vs nulls last. struct TopKHeap { - /// The maximum number of elemenents to store in this heap. + /// The maximum number of elements to store in this heap. k: usize, /// The target number of rows for output batches batch_size: usize, @@ -421,7 +421,7 @@ impl TopKHeap { let num_rows = self.inner.len(); let (new_batch, mut topk_rows) = self.emit_with_state()?; - // clear all old entires in store (this invalidates all + // clear all old entries in store (this invalidates all // store_ids in `inner`) self.store.clear(); @@ -453,7 +453,7 @@ impl TopKHeap { /// Represents one of the top K rows held in this heap. Orders /// according to memcmp of row (e.g. the arrow Row format, but could -/// also be primtive values) +/// also be primitive values) /// /// Reuses allocations to minimize runtime overhead of creating new Vecs #[derive(Debug, PartialEq)] diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 7f794556a241..5eca7af19d16 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -805,7 +805,7 @@ mod tests { } #[tokio::test] - async fn test_satisfiy_nullable() -> Result<()> { + async fn test_satisfy_nullable() -> Result<()> { let schema = create_test_schema()?; let params = vec![ ((true, true), (false, false), false), diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs index b6330f65e0b7..1d5c6061a0f9 100644 --- a/datafusion/physical-plan/src/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -126,7 +126,7 @@ impl WindowAggExec { // Get output partitioning: // Because we can have repartitioning using the partition keys this - // would be either 1 or more than 1 depending on the presense of repartitioning. + // would be either 1 or more than 1 depending on the presence of repartitioning. let output_partitioning = input.output_partitioning().clone(); // Determine execution mode: diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs index 5f3cf6e2aee8..ba95640a87c7 100644 --- a/datafusion/physical-plan/src/work_table.rs +++ b/datafusion/physical-plan/src/work_table.rs @@ -225,7 +225,7 @@ mod tests { #[test] fn test_work_table() { let work_table = WorkTable::new(); - // cann't take from empty work_table + // can't take from empty work_table assert!(work_table.take().is_err()); let pool = Arc::new(UnboundedMemoryPool::default()) as _; From 4dd8532e6cd52c480a29a7851c6676a69f261545 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 19 Jul 2024 04:12:17 +0800 Subject: [PATCH 06/37] rm clone (#11532) Signed-off-by: jayzhan211 --- datafusion/optimizer/src/push_down_filter.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 33b2883d6ed8..a22f2e83e211 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1020,7 +1020,7 @@ impl OptimizerRule for PushDownFilter { /// ``` fn rewrite_projection( predicates: Vec, - projection: Projection, + mut projection: Projection, ) -> Result<(Transformed, Option)> { // A projection is filter-commutable if it do not contain volatile predicates or contain volatile // predicates that are not used in the filter. However, we should re-writes all predicate expressions. @@ -1053,11 +1053,13 @@ fn rewrite_projection( // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" let new_filter = LogicalPlan::Filter(Filter::try_new( replace_cols_by_name(expr, &non_volatile_map)?, - Arc::clone(&projection.input), + std::mem::take(&mut projection.input), )?); + projection.input = Arc::new(new_filter); + Ok(( - insert_below(LogicalPlan::Projection(projection), new_filter)?, + Transformed::yes(LogicalPlan::Projection(projection)), conjunction(keep_predicates), )) } From 723a595528e945c0ebc59a62ece2e24e90627764 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 18 Jul 2024 16:32:09 -0400 Subject: [PATCH 07/37] Minor: avoid a clone in type coercion (#11530) * Minor: avoid a clone in type coercion * Fix test --- .../optimizer/src/analyzer/type_coercion.rs | 18 ++++++++---------- datafusion/sqllogictest/test_files/misc.slt | 4 ++++ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 337492d1a55b..50fb1b8193ce 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -84,7 +84,7 @@ impl AnalyzerRule for TypeCoercion { /// Assumes that children have already been optimized fn analyze_internal( external_schema: &DFSchema, - mut plan: LogicalPlan, + plan: LogicalPlan, ) -> Result> { // get schema representing all available input fields. This is used for data type // resolution only, so order does not matter here @@ -103,15 +103,13 @@ fn analyze_internal( // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3) schema.merge(external_schema); - if let LogicalPlan::Filter(filter) = &mut plan { - if let Ok(new_predicate) = filter - .predicate - .clone() - .cast_to(&DataType::Boolean, filter.input.schema()) - { - filter.predicate = new_predicate; - } - } + // Coerce filter predicates to boolean (handles `WHERE NULL`) + let plan = if let LogicalPlan::Filter(mut filter) = plan { + filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?; + LogicalPlan::Filter(filter) + } else { + plan + }; let mut expr_rewrite = TypeCoercionRewriter::new(&schema); diff --git a/datafusion/sqllogictest/test_files/misc.slt b/datafusion/sqllogictest/test_files/misc.slt index 9f4710eb9bcc..9bd3023b56f7 100644 --- a/datafusion/sqllogictest/test_files/misc.slt +++ b/datafusion/sqllogictest/test_files/misc.slt @@ -30,6 +30,10 @@ query I select 1 where NULL ---- +# Where clause does not accept non boolean and has nice error message +query error Cannot create filter with non\-boolean predicate 'Utf8\("foo"\)' returning Utf8 +select 1 where 'foo' + query I select 1 where NULL and 1 = 1 ---- From 12d82c427d6c37f7884a508707ccd3058a446908 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 19 Jul 2024 11:59:01 +0800 Subject: [PATCH 08/37] Move array `ArrayAgg` to a `UserDefinedAggregate` (#11448) * Add input_nullable to UDAF args StateField/AccumulatorArgs This follows how it done for input_type and only provide a single value. But might need to be changed into a Vec in the future. This is need when we are moving `arrag_agg` to udaf where one of the states nullability will depend on the nullability of the input. * Make ArragAgg (not ordered or distinct) into a UDAF * Add roundtrip_expr_api test case * Address PR comments * Propegate input nullability for aggregates * Remove from accumulator args * first draft Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 * distinct Signed-off-by: jayzhan211 * fix Signed-off-by: jayzhan211 * address comment Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 Co-authored-by: Emil Ejbyfeldt --- datafusion/core/src/dataframe/mod.rs | 6 +- datafusion/core/src/physical_planner.rs | 28 ++ datafusion/core/tests/dataframe/mod.rs | 6 +- datafusion/expr/src/expr_fn.rs | 12 - .../functions-aggregate/src/array_agg.rs | 261 +++++++++++ datafusion/functions-aggregate/src/lib.rs | 8 +- datafusion/functions-array/src/planner.rs | 12 +- .../physical-expr/src/aggregate/array_agg.rs | 185 -------- .../src/aggregate/array_agg_distinct.rs | 433 ------------------ .../physical-expr/src/aggregate/build_in.rs | 80 +--- datafusion/physical-expr/src/aggregate/mod.rs | 2 - .../physical-expr/src/expressions/mod.rs | 159 ------- .../src/aggregates/no_grouping.rs | 1 + .../proto/src/physical_plan/to_proto.rs | 21 +- .../tests/cases/roundtrip_logical_plan.rs | 4 +- .../sqllogictest/test_files/aggregate.slt | 2 +- 16 files changed, 328 insertions(+), 892 deletions(-) create mode 100644 datafusion/functions-aggregate/src/array_agg.rs delete mode 100644 datafusion/physical-expr/src/aggregate/array_agg.rs delete mode 100644 datafusion/physical-expr/src/aggregate/array_agg_distinct.rs diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index c55b7c752765..fb28b5c1ab47 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1696,10 +1696,10 @@ mod tests { use datafusion_common::{Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, - ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition, + cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation, + Volatility, WindowFrame, WindowFunctionDefinition, }; - use datafusion_functions_aggregate::expr_fn::count_distinct; + use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 0accf9d83516..97533cd5276a 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1839,7 +1839,34 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; + // TODO: Remove this after array_agg are all udafs let (agg_expr, filter, order_by) = match func_def { + AggregateFunctionDefinition::UDF(udf) + if udf.name() == "ARRAY_AGG" && order_by.is_some() => + { + // not yet support UDAF, fallback to builtin + let physical_sort_exprs = match order_by { + Some(exprs) => Some(create_physical_sort_exprs( + exprs, + logical_input_schema, + execution_props, + )?), + None => None, + }; + let ordering_reqs: Vec = + physical_sort_exprs.clone().unwrap_or(vec![]); + let fun = aggregates::AggregateFunction::ArrayAgg; + let agg_expr = aggregates::create_aggregate_expr( + &fun, + *distinct, + &physical_args, + &ordering_reqs, + physical_input_schema, + name, + ignore_nulls, + )?; + (agg_expr, filter, physical_sort_exprs) + } AggregateFunctionDefinition::BuiltIn(fun) => { let physical_sort_exprs = match order_by { Some(exprs) => Some(create_physical_sort_exprs( @@ -1888,6 +1915,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( (agg_expr, filter, physical_sort_exprs) } }; + Ok((agg_expr, filter, order_by)) } other => internal_err!("Invalid aggregate expression '{other:?}'"), diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 9f7bd5227e34..d68b80691917 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -54,11 +54,11 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - array_agg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, + cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion_functions_aggregate::expr_fn::{avg, count, sum}; +use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { @@ -1389,7 +1389,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { let expected = vec![ "Projection: shapes.shape_id [shape_id:UInt32]", " Unnest: lists[shape_id2] structs[] [shape_id:UInt32, shape_id2:UInt32;N]", - " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", + " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", " TableScan: shapes projection=[shape_id] [shape_id:UInt32]", ]; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 8b0213fd52fd..9187e8352205 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -171,18 +171,6 @@ pub fn max(expr: Expr) -> Expr { )) } -/// Create an expression to represent the array_agg() aggregate function -pub fn array_agg(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ArrayAgg, - vec![expr], - false, - None, - None, - None, - )) -} - /// Return a new expression with bitwise AND pub fn bitwise_and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new( diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs new file mode 100644 index 000000000000..9ad453d7a4b2 --- /dev/null +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -0,0 +1,261 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`] + +use arrow::array::{Array, ArrayRef, AsArray}; +use arrow::datatypes::DataType; +use arrow_schema::Field; + +use datafusion_common::cast::as_list_array; +use datafusion_common::utils::array_into_list_array_nullable; +use datafusion_common::ScalarValue; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::AggregateUDFImpl; +use datafusion_expr::{Accumulator, Signature, Volatility}; +use std::collections::HashSet; +use std::sync::Arc; + +make_udaf_expr_and_func!( + ArrayAgg, + array_agg, + expression, + "input values, including nulls, concatenated into an array", + array_agg_udaf +); + +#[derive(Debug)] +/// ARRAY_AGG aggregate expression +pub struct ArrayAgg { + signature: Signature, + alias: Vec, +} + +impl Default for ArrayAgg { + fn default() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + alias: vec!["array_agg".to_string()], + } + } +} + +impl AggregateUDFImpl for ArrayAgg { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + // TODO: change name to lowercase + fn name(&self) -> &str { + "ARRAY_AGG" + } + + fn aliases(&self) -> &[String] { + &self.alias + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::List(Arc::new(Field::new( + "item", + arg_types[0].clone(), + true, + )))) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + if args.is_distinct { + return Ok(vec![Field::new_list( + format_state_name(args.name, "distinct_array_agg"), + Field::new("item", args.input_type.clone(), true), + true, + )]); + } + + Ok(vec![Field::new_list( + format_state_name(args.name, "array_agg"), + Field::new("item", args.input_type.clone(), true), + true, + )]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return Ok(Box::new(DistinctArrayAggAccumulator::try_new( + acc_args.input_type, + )?)); + } + + Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?)) + } +} + +#[derive(Debug)] +pub struct ArrayAggAccumulator { + values: Vec, + datatype: DataType, +} + +impl ArrayAggAccumulator { + /// new array_agg accumulator based on given item data type + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + values: vec![], + datatype: datatype.clone(), + }) + } +} + +impl Accumulator for ArrayAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // Append value like Int64Array(1,2,3) + if values.is_empty() { + return Ok(()); + } + + if values.len() != 1 { + return internal_err!("expects single batch"); + } + + let val = Arc::clone(&values[0]); + if val.len() > 0 { + self.values.push(val); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) + if states.is_empty() { + return Ok(()); + } + + if states.len() != 1 { + return internal_err!("expects single state"); + } + + let list_arr = as_list_array(&states[0])?; + for arr in list_arr.iter().flatten() { + self.values.push(arr); + } + Ok(()) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn evaluate(&mut self) -> Result { + // Transform Vec to ListArr + let element_arrays: Vec<&dyn Array> = + self.values.iter().map(|a| a.as_ref()).collect(); + + if element_arrays.is_empty() { + return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); + } + + let concated_array = arrow::compute::concat(&element_arrays)?; + let list_array = array_into_list_array_nullable(concated_array); + + Ok(ScalarValue::List(Arc::new(list_array))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.values.capacity()) + + self + .values + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum::() + + self.datatype.size() + - std::mem::size_of_val(&self.datatype) + } +} + +#[derive(Debug)] +struct DistinctArrayAggAccumulator { + values: HashSet, + datatype: DataType, +} + +impl DistinctArrayAggAccumulator { + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + values: HashSet::new(), + datatype: datatype.clone(), + }) + } +} + +impl Accumulator for DistinctArrayAggAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.len() != 1 { + return internal_err!("expects single batch"); + } + + let array = &values[0]; + + for i in 0..array.len() { + let scalar = ScalarValue::try_from_array(&array, i)?; + self.values.insert(scalar); + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + + if states.len() != 1 { + return internal_err!("expects single state"); + } + + states[0] + .as_list::() + .iter() + .flatten() + .try_for_each(|val| self.update_batch(&[val])) + } + + fn evaluate(&mut self) -> Result { + let values: Vec = self.values.iter().cloned().collect(); + if values.is_empty() { + return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); + } + let arr = ScalarValue::new_list(&values, &self.datatype, true); + Ok(ScalarValue::List(arr)) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + ScalarValue::size_of_hashset(&self.values) + - std::mem::size_of_val(&self.values) + + self.datatype.size() + - std::mem::size_of_val(&self.datatype) + } +} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index a3808a08b007..b39b1955bb07 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -58,6 +58,7 @@ pub mod macros; pub mod approx_distinct; +pub mod array_agg; pub mod correlation; pub mod count; pub mod covariance; @@ -93,6 +94,7 @@ pub mod expr_fn { pub use super::approx_median::approx_median; pub use super::approx_percentile_cont::approx_percentile_cont; pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight; + pub use super::array_agg::array_agg; pub use super::average::avg; pub use super::bit_and_or_xor::bit_and; pub use super::bit_and_or_xor::bit_or; @@ -128,6 +130,7 @@ pub mod expr_fn { /// Returns all default aggregate functions pub fn all_default_aggregate_functions() -> Vec> { vec![ + array_agg::array_agg_udaf(), first_last::first_value_udaf(), first_last::last_value_udaf(), covariance::covar_samp_udaf(), @@ -191,8 +194,9 @@ mod tests { let mut names = HashSet::new(); for func in all_default_aggregate_functions() { // TODO: remove this - // These functions are in intermediate migration state, skip them - if func.name().to_lowercase() == "count" { + // These functions are in intermidiate migration state, skip them + let name_lower_case = func.name().to_lowercase(); + if name_lower_case == "count" || name_lower_case == "array_agg" { continue; } assert!( diff --git a/datafusion/functions-array/src/planner.rs b/datafusion/functions-array/src/planner.rs index cfbe99b4b7fd..dfb620f84f3a 100644 --- a/datafusion/functions-array/src/planner.rs +++ b/datafusion/functions-array/src/planner.rs @@ -19,8 +19,9 @@ use datafusion_common::{utils::list_ndims, DFSchema, Result}; use datafusion_expr::{ + expr::AggregateFunctionDefinition, planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr}, - sqlparser, AggregateFunction, Expr, ExprSchemable, GetFieldAccess, + sqlparser, Expr, ExprSchemable, GetFieldAccess, }; use datafusion_functions::expr_fn::get_field; use datafusion_functions_aggregate::nth_value::nth_value_udaf; @@ -153,8 +154,9 @@ impl ExprPlanner for FieldAccessPlanner { } fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { - agg_func.func_def - == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn( - AggregateFunction::ArrayAgg, - ) + if let AggregateFunctionDefinition::UDF(udf) = &agg_func.func_def { + return udf.name() == "ARRAY_AGG"; + } + + false } diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs deleted file mode 100644 index 0d5ed730e283..000000000000 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ /dev/null @@ -1,185 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; -use datafusion_common::cast::as_list_array; -use datafusion_common::utils::array_into_list_array_nullable; -use datafusion_common::Result; -use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; - -/// ARRAY_AGG aggregate expression -#[derive(Debug)] -pub struct ArrayAgg { - /// Column name - name: String, - /// The DataType for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, -} - -impl ArrayAgg { - /// Create a new ArrayAgg aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - input_data_type: data_type, - expr, - } - } -} - -impl AggregateExpr for ArrayAgg { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new_list( - &self.name, - // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - true, - )) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(ArrayAggAccumulator::try_new( - &self.input_data_type, - )?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new_list( - format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), true), - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for ArrayAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -pub(crate) struct ArrayAggAccumulator { - values: Vec, - datatype: DataType, -} - -impl ArrayAggAccumulator { - /// new array_agg accumulator based on given item data type - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - values: vec![], - datatype: datatype.clone(), - }) - } -} - -impl Accumulator for ArrayAggAccumulator { - // Append value like Int64Array(1,2,3) - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - assert!(values.len() == 1, "array_agg can only take 1 param!"); - - let val = Arc::clone(&values[0]); - if val.len() > 0 { - self.values.push(val); - } - Ok(()) - } - - // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - assert!(states.len() == 1, "array_agg states must be singleton!"); - - let list_arr = as_list_array(&states[0])?; - for arr in list_arr.iter().flatten() { - self.values.push(arr); - } - Ok(()) - } - - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn evaluate(&mut self) -> Result { - // Transform Vec to ListArr - let element_arrays: Vec<&dyn Array> = - self.values.iter().map(|a| a.as_ref()).collect(); - - if element_arrays.is_empty() { - return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); - } - - let concated_array = arrow::compute::concat(&element_arrays)?; - let list_array = array_into_list_array_nullable(concated_array); - - Ok(ScalarValue::List(Arc::new(list_array))) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .map(|arr| arr.get_array_memory_size()) - .sum::() - + self.datatype.size() - - std::mem::size_of_val(&self.datatype) - } -} diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs deleted file mode 100644 index eca6e4ce4f65..000000000000 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ /dev/null @@ -1,433 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)` - -use std::any::Any; -use std::collections::HashSet; -use std::fmt::Debug; -use std::sync::Arc; - -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::cast::AsArray; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; - -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::Accumulator; - -/// Expression for a ARRAY_AGG(DISTINCT) aggregation. -#[derive(Debug)] -pub struct DistinctArrayAgg { - /// Column name - name: String, - /// The DataType for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, -} - -impl DistinctArrayAgg { - /// Create a new DistinctArrayAgg aggregate function - pub fn new( - expr: Arc, - name: impl Into, - input_data_type: DataType, - ) -> Self { - let name = name.into(); - Self { - name, - input_data_type, - expr, - } - } -} - -impl AggregateExpr for DistinctArrayAgg { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new_list( - &self.name, - // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - true, - )) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(DistinctArrayAggAccumulator::try_new( - &self.input_data_type, - )?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new_list( - format_state_name(&self.name, "distinct_array_agg"), - Field::new("item", self.input_data_type.clone(), true), - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for DistinctArrayAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -struct DistinctArrayAggAccumulator { - values: HashSet, - datatype: DataType, -} - -impl DistinctArrayAggAccumulator { - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - values: HashSet::new(), - datatype: datatype.clone(), - }) - } -} - -impl Accumulator for DistinctArrayAggAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - assert_eq!(values.len(), 1, "batch input should only include 1 column!"); - - let array = &values[0]; - - for i in 0..array.len() { - let scalar = ScalarValue::try_from_array(&array, i)?; - self.values.insert(scalar); - } - - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - states[0] - .as_list::() - .iter() - .flatten() - .try_for_each(|val| self.update_batch(&[val])) - } - - fn evaluate(&mut self) -> Result { - let values: Vec = self.values.iter().cloned().collect(); - if values.is_empty() { - return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); - } - let arr = ScalarValue::new_list(&values, &self.datatype, true); - Ok(ScalarValue::List(arr)) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) + ScalarValue::size_of_hashset(&self.values) - - std::mem::size_of_val(&self.values) - + self.datatype.size() - - std::mem::size_of_val(&self.datatype) - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use arrow::array::Int32Array; - use arrow::datatypes::Schema; - use arrow::record_batch::RecordBatch; - use arrow_array::types::Int32Type; - use arrow_array::Array; - use arrow_array::ListArray; - use arrow_buffer::OffsetBuffer; - use datafusion_common::internal_err; - - // arrow::compute::sort can't sort nested ListArray directly, so we compare the scalar values pair-wise. - fn compare_list_contents( - expected: Vec, - actual: ScalarValue, - ) -> Result<()> { - let array = actual.to_array()?; - let list_array = array.as_list::(); - let inner_array = list_array.value(0); - let mut actual_scalars = vec![]; - for index in 0..inner_array.len() { - let sv = ScalarValue::try_from_array(&inner_array, index)?; - actual_scalars.push(sv); - } - - if actual_scalars.len() != expected.len() { - return internal_err!( - "Expected and actual list lengths differ: expected={}, actual={}", - expected.len(), - actual_scalars.len() - ); - } - - let mut seen = vec![false; expected.len()]; - for v in expected { - let mut found = false; - for (i, sv) in actual_scalars.iter().enumerate() { - if sv == &v { - seen[i] = true; - found = true; - break; - } - } - if !found { - return internal_err!( - "Expected value {:?} not found in actual values {:?}", - v, - actual_scalars - ); - } - } - - Ok(()) - } - - fn check_distinct_array_agg( - input: ArrayRef, - expected: Vec, - datatype: DataType, - ) -> Result<()> { - let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![input])?; - - let agg = Arc::new(DistinctArrayAgg::new( - col("a", &schema)?, - "bla".to_string(), - datatype, - )); - let actual = aggregate(&batch, agg)?; - compare_list_contents(expected, actual) - } - - fn check_merge_distinct_array_agg( - input1: ArrayRef, - input2: ArrayRef, - expected: Vec, - datatype: DataType, - ) -> Result<()> { - let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]); - let agg = Arc::new(DistinctArrayAgg::new( - col("a", &schema)?, - "bla".to_string(), - datatype, - )); - - let mut accum1 = agg.create_accumulator()?; - let mut accum2 = agg.create_accumulator()?; - - accum1.update_batch(&[input1])?; - accum2.update_batch(&[input2])?; - - let array = accum2.state()?[0].raw_data()?; - accum1.merge_batch(&[array])?; - - let actual = accum1.evaluate()?; - compare_list_contents(expected, actual) - } - - #[test] - fn distinct_array_agg_i32() -> Result<()> { - let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); - - let expected = vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ScalarValue::Int32(Some(7)), - ]; - - check_distinct_array_agg(col, expected, DataType::Int32) - } - - #[test] - fn merge_distinct_array_agg_i32() -> Result<()> { - let col1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); - let col2: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 7, 8, 4])); - - let expected = vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(3)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ScalarValue::Int32(Some(7)), - ScalarValue::Int32(Some(8)), - ]; - - check_merge_distinct_array_agg(col1, col2, expected, DataType::Int32) - } - - #[test] - fn distinct_array_agg_nested() -> Result<()> { - // [[1, 2, 3], [4, 5]] - let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - Some(3), - ])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(4), - Some(5), - ])]); - let l1 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([2]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - // [[6], [7, 8]] - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(6)])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(7), - Some(8), - ])]); - let l2 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([2]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - // [[9]] - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(9)])]); - let l3 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([1]), - Arc::new(a1), - None, - ); - - let l1 = ScalarValue::List(Arc::new(l1)); - let l2 = ScalarValue::List(Arc::new(l2)); - let l3 = ScalarValue::List(Arc::new(l3)); - - // Duplicate l1 and l3 in the input array and check that it is deduped in the output. - let array = ScalarValue::iter_to_array(vec![ - l1.clone(), - l2.clone(), - l3.clone(), - l3.clone(), - l1.clone(), - ]) - .unwrap(); - let expected = vec![l1, l2, l3]; - - check_distinct_array_agg( - array, - expected, - DataType::List(Arc::new(Field::new_list( - "item", - Field::new("item", DataType::Int32, true), - true, - ))), - ) - } - - #[test] - fn merge_distinct_array_agg_nested() -> Result<()> { - // [[1, 2], [3, 4]] - let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - ])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(3), - Some(4), - ])]); - let l1 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([2]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(5)])]); - let l2 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([1]), - Arc::new(a1), - None, - ); - - // [[6, 7], [8]] - let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(6), - Some(7), - ])]); - let a2 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(8)])]); - let l3 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([2]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - let l1 = ScalarValue::List(Arc::new(l1)); - let l2 = ScalarValue::List(Arc::new(l2)); - let l3 = ScalarValue::List(Arc::new(l3)); - - // Duplicate l1 in the input array and check that it is deduped in the output. - let input1 = ScalarValue::iter_to_array(vec![l1.clone(), l2.clone()]).unwrap(); - let input2 = ScalarValue::iter_to_array(vec![l1.clone(), l3.clone()]).unwrap(); - - let expected = vec![l1, l2, l3]; - - check_merge_distinct_array_agg(input1, input2, expected, DataType::Int32) - } -} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index ef21b3d0f788..9c270561f37d 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -30,7 +30,7 @@ use std::sync::Arc; use arrow::datatypes::Schema; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::{internal_err, Result}; use datafusion_expr::AggregateFunction; use crate::expressions::{self}; @@ -60,11 +60,13 @@ pub fn create_aggregate_expr( .collect::>>()?; let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { - (AggregateFunction::ArrayAgg, false) => { + (AggregateFunction::ArrayAgg, _) => { let expr = Arc::clone(&input_phy_exprs[0]); if ordering_req.is_empty() { - Arc::new(expressions::ArrayAgg::new(expr, name, data_type)) + return internal_err!( + "ArrayAgg without ordering should be handled as UDAF" + ); } else { Arc::new(expressions::OrderSensitiveArrayAgg::new( expr, @@ -75,15 +77,6 @@ pub fn create_aggregate_expr( )) } } - (AggregateFunction::ArrayAgg, true) => { - if !ordering_req.is_empty() { - return not_impl_err!( - "ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available" - ); - } - let expr = Arc::clone(&input_phy_exprs[0]); - Arc::new(expressions::DistinctArrayAgg::new(expr, name, data_type)) - } (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( Arc::clone(&input_phy_exprs[0]), name, @@ -104,70 +97,9 @@ mod tests { use datafusion_common::plan_err; use datafusion_expr::{type_coercion, Signature}; - use crate::expressions::{try_cast, ArrayAgg, DistinctArrayAgg, Max, Min}; + use crate::expressions::{try_cast, Max, Min}; use super::*; - #[test] - fn test_approx_expr() -> Result<()> { - let funcs = vec![AggregateFunction::ArrayAgg]; - let data_types = vec![ - DataType::UInt32, - DataType::Int32, - DataType::Float32, - DataType::Float64, - DataType::Decimal128(10, 2), - DataType::Utf8, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::ArrayAgg { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } - - let result_distinct = create_physical_agg_expr_for_test( - &fun, - true, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::ArrayAgg { - assert!(result_distinct.as_any().is::()); - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } - } - } - Ok(()) - } #[test] fn test_min_max_expr() -> Result<()> { diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index b9d803900f53..749cf2be7297 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -17,8 +17,6 @@ pub use datafusion_physical_expr_common::aggregate::AggregateExpr; -pub(crate) mod array_agg; -pub(crate) mod array_agg_distinct; pub(crate) mod array_agg_ordered; #[macro_use] pub(crate) mod min_max; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 7d8f12091f46..fa80bc9873f0 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -34,8 +34,6 @@ mod try_cast; pub mod helpers { pub use crate::aggregate::min_max::{max, min}; } -pub use crate::aggregate::array_agg::ArrayAgg; -pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; pub use crate::aggregate::build_in::create_aggregate_expr; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; @@ -63,160 +61,3 @@ pub use negative::{negative, NegativeExpr}; pub use no_op::NoOp; pub use not::{not, NotExpr}; pub use try_cast::{try_cast, TryCastExpr}; - -#[cfg(test)] -pub(crate) mod tests { - use std::sync::Arc; - - use crate::AggregateExpr; - - use arrow::record_batch::RecordBatch; - use datafusion_common::{Result, ScalarValue}; - - /// macro to perform an aggregation using [`datafusion_expr::Accumulator`] and verify the - /// result. - #[macro_export] - macro_rules! generic_test_op { - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { - generic_test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type()) - }; - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ - let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); - - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; - - let agg = Arc::new(<$OP>::new( - col("a", &schema)?, - "bla".to_string(), - $EXPECTED_DATATYPE, - )); - let actual = aggregate(&batch, agg)?; - let expected = ScalarValue::from($EXPECTED); - - assert_eq!(expected, actual); - - Ok(()) as Result<(), ::datafusion_common::DataFusionError> - }}; - } - - /// Same as [`generic_test_op`] but with support for providing a 4th argument, usually - /// a boolean to indicate if using the distinct version of the op. - #[macro_export] - macro_rules! generic_test_distinct_op { - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $DISTINCT:expr, $EXPECTED:expr) => { - generic_test_distinct_op!( - $ARRAY, - $DATATYPE, - $OP, - $DISTINCT, - $EXPECTED, - $EXPECTED.data_type() - ) - }; - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $DISTINCT:expr, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ - let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); - - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; - - let agg = Arc::new(<$OP>::new( - col("a", &schema)?, - "bla".to_string(), - $EXPECTED_DATATYPE, - $DISTINCT, - )); - let actual = aggregate(&batch, agg)?; - let expected = ScalarValue::from($EXPECTED); - - assert_eq!(expected, actual); - - Ok(()) as Result<(), ::datafusion_common::DataFusionError> - }}; - } - - /// macro to perform an aggregation using [`crate::GroupsAccumulator`] and verify the result. - /// - /// The difference between this and the above `generic_test_op` is that the former checks - /// the old slow-path [`datafusion_expr::Accumulator`] implementation, while this checks - /// the new [`crate::GroupsAccumulator`] implementation. - #[macro_export] - macro_rules! generic_test_op_new { - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { - generic_test_op_new!( - $ARRAY, - $DATATYPE, - $OP, - $EXPECTED, - $EXPECTED.data_type().clone() - ) - }; - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ - let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); - - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; - - let agg = Arc::new(<$OP>::new( - col("a", &schema)?, - "bla".to_string(), - $EXPECTED_DATATYPE, - )); - let actual = aggregate_new(&batch, agg)?; - assert_eq!($EXPECTED, &actual); - - Ok(()) as Result<(), ::datafusion_common::DataFusionError> - }}; - } - - /// macro to perform an aggregation with two inputs and verify the result. - #[macro_export] - macro_rules! generic_test_op2 { - ($ARRAY1:expr, $ARRAY2:expr, $DATATYPE1:expr, $DATATYPE2:expr, $OP:ident, $EXPECTED:expr) => { - generic_test_op2!( - $ARRAY1, - $ARRAY2, - $DATATYPE1, - $DATATYPE2, - $OP, - $EXPECTED, - $EXPECTED.data_type() - ) - }; - ($ARRAY1:expr, $ARRAY2:expr, $DATATYPE1:expr, $DATATYPE2:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ - let schema = Schema::new(vec![ - Field::new("a", $DATATYPE1, true), - Field::new("b", $DATATYPE2, true), - ]); - let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY1, $ARRAY2])?; - - let agg = Arc::new(<$OP>::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - $EXPECTED_DATATYPE, - )); - let actual = aggregate(&batch, agg)?; - let expected = ScalarValue::from($EXPECTED); - - assert_eq!(expected, actual); - - Ok(()) - }}; - } - - pub fn aggregate( - batch: &RecordBatch, - agg: Arc, - ) -> Result { - let mut accum = agg.create_accumulator()?; - let expr = agg.expressions(); - let values = expr - .iter() - .map(|e| { - e.evaluate(batch) - .and_then(|v| v.into_array(batch.num_rows())) - }) - .collect::>>()?; - accum.update_batch(&values)?; - accum.evaluate() - } -} diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs index f85164f7f1e2..99417e4ee3e9 100644 --- a/datafusion/physical-plan/src/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -218,6 +218,7 @@ fn aggregate_batch( Some(filter) => Cow::Owned(batch_filter(&batch, filter)?), None => Cow::Borrowed(&batch), }; + // 1.3 let values = &expr .iter() diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 7ea2902cf3c0..e9a90fce2663 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,10 +23,9 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - ArrayAgg, BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, DistinctArrayAgg, - InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, - NthValue, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr, - WindowShift, + BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, InListExpr, IsNotNullExpr, + IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, Ntile, + OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -260,14 +259,9 @@ struct AggrFn { fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let aggr_expr = expr.as_any(); - let mut distinct = false; - let inner = if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ArrayAgg - } else if aggr_expr.downcast_ref::().is_some() { - distinct = true; - protobuf::AggregateFunction::ArrayAgg - } else if aggr_expr.downcast_ref::().is_some() { + // TODO: remove OrderSensitiveArrayAgg + let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::ArrayAgg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Min @@ -277,7 +271,10 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { return not_impl_err!("Aggregate function not supported: {expr:?}"); }; - Ok(AggrFn { inner, distinct }) + Ok(AggrFn { + inner, + distinct: false, + }) } pub fn serialize_physical_sort_exprs( diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 0117502f400d..11945f39589a 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -66,7 +66,7 @@ use datafusion_expr::{ }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::expr_fn::{ - avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr, + array_agg, avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr, }; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_proto::bytes::{ @@ -702,6 +702,8 @@ async fn roundtrip_expr_api() -> Result<()> { string_agg(col("a").cast_to(&DataType::Utf8, &schema)?, lit("|")), bool_and(lit(true)), bool_or(lit(true)), + array_agg(lit(1)), + array_agg(lit(1)).distinct().build().unwrap(), ]; // ensure expressions created with the expr api can be round tripped diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a0140b1c5292..1976951b8ce6 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -183,7 +183,7 @@ CREATE TABLE array_agg_distinct_list_table AS VALUES ; # Apply array_sort to have deterministic result, higher dimension nested array also works but not for array sort, -# so they are covered in `datafusion/physical-expr/src/aggregate/array_agg_distinct.rs` +# so they are covered in `datafusion/functions-aggregate/src/array_agg.rs` query ?? select array_sort(c1), array_sort(c2) from ( select array_agg(distinct column1) as c1, array_agg(distinct column2) as c2 from array_agg_distinct_list_table From 5f0dfbb8e7424964303b00f4781f8df4f445d928 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 19 Jul 2024 17:32:24 +0800 Subject: [PATCH 09/37] Move `MAKE_MAP` to ExprPlanner (#11452) * move make_map to ExprPlanner * add benchmark for make_map * remove todo comment * update lock * refactor plan_make_map * implement make_array_strict for type checking strictly * fix planner provider * roll back to `make_array` * update lock --- datafusion-cli/Cargo.lock | 9 +- datafusion/expr/src/planner.rs | 7 + datafusion/functions-array/Cargo.toml | 5 + datafusion/functions-array/benches/map.rs | 69 ++++++++++ datafusion/functions-array/src/planner.rs | 21 ++- datafusion/functions/benches/map.rs | 23 +--- datafusion/functions/src/core/map.rs | 149 +-------------------- datafusion/functions/src/core/mod.rs | 6 - datafusion/sql/src/expr/function.rs | 12 ++ datafusion/sqllogictest/test_files/map.slt | 14 +- 10 files changed, 131 insertions(+), 184 deletions(-) create mode 100644 datafusion/functions-array/benches/map.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index cdf0e7f57316..61d9c72b89d9 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1331,6 +1331,7 @@ dependencies = [ "itertools", "log", "paste", + "rand", ] [[package]] @@ -3593,18 +3594,18 @@ checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" [[package]] name = "thiserror" -version = "1.0.62" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2675633b1499176c2dff06b0856a27976a8f9d436737b4cf4f312d4d91d8bbb" +checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.62" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d20468752b09f49e909e55a5d338caa8bedf615594e9d80bc4c565d30faf798c" +checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 009f3512c588..415af1bf94dc 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -173,6 +173,13 @@ pub trait ExprPlanner: Send + Sync { fn plan_overlay(&self, args: Vec) -> Result>> { Ok(PlannerResult::Original(args)) } + + /// Plan a make_map expression, e.g., `make_map(key1, value1, key2, value2, ...)` + /// + /// Returns origin expression arguments if not possible + fn plan_make_map(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Original(args)) + } } /// An operator with two arguments to plan diff --git a/datafusion/functions-array/Cargo.toml b/datafusion/functions-array/Cargo.toml index 73c5b9114a2c..de424b259694 100644 --- a/datafusion/functions-array/Cargo.toml +++ b/datafusion/functions-array/Cargo.toml @@ -53,6 +53,7 @@ datafusion-functions-aggregate = { workspace = true } itertools = { version = "0.12", features = ["use_std"] } log = { workspace = true } paste = "1.0.14" +rand = "0.8.5" [dev-dependencies] criterion = { version = "0.5", features = ["async_tokio"] } @@ -60,3 +61,7 @@ criterion = { version = "0.5", features = ["async_tokio"] } [[bench]] harness = false name = "array_expression" + +[[bench]] +harness = false +name = "map" diff --git a/datafusion/functions-array/benches/map.rs b/datafusion/functions-array/benches/map.rs new file mode 100644 index 000000000000..2e9b45266abc --- /dev/null +++ b/datafusion/functions-array/benches/map.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::prelude::ThreadRng; +use rand::Rng; + +use datafusion_common::ScalarValue; +use datafusion_expr::planner::ExprPlanner; +use datafusion_expr::Expr; +use datafusion_functions_array::planner::ArrayFunctionPlanner; + +fn keys(rng: &mut ThreadRng) -> Vec { + let mut keys = vec![]; + for _ in 0..1000 { + keys.push(rng.gen_range(0..9999).to_string()); + } + keys +} + +fn values(rng: &mut ThreadRng) -> Vec { + let mut values = vec![]; + for _ in 0..1000 { + values.push(rng.gen_range(0..9999)); + } + values +} + +fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("make_map_1000", |b| { + let mut rng = rand::thread_rng(); + let keys = keys(&mut rng); + let values = values(&mut rng); + let mut buffer = Vec::new(); + for i in 0..1000 { + buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); + buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + } + + let planner = ArrayFunctionPlanner {}; + + b.iter(|| { + black_box( + planner + .plan_make_map(buffer.clone()) + .expect("map should work on valid values"), + ); + }); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-array/src/planner.rs b/datafusion/functions-array/src/planner.rs index dfb620f84f3a..fbb541d9b151 100644 --- a/datafusion/functions-array/src/planner.rs +++ b/datafusion/functions-array/src/planner.rs @@ -17,7 +17,8 @@ //! SQL planning extensions like [`ArrayFunctionPlanner`] and [`FieldAccessPlanner`] -use datafusion_common::{utils::list_ndims, DFSchema, Result}; +use datafusion_common::{exec_err, utils::list_ndims, DFSchema, Result}; +use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ expr::AggregateFunctionDefinition, planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr}, @@ -98,6 +99,24 @@ impl ExprPlanner for ArrayFunctionPlanner { ) -> Result>> { Ok(PlannerResult::Planned(make_array(exprs))) } + + fn plan_make_map(&self, args: Vec) -> Result>> { + if args.len() % 2 != 0 { + return exec_err!("make_map requires an even number of arguments"); + } + + let (keys, values): (Vec<_>, Vec<_>) = + args.into_iter().enumerate().partition(|(i, _)| i % 2 == 0); + let keys = make_array(keys.into_iter().map(|(_, e)| e).collect()); + let values = make_array(values.into_iter().map(|(_, e)| e).collect()); + + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf( + datafusion_functions::core::map(), + vec![keys, values], + ), + ))) + } } pub struct FieldAccessPlanner; diff --git a/datafusion/functions/benches/map.rs b/datafusion/functions/benches/map.rs index cd863d0e3311..811c21a41b46 100644 --- a/datafusion/functions/benches/map.rs +++ b/datafusion/functions/benches/map.rs @@ -23,7 +23,7 @@ use arrow_buffer::{OffsetBuffer, ScalarBuffer}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; -use datafusion_functions::core::{make_map, map}; +use datafusion_functions::core::map; use rand::prelude::ThreadRng; use rand::Rng; use std::sync::Arc; @@ -45,27 +45,6 @@ fn values(rng: &mut ThreadRng) -> Vec { } fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("make_map_1000", |b| { - let mut rng = rand::thread_rng(); - let keys = keys(&mut rng); - let values = values(&mut rng); - let mut buffer = Vec::new(); - for i in 0..1000 { - buffer.push(ColumnarValue::Scalar(ScalarValue::Utf8(Some( - keys[i].clone(), - )))); - buffer.push(ColumnarValue::Scalar(ScalarValue::Int32(Some(values[i])))); - } - - b.iter(|| { - black_box( - make_map() - .invoke(&buffer) - .expect("map should work on valid values"), - ); - }); - }); - c.bench_function("map_1000", |b| { let mut rng = rand::thread_rng(); let field = Arc::new(Field::new("item", DataType::Utf8, true)); diff --git a/datafusion/functions/src/core/map.rs b/datafusion/functions/src/core/map.rs index 1834c7ac6060..2deef242f8a0 100644 --- a/datafusion/functions/src/core/map.rs +++ b/datafusion/functions/src/core/map.rs @@ -20,12 +20,11 @@ use std::collections::VecDeque; use std::sync::Arc; use arrow::array::{Array, ArrayData, ArrayRef, MapArray, StructArray}; -use arrow::compute::concat; use arrow::datatypes::{DataType, Field, SchemaBuilder}; use arrow_buffer::{Buffer, ToByteSlice}; -use datafusion_common::{exec_err, internal_err, ScalarValue}; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::Result; +use datafusion_common::{exec_err, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; /// Check if we can evaluate the expr to constant directly. @@ -40,41 +39,6 @@ fn can_evaluate_to_const(args: &[ColumnarValue]) -> bool { .all(|arg| matches!(arg, ColumnarValue::Scalar(_))) } -fn make_map(args: &[ColumnarValue]) -> Result { - let can_evaluate_to_const = can_evaluate_to_const(args); - - let (key, value): (Vec<_>, Vec<_>) = args - .chunks_exact(2) - .map(|chunk| { - if let ColumnarValue::Array(_) = chunk[0] { - return not_impl_err!("make_map does not support array keys"); - } - if let ColumnarValue::Array(_) = chunk[1] { - return not_impl_err!("make_map does not support array values"); - } - Ok((chunk[0].clone(), chunk[1].clone())) - }) - .collect::>>()? - .into_iter() - .unzip(); - - let keys = ColumnarValue::values_to_arrays(&key)?; - let values = ColumnarValue::values_to_arrays(&value)?; - - let keys: Vec<_> = keys.iter().map(|k| k.as_ref()).collect(); - let values: Vec<_> = values.iter().map(|v| v.as_ref()).collect(); - - let key = match concat(&keys) { - Ok(key) => key, - Err(e) => return internal_err!("Error concatenating keys: {}", e), - }; - let value = match concat(&values) { - Ok(value) => value, - Err(e) => return internal_err!("Error concatenating values: {}", e), - }; - make_map_batch_internal(key, value, can_evaluate_to_const) -} - fn make_map_batch(args: &[ColumnarValue]) -> Result { if args.len() != 2 { return exec_err!( @@ -154,115 +118,6 @@ fn make_map_batch_internal( }) } -#[derive(Debug)] -pub struct MakeMap { - signature: Signature, -} - -impl Default for MakeMap { - fn default() -> Self { - Self::new() - } -} - -impl MakeMap { - pub fn new() -> Self { - Self { - signature: Signature::user_defined(Volatility::Immutable), - } - } -} - -impl ScalarUDFImpl for MakeMap { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - "make_map" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.is_empty() { - return exec_err!( - "make_map requires at least one pair of arguments, got 0 instead" - ); - } - if arg_types.len() % 2 != 0 { - return exec_err!( - "make_map requires an even number of arguments, got {} instead", - arg_types.len() - ); - } - - let key_type = &arg_types[0]; - let mut value_type = &arg_types[1]; - - for (i, chunk) in arg_types.chunks_exact(2).enumerate() { - if chunk[0].is_null() { - return exec_err!("make_map key cannot be null at position {}", i); - } - if &chunk[0] != key_type { - return exec_err!( - "make_map requires all keys to have the same type {}, got {} instead at position {}", - key_type, - chunk[0], - i - ); - } - - if !chunk[1].is_null() { - if value_type.is_null() { - value_type = &chunk[1]; - } else if &chunk[1] != value_type { - return exec_err!( - "map requires all values to have the same type {}, got {} instead at position {}", - value_type, - &chunk[1], - i - ); - } - } - } - - let mut result = Vec::new(); - for _ in 0..arg_types.len() / 2 { - result.push(key_type.clone()); - result.push(value_type.clone()); - } - - Ok(result) - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - let key_type = &arg_types[0]; - let mut value_type = &arg_types[1]; - - for chunk in arg_types.chunks_exact(2) { - if !chunk[1].is_null() && value_type.is_null() { - value_type = &chunk[1]; - } - } - - let mut builder = SchemaBuilder::new(); - builder.push(Field::new("key", key_type.clone(), false)); - builder.push(Field::new("value", value_type.clone(), true)); - let fields = builder.finish().fields; - Ok(DataType::Map( - Arc::new(Field::new("entries", DataType::Struct(fields), false)), - false, - )) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - make_map(args) - } -} - #[derive(Debug)] pub struct MapFunc { signature: Signature, diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 31bce04beec1..cbfaa592b012 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -43,7 +43,6 @@ make_udf_function!(r#struct::StructFunc, STRUCT, r#struct); make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct); make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce); -make_udf_function!(map::MakeMap, MAKE_MAP, make_map); make_udf_function!(map::MapFunc, MAP, map); pub mod expr_fn { @@ -81,10 +80,6 @@ pub mod expr_fn { coalesce, "Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL", args, - ),( - make_map, - "Returns a map created from the given keys and values pairs. This function isn't efficient for large maps. Use the `map` function instead.", - args, ),( map, "Returns a map created from a key list and a value list", @@ -107,7 +102,6 @@ pub fn functions() -> Vec> { named_struct(), get_field(), coalesce(), - make_map(), map(), ] } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index dab328cc4908..4a4b16b804e2 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -21,6 +21,7 @@ use datafusion_common::{ internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result, }; +use datafusion_expr::planner::PlannerResult; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition, @@ -227,6 +228,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { crate::utils::normalize_ident(name.0[0].clone()) }; + if name.eq("make_map") { + let mut fn_args = + self.function_args_to_expr(args.clone(), schema, planner_context)?; + for planner in self.context_provider.get_expr_planners().iter() { + match planner.plan_make_map(fn_args)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(args) => fn_args = args, + } + } + } + // user-defined function (UDF) should have precedence if let Some(fm) = self.context_provider.get_function_meta(&name) { let args = self.function_args_to_expr(args, schema, planner_context)?; diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index fb8917a5f4fe..26bfb4a5922e 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -131,17 +131,23 @@ SELECT MAKE_MAP([1,2], ['a', 'b'], [3,4], ['b']); ---- {[1, 2]: [a, b], [3, 4]: [b]} -query error +query ? SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30); +---- +{POST: 41, HEAD: ab, PATCH: 30} query error SELECT MAKE_MAP('POST', 41, 'HEAD', 33, null, 30); -query error -SELECT MAKE_MAP('POST', 41, 123, 33,'PATCH', 30); +query ? +SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30); +---- +{POST: 41, HEAD: ab, PATCH: 30} -query error +query ? SELECT MAKE_MAP() +---- +{} query error SELECT MAKE_MAP('POST', 41, 'HEAD'); From 28fa74bf0fb69f46fd03ef97eb301090de23b5f5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 19 Jul 2024 06:53:21 -0600 Subject: [PATCH 10/37] feat: Optimize CASE expression for "column or null" use case (#11534) --- datafusion/core/example.parquet | Bin 0 -> 976 bytes datafusion/physical-expr/benches/case_when.rs | 41 ++++- .../physical-expr/src/expressions/case.rs | 161 +++++++++++++++++- datafusion/sqllogictest/README.md | 2 +- datafusion/sqllogictest/test_files/case.slt | 52 ++++++ 5 files changed, 242 insertions(+), 14 deletions(-) create mode 100644 datafusion/core/example.parquet create mode 100644 datafusion/sqllogictest/test_files/case.slt diff --git a/datafusion/core/example.parquet b/datafusion/core/example.parquet new file mode 100644 index 0000000000000000000000000000000000000000..94de10394b33d26a23a9888e88faa1fa90f14043 GIT binary patch literal 976 zcmb7@y-UMD7{=dR+8CijD}7&bkU<0w2k`?GGL%A>;?SWub(6GKRM0|Ob#>@0PW}lF zj)D#j4jufP929)7>E!~g7L&j|d7ryqp4>;XcDRc@6Mv}ynjBo~6V zH`y+thhEt7jT5A*RN}trNWm|{0f9NW4_;9QPK*T-bm!2sqpAu*JJPBsrP-S1{t{1r zL|?PhFl5S2s;{?7@i{f==;+c__5v4R+(_C2+(#~f_ zqL5^4^S5jpnYGQ=*fw%%hg8QQV?c&9a#K0ZCz5+L4hnI<-@7>)bWXb$F?zh7>w(ctf)69oWF1C%Kz8Cp)ViHG+x3gs82To&93&yRUd+}=C`ei=G63r()}_L z-TE5)>SeImRT}5jD9>0kT~Se9KA*kUVhCXR^s>`3E!aTC)HE literal 0 HcmV?d00001 diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index 9cc7bdc465fb..862edd9c1fac 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -40,6 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) { // create input data let mut c1 = Int32Builder::new(); let mut c2 = StringBuilder::new(); + let mut c3 = StringBuilder::new(); for i in 0..1000 { c1.append_value(i); if i % 7 == 0 { @@ -47,14 +48,21 @@ fn criterion_benchmark(c: &mut Criterion) { } else { c2.append_value(&format!("string {i}")); } + if i % 9 == 0 { + c3.append_null(); + } else { + c3.append_value(&format!("other string {i}")); + } } let c1 = Arc::new(c1.finish()); let c2 = Arc::new(c2.finish()); + let c3 = Arc::new(c3.finish()); let schema = Schema::new(vec![ Field::new("c1", DataType::Int32, true), Field::new("c2", DataType::Utf8, true), + Field::new("c3", DataType::Utf8, true), ]); - let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2, c3]).unwrap(); // use same predicate for all benchmarks let predicate = Arc::new(BinaryExpr::new( @@ -63,7 +71,7 @@ fn criterion_benchmark(c: &mut Criterion) { make_lit_i32(500), )); - // CASE WHEN expr THEN 1 ELSE 0 END + // CASE WHEN c1 <= 500 THEN 1 ELSE 0 END c.bench_function("case_when: scalar or scalar", |b| { let expr = Arc::new( CaseExpr::try_new( @@ -76,13 +84,38 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) }); - // CASE WHEN expr THEN col ELSE null END + // CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END c.bench_function("case_when: column or null", |b| { + let expr = Arc::new( + CaseExpr::try_new(None, vec![(predicate.clone(), make_col("c2", 1))], None) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + // CASE WHEN c1 <= 500 THEN c2 ELSE c3 END + c.bench_function("case_when: expr or expr", |b| { let expr = Arc::new( CaseExpr::try_new( None, vec![(predicate.clone(), make_col("c2", 1))], - Some(Arc::new(Literal::new(ScalarValue::Utf8(None)))), + Some(make_col("c3", 2)), + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + // CASE c1 WHEN 1 THEN c2 WHEN 2 THEN c3 END + c.bench_function("case_when: CASE expr", |b| { + let expr = Arc::new( + CaseExpr::try_new( + Some(make_col("c1", 0)), + vec![ + (make_lit_i32(1), make_col("c2", 1)), + (make_lit_i32(2), make_col("c3", 2)), + ], + None, ) .unwrap(), ); diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 7a434c940229..521a7ed9acae 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -32,10 +32,33 @@ use datafusion_common::cast::as_boolean_array; use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::ColumnarValue; +use datafusion_physical_expr_common::expressions::column::Column; +use datafusion_physical_expr_common::expressions::Literal; use itertools::Itertools; type WhenThen = (Arc, Arc); +#[derive(Debug, Hash)] +enum EvalMethod { + /// CASE WHEN condition THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + NoExpression, + /// CASE expression + /// WHEN value THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + WithExpression, + /// This is a specialization for a specific use case where we can take a fast path + /// for expressions that are infallible and can be cheaply computed for the entire + /// record batch rather than just for the rows where the predicate is true. + /// + /// CASE WHEN condition THEN column [ELSE NULL] END + InfallibleExprOrNull, +} + /// The CASE expression is similar to a series of nested if/else and there are two forms that /// can be used. The first form consists of a series of boolean "when" expressions with /// corresponding "then" expressions, and an optional "else" expression. @@ -61,6 +84,8 @@ pub struct CaseExpr { when_then_expr: Vec, /// Optional "else" expression else_expr: Option>, + /// Evaluation method to use + eval_method: EvalMethod, } impl std::fmt::Display for CaseExpr { @@ -79,6 +104,15 @@ impl std::fmt::Display for CaseExpr { } } +/// This is a specialization for a specific use case where we can take a fast path +/// for expressions that are infallible and can be cheaply computed for the entire +/// record batch rather than just for the rows where the predicate is true. For now, +/// this is limited to use with Column expressions but could potentially be used for other +/// expressions in the future +fn is_cheap_and_infallible(expr: &Arc) -> bool { + expr.as_any().is::() +} + impl CaseExpr { /// Create a new CASE WHEN expression pub fn try_new( @@ -86,13 +120,35 @@ impl CaseExpr { when_then_expr: Vec, else_expr: Option>, ) -> Result { + // normalize null literals to None in the else_expr (this already happens + // during SQL planning, but not necessarily for other use cases) + let else_expr = match &else_expr { + Some(e) => match e.as_any().downcast_ref::() { + Some(lit) if lit.value().is_null() => None, + _ => else_expr, + }, + _ => else_expr, + }; + if when_then_expr.is_empty() { exec_err!("There must be at least one WHEN clause") } else { + let eval_method = if expr.is_some() { + EvalMethod::WithExpression + } else if when_then_expr.len() == 1 + && is_cheap_and_infallible(&(when_then_expr[0].1)) + && else_expr.is_none() + { + EvalMethod::InfallibleExprOrNull + } else { + EvalMethod::NoExpression + }; + Ok(Self { expr, when_then_expr, else_expr, + eval_method, }) } } @@ -256,6 +312,38 @@ impl CaseExpr { Ok(ColumnarValue::Array(current_value)) } + + /// This function evaluates the specialized case of: + /// + /// CASE WHEN condition THEN column + /// [ELSE NULL] + /// END + /// + /// Note that this function is only safe to use for "then" expressions + /// that are infallible because the expression will be evaluated for all + /// rows in the input batch. + fn case_column_or_null(&self, batch: &RecordBatch) -> Result { + let when_expr = &self.when_then_expr[0].0; + let then_expr = &self.when_then_expr[0].1; + if let ColumnarValue::Array(bit_mask) = when_expr.evaluate(batch)? { + let bit_mask = bit_mask + .as_any() + .downcast_ref::() + .expect("predicate should evaluate to a boolean array"); + // invert the bitmask + let bit_mask = not(bit_mask)?; + match then_expr.evaluate(batch)? { + ColumnarValue::Array(array) => { + Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?)) + } + ColumnarValue::Scalar(_) => { + internal_err!("expression did not evaluate to an array") + } + } + } else { + internal_err!("predicate did not evaluate to an array") + } + } } impl PhysicalExpr for CaseExpr { @@ -303,14 +391,21 @@ impl PhysicalExpr for CaseExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - if self.expr.is_some() { - // this use case evaluates "expr" and then compares the values with the "when" - // values - self.case_when_with_expr(batch) - } else { - // The "when" conditions all evaluate to boolean in this use case and can be - // arbitrary expressions - self.case_when_no_expr(batch) + match self.eval_method { + EvalMethod::WithExpression => { + // this use case evaluates "expr" and then compares the values with the "when" + // values + self.case_when_with_expr(batch) + } + EvalMethod::NoExpression => { + // The "when" conditions all evaluate to boolean in this use case and can be + // arbitrary expressions + self.case_when_no_expr(batch) + } + EvalMethod::InfallibleExprOrNull => { + // Specialization for CASE WHEN expr THEN column [ELSE NULL] END + self.case_column_or_null(batch) + } } } @@ -409,7 +504,7 @@ pub fn case( #[cfg(test)] mod tests { use super::*; - use crate::expressions::{binary, cast, col, lit}; + use crate::expressions::{binary, cast, col, lit, BinaryExpr}; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; @@ -419,6 +514,7 @@ mod tests { use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_expr::Operator; + use datafusion_physical_expr_common::expressions::Literal; #[test] fn case_with_expr() -> Result<()> { @@ -998,6 +1094,53 @@ mod tests { Ok(()) } + #[test] + fn test_column_or_null_specialization() -> Result<()> { + // create input data + let mut c1 = Int32Builder::new(); + let mut c2 = StringBuilder::new(); + for i in 0..1000 { + c1.append_value(i); + if i % 7 == 0 { + c2.append_null(); + } else { + c2.append_value(&format!("string {i}")); + } + } + let c1 = Arc::new(c1.finish()); + let c2 = Arc::new(c2.finish()); + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Utf8, true), + ]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap(); + + // CaseWhenExprOrNull should produce same results as CaseExpr + let predicate = Arc::new(BinaryExpr::new( + make_col("c1", 0), + Operator::LtEq, + make_lit_i32(250), + )); + let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?; + assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull)); + match expr.evaluate(&batch)? { + ColumnarValue::Array(array) => { + assert_eq!(1000, array.len()); + assert_eq!(785, array.null_count()); + } + _ => unreachable!(), + } + Ok(()) + } + + fn make_col(name: &str, index: usize) -> Arc { + Arc::new(Column::new(name, index)) + } + + fn make_lit_i32(n: i32) -> Arc { + Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) + } + fn generate_case_when_with_type_coercion( expr: Option>, when_thens: Vec, diff --git a/datafusion/sqllogictest/README.md b/datafusion/sqllogictest/README.md index c7f04c0d762c..5becc75c985a 100644 --- a/datafusion/sqllogictest/README.md +++ b/datafusion/sqllogictest/README.md @@ -133,7 +133,7 @@ In order to run the sqllogictests running against a previously running Postgres PG_COMPAT=true PG_URI="postgresql://postgres@127.0.0.1/postgres" cargo test --features=postgres --test sqllogictests ``` -The environemnt variables: +The environment variables: 1. `PG_COMPAT` instructs sqllogictest to run against Postgres (not DataFusion) 2. `PG_URI` contains a `libpq` style connection string, whose format is described in diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt new file mode 100644 index 000000000000..fac1042bb6dd --- /dev/null +++ b/datafusion/sqllogictest/test_files/case.slt @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# create test data +statement ok +create table foo (a int, b int) as values (1, 2), (3, 4), (5, 6); + +# CASE WHEN with condition +query T +SELECT CASE a WHEN 1 THEN 'one' WHEN 3 THEN 'three' ELSE '?' END FROM foo +---- +one +three +? + +# CASE WHEN with no condition +query I +SELECT CASE WHEN a > 2 THEN a ELSE b END FROM foo +---- +2 +3 +5 + +# column or explicit null +query I +SELECT CASE WHEN a > 2 THEN b ELSE null END FROM foo +---- +NULL +4 +6 + +# column or implicit null +query I +SELECT CASE WHEN a > 2 THEN b END FROM foo +---- +NULL +4 +6 From cf9da768306a1e103bfeae68f4f2ed3dfe87df7b Mon Sep 17 00:00:00 2001 From: JasonLi Date: Fri, 19 Jul 2024 21:56:14 +0800 Subject: [PATCH 11/37] fix: typos of sql, sqllogictest and substrait packages (#11548) --- datafusion/sql/src/expr/function.rs | 2 +- datafusion/sql/src/parser.rs | 4 ++-- datafusion/sql/src/relation/mod.rs | 2 +- datafusion/sql/src/select.rs | 2 +- datafusion/sql/src/unparser/ast.rs | 2 +- datafusion/sql/src/unparser/plan.rs | 6 +++--- datafusion/sql/src/unparser/utils.rs | 4 ++-- datafusion/sql/src/utils.rs | 2 +- datafusion/sql/tests/common/mod.rs | 2 +- datafusion/sql/tests/sql_integration.rs | 4 ++-- datafusion/sqllogictest/test_files/aggregate.slt | 8 ++++---- datafusion/sqllogictest/test_files/array.slt | 4 ++-- datafusion/sqllogictest/test_files/binary.slt | 2 +- datafusion/sqllogictest/test_files/copy.slt | 2 +- datafusion/sqllogictest/test_files/explain.slt | 2 +- datafusion/sqllogictest/test_files/interval.slt | 2 +- datafusion/sqllogictest/test_files/math.slt | 2 +- datafusion/sqllogictest/test_files/options.slt | 2 +- datafusion/sqllogictest/test_files/scalar.slt | 2 +- datafusion/sqllogictest/test_files/timestamps.slt | 2 +- datafusion/sqllogictest/test_files/unnest.slt | 4 ++-- datafusion/sqllogictest/test_files/update.slt | 2 +- datafusion/sqllogictest/test_files/window.slt | 2 +- datafusion/substrait/src/logical_plan/producer.rs | 2 +- 24 files changed, 34 insertions(+), 34 deletions(-) diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 4a4b16b804e2..4804752d8389 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -67,7 +67,7 @@ pub fn suggest_valid_function( find_closest_match(valid_funcs, input_function_name) } -/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve) +/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitive) /// Input `candidates` must not be empty otherwise it will panic fn find_closest_match(candidates: Vec, target: &str) -> String { let target = target.to_lowercase(); diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index bc13484235c3..a743aa72829d 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -218,7 +218,7 @@ impl fmt::Display for CreateExternalTable { /// /// This can either be a [`Statement`] from [`sqlparser`] from a /// standard SQL dialect, or a DataFusion extension such as `CREATE -/// EXTERAL TABLE`. See [`DFParser`] for more information. +/// EXTERNAL TABLE`. See [`DFParser`] for more information. /// /// [`Statement`]: sqlparser::ast::Statement #[derive(Debug, Clone, PartialEq, Eq)] @@ -1101,7 +1101,7 @@ mod tests { }); expect_parse_ok(sql, expected)?; - // positive case: column definiton allowed in 'partition by' clause + // positive case: column definition allowed in 'partition by' clause let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (p1 int) LOCATION 'foo.csv'"; let expected = Statement::CreateExternalTable(CreateExternalTable { diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 9380e569f2e4..b812dae5ae3e 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -105,7 +105,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Unnest table factor has empty input let schema = DFSchema::empty(); let input = LogicalPlanBuilder::empty(true).build()?; - // Unnest table factor can have multiple arugments. + // Unnest table factor can have multiple arguments. // We treat each argument as a separate unnest expression. let unnest_exprs = array_exprs .into_iter() diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 84b80c311245..fc46c3a841b5 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -306,7 +306,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut intermediate_select_exprs = select_exprs; // Each expr in select_exprs can contains multiple unnest stage // The transformation happen bottom up, one at a time for each iteration - // Ony exaust the loop if no more unnest transformation is found + // Only exaust the loop if no more unnest transformation is found for i in 0.. { let mut unnest_columns = vec![]; // from which column used for projection, before the unnest happen diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index 06b4d4a710a3..02eb44dbb657 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -497,7 +497,7 @@ impl Default for DerivedRelationBuilder { pub(super) struct UninitializedFieldError(&'static str); impl UninitializedFieldError { - /// Create a new `UnitializedFieldError` for the specified field name. + /// Create a new `UninitializedFieldError` for the specified field name. pub fn new(field_name: &'static str) -> Self { UninitializedFieldError(field_name) } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 7a653f80be08..26fd47299637 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -214,12 +214,12 @@ impl Unparser<'_> { } else { let mut derived_builder = DerivedRelationBuilder::default(); derived_builder.lateral(false).alias(None).subquery({ - let inner_statment = self.plan_to_sql(plan)?; - if let ast::Statement::Query(inner_query) = inner_statment { + let inner_statement = self.plan_to_sql(plan)?; + if let ast::Statement::Query(inner_query) = inner_statement { inner_query } else { return internal_err!( - "Subquery must be a Query, but found {inner_statment:?}" + "Subquery must be a Query, but found {inner_statement:?}" ); } }); diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 331da9773f16..71f64f1cf459 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -31,7 +31,7 @@ pub(crate) enum AggVariant<'a> { /// Recursively searches children of [LogicalPlan] to find an Aggregate or window node if one exists /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). /// If an Aggregate or window node is not found prior to this or at all before reaching the end -/// of the tree, None is returned. It is assumed that a Window and Aggegate node cannot both +/// of the tree, None is returned. It is assumed that a Window and Aggregate node cannot both /// be found in a single select query. pub(crate) fn find_agg_node_within_select<'a>( plan: &'a LogicalPlan, @@ -82,7 +82,7 @@ pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result expr.clone() .transform(|sub_expr| { if let Expr::Column(c) = sub_expr { - // find the column in the agg schmea + // find the column in the agg schema if let Ok(n) = agg.schema.index_of_column(&c) { let unprojected_expr = agg .group_expr diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 2eacbd174fc2..a70e3e9be930 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -325,7 +325,7 @@ pub(crate) fn transform_bottom_unnest( let (data_type, _) = arg.data_type_and_nullable(input.schema())?; if let DataType::Struct(_) = data_type { - return internal_err!("unnest on struct can ony be applied at the root level of select expression"); + return internal_err!("unnest on struct can only be applied at the root level of select expression"); } let mut transformed_exprs = transform(&expr, arg)?; diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index b8d8bd12d28b..bcfb8f43848e 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -56,7 +56,7 @@ pub(crate) struct MockContextProvider { } impl MockContextProvider { - // Surpressing dead code warning, as this is used in integration test crates + // Suppressing dead code warning, as this is used in integration test crates #[allow(dead_code)] pub(crate) fn options_mut(&mut self) -> &mut ConfigOptions { &mut self.options diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index e34e7e20a0f3..57dab81331b3 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -1144,7 +1144,7 @@ fn select_aggregate_with_group_by_with_having_that_reuses_aggregate_multiple_tim } #[test] -fn select_aggregate_with_group_by_with_having_using_aggreagate_not_in_select() { +fn select_aggregate_with_group_by_with_having_using_aggregate_not_in_select() { let sql = "SELECT first_name, MAX(age) FROM person GROUP BY first_name @@ -1185,7 +1185,7 @@ fn select_aggregate_compound_aliased_with_group_by_with_having_referencing_compo } #[test] -fn select_aggregate_with_group_by_with_having_using_derived_column_aggreagate_not_in_select( +fn select_aggregate_with_group_by_with_having_using_derived_column_aggregate_not_in_select( ) { let sql = "SELECT first_name, MAX(age) FROM person diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 1976951b8ce6..d0f7f2d9ac7a 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3705,7 +3705,7 @@ select column3 as tag from t_source; -# Demonstate the contents +# Demonstrate the contents query PPPPPPPPTT select * from t; ---- @@ -3816,7 +3816,7 @@ select column3 as tag from t_source; -# Demonstate the contents +# Demonstrate the contents query DDTT select * from t; ---- @@ -3914,7 +3914,7 @@ select column3 as tag from t_source; -# Demonstate the contents +# Demonstrate the contents query DDDDTT select * from t; ---- @@ -4108,7 +4108,7 @@ select sum(c1), arrow_typeof(sum(c1)) from d_table; ---- 100 Decimal128(20, 3) -# aggregate sum with deciaml +# aggregate sum with decimal statement ok create table t (c decimal(35, 3)) as values (10), (null), (20); diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 7917f1d78da8..f2972e4c14c2 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -5690,7 +5690,7 @@ select ---- [] [] [0] [0] -# Test range for other egde cases +# Test range for other edge cases query ???????? select range(9223372036854775807, 9223372036854775807, -1) as c1, @@ -5828,7 +5828,7 @@ select [-9223372036854775808] [9223372036854775807] [0, -9223372036854775808] [0, 9223372036854775807] -# Test generate_series for other egde cases +# Test generate_series for other edge cases query ???? select generate_series(9223372036854775807, 9223372036854775807, -1) as c1, diff --git a/datafusion/sqllogictest/test_files/binary.slt b/datafusion/sqllogictest/test_files/binary.slt index 621cd3e528f1..5c5f9d510e55 100644 --- a/datafusion/sqllogictest/test_files/binary.slt +++ b/datafusion/sqllogictest/test_files/binary.slt @@ -25,7 +25,7 @@ SELECT X'FF01', arrow_typeof(X'FF01'); ---- ff01 Binary -# Invaid hex values +# Invalid hex values query error DataFusion error: Error during planning: Invalid HexStringLiteral 'Z' SELECT X'Z' diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index 6a6ab15a065d..7af4c52c654b 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -271,7 +271,7 @@ select * from validate_struct_with_array; {c0: foo, c1: [1, 2, 3], c2: {c0: bar, c1: [2, 3, 4]}} -# Copy parquet with all supported statment overrides +# Copy parquet with all supported statement overrides query IT COPY source_table TO 'test_files/scratch/copy/table_with_options/' diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 3a4e8072bbc7..172cbad44dca 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -394,7 +394,7 @@ physical_plan_with_schema statement ok set datafusion.execution.collect_statistics = false; -# Explain ArrayFuncions +# Explain ArrayFunctions statement ok set datafusion.explain.physical_plan_only = false diff --git a/datafusion/sqllogictest/test_files/interval.slt b/datafusion/sqllogictest/test_files/interval.slt index eab4eed00269..afb262cf95a5 100644 --- a/datafusion/sqllogictest/test_files/interval.slt +++ b/datafusion/sqllogictest/test_files/interval.slt @@ -325,7 +325,7 @@ select ---- Interval(MonthDayNano) Interval(MonthDayNano) -# cast with explicit cast sytax +# cast with explicit cast syntax query TT select arrow_typeof(cast ('5 months' as interval)), diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 6ff804c3065d..6884d762612d 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -112,7 +112,7 @@ SELECT iszero(1.0), iszero(0.0), iszero(-0.0), iszero(NULL) ---- false true true NULL -# abs: empty argumnet +# abs: empty argument statement error SELECT abs(); diff --git a/datafusion/sqllogictest/test_files/options.slt b/datafusion/sqllogictest/test_files/options.slt index ba9eedcbbd34..aafaa054964e 100644 --- a/datafusion/sqllogictest/test_files/options.slt +++ b/datafusion/sqllogictest/test_files/options.slt @@ -42,7 +42,7 @@ physical_plan statement ok set datafusion.execution.coalesce_batches = false -# expect no coalsece +# expect no coalescence query TT explain SELECT * FROM a WHERE c0 < 1; ---- diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 5daa9333fb36..dd19a1344139 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1203,7 +1203,7 @@ FROM t1 999 999 -# case_when_else_with_null_contant() +# case_when_else_with_null_constant() query I SELECT CASE WHEN c1 = 'a' THEN 1 diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index f4e492649b9f..2ca2d49997a6 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -1161,7 +1161,7 @@ ts_data_secs 2020-09-08T00:00:00 ts_data_secs 2020-09-08T00:00:00 ts_data_secs 2020-09-08T00:00:00 -# Test date trun on different granularity +# Test date turn on different granularity query TP rowsort SELECT 'millisecond', DATE_TRUNC('millisecond', ts) FROM ts_data_nanos UNION ALL diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index 93146541e107..d818c0e92795 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -484,7 +484,7 @@ query error DataFusion error: type_coercion\ncaused by\nThis feature is not impl select sum(unnest(generate_series(1,10))); ## TODO: support unnest as a child expr -query error DataFusion error: Internal error: unnest on struct can ony be applied at the root level of select expression +query error DataFusion error: Internal error: unnest on struct can only be applied at the root level of select expression select arrow_typeof(unnest(column5)) from unnest_table; @@ -517,7 +517,7 @@ select unnest(unnest(unnest(column3)['c1'])), column3 from recursive_unnest_tabl 3 [{c0: [2], c1: [[3], [4]]}] 4 [{c0: [2], c1: [[3], [4]]}] -## tripple list unnest +## triple list unnest query I? select unnest(unnest(unnest(column2))), column2 from recursive_unnest_table; ---- diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index 49b2bd9aa0b5..3d455d7a88ca 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -74,7 +74,7 @@ logical_plan statement ok create table t3(a int, b varchar, c double, d int); -# set from mutiple tables, sqlparser only supports from one table +# set from multiple tables, sqlparser only supports from one table query error DataFusion error: SQL error: ParserError\("Expected end of statement, found: ,"\) explain update t1 set b = t2.b, c = t3.a, d = 1 from t2, t3 where t1.a = t2.a and t1.a = t3.a; diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 5296f13de08a..37214e11eae8 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3968,7 +3968,7 @@ CREATE TABLE table_with_pk ( # However, if we know that contains a unique column (e.g. a PRIMARY KEY), # it can be treated as `OVER (ORDER BY ROWS BETWEEN UNBOUNDED PRECEDING # AND CURRENT ROW)` where window frame units change from `RANGE` to `ROWS`. This -# conversion makes the window frame manifestly causal by eliminating the possiblity +# conversion makes the window frame manifestly causal by eliminating the possibility # of ties explicitly (see window frame documentation for a discussion of causality # in this context). The Query below should have `ROWS` in its window frame. query TT diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 7849d0bd431e..0fd59d528086 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -994,7 +994,7 @@ pub fn make_binary_op_scalar_func( /// /// * `expr` - DataFusion expression to be parse into a Substrait expression /// * `schema` - DataFusion input schema for looking up field qualifiers -/// * `col_ref_offset` - Offset for caculating Substrait field reference indices. +/// * `col_ref_offset` - Offset for calculating Substrait field reference indices. /// This should only be set by caller with more than one input relations i.e. Join. /// Substrait expects one set of indices when joining two relations. /// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right` From a4c9bb45744bfdfa714cb1a7a234e89608196169 Mon Sep 17 00:00:00 2001 From: Arttu Date: Fri, 19 Jul 2024 21:38:04 +0200 Subject: [PATCH 12/37] feat: consume and produce Substrait type extensions (#11510) * support reading type extensions in consumer * read extension for UDTs * support also type extensions in producer * produce extensions for MonthDayNano UDT * unify extensions between consumer and producer * fixes * add doc comments * add extension tests * fix * fix docs * fix test * fix clipppy --- datafusion/substrait/src/extensions.rs | 157 ++++++ datafusion/substrait/src/lib.rs | 1 + .../substrait/src/logical_plan/consumer.rs | 269 +++++---- .../substrait/src/logical_plan/producer.rs | 524 +++++++++--------- datafusion/substrait/src/variation_const.rs | 24 +- .../tests/cases/roundtrip_logical_plan.rs | 99 ++-- 6 files changed, 644 insertions(+), 430 deletions(-) create mode 100644 datafusion/substrait/src/extensions.rs diff --git a/datafusion/substrait/src/extensions.rs b/datafusion/substrait/src/extensions.rs new file mode 100644 index 000000000000..459d0e0c5ae5 --- /dev/null +++ b/datafusion/substrait/src/extensions.rs @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::{plan_err, DataFusionError}; +use std::collections::HashMap; +use substrait::proto::extensions::simple_extension_declaration::{ + ExtensionFunction, ExtensionType, ExtensionTypeVariation, MappingType, +}; +use substrait::proto::extensions::SimpleExtensionDeclaration; + +/// Substrait uses [SimpleExtensions](https://substrait.io/extensions/#simple-extensions) to define +/// behavior of plans in addition to what's supported directly by the protobuf definitions. +/// That includes functions, but also provides support for custom types and variations for existing +/// types. This structs facilitates the use of these extensions in DataFusion. +/// TODO: DF doesn't yet use extensions for type variations +/// TODO: DF doesn't yet provide valid extensionUris +#[derive(Default, Debug, PartialEq)] +pub struct Extensions { + pub functions: HashMap, // anchor -> function name + pub types: HashMap, // anchor -> type name + pub type_variations: HashMap, // anchor -> type variation name +} + +impl Extensions { + /// Registers a function and returns the anchor (reference) to it. If the function has already + /// been registered, it returns the existing anchor. + /// Function names are case-insensitive (converted to lowercase). + pub fn register_function(&mut self, function_name: String) -> u32 { + let function_name = function_name.to_lowercase(); + + // Some functions are named differently in Substrait default extensions than in DF + // Rename those to match the Substrait extensions for interoperability + let function_name = match function_name.as_str() { + "substr" => "substring".to_string(), + _ => function_name, + }; + + match self.functions.iter().find(|(_, f)| *f == &function_name) { + Some((function_anchor, _)) => *function_anchor, // Function has been registered + None => { + // Function has NOT been registered + let function_anchor = self.functions.len() as u32; + self.functions + .insert(function_anchor, function_name.clone()); + function_anchor + } + } + } + + /// Registers a type and returns the anchor (reference) to it. If the type has already + /// been registered, it returns the existing anchor. + pub fn register_type(&mut self, type_name: String) -> u32 { + let type_name = type_name.to_lowercase(); + match self.types.iter().find(|(_, t)| *t == &type_name) { + Some((type_anchor, _)) => *type_anchor, // Type has been registered + None => { + // Type has NOT been registered + let type_anchor = self.types.len() as u32; + self.types.insert(type_anchor, type_name.clone()); + type_anchor + } + } + } +} + +impl TryFrom<&Vec> for Extensions { + type Error = DataFusionError; + + fn try_from( + value: &Vec, + ) -> datafusion::common::Result { + let mut functions = HashMap::new(); + let mut types = HashMap::new(); + let mut type_variations = HashMap::new(); + + for ext in value { + match &ext.mapping_type { + Some(MappingType::ExtensionFunction(ext_f)) => { + functions.insert(ext_f.function_anchor, ext_f.name.to_owned()); + } + Some(MappingType::ExtensionType(ext_t)) => { + types.insert(ext_t.type_anchor, ext_t.name.to_owned()); + } + Some(MappingType::ExtensionTypeVariation(ext_v)) => { + type_variations + .insert(ext_v.type_variation_anchor, ext_v.name.to_owned()); + } + None => return plan_err!("Cannot parse empty extension"), + } + } + + Ok(Extensions { + functions, + types, + type_variations, + }) + } +} + +impl From for Vec { + fn from(val: Extensions) -> Vec { + let mut extensions = vec![]; + for (f_anchor, f_name) in val.functions { + let function_extension = ExtensionFunction { + extension_uri_reference: u32::MAX, + function_anchor: f_anchor, + name: f_name, + }; + let simple_extension = SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionFunction(function_extension)), + }; + extensions.push(simple_extension); + } + + for (t_anchor, t_name) in val.types { + let type_extension = ExtensionType { + extension_uri_reference: u32::MAX, // https://github.com/apache/datafusion/issues/11545 + type_anchor: t_anchor, + name: t_name, + }; + let simple_extension = SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionType(type_extension)), + }; + extensions.push(simple_extension); + } + + for (tv_anchor, tv_name) in val.type_variations { + let type_variation_extension = ExtensionTypeVariation { + extension_uri_reference: u32::MAX, // We don't register proper extension URIs yet + type_variation_anchor: tv_anchor, + name: tv_name, + }; + let simple_extension = SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionTypeVariation( + type_variation_extension, + )), + }; + extensions.push(simple_extension); + } + + extensions + } +} diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs index 454f0e7b7cb9..0b1c796553c0 100644 --- a/datafusion/substrait/src/lib.rs +++ b/datafusion/substrait/src/lib.rs @@ -72,6 +72,7 @@ //! # Ok(()) //! # } //! ``` +pub mod extensions; pub mod logical_plan; pub mod physical_plan; pub mod serializer; diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 1365630d5079..5768c44bbf6c 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -36,16 +36,21 @@ use datafusion::logical_expr::{ use substrait::proto::expression::subquery::set_predicate::PredicateOp; use url::Url; +use crate::extensions::Extensions; use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF, - INTERVAL_YEAR_MONTH_TYPE_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, + INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, }; +#[allow(deprecated)] +use crate::variation_const::{ + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF, + INTERVAL_YEAR_MONTH_TYPE_REF, +}; use datafusion::common::scalar::ScalarStructBuilder; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ @@ -65,7 +70,9 @@ use std::str::FromStr; use std::sync::Arc; use substrait::proto::exchange_rel::ExchangeKind; use substrait::proto::expression::literal::user_defined::Val; -use substrait::proto::expression::literal::{IntervalDayToSecond, IntervalYearToMonth}; +use substrait::proto::expression::literal::{ + IntervalDayToSecond, IntervalYearToMonth, UserDefined, +}; use substrait::proto::expression::subquery::SubqueryType; use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction}; use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; @@ -78,7 +85,6 @@ use substrait::proto::{ window_function::bound::Kind as BoundKind, window_function::Bound, window_function::BoundsType, MaskExpression, RexType, }, - extensions::simple_extension_declaration::MappingType, function_argument::ArgType, join_rel, plan_rel, r#type, read_rel::ReadType, @@ -185,19 +191,10 @@ pub async fn from_substrait_plan( plan: &Plan, ) -> Result { // Register function extension - let function_extension = plan - .extensions - .iter() - .map(|e| match &e.mapping_type { - Some(ext) => match ext { - MappingType::ExtensionFunction(ext_f) => { - Ok((ext_f.function_anchor, &ext_f.name)) - } - _ => not_impl_err!("Extension type not supported: {ext:?}"), - }, - None => not_impl_err!("Cannot parse empty extension"), - }) - .collect::>>()?; + let extensions = Extensions::try_from(&plan.extensions)?; + if !extensions.type_variations.is_empty() { + return not_impl_err!("Type variation extensions are not supported"); + } // Parse relations match plan.relations.len() { @@ -205,10 +202,10 @@ pub async fn from_substrait_plan( match plan.relations[0].rel_type.as_ref() { Some(rt) => match rt { plan_rel::RelType::Rel(rel) => { - Ok(from_substrait_rel(ctx, rel, &function_extension).await?) + Ok(from_substrait_rel(ctx, rel, &extensions).await?) }, plan_rel::RelType::Root(root) => { - let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?; + let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &extensions).await?; if root.names.is_empty() { // Backwards compatibility for plans missing names return Ok(plan); @@ -396,7 +393,7 @@ fn make_renamed_schema( pub async fn from_substrait_rel( ctx: &SessionContext, rel: &Rel, - extensions: &HashMap, + extensions: &Extensions, ) -> Result { match &rel.rel_type { Some(RelType::Project(p)) => { @@ -660,7 +657,7 @@ pub async fn from_substrait_rel( substrait_datafusion_err!("No base schema provided for Virtual Table") })?; - let schema = from_substrait_named_struct(base_schema)?; + let schema = from_substrait_named_struct(base_schema, extensions)?; if vt.values.is_empty() { return Ok(LogicalPlan::EmptyRelation(EmptyRelation { @@ -681,6 +678,7 @@ pub async fn from_substrait_rel( name_idx += 1; // top-level names are provided through schema Ok(Expr::Literal(from_substrait_literal( lit, + extensions, &base_schema.names, &mut name_idx, )?)) @@ -892,7 +890,7 @@ pub async fn from_substrait_sorts( ctx: &SessionContext, substrait_sorts: &Vec, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result> { let mut sorts: Vec = vec![]; for s in substrait_sorts { @@ -942,7 +940,7 @@ pub async fn from_substrait_rex_vec( ctx: &SessionContext, exprs: &Vec, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result> { let mut expressions: Vec = vec![]; for expr in exprs { @@ -957,7 +955,7 @@ pub async fn from_substrait_func_args( ctx: &SessionContext, arguments: &Vec, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result> { let mut args: Vec = vec![]; for arg in arguments { @@ -977,7 +975,7 @@ pub async fn from_substrait_agg_func( ctx: &SessionContext, f: &AggregateFunction, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, filter: Option>, order_by: Option>, distinct: bool, @@ -985,14 +983,14 @@ pub async fn from_substrait_agg_func( let args = from_substrait_func_args(ctx, &f.arguments, input_schema, extensions).await?; - let Some(function_name) = extensions.get(&f.function_reference) else { + let Some(function_name) = extensions.functions.get(&f.function_reference) else { return plan_err!( "Aggregate function not registered: function anchor = {:?}", f.function_reference ); }; - let function_name = substrait_fun_name((**function_name).as_str()); + let function_name = substrait_fun_name(function_name); // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { // deal with situation that count(*) got no arguments @@ -1025,7 +1023,7 @@ pub async fn from_substrait_rex( ctx: &SessionContext, e: &Expression, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result { match &e.rex_type { Some(RexType::SingularOrList(s)) => { @@ -1105,7 +1103,7 @@ pub async fn from_substrait_rex( })) } Some(RexType::ScalarFunction(f)) => { - let Some(fn_name) = extensions.get(&f.function_reference) else { + let Some(fn_name) = extensions.functions.get(&f.function_reference) else { return plan_err!( "Scalar function not found: function reference = {:?}", f.function_reference @@ -1155,7 +1153,7 @@ pub async fn from_substrait_rex( } } Some(RexType::Literal(lit)) => { - let scalar_value = from_substrait_literal_without_names(lit)?; + let scalar_value = from_substrait_literal_without_names(lit, extensions)?; Ok(Expr::Literal(scalar_value)) } Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() { @@ -1169,12 +1167,13 @@ pub async fn from_substrait_rex( ) .await?, ), - from_substrait_type_without_names(output_type)?, + from_substrait_type_without_names(output_type, extensions)?, ))), None => substrait_err!("Cast expression without output type is not allowed"), }, Some(RexType::WindowFunction(window)) => { - let Some(fn_name) = extensions.get(&window.function_reference) else { + let Some(fn_name) = extensions.functions.get(&window.function_reference) + else { return plan_err!( "Window function not found: function reference = {:?}", window.function_reference @@ -1328,12 +1327,16 @@ pub async fn from_substrait_rex( } } -pub(crate) fn from_substrait_type_without_names(dt: &Type) -> Result { - from_substrait_type(dt, &[], &mut 0) +pub(crate) fn from_substrait_type_without_names( + dt: &Type, + extensions: &Extensions, +) -> Result { + from_substrait_type(dt, extensions, &[], &mut 0) } fn from_substrait_type( dt: &Type, + extensions: &Extensions, dfs_names: &[String], name_idx: &mut usize, ) -> Result { @@ -1416,7 +1419,7 @@ fn from_substrait_type( substrait_datafusion_err!("List type must have inner type") })?; let field = Arc::new(Field::new_list_field( - from_substrait_type(inner_type, dfs_names, name_idx)?, + from_substrait_type(inner_type, extensions, dfs_names, name_idx)?, // We ignore Substrait's nullability here to match to_substrait_literal // which always creates nullable lists true, @@ -1438,12 +1441,12 @@ fn from_substrait_type( })?; let key_field = Arc::new(Field::new( "key", - from_substrait_type(key_type, dfs_names, name_idx)?, + from_substrait_type(key_type, extensions, dfs_names, name_idx)?, false, )); let value_field = Arc::new(Field::new( "value", - from_substrait_type(value_type, dfs_names, name_idx)?, + from_substrait_type(value_type, extensions, dfs_names, name_idx)?, true, )); match map.type_variation_reference { @@ -1490,28 +1493,41 @@ fn from_substrait_type( ), }, r#type::Kind::UserDefined(u) => { - match u.type_reference { - // Kept for backwards compatibility, use IntervalYear instead - INTERVAL_YEAR_MONTH_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::YearMonth)) - } - // Kept for backwards compatibility, use IntervalDay instead - INTERVAL_DAY_TIME_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::DayTime)) + if let Some(name) = extensions.types.get(&u.type_reference) { + match name.as_ref() { + INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), + _ => not_impl_err!( + "Unsupported Substrait user defined type with ref {} and variation {}", + u.type_reference, + u.type_variation_reference + ), } - // Not supported yet by Substrait - INTERVAL_MONTH_DAY_NANO_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::MonthDayNano)) - } - _ => not_impl_err!( + } else { + // Kept for backwards compatibility, new plans should include the extension instead + #[allow(deprecated)] + match u.type_reference { + // Kept for backwards compatibility, use IntervalYear instead + INTERVAL_YEAR_MONTH_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::YearMonth)) + } + // Kept for backwards compatibility, use IntervalDay instead + INTERVAL_DAY_TIME_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::DayTime)) + } + // Not supported yet by Substrait + INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } + _ => not_impl_err!( "Unsupported Substrait user defined type with ref {} and variation {}", u.type_reference, u.type_variation_reference ), + } } } r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type( - s, dfs_names, name_idx, + s, extensions, dfs_names, name_idx, )?)), r#type::Kind::Varchar(_) => Ok(DataType::Utf8), r#type::Kind::FixedChar(_) => Ok(DataType::Utf8), @@ -1523,6 +1539,7 @@ fn from_substrait_type( fn from_substrait_struct_type( s: &r#type::Struct, + extensions: &Extensions, dfs_names: &[String], name_idx: &mut usize, ) -> Result { @@ -1530,7 +1547,7 @@ fn from_substrait_struct_type( for (i, f) in s.types.iter().enumerate() { let field = Field::new( next_struct_field_name(i, dfs_names, name_idx)?, - from_substrait_type(f, dfs_names, name_idx)?, + from_substrait_type(f, extensions, dfs_names, name_idx)?, true, // We assume everything to be nullable since that's easier than ensuring it matches ); fields.push(field); @@ -1556,12 +1573,16 @@ fn next_struct_field_name( } } -fn from_substrait_named_struct(base_schema: &NamedStruct) -> Result { +fn from_substrait_named_struct( + base_schema: &NamedStruct, + extensions: &Extensions, +) -> Result { let mut name_idx = 0; let fields = from_substrait_struct_type( base_schema.r#struct.as_ref().ok_or_else(|| { substrait_datafusion_err!("Named struct must contain a struct") })?, + extensions, &base_schema.names, &mut name_idx, ); @@ -1621,12 +1642,16 @@ fn from_substrait_bound( } } -pub(crate) fn from_substrait_literal_without_names(lit: &Literal) -> Result { - from_substrait_literal(lit, &vec![], &mut 0) +pub(crate) fn from_substrait_literal_without_names( + lit: &Literal, + extensions: &Extensions, +) -> Result { + from_substrait_literal(lit, extensions, &vec![], &mut 0) } fn from_substrait_literal( lit: &Literal, + extensions: &Extensions, dfs_names: &Vec, name_idx: &mut usize, ) -> Result { @@ -1721,7 +1746,7 @@ fn from_substrait_literal( let elements = l .values .iter() - .map(|el| from_substrait_literal(el, dfs_names, name_idx)) + .map(|el| from_substrait_literal(el, extensions, dfs_names, name_idx)) .collect::>>()?; if elements.is_empty() { return substrait_err!( @@ -1744,6 +1769,7 @@ fn from_substrait_literal( Some(LiteralType::EmptyList(l)) => { let element_type = from_substrait_type( l.r#type.clone().unwrap().as_ref(), + extensions, dfs_names, name_idx, )?; @@ -1763,7 +1789,7 @@ fn from_substrait_literal( let mut builder = ScalarStructBuilder::new(); for (i, field) in s.fields.iter().enumerate() { let name = next_struct_field_name(i, dfs_names, name_idx)?; - let sv = from_substrait_literal(field, dfs_names, name_idx)?; + let sv = from_substrait_literal(field, extensions, dfs_names, name_idx)?; // We assume everything to be nullable, since Arrow's strict about things matching // and it's hard to match otherwise. builder = builder.with_scalar(Field::new(name, sv.data_type(), true), sv); @@ -1771,7 +1797,7 @@ fn from_substrait_literal( builder.build()? } Some(LiteralType::Null(ntype)) => { - from_substrait_null(ntype, dfs_names, name_idx)? + from_substrait_null(ntype, extensions, dfs_names, name_idx)? } Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond { days, @@ -1786,40 +1812,9 @@ fn from_substrait_literal( } Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), Some(LiteralType::UserDefined(user_defined)) => { - match user_defined.type_reference { - // Kept for backwards compatibility, use IntervalYearToMonth instead - INTERVAL_YEAR_MONTH_TYPE_REF => { - let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { - return substrait_err!("Interval year month value is empty"); - }; - let value_slice: [u8; 4] = - (*raw_val.value).try_into().map_err(|_| { - substrait_datafusion_err!( - "Failed to parse interval year month value" - ) - })?; - ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes(value_slice))) - } - // Kept for backwards compatibility, use IntervalDayToSecond instead - INTERVAL_DAY_TIME_TYPE_REF => { - let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { - return substrait_err!("Interval day time value is empty"); - }; - let value_slice: [u8; 8] = - (*raw_val.value).try_into().map_err(|_| { - substrait_datafusion_err!( - "Failed to parse interval day time value" - ) - })?; - let days = i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); - let milliseconds = - i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); - ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days, - milliseconds, - })) - } - INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + // Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed + let interval_month_day_nano = + |user_defined: &UserDefined| -> Result { let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { return substrait_err!("Interval month day nano value is empty"); }; @@ -1834,17 +1829,76 @@ fn from_substrait_literal( let days = i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); let nanoseconds = i64::from_le_bytes(value_slice[8..16].try_into().unwrap()); - ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { - months, - days, - nanoseconds, - })) - } - _ => { - return not_impl_err!( - "Unsupported Substrait user defined type with ref {}", - user_defined.type_reference + Ok(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano { + months, + days, + nanoseconds, + }, + ))) + }; + + if let Some(name) = extensions.types.get(&user_defined.type_reference) { + match name.as_ref() { + INTERVAL_MONTH_DAY_NANO_TYPE_NAME => { + interval_month_day_nano(user_defined)? + } + _ => { + return not_impl_err!( + "Unsupported Substrait user defined type with ref {} and name {}", + user_defined.type_reference, + name ) + } + } + } else { + // Kept for backwards compatibility - new plans should include extension instead + #[allow(deprecated)] + match user_defined.type_reference { + // Kept for backwards compatibility, use IntervalYearToMonth instead + INTERVAL_YEAR_MONTH_TYPE_REF => { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval year month value is empty"); + }; + let value_slice: [u8; 4] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval year month value" + ) + })?; + ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes( + value_slice, + ))) + } + // Kept for backwards compatibility, use IntervalDayToSecond instead + INTERVAL_DAY_TIME_TYPE_REF => { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval day time value is empty"); + }; + let value_slice: [u8; 8] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval day time value" + ) + })?; + let days = + i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); + let milliseconds = + i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days, + milliseconds, + })) + } + INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + interval_month_day_nano(user_defined)? + } + _ => { + return not_impl_err!( + "Unsupported Substrait user defined type literal with ref {}", + user_defined.type_reference + ) + } } } } @@ -1856,6 +1910,7 @@ fn from_substrait_literal( fn from_substrait_null( null_type: &Type, + extensions: &Extensions, dfs_names: &[String], name_idx: &mut usize, ) -> Result { @@ -1940,6 +1995,7 @@ fn from_substrait_null( let field = Field::new_list_field( from_substrait_type( l.r#type.clone().unwrap().as_ref(), + extensions, dfs_names, name_idx, )?, @@ -1958,7 +2014,8 @@ fn from_substrait_null( } } r#type::Kind::Struct(s) => { - let fields = from_substrait_struct_type(s, dfs_names, name_idx)?; + let fields = + from_substrait_struct_type(s, extensions, dfs_names, name_idx)?; Ok(ScalarStructBuilder::new_null(fields)) } _ => not_impl_err!("Unsupported Substrait type for null: {kind:?}"), @@ -2012,7 +2069,7 @@ impl BuiltinExprBuilder { ctx: &SessionContext, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result { match self.expr_name.as_str() { "like" => { @@ -2037,7 +2094,7 @@ impl BuiltinExprBuilder { fn_name: &str, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result { if f.arguments.len() != 1 { return substrait_err!("Expect one argument for {fn_name} expr"); @@ -2071,7 +2128,7 @@ impl BuiltinExprBuilder { case_insensitive: bool, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result { let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; if f.arguments.len() != 2 && f.arguments.len() != 3 { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 0fd59d528086..8f69cc5e218f 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -16,7 +16,6 @@ // under the License. use itertools::Itertools; -use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; @@ -33,6 +32,16 @@ use datafusion::{ scalar::ScalarValue, }; +use crate::extensions::Extensions; +use crate::variation_const::{ + DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, + DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, + TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, + TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, + UNSIGNED_INTEGER_TYPE_VARIATION_REF, +}; use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; use datafusion::common::{ exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err, @@ -72,10 +81,6 @@ use substrait::{ ScalarFunction, SingularOrList, Subquery, WindowFunction as SubstraitWindowFunction, }, - extensions::{ - self, - simple_extension_declaration::{ExtensionFunction, MappingType}, - }, function_argument::ArgType, join_rel, plan_rel, r#type, read_rel::{NamedTable, ReadType}, @@ -90,39 +95,24 @@ use substrait::{ version, }; -use crate::variation_const::{ - DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, - DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_URL, - LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, - TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, - TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, -}; - /// Convert DataFusion LogicalPlan to Substrait Plan pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result> { + let mut extensions = Extensions::default(); // Parse relation nodes - let mut extension_info: ( - Vec, - HashMap, - ) = (vec![], HashMap::new()); // Generate PlanRel(s) // Note: Only 1 relation tree is currently supported let plan_rels = vec![PlanRel { rel_type: Some(plan_rel::RelType::Root(RelRoot { - input: Some(*to_substrait_rel(plan, ctx, &mut extension_info)?), - names: to_substrait_named_struct(plan.schema())?.names, + input: Some(*to_substrait_rel(plan, ctx, &mut extensions)?), + names: to_substrait_named_struct(plan.schema(), &mut extensions)?.names, })), }]; - let (function_extensions, _) = extension_info; - // Return parsed plan Ok(Box::new(Plan { version: Some(version::version_with_producer("datafusion")), extension_uris: vec![], - extensions: function_extensions, + extensions: extensions.into(), relations: plan_rels, advanced_extensions: None, expected_type_urls: vec![], @@ -133,10 +123,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result, - HashMap, - ), + extensions: &mut Extensions, ) -> Result> { match plan { LogicalPlan::TableScan(scan) => { @@ -187,7 +174,7 @@ pub fn to_substrait_rel( Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, - base_schema: Some(to_substrait_named_struct(&e.schema)?), + base_schema: Some(to_substrait_named_struct(&e.schema, extensions)?), filter: None, best_effort_filter: None, projection: None, @@ -206,10 +193,10 @@ pub fn to_substrait_rel( let fields = row .iter() .map(|v| match v { - Expr::Literal(sv) => to_substrait_literal(sv), + Expr::Literal(sv) => to_substrait_literal(sv, extensions), Expr::Alias(alias) => match alias.expr.as_ref() { // The schema gives us the names, so we can skip aliases - Expr::Literal(sv) => to_substrait_literal(sv), + Expr::Literal(sv) => to_substrait_literal(sv, extensions), _ => Err(substrait_datafusion_err!( "Only literal types can be aliased in Virtual Tables, got: {}", alias.expr.variant_name() )), @@ -225,7 +212,7 @@ pub fn to_substrait_rel( Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, - base_schema: Some(to_substrait_named_struct(&v.schema)?), + base_schema: Some(to_substrait_named_struct(&v.schema, extensions)?), filter: None, best_effort_filter: None, projection: None, @@ -238,25 +225,25 @@ pub fn to_substrait_rel( let expressions = p .expr .iter() - .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extension_info)) + .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extensions)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { common: None, - input: Some(to_substrait_rel(p.input.as_ref(), ctx, extension_info)?), + input: Some(to_substrait_rel(p.input.as_ref(), ctx, extensions)?), expressions, advanced_extension: None, }))), })) } LogicalPlan::Filter(filter) => { - let input = to_substrait_rel(filter.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(filter.input.as_ref(), ctx, extensions)?; let filter_expr = to_substrait_rex( ctx, &filter.predicate, filter.input.schema(), 0, - extension_info, + extensions, )?; Ok(Box::new(Rel { rel_type: Some(RelType::Filter(Box::new(FilterRel { @@ -268,7 +255,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Limit(limit) => { - let input = to_substrait_rel(limit.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(limit.input.as_ref(), ctx, extensions)?; // Since protobuf can't directly distinguish `None` vs `0` encode `None` as `MAX` let limit_fetch = limit.fetch.unwrap_or(usize::MAX); Ok(Box::new(Rel { @@ -282,13 +269,11 @@ pub fn to_substrait_rel( })) } LogicalPlan::Sort(sort) => { - let input = to_substrait_rel(sort.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(sort.input.as_ref(), ctx, extensions)?; let sort_fields = sort .expr .iter() - .map(|e| { - substrait_sort_field(ctx, e, sort.input.schema(), extension_info) - }) + .map(|e| substrait_sort_field(ctx, e, sort.input.schema(), extensions)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { @@ -300,19 +285,17 @@ pub fn to_substrait_rel( })) } LogicalPlan::Aggregate(agg) => { - let input = to_substrait_rel(agg.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(agg.input.as_ref(), ctx, extensions)?; let groupings = to_substrait_groupings( ctx, &agg.group_expr, agg.input.schema(), - extension_info, + extensions, )?; let measures = agg .aggr_expr .iter() - .map(|e| { - to_substrait_agg_measure(ctx, e, agg.input.schema(), extension_info) - }) + .map(|e| to_substrait_agg_measure(ctx, e, agg.input.schema(), extensions)) .collect::>>()?; Ok(Box::new(Rel { @@ -327,7 +310,7 @@ pub fn to_substrait_rel( } LogicalPlan::Distinct(Distinct::All(plan)) => { // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = to_substrait_rel(plan.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(plan.as_ref(), ctx, extensions)?; // Get grouping keys from the input relation's number of output fields let grouping = (0..plan.schema().fields().len()) .map(substrait_field_ref) @@ -346,8 +329,8 @@ pub fn to_substrait_rel( })) } LogicalPlan::Join(join) => { - let left = to_substrait_rel(join.left.as_ref(), ctx, extension_info)?; - let right = to_substrait_rel(join.right.as_ref(), ctx, extension_info)?; + let left = to_substrait_rel(join.left.as_ref(), ctx, extensions)?; + let right = to_substrait_rel(join.right.as_ref(), ctx, extensions)?; let join_type = to_substrait_jointype(join.join_type); // we only support basic joins so return an error for anything not yet supported match join.join_constraint { @@ -364,7 +347,7 @@ pub fn to_substrait_rel( filter, &Arc::new(in_join_schema), 0, - extension_info, + extensions, )?), None => None, }; @@ -382,7 +365,7 @@ pub fn to_substrait_rel( eq_op, join.left.schema(), join.right.schema(), - extension_info, + extensions, )?; // create conjunction between `join_on` and `join_filter` to embed all join conditions, @@ -393,7 +376,7 @@ pub fn to_substrait_rel( on_expr, filter, Operator::And, - extension_info, + extensions, ))), None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist }, @@ -421,8 +404,8 @@ pub fn to_substrait_rel( right, schema: _, } = cross_join; - let left = to_substrait_rel(left.as_ref(), ctx, extension_info)?; - let right = to_substrait_rel(right.as_ref(), ctx, extension_info)?; + let left = to_substrait_rel(left.as_ref(), ctx, extensions)?; + let right = to_substrait_rel(right.as_ref(), ctx, extensions)?; Ok(Box::new(Rel { rel_type: Some(RelType::Cross(Box::new(CrossRel { common: None, @@ -435,13 +418,13 @@ pub fn to_substrait_rel( LogicalPlan::SubqueryAlias(alias) => { // Do nothing if encounters SubqueryAlias // since there is no corresponding relation type in Substrait - to_substrait_rel(alias.input.as_ref(), ctx, extension_info) + to_substrait_rel(alias.input.as_ref(), ctx, extensions) } LogicalPlan::Union(union) => { let input_rels = union .inputs .iter() - .map(|input| to_substrait_rel(input.as_ref(), ctx, extension_info)) + .map(|input| to_substrait_rel(input.as_ref(), ctx, extensions)) .collect::>>()? .into_iter() .map(|ptr| *ptr) @@ -456,7 +439,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Window(window) => { - let input = to_substrait_rel(window.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(window.input.as_ref(), ctx, extensions)?; // If the input is a Project relation, we can just append the WindowFunction expressions // before returning // Otherwise, wrap the input in a Project relation before appending the WindowFunction @@ -484,7 +467,7 @@ pub fn to_substrait_rel( expr, window.input.schema(), 0, - extension_info, + extensions, )?); } // Append parsed WindowFunction expressions @@ -494,8 +477,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Repartition(repartition) => { - let input = - to_substrait_rel(repartition.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(repartition.input.as_ref(), ctx, extensions)?; let partition_count = match repartition.partitioning_scheme { Partitioning::RoundRobinBatch(num) => num, Partitioning::Hash(_, num) => num, @@ -553,7 +535,7 @@ pub fn to_substrait_rel( .node .inputs() .into_iter() - .map(|plan| to_substrait_rel(plan, ctx, extension_info)) + .map(|plan| to_substrait_rel(plan, ctx, extensions)) .collect::>>()?; let rel_type = match inputs_rel.len() { 0 => RelType::ExtensionLeaf(ExtensionLeafRel { @@ -579,7 +561,10 @@ pub fn to_substrait_rel( } } -fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { +fn to_substrait_named_struct( + schema: &DFSchemaRef, + extensions: &mut Extensions, +) -> Result { // Substrait wants a list of all field names, including nested fields from structs, // also from within e.g. lists and maps. However, it does not want the list and map field names // themselves - only proper structs fields are considered to have useful names. @@ -624,7 +609,7 @@ fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { types: schema .fields() .iter() - .map(|f| to_substrait_type(f.data_type(), f.is_nullable())) + .map(|f| to_substrait_type(f.data_type(), f.is_nullable(), extensions)) .collect::>()?, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability: r#type::Nullability::Unspecified as i32, @@ -642,30 +627,27 @@ fn to_substrait_join_expr( eq_op: Operator, left_schema: &DFSchemaRef, right_schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result> { // Only support AND conjunction for each binary expression in join conditions let mut exprs: Vec = vec![]; for (left, right) in join_conditions { // Parse left - let l = to_substrait_rex(ctx, left, left_schema, 0, extension_info)?; + let l = to_substrait_rex(ctx, left, left_schema, 0, extensions)?; // Parse right let r = to_substrait_rex( ctx, right, right_schema, left_schema.fields().len(), // offset to return the correct index - extension_info, + extensions, )?; // AND with existing expression - exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extension_info)); + exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extensions)); } let join_expr: Option = exprs.into_iter().reduce(|acc: Expression, e: Expression| { - make_binary_op_scalar_func(&acc, &e, Operator::And, extension_info) + make_binary_op_scalar_func(&acc, &e, Operator::And, extensions) }); Ok(join_expr) } @@ -722,14 +704,11 @@ pub fn parse_flat_grouping_exprs( ctx: &SessionContext, exprs: &[Expr], schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { let grouping_expressions = exprs .iter() - .map(|e| to_substrait_rex(ctx, e, schema, 0, extension_info)) + .map(|e| to_substrait_rex(ctx, e, schema, 0, extensions)) .collect::>>()?; Ok(Grouping { grouping_expressions, @@ -740,10 +719,7 @@ pub fn to_substrait_groupings( ctx: &SessionContext, exprs: &[Expr], schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result> { match exprs.len() { 1 => match &exprs[0] { @@ -753,9 +729,7 @@ pub fn to_substrait_groupings( )), GroupingSet::GroupingSets(sets) => Ok(sets .iter() - .map(|set| { - parse_flat_grouping_exprs(ctx, set, schema, extension_info) - }) + .map(|set| parse_flat_grouping_exprs(ctx, set, schema, extensions)) .collect::>>()?), GroupingSet::Rollup(set) => { let mut sets: Vec> = vec![vec![]]; @@ -766,23 +740,17 @@ pub fn to_substrait_groupings( .iter() .rev() .map(|set| { - parse_flat_grouping_exprs(ctx, set, schema, extension_info) + parse_flat_grouping_exprs(ctx, set, schema, extensions) }) .collect::>>()?) } }, _ => Ok(vec![parse_flat_grouping_exprs( - ctx, - exprs, - schema, - extension_info, + ctx, exprs, schema, extensions, )?]), }, _ => Ok(vec![parse_flat_grouping_exprs( - ctx, - exprs, - schema, - extension_info, + ctx, exprs, schema, extensions, )?]), } } @@ -792,25 +760,22 @@ pub fn to_substrait_agg_measure( ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { match expr { Expr::AggregateFunction(expr::AggregateFunction { func_def, args, distinct, filter, order_by, null_treatment: _, }) => { match func_def { AggregateFunctionDefinition::BuiltIn (fun) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? } else { vec![] }; let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extensions)?)) }); } - let function_anchor = register_function(fun.to_string(), extension_info); + let function_anchor = extensions.register_function(fun.to_string()); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, @@ -826,22 +791,22 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), + Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extensions)?), None => None } }) } AggregateFunctionDefinition::UDF(fun) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? } else { vec![] }; let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extensions)?)) }); } - let function_anchor = register_function(fun.name().to_string(), extension_info); + let function_anchor = extensions.register_function(fun.name().to_string()); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, @@ -857,7 +822,7 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), + Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extensions)?), None => None } }) @@ -866,7 +831,7 @@ pub fn to_substrait_agg_measure( } Expr::Alias(Alias{expr,..})=> { - to_substrait_agg_measure(ctx, expr, schema, extension_info) + to_substrait_agg_measure(ctx, expr, schema, extensions) } _ => internal_err!( "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", @@ -881,10 +846,7 @@ fn to_substrait_sort_field( ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { match expr { Expr::Sort(sort) => { @@ -900,7 +862,7 @@ fn to_substrait_sort_field( sort.expr.deref(), schema, 0, - extension_info, + extensions, )?), sort_kind: Some(SortKind::Direction(sort_kind.into())), }) @@ -909,67 +871,15 @@ fn to_substrait_sort_field( } } -fn register_function( - function_name: String, - extension_info: &mut ( - Vec, - HashMap, - ), -) -> u32 { - let (function_extensions, function_set) = extension_info; - let function_name = function_name.to_lowercase(); - - // Some functions are named differently in Substrait default extensions than in DF - // Rename those to match the Substrait extensions for interoperability - let function_name = match function_name.as_str() { - "substr" => "substring".to_string(), - _ => function_name, - }; - - // To prevent ambiguous references between ScalarFunctions and AggregateFunctions, - // a plan-relative identifier starting from 0 is used as the function_anchor. - // The consumer is responsible for correctly registering - // mapping info stored in the extensions by the producer. - let function_anchor = match function_set.get(&function_name) { - Some(function_anchor) => { - // Function has been registered - *function_anchor - } - None => { - // Function has NOT been registered - let function_anchor = function_set.len() as u32; - function_set.insert(function_name.clone(), function_anchor); - - let function_extension = ExtensionFunction { - extension_uri_reference: u32::MAX, - function_anchor, - name: function_name, - }; - let simple_extension = extensions::SimpleExtensionDeclaration { - mapping_type: Some(MappingType::ExtensionFunction(function_extension)), - }; - function_extensions.push(simple_extension); - function_anchor - } - }; - - // Return function anchor - function_anchor -} - /// Return Substrait scalar function with two arguments #[allow(deprecated)] pub fn make_binary_op_scalar_func( lhs: &Expression, rhs: &Expression, op: Operator, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Expression { - let function_anchor = - register_function(operator_to_name(op).to_string(), extension_info); + let function_anchor = extensions.register_function(operator_to_name(op).to_string()); Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -1010,17 +920,14 @@ pub fn make_binary_op_scalar_func( /// `col_ref(1) = col_ref(3 + 0)` /// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index /// of the join key column from `right` -/// * `extension_info` - Substrait extension info. Contains registered function information +/// * `extensions` - Substrait extension info. Contains registered function information #[allow(deprecated)] pub fn to_substrait_rex( ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { match expr { Expr::InList(InList { @@ -1030,10 +937,10 @@ pub fn to_substrait_rex( }) => { let substrait_list = list .iter() - .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, extension_info)) + .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, extensions)) .collect::>>()?; let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; let substrait_or_list = Expression { rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { @@ -1043,8 +950,7 @@ pub fn to_substrait_rex( }; if *negated { - let function_anchor = - register_function("not".to_string(), extension_info); + let function_anchor = extensions.register_function("not".to_string()); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1070,13 +976,12 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, )?)), }); } - let function_anchor = - register_function(fun.name().to_string(), extension_info); + let function_anchor = extensions.register_function(fun.name().to_string()); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -1096,58 +1001,58 @@ pub fn to_substrait_rex( if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; let substrait_low = - to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extensions)?; let substrait_high = - to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extensions)?; let l_expr = make_binary_op_scalar_func( &substrait_expr, &substrait_low, Operator::Lt, - extension_info, + extensions, ); let r_expr = make_binary_op_scalar_func( &substrait_high, &substrait_expr, Operator::Lt, - extension_info, + extensions, ); Ok(make_binary_op_scalar_func( &l_expr, &r_expr, Operator::Or, - extension_info, + extensions, )) } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; let substrait_low = - to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extensions)?; let substrait_high = - to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extensions)?; let l_expr = make_binary_op_scalar_func( &substrait_low, &substrait_expr, Operator::LtEq, - extension_info, + extensions, ); let r_expr = make_binary_op_scalar_func( &substrait_expr, &substrait_high, Operator::LtEq, - extension_info, + extensions, ); Ok(make_binary_op_scalar_func( &l_expr, &r_expr, Operator::And, - extension_info, + extensions, )) } } @@ -1156,10 +1061,10 @@ pub fn to_substrait_rex( substrait_field_ref(index + col_ref_offset) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extension_info)?; - let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extension_info)?; + let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extensions)?; + let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extensions)?; - Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info)) + Ok(make_binary_op_scalar_func(&l, &r, *op, extensions)) } Expr::Case(Case { expr, @@ -1176,7 +1081,7 @@ pub fn to_substrait_rex( e, schema, col_ref_offset, - extension_info, + extensions, )?), then: None, }); @@ -1189,14 +1094,14 @@ pub fn to_substrait_rex( r#if, schema, col_ref_offset, - extension_info, + extensions, )?), then: Some(to_substrait_rex( ctx, then, schema, col_ref_offset, - extension_info, + extensions, )?), }); } @@ -1208,7 +1113,7 @@ pub fn to_substrait_rex( e, schema, col_ref_offset, - extension_info, + extensions, )?)), None => None, }; @@ -1221,22 +1126,22 @@ pub fn to_substrait_rex( Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), + r#type: Some(to_substrait_type(data_type, true, extensions)?), input: Some(Box::new(to_substrait_rex( ctx, expr, schema, col_ref_offset, - extension_info, + extensions, )?)), failure_behavior: 0, // FAILURE_BEHAVIOR_UNSPECIFIED }, ))), }) } - Expr::Literal(value) => to_substrait_literal_expr(value), + Expr::Literal(value) => to_substrait_literal_expr(value, extensions), Expr::Alias(Alias { expr, .. }) => { - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info) + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions) } Expr::WindowFunction(WindowFunction { fun, @@ -1247,7 +1152,7 @@ pub fn to_substrait_rex( null_treatment: _, }) => { // function reference - let function_anchor = register_function(fun.to_string(), extension_info); + let function_anchor = extensions.register_function(fun.to_string()); // arguments let mut arguments: Vec = vec![]; for arg in args { @@ -1257,19 +1162,19 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, )?)), }); } // partition by expressions let partition_by = partition_by .iter() - .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, extension_info)) + .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, extensions)) .collect::>>()?; // order by expressions let order_by = order_by .iter() - .map(|e| substrait_sort_field(ctx, e, schema, extension_info)) + .map(|e| substrait_sort_field(ctx, e, schema, extensions)) .collect::>>()?; // window frame let bounds = to_substrait_bounds(window_frame)?; @@ -1298,7 +1203,7 @@ pub fn to_substrait_rex( *escape_char, schema, col_ref_offset, - extension_info, + extensions, ), Expr::InSubquery(InSubquery { expr, @@ -1306,10 +1211,10 @@ pub fn to_substrait_rex( negated, }) => { let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; let subquery_plan = - to_substrait_rel(subquery.subquery.as_ref(), ctx, extension_info)?; + to_substrait_rel(subquery.subquery.as_ref(), ctx, extensions)?; let substrait_subquery = Expression { rex_type: Some(RexType::Subquery(Box::new(Subquery { @@ -1324,8 +1229,7 @@ pub fn to_substrait_rex( }))), }; if *negated { - let function_anchor = - register_function("not".to_string(), extension_info); + let function_anchor = extensions.register_function("not".to_string()); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1348,7 +1252,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNull(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1356,7 +1260,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1364,7 +1268,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1372,7 +1276,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1380,7 +1284,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1388,7 +1292,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1396,7 +1300,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1404,7 +1308,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1412,7 +1316,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::Negative(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1420,7 +1324,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), _ => { not_impl_err!("Unsupported expression: {expr:?}") @@ -1428,7 +1332,11 @@ pub fn to_substrait_rex( } } -fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { +fn to_substrait_type( + dt: &DataType, + nullable: bool, + extensions: &mut Extensions, +) -> Result { let nullability = if nullable { r#type::Nullability::Nullable as i32 } else { @@ -1548,7 +1456,9 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result Result { - let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; + let inner_type = + to_substrait_type(inner.data_type(), inner.is_nullable(), extensions)?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), @@ -1599,7 +1510,8 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { - let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; + let inner_type = + to_substrait_type(inner.data_type(), inner.is_nullable(), extensions)?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), @@ -1613,10 +1525,12 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result Result { let field_types = fields .iter() - .map(|field| to_substrait_type(field.data_type(), field.is_nullable())) + .map(|field| { + to_substrait_type(field.data_type(), field.is_nullable(), extensions) + }) .collect::>>()?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Struct(r#type::Struct { @@ -1700,21 +1616,19 @@ fn make_substrait_like_expr( escape_char: Option, schema: &DFSchemaRef, col_ref_offset: usize, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { let function_anchor = if ignore_case { - register_function("ilike".to_string(), extension_info) + extensions.register_function("ilike".to_string()) } else { - register_function("like".to_string(), extension_info) + extensions.register_function("like".to_string()) }; - let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; - let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?; - let escape_char = to_substrait_literal_expr(&ScalarValue::Utf8( - escape_char.map(|c| c.to_string()), - ))?; + let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; + let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extensions)?; + let escape_char = to_substrait_literal_expr( + &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), + extensions, + )?; let arguments = vec![ FunctionArgument { arg_type: Some(ArgType::Value(expr)), @@ -1738,7 +1652,7 @@ fn make_substrait_like_expr( }; if negated { - let function_anchor = register_function("not".to_string(), extension_info); + let function_anchor = extensions.register_function("not".to_string()); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1870,7 +1784,10 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { )) } -fn to_substrait_literal(value: &ScalarValue) -> Result { +fn to_substrait_literal( + value: &ScalarValue, + extensions: &mut Extensions, +) -> Result { if value.is_null() { return Ok(Literal { nullable: true, @@ -1878,6 +1795,7 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { literal_type: Some(LiteralType::Null(to_substrait_type( &value.data_type(), true, + extensions, )?)), }); } @@ -1949,14 +1867,15 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { let bytes = i.to_byte_slice(); ( LiteralType::UserDefined(UserDefined { - type_reference: INTERVAL_MONTH_DAY_NANO_TYPE_REF, + type_reference: extensions + .register_type(INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string()), type_parameters: vec![], val: Some(user_defined::Val::Value(ProtoAny { - type_url: INTERVAL_MONTH_DAY_NANO_TYPE_URL.to_string(), + type_url: INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string(), value: bytes.to_vec().into(), })), }), - INTERVAL_MONTH_DAY_NANO_TYPE_REF, + DEFAULT_TYPE_VARIATION_REF, ) } ScalarValue::IntervalDayTime(Some(i)) => ( @@ -1996,11 +1915,11 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { DECIMAL_128_TYPE_VARIATION_REF, ), ScalarValue::List(l) => ( - convert_array_to_literal_list(l)?, + convert_array_to_literal_list(l, extensions)?, DEFAULT_CONTAINER_TYPE_VARIATION_REF, ), ScalarValue::LargeList(l) => ( - convert_array_to_literal_list(l)?, + convert_array_to_literal_list(l, extensions)?, LARGE_CONTAINER_TYPE_VARIATION_REF, ), ScalarValue::Struct(s) => ( @@ -2009,7 +1928,10 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { .columns() .iter() .map(|col| { - to_substrait_literal(&ScalarValue::try_from_array(col, 0)?) + to_substrait_literal( + &ScalarValue::try_from_array(col, 0)?, + extensions, + ) }) .collect::>>()?, }), @@ -2030,16 +1952,26 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { fn convert_array_to_literal_list( array: &GenericListArray, + extensions: &mut Extensions, ) -> Result { assert_eq!(array.len(), 1); let nested_array = array.value(0); let values = (0..nested_array.len()) - .map(|i| to_substrait_literal(&ScalarValue::try_from_array(&nested_array, i)?)) + .map(|i| { + to_substrait_literal( + &ScalarValue::try_from_array(&nested_array, i)?, + extensions, + ) + }) .collect::>>()?; if values.is_empty() { - let et = match to_substrait_type(array.data_type(), array.is_nullable())? { + let et = match to_substrait_type( + array.data_type(), + array.is_nullable(), + extensions, + )? { substrait::proto::Type { kind: Some(r#type::Kind::List(lt)), } => lt.as_ref().to_owned(), @@ -2051,8 +1983,11 @@ fn convert_array_to_literal_list( } } -fn to_substrait_literal_expr(value: &ScalarValue) -> Result { - let literal = to_substrait_literal(value)?; +fn to_substrait_literal_expr( + value: &ScalarValue, + extensions: &mut Extensions, +) -> Result { + let literal = to_substrait_literal(value, extensions)?; Ok(Expression { rex_type: Some(RexType::Literal(literal)), }) @@ -2065,14 +2000,10 @@ fn to_substrait_unary_scalar_fn( arg: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { - let function_anchor = register_function(fn_name.to_string(), extension_info); - let substrait_expr = - to_substrait_rex(ctx, arg, schema, col_ref_offset, extension_info)?; + let function_anchor = extensions.register_function(fn_name.to_string()); + let substrait_expr = to_substrait_rex(ctx, arg, schema, col_ref_offset, extensions)?; Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -2116,10 +2047,7 @@ fn substrait_sort_field( ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { match expr { Expr::Sort(Sort { @@ -2127,7 +2055,7 @@ fn substrait_sort_field( asc, nulls_first, }) => { - let e = to_substrait_rex(ctx, expr, schema, 0, extension_info)?; + let e = to_substrait_rex(ctx, expr, schema, 0, extensions)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, @@ -2161,6 +2089,7 @@ fn substrait_field_ref(index: usize) -> Result { #[cfg(test)] mod test { + use super::*; use crate::logical_plan::consumer::{ from_substrait_literal_without_names, from_substrait_type_without_names, }; @@ -2168,8 +2097,7 @@ mod test { use datafusion::arrow::array::GenericListArray; use datafusion::arrow::datatypes::Field; use datafusion::common::scalar::ScalarStructBuilder; - - use super::*; + use std::collections::HashMap; #[test] fn round_trip_literals() -> Result<()> { @@ -2258,12 +2186,47 @@ mod test { fn round_trip_literal(scalar: ScalarValue) -> Result<()> { println!("Checking round trip of {scalar:?}"); - let substrait_literal = to_substrait_literal(&scalar)?; - let roundtrip_scalar = from_substrait_literal_without_names(&substrait_literal)?; + let mut extensions = Extensions::default(); + let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?; + let roundtrip_scalar = + from_substrait_literal_without_names(&substrait_literal, &extensions)?; assert_eq!(scalar, roundtrip_scalar); Ok(()) } + #[test] + fn custom_type_literal_extensions() -> Result<()> { + let mut extensions = Extensions::default(); + // IntervalMonthDayNano is represented as a custom type in Substrait + let scalar = ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano::new( + 17, 25, 1234567890, + ))); + let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?; + let roundtrip_scalar = + from_substrait_literal_without_names(&substrait_literal, &extensions)?; + assert_eq!(scalar, roundtrip_scalar); + + assert_eq!( + extensions, + Extensions { + functions: HashMap::new(), + types: HashMap::from([( + 0, + INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string() + )]), + type_variations: HashMap::new(), + } + ); + + // Check we fail if we don't propagate extensions + assert!(from_substrait_literal_without_names( + &substrait_literal, + &Extensions::default() + ) + .is_err()); + Ok(()) + } + #[test] fn round_trip_types() -> Result<()> { round_trip_type(DataType::Boolean)?; @@ -2329,11 +2292,44 @@ mod test { fn round_trip_type(dt: DataType) -> Result<()> { println!("Checking round trip of {dt:?}"); + let mut extensions = Extensions::default(); + // As DataFusion doesn't consider nullability as a property of the type, but field, // it doesn't matter if we set nullability to true or false here. - let substrait = to_substrait_type(&dt, true)?; - let roundtrip_dt = from_substrait_type_without_names(&substrait)?; + let substrait = to_substrait_type(&dt, true, &mut extensions)?; + let roundtrip_dt = from_substrait_type_without_names(&substrait, &extensions)?; + assert_eq!(dt, roundtrip_dt); + Ok(()) + } + + #[test] + fn custom_type_extensions() -> Result<()> { + let mut extensions = Extensions::default(); + // IntervalMonthDayNano is represented as a custom type in Substrait + let dt = DataType::Interval(IntervalUnit::MonthDayNano); + + let substrait = to_substrait_type(&dt, true, &mut extensions)?; + let roundtrip_dt = from_substrait_type_without_names(&substrait, &extensions)?; assert_eq!(dt, roundtrip_dt); + + assert_eq!( + extensions, + Extensions { + functions: HashMap::new(), + types: HashMap::from([( + 0, + INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string() + )]), + type_variations: HashMap::new(), + } + ); + + // Check we fail if we don't propagate extensions + assert!( + from_substrait_type_without_names(&substrait, &Extensions::default()) + .is_err() + ); + Ok(()) } } diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs index 27f4b3ea228a..c94ad2d669fd 100644 --- a/datafusion/substrait/src/variation_const.rs +++ b/datafusion/substrait/src/variation_const.rs @@ -25,13 +25,16 @@ //! - Default type reference is 0. It is used when the actual type is the same with the original type. //! - Extended variant type references start from 1, and ususlly increase by 1. //! -//! Definitions here are not the final form. All the non-system-preferred variations will be defined +//! TODO: Definitions here are not the final form. All the non-system-preferred variations will be defined //! using [simple extensions] as per the [spec of type_variations](https://substrait.io/types/type_variations/) +//! //! //! [simple extensions]: (https://substrait.io/extensions/#simple-extensions) // For [type variations](https://substrait.io/types/type_variations/#type-variations) in substrait. // Type variations are used to represent different types based on one type class. +// TODO: Define as extensions: + /// The "system-preferred" variation (i.e., no variation). pub const DEFAULT_TYPE_VARIATION_REF: u32 = 0; pub const UNSIGNED_INTEGER_TYPE_VARIATION_REF: u32 = 1; @@ -55,6 +58,7 @@ pub const DECIMAL_256_TYPE_VARIATION_REF: u32 = 1; /// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval /// [`IntervalUnit::YearMonth`]: datafusion::arrow::datatypes::IntervalUnit::YearMonth /// [`ScalarValue::IntervalYearMonth`]: datafusion::common::ScalarValue::IntervalYearMonth +#[deprecated(since = "41.0.0", note = "Use Substrait `IntervalYear` type instead")] pub const INTERVAL_YEAR_MONTH_TYPE_REF: u32 = 1; /// For [`DataType::Interval`] with [`IntervalUnit::DayTime`]. @@ -68,6 +72,7 @@ pub const INTERVAL_YEAR_MONTH_TYPE_REF: u32 = 1; /// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval /// [`IntervalUnit::DayTime`]: datafusion::arrow::datatypes::IntervalUnit::DayTime /// [`ScalarValue::IntervalDayTime`]: datafusion::common::ScalarValue::IntervalDayTime +#[deprecated(since = "41.0.0", note = "Use Substrait `IntervalDay` type instead")] pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2; /// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`]. @@ -82,21 +87,14 @@ pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2; /// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval /// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano /// [`ScalarValue::IntervalMonthDayNano`]: datafusion::common::ScalarValue::IntervalMonthDayNano +#[deprecated( + since = "41.0.0", + note = "Use Substrait `UserDefinedType` with name `INTERVAL_MONTH_DAY_NANO_TYPE_NAME` instead" +)] pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3; -// For User Defined URLs -/// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`]. -/// -/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval -/// [`IntervalUnit::YearMonth`]: datafusion::arrow::datatypes::IntervalUnit::YearMonth -pub const INTERVAL_YEAR_MONTH_TYPE_URL: &str = "interval-year-month"; -/// For [`DataType::Interval`] with [`IntervalUnit::DayTime`]. -/// -/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval -/// [`IntervalUnit::DayTime`]: datafusion::arrow::datatypes::IntervalUnit::DayTime -pub const INTERVAL_DAY_TIME_TYPE_URL: &str = "interval-day-time"; /// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`]. /// /// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval /// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano -pub const INTERVAL_MONTH_DAY_NANO_TYPE_URL: &str = "interval-month-day-nano"; +pub const INTERVAL_MONTH_DAY_NANO_TYPE_NAME: &str = "interval-month-day-nano"; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index a7653e11d598..5b4389c832c7 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -38,7 +38,10 @@ use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLI use datafusion::prelude::*; use datafusion::execution::session_state::SessionStateBuilder; -use substrait::proto::extensions::simple_extension_declaration::MappingType; +use substrait::proto::extensions::simple_extension_declaration::{ + ExtensionType, MappingType, +}; +use substrait::proto::extensions::SimpleExtensionDeclaration; use substrait::proto::rel::RelType; use substrait::proto::{plan_rel, Plan, Rel}; @@ -175,15 +178,46 @@ async fn select_with_filter() -> Result<()> { #[tokio::test] async fn select_with_reused_functions() -> Result<()> { + let ctx = create_context().await?; let sql = "SELECT * FROM data WHERE a > 1 AND a < 10 AND b > 0"; - roundtrip(sql).await?; - let (mut function_names, mut function_anchors) = function_extension_info(sql).await?; - function_names.sort(); - function_anchors.sort(); + let proto = roundtrip_with_ctx(sql, ctx).await?; + let mut functions = proto + .extensions + .iter() + .map(|e| match e.mapping_type.as_ref().unwrap() { + MappingType::ExtensionFunction(ext_f) => { + (ext_f.function_anchor, ext_f.name.to_owned()) + } + _ => unreachable!("Non-function extensions not expected"), + }) + .collect::>(); + functions.sort_by_key(|(anchor, _)| *anchor); + + // Functions are encountered (and thus registered) depth-first + let expected = vec![ + (0, "gt".to_string()), + (1, "lt".to_string()), + (2, "and".to_string()), + ]; + assert_eq!(functions, expected); - assert_eq!(function_names, ["and", "gt", "lt"]); - assert_eq!(function_anchors, [0, 1, 2]); + Ok(()) +} +#[tokio::test] +async fn roundtrip_udt_extensions() -> Result<()> { + let ctx = create_context().await?; + let proto = + roundtrip_with_ctx("SELECT INTERVAL '1 YEAR 1 DAY 1 SECOND' FROM data", ctx) + .await?; + let expected_type = SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionType(ExtensionType { + extension_uri_reference: u32::MAX, + type_anchor: 0, + name: "interval-month-day-nano".to_string(), + })), + }; + assert_eq!(proto.extensions, vec![expected_type]); Ok(()) } @@ -858,7 +892,8 @@ async fn roundtrip_aggregate_udf() -> Result<()> { let ctx = create_context().await?; ctx.register_udaf(dummy_agg); - roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await + roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await?; + Ok(()) } #[tokio::test] @@ -891,7 +926,8 @@ async fn roundtrip_window_udf() -> Result<()> { let ctx = create_context().await?; ctx.register_udwf(dummy_agg); - roundtrip_with_ctx("select dummy_window(a) OVER () from data", ctx).await + roundtrip_with_ctx("select dummy_window(a) OVER () from data", ctx).await?; + Ok(()) } #[tokio::test] @@ -1083,7 +1119,7 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> { Ok(()) } -async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> { +async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result> { let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; let proto = to_substrait_plan(&plan, &ctx)?; @@ -1102,56 +1138,25 @@ async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> { assert_eq!(plan.schema(), plan2.schema()); DataFrame::new(ctx.state(), plan2).show().await?; - Ok(()) + Ok(proto) } async fn roundtrip(sql: &str) -> Result<()> { - roundtrip_with_ctx(sql, create_context().await?).await + roundtrip_with_ctx(sql, create_context().await?).await?; + Ok(()) } async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> { let ctx = create_context().await?; - let df = ctx.sql(sql).await?; - let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; - let plan2 = ctx.state().optimize(&plan2)?; - - println!("{plan:#?}"); - println!("{plan2:#?}"); - - let plan1str = format!("{plan:?}"); - let plan2str = format!("{plan2:?}"); - assert_eq!(plan1str, plan2str); - - assert_eq!(plan.schema(), plan2.schema()); + let proto = roundtrip_with_ctx(sql, ctx).await?; // verify that the join filters are None verify_post_join_filter_value(proto).await } async fn roundtrip_all_types(sql: &str) -> Result<()> { - roundtrip_with_ctx(sql, create_all_type_context().await?).await -} - -async fn function_extension_info(sql: &str) -> Result<(Vec, Vec)> { - let ctx = create_context().await?; - let df = ctx.sql(sql).await?; - let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - - let mut function_names: Vec = vec![]; - let mut function_anchors: Vec = vec![]; - for e in &proto.extensions { - let (function_anchor, function_name) = match e.mapping_type.as_ref().unwrap() { - MappingType::ExtensionFunction(ext_f) => (ext_f.function_anchor, &ext_f.name), - _ => unreachable!("Producer does not generate a non-function extension"), - }; - function_names.push(function_name.to_string()); - function_anchors.push(function_anchor); - } - - Ok((function_names, function_anchors)) + roundtrip_with_ctx(sql, create_all_type_context().await?).await?; + Ok(()) } async fn create_context() -> Result { From af0d2baf02e169760beedb8465ec1d9e5563d2be Mon Sep 17 00:00:00 2001 From: yfu Date: Sat, 20 Jul 2024 05:55:42 +1000 Subject: [PATCH 13/37] make unparser Dialect Send + Sync (#11504) --- datafusion/sql/src/unparser/dialect.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 87453f81ee3d..1e82fc2b3c1b 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -27,7 +27,7 @@ use sqlparser::{ast, keywords::ALL_KEYWORDS}; /// /// See /// See also the discussion in -pub trait Dialect { +pub trait Dialect: Send + Sync { /// Return the character used to quote identifiers. fn identifier_quote_style(&self, _identifier: &str) -> Option; From f1953528187828bc3636e90fa7d640d5cb3e54d1 Mon Sep 17 00:00:00 2001 From: yfu Date: Sat, 20 Jul 2024 05:56:28 +1000 Subject: [PATCH 14/37] fix: unparser generates wrong sql for derived table with columns (#17) (#11505) * fix unparser for derived table with columns * refactoring * renaming * case in tests --- datafusion/sql/src/unparser/plan.rs | 77 ++++++++++++++++++++--- datafusion/sql/tests/cases/plan_to_sql.rs | 29 +++++++++ 2 files changed, 96 insertions(+), 10 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 26fd47299637..7f050d8a0690 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -19,7 +19,7 @@ use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError, R use datafusion_expr::{ expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, Projection, }; -use sqlparser::ast::{self, SetExpr}; +use sqlparser::ast::{self, Ident, SetExpr}; use crate::unparser::utils::unproject_agg_exprs; @@ -457,15 +457,11 @@ impl Unparser<'_> { } LogicalPlan::SubqueryAlias(plan_alias) => { // Handle bottom-up to allocate relation - self.select_to_sql_recursively( - plan_alias.input.as_ref(), - query, - select, - relation, - )?; + let (plan, columns) = subquery_alias_inner_query_and_columns(plan_alias); + self.select_to_sql_recursively(plan, query, select, relation)?; relation.alias(Some( - self.new_table_alias(plan_alias.alias.table().to_string()), + self.new_table_alias(plan_alias.alias.table().to_string(), columns), )); Ok(()) @@ -599,10 +595,10 @@ impl Unparser<'_> { self.binary_op_to_sql(lhs, rhs, ast::BinaryOperator::And) } - fn new_table_alias(&self, alias: String) -> ast::TableAlias { + fn new_table_alias(&self, alias: String, columns: Vec) -> ast::TableAlias { ast::TableAlias { name: self.new_ident_quoted_if_needs(alias), - columns: Vec::new(), + columns, } } @@ -611,6 +607,67 @@ impl Unparser<'_> { } } +// This logic is to work out the columns and inner query for SubqueryAlias plan for both types of +// subquery +// - `(SELECT column_a as a from table) AS A` +// - `(SELECT column_a from table) AS A (a)` +// +// A roundtrip example for table alias with columns +// +// query: SELECT id FROM (SELECT j1_id from j1) AS c (id) +// +// LogicPlan: +// Projection: c.id +// SubqueryAlias: c +// Projection: j1.j1_id AS id +// Projection: j1.j1_id +// TableScan: j1 +// +// Before introducing this logic, the unparsed query would be `SELECT c.id FROM (SELECT j1.j1_id AS +// id FROM (SELECT j1.j1_id FROM j1)) AS c`. +// The query is invalid as `j1.j1_id` is not a valid identifier in the derived table +// `(SELECT j1.j1_id FROM j1)` +// +// With this logic, the unparsed query will be: +// `SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)` +// +// Caveat: this won't handle the case like `select * from (select 1, 2) AS a (b, c)` +// as the parser gives a wrong plan which has mismatch `Int(1)` types: Literal and +// Column in the Projections. Once the parser side is fixed, this logic should work +fn subquery_alias_inner_query_and_columns( + subquery_alias: &datafusion_expr::SubqueryAlias, +) -> (&LogicalPlan, Vec) { + let plan: &LogicalPlan = subquery_alias.input.as_ref(); + + let LogicalPlan::Projection(outer_projections) = plan else { + return (plan, vec![]); + }; + + // check if it's projection inside projection + let LogicalPlan::Projection(inner_projection) = outer_projections.input.as_ref() + else { + return (plan, vec![]); + }; + + let mut columns: Vec = vec![]; + // check if the inner projection and outer projection have a matching pattern like + // Projection: j1.j1_id AS id + // Projection: j1.j1_id + for (i, inner_expr) in inner_projection.expr.iter().enumerate() { + let Expr::Alias(ref outer_alias) = &outer_projections.expr[i] else { + return (plan, vec![]); + }; + + if outer_alias.expr.as_ref() != inner_expr { + return (plan, vec![]); + }; + + columns.push(outer_alias.name.as_str().into()); + } + + (outer_projections.input.as_ref(), columns) +} + impl From for DataFusionError { fn from(e: BuilderError) -> Self { DataFusionError::External(Box::new(e)) diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 91295b2e8aae..ed79a1dfc0c7 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -240,6 +240,35 @@ fn roundtrip_statement_with_dialect() -> Result<()> { parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), }, + // more tests around subquery/derived table roundtrip + TestStatementWithDialect { + sql: "SELECT string_count FROM ( + SELECT + j1_id, + MIN(j2_string) + FROM + j1 LEFT OUTER JOIN j2 ON + j1_id = j2_id + GROUP BY + j1_id + ) AS agg (id, string_count) + ", + expected: r#"SELECT agg.string_count FROM (SELECT j1.j1_id, MIN(j2.j2_string) FROM j1 LEFT JOIN j2 ON (j1.j1_id = j2.j2_id) GROUP BY j1.j1_id) AS agg (id, string_count)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT j1_id from j1) AS c (id)", + expected: r#"SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT j1_id as id from j1) AS c", + expected: r#"SELECT c.id FROM (SELECT j1.j1_id AS id FROM j1) AS c"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, ]; for query in tests { From 9189a1acddbe0da9ab3cbdb3a317a6a45a561f41 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Fri, 19 Jul 2024 22:02:30 +0200 Subject: [PATCH 15/37] Prevent bigger files from being checked in (#11508) --- .github/workflows/large_files.yml | 55 +++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 .github/workflows/large_files.yml diff --git a/.github/workflows/large_files.yml b/.github/workflows/large_files.yml new file mode 100644 index 000000000000..aa96d55a0d85 --- /dev/null +++ b/.github/workflows/large_files.yml @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Large files PR check + +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + +on: + pull_request: + +jobs: + check-files: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Check size of new Git objects + env: + # 1 MB ought to be enough for anybody. + # TODO in case we may want to consciously commit a bigger file to the repo without using Git LFS we may disable the check e.g. with a label + MAX_FILE_SIZE_BYTES: 1048576 + shell: bash + run: | + git rev-list --objects ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }} \ + > pull-request-objects.txt + exit_code=0 + while read -r id path; do + # Skip objects which are not files (commits, trees) + if [ ! -z "${path}" ]; then + size="$(git cat-file -s "${id}")" + if [ "${size}" -gt "${MAX_FILE_SIZE_BYTES}" ]; then + exit_code=1 + echo "Object ${id} [${path}] has size ${size}, exceeding ${MAX_FILE_SIZE_BYTES} limit." >&2 + echo "::error file=${path}::File ${path} has size ${size}, exceeding ${MAX_FILE_SIZE_BYTES} limit." + fi + fi + done < pull-request-objects.txt + exit "${exit_code}" From ebe61bae2aeda41b576c4a6e6fc96c5a502e7150 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Sat, 20 Jul 2024 04:22:31 +0800 Subject: [PATCH 16/37] fix: make `UnKnownColumn`s not equal to others physical exprs (#11536) * fix: fall back to `UnionExec` if can't interleave * alternative fix * check interleavable in with_new_children * link to pr --- .../physical-expr/src/expressions/column.rs | 10 ++--- datafusion/physical-plan/src/union.rs | 6 +++ datafusion/sqllogictest/test_files/union.slt | 45 +++++++++++++++++++ 3 files changed, 55 insertions(+), 6 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 38779c54607f..ab43201ceb75 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -21,7 +21,6 @@ use std::any::Any; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::{ @@ -95,11 +94,10 @@ impl PhysicalExpr for UnKnownColumn { } impl PartialEq for UnKnownColumn { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self == x) - .unwrap_or(false) + fn eq(&self, _other: &dyn Any) -> bool { + // UnknownColumn is not a valid expression, so it should not be equal to any other expression. + // See https://github.com/apache/datafusion/pull/11536 + false } } diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index b39c6aee82b9..24c80048ab4a 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -431,6 +431,12 @@ impl ExecutionPlan for InterleaveExec { self: Arc, children: Vec>, ) -> Result> { + // New children are no longer interleavable, which might be a bug of optimization rewrite. + if !can_interleave(children.iter()) { + return internal_err!( + "Can not create InterleaveExec: new children can not be interleaved" + ); + } Ok(Arc::new(InterleaveExec::try_new(children)?)) } diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 31b16f975e9e..2dc8385bf191 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -602,3 +602,48 @@ physical_plan 09)--ProjectionExec: expr=[1 as count, MAX(Int64(10))@0 as n] 10)----AggregateExec: mode=Single, gby=[], aggr=[MAX(Int64(10))] 11)------PlaceholderRowExec + + +# Test issue: https://github.com/apache/datafusion/issues/11409 +statement ok +CREATE TABLE t1(v0 BIGINT, v1 BIGINT, v2 BIGINT, v3 BOOLEAN); + +statement ok +CREATE TABLE t2(v0 DOUBLE); + +query I +INSERT INTO t1(v0, v2, v1) VALUES (-1229445667, -342312412, -1507138076); +---- +1 + +query I +INSERT INTO t1(v0, v1) VALUES (1541512604, -1229445667); +---- +1 + +query I +INSERT INTO t1(v1, v3, v0, v2) VALUES (-1020641465, false, -1493773377, 1751276473); +---- +1 + +query I +INSERT INTO t1(v3) VALUES (true), (true), (false); +---- +3 + +query I +INSERT INTO t2(v0) VALUES (0.28014577292925047); +---- +1 + +query II +SELECT t1.v2, t1.v0 FROM t2 NATURAL JOIN t1 + UNION ALL +SELECT t1.v2, t1.v0 FROM t2 NATURAL JOIN t1 WHERE (t1.v2 IS NULL); +---- + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; From 827d0e3a29c0ea34bafbf03f5102407bd8e9b826 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Fri, 19 Jul 2024 13:23:32 -0700 Subject: [PATCH 17/37] Add dialect param to use double precision for float64 in Postgres (#11495) * Add dialect param to use double precision for float64 in Postgres * return ast data type instead of bool * Fix errors in merging * fix --- datafusion/sql/src/unparser/dialect.rs | 28 ++++++++++++++++++++++++ datafusion/sql/src/unparser/expr.rs | 30 +++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 1e82fc2b3c1b..ed0cfddc3827 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -46,11 +46,18 @@ pub trait Dialect: Send + Sync { IntervalStyle::PostgresVerbose } + // Does the dialect use DOUBLE PRECISION to represent Float64 rather than DOUBLE? + // E.g. Postgres uses DOUBLE PRECISION instead of DOUBLE + fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { + sqlparser::ast::DataType::Double + } + // The SQL type to use for Arrow Utf8 unparsing // Most dialects use VARCHAR, but some, like MySQL, require CHAR fn utf8_cast_dtype(&self) -> ast::DataType { ast::DataType::Varchar(None) } + // The SQL type to use for Arrow LargeUtf8 unparsing // Most dialects use TEXT, but some, like MySQL, require CHAR fn large_utf8_cast_dtype(&self) -> ast::DataType { @@ -98,6 +105,10 @@ impl Dialect for PostgreSqlDialect { fn interval_style(&self) -> IntervalStyle { IntervalStyle::PostgresVerbose } + + fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { + sqlparser::ast::DataType::DoublePrecision + } } pub struct MySqlDialect {} @@ -137,6 +148,7 @@ pub struct CustomDialect { supports_nulls_first_in_sort: bool, use_timestamp_for_date64: bool, interval_style: IntervalStyle, + float64_ast_dtype: sqlparser::ast::DataType, utf8_cast_dtype: ast::DataType, large_utf8_cast_dtype: ast::DataType, } @@ -148,6 +160,7 @@ impl Default for CustomDialect { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::SQLStandard, + float64_ast_dtype: sqlparser::ast::DataType::Double, utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, } @@ -182,6 +195,10 @@ impl Dialect for CustomDialect { self.interval_style } + fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { + self.float64_ast_dtype.clone() + } + fn utf8_cast_dtype(&self) -> ast::DataType { self.utf8_cast_dtype.clone() } @@ -210,6 +227,7 @@ pub struct CustomDialectBuilder { supports_nulls_first_in_sort: bool, use_timestamp_for_date64: bool, interval_style: IntervalStyle, + float64_ast_dtype: sqlparser::ast::DataType, utf8_cast_dtype: ast::DataType, large_utf8_cast_dtype: ast::DataType, } @@ -227,6 +245,7 @@ impl CustomDialectBuilder { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::PostgresVerbose, + float64_ast_dtype: sqlparser::ast::DataType::Double, utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, } @@ -238,6 +257,7 @@ impl CustomDialectBuilder { supports_nulls_first_in_sort: self.supports_nulls_first_in_sort, use_timestamp_for_date64: self.use_timestamp_for_date64, interval_style: self.interval_style, + float64_ast_dtype: self.float64_ast_dtype, utf8_cast_dtype: self.utf8_cast_dtype, large_utf8_cast_dtype: self.large_utf8_cast_dtype, } @@ -273,6 +293,14 @@ impl CustomDialectBuilder { self } + pub fn with_float64_ast_dtype( + mut self, + float64_ast_dtype: sqlparser::ast::DataType, + ) -> Self { + self.float64_ast_dtype = float64_ast_dtype; + self + } + pub fn with_utf8_cast_dtype(mut self, utf8_cast_dtype: ast::DataType) -> Self { self.utf8_cast_dtype = utf8_cast_dtype; self diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 950e7e11288a..2f7854c1a183 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1240,7 +1240,7 @@ impl Unparser<'_> { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } DataType::Float32 => Ok(ast::DataType::Float(None)), - DataType::Float64 => Ok(ast::DataType::Double), + DataType::Float64 => Ok(self.dialect.float64_ast_dtype()), DataType::Timestamp(_, tz) => { let tz_info = match tz { Some(_) => TimezoneInfo::WithTimeZone, @@ -1822,6 +1822,34 @@ mod tests { Ok(()) } + #[test] + fn custom_dialect_float64_ast_dtype() -> Result<()> { + for (float64_ast_dtype, identifier) in [ + (sqlparser::ast::DataType::Double, "DOUBLE"), + ( + sqlparser::ast::DataType::DoublePrecision, + "DOUBLE PRECISION", + ), + ] { + let dialect = CustomDialectBuilder::new() + .with_float64_ast_dtype(float64_ast_dtype) + .build(); + let unparser = Unparser::new(&dialect); + + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Float64, + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + + let expected = format!(r#"CAST(a AS {identifier})"#); + assert_eq!(actual, expected); + } + Ok(()) + } + #[test] fn customer_dialect_support_nulls_first_in_ort() -> Result<()> { let tests: Vec<(Expr, &str, bool)> = vec![ From 5da7ab300215c44ca5dc16771091890de22af99b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 20 Jul 2024 09:19:22 -0400 Subject: [PATCH 18/37] Minor: move `SessionStateDefaults` into its own module (#11566) * Minor: move `SessionStateDefaults` into its own module * Fix no default features --- .../core/src/datasource/listing/table.rs | 2 +- .../core/src/datasource/schema_adapter.rs | 1 + datafusion/core/src/execution/mod.rs | 3 + .../core/src/execution/session_state.rs | 185 +--------------- .../src/execution/session_state_defaults.rs | 202 ++++++++++++++++++ 5 files changed, 211 insertions(+), 182 deletions(-) create mode 100644 datafusion/core/src/execution/session_state_defaults.rs diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 1a7390d46f89..4d0a7738b039 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -1038,8 +1038,8 @@ mod tests { use crate::datasource::file_format::avro::AvroFormat; use crate::datasource::file_format::csv::CsvFormat; use crate::datasource::file_format::json::JsonFormat; - use crate::datasource::file_format::parquet::ParquetFormat; #[cfg(feature = "parquet")] + use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::{provider_as_source, MemTable}; use crate::execution::options::ArrowReadOptions; use crate::physical_plan::collect; diff --git a/datafusion/core/src/datasource/schema_adapter.rs b/datafusion/core/src/datasource/schema_adapter.rs index 715e2da5d978..f485c49e9109 100644 --- a/datafusion/core/src/datasource/schema_adapter.rs +++ b/datafusion/core/src/datasource/schema_adapter.rs @@ -246,6 +246,7 @@ mod tests { use crate::datasource::schema_adapter::{ SchemaAdapter, SchemaAdapterFactory, SchemaMapper, }; + #[cfg(feature = "parquet")] use parquet::arrow::ArrowWriter; use tempfile::TempDir; diff --git a/datafusion/core/src/execution/mod.rs b/datafusion/core/src/execution/mod.rs index ac02c7317256..a1b3eab25f33 100644 --- a/datafusion/core/src/execution/mod.rs +++ b/datafusion/core/src/execution/mod.rs @@ -19,6 +19,9 @@ pub mod context; pub mod session_state; +mod session_state_defaults; + +pub use session_state_defaults::SessionStateDefaults; // backwards compatibility pub use crate::datasource::file_format::options; diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 0824b249b7d1..59cc620dae4d 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -18,29 +18,17 @@ //! [`SessionState`]: information required to run queries in a session use crate::catalog::information_schema::{InformationSchemaProvider, INFORMATION_SCHEMA}; -use crate::catalog::listing_schema::ListingSchemaProvider; -use crate::catalog::schema::{MemorySchemaProvider, SchemaProvider}; -use crate::catalog::{ - CatalogProvider, CatalogProviderList, MemoryCatalogProvider, - MemoryCatalogProviderList, -}; +use crate::catalog::schema::SchemaProvider; +use crate::catalog::{CatalogProviderList, MemoryCatalogProviderList}; use crate::datasource::cte_worktable::CteWorkTable; -use crate::datasource::file_format::arrow::ArrowFormatFactory; -use crate::datasource::file_format::avro::AvroFormatFactory; -use crate::datasource::file_format::csv::CsvFormatFactory; -use crate::datasource::file_format::json::JsonFormatFactory; -#[cfg(feature = "parquet")] -use crate::datasource::file_format::parquet::ParquetFormatFactory; use crate::datasource::file_format::{format_as_file_type, FileFormatFactory}; use crate::datasource::function::{TableFunction, TableFunctionImpl}; -use crate::datasource::provider::{DefaultTableFactory, TableProviderFactory}; +use crate::datasource::provider::TableProviderFactory; use crate::datasource::provider_as_source; use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner}; -#[cfg(feature = "array_expressions")] -use crate::functions_array; +use crate::execution::SessionStateDefaults; use crate::physical_optimizer::optimizer::PhysicalOptimizer; use crate::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; -use crate::{functions, functions_aggregate}; use arrow_schema::{DataType, SchemaRef}; use async_trait::async_trait; use chrono::{DateTime, Utc}; @@ -54,7 +42,6 @@ use datafusion_common::{ ResolvedTableReference, TableReference, }; use datafusion_execution::config::SessionConfig; -use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::execution_props::ExecutionProps; @@ -85,7 +72,6 @@ use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; use std::fmt::Debug; use std::sync::Arc; -use url::Url; use uuid::Uuid; /// Execution context for registering data sources and executing queries. @@ -1420,169 +1406,6 @@ impl From for SessionStateBuilder { } } -/// Defaults that are used as part of creating a SessionState such as table providers, -/// file formats, registering of builtin functions, etc. -pub struct SessionStateDefaults {} - -impl SessionStateDefaults { - /// returns a map of the default [`TableProviderFactory`]s - pub fn default_table_factories() -> HashMap> { - let mut table_factories: HashMap> = - HashMap::new(); - #[cfg(feature = "parquet")] - table_factories.insert("PARQUET".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("CSV".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("JSON".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); - - table_factories - } - - /// returns the default MemoryCatalogProvider - pub fn default_catalog( - config: &SessionConfig, - table_factories: &HashMap>, - runtime: &Arc, - ) -> MemoryCatalogProvider { - let default_catalog = MemoryCatalogProvider::new(); - - default_catalog - .register_schema( - &config.options().catalog.default_schema, - Arc::new(MemorySchemaProvider::new()), - ) - .expect("memory catalog provider can register schema"); - - Self::register_default_schema(config, table_factories, runtime, &default_catalog); - - default_catalog - } - - /// returns the list of default [`ExprPlanner`]s - pub fn default_expr_planners() -> Vec> { - let expr_planners: Vec> = vec![ - Arc::new(functions::core::planner::CoreFunctionPlanner::default()), - // register crate of array expressions (if enabled) - #[cfg(feature = "array_expressions")] - Arc::new(functions_array::planner::ArrayFunctionPlanner), - #[cfg(feature = "array_expressions")] - Arc::new(functions_array::planner::FieldAccessPlanner), - #[cfg(any( - feature = "datetime_expressions", - feature = "unicode_expressions" - ))] - Arc::new(functions::planner::UserDefinedFunctionPlanner), - ]; - - expr_planners - } - - /// returns the list of default [`ScalarUDF']'s - pub fn default_scalar_functions() -> Vec> { - let mut functions: Vec> = functions::all_default_functions(); - #[cfg(feature = "array_expressions")] - functions.append(&mut functions_array::all_default_array_functions()); - - functions - } - - /// returns the list of default [`AggregateUDF']'s - pub fn default_aggregate_functions() -> Vec> { - functions_aggregate::all_default_aggregate_functions() - } - - /// returns the list of default [`FileFormatFactory']'s - pub fn default_file_formats() -> Vec> { - let file_formats: Vec> = vec![ - #[cfg(feature = "parquet")] - Arc::new(ParquetFormatFactory::new()), - Arc::new(JsonFormatFactory::new()), - Arc::new(CsvFormatFactory::new()), - Arc::new(ArrowFormatFactory::new()), - Arc::new(AvroFormatFactory::new()), - ]; - - file_formats - } - - /// registers all builtin functions - scalar, array and aggregate - pub fn register_builtin_functions(state: &mut SessionState) { - Self::register_scalar_functions(state); - Self::register_array_functions(state); - Self::register_aggregate_functions(state); - } - - /// registers all the builtin scalar functions - pub fn register_scalar_functions(state: &mut SessionState) { - functions::register_all(state).expect("can not register built in functions"); - } - - /// registers all the builtin array functions - pub fn register_array_functions(state: &mut SessionState) { - // register crate of array expressions (if enabled) - #[cfg(feature = "array_expressions")] - functions_array::register_all(state).expect("can not register array expressions"); - } - - /// registers all the builtin aggregate functions - pub fn register_aggregate_functions(state: &mut SessionState) { - functions_aggregate::register_all(state) - .expect("can not register aggregate functions"); - } - - /// registers the default schema - pub fn register_default_schema( - config: &SessionConfig, - table_factories: &HashMap>, - runtime: &Arc, - default_catalog: &MemoryCatalogProvider, - ) { - let url = config.options().catalog.location.as_ref(); - let format = config.options().catalog.format.as_ref(); - let (url, format) = match (url, format) { - (Some(url), Some(format)) => (url, format), - _ => return, - }; - let url = url.to_string(); - let format = format.to_string(); - - let url = Url::parse(url.as_str()).expect("Invalid default catalog location!"); - let authority = match url.host_str() { - Some(host) => format!("{}://{}", url.scheme(), host), - None => format!("{}://", url.scheme()), - }; - let path = &url.as_str()[authority.len()..]; - let path = object_store::path::Path::parse(path).expect("Can't parse path"); - let store = ObjectStoreUrl::parse(authority.as_str()) - .expect("Invalid default catalog url"); - let store = match runtime.object_store(store) { - Ok(store) => store, - _ => return, - }; - let factory = match table_factories.get(format.as_str()) { - Some(factory) => factory, - _ => return, - }; - let schema = - ListingSchemaProvider::new(authority, path, factory.clone(), store, format); - let _ = default_catalog - .register_schema("default", Arc::new(schema)) - .expect("Failed to register default schema"); - } - - /// registers the default [`FileFormatFactory`]s - pub fn register_default_file_formats(state: &mut SessionState) { - let formats = SessionStateDefaults::default_file_formats(); - for format in formats { - if let Err(e) = state.register_file_format(format, false) { - log::info!("Unable to register default file format: {e}") - }; - } - } -} - /// Adapter that implements the [`ContextProvider`] trait for a [`SessionState`] /// /// This is used so the SQL planner can access the state of the session without diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs new file mode 100644 index 000000000000..0b0465e44605 --- /dev/null +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -0,0 +1,202 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::catalog::listing_schema::ListingSchemaProvider; +use crate::catalog::{CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider}; +use crate::datasource::file_format::arrow::ArrowFormatFactory; +use crate::datasource::file_format::avro::AvroFormatFactory; +use crate::datasource::file_format::csv::CsvFormatFactory; +use crate::datasource::file_format::json::JsonFormatFactory; +#[cfg(feature = "parquet")] +use crate::datasource::file_format::parquet::ParquetFormatFactory; +use crate::datasource::file_format::FileFormatFactory; +use crate::datasource::provider::{DefaultTableFactory, TableProviderFactory}; +use crate::execution::context::SessionState; +#[cfg(feature = "array_expressions")] +use crate::functions_array; +use crate::{functions, functions_aggregate}; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_expr::planner::ExprPlanner; +use datafusion_expr::{AggregateUDF, ScalarUDF}; +use std::collections::HashMap; +use std::sync::Arc; +use url::Url; + +/// Defaults that are used as part of creating a SessionState such as table providers, +/// file formats, registering of builtin functions, etc. +pub struct SessionStateDefaults {} + +impl SessionStateDefaults { + /// returns a map of the default [`TableProviderFactory`]s + pub fn default_table_factories() -> HashMap> { + let mut table_factories: HashMap> = + HashMap::new(); + #[cfg(feature = "parquet")] + table_factories.insert("PARQUET".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("CSV".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("JSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); + + table_factories + } + + /// returns the default MemoryCatalogProvider + pub fn default_catalog( + config: &SessionConfig, + table_factories: &HashMap>, + runtime: &Arc, + ) -> MemoryCatalogProvider { + let default_catalog = MemoryCatalogProvider::new(); + + default_catalog + .register_schema( + &config.options().catalog.default_schema, + Arc::new(MemorySchemaProvider::new()), + ) + .expect("memory catalog provider can register schema"); + + Self::register_default_schema(config, table_factories, runtime, &default_catalog); + + default_catalog + } + + /// returns the list of default [`ExprPlanner`]s + pub fn default_expr_planners() -> Vec> { + let expr_planners: Vec> = vec![ + Arc::new(functions::core::planner::CoreFunctionPlanner::default()), + // register crate of array expressions (if enabled) + #[cfg(feature = "array_expressions")] + Arc::new(functions_array::planner::ArrayFunctionPlanner), + #[cfg(feature = "array_expressions")] + Arc::new(functions_array::planner::FieldAccessPlanner), + #[cfg(any( + feature = "datetime_expressions", + feature = "unicode_expressions" + ))] + Arc::new(functions::planner::UserDefinedFunctionPlanner), + ]; + + expr_planners + } + + /// returns the list of default [`ScalarUDF']'s + pub fn default_scalar_functions() -> Vec> { + let mut functions: Vec> = functions::all_default_functions(); + #[cfg(feature = "array_expressions")] + functions.append(&mut functions_array::all_default_array_functions()); + + functions + } + + /// returns the list of default [`AggregateUDF']'s + pub fn default_aggregate_functions() -> Vec> { + functions_aggregate::all_default_aggregate_functions() + } + + /// returns the list of default [`FileFormatFactory']'s + pub fn default_file_formats() -> Vec> { + let file_formats: Vec> = vec![ + #[cfg(feature = "parquet")] + Arc::new(ParquetFormatFactory::new()), + Arc::new(JsonFormatFactory::new()), + Arc::new(CsvFormatFactory::new()), + Arc::new(ArrowFormatFactory::new()), + Arc::new(AvroFormatFactory::new()), + ]; + + file_formats + } + + /// registers all builtin functions - scalar, array and aggregate + pub fn register_builtin_functions(state: &mut SessionState) { + Self::register_scalar_functions(state); + Self::register_array_functions(state); + Self::register_aggregate_functions(state); + } + + /// registers all the builtin scalar functions + pub fn register_scalar_functions(state: &mut SessionState) { + functions::register_all(state).expect("can not register built in functions"); + } + + /// registers all the builtin array functions + pub fn register_array_functions(state: &mut SessionState) { + // register crate of array expressions (if enabled) + #[cfg(feature = "array_expressions")] + functions_array::register_all(state).expect("can not register array expressions"); + } + + /// registers all the builtin aggregate functions + pub fn register_aggregate_functions(state: &mut SessionState) { + functions_aggregate::register_all(state) + .expect("can not register aggregate functions"); + } + + /// registers the default schema + pub fn register_default_schema( + config: &SessionConfig, + table_factories: &HashMap>, + runtime: &Arc, + default_catalog: &MemoryCatalogProvider, + ) { + let url = config.options().catalog.location.as_ref(); + let format = config.options().catalog.format.as_ref(); + let (url, format) = match (url, format) { + (Some(url), Some(format)) => (url, format), + _ => return, + }; + let url = url.to_string(); + let format = format.to_string(); + + let url = Url::parse(url.as_str()).expect("Invalid default catalog location!"); + let authority = match url.host_str() { + Some(host) => format!("{}://{}", url.scheme(), host), + None => format!("{}://", url.scheme()), + }; + let path = &url.as_str()[authority.len()..]; + let path = object_store::path::Path::parse(path).expect("Can't parse path"); + let store = ObjectStoreUrl::parse(authority.as_str()) + .expect("Invalid default catalog url"); + let store = match runtime.object_store(store) { + Ok(store) => store, + _ => return, + }; + let factory = match table_factories.get(format.as_str()) { + Some(factory) => factory, + _ => return, + }; + let schema = + ListingSchemaProvider::new(authority, path, factory.clone(), store, format); + let _ = default_catalog + .register_schema("default", Arc::new(schema)) + .expect("Failed to register default schema"); + } + + /// registers the default [`FileFormatFactory`]s + pub fn register_default_file_formats(state: &mut SessionState) { + let formats = SessionStateDefaults::default_file_formats(); + for format in formats { + if let Err(e) = state.register_file_format(format, false) { + log::info!("Unable to register default file format: {e}") + }; + } + } +} From 7df2bde8fc12554ad92b8941f7916069c1651f11 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Sun, 21 Jul 2024 05:31:26 -0700 Subject: [PATCH 19/37] fix: fixes trig function order by (#11559) * fix: remove assert * tests: add tests from ticket * tests: clean up table --- datafusion/common/src/scalar/mod.rs | 3 -- datafusion/sqllogictest/test_files/scalar.slt | 34 +++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 38f70e4c1466..065101390115 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1063,7 +1063,6 @@ impl ScalarValue { /// Create an one value in the given type. pub fn new_one(datatype: &DataType) -> Result { - assert!(datatype.is_primitive()); Ok(match datatype { DataType::Int8 => ScalarValue::Int8(Some(1)), DataType::Int16 => ScalarValue::Int16(Some(1)), @@ -1086,7 +1085,6 @@ impl ScalarValue { /// Create a negative one value in the given type. pub fn new_negative_one(datatype: &DataType) -> Result { - assert!(datatype.is_primitive()); Ok(match datatype { DataType::Int8 | DataType::UInt8 => ScalarValue::Int8(Some(-1)), DataType::Int16 | DataType::UInt16 => ScalarValue::Int16(Some(-1)), @@ -1104,7 +1102,6 @@ impl ScalarValue { } pub fn new_ten(datatype: &DataType) -> Result { - assert!(datatype.is_primitive()); Ok(match datatype { DataType::Int8 => ScalarValue::Int8(Some(10)), DataType::Int16 => ScalarValue::Int16(Some(10)), diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index dd19a1344139..48f94fc080a4 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1982,3 +1982,37 @@ query I select strpos('joséésoj', arrow_cast(null, 'Utf8')); ---- NULL + +statement ok +CREATE TABLE t1 (v1 int) AS VALUES (1), (2), (3); + +query I +SELECT * FROM t1 ORDER BY ACOS(SIN(v1)); +---- +2 +1 +3 + +query I +SELECT * FROM t1 ORDER BY ACOSH(SIN(v1)); +---- +1 +2 +3 + +query I +SELECT * FROM t1 ORDER BY ASIN(SIN(v1)); +---- +3 +1 +2 + +query I +SELECT * FROM t1 ORDER BY ATANH(SIN(v1)); +---- +3 +1 +2 + +statement ok +drop table t1; From d232065c3b710d0c8e035de49730238a30073eb2 Mon Sep 17 00:00:00 2001 From: Lorrens Pantelis <100197010+LorrensP-2158466@users.noreply.github.com> Date: Sun, 21 Jul 2024 14:31:41 +0200 Subject: [PATCH 20/37] refactor: rewrite mega type to an enum containing both cases (#11539) --- .../file_format/write/orchestration.rs | 62 ++++++++++++++----- 1 file changed, 48 insertions(+), 14 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index f788865b070f..1d32063ee9f3 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -42,6 +42,37 @@ use tokio::task::JoinSet; type WriterType = Box; type SerializerType = Arc; +/// Result of calling [`serialize_rb_stream_to_object_store`] +pub(crate) enum SerializedRecordBatchResult { + Success { + /// the writer + writer: WriterType, + + /// the number of rows successfully written + row_count: usize, + }, + Failure { + /// As explained in [`serialize_rb_stream_to_object_store`]: + /// - If an IO error occured that involved the ObjectStore writer, then the writer will not be returned to the caller + /// - Otherwise, the writer is returned to the caller + writer: Option, + + /// the actual error that occured + err: DataFusionError, + }, +} + +impl SerializedRecordBatchResult { + /// Create the success variant + pub fn success(writer: WriterType, row_count: usize) -> Self { + Self::Success { writer, row_count } + } + + pub fn failure(writer: Option, err: DataFusionError) -> Self { + Self::Failure { writer, err } + } +} + /// Serializes a single data stream in parallel and writes to an ObjectStore concurrently. /// Data order is preserved. /// @@ -55,7 +86,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store( mut data_rx: Receiver, serializer: Arc, mut writer: WriterType, -) -> std::result::Result<(WriterType, u64), (Option, DataFusionError)> { +) -> SerializedRecordBatchResult { let (tx, mut rx) = mpsc::channel::>>(100); let serialize_task = SpawnedTask::spawn(async move { @@ -86,43 +117,43 @@ pub(crate) async fn serialize_rb_stream_to_object_store( match writer.write_all(&bytes).await { Ok(_) => (), Err(e) => { - return Err(( + return SerializedRecordBatchResult::failure( None, DataFusionError::Execution(format!( "Error writing to object store: {e}" )), - )) + ) } }; row_count += cnt; } Ok(Err(e)) => { // Return the writer along with the error - return Err((Some(writer), e)); + return SerializedRecordBatchResult::failure(Some(writer), e); } Err(e) => { // Handle task panic or cancellation - return Err(( + return SerializedRecordBatchResult::failure( Some(writer), DataFusionError::Execution(format!( "Serialization task panicked or was cancelled: {e}" )), - )); + ); } } } match serialize_task.join().await { Ok(Ok(_)) => (), - Ok(Err(e)) => return Err((Some(writer), e)), + Ok(Err(e)) => return SerializedRecordBatchResult::failure(Some(writer), e), Err(_) => { - return Err(( + return SerializedRecordBatchResult::failure( Some(writer), internal_datafusion_err!("Unknown error writing to object store"), - )) + ) } } - Ok((writer, row_count as u64)) + SerializedRecordBatchResult::success(writer, row_count) } type FileWriteBundle = (Receiver, SerializerType, WriterType); @@ -153,14 +184,17 @@ pub(crate) async fn stateless_serialize_and_write_files( while let Some(result) = join_set.join_next().await { match result { Ok(res) => match res { - Ok((writer, cnt)) => { + SerializedRecordBatchResult::Success { + writer, + row_count: cnt, + } => { finished_writers.push(writer); row_count += cnt; } - Err((writer, e)) => { + SerializedRecordBatchResult::Failure { writer, err } => { finished_writers.extend(writer); any_errors = true; - triggering_error = Some(e); + triggering_error = Some(err); } }, Err(e) => { @@ -193,7 +227,7 @@ pub(crate) async fn stateless_serialize_and_write_files( } } - tx.send(row_count).map_err(|_| { + tx.send(row_count as u64).map_err(|_| { internal_datafusion_err!( "Error encountered while sending row count back to file sink!" ) From 36660fe10d9c0cdff62e0da0b94bee28422d3419 Mon Sep 17 00:00:00 2001 From: Dharan Aditya Date: Sun, 21 Jul 2024 18:03:27 +0530 Subject: [PATCH 21/37] Move `sql_compound_identifier_to_expr ` to `ExprPlanner` (#11487) * move get_field to expr planner * formatting * formatting * documentation * refactor * documentation & fix's * move optimizer tests to core * fix breaking tc's * cleanup * fix examples * formatting * rm datafusion-functions from optimizer * update compound identifier * update planner * update planner * formatting * reverting optimizer tests * formatting --- datafusion/expr/src/planner.rs | 19 +++++++++- datafusion/functions/src/core/mod.rs | 1 - datafusion/functions/src/core/planner.rs | 27 +++++++++++++- datafusion/sql/examples/sql.rs | 20 +++++++++- datafusion/sql/src/expr/identifier.rs | 45 +++++++++++------------ datafusion/sql/tests/cases/plan_to_sql.rs | 11 ++++-- datafusion/sql/tests/common/mod.rs | 11 ++++++ datafusion/sql/tests/sql_integration.rs | 5 ++- 8 files changed, 106 insertions(+), 33 deletions(-) diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 415af1bf94dc..c775427df138 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -19,7 +19,7 @@ use std::sync::Arc; -use arrow::datatypes::{DataType, SchemaRef}; +use arrow::datatypes::{DataType, Field, SchemaRef}; use datafusion_common::{ config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema, Result, TableReference, @@ -180,6 +180,23 @@ pub trait ExprPlanner: Send + Sync { fn plan_make_map(&self, args: Vec) -> Result>> { Ok(PlannerResult::Original(args)) } + + /// Plans compound identifier eg `db.schema.table` for non-empty nested names + /// + /// Note: + /// Currently compound identifier for outer query schema is not supported. + /// + /// Returns planned expression + fn plan_compound_identifier( + &self, + _field: &Field, + _qualifier: Option<&TableReference>, + _nested_names: &[String], + ) -> Result>> { + not_impl_err!( + "Default planner compound identifier hasn't been implemented for ExprPlanner" + ) + } } /// An operator with two arguments to plan diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index cbfaa592b012..ee0309e59382 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -100,7 +100,6 @@ pub fn functions() -> Vec> { nvl2(), arrow_typeof(), named_struct(), - get_field(), coalesce(), map(), ] diff --git a/datafusion/functions/src/core/planner.rs b/datafusion/functions/src/core/planner.rs index 63eaa9874c2b..889f191d592f 100644 --- a/datafusion/functions/src/core/planner.rs +++ b/datafusion/functions/src/core/planner.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::DFSchema; +use arrow::datatypes::Field; use datafusion_common::Result; +use datafusion_common::{not_impl_err, Column, DFSchema, ScalarValue, TableReference}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawDictionaryExpr}; -use datafusion_expr::Expr; +use datafusion_expr::{lit, Expr}; use super::named_struct; @@ -62,4 +63,26 @@ impl ExprPlanner for CoreFunctionPlanner { ScalarFunction::new_udf(crate::string::overlay(), args), ))) } + + fn plan_compound_identifier( + &self, + field: &Field, + qualifier: Option<&TableReference>, + nested_names: &[String], + ) -> Result>> { + // TODO: remove when can support multiple nested identifiers + if nested_names.len() > 1 { + return not_impl_err!( + "Nested identifiers not yet supported for column {}", + Column::from((qualifier, field)).quoted_flat_name() + ); + } + let nested_name = nested_names[0].to_string(); + + let col = Expr::Column(Column::from((qualifier, field))); + let get_field_args = vec![col, lit(ScalarValue::from(nested_name))]; + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::core::get_field(), get_field_args), + ))) + } } diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index b724afabaf09..d9ee1b4db8e2 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -15,13 +15,18 @@ // specific language governing permissions and limitations // under the License. +use std::{collections::HashMap, sync::Arc}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; +use datafusion_expr::planner::ExprPlanner; use datafusion_expr::WindowUDF; use datafusion_expr::{ logical_plan::builder::LogicalTableSource, AggregateUDF, ScalarUDF, TableSource, }; +use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_sql::{ @@ -29,7 +34,6 @@ use datafusion_sql::{ sqlparser::{dialect::GenericDialect, parser::Parser}, TableReference, }; -use std::{collections::HashMap, sync::Arc}; fn main() { let sql = "SELECT \ @@ -53,7 +57,8 @@ fn main() { // create a logical query plan let context_provider = MyContextProvider::new() .with_udaf(sum_udaf()) - .with_udaf(count_udaf()); + .with_udaf(count_udaf()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); @@ -65,6 +70,7 @@ struct MyContextProvider { options: ConfigOptions, tables: HashMap>, udafs: HashMap>, + expr_planners: Vec>, } impl MyContextProvider { @@ -73,6 +79,11 @@ impl MyContextProvider { self } + fn with_expr_planner(mut self, planner: Arc) -> Self { + self.expr_planners.push(planner); + self + } + fn new() -> Self { let mut tables = HashMap::new(); tables.insert( @@ -105,6 +116,7 @@ impl MyContextProvider { tables, options: Default::default(), udafs: Default::default(), + expr_planners: vec![], } } } @@ -154,4 +166,8 @@ impl ContextProvider for MyContextProvider { fn udwf_names(&self) -> Vec { Vec::new() } + + fn get_expr_planners(&self) -> &[Arc] { + &self.expr_planners + } } diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 39736b1fbba5..f8979bde3086 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -15,14 +15,17 @@ // specific language governing permissions and limitations // under the License. -use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::Field; +use sqlparser::ast::{Expr as SQLExpr, Ident}; + use datafusion_common::{ internal_err, not_impl_err, plan_datafusion_err, Column, DFSchema, DataFusionError, - Result, ScalarValue, TableReference, + Result, TableReference, }; -use datafusion_expr::{expr::ScalarFunction, lit, Case, Expr}; -use sqlparser::ast::{Expr as SQLExpr, Ident}; +use datafusion_expr::planner::PlannerResult; +use datafusion_expr::{Case, Expr}; + +use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn sql_identifier_to_expr( @@ -125,26 +128,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match search_result { // found matching field with spare identifier(s) for nested field(s) in structure Some((field, qualifier, nested_names)) if !nested_names.is_empty() => { - // TODO: remove when can support multiple nested identifiers - if nested_names.len() > 1 { - return not_impl_err!( - "Nested identifiers not yet supported for column {}", - Column::from((qualifier, field)).quoted_flat_name() - ); - } - let nested_name = nested_names[0].to_string(); - - let col = Expr::Column(Column::from((qualifier, field))); - if let Some(udf) = - self.context_provider.get_function_meta("get_field") - { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![col, lit(ScalarValue::from(nested_name))], - ))) - } else { - internal_err!("get_field not found") + // found matching field with spare identifier(s) for nested field(s) in structure + for planner in self.context_provider.get_expr_planners() { + if let Ok(planner_result) = planner.plan_compound_identifier( + field, + qualifier, + nested_names, + ) { + match planner_result { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(_args) => {} + } + } } + not_impl_err!( + "Compound identifiers not supported by ExprPlanner: {ids:?}" + ) } // found matching field with no spare identifier(s) Some((field, qualifier, _nested_names)) => { diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index ed79a1dfc0c7..e9c4114353c0 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; use std::vec; use arrow_schema::*; @@ -28,6 +29,7 @@ use datafusion_sql::unparser::dialect::{ }; use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser}; +use datafusion_functions::core::planner::CoreFunctionPlanner; use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect}; use sqlparser::parser::Parser; @@ -155,7 +157,8 @@ fn roundtrip_statement() -> Result<()> { let context = MockContextProvider::default() .with_udaf(sum_udaf()) - .with_udaf(count_udaf()); + .with_udaf(count_udaf()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); @@ -184,7 +187,8 @@ fn roundtrip_crossjoin() -> Result<()> { .try_with_sql(query)? .parse_statement()?; - let context = MockContextProvider::default(); + let context = MockContextProvider::default() + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); @@ -276,7 +280,8 @@ fn roundtrip_statement_with_dialect() -> Result<()> { .try_with_sql(query.sql)? .parse_statement()?; - let context = MockContextProvider::default(); + let context = MockContextProvider::default() + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel .sql_statement_to_plan(statement) diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index bcfb8f43848e..374aa9db6714 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -25,6 +25,7 @@ use arrow_schema::*; use datafusion_common::config::ConfigOptions; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{plan_err, GetExt, Result, TableReference}; +use datafusion_expr::planner::ExprPlanner; use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; use datafusion_sql::planner::ContextProvider; @@ -53,6 +54,7 @@ pub(crate) struct MockContextProvider { options: ConfigOptions, udfs: HashMap>, udafs: HashMap>, + expr_planners: Vec>, } impl MockContextProvider { @@ -73,6 +75,11 @@ impl MockContextProvider { self.udafs.insert(udaf.name().to_lowercase(), udaf); self } + + pub(crate) fn with_expr_planner(mut self, planner: Arc) -> Self { + self.expr_planners.push(planner); + self + } } impl ContextProvider for MockContextProvider { @@ -240,6 +247,10 @@ impl ContextProvider for MockContextProvider { fn udwf_names(&self) -> Vec { Vec::new() } + + fn get_expr_planners(&self) -> &[Arc] { + &self.expr_planners + } } struct EmptyTable { diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 57dab81331b3..3291560383df 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -18,6 +18,7 @@ use std::any::Any; #[cfg(test)] use std::collections::HashMap; +use std::sync::Arc; use std::vec; use arrow_schema::TimeUnit::Nanosecond; @@ -37,6 +38,7 @@ use datafusion_sql::{ planner::{ParserOptions, SqlToRel}, }; +use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_functions_aggregate::{ approx_median::approx_median_udaf, count::count_udaf, }; @@ -2694,7 +2696,8 @@ fn logical_plan_with_dialect_and_options( .with_udaf(approx_median_udaf()) .with_udaf(count_udaf()) .with_udaf(avg_udaf()) - .with_udaf(grouping_udaf()); + .with_udaf(grouping_udaf()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); From 2587df09c3fd9659f5076cedf98046e258764b2e Mon Sep 17 00:00:00 2001 From: Chris Connelly Date: Mon, 22 Jul 2024 00:42:03 +0100 Subject: [PATCH 22/37] Support `newlines_in_values` CSV option (#11533) * feat!: support `newlines_in_values` CSV option This significantly simplifies the UX when dealing with large CSV files that must support newlines in (quoted) values. By default, large CSV files will be repartitioned into multiple parallel range scans. This is great for performance in the common case but when large CSVs contain newlines in values the parallel scan will fail due to splitting on newlines within quotes rather than actual line terminators. With the current implementation, this behaviour can be controlled by the session-level `datafusion.optimizer.repartition_file_scans` and `datafusion.optimizer.repartition_file_min_size` settings. This commit introduces a `newlines_in_values` option to `CsvOptions` and plumbs it through to `CsvExec`, which includes it in the test for whether parallel execution is supported. This provides a convenient and searchable way to disable file scan repartitioning on a per-CSV basis. BREAKING CHANGE: This adds new public fields to types with all public fields, which is a breaking change. * docs: normalise `newlines_in_values` documentation * test: add/fix sqllogictests for `newlines_in_values` * docs: document `datafusion.catalog.newlines_in_values` * fix: typo in config.md * chore: suppress lint on too many arguments for `CsvExec::new` * fix: always checkout `*.slt` with LF line endings This is a bit of a stab in the dark, but it might fix multiline tests on Windows. * fix: always checkout `newlines_in_values.csv` with `LF` line endings The default git behaviour of converting line endings for checked out files causes the `csv_files.slt` test to fail when testing `newlines_in_values`. This appears to be due to the quoted newlines being converted to CRLF, which are not then normalised when the CSV is read. Assuming that the sqllogictests do normalise line endings in the expected output, this could then lead to a "spurious" diff from the actual output. --------- Co-authored-by: Andrew Lamb --- .gitattributes | 1 + datafusion/common/src/config.rs | 30 +++++++++++ .../core/src/datasource/file_format/csv.rs | 50 +++++++++++++++++++ .../src/datasource/file_format/options.rs | 22 ++++++++ .../core/src/datasource/physical_plan/csv.rs | 27 ++++++++-- .../enforce_distribution.rs | 3 ++ .../physical_optimizer/projection_pushdown.rs | 3 ++ .../replace_with_order_preserving_variants.rs | 1 + datafusion/core/src/test/mod.rs | 3 ++ .../core/tests/data/newlines_in_values.csv | 13 +++++ .../proto/datafusion_common.proto | 1 + datafusion/proto-common/src/from_proto/mod.rs | 1 + .../proto-common/src/generated/pbjson.rs | 21 ++++++++ .../proto-common/src/generated/prost.rs | 3 ++ datafusion/proto-common/src/to_proto/mod.rs | 3 ++ datafusion/proto/proto/datafusion.proto | 1 + .../src/generated/datafusion_proto_common.rs | 3 ++ datafusion/proto/src/generated/pbjson.rs | 18 +++++++ datafusion/proto/src/generated/prost.rs | 2 + datafusion/proto/src/physical_plan/mod.rs | 2 + .../sqllogictest/test_files/csv_files.slt | 42 ++++++++++++++++ .../test_files/information_schema.slt | 2 + docs/source/user-guide/configs.md | 1 + 23 files changed, 250 insertions(+), 3 deletions(-) create mode 100644 datafusion/core/tests/data/newlines_in_values.csv diff --git a/.gitattributes b/.gitattributes index bcdeffc09a11..84b47a6fc56e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,4 @@ .github/ export-ignore +datafusion/core/tests/data/newlines_in_values.csv text eol=lf datafusion/proto/src/generated/prost.rs linguist-generated datafusion/proto/src/generated/pbjson.rs linguist-generated diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index b46b002baac0..3cbe14cb558e 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -184,6 +184,16 @@ config_namespace! { /// Default value for `format.has_header` for `CREATE EXTERNAL TABLE` /// if not specified explicitly in the statement. pub has_header: bool, default = false + + /// Specifies whether newlines in (quoted) CSV values are supported. + /// + /// This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` + /// if not specified explicitly in the statement. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + pub newlines_in_values: bool, default = false } } @@ -1593,6 +1603,14 @@ config_namespace! { pub quote: u8, default = b'"' pub escape: Option, default = None pub double_quote: Option, default = None + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub newlines_in_values: Option, default = None pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED pub schema_infer_max_rec: usize, default = 100 pub date_format: Option, default = None @@ -1665,6 +1683,18 @@ impl CsvOptions { self } + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self { + self.newlines_in_values = Some(newlines_in_values); + self + } + /// Set a `CompressionTypeVariant` of CSV /// - defaults to `CompressionTypeVariant::UNCOMPRESSED` pub fn with_file_compression_type( diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 5daa8447551b..185f50883b2c 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -233,6 +233,18 @@ impl CsvFormat { self } + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self { + self.options.newlines_in_values = Some(newlines_in_values); + self + } + /// Set a `FileCompressionType` of CSV /// - defaults to `FileCompressionType::UNCOMPRESSED` pub fn with_file_compression_type( @@ -330,6 +342,9 @@ impl FileFormat for CsvFormat { self.options.quote, self.options.escape, self.options.comment, + self.options + .newlines_in_values + .unwrap_or(state.config_options().catalog.newlines_in_values), self.options.compression.into(), ); Ok(Arc::new(exec)) @@ -1052,6 +1067,41 @@ mod tests { Ok(()) } + #[rstest(n_partitions, case(1), case(2), case(3), case(4))] + #[tokio::test] + async fn test_csv_parallel_newlines_in_values(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + let csv_options = CsvReadOptions::default() + .has_header(true) + .newlines_in_values(true); + let ctx = SessionContext::new_with_config(config); + let testdata = arrow_test_data(); + ctx.register_csv( + "aggr", + &format!("{testdata}/csv/aggregate_test_100.csv"), + csv_options, + ) + .await?; + + let query = "select sum(c3) from aggr;"; + let query_result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_query_csv_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = ["+--------------+", + "| sum(aggr.c3) |", + "+--------------+", + "| 781 |", + "+--------------+"]; + assert_batches_eq!(expected, &query_result); + assert_eq!(1, actual_partitions); // csv won't be scanned in parallel when newlines_in_values is set + + Ok(()) + } + /// Read a single empty csv file in parallel /// /// empty_0_byte.csv: diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index c6d143ed6749..552977baba17 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -63,6 +63,14 @@ pub struct CsvReadOptions<'a> { pub escape: Option, /// If enabled, lines beginning with this byte are ignored. pub comment: Option, + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub newlines_in_values: bool, /// An optional schema representing the CSV files. If None, CSV reader will try to infer it /// based on data in file. pub schema: Option<&'a Schema>, @@ -95,6 +103,7 @@ impl<'a> CsvReadOptions<'a> { delimiter: b',', quote: b'"', escape: None, + newlines_in_values: false, file_extension: DEFAULT_CSV_EXTENSION, table_partition_cols: vec![], file_compression_type: FileCompressionType::UNCOMPRESSED, @@ -133,6 +142,18 @@ impl<'a> CsvReadOptions<'a> { self } + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub fn newlines_in_values(mut self, newlines_in_values: bool) -> Self { + self.newlines_in_values = newlines_in_values; + self + } + /// Specify the file extension for CSV file selection pub fn file_extension(mut self, file_extension: &'a str) -> Self { self.file_extension = file_extension; @@ -490,6 +511,7 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { .with_delimiter(self.delimiter) .with_quote(self.quote) .with_escape(self.escape) + .with_newlines_in_values(self.newlines_in_values) .with_schema_infer_max_rec(self.schema_infer_max_records) .with_file_compression_type(self.file_compression_type.to_owned()); diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 327fbd976e87..fb0e23c6c164 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -59,6 +59,7 @@ pub struct CsvExec { quote: u8, escape: Option, comment: Option, + newlines_in_values: bool, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Compression type of the file associated with CsvExec @@ -68,6 +69,7 @@ pub struct CsvExec { impl CsvExec { /// Create a new CSV reader execution plan provided base and specific configurations + #[allow(clippy::too_many_arguments)] pub fn new( base_config: FileScanConfig, has_header: bool, @@ -75,6 +77,7 @@ impl CsvExec { quote: u8, escape: Option, comment: Option, + newlines_in_values: bool, file_compression_type: FileCompressionType, ) -> Self { let (projected_schema, projected_statistics, projected_output_ordering) = @@ -91,6 +94,7 @@ impl CsvExec { delimiter, quote, escape, + newlines_in_values, metrics: ExecutionPlanMetricsSet::new(), file_compression_type, cache, @@ -126,6 +130,17 @@ impl CsvExec { self.escape } + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub fn newlines_in_values(&self) -> bool { + self.newlines_in_values + } + fn output_partitioning_helper(file_scan_config: &FileScanConfig) -> Partitioning { Partitioning::UnknownPartitioning(file_scan_config.file_groups.len()) } @@ -196,15 +211,15 @@ impl ExecutionPlan for CsvExec { /// Redistribute files across partitions according to their size /// See comments on [`FileGroupPartitioner`] for more detail. /// - /// Return `None` if can't get repartitioned(empty/compressed file). + /// Return `None` if can't get repartitioned (empty, compressed file, or `newlines_in_values` set). fn repartitioned( &self, target_partitions: usize, config: &ConfigOptions, ) -> Result>> { let repartition_file_min_size = config.optimizer.repartition_file_min_size; - // Parallel execution on compressed CSV file is not supported yet. - if self.file_compression_type.is_compressed() { + // Parallel execution on compressed CSV files or files that must support newlines in values is not supported yet. + if self.file_compression_type.is_compressed() || self.newlines_in_values { return Ok(None); } @@ -589,6 +604,7 @@ mod tests { b'"', None, None, + false, file_compression_type.to_owned(), ); assert_eq!(13, csv.base_config.file_schema.fields().len()); @@ -658,6 +674,7 @@ mod tests { b'"', None, None, + false, file_compression_type.to_owned(), ); assert_eq!(13, csv.base_config.file_schema.fields().len()); @@ -727,6 +744,7 @@ mod tests { b'"', None, None, + false, file_compression_type.to_owned(), ); assert_eq!(13, csv.base_config.file_schema.fields().len()); @@ -793,6 +811,7 @@ mod tests { b'"', None, None, + false, file_compression_type.to_owned(), ); assert_eq!(14, csv.base_config.file_schema.fields().len()); @@ -858,6 +877,7 @@ mod tests { b'"', None, None, + false, file_compression_type.to_owned(), ); assert_eq!(13, csv.base_config.file_schema.fields().len()); @@ -953,6 +973,7 @@ mod tests { b'"', None, None, + false, file_compression_type.to_owned(), ); diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index afed5dd37535..9791f23f963e 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -1472,6 +1472,7 @@ pub(crate) mod tests { b'"', None, None, + false, FileCompressionType::UNCOMPRESSED, )) } @@ -1496,6 +1497,7 @@ pub(crate) mod tests { b'"', None, None, + false, FileCompressionType::UNCOMPRESSED, )) } @@ -3770,6 +3772,7 @@ pub(crate) mod tests { b'"', None, None, + false, compression_type, )), vec![("a".to_string(), "a".to_string())], diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 84f898431762..d0d0c985b8b6 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -186,6 +186,7 @@ fn try_swapping_with_csv( csv.quote(), csv.escape(), csv.comment(), + csv.newlines_in_values(), csv.file_compression_type, )) as _ }) @@ -1700,6 +1701,7 @@ mod tests { 0, None, None, + false, FileCompressionType::UNCOMPRESSED, )) } @@ -1723,6 +1725,7 @@ mod tests { 0, None, None, + false, FileCompressionType::UNCOMPRESSED, )) } diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 013155b8400a..6565e3e7d0d2 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -1503,6 +1503,7 @@ mod tests { b'"', None, None, + false, FileCompressionType::UNCOMPRESSED, )) } diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index e8550a79cb0e..5cb1b6ea7017 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -99,6 +99,7 @@ pub fn scan_partitioned_csv(partitions: usize, work_dir: &Path) -> Result for CsvOptions { quote: proto_opts.quote[0], escape: proto_opts.escape.first().copied(), double_quote: proto_opts.has_header.first().map(|h| *h != 0), + newlines_in_values: proto_opts.newlines_in_values.first().map(|h| *h != 0), compression: proto_opts.compression().into(), schema_infer_max_rec: proto_opts.schema_infer_max_rec as usize, date_format: (!proto_opts.date_format.is_empty()) diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index be3cc58b23df..4b34660ae2ef 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -1884,6 +1884,9 @@ impl serde::Serialize for CsvOptions { if !self.double_quote.is_empty() { len += 1; } + if !self.newlines_in_values.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvOptions", len)?; if !self.has_header.is_empty() { #[allow(clippy::needless_borrow)] @@ -1936,6 +1939,10 @@ impl serde::Serialize for CsvOptions { #[allow(clippy::needless_borrow)] struct_ser.serialize_field("doubleQuote", pbjson::private::base64::encode(&self.double_quote).as_str())?; } + if !self.newlines_in_values.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("newlinesInValues", pbjson::private::base64::encode(&self.newlines_in_values).as_str())?; + } struct_ser.end() } } @@ -1969,6 +1976,8 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { "comment", "double_quote", "doubleQuote", + "newlines_in_values", + "newlinesInValues", ]; #[allow(clippy::enum_variant_names)] @@ -1987,6 +1996,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { NullValue, Comment, DoubleQuote, + NewlinesInValues, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2022,6 +2032,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { "nullValue" | "null_value" => Ok(GeneratedField::NullValue), "comment" => Ok(GeneratedField::Comment), "doubleQuote" | "double_quote" => Ok(GeneratedField::DoubleQuote), + "newlinesInValues" | "newlines_in_values" => Ok(GeneratedField::NewlinesInValues), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2055,6 +2066,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { let mut null_value__ = None; let mut comment__ = None; let mut double_quote__ = None; + let mut newlines_in_values__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::HasHeader => { @@ -2155,6 +2167,14 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } + GeneratedField::NewlinesInValues => { + if newlines_in_values__.is_some() { + return Err(serde::de::Error::duplicate_field("newlinesInValues")); + } + newlines_in_values__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } } } Ok(CsvOptions { @@ -2172,6 +2192,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { null_value: null_value__.unwrap_or_default(), comment: comment__.unwrap_or_default(), double_quote: double_quote__.unwrap_or_default(), + newlines_in_values: newlines_in_values__.unwrap_or_default(), }) } } diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index b0674ff28d75..9a2770997f15 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -633,6 +633,9 @@ pub struct CsvOptions { /// Indicates if quotes are doubled #[prost(bytes = "vec", tag = "14")] pub double_quote: ::prost::alloc::vec::Vec, + /// Indicates if newlines are supported in values + #[prost(bytes = "vec", tag = "15")] + pub newlines_in_values: ::prost::alloc::vec::Vec, } /// Options controlling CSV format #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index 705a479e0178..9dcb65444a47 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -900,6 +900,9 @@ impl TryFrom<&CsvOptions> for protobuf::CsvOptions { quote: vec![opts.quote], escape: opts.escape.map_or_else(Vec::new, |e| vec![e]), double_quote: opts.double_quote.map_or_else(Vec::new, |h| vec![h as u8]), + newlines_in_values: opts + .newlines_in_values + .map_or_else(Vec::new, |h| vec![h as u8]), compression: compression.into(), schema_infer_max_rec: opts.schema_infer_max_rec as u64, date_format: opts.date_format.clone().unwrap_or_default(), diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index dc551778c5fb..49d9f2dde67f 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1007,6 +1007,7 @@ message CsvScanExecNode { oneof optional_comment { string comment = 6; } + bool newlines_in_values = 7; } message AvroScanExecNode { diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index b0674ff28d75..9a2770997f15 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -633,6 +633,9 @@ pub struct CsvOptions { /// Indicates if quotes are doubled #[prost(bytes = "vec", tag = "14")] pub double_quote: ::prost::alloc::vec::Vec, + /// Indicates if newlines are supported in values + #[prost(bytes = "vec", tag = "15")] + pub newlines_in_values: ::prost::alloc::vec::Vec, } /// Options controlling CSV format #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 8f77c24bd911..25f6646d2a9a 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -3605,6 +3605,9 @@ impl serde::Serialize for CsvScanExecNode { if !self.quote.is_empty() { len += 1; } + if self.newlines_in_values { + len += 1; + } if self.optional_escape.is_some() { len += 1; } @@ -3624,6 +3627,9 @@ impl serde::Serialize for CsvScanExecNode { if !self.quote.is_empty() { struct_ser.serialize_field("quote", &self.quote)?; } + if self.newlines_in_values { + struct_ser.serialize_field("newlinesInValues", &self.newlines_in_values)?; + } if let Some(v) = self.optional_escape.as_ref() { match v { csv_scan_exec_node::OptionalEscape::Escape(v) => { @@ -3654,6 +3660,8 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { "hasHeader", "delimiter", "quote", + "newlines_in_values", + "newlinesInValues", "escape", "comment", ]; @@ -3664,6 +3672,7 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { HasHeader, Delimiter, Quote, + NewlinesInValues, Escape, Comment, } @@ -3691,6 +3700,7 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), "delimiter" => Ok(GeneratedField::Delimiter), "quote" => Ok(GeneratedField::Quote), + "newlinesInValues" | "newlines_in_values" => Ok(GeneratedField::NewlinesInValues), "escape" => Ok(GeneratedField::Escape), "comment" => Ok(GeneratedField::Comment), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -3716,6 +3726,7 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { let mut has_header__ = None; let mut delimiter__ = None; let mut quote__ = None; + let mut newlines_in_values__ = None; let mut optional_escape__ = None; let mut optional_comment__ = None; while let Some(k) = map_.next_key()? { @@ -3744,6 +3755,12 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { } quote__ = Some(map_.next_value()?); } + GeneratedField::NewlinesInValues => { + if newlines_in_values__.is_some() { + return Err(serde::de::Error::duplicate_field("newlinesInValues")); + } + newlines_in_values__ = Some(map_.next_value()?); + } GeneratedField::Escape => { if optional_escape__.is_some() { return Err(serde::de::Error::duplicate_field("escape")); @@ -3763,6 +3780,7 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { has_header: has_header__.unwrap_or_default(), delimiter: delimiter__.unwrap_or_default(), quote: quote__.unwrap_or_default(), + newlines_in_values: newlines_in_values__.unwrap_or_default(), optional_escape: optional_escape__, optional_comment: optional_comment__, }) diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 605c56fa946a..ba288fe3d1b8 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1542,6 +1542,8 @@ pub struct CsvScanExecNode { pub delimiter: ::prost::alloc::string::String, #[prost(string, tag = "4")] pub quote: ::prost::alloc::string::String, + #[prost(bool, tag = "7")] + pub newlines_in_values: bool, #[prost(oneof = "csv_scan_exec_node::OptionalEscape", tags = "5")] pub optional_escape: ::core::option::Option, #[prost(oneof = "csv_scan_exec_node::OptionalComment", tags = "6")] diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 1220f42ded83..9e17c19ecbc5 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -211,6 +211,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { } else { None }, + scan.newlines_in_values, FileCompressionType::UNCOMPRESSED, ))), #[cfg(feature = "parquet")] @@ -1579,6 +1580,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { } else { None }, + newlines_in_values: exec.newlines_in_values(), }, )), }); diff --git a/datafusion/sqllogictest/test_files/csv_files.slt b/datafusion/sqllogictest/test_files/csv_files.slt index ca3bebe79f27..f7f5aa54dd0d 100644 --- a/datafusion/sqllogictest/test_files/csv_files.slt +++ b/datafusion/sqllogictest/test_files/csv_files.slt @@ -293,3 +293,45 @@ id0 "value0" id1 "value1" id2 "value2" id3 "value3" + +# Handling of newlines in values + +statement ok +SET datafusion.optimizer.repartition_file_min_size = 1; + +statement ok +CREATE EXTERNAL TABLE stored_table_with_newlines_in_values_unsafe ( +col1 TEXT, +col2 TEXT +) STORED AS CSV +LOCATION '../core/tests/data/newlines_in_values.csv'; + +statement error incorrect number of fields +select * from stored_table_with_newlines_in_values_unsafe; + +statement ok +CREATE EXTERNAL TABLE stored_table_with_newlines_in_values_safe ( +col1 TEXT, +col2 TEXT +) STORED AS CSV +LOCATION '../core/tests/data/newlines_in_values.csv' +OPTIONS ('format.newlines_in_values' 'true'); + +query TT +select * from stored_table_with_newlines_in_values_safe; +---- +id message +1 +01)hello +02)world +2 +01)something +02)else +3 +01) +02)many +03)lines +04)make +05)good test +4 unquoted +value end diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index f7b755b01911..c8c0d1d45b97 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -168,6 +168,7 @@ datafusion.catalog.format NULL datafusion.catalog.has_header false datafusion.catalog.information_schema true datafusion.catalog.location NULL +datafusion.catalog.newlines_in_values false datafusion.execution.aggregate.scalar_update_factor 10 datafusion.execution.batch_size 8192 datafusion.execution.coalesce_batches true @@ -252,6 +253,7 @@ datafusion.catalog.format NULL Type of `TableProvider` to use when loading `defa datafusion.catalog.has_header false Default value for `format.has_header` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. datafusion.catalog.information_schema true Should DataFusion provide access to `information_schema` virtual tables for displaying schema information datafusion.catalog.location NULL Location scanned to load tables for `default` schema +datafusion.catalog.newlines_in_values false Specifies whether newlines in (quoted) CSV values are supported. This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. Parsing newlines in quoted values may be affected by execution behaviour such as parallel file scanning. Setting this to `true` ensures that newlines in values are parsed successfully, which may reduce performance. datafusion.execution.aggregate.scalar_update_factor 10 Specifies the threshold for using `ScalarValue`s to update accumulators during high-cardinality aggregations for each input batch. The aggregation is considered high-cardinality if the number of affected groups is greater than or equal to `batch_size / scalar_update_factor`. In such cases, `ScalarValue`s are utilized for updating accumulators, rather than the default batch-slice approach. This can lead to performance improvements. By adjusting the `scalar_update_factor`, you can balance the trade-off between more efficient accumulator updates and the number of groups affected. datafusion.execution.batch_size 8192 Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 8d3ecbc98544..5e5de016e375 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -44,6 +44,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.catalog.location | NULL | Location scanned to load tables for `default` schema | | datafusion.catalog.format | NULL | Type of `TableProvider` to use when loading `default` schema | | datafusion.catalog.has_header | false | Default value for `format.has_header` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. | +| datafusion.catalog.newlines_in_values | false | Specifies whether newlines in (quoted) CSV values are supported. This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. Parsing newlines in quoted values may be affected by execution behaviour such as parallel file scanning. Setting this to `true` ensures that newlines in values are parsed successfully, which may reduce performance. | | datafusion.execution.batch_size | 8192 | Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption | | datafusion.execution.coalesce_batches | true | When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting | | datafusion.execution.collect_statistics | false | Should DataFusion collect statistics after listing files | From 63efaee2555ddd1381b4885867860621ec791f82 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Sun, 21 Jul 2024 17:09:54 -0700 Subject: [PATCH 23/37] Support SortMergeJoin spilling (#11218) * Support SortMerge spilling --- datafusion/core/tests/memory_limit/mod.rs | 27 +- datafusion/execution/src/memory_pool/mod.rs | 19 +- .../src/joins/sort_merge_join.rs | 457 +++++++++++++++--- datafusion/physical-plan/src/sorts/sort.rs | 7 +- datafusion/physical-plan/src/spill.rs | 103 +++- 5 files changed, 529 insertions(+), 84 deletions(-) diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index f4f4f8cd89cb..bc2c3315da59 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -164,7 +164,7 @@ async fn cross_join() { } #[tokio::test] -async fn merge_join() { +async fn sort_merge_join_no_spill() { // Planner chooses MergeJoin only if number of partitions > 1 let config = SessionConfig::new() .with_target_partitions(2) @@ -175,11 +175,32 @@ async fn merge_join() { "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", ) .with_expected_errors(vec![ - "Resources exhausted: Failed to allocate additional", + "Failed to allocate additional", "SMJStream", + "Disk spilling disabled", ]) .with_memory_limit(1_000) .with_config(config) + .with_scenario(Scenario::AccessLogStreaming) + .run() + .await +} + +#[tokio::test] +async fn sort_merge_join_spill() { + // Planner chooses MergeJoin only if number of partitions > 1 + let config = SessionConfig::new() + .with_target_partitions(2) + .set_bool("datafusion.optimizer.prefer_hash_join", false); + + TestCase::new() + .with_query( + "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", + ) + .with_memory_limit(1_000) + .with_config(config) + .with_disk_manager_config(DiskManagerConfig::NewOs) + .with_scenario(Scenario::AccessLogStreaming) .run() .await } @@ -453,7 +474,7 @@ impl TestCase { let table = scenario.table(); let rt_config = RuntimeConfig::new() - // do not allow spilling + // disk manager setting controls the spilling .with_disk_manager(disk_manager_config) .with_memory_limit(memory_limit, MEMORY_FRACTION); diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 3f66a304dc18..92ed1b2918de 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -18,7 +18,7 @@ //! [`MemoryPool`] for memory management during query execution, [`proxy]` for //! help with allocation accounting. -use datafusion_common::Result; +use datafusion_common::{internal_err, Result}; use std::{cmp::Ordering, sync::Arc}; mod pool; @@ -220,6 +220,23 @@ impl MemoryReservation { self.size = new_size } + /// Tries to free `capacity` bytes from this reservation + /// if `capacity` does not exceed [`Self::size`] + /// Returns new reservation size + /// or error if shrinking capacity is more than allocated size + pub fn try_shrink(&mut self, capacity: usize) -> Result { + if let Some(new_size) = self.size.checked_sub(capacity) { + self.registration.pool.shrink(self, capacity); + self.size = new_size; + Ok(new_size) + } else { + internal_err!( + "Cannot free the capacity {capacity} out of allocated size {}", + self.size + ) + } + } + /// Sets the size of this reservation to `capacity` pub fn resize(&mut self, capacity: usize) { match capacity.cmp(&self.size) { diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index a03e4a83fd2d..5fde028c7f48 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -24,40 +24,46 @@ use std::any::Any; use std::cmp::Ordering; use std::collections::VecDeque; use std::fmt::Formatter; +use std::fs::File; +use std::io::BufReader; use std::mem; use std::ops::Range; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::expressions::PhysicalSortExpr; -use crate::joins::utils::{ - build_join_schema, check_join_is_valid, estimate_join_statistics, - symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, -}; -use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; -use crate::{ - execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, - ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, - RecordBatchStream, SendableRecordBatchStream, Statistics, -}; - use arrow::array::*; use arrow::compute::{self, concat_batches, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; +use arrow::ipc::reader::FileReader; use arrow_array::types::UInt64Type; +use futures::{Stream, StreamExt}; +use hashbrown::HashSet; use datafusion_common::{ - internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, + exec_err, internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, + Result, }; +use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; -use futures::{Stream, StreamExt}; -use hashbrown::HashSet; +use crate::expressions::PhysicalSortExpr; +use crate::joins::utils::{ + build_join_schema, check_join_is_valid, estimate_join_statistics, + symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, +}; +use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use crate::spill::spill_record_batches; +use crate::{ + execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, + ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, + RecordBatchStream, SendableRecordBatchStream, Statistics, +}; /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. @@ -234,11 +240,6 @@ impl SortMergeJoinExec { impl DisplayAs for SortMergeJoinExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - let display_filter = self.filter.as_ref().map_or_else( - || "".to_string(), - |f| format!(", filter={}", f.expression()), - ); - match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { let on = self @@ -250,7 +251,12 @@ impl DisplayAs for SortMergeJoinExec { write!( f, "SortMergeJoin: join_type={:?}, on=[{}]{}", - self.join_type, on, display_filter + self.join_type, + on, + self.filter.as_ref().map_or("".to_string(), |f| format!( + ", filter={}", + f.expression() + )) ) } } @@ -375,6 +381,7 @@ impl ExecutionPlan for SortMergeJoinExec { batch_size, SortMergeJoinMetrics::new(partition, &self.metrics), reservation, + context.runtime_env(), )?)) } @@ -412,6 +419,12 @@ struct SortMergeJoinMetrics { /// Peak memory used for buffered data. /// Calculated as sum of peak memory values across partitions peak_mem_used: metrics::Gauge, + /// count of spills during the execution of the operator + spill_count: Count, + /// total spilled bytes during the execution of the operator + spilled_bytes: Count, + /// total spilled rows during the execution of the operator + spilled_rows: Count, } impl SortMergeJoinMetrics { @@ -425,6 +438,9 @@ impl SortMergeJoinMetrics { MetricBuilder::new(metrics).counter("output_batches", partition); let output_rows = MetricBuilder::new(metrics).output_rows(partition); let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", partition); + let spill_count = MetricBuilder::new(metrics).spill_count(partition); + let spilled_bytes = MetricBuilder::new(metrics).spilled_bytes(partition); + let spilled_rows = MetricBuilder::new(metrics).spilled_rows(partition); Self { join_time, @@ -433,6 +449,9 @@ impl SortMergeJoinMetrics { output_batches, output_rows, peak_mem_used, + spill_count, + spilled_bytes, + spilled_rows, } } } @@ -565,7 +584,8 @@ impl StreamedBatch { #[derive(Debug)] struct BufferedBatch { /// The buffered record batch - pub batch: RecordBatch, + /// None if the batch spilled to disk th + pub batch: Option, /// The range in which the rows share the same join key pub range: Range, /// Array refs of the join key @@ -577,6 +597,14 @@ struct BufferedBatch { /// The indices of buffered batch that failed the join filter. /// When dequeuing the buffered batch, we need to produce null joined rows for these indices. pub join_filter_failed_idxs: HashSet, + /// Current buffered batch number of rows. Equal to batch.num_rows() + /// but if batch is spilled to disk this property is preferable + /// and less expensive + pub num_rows: usize, + /// An optional temp spill file name on the disk if the batch spilled + /// None by default + /// Some(fileName) if the batch spilled to the disk + pub spill_file: Option, } impl BufferedBatch { @@ -602,13 +630,16 @@ impl BufferedBatch { + mem::size_of::>() + mem::size_of::(); + let num_rows = batch.num_rows(); BufferedBatch { - batch, + batch: Some(batch), range, join_arrays, null_joined: vec![], size_estimation, join_filter_failed_idxs: HashSet::new(), + num_rows, + spill_file: None, } } } @@ -666,6 +697,8 @@ struct SMJStream { pub join_metrics: SortMergeJoinMetrics, /// Memory reservation pub reservation: MemoryReservation, + /// Runtime env + pub runtime_env: Arc, } impl RecordBatchStream for SMJStream { @@ -785,6 +818,7 @@ impl SMJStream { batch_size: usize, join_metrics: SortMergeJoinMetrics, reservation: MemoryReservation, + runtime_env: Arc, ) -> Result { let streamed_schema = streamed.schema(); let buffered_schema = buffered.schema(); @@ -813,6 +847,7 @@ impl SMJStream { join_type, join_metrics, reservation, + runtime_env, }) } @@ -858,6 +893,58 @@ impl SMJStream { } } + fn free_reservation(&mut self, buffered_batch: BufferedBatch) -> Result<()> { + // Shrink memory usage for in-memory batches only + if buffered_batch.spill_file.is_none() && buffered_batch.batch.is_some() { + self.reservation + .try_shrink(buffered_batch.size_estimation)?; + } + + Ok(()) + } + + fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> { + match self.reservation.try_grow(buffered_batch.size_estimation) { + Ok(_) => { + self.join_metrics + .peak_mem_used + .set_max(self.reservation.size()); + Ok(()) + } + Err(_) if self.runtime_env.disk_manager.tmp_files_enabled() => { + // spill buffered batch to disk + let spill_file = self + .runtime_env + .disk_manager + .create_tmp_file("sort_merge_join_buffered_spill")?; + + if let Some(batch) = buffered_batch.batch { + spill_record_batches( + vec![batch], + spill_file.path().into(), + Arc::clone(&self.buffered_schema), + )?; + buffered_batch.spill_file = Some(spill_file); + buffered_batch.batch = None; + + // update metrics to register spill + self.join_metrics.spill_count.add(1); + self.join_metrics + .spilled_bytes + .add(buffered_batch.size_estimation); + self.join_metrics.spilled_rows.add(buffered_batch.num_rows); + Ok(()) + } else { + internal_err!("Buffered batch has empty body") + } + } + Err(e) => exec_err!("{}. Disk spilling disabled.", e.message()), + }?; + + self.buffered_data.batches.push_back(buffered_batch); + Ok(()) + } + /// Poll next buffered batches fn poll_buffered_batches(&mut self, cx: &mut Context) -> Poll>> { loop { @@ -867,12 +954,12 @@ impl SMJStream { while !self.buffered_data.batches.is_empty() { let head_batch = self.buffered_data.head_batch(); // If the head batch is fully processed, dequeue it and produce output of it. - if head_batch.range.end == head_batch.batch.num_rows() { + if head_batch.range.end == head_batch.num_rows { self.freeze_dequeuing_buffered()?; if let Some(buffered_batch) = self.buffered_data.batches.pop_front() { - self.reservation.shrink(buffered_batch.size_estimation); + self.free_reservation(buffered_batch)?; } } else { // If the head batch is not fully processed, break the loop. @@ -900,25 +987,22 @@ impl SMJStream { Poll::Ready(Some(batch)) => { self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); + if batch.num_rows() > 0 { let buffered_batch = BufferedBatch::new(batch, 0..1, &self.on_buffered); - self.reservation.try_grow(buffered_batch.size_estimation)?; - self.join_metrics - .peak_mem_used - .set_max(self.reservation.size()); - self.buffered_data.batches.push_back(buffered_batch); + self.allocate_reservation(buffered_batch)?; self.buffered_state = BufferedState::PollingRest; } } }, BufferedState::PollingRest => { if self.buffered_data.tail_batch().range.end - < self.buffered_data.tail_batch().batch.num_rows() + < self.buffered_data.tail_batch().num_rows { while self.buffered_data.tail_batch().range.end - < self.buffered_data.tail_batch().batch.num_rows() + < self.buffered_data.tail_batch().num_rows { if is_join_arrays_equal( &self.buffered_data.head_batch().join_arrays, @@ -941,6 +1025,7 @@ impl SMJStream { self.buffered_state = BufferedState::Ready; } Poll::Ready(Some(batch)) => { + // Polling batches coming concurrently as multiple partitions self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); if batch.num_rows() > 0 { @@ -949,12 +1034,7 @@ impl SMJStream { 0..0, &self.on_buffered, ); - self.reservation - .try_grow(buffered_batch.size_estimation)?; - self.join_metrics - .peak_mem_used - .set_max(self.reservation.size()); - self.buffered_data.batches.push_back(buffered_batch); + self.allocate_reservation(buffered_batch)?; } } } @@ -1473,13 +1553,8 @@ fn produce_buffered_null_batch( } // Take buffered (right) columns - let buffered_columns = buffered_batch - .batch - .columns() - .iter() - .map(|column| take(column, &buffered_indices, None)) - .collect::, ArrowError>>() - .map_err(Into::::into)?; + let buffered_columns = + get_buffered_columns_from_batch(buffered_batch, buffered_indices)?; // Create null streamed (left) columns let mut streamed_columns = streamed_schema @@ -1502,13 +1577,45 @@ fn get_buffered_columns( buffered_data: &BufferedData, buffered_batch_idx: usize, buffered_indices: &UInt64Array, -) -> Result, ArrowError> { - buffered_data.batches[buffered_batch_idx] - .batch - .columns() - .iter() - .map(|column| take(column, &buffered_indices, None)) - .collect::, ArrowError>>() +) -> Result> { + get_buffered_columns_from_batch( + &buffered_data.batches[buffered_batch_idx], + buffered_indices, + ) +} + +#[inline(always)] +fn get_buffered_columns_from_batch( + buffered_batch: &BufferedBatch, + buffered_indices: &UInt64Array, +) -> Result> { + match (&buffered_batch.spill_file, &buffered_batch.batch) { + // In memory batch + (None, Some(batch)) => Ok(batch + .columns() + .iter() + .map(|column| take(column, &buffered_indices, None)) + .collect::, ArrowError>>() + .map_err(Into::::into)?), + // If the batch was spilled to disk, less likely + (Some(spill_file), None) => { + let mut buffered_cols: Vec = + Vec::with_capacity(buffered_indices.len()); + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = FileReader::try_new(file, None)?; + + for batch in reader { + batch?.columns().iter().for_each(|column| { + buffered_cols.extend(take(column, &buffered_indices, None)) + }); + } + + Ok(buffered_cols) + } + // Invalid combination + (spill, batch) => internal_err!("Unexpected buffered batch spill status. Spill exists: {}. In-memory exists: {}", spill.is_some(), batch.is_some()), + } } /// Calculate join filter bit mask considering join type specifics @@ -1854,6 +1961,7 @@ mod tests { assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, }; use datafusion_execution::config::SessionConfig; + use datafusion_execution::disk_manager::DiskManagerConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_execution::TaskContext; @@ -2749,7 +2857,7 @@ mod tests { } #[tokio::test] - async fn overallocation_single_batch() -> Result<()> { + async fn overallocation_single_batch_no_spill() -> Result<()> { let left = build_table( ("a1", &vec![0, 1, 2, 3, 4, 5]), ("b1", &vec![1, 2, 3, 4, 5, 6]), @@ -2775,14 +2883,17 @@ mod tests { JoinType::LeftAnti, ]; - for join_type in join_types { - let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); - let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); - let session_config = SessionConfig::default().with_batch_size(50); + // Disable DiskManager to prevent spilling + let runtime_config = RuntimeConfig::new() + .with_memory_limit(100, 1.0) + .with_disk_manager(DiskManagerConfig::Disabled); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let session_config = SessionConfig::default().with_batch_size(50); + for join_type in join_types { let task_ctx = TaskContext::default() - .with_session_config(session_config) - .with_runtime(runtime); + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); let task_ctx = Arc::new(task_ctx); let join = join_with_options( @@ -2797,18 +2908,20 @@ mod tests { let stream = join.execute(0, task_ctx)?; let err = common::collect(stream).await.unwrap_err(); - assert_contains!( - err.to_string(), - "Resources exhausted: Failed to allocate additional" - ); + assert_contains!(err.to_string(), "Failed to allocate additional"); assert_contains!(err.to_string(), "SMJStream[0]"); + assert_contains!(err.to_string(), "Disk spilling disabled"); + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); } Ok(()) } #[tokio::test] - async fn overallocation_multi_batch() -> Result<()> { + async fn overallocation_multi_batch_no_spill() -> Result<()> { let left_batch_1 = build_table_i32( ("a1", &vec![0, 1]), ("b1", &vec![1, 1]), @@ -2855,13 +2968,17 @@ mod tests { JoinType::LeftAnti, ]; + // Disable DiskManager to prevent spilling + let runtime_config = RuntimeConfig::new() + .with_memory_limit(100, 1.0) + .with_disk_manager(DiskManagerConfig::Disabled); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let session_config = SessionConfig::default().with_batch_size(50); + for join_type in join_types { - let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); - let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); - let session_config = SessionConfig::default().with_batch_size(50); let task_ctx = TaskContext::default() - .with_session_config(session_config) - .with_runtime(runtime); + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); let task_ctx = Arc::new(task_ctx); let join = join_with_options( Arc::clone(&left), @@ -2875,11 +2992,205 @@ mod tests { let stream = join.execute(0, task_ctx)?; let err = common::collect(stream).await.unwrap_err(); - assert_contains!( - err.to_string(), - "Resources exhausted: Failed to allocate additional" - ); + assert_contains!(err.to_string(), "Failed to allocate additional"); assert_contains!(err.to_string(), "SMJStream[0]"); + assert_contains!(err.to_string(), "Disk spilling disabled"); + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + } + + Ok(()) + } + + #[tokio::test] + async fn overallocation_single_batch_spill() -> Result<()> { + let left = build_table( + ("a1", &vec![0, 1, 2, 3, 4, 5]), + ("b1", &vec![1, 2, 3, 4, 5, 6]), + ("c1", &vec![4, 5, 6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![0, 10, 20, 30, 40]), + ("b2", &vec![1, 3, 4, 6, 8]), + ("c2", &vec![50, 60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = [ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + ]; + + // Enable DiskManager to allow spilling + let runtime_config = RuntimeConfig::new() + .with_memory_limit(100, 1.0) + .with_disk_manager(DiskManagerConfig::NewOs); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + + for batch_size in [1, 50] { + let session_config = SessionConfig::default().with_batch_size(batch_size); + + for join_type in &join_types { + let task_ctx = TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); + let task_ctx = Arc::new(task_ctx); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + false, + )?; + + let stream = join.execute(0, task_ctx)?; + let spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert!(join.metrics().unwrap().spill_count().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); + + // Run the test with no spill configuration as + let task_ctx_no_spill = + TaskContext::default().with_session_config(session_config.clone()); + let task_ctx_no_spill = Arc::new(task_ctx_no_spill); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + false, + )?; + let stream = join.execute(0, task_ctx_no_spill)?; + let no_spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + // Compare spilled and non spilled data to check spill logic doesn't corrupt the data + assert_eq!(spilled_join_result, no_spilled_join_result); + } + } + + Ok(()) + } + + #[tokio::test] + async fn overallocation_multi_batch_spill() -> Result<()> { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![2, 3]), + ("b1", &vec![1, 1]), + ("c1", &vec![6, 7]), + ); + let left_batch_3 = build_table_i32( + ("a1", &vec![4, 5]), + ("b1", &vec![1, 1]), + ("c1", &vec![8, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10]), + ("b2", &vec![1, 1]), + ("c2", &vec![50, 60]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![20, 30]), + ("b2", &vec![1, 1]), + ("c2", &vec![70, 80]), + ); + let right_batch_3 = + build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90])); + let left = + build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]); + let right = + build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = [ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + ]; + + // Enable DiskManager to allow spilling + let runtime_config = RuntimeConfig::new() + .with_memory_limit(500, 1.0) + .with_disk_manager(DiskManagerConfig::NewOs); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + for batch_size in [1, 50] { + let session_config = SessionConfig::default().with_batch_size(batch_size); + + for join_type in &join_types { + let task_ctx = TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); + let task_ctx = Arc::new(task_ctx); + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + false, + )?; + + let stream = join.execute(0, task_ctx)?; + let spilled_join_result = common::collect(stream).await.unwrap(); + assert!(join.metrics().is_some()); + assert!(join.metrics().unwrap().spill_count().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); + + // Run the test with no spill configuration as + let task_ctx_no_spill = + TaskContext::default().with_session_config(session_config.clone()); + let task_ctx_no_spill = Arc::new(task_ctx_no_spill); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + false, + )?; + let stream = join.execute(0, task_ctx_no_spill)?; + let no_spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + // Compare spilled and non spilled data to check spill logic doesn't corrupt the data + assert_eq!(spilled_join_result, no_spilled_join_result); + } } Ok(()) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index d576f77d9f74..13ff63c17405 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -45,7 +45,7 @@ use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, SortField}; use arrow_array::{Array, RecordBatchOptions, UInt32Array}; use arrow_schema::DataType; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{internal_err, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; @@ -333,10 +333,7 @@ impl ExternalSorter { for spill in self.spills.drain(..) { if !spill.path().exists() { - return Err(DataFusionError::Internal(format!( - "Spill file {:?} does not exist", - spill.path() - ))); + return internal_err!("Spill file {:?} does not exist", spill.path()); } let stream = read_spill_as_stream(spill, Arc::clone(&self.schema), 2)?; streams.push(stream); diff --git a/datafusion/physical-plan/src/spill.rs b/datafusion/physical-plan/src/spill.rs index 0018a27bd22b..21ca58fa0a9f 100644 --- a/datafusion/physical-plan/src/spill.rs +++ b/datafusion/physical-plan/src/spill.rs @@ -40,7 +40,7 @@ use crate::stream::RecordBatchReceiverStream; /// `path` - temp file /// `schema` - batches schema, should be the same across batches /// `buffer` - internal buffer of capacity batches -pub fn read_spill_as_stream( +pub(crate) fn read_spill_as_stream( path: RefCountedTempFile, schema: SchemaRef, buffer: usize, @@ -56,7 +56,7 @@ pub fn read_spill_as_stream( /// Spills in-memory `batches` to disk. /// /// Returns total number of the rows spilled to disk. -pub fn spill_record_batches( +pub(crate) fn spill_record_batches( batches: Vec, path: PathBuf, schema: SchemaRef, @@ -85,3 +85,102 @@ fn read_spill(sender: Sender>, path: &Path) -> Result<()> { } Ok(()) } + +/// Spill the `RecordBatch` to disk as smaller batches +/// split by `batch_size_rows` +/// Return `total_rows` what is spilled +pub fn spill_record_batch_by_size( + batch: &RecordBatch, + path: PathBuf, + schema: SchemaRef, + batch_size_rows: usize, +) -> Result<()> { + let mut offset = 0; + let total_rows = batch.num_rows(); + let mut writer = IPCWriter::new(&path, schema.as_ref())?; + + while offset < total_rows { + let length = std::cmp::min(total_rows - offset, batch_size_rows); + let batch = batch.slice(offset, length); + offset += batch.num_rows(); + writer.write(&batch)?; + } + writer.finish()?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use crate::spill::{spill_record_batch_by_size, spill_record_batches}; + use crate::test::build_table_i32; + use datafusion_common::Result; + use datafusion_execution::disk_manager::DiskManagerConfig; + use datafusion_execution::DiskManager; + use std::fs::File; + use std::io::BufReader; + use std::sync::Arc; + + #[test] + fn test_batch_spill_and_read() -> Result<()> { + let batch1 = build_table_i32( + ("a2", &vec![0, 1, 2]), + ("b2", &vec![3, 4, 5]), + ("c2", &vec![4, 5, 6]), + ); + + let batch2 = build_table_i32( + ("a2", &vec![10, 11, 12]), + ("b2", &vec![13, 14, 15]), + ("c2", &vec![14, 15, 16]), + ); + + let disk_manager = DiskManager::try_new(DiskManagerConfig::NewOs)?; + + let spill_file = disk_manager.create_tmp_file("Test Spill")?; + let schema = batch1.schema(); + let num_rows = batch1.num_rows() + batch2.num_rows(); + let cnt = spill_record_batches( + vec![batch1, batch2], + spill_file.path().into(), + Arc::clone(&schema), + ); + assert_eq!(cnt.unwrap(), num_rows); + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = arrow::ipc::reader::FileReader::try_new(file, None)?; + + assert_eq!(reader.num_batches(), 2); + assert_eq!(reader.schema(), schema); + + Ok(()) + } + + #[test] + fn test_batch_spill_by_size() -> Result<()> { + let batch1 = build_table_i32( + ("a2", &vec![0, 1, 2, 3]), + ("b2", &vec![3, 4, 5, 6]), + ("c2", &vec![4, 5, 6, 7]), + ); + + let disk_manager = DiskManager::try_new(DiskManagerConfig::NewOs)?; + + let spill_file = disk_manager.create_tmp_file("Test Spill")?; + let schema = batch1.schema(); + spill_record_batch_by_size( + &batch1, + spill_file.path().into(), + Arc::clone(&schema), + 1, + )?; + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = arrow::ipc::reader::FileReader::try_new(file, None)?; + + assert_eq!(reader.num_batches(), 4); + assert_eq!(reader.schema(), schema); + + Ok(()) + } +} From ecf5323eaa38869ed2f911b02f98e17aa6db639a Mon Sep 17 00:00:00 2001 From: yfu Date: Mon, 22 Jul 2024 21:04:36 +1000 Subject: [PATCH 24/37] Fix unparser invalid sql for query with order (#11527) * wip * fix wrong unparsed query for original query that has derived table with limit/sort/distinct; fix wrong unparsed query for original query with sort column that is not in select * clippy * addressed the comments, also fix one issue when selected column is aliased - see test --- datafusion/sql/src/unparser/plan.rs | 67 +++++++++------- datafusion/sql/src/unparser/rewrite.rs | 80 +++++++++++++++++- datafusion/sql/tests/cases/plan_to_sql.rs | 98 +++++++++++++++++++++++ 3 files changed, 215 insertions(+), 30 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 7f050d8a0690..59660f4f0404 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -29,6 +29,7 @@ use super::{ SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder, }, rewrite::normalize_union_schema, + rewrite::rewrite_plan_for_sort_on_non_projected_fields, utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant}, Unparser, }; @@ -199,33 +200,21 @@ impl Unparser<'_> { Ok(()) } - fn projection_to_sql( - &self, - plan: &LogicalPlan, - p: &Projection, - query: &mut Option, - select: &mut SelectBuilder, - relation: &mut RelationBuilder, - ) -> Result<()> { - // A second projection implies a derived tablefactor - if !select.already_projected() { - self.reconstruct_select_statement(plan, p, select)?; - self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) - } else { - let mut derived_builder = DerivedRelationBuilder::default(); - derived_builder.lateral(false).alias(None).subquery({ - let inner_statement = self.plan_to_sql(plan)?; - if let ast::Statement::Query(inner_query) = inner_statement { - inner_query - } else { - return internal_err!( - "Subquery must be a Query, but found {inner_statement:?}" - ); - } - }); - relation.derived(derived_builder); - Ok(()) - } + fn derive(&self, plan: &LogicalPlan, relation: &mut RelationBuilder) -> Result<()> { + let mut derived_builder = DerivedRelationBuilder::default(); + derived_builder.lateral(false).alias(None).subquery({ + let inner_statement = self.plan_to_sql(plan)?; + if let ast::Statement::Query(inner_query) = inner_statement { + inner_query + } else { + return internal_err!( + "Subquery must be a Query, but found {inner_statement:?}" + ); + } + }); + relation.derived(derived_builder); + + Ok(()) } fn select_to_sql_recursively( @@ -256,7 +245,17 @@ impl Unparser<'_> { Ok(()) } LogicalPlan::Projection(p) => { - self.projection_to_sql(plan, p, query, select, relation) + if let Some(new_plan) = rewrite_plan_for_sort_on_non_projected_fields(p) { + return self + .select_to_sql_recursively(&new_plan, query, select, relation); + } + + // Projection can be top-level plan for derived table + if select.already_projected() { + return self.derive(plan, relation); + } + self.reconstruct_select_statement(plan, p, select)?; + self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) } LogicalPlan::Filter(filter) => { if let Some(AggVariant::Aggregate(agg)) = @@ -278,6 +277,10 @@ impl Unparser<'_> { ) } LogicalPlan::Limit(limit) => { + // Limit can be top-level plan for derived table + if select.already_projected() { + return self.derive(plan, relation); + } if let Some(fetch) = limit.fetch { let Some(query) = query.as_mut() else { return internal_err!( @@ -298,6 +301,10 @@ impl Unparser<'_> { ) } LogicalPlan::Sort(sort) => { + // Sort can be top-level plan for derived table + if select.already_projected() { + return self.derive(plan, relation); + } if let Some(query_ref) = query { query_ref.order_by(self.sort_to_sql(sort.expr.clone())?); } else { @@ -323,6 +330,10 @@ impl Unparser<'_> { ) } LogicalPlan::Distinct(distinct) => { + // Distinct can be top-level plan for derived table + if select.already_projected() { + return self.derive(plan, relation); + } let (select_distinct, input) = match distinct { Distinct::All(input) => (ast::Distinct::Distinct, input.as_ref()), Distinct::On(on) => { diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index a73fce30ced3..fba95ad48f32 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -15,13 +15,16 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeIterator}, Result, }; -use datafusion_expr::{Expr, LogicalPlan, Sort}; +use datafusion_expr::{Expr, LogicalPlan, Projection, Sort}; /// Normalize the schema of a union plan to remove qualifiers from the schema fields and sort expressions. /// @@ -99,3 +102,76 @@ fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { Ok(sort_exprs) } + +// Rewrite logic plan for query that order by columns are not in projections +// Plan before rewrite: +// +// Projection: j1.j1_string, j2.j2_string +// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST +// Projection: j1.j1_string, j2.j2_string, j1.j1_id, j2.j2_id +// Inner Join: Filter: j1.j1_id = j2.j2_id +// TableScan: j1 +// TableScan: j2 +// +// Plan after rewrite +// +// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST +// Projection: j1.j1_string, j2.j2_string +// Inner Join: Filter: j1.j1_id = j2.j2_id +// TableScan: j1 +// TableScan: j2 +// +// This prevents the original plan generate query with derived table but missing alias. +pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( + p: &Projection, +) -> Option { + let LogicalPlan::Sort(sort) = p.input.as_ref() else { + return None; + }; + + let LogicalPlan::Projection(inner_p) = sort.input.as_ref() else { + return None; + }; + + let mut map = HashMap::new(); + let inner_exprs = inner_p + .expr + .iter() + .map(|f| { + if let Expr::Alias(alias) = f { + let a = Expr::Column(alias.name.clone().into()); + map.insert(a.clone(), f.clone()); + a + } else { + f.clone() + } + }) + .collect::>(); + + let mut collects = p.expr.clone(); + for expr in &sort.expr { + if let Expr::Sort(s) = expr { + collects.push(s.expr.as_ref().clone()); + } + } + + if collects.iter().collect::>() + == inner_exprs.iter().collect::>() + { + let mut sort = sort.clone(); + let mut inner_p = inner_p.clone(); + + let new_exprs = p + .expr + .iter() + .map(|e| map.get(e).unwrap_or(e).clone()) + .collect::>(); + + inner_p.expr.clone_from(&new_exprs); + sort.input = Arc::new(LogicalPlan::Projection(inner_p)); + + Some(LogicalPlan::Sort(sort)) + } else { + None + } +} diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index e9c4114353c0..aada560fd884 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -244,6 +244,50 @@ fn roundtrip_statement_with_dialect() -> Result<()> { parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), }, + // Test query with derived tables that put distinct,sort,limit on the wrong level + TestStatementWithDialect { + sql: "SELECT j1_string from j1 order by j1_id", + expected: r#"SELECT j1.j1_string FROM j1 ORDER BY j1.j1_id ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT j1_string AS a from j1 order by j1_id", + expected: r#"SELECT j1.j1_string AS a FROM j1 ORDER BY j1.j1_id ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT j1_string from j1 join j2 on j1.j1_id = j2.j2_id order by j1_id", + expected: r#"SELECT j1.j1_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: " + SELECT + j1_string, + j2_string + FROM + ( + SELECT + distinct j1_id, + j1_string, + j2_string + from + j1 + INNER join j2 ON j1.j1_id = j2.j2_id + order by + j1.j1_id desc + limit + 10 + ) abc + ORDER BY + abc.j2_string", + expected: r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT DISTINCT j1.j1_id, j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, // more tests around subquery/derived table roundtrip TestStatementWithDialect { sql: "SELECT string_count FROM ( @@ -261,6 +305,60 @@ fn roundtrip_statement_with_dialect() -> Result<()> { parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), }, + TestStatementWithDialect { + sql: " + SELECT + j1_string, + j2_string + FROM + ( + SELECT + j1_id, + j1_string, + j2_string + from + j1 + INNER join j2 ON j1.j1_id = j2.j2_id + group by + j1_id, + j1_string, + j2_string + order by + j1.j1_id desc + limit + 10 + ) abc + ORDER BY + abc.j2_string", + expected: r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT j1.j1_id, j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) GROUP BY j1.j1_id, j1.j1_string, j2.j2_string ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + // Test query that order by columns are not in select columns + TestStatementWithDialect { + sql: " + SELECT + j1_string + FROM + ( + SELECT + j1_string, + j2_string + from + j1 + INNER join j2 ON j1.j1_id = j2.j2_id + order by + j1.j1_id desc, + j2.j2_id desc + limit + 10 + ) abc + ORDER BY + j2_string", + expected: r#"SELECT abc.j1_string FROM (SELECT j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, TestStatementWithDialect { sql: "SELECT id FROM (SELECT j1_id from j1) AS c (id)", expected: r#"SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)"#, From 12c0a1e2e21a750e2672bf3109e244836a12b399 Mon Sep 17 00:00:00 2001 From: Kaviraj Kanagaraj Date: Mon, 22 Jul 2024 14:27:07 +0200 Subject: [PATCH 25/37] chore: Minor cleanup `simplify_demo()` example (#11576) * chore: fix examples and comments Signed-off-by: Kaviraj * remove unused `b` field Signed-off-by: Kaviraj * fix the number of days Signed-off-by: Kaviraj --------- Signed-off-by: Kaviraj --- datafusion-examples/examples/expr_api.rs | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index a5cf7011f811..a48171c625a8 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -177,16 +177,12 @@ fn simplify_demo() -> Result<()> { ); // here are some other examples of what DataFusion is capable of - let schema = Schema::new(vec![ - make_field("i", DataType::Int64), - make_field("b", DataType::Boolean), - ]) - .to_dfschema_ref()?; + let schema = Schema::new(vec![make_field("i", DataType::Int64)]).to_dfschema_ref()?; let context = SimplifyContext::new(&props).with_schema(schema.clone()); let simplifier = ExprSimplifier::new(context); // basic arithmetic simplification - // i + 1 + 2 => a + 3 + // i + 1 + 2 => i + 3 // (note this is not done if the expr is (col("i") + (lit(1) + lit(2)))) assert_eq!( simplifier.simplify(col("i") + (lit(1) + lit(2)))?, @@ -209,7 +205,7 @@ fn simplify_demo() -> Result<()> { ); // String --> Date simplification - // `cast('2020-09-01' as date)` --> 18500 + // `cast('2020-09-01' as date)` --> 18506 # number of days since epoch 1970-01-01 assert_eq!( simplifier.simplify(lit("2020-09-01").cast_to(&DataType::Date32, &schema)?)?, lit(ScalarValue::Date32(Some(18506))) From f9457de779e213f610fc92dd9165076c7ee770a2 Mon Sep 17 00:00:00 2001 From: Devesh Rahatekar <79015420+devesh-2002@users.noreply.github.com> Date: Mon, 22 Jul 2024 18:01:52 +0530 Subject: [PATCH 26/37] Move Datafusion Query Optimizer to library user guide (#11563) * Added Datafusion Query Optimizer to user guide * Updated Query optimizer name, Added to index and replaced the README content * Fix RAT check --------- Co-authored-by: Andrew Lamb --- datafusion/optimizer/README.md | 318 +---------------- docs/source/index.rst | 2 +- .../library-user-guide/query-optimizer.md | 336 ++++++++++++++++++ 3 files changed, 339 insertions(+), 317 deletions(-) create mode 100644 docs/source/library-user-guide/query-optimizer.md diff --git a/datafusion/optimizer/README.md b/datafusion/optimizer/README.md index 5aacfaf59cb1..61bc1cd70145 100644 --- a/datafusion/optimizer/README.md +++ b/datafusion/optimizer/README.md @@ -17,320 +17,6 @@ under the License. --> -# DataFusion Query Optimizer +Please see [Query Optimizer] in the Library User Guide -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory -format. - -DataFusion has modular design, allowing individual crates to be re-used in other projects. - -This crate is a submodule of DataFusion that provides a query optimizer for logical plans, and -contains an extensive set of OptimizerRules that may rewrite the plan and/or its expressions so -they execute more quickly while still computing the same result. - -## Running the Optimizer - -The following code demonstrates the basic flow of creating the optimizer with a default set of optimization rules -and applying it to a logical plan to produce an optimized logical plan. - -```rust - -// We need a logical plan as the starting point. There are many ways to build a logical plan: -// -// The `datafusion-expr` crate provides a LogicalPlanBuilder -// The `datafusion-sql` crate provides a SQL query planner that can create a LogicalPlan from SQL -// The `datafusion` crate provides a DataFrame API that can create a LogicalPlan -let logical_plan = ... - -let mut config = OptimizerContext::default(); -let optimizer = Optimizer::new(&config); -let optimized_plan = optimizer.optimize(&logical_plan, &config, observe)?; - -fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { - println!( - "After applying rule '{}':\n{}", - rule.name(), - plan.display_indent() - ) -} -``` - -## Providing Custom Rules - -The optimizer can be created with a custom set of rules. - -```rust -let optimizer = Optimizer::with_rules(vec![ - Arc::new(MyRule {}) -]); -``` - -## Writing Optimization Rules - -Please refer to the -[optimizer_rule.rs](../../datafusion-examples/examples/optimizer_rule.rs) -example to learn more about the general approach to writing optimizer rules and -then move onto studying the existing rules. - -All rules must implement the `OptimizerRule` trait. - -```rust -/// `OptimizerRule` transforms one ['LogicalPlan'] into another which -/// computes the same results, but in a potentially more efficient -/// way. If there are no suitable transformations for the input plan, -/// the optimizer can simply return it as is. -pub trait OptimizerRule { - /// Rewrite `plan` to an optimized form - fn optimize( - &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result; - - /// A human readable name for this optimizer rule - fn name(&self) -> &str; -} -``` - -### General Guidelines - -Rules typical walk the logical plan and walk the expression trees inside operators and selectively mutate -individual operators or expressions. - -Sometimes there is an initial pass that visits the plan and builds state that is used in a second pass that performs -the actual optimization. This approach is used in projection push down and filter push down. - -### Expression Naming - -Every expression in DataFusion has a name, which is used as the column name. For example, in this example the output -contains a single column with the name `"COUNT(aggregate_test_100.c9)"`: - -```text -> select count(c9) from aggregate_test_100; -+------------------------------+ -| COUNT(aggregate_test_100.c9) | -+------------------------------+ -| 100 | -+------------------------------+ -``` - -These names are used to refer to the columns in both subqueries as well as internally from one stage of the LogicalPlan -to another. For example: - -```text -> select "COUNT(aggregate_test_100.c9)" + 1 from (select count(c9) from aggregate_test_100) as sq; -+--------------------------------------------+ -| sq.COUNT(aggregate_test_100.c9) + Int64(1) | -+--------------------------------------------+ -| 101 | -+--------------------------------------------+ -``` - -### Implication - -Because DataFusion identifies columns using a string name, it means it is critical that the names of expressions are -not changed by the optimizer when it rewrites expressions. This is typically accomplished by renaming a rewritten -expression by adding an alias. - -Here is a simple example of such a rewrite. The expression `1 + 2` can be internally simplified to 3 but must still be -displayed the same as `1 + 2`: - -```text -> select 1 + 2; -+---------------------+ -| Int64(1) + Int64(2) | -+---------------------+ -| 3 | -+---------------------+ -``` - -Looking at the `EXPLAIN` output we can see that the optimizer has effectively rewritten `1 + 2` into effectively -`3 as "1 + 2"`: - -```text -> explain select 1 + 2; -+---------------+-------------------------------------------------+ -| plan_type | plan | -+---------------+-------------------------------------------------+ -| logical_plan | Projection: Int64(3) AS Int64(1) + Int64(2) | -| | EmptyRelation | -| physical_plan | ProjectionExec: expr=[3 as Int64(1) + Int64(2)] | -| | PlaceholderRowExec | -| | | -+---------------+-------------------------------------------------+ -``` - -If the expression name is not preserved, bugs such as [#3704](https://github.com/apache/datafusion/issues/3704) -and [#3555](https://github.com/apache/datafusion/issues/3555) occur where the expected columns can not be found. - -### Building Expression Names - -There are currently two ways to create a name for an expression in the logical plan. - -```rust -impl Expr { - /// Returns the name of this expression as it should appear in a schema. This name - /// will not include any CAST expressions. - pub fn display_name(&self) -> Result { - create_name(self) - } - - /// Returns a full and complete string representation of this expression. - pub fn canonical_name(&self) -> String { - format!("{}", self) - } -} -``` - -When comparing expressions to determine if they are equivalent, `canonical_name` should be used, and when creating a -name to be used in a schema, `display_name` should be used. - -### Utilities - -There are a number of utility methods provided that take care of some common tasks. - -### ExprVisitor - -The `ExprVisitor` and `ExprVisitable` traits provide a mechanism for applying a visitor pattern to an expression tree. - -Here is an example that demonstrates this. - -```rust -fn extract_subquery_filters(expression: &Expr, extracted: &mut Vec) -> Result<()> { - struct InSubqueryVisitor<'a> { - accum: &'a mut Vec, - } - - impl ExpressionVisitor for InSubqueryVisitor<'_> { - fn pre_visit(self, expr: &Expr) -> Result> { - if let Expr::InSubquery(_) = expr { - self.accum.push(expr.to_owned()); - } - Ok(Recursion::Continue(self)) - } - } - - expression.accept(InSubqueryVisitor { accum: extracted })?; - Ok(()) -} -``` - -### Rewriting Expressions - -The `MyExprRewriter` trait can be implemented to provide a way to rewrite expressions. This rule can then be applied -to an expression by calling `Expr::rewrite` (from the `ExprRewritable` trait). - -The `rewrite` method will perform a depth first walk of the expression and its children to rewrite an expression, -consuming `self` producing a new expression. - -```rust -let mut expr_rewriter = MyExprRewriter {}; -let expr = expr.rewrite(&mut expr_rewriter)?; -``` - -Here is an example implementation which will rewrite `expr BETWEEN a AND b` as `expr >= a AND expr <= b`. Note that the -implementation does not need to perform any recursion since this is handled by the `rewrite` method. - -```rust -struct MyExprRewriter {} - -impl ExprRewriter for MyExprRewriter { - fn mutate(&mut self, expr: Expr) -> Result { - match expr { - Expr::Between { - negated, - expr, - low, - high, - } => { - let expr: Expr = expr.as_ref().clone(); - let low: Expr = low.as_ref().clone(); - let high: Expr = high.as_ref().clone(); - if negated { - Ok(expr.clone().lt(low).or(expr.clone().gt(high))) - } else { - Ok(expr.clone().gt_eq(low).and(expr.clone().lt_eq(high))) - } - } - _ => Ok(expr.clone()), - } - } -} -``` - -### optimize_children - -Typically a rule is applied recursively to all operators within a query plan. Rather than duplicate -that logic in each rule, an `optimize_children` method is provided. This recursively invokes the `optimize` method on -the plan's children and then returns a node of the same type. - -```rust -fn optimize( - &self, - plan: &LogicalPlan, - _config: &mut OptimizerConfig, -) -> Result { - // recurse down and optimize children first - let plan = utils::optimize_children(self, plan, _config)?; - - ... -} -``` - -### Writing Tests - -There should be unit tests in the same file as the new rule that test the effect of the rule being applied to a plan -in isolation (without any other rule being applied). - -There should also be a test in `integration-tests.rs` that tests the rule as part of the overall optimization process. - -### Debugging - -The `EXPLAIN VERBOSE` command can be used to show the effect of each optimization rule on a query. - -In the following example, the `type_coercion` and `simplify_expressions` passes have simplified the plan so that it returns the constant `"3.2"` rather than doing a computation at execution time. - -```text -> explain verbose select cast(1 + 2.2 as string) as foo; -+------------------------------------------------------------+---------------------------------------------------------------------------+ -| plan_type | plan | -+------------------------------------------------------------+---------------------------------------------------------------------------+ -| initial_logical_plan | Projection: CAST(Int64(1) + Float64(2.2) AS Utf8) AS foo | -| | EmptyRelation | -| logical_plan after type_coercion | Projection: CAST(CAST(Int64(1) AS Float64) + Float64(2.2) AS Utf8) AS foo | -| | EmptyRelation | -| logical_plan after simplify_expressions | Projection: Utf8("3.2") AS foo | -| | EmptyRelation | -| logical_plan after unwrap_cast_in_comparison | SAME TEXT AS ABOVE | -| logical_plan after decorrelate_where_exists | SAME TEXT AS ABOVE | -| logical_plan after decorrelate_where_in | SAME TEXT AS ABOVE | -| logical_plan after scalar_subquery_to_join | SAME TEXT AS ABOVE | -| logical_plan after subquery_filter_to_join | SAME TEXT AS ABOVE | -| logical_plan after simplify_expressions | SAME TEXT AS ABOVE | -| logical_plan after eliminate_filter | SAME TEXT AS ABOVE | -| logical_plan after reduce_cross_join | SAME TEXT AS ABOVE | -| logical_plan after common_sub_expression_eliminate | SAME TEXT AS ABOVE | -| logical_plan after eliminate_limit | SAME TEXT AS ABOVE | -| logical_plan after projection_push_down | SAME TEXT AS ABOVE | -| logical_plan after rewrite_disjunctive_predicate | SAME TEXT AS ABOVE | -| logical_plan after reduce_outer_join | SAME TEXT AS ABOVE | -| logical_plan after filter_push_down | SAME TEXT AS ABOVE | -| logical_plan after limit_push_down | SAME TEXT AS ABOVE | -| logical_plan after single_distinct_aggregation_to_group_by | SAME TEXT AS ABOVE | -| logical_plan | Projection: Utf8("3.2") AS foo | -| | EmptyRelation | -| initial_physical_plan | ProjectionExec: expr=[3.2 as foo] | -| | PlaceholderRowExec | -| | | -| physical_plan after aggregate_statistics | SAME TEXT AS ABOVE | -| physical_plan after join_selection | SAME TEXT AS ABOVE | -| physical_plan after coalesce_batches | SAME TEXT AS ABOVE | -| physical_plan after repartition | SAME TEXT AS ABOVE | -| physical_plan after add_merge_exec | SAME TEXT AS ABOVE | -| physical_plan | ProjectionExec: expr=[3.2 as foo] | -| | PlaceholderRowExec | -| | | -+------------------------------------------------------------+---------------------------------------------------------------------------+ -``` - -[df]: https://crates.io/crates/datafusion +[query optimizer]: https://datafusion.apache.org/library-user-guide/query-optimizer.html diff --git a/docs/source/index.rst b/docs/source/index.rst index ca6905c434f3..9c8c886d2502 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -107,7 +107,7 @@ To get started, see library-user-guide/custom-table-providers library-user-guide/extending-operators library-user-guide/profiling - + library-user-guide/query-optimizer .. _toc.contributor-guide: .. toctree:: diff --git a/docs/source/library-user-guide/query-optimizer.md b/docs/source/library-user-guide/query-optimizer.md new file mode 100644 index 000000000000..5aacfaf59cb1 --- /dev/null +++ b/docs/source/library-user-guide/query-optimizer.md @@ -0,0 +1,336 @@ + + +# DataFusion Query Optimizer + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory +format. + +DataFusion has modular design, allowing individual crates to be re-used in other projects. + +This crate is a submodule of DataFusion that provides a query optimizer for logical plans, and +contains an extensive set of OptimizerRules that may rewrite the plan and/or its expressions so +they execute more quickly while still computing the same result. + +## Running the Optimizer + +The following code demonstrates the basic flow of creating the optimizer with a default set of optimization rules +and applying it to a logical plan to produce an optimized logical plan. + +```rust + +// We need a logical plan as the starting point. There are many ways to build a logical plan: +// +// The `datafusion-expr` crate provides a LogicalPlanBuilder +// The `datafusion-sql` crate provides a SQL query planner that can create a LogicalPlan from SQL +// The `datafusion` crate provides a DataFrame API that can create a LogicalPlan +let logical_plan = ... + +let mut config = OptimizerContext::default(); +let optimizer = Optimizer::new(&config); +let optimized_plan = optimizer.optimize(&logical_plan, &config, observe)?; + +fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { + println!( + "After applying rule '{}':\n{}", + rule.name(), + plan.display_indent() + ) +} +``` + +## Providing Custom Rules + +The optimizer can be created with a custom set of rules. + +```rust +let optimizer = Optimizer::with_rules(vec![ + Arc::new(MyRule {}) +]); +``` + +## Writing Optimization Rules + +Please refer to the +[optimizer_rule.rs](../../datafusion-examples/examples/optimizer_rule.rs) +example to learn more about the general approach to writing optimizer rules and +then move onto studying the existing rules. + +All rules must implement the `OptimizerRule` trait. + +```rust +/// `OptimizerRule` transforms one ['LogicalPlan'] into another which +/// computes the same results, but in a potentially more efficient +/// way. If there are no suitable transformations for the input plan, +/// the optimizer can simply return it as is. +pub trait OptimizerRule { + /// Rewrite `plan` to an optimized form + fn optimize( + &self, + plan: &LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result; + + /// A human readable name for this optimizer rule + fn name(&self) -> &str; +} +``` + +### General Guidelines + +Rules typical walk the logical plan and walk the expression trees inside operators and selectively mutate +individual operators or expressions. + +Sometimes there is an initial pass that visits the plan and builds state that is used in a second pass that performs +the actual optimization. This approach is used in projection push down and filter push down. + +### Expression Naming + +Every expression in DataFusion has a name, which is used as the column name. For example, in this example the output +contains a single column with the name `"COUNT(aggregate_test_100.c9)"`: + +```text +> select count(c9) from aggregate_test_100; ++------------------------------+ +| COUNT(aggregate_test_100.c9) | ++------------------------------+ +| 100 | ++------------------------------+ +``` + +These names are used to refer to the columns in both subqueries as well as internally from one stage of the LogicalPlan +to another. For example: + +```text +> select "COUNT(aggregate_test_100.c9)" + 1 from (select count(c9) from aggregate_test_100) as sq; ++--------------------------------------------+ +| sq.COUNT(aggregate_test_100.c9) + Int64(1) | ++--------------------------------------------+ +| 101 | ++--------------------------------------------+ +``` + +### Implication + +Because DataFusion identifies columns using a string name, it means it is critical that the names of expressions are +not changed by the optimizer when it rewrites expressions. This is typically accomplished by renaming a rewritten +expression by adding an alias. + +Here is a simple example of such a rewrite. The expression `1 + 2` can be internally simplified to 3 but must still be +displayed the same as `1 + 2`: + +```text +> select 1 + 2; ++---------------------+ +| Int64(1) + Int64(2) | ++---------------------+ +| 3 | ++---------------------+ +``` + +Looking at the `EXPLAIN` output we can see that the optimizer has effectively rewritten `1 + 2` into effectively +`3 as "1 + 2"`: + +```text +> explain select 1 + 2; ++---------------+-------------------------------------------------+ +| plan_type | plan | ++---------------+-------------------------------------------------+ +| logical_plan | Projection: Int64(3) AS Int64(1) + Int64(2) | +| | EmptyRelation | +| physical_plan | ProjectionExec: expr=[3 as Int64(1) + Int64(2)] | +| | PlaceholderRowExec | +| | | ++---------------+-------------------------------------------------+ +``` + +If the expression name is not preserved, bugs such as [#3704](https://github.com/apache/datafusion/issues/3704) +and [#3555](https://github.com/apache/datafusion/issues/3555) occur where the expected columns can not be found. + +### Building Expression Names + +There are currently two ways to create a name for an expression in the logical plan. + +```rust +impl Expr { + /// Returns the name of this expression as it should appear in a schema. This name + /// will not include any CAST expressions. + pub fn display_name(&self) -> Result { + create_name(self) + } + + /// Returns a full and complete string representation of this expression. + pub fn canonical_name(&self) -> String { + format!("{}", self) + } +} +``` + +When comparing expressions to determine if they are equivalent, `canonical_name` should be used, and when creating a +name to be used in a schema, `display_name` should be used. + +### Utilities + +There are a number of utility methods provided that take care of some common tasks. + +### ExprVisitor + +The `ExprVisitor` and `ExprVisitable` traits provide a mechanism for applying a visitor pattern to an expression tree. + +Here is an example that demonstrates this. + +```rust +fn extract_subquery_filters(expression: &Expr, extracted: &mut Vec) -> Result<()> { + struct InSubqueryVisitor<'a> { + accum: &'a mut Vec, + } + + impl ExpressionVisitor for InSubqueryVisitor<'_> { + fn pre_visit(self, expr: &Expr) -> Result> { + if let Expr::InSubquery(_) = expr { + self.accum.push(expr.to_owned()); + } + Ok(Recursion::Continue(self)) + } + } + + expression.accept(InSubqueryVisitor { accum: extracted })?; + Ok(()) +} +``` + +### Rewriting Expressions + +The `MyExprRewriter` trait can be implemented to provide a way to rewrite expressions. This rule can then be applied +to an expression by calling `Expr::rewrite` (from the `ExprRewritable` trait). + +The `rewrite` method will perform a depth first walk of the expression and its children to rewrite an expression, +consuming `self` producing a new expression. + +```rust +let mut expr_rewriter = MyExprRewriter {}; +let expr = expr.rewrite(&mut expr_rewriter)?; +``` + +Here is an example implementation which will rewrite `expr BETWEEN a AND b` as `expr >= a AND expr <= b`. Note that the +implementation does not need to perform any recursion since this is handled by the `rewrite` method. + +```rust +struct MyExprRewriter {} + +impl ExprRewriter for MyExprRewriter { + fn mutate(&mut self, expr: Expr) -> Result { + match expr { + Expr::Between { + negated, + expr, + low, + high, + } => { + let expr: Expr = expr.as_ref().clone(); + let low: Expr = low.as_ref().clone(); + let high: Expr = high.as_ref().clone(); + if negated { + Ok(expr.clone().lt(low).or(expr.clone().gt(high))) + } else { + Ok(expr.clone().gt_eq(low).and(expr.clone().lt_eq(high))) + } + } + _ => Ok(expr.clone()), + } + } +} +``` + +### optimize_children + +Typically a rule is applied recursively to all operators within a query plan. Rather than duplicate +that logic in each rule, an `optimize_children` method is provided. This recursively invokes the `optimize` method on +the plan's children and then returns a node of the same type. + +```rust +fn optimize( + &self, + plan: &LogicalPlan, + _config: &mut OptimizerConfig, +) -> Result { + // recurse down and optimize children first + let plan = utils::optimize_children(self, plan, _config)?; + + ... +} +``` + +### Writing Tests + +There should be unit tests in the same file as the new rule that test the effect of the rule being applied to a plan +in isolation (without any other rule being applied). + +There should also be a test in `integration-tests.rs` that tests the rule as part of the overall optimization process. + +### Debugging + +The `EXPLAIN VERBOSE` command can be used to show the effect of each optimization rule on a query. + +In the following example, the `type_coercion` and `simplify_expressions` passes have simplified the plan so that it returns the constant `"3.2"` rather than doing a computation at execution time. + +```text +> explain verbose select cast(1 + 2.2 as string) as foo; ++------------------------------------------------------------+---------------------------------------------------------------------------+ +| plan_type | plan | ++------------------------------------------------------------+---------------------------------------------------------------------------+ +| initial_logical_plan | Projection: CAST(Int64(1) + Float64(2.2) AS Utf8) AS foo | +| | EmptyRelation | +| logical_plan after type_coercion | Projection: CAST(CAST(Int64(1) AS Float64) + Float64(2.2) AS Utf8) AS foo | +| | EmptyRelation | +| logical_plan after simplify_expressions | Projection: Utf8("3.2") AS foo | +| | EmptyRelation | +| logical_plan after unwrap_cast_in_comparison | SAME TEXT AS ABOVE | +| logical_plan after decorrelate_where_exists | SAME TEXT AS ABOVE | +| logical_plan after decorrelate_where_in | SAME TEXT AS ABOVE | +| logical_plan after scalar_subquery_to_join | SAME TEXT AS ABOVE | +| logical_plan after subquery_filter_to_join | SAME TEXT AS ABOVE | +| logical_plan after simplify_expressions | SAME TEXT AS ABOVE | +| logical_plan after eliminate_filter | SAME TEXT AS ABOVE | +| logical_plan after reduce_cross_join | SAME TEXT AS ABOVE | +| logical_plan after common_sub_expression_eliminate | SAME TEXT AS ABOVE | +| logical_plan after eliminate_limit | SAME TEXT AS ABOVE | +| logical_plan after projection_push_down | SAME TEXT AS ABOVE | +| logical_plan after rewrite_disjunctive_predicate | SAME TEXT AS ABOVE | +| logical_plan after reduce_outer_join | SAME TEXT AS ABOVE | +| logical_plan after filter_push_down | SAME TEXT AS ABOVE | +| logical_plan after limit_push_down | SAME TEXT AS ABOVE | +| logical_plan after single_distinct_aggregation_to_group_by | SAME TEXT AS ABOVE | +| logical_plan | Projection: Utf8("3.2") AS foo | +| | EmptyRelation | +| initial_physical_plan | ProjectionExec: expr=[3.2 as foo] | +| | PlaceholderRowExec | +| | | +| physical_plan after aggregate_statistics | SAME TEXT AS ABOVE | +| physical_plan after join_selection | SAME TEXT AS ABOVE | +| physical_plan after coalesce_batches | SAME TEXT AS ABOVE | +| physical_plan after repartition | SAME TEXT AS ABOVE | +| physical_plan after add_merge_exec | SAME TEXT AS ABOVE | +| physical_plan | ProjectionExec: expr=[3.2 as foo] | +| | PlaceholderRowExec | +| | | ++------------------------------------------------------------+---------------------------------------------------------------------------+ +``` + +[df]: https://crates.io/crates/datafusion From 47d5d1fe1ac7d2eb363d4b2b52268629e89b64f9 Mon Sep 17 00:00:00 2001 From: June <61218022+itsjunetime@users.noreply.github.com> Date: Mon, 22 Jul 2024 06:32:12 -0600 Subject: [PATCH 27/37] feat: Error when a SHOW command is passed in with an accompanying non-existant variable (#11540) * feat: Error when a SHOW command is passed in with an accompanying non-existant variable * fix: Run fmt * Switch to 'query error' instead of 'statement error' in sqllogictest test to see if that fixes CI * Move some errors in sqllogictest to line above to maybe fix CI * Fix (hopefully final) failing information_schema slt test due to multiline error message/placement --- datafusion/sql/src/statement.rs | 16 ++++++++++++++++ .../test_files/information_schema.slt | 11 ++++++----- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 6df25086305d..8eb4113f80a6 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -1146,6 +1146,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // we could introduce alias in OptionDefinition if this string matching thing grows format!("{base_query} WHERE name = 'datafusion.execution.time_zone'") } else { + // These values are what are used to make the information_schema table, so we just + // check here, before actually planning or executing the query, if it would produce no + // results, and error preemptively if it would (for a better UX) + let is_valid_variable = self + .context_provider + .options() + .entries() + .iter() + .any(|opt| opt.key == variable); + + if !is_valid_variable { + return plan_err!( + "'{variable}' is not a variable which can be viewed with 'SHOW'" + ); + } + format!("{base_query} WHERE name = '{variable}'") }; diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index c8c0d1d45b97..1c6ffd44b1ef 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -370,9 +370,12 @@ datafusion.execution.time_zone +00:00 The default time zone Some functions, e.g. # show empty verbose -query TTT +statement error DataFusion error: Error during planning: '' is not a variable which can be viewed with 'SHOW' SHOW VERBOSE ----- + +# show nonsense verbose +statement error DataFusion error: Error during planning: 'nonsense' is not a variable which can be viewed with 'SHOW' +SHOW NONSENSE VERBOSE # information_schema_describe_table @@ -508,9 +511,7 @@ SHOW columns from datafusion.public.t2 # show_non_existing_variable -# FIXME -# currently we cannot know whether a variable exists, this will output 0 row instead -statement ok +statement error DataFusion error: Error during planning: 'something_unknown' is not a variable which can be viewed with 'SHOW' SHOW SOMETHING_UNKNOWN; statement ok From 5c65efc79954ce495328d63fd0445e982a7319a9 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Mon, 22 Jul 2024 20:35:42 +0800 Subject: [PATCH 28/37] fix: CASE with NULL (#11542) * fix: CASE with NULL * chore: Add tests * chore * chore: Fix CI * chore: Support all types are NULL * chore: Fix CI * chore: add more tests * fix: Return first non-null type in then exprs * chore: Fix CI * Update datafusion/expr/src/expr_schema.rs Co-authored-by: Jonah Gao * Update datafusion/expr/src/expr_schema.rs Co-authored-by: Jonah Gao --------- Co-authored-by: Jonah Gao --- datafusion/expr/src/expr_schema.rs | 12 +++++++- .../sqllogictest/test_files/aggregate.slt | 28 +++++++++++++++++++ datafusion/sqllogictest/test_files/scalar.slt | 8 +++--- datafusion/sqllogictest/test_files/select.slt | 27 ++++++++++++++++++ 4 files changed, 70 insertions(+), 5 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 1df5d6c4d736..5e0571f712ee 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -112,7 +112,17 @@ impl ExprSchemable for Expr { Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), Expr::Literal(l) => Ok(l.data_type()), - Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), + Expr::Case(case) => { + for (_, then_expr) in &case.when_then_expr { + let then_type = then_expr.get_type(schema)?; + if !then_type.is_null() { + return Ok(then_type); + } + } + case.else_expr + .as_ref() + .map_or(Ok(DataType::Null), |e| e.get_type(schema)) + } Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), Expr::Unnest(Unnest { expr }) => { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index d0f7f2d9ac7a..bb5ce1150a58 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5418,6 +5418,34 @@ SELECT LAST_VALUE(column1 ORDER BY column2 DESC) IGNORE NULLS FROM t; statement ok DROP TABLE t; +# Test for CASE with NULL in aggregate function +statement ok +CREATE TABLE example(data double precision); + +statement ok +INSERT INTO example VALUES (1), (2), (NULL), (4); + +query RR +SELECT + sum(CASE WHEN data is NULL THEN NULL ELSE data+1 END) as then_null, + sum(CASE WHEN data is NULL THEN data+1 ELSE NULL END) as else_null +FROM example; +---- +10 NULL + +query R +SELECT + CASE data WHEN 1 THEN NULL WHEN 2 THEN 3.3 ELSE NULL END as case_null +FROM example; +---- +NULL +3.3 +NULL +NULL + +statement ok +drop table example; + # Test Convert FirstLast optimizer rule statement ok CREATE EXTERNAL TABLE convert_first_last_table ( diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 48f94fc080a4..ff9afa94f40a 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1238,27 +1238,27 @@ SELECT CASE WHEN NULL THEN 'foo' ELSE 'bar' END bar # case_expr_with_null() -query ? +query I select case when b is null then null else b end from (select a,b from (values (1,null),(2,3)) as t (a,b)) a; ---- NULL 3 -query ? +query I select case when b is null then null else b end from (select a,b from (values (1,1),(2,3)) as t (a,b)) a; ---- 1 3 # case_expr_with_nulls() -query ? +query I select case when b is null then null when b < 3 then null when b >=3 then b + 1 else b end from (select a,b from (values (1,null),(1,2),(2,3)) as t (a,b)) a ---- NULL NULL 4 -query ? +query I select case b when 1 then null when 2 then null when 3 then b + 1 else b end from (select a,b from (values (1,null),(1,2),(2,3)) as t (a,b)) a; ---- NULL diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 03426dec874f..6884efc07e15 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -613,6 +613,33 @@ END; ---- 2 +# select case when type is null +query I +select CASE + WHEN NULL THEN 1 + ELSE 2 +END; +---- +2 + +# select case then type is null +query I +select CASE + WHEN 10 > 5 THEN NULL + ELSE 2 +END; +---- +NULL + +# select case else type is null +query I +select CASE + WHEN 10 = 5 THEN 1 + ELSE NULL +END; +---- +NULL + # Binary Expression for LargeUtf8 # issue: https://github.com/apache/datafusion/issues/5893 statement ok From 51da92fb9fe1b1bc2344fa78be52c448b36880d9 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Mon, 22 Jul 2024 20:36:58 +0800 Subject: [PATCH 29/37] Provide DataFrame API for `map` and move `map` to `functions-array` (#11560) * move map to `functions-array` and implement dataframe api * add benchmark for dataframe api * fix format * add roundtrip_expr_api test --- datafusion/core/Cargo.toml | 5 + datafusion/core/benches/map_query_sql.rs | 93 +++++++++++++++++++ .../tests/dataframe/dataframe_functions.rs | 22 +++++ datafusion/functions-array/benches/map.rs | 37 +++++++- datafusion/functions-array/src/lib.rs | 3 + .../src/core => functions-array/src}/map.rs | 35 ++++--- datafusion/functions-array/src/planner.rs | 6 +- datafusion/functions/Cargo.toml | 5 - datafusion/functions/benches/map.rs | 80 ---------------- datafusion/functions/src/core/mod.rs | 7 -- .../tests/cases/roundtrip_logical_plan.rs | 5 + 11 files changed, 189 insertions(+), 109 deletions(-) create mode 100644 datafusion/core/benches/map_query_sql.rs rename datafusion/{functions/src/core => functions-array/src}/map.rs (83%) delete mode 100644 datafusion/functions/benches/map.rs diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index c937a6f6e59a..4301396b231f 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -217,3 +217,8 @@ name = "topk_aggregate" [[bench]] harness = false name = "parquet_statistic" + +[[bench]] +harness = false +name = "map_query_sql" +required-features = ["array_expressions"] diff --git a/datafusion/core/benches/map_query_sql.rs b/datafusion/core/benches/map_query_sql.rs new file mode 100644 index 000000000000..b6ac8b6b647a --- /dev/null +++ b/datafusion/core/benches/map_query_sql.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_array::{ArrayRef, Int32Array, RecordBatch}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use parking_lot::Mutex; +use rand::prelude::ThreadRng; +use rand::Rng; +use tokio::runtime::Runtime; + +use datafusion::prelude::SessionContext; +use datafusion_common::ScalarValue; +use datafusion_expr::Expr; +use datafusion_functions_array::map::map; + +mod data_utils; + +fn build_keys(rng: &mut ThreadRng) -> Vec { + let mut keys = vec![]; + for _ in 0..1000 { + keys.push(rng.gen_range(0..9999).to_string()); + } + keys +} + +fn build_values(rng: &mut ThreadRng) -> Vec { + let mut values = vec![]; + for _ in 0..1000 { + values.push(rng.gen_range(0..9999)); + } + values +} + +fn t_batch(num: i32) -> RecordBatch { + let value: Vec = (0..num).collect(); + let c1: ArrayRef = Arc::new(Int32Array::from(value)); + RecordBatch::try_from_iter(vec![("c1", c1)]).unwrap() +} + +fn create_context(num: i32) -> datafusion_common::Result>> { + let ctx = SessionContext::new(); + ctx.register_batch("t", t_batch(num))?; + Ok(Arc::new(Mutex::new(ctx))) +} + +fn criterion_benchmark(c: &mut Criterion) { + let ctx = create_context(1).unwrap(); + let rt = Runtime::new().unwrap(); + let df = rt.block_on(ctx.lock().table("t")).unwrap(); + + let mut rng = rand::thread_rng(); + let keys = build_keys(&mut rng); + let values = build_values(&mut rng); + let mut key_buffer = Vec::new(); + let mut value_buffer = Vec::new(); + + for i in 0..1000 { + key_buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); + value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + } + c.bench_function("map_1000_1", |b| { + b.iter(|| { + black_box( + rt.block_on( + df.clone() + .select(vec![map(key_buffer.clone(), value_buffer.clone())]) + .unwrap() + .collect(), + ) + .unwrap(), + ); + }); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 1c55c48fea40..f7b02196d8ed 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -34,6 +34,7 @@ use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::expr::Alias; use datafusion_expr::ExprSchemable; use datafusion_functions_aggregate::expr_fn::{approx_median, approx_percentile_cont}; +use datafusion_functions_array::map::map; fn test_schema() -> SchemaRef { Arc::new(Schema::new(vec![ @@ -1087,3 +1088,24 @@ async fn test_fn_array_to_string() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_fn_map() -> Result<()> { + let expr = map( + vec![lit("a"), lit("b"), lit("c")], + vec![lit(1), lit(2), lit(3)], + ); + let expected = [ + "+---------------------------------------------------------------------------------------+", + "| map(make_array(Utf8(\"a\"),Utf8(\"b\"),Utf8(\"c\")),make_array(Int32(1),Int32(2),Int32(3))) |", + "+---------------------------------------------------------------------------------------+", + "| {a: 1, b: 2, c: 3} |", + "| {a: 1, b: 2, c: 3} |", + "| {a: 1, b: 2, c: 3} |", + "| {a: 1, b: 2, c: 3} |", + "+---------------------------------------------------------------------------------------+", + ]; + assert_fn_batches!(expr, expected); + + Ok(()) +} diff --git a/datafusion/functions-array/benches/map.rs b/datafusion/functions-array/benches/map.rs index 2e9b45266abc..c2e0e641e80d 100644 --- a/datafusion/functions-array/benches/map.rs +++ b/datafusion/functions-array/benches/map.rs @@ -17,13 +17,18 @@ extern crate criterion; +use arrow_array::{Int32Array, ListArray, StringArray}; +use arrow_buffer::{OffsetBuffer, ScalarBuffer}; +use arrow_schema::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use rand::prelude::ThreadRng; use rand::Rng; +use std::sync::Arc; use datafusion_common::ScalarValue; use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::Expr; +use datafusion_expr::{ColumnarValue, Expr}; +use datafusion_functions_array::map::map_udf; use datafusion_functions_array::planner::ArrayFunctionPlanner; fn keys(rng: &mut ThreadRng) -> Vec { @@ -63,6 +68,36 @@ fn criterion_benchmark(c: &mut Criterion) { ); }); }); + + c.bench_function("map_1000", |b| { + let mut rng = rand::thread_rng(); + let field = Arc::new(Field::new("item", DataType::Utf8, true)); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 1000])); + let key_list = ListArray::new( + field, + offsets, + Arc::new(StringArray::from(keys(&mut rng))), + None, + ); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 1000])); + let value_list = ListArray::new( + field, + offsets, + Arc::new(Int32Array::from(values(&mut rng))), + None, + ); + let keys = ColumnarValue::Scalar(ScalarValue::List(Arc::new(key_list))); + let values = ColumnarValue::Scalar(ScalarValue::List(Arc::new(value_list))); + + b.iter(|| { + black_box( + map_udf() + .invoke(&[keys.clone(), values.clone()]) + .expect("map should work on valid values"), + ); + }); + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index 9717d29883fd..f68f59dcd6a1 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -41,6 +41,7 @@ pub mod extract; pub mod flatten; pub mod length; pub mod make_array; +pub mod map; pub mod planner; pub mod position; pub mod range; @@ -53,6 +54,7 @@ pub mod set_ops; pub mod sort; pub mod string; pub mod utils; + use datafusion_common::Result; use datafusion_execution::FunctionRegistry; use datafusion_expr::ScalarUDF; @@ -140,6 +142,7 @@ pub fn all_default_array_functions() -> Vec> { replace::array_replace_n_udf(), replace::array_replace_all_udf(), replace::array_replace_udf(), + map::map_udf(), ] } diff --git a/datafusion/functions/src/core/map.rs b/datafusion/functions-array/src/map.rs similarity index 83% rename from datafusion/functions/src/core/map.rs rename to datafusion/functions-array/src/map.rs index 2deef242f8a0..e218b501dcf1 100644 --- a/datafusion/functions/src/core/map.rs +++ b/datafusion/functions-array/src/map.rs @@ -15,17 +15,26 @@ // specific language governing permissions and limitations // under the License. +use crate::make_array::make_array; +use arrow::array::ArrayData; +use arrow_array::{Array, ArrayRef, MapArray, StructArray}; +use arrow_buffer::{Buffer, ToByteSlice}; +use arrow_schema::{DataType, Field, SchemaBuilder}; +use datafusion_common::{exec_err, ScalarValue}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::collections::VecDeque; use std::sync::Arc; -use arrow::array::{Array, ArrayData, ArrayRef, MapArray, StructArray}; -use arrow::datatypes::{DataType, Field, SchemaBuilder}; -use arrow_buffer::{Buffer, ToByteSlice}; +/// Returns a map created from a key list and a value list +pub fn map(keys: Vec, values: Vec) -> Expr { + let keys = make_array(keys); + let values = make_array(values); + Expr::ScalarFunction(ScalarFunction::new_udf(map_udf(), vec![keys, values])) +} -use datafusion_common::Result; -use datafusion_common::{exec_err, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +create_func!(MapFunc, map_udf); /// Check if we can evaluate the expr to constant directly. /// @@ -39,7 +48,7 @@ fn can_evaluate_to_const(args: &[ColumnarValue]) -> bool { .all(|arg| matches!(arg, ColumnarValue::Scalar(_))) } -fn make_map_batch(args: &[ColumnarValue]) -> Result { +fn make_map_batch(args: &[ColumnarValue]) -> datafusion_common::Result { if args.len() != 2 { return exec_err!( "make_map requires exactly 2 arguments, got {} instead", @@ -54,7 +63,9 @@ fn make_map_batch(args: &[ColumnarValue]) -> Result { make_map_batch_internal(key, value, can_evaluate_to_const) } -fn get_first_array_ref(columnar_value: &ColumnarValue) -> Result { +fn get_first_array_ref( + columnar_value: &ColumnarValue, +) -> datafusion_common::Result { match columnar_value { ColumnarValue::Scalar(value) => match value { ScalarValue::List(array) => Ok(array.value(0)), @@ -70,7 +81,7 @@ fn make_map_batch_internal( keys: ArrayRef, values: ArrayRef, can_evaluate_to_const: bool, -) -> Result { +) -> datafusion_common::Result { if keys.null_count() > 0 { return exec_err!("map key cannot be null"); } @@ -150,7 +161,7 @@ impl ScalarUDFImpl for MapFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { if arg_types.len() % 2 != 0 { return exec_err!( "map requires an even number of arguments, got {} instead", @@ -175,12 +186,12 @@ impl ScalarUDFImpl for MapFunc { )) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { make_map_batch(args) } } -fn get_element_type(data_type: &DataType) -> Result<&DataType> { +fn get_element_type(data_type: &DataType) -> datafusion_common::Result<&DataType> { match data_type { DataType::List(element) => Ok(element.data_type()), DataType::LargeList(element) => Ok(element.data_type()), diff --git a/datafusion/functions-array/src/planner.rs b/datafusion/functions-array/src/planner.rs index fbb541d9b151..c63c2c83e66e 100644 --- a/datafusion/functions-array/src/planner.rs +++ b/datafusion/functions-array/src/planner.rs @@ -27,6 +27,7 @@ use datafusion_expr::{ use datafusion_functions::expr_fn::get_field; use datafusion_functions_aggregate::nth_value::nth_value_udaf; +use crate::map::map_udf; use crate::{ array_has::array_has_all, expr_fn::{array_append, array_concat, array_prepend}, @@ -111,10 +112,7 @@ impl ExprPlanner for ArrayFunctionPlanner { let values = make_array(values.into_iter().map(|(_, e)| e).collect()); Ok(PlannerResult::Planned(Expr::ScalarFunction( - ScalarFunction::new_udf( - datafusion_functions::core::map(), - vec![keys, values], - ), + ScalarFunction::new_udf(map_udf(), vec![keys, values]), ))) } } diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index b143080b1962..0281676cabf2 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -141,8 +141,3 @@ required-features = ["string_expressions"] harness = false name = "upper" required-features = ["string_expressions"] - -[[bench]] -harness = false -name = "map" -required-features = ["core_expressions"] diff --git a/datafusion/functions/benches/map.rs b/datafusion/functions/benches/map.rs deleted file mode 100644 index 811c21a41b46..000000000000 --- a/datafusion/functions/benches/map.rs +++ /dev/null @@ -1,80 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -extern crate criterion; - -use arrow::array::{Int32Array, ListArray, StringArray}; -use arrow::datatypes::{DataType, Field}; -use arrow_buffer::{OffsetBuffer, ScalarBuffer}; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_common::ScalarValue; -use datafusion_expr::ColumnarValue; -use datafusion_functions::core::map; -use rand::prelude::ThreadRng; -use rand::Rng; -use std::sync::Arc; - -fn keys(rng: &mut ThreadRng) -> Vec { - let mut keys = vec![]; - for _ in 0..1000 { - keys.push(rng.gen_range(0..9999).to_string()); - } - keys -} - -fn values(rng: &mut ThreadRng) -> Vec { - let mut values = vec![]; - for _ in 0..1000 { - values.push(rng.gen_range(0..9999)); - } - values -} - -fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("map_1000", |b| { - let mut rng = rand::thread_rng(); - let field = Arc::new(Field::new("item", DataType::Utf8, true)); - let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 1000])); - let key_list = ListArray::new( - field, - offsets, - Arc::new(StringArray::from(keys(&mut rng))), - None, - ); - let field = Arc::new(Field::new("item", DataType::Int32, true)); - let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 1000])); - let value_list = ListArray::new( - field, - offsets, - Arc::new(Int32Array::from(values(&mut rng))), - None, - ); - let keys = ColumnarValue::Scalar(ScalarValue::List(Arc::new(key_list))); - let values = ColumnarValue::Scalar(ScalarValue::List(Arc::new(value_list))); - - b.iter(|| { - black_box( - map() - .invoke(&[keys.clone(), values.clone()]) - .expect("map should work on valid values"), - ); - }); - }); -} - -criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index ee0309e59382..8c5121397284 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -25,7 +25,6 @@ pub mod arrowtypeof; pub mod coalesce; pub mod expr_ext; pub mod getfield; -pub mod map; pub mod named_struct; pub mod nullif; pub mod nvl; @@ -43,7 +42,6 @@ make_udf_function!(r#struct::StructFunc, STRUCT, r#struct); make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct); make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce); -make_udf_function!(map::MapFunc, MAP, map); pub mod expr_fn { use datafusion_expr::{Expr, Literal}; @@ -80,10 +78,6 @@ pub mod expr_fn { coalesce, "Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL", args, - ),( - map, - "Returns a map created from a key list and a value list", - args, )); #[doc = "Returns the value of the field with the given name from the struct"] @@ -101,6 +95,5 @@ pub fn functions() -> Vec> { arrow_typeof(), named_struct(), coalesce(), - map(), ] } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 11945f39589a..3476d5d042cc 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -44,6 +44,7 @@ use datafusion::functions_aggregate::expr_fn::{ count_distinct, covar_pop, covar_samp, first_value, grouping, median, stddev, stddev_pop, sum, var_pop, var_sample, }; +use datafusion::functions_array::map::map; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::TableOptions; @@ -704,6 +705,10 @@ async fn roundtrip_expr_api() -> Result<()> { bool_or(lit(true)), array_agg(lit(1)), array_agg(lit(1)).distinct().build().unwrap(), + map( + vec![lit(1), lit(2), lit(3)], + vec![lit(10), lit(20), lit(30)], + ), ]; // ensure expressions created with the expr api can be round tripped From 81d06f2e103385fe744fb909563d4fb4c4b13d49 Mon Sep 17 00:00:00 2001 From: Xin Li <33629085+xinlifoobar@users.noreply.github.com> Date: Mon, 22 Jul 2024 05:37:55 -0700 Subject: [PATCH 30/37] Move OutputRequirements to datafusion-physical-optimizer crate (#11579) * Move OutputRequirements to datafusion-physical-optimizer crate * Fix fmt * Fix cargo for cli --- datafusion-cli/Cargo.lock | 10 ++++---- .../enforce_distribution.rs | 4 ++-- datafusion/core/src/physical_optimizer/mod.rs | 1 - datafusion/physical-optimizer/Cargo.toml | 2 ++ datafusion/physical-optimizer/src/lib.rs | 1 + .../src}/output_requirements.rs | 24 +++++++++++-------- 6 files changed, 25 insertions(+), 17 deletions(-) rename datafusion/{core/src/physical_optimizer => physical-optimizer/src}/output_requirements.rs (94%) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 61d9c72b89d9..84bff8c87190 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -118,9 +118,9 @@ dependencies = [ [[package]] name = "arrayref" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b4930d2cb77ce62f89ee5d5289b4ac049559b1c45539271f5ed4fdc7db34545" +checksum = "9d151e35f61089500b617991b791fc8bfd237ae50cd5950803758a179b41e67a" [[package]] name = "arrayvec" @@ -875,9 +875,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.5" +version = "1.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "324c74f2155653c90b04f25b2a47a8a631360cb908f92a772695f430c7e31052" +checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f" dependencies = [ "jobserver", "libc", @@ -1397,6 +1397,8 @@ name = "datafusion-physical-optimizer" version = "40.0.0" dependencies = [ "datafusion-common", + "datafusion-execution", + "datafusion-physical-expr", "datafusion-physical-plan", ] diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 9791f23f963e..62ac9089e2b4 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -24,7 +24,6 @@ use std::fmt::Debug; use std::sync::Arc; -use super::output_requirements::OutputRequirementExec; use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::utils::{ @@ -55,6 +54,7 @@ use datafusion_physical_expr::{ use datafusion_physical_plan::windows::{get_best_fitting_window, BoundedWindowAggExec}; use datafusion_physical_plan::ExecutionPlanProperties; +use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; use datafusion_physical_optimizer::PhysicalOptimizerRule; use itertools::izip; @@ -1290,7 +1290,6 @@ pub(crate) mod tests { use crate::datasource::object_store::ObjectStoreUrl; use crate::datasource::physical_plan::{CsvExec, FileScanConfig, ParquetExec}; use crate::physical_optimizer::enforce_sorting::EnforceSorting; - use crate::physical_optimizer::output_requirements::OutputRequirements; use crate::physical_optimizer::test_utils::{ check_integrity, coalesce_partitions_exec, repartition_exec, }; @@ -1301,6 +1300,7 @@ pub(crate) mod tests { use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{displayable, DisplayAs, DisplayFormatType, Statistics}; + use datafusion_physical_optimizer::output_requirements::OutputRequirements; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::ScalarValue; diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index 582f340151ae..a0c9c3697744 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -29,7 +29,6 @@ pub mod enforce_sorting; pub mod join_selection; pub mod limited_distinct_aggregation; pub mod optimizer; -pub mod output_requirements; pub mod projection_pushdown; pub mod pruning; pub mod replace_with_order_preserving_variants; diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index 9c0ee61da52a..125ea6acc77f 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -33,4 +33,6 @@ workspace = true [dependencies] datafusion-common = { workspace = true, default-features = true } +datafusion-execution = { workspace = true } +datafusion-physical-expr = { workspace = true } datafusion-physical-plan = { workspace = true } diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index c5a49216f5fd..6b9df7cad5c8 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -18,5 +18,6 @@ #![deny(clippy::clone_on_ref_ptr)] mod optimizer; +pub mod output_requirements; pub use optimizer::PhysicalOptimizerRule; diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/physical-optimizer/src/output_requirements.rs similarity index 94% rename from datafusion/core/src/physical_optimizer/output_requirements.rs rename to datafusion/physical-optimizer/src/output_requirements.rs index cb9a0cb90e6c..f971d8f1f0aa 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/physical-optimizer/src/output_requirements.rs @@ -24,17 +24,21 @@ use std::sync::Arc; -use crate::physical_plan::sorts::sort::SortExec; -use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, SendableRecordBatchStream, +}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Result, Statistics}; use datafusion_physical_expr::{Distribution, LexRequirement, PhysicalSortRequirement}; -use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::{ExecutionPlanProperties, PlanProperties}; +use crate::PhysicalOptimizerRule; + /// This rule either adds or removes [`OutputRequirements`]s to/from the physical /// plan according to its `mode` attribute, which is set by the constructors /// `new_add_mode` and `new_remove_mode`. With this rule, we can keep track of @@ -86,7 +90,7 @@ enum RuleMode { /// /// See [`OutputRequirements`] for more details #[derive(Debug)] -pub(crate) struct OutputRequirementExec { +pub struct OutputRequirementExec { input: Arc, order_requirement: Option, dist_requirement: Distribution, @@ -94,7 +98,7 @@ pub(crate) struct OutputRequirementExec { } impl OutputRequirementExec { - pub(crate) fn new( + pub fn new( input: Arc, requirements: Option, dist_requirement: Distribution, @@ -108,8 +112,8 @@ impl OutputRequirementExec { } } - pub(crate) fn input(&self) -> Arc { - self.input.clone() + pub fn input(&self) -> Arc { + Arc::clone(&self.input) } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -179,8 +183,8 @@ impl ExecutionPlan for OutputRequirementExec { fn execute( &self, _partition: usize, - _context: Arc, - ) -> Result { + _context: Arc, + ) -> Result { unreachable!(); } @@ -275,7 +279,7 @@ fn require_top_ordering_helper( // When an operator requires an ordering, any `SortExec` below can not // be responsible for (i.e. the originator of) the global ordering. let (new_child, is_changed) = - require_top_ordering_helper(children.swap_remove(0).clone())?; + require_top_ordering_helper(Arc::clone(children.swap_remove(0)))?; Ok((plan.with_new_children(vec![new_child])?, is_changed)) } else { // Stop searching, there is no global ordering desired for the query. From 4417a9404f99eeb662d887cbb12de3445eb9cd2a Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Mon, 22 Jul 2024 22:23:26 +0800 Subject: [PATCH 31/37] Minor: move `Column` related tests and rename `column.rs` (#11573) * Minor: move `Column` related tests * Rename column.rs to unknown_column.rs --- .../src/expressions/column.rs | 46 ++++++++++++++++++ .../physical-expr/src/expressions/mod.rs | 4 +- .../{column.rs => unknown_column.rs} | 48 +------------------ 3 files changed, 49 insertions(+), 49 deletions(-) rename datafusion/physical-expr/src/expressions/{column.rs => unknown_column.rs} (56%) diff --git a/datafusion/physical-expr-common/src/expressions/column.rs b/datafusion/physical-expr-common/src/expressions/column.rs index d972d35b9e4e..5397599ea2dc 100644 --- a/datafusion/physical-expr-common/src/expressions/column.rs +++ b/datafusion/physical-expr-common/src/expressions/column.rs @@ -135,3 +135,49 @@ impl Column { pub fn col(name: &str, schema: &Schema) -> Result> { Ok(Arc::new(Column::new_with_schema(name, schema)?)) } + +#[cfg(test)] +mod test { + use super::Column; + use crate::physical_expr::PhysicalExpr; + + use arrow::array::StringArray; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_common::Result; + + use std::sync::Arc; + + #[test] + fn out_of_bounds_data_type() { + let schema = Schema::new(vec![Field::new("foo", DataType::Utf8, true)]); + let col = Column::new("id", 9); + let error = col.data_type(&schema).expect_err("error").strip_backtrace(); + assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ + but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ + DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)) + } + + #[test] + fn out_of_bounds_nullable() { + let schema = Schema::new(vec![Field::new("foo", DataType::Utf8, true)]); + let col = Column::new("id", 9); + let error = col.nullable(&schema).expect_err("error").strip_backtrace(); + assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ + but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ + DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)) + } + + #[test] + fn out_of_bounds_evaluate() -> Result<()> { + let schema = Schema::new(vec![Field::new("foo", DataType::Utf8, true)]); + let data: StringArray = vec!["data"].into(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?; + let col = Column::new("id", 9); + let error = col.evaluate(&batch).expect_err("error").strip_backtrace(); + assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ + but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ + DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)); + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index fa80bc9873f0..5a2bcb63b18e 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -20,7 +20,6 @@ #[macro_use] mod binary; mod case; -mod column; mod in_list; mod is_not_null; mod is_null; @@ -29,6 +28,7 @@ mod negative; mod no_op; mod not; mod try_cast; +mod unknown_column; /// Module with some convenient methods used in expression building pub mod helpers { @@ -48,7 +48,6 @@ pub use crate::PhysicalSortExpr; pub use binary::{binary, BinaryExpr}; pub use case::{case, CaseExpr}; -pub use column::UnKnownColumn; pub use datafusion_expr::utils::format_state_name; pub use datafusion_physical_expr_common::expressions::column::{col, Column}; pub use datafusion_physical_expr_common::expressions::literal::{lit, Literal}; @@ -61,3 +60,4 @@ pub use negative::{negative, NegativeExpr}; pub use no_op::NoOp; pub use not::{not, NotExpr}; pub use try_cast::{try_cast, TryCastExpr}; +pub use unknown_column::UnKnownColumn; diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/unknown_column.rs similarity index 56% rename from datafusion/physical-expr/src/expressions/column.rs rename to datafusion/physical-expr/src/expressions/unknown_column.rs index ab43201ceb75..cb7221e7fa15 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/unknown_column.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Column expression +//! UnKnownColumn expression use std::any::Any; use std::hash::{Hash, Hasher}; @@ -100,49 +100,3 @@ impl PartialEq for UnKnownColumn { false } } - -#[cfg(test)] -mod test { - use crate::expressions::Column; - use crate::PhysicalExpr; - - use arrow::array::StringArray; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion_common::Result; - - use std::sync::Arc; - - #[test] - fn out_of_bounds_data_type() { - let schema = Schema::new(vec![Field::new("foo", DataType::Utf8, true)]); - let col = Column::new("id", 9); - let error = col.data_type(&schema).expect_err("error").strip_backtrace(); - assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ - but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ - DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)) - } - - #[test] - fn out_of_bounds_nullable() { - let schema = Schema::new(vec![Field::new("foo", DataType::Utf8, true)]); - let col = Column::new("id", 9); - let error = col.nullable(&schema).expect_err("error").strip_backtrace(); - assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ - but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ - DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)) - } - - #[test] - fn out_of_bounds_evaluate() -> Result<()> { - let schema = Schema::new(vec![Field::new("foo", DataType::Utf8, true)]); - let data: StringArray = vec!["data"].into(); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?; - let col = Column::new("id", 9); - let error = col.evaluate(&batch).expect_err("error").strip_backtrace(); - assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ - but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ - DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)); - Ok(()) - } -} From b6e55d7e9cf17cfd1dcf633350cc6d205608ecd0 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 22 Jul 2024 09:51:40 -0600 Subject: [PATCH 32/37] feat: Optimize CASE expression for usage where then and else values are literals (#11553) * Optimize CASE expression for usage where then and else values are literals * add slt test * add more test cases --- .../physical-expr/src/expressions/case.rs | 44 ++++++++++++++ datafusion/sqllogictest/test_files/case.slt | 60 ++++++++++++++++++- 2 files changed, 103 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 521a7ed9acae..b428d562bd1b 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -57,6 +57,11 @@ enum EvalMethod { /// /// CASE WHEN condition THEN column [ELSE NULL] END InfallibleExprOrNull, + /// This is a specialization for a specific use case where we can take a fast path + /// if there is just one when/then pair and both the `then` and `else` expressions + /// are literal values + /// CASE WHEN condition THEN literal ELSE literal END + ScalarOrScalar, } /// The CASE expression is similar to a series of nested if/else and there are two forms that @@ -140,6 +145,12 @@ impl CaseExpr { && else_expr.is_none() { EvalMethod::InfallibleExprOrNull + } else if when_then_expr.len() == 1 + && when_then_expr[0].1.as_any().is::() + && else_expr.is_some() + && else_expr.as_ref().unwrap().as_any().is::() + { + EvalMethod::ScalarOrScalar } else { EvalMethod::NoExpression }; @@ -344,6 +355,38 @@ impl CaseExpr { internal_err!("predicate did not evaluate to an array") } } + + fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result { + let return_type = self.data_type(&batch.schema())?; + + // evaluate when expression + let when_value = self.when_then_expr[0].0.evaluate(batch)?; + let when_value = when_value.into_array(batch.num_rows())?; + let when_value = as_boolean_array(&when_value).map_err(|e| { + DataFusionError::Context( + "WHEN expression did not return a BooleanArray".to_string(), + Box::new(e), + ) + })?; + + // Treat 'NULL' as false value + let when_value = match when_value.null_count() { + 0 => Cow::Borrowed(when_value), + _ => Cow::Owned(prep_null_mask_filter(when_value)), + }; + + // evaluate then_value + let then_value = self.when_then_expr[0].1.evaluate(batch)?; + let then_value = Scalar::new(then_value.into_array(1)?); + + // keep `else_expr`'s data type and return type consistent + let e = self.else_expr.as_ref().unwrap(); + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) + .unwrap_or_else(|_| Arc::clone(e)); + let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?); + + Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?)) + } } impl PhysicalExpr for CaseExpr { @@ -406,6 +449,7 @@ impl PhysicalExpr for CaseExpr { // Specialization for CASE WHEN expr THEN column [ELSE NULL] END self.case_column_or_null(batch) } + EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch), } } diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index fac1042bb6dd..70063b88fb19 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -17,7 +17,7 @@ # create test data statement ok -create table foo (a int, b int) as values (1, 2), (3, 4), (5, 6); +create table foo (a int, b int) as values (1, 2), (3, 4), (5, 6), (null, null), (6, null), (null, 7); # CASE WHEN with condition query T @@ -26,6 +26,9 @@ SELECT CASE a WHEN 1 THEN 'one' WHEN 3 THEN 'three' ELSE '?' END FROM foo one three ? +? +? +? # CASE WHEN with no condition query I @@ -34,6 +37,9 @@ SELECT CASE WHEN a > 2 THEN a ELSE b END FROM foo 2 3 5 +NULL +6 +7 # column or explicit null query I @@ -42,6 +48,9 @@ SELECT CASE WHEN a > 2 THEN b ELSE null END FROM foo NULL 4 6 +NULL +NULL +7 # column or implicit null query I @@ -50,3 +59,52 @@ SELECT CASE WHEN a > 2 THEN b END FROM foo NULL 4 6 +NULL +NULL +7 + +# scalar or scalar (string) +query T +SELECT CASE WHEN a > 2 THEN 'even' ELSE 'odd' END FROM foo +---- +odd +even +even +odd +even +odd + +# scalar or scalar (int) +query I +SELECT CASE WHEN a > 2 THEN 1 ELSE 0 END FROM foo +---- +0 +1 +1 +0 +1 +0 + +# predicate binary expression with scalars (does not make much sense because the expression in +# this case is always false, so this expression could be rewritten as a literal 0 during planning +query I +SELECT CASE WHEN 1 > 2 THEN 1 ELSE 0 END FROM foo +---- +0 +0 +0 +0 +0 +0 + +# predicate using boolean literal (does not make much sense because the expression in +# this case is always false, so this expression could be rewritten as a literal 0 during planning +query I +SELECT CASE WHEN false THEN 1 ELSE 0 END FROM foo +---- +0 +0 +0 +0 +0 +0 From 7d078d8c11155fd098595126b1ed60cad9afce5a Mon Sep 17 00:00:00 2001 From: Oleks V Date: Mon, 22 Jul 2024 11:10:53 -0700 Subject: [PATCH 33/37] Fix SortMergeJoin antijoin flaky condition (#11604) --- .../physical-plan/src/joins/sort_merge_join.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 5fde028c7f48..96d5ba728a30 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1681,22 +1681,25 @@ fn get_filtered_join_mask( JoinType::LeftAnti => { // have we seen a filter match for a streaming index before for i in 0..streamed_indices_length { - if mask.value(i) && !seen_as_true { + let streamed_idx = streamed_indices.value(i); + if mask.value(i) + && !seen_as_true + && !matched_indices.contains(&streamed_idx) + { seen_as_true = true; - filter_matched_indices.push(streamed_indices.value(i)); + filter_matched_indices.push(streamed_idx); } // Reset `seen_as_true` flag and calculate mask for the current streaming index // - if within the batch it switched to next streaming index(e.g. from 0 to 1, or from 1 to 2) // - if it is at the end of the all buffered batches for the given streaming index, 0 index comes last if (i < streamed_indices_length - 1 - && streamed_indices.value(i) != streamed_indices.value(i + 1)) + && streamed_idx != streamed_indices.value(i + 1)) || (i == streamed_indices_length - 1 && *scanning_buffered_offset == 0) { corrected_mask.append_value( - !matched_indices.contains(&streamed_indices.value(i)) - && !seen_as_true, + !matched_indices.contains(&streamed_idx) && !seen_as_true, ); seen_as_true = false; } else { From a2ac00da1b3aa7879317ae88d1b356b27f49f887 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Mon, 22 Jul 2024 23:51:43 +0300 Subject: [PATCH 34/37] Improve Union Equivalence Propagation (#11506) * Initial commit * Fix formatting * Minor changes * Fix failing test * Change union calculation algorithm to make it symmetric * Minor changes * Add unit tests * Simplifications * Review Part 1 * Move test and union equivalence * Add new tests * Support for union with different schema * Address reviews * Review Part 2 * Add new tests * Final Review --------- Co-authored-by: Mehmet Ozan Kabak --- .../physical-expr-common/src/physical_expr.rs | 33 +- .../physical-expr/src/equivalence/mod.rs | 4 +- .../src/equivalence/properties.rs | 641 ++++++++++++++++-- datafusion/physical-expr/src/lib.rs | 2 +- datafusion/physical-plan/src/common.rs | 356 +--------- datafusion/physical-plan/src/union.rs | 115 +--- datafusion/sqllogictest/test_files/order.slt | 7 + 7 files changed, 647 insertions(+), 511 deletions(-) diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 1998f1439646..c74fb9c2d1b7 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -20,13 +20,15 @@ use std::fmt::{Debug, Display}; use std::hash::{Hash, Hasher}; use std::sync::Arc; +use crate::expressions::column::Column; use crate::utils::scatter; use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use datafusion_common::{internal_err, not_impl_err, Result}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{internal_err, not_impl_err, plan_err, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::ColumnarValue; @@ -191,6 +193,33 @@ pub fn with_new_children_if_necessary( } } +/// Rewrites an expression according to new schema; i.e. changes the columns it +/// refers to with the column at corresponding index in the new schema. Returns +/// an error if the given schema has fewer columns than the original schema. +/// Note that the resulting expression may not be valid if data types in the +/// new schema is incompatible with expression nodes. +pub fn with_new_schema( + expr: Arc, + schema: &SchemaRef, +) -> Result> { + Ok(expr + .transform_up(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + let idx = col.index(); + let Some(field) = schema.fields().get(idx) else { + return plan_err!( + "New schema has fewer columns than original schema" + ); + }; + let new_col = Column::new(field.name(), idx); + Ok(Transformed::yes(Arc::new(new_col) as _)) + } else { + Ok(Transformed::no(expr)) + } + })? + .data) +} + pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { if any.is::>() { any.downcast_ref::>() diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 83f94057f740..b9228282b081 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -30,7 +30,9 @@ mod properties; pub use class::{ConstExpr, EquivalenceClass, EquivalenceGroup}; pub use ordering::OrderingEquivalenceClass; pub use projection::ProjectionMapping; -pub use properties::{join_equivalence_properties, EquivalenceProperties}; +pub use properties::{ + calculate_union, join_equivalence_properties, EquivalenceProperties, +}; /// This function constructs a duplicate-free `LexOrderingReq` by filtering out /// duplicate entries that have same physical expression inside. For example, diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 8c327fbaf409..64c22064d4b7 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -21,7 +21,8 @@ use std::sync::Arc; use super::ordering::collapse_lex_ordering; use crate::equivalence::class::const_exprs_contains; use crate::equivalence::{ - collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, + collapse_lex_req, EquivalenceClass, EquivalenceGroup, OrderingEquivalenceClass, + ProjectionMapping, }; use crate::expressions::Literal; use crate::{ @@ -32,11 +33,12 @@ use crate::{ use arrow_schema::{SchemaRef, SortOptions}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{JoinSide, JoinType, Result}; +use datafusion_common::{plan_err, JoinSide, JoinType, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_physical_expr_common::expressions::column::Column; use datafusion_physical_expr_common::expressions::CastExpr; +use datafusion_physical_expr_common::physical_expr::with_new_schema; use datafusion_physical_expr_common::utils::ExprPropertiesNode; use indexmap::{IndexMap, IndexSet}; @@ -536,33 +538,6 @@ impl EquivalenceProperties { .then_some(if lhs.len() >= rhs.len() { lhs } else { rhs }) } - /// Calculates the "meet" of the given orderings (`lhs` and `rhs`). - /// The meet of a set of orderings is the finest ordering that is satisfied - /// by all the orderings in that set. For details, see: - /// - /// - /// - /// If there is no ordering that satisfies both `lhs` and `rhs`, returns - /// `None`. As an example, the meet of orderings `[a ASC]` and `[a ASC, b ASC]` - /// is `[a ASC]`. - pub fn get_meet_ordering( - &self, - lhs: LexOrderingRef, - rhs: LexOrderingRef, - ) -> Option { - let lhs = self.normalize_sort_exprs(lhs); - let rhs = self.normalize_sort_exprs(rhs); - let mut meet = vec![]; - for (lhs, rhs) in lhs.into_iter().zip(rhs.into_iter()) { - if lhs.eq(&rhs) { - meet.push(lhs); - } else { - break; - } - } - (!meet.is_empty()).then_some(meet) - } - /// we substitute the ordering according to input expression type, this is a simplified version /// In this case, we just substitute when the expression satisfy the following condition: /// I. just have one column and is a CAST expression @@ -1007,6 +982,74 @@ impl EquivalenceProperties { .map(|node| node.data) .unwrap_or(ExprProperties::new_unknown()) } + + /// Transforms this `EquivalenceProperties` into a new `EquivalenceProperties` + /// by mapping columns in the original schema to columns in the new schema + /// by index. + pub fn with_new_schema(self, schema: SchemaRef) -> Result { + // The new schema and the original schema is aligned when they have the + // same number of columns, and fields at the same index have the same + // type in both schemas. + let schemas_aligned = (self.schema.fields.len() == schema.fields.len()) + && self + .schema + .fields + .iter() + .zip(schema.fields.iter()) + .all(|(lhs, rhs)| lhs.data_type().eq(rhs.data_type())); + if !schemas_aligned { + // Rewriting equivalence properties in terms of new schema is not + // safe when schemas are not aligned: + return plan_err!( + "Cannot rewrite old_schema:{:?} with new schema: {:?}", + self.schema, + schema + ); + } + // Rewrite constants according to new schema: + let new_constants = self + .constants + .into_iter() + .map(|const_expr| { + let across_partitions = const_expr.across_partitions(); + let new_const_expr = with_new_schema(const_expr.owned_expr(), &schema)?; + Ok(ConstExpr::new(new_const_expr) + .with_across_partitions(across_partitions)) + }) + .collect::>>()?; + + // Rewrite orderings according to new schema: + let mut new_orderings = vec![]; + for ordering in self.oeq_class.orderings { + let new_ordering = ordering + .into_iter() + .map(|mut sort_expr| { + sort_expr.expr = with_new_schema(sort_expr.expr, &schema)?; + Ok(sort_expr) + }) + .collect::>()?; + new_orderings.push(new_ordering); + } + + // Rewrite equivalence classes according to the new schema: + let mut eq_classes = vec![]; + for eq_class in self.eq_group.classes { + let new_eq_exprs = eq_class + .into_vec() + .into_iter() + .map(|expr| with_new_schema(expr, &schema)) + .collect::>()?; + eq_classes.push(EquivalenceClass::new(new_eq_exprs)); + } + + // Construct the resulting equivalence properties: + let mut result = EquivalenceProperties::new(schema); + result.constants = new_constants; + result.add_new_orderings(new_orderings); + result.add_equivalence_group(EquivalenceGroup::new(eq_classes)); + + Ok(result) + } } /// Calculates the properties of a given [`ExprPropertiesNode`]. @@ -1484,6 +1527,84 @@ impl Hash for ExprWrapper { } } +/// Calculates the union (in the sense of `UnionExec`) `EquivalenceProperties` +/// of `lhs` and `rhs` according to the schema of `lhs`. +fn calculate_union_binary( + lhs: EquivalenceProperties, + mut rhs: EquivalenceProperties, +) -> Result { + // TODO: In some cases, we should be able to preserve some equivalence + // classes. Add support for such cases. + + // Harmonize the schema of the rhs with the schema of the lhs (which is the accumulator schema): + if !rhs.schema.eq(&lhs.schema) { + rhs = rhs.with_new_schema(Arc::clone(&lhs.schema))?; + } + + // First, calculate valid constants for the union. A quantity is constant + // after the union if it is constant in both sides. + let constants = lhs + .constants() + .iter() + .filter(|const_expr| const_exprs_contains(rhs.constants(), const_expr.expr())) + .map(|const_expr| { + // TODO: When both sides' constants are valid across partitions, + // the union's constant should also be valid if values are + // the same. However, we do not have the capability to + // check this yet. + ConstExpr::new(Arc::clone(const_expr.expr())).with_across_partitions(false) + }) + .collect(); + + // Next, calculate valid orderings for the union by searching for prefixes + // in both sides. + let mut orderings = vec![]; + for mut ordering in lhs.normalized_oeq_class().orderings { + // Progressively shorten the ordering to search for a satisfied prefix: + while !rhs.ordering_satisfy(&ordering) { + ordering.pop(); + } + // There is a non-trivial satisfied prefix, add it as a valid ordering: + if !ordering.is_empty() { + orderings.push(ordering); + } + } + for mut ordering in rhs.normalized_oeq_class().orderings { + // Progressively shorten the ordering to search for a satisfied prefix: + while !lhs.ordering_satisfy(&ordering) { + ordering.pop(); + } + // There is a non-trivial satisfied prefix, add it as a valid ordering: + if !ordering.is_empty() { + orderings.push(ordering); + } + } + let mut eq_properties = EquivalenceProperties::new(lhs.schema); + eq_properties.constants = constants; + eq_properties.add_new_orderings(orderings); + Ok(eq_properties) +} + +/// Calculates the union (in the sense of `UnionExec`) `EquivalenceProperties` +/// of the given `EquivalenceProperties` in `eqps` according to the given +/// output `schema` (which need not be the same with those of `lhs` and `rhs` +/// as details such as nullability may be different). +pub fn calculate_union( + eqps: Vec, + schema: SchemaRef, +) -> Result { + // TODO: In some cases, we should be able to preserve some equivalence + // classes. Add support for such cases. + let mut init = eqps[0].clone(); + // Harmonize the schema of the init with the schema of the union: + if !init.schema.eq(&schema) { + init = init.with_new_schema(schema)?; + } + eqps.into_iter() + .skip(1) + .try_fold(init, calculate_union_binary) +} + #[cfg(test)] mod tests { use std::ops::Not; @@ -2188,50 +2309,6 @@ mod tests { Ok(()) } - #[test] - fn test_get_meet_ordering() -> Result<()> { - let schema = create_test_schema()?; - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let eq_properties = EquivalenceProperties::new(schema); - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - let tests_cases = vec![ - // Get meet ordering between [a ASC] and [a ASC, b ASC] - // result should be [a ASC] - ( - vec![(col_a, option_asc)], - vec![(col_a, option_asc), (col_b, option_asc)], - Some(vec![(col_a, option_asc)]), - ), - // Get meet ordering between [a ASC] and [a DESC] - // result should be None. - (vec![(col_a, option_asc)], vec![(col_a, option_desc)], None), - // Get meet ordering between [a ASC, b ASC] and [a ASC, b DESC] - // result should be [a ASC]. - ( - vec![(col_a, option_asc), (col_b, option_asc)], - vec![(col_a, option_asc), (col_b, option_desc)], - Some(vec![(col_a, option_asc)]), - ), - ]; - for (lhs, rhs, expected) in tests_cases { - let lhs = convert_to_sort_exprs(&lhs); - let rhs = convert_to_sort_exprs(&rhs); - let expected = expected.map(|expected| convert_to_sort_exprs(&expected)); - let finer = eq_properties.get_meet_ordering(&lhs, &rhs); - assert_eq!(finer, expected) - } - - Ok(()) - } - #[test] fn test_get_finer() -> Result<()> { let schema = create_test_schema()?; @@ -2525,4 +2602,422 @@ mod tests { Ok(()) } + + fn append_fields(schema: &SchemaRef, text: &str) -> SchemaRef { + Arc::new(Schema::new( + schema + .fields() + .iter() + .map(|field| { + Field::new( + // Annotate name with `text`: + format!("{}{}", field.name(), text), + field.data_type().clone(), + field.is_nullable(), + ) + }) + .collect::>(), + )) + } + + #[tokio::test] + async fn test_union_equivalence_properties_multi_children() -> Result<()> { + let schema = create_test_schema()?; + let schema2 = append_fields(&schema, "1"); + let schema3 = append_fields(&schema, "2"); + let test_cases = vec![ + // --------- TEST CASE 1 ---------- + ( + vec![ + // Children 1 + ( + // Orderings + vec![vec!["a", "b", "c"]], + Arc::clone(&schema), + ), + // Children 2 + ( + // Orderings + vec![vec!["a1", "b1", "c1"]], + Arc::clone(&schema2), + ), + // Children 3 + ( + // Orderings + vec![vec!["a2", "b2"]], + Arc::clone(&schema3), + ), + ], + // Expected + vec![vec!["a", "b"]], + ), + // --------- TEST CASE 2 ---------- + ( + vec![ + // Children 1 + ( + // Orderings + vec![vec!["a", "b", "c"]], + Arc::clone(&schema), + ), + // Children 2 + ( + // Orderings + vec![vec!["a1", "b1", "c1"]], + Arc::clone(&schema2), + ), + // Children 3 + ( + // Orderings + vec![vec!["a2", "b2", "c2"]], + Arc::clone(&schema3), + ), + ], + // Expected + vec![vec!["a", "b", "c"]], + ), + // --------- TEST CASE 3 ---------- + ( + vec![ + // Children 1 + ( + // Orderings + vec![vec!["a", "b"]], + Arc::clone(&schema), + ), + // Children 2 + ( + // Orderings + vec![vec!["a1", "b1", "c1"]], + Arc::clone(&schema2), + ), + // Children 3 + ( + // Orderings + vec![vec!["a2", "b2", "c2"]], + Arc::clone(&schema3), + ), + ], + // Expected + vec![vec!["a", "b"]], + ), + // --------- TEST CASE 4 ---------- + ( + vec![ + // Children 1 + ( + // Orderings + vec![vec!["a", "b"]], + Arc::clone(&schema), + ), + // Children 2 + ( + // Orderings + vec![vec!["a1", "b1"]], + Arc::clone(&schema2), + ), + // Children 3 + ( + // Orderings + vec![vec!["b2", "c2"]], + Arc::clone(&schema3), + ), + ], + // Expected + vec![], + ), + // --------- TEST CASE 5 ---------- + ( + vec![ + // Children 1 + ( + // Orderings + vec![vec!["a", "b"], vec!["c"]], + Arc::clone(&schema), + ), + // Children 2 + ( + // Orderings + vec![vec!["a1", "b1"], vec!["c1"]], + Arc::clone(&schema2), + ), + ], + // Expected + vec![vec!["a", "b"], vec!["c"]], + ), + ]; + for (children, expected) in test_cases { + let children_eqs = children + .iter() + .map(|(orderings, schema)| { + let orderings = orderings + .iter() + .map(|ordering| { + ordering + .iter() + .map(|name| PhysicalSortExpr { + expr: col(name, schema).unwrap(), + options: SortOptions::default(), + }) + .collect::>() + }) + .collect::>(); + EquivalenceProperties::new_with_orderings( + Arc::clone(schema), + &orderings, + ) + }) + .collect::>(); + let actual = calculate_union(children_eqs, Arc::clone(&schema))?; + + let expected_ordering = expected + .into_iter() + .map(|ordering| { + ordering + .into_iter() + .map(|name| PhysicalSortExpr { + expr: col(name, &schema).unwrap(), + options: SortOptions::default(), + }) + .collect::>() + }) + .collect::>(); + let expected = EquivalenceProperties::new_with_orderings( + Arc::clone(&schema), + &expected_ordering, + ); + assert_eq_properties_same( + &actual, + &expected, + format!("expected: {:?}, actual: {:?}", expected, actual), + ); + } + Ok(()) + } + + #[tokio::test] + async fn test_union_equivalence_properties_binary() -> Result<()> { + let schema = create_test_schema()?; + let schema2 = append_fields(&schema, "1"); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_a1 = &col("a1", &schema2)?; + let col_b1 = &col("b1", &schema2)?; + let options = SortOptions::default(); + let options_desc = !SortOptions::default(); + let test_cases = [ + //-----------TEST CASE 1----------// + ( + ( + // First child orderings + vec![ + // [a ASC] + (vec![(col_a, options)]), + ], + // First child constants + vec![col_b, col_c], + Arc::clone(&schema), + ), + ( + // Second child orderings + vec![ + // [b ASC] + (vec![(col_b, options)]), + ], + // Second child constants + vec![col_a, col_c], + Arc::clone(&schema), + ), + ( + // Union expected orderings + vec![ + // [a ASC] + vec![(col_a, options)], + // [b ASC] + vec![(col_b, options)], + ], + // Union + vec![col_c], + ), + ), + //-----------TEST CASE 2----------// + // Meet ordering between [a ASC], [a ASC, b ASC] should be [a ASC] + ( + ( + // First child orderings + vec![ + // [a ASC] + vec![(col_a, options)], + ], + // No constant + vec![], + Arc::clone(&schema), + ), + ( + // Second child orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, options), (col_b, options)], + ], + // No constant + vec![], + Arc::clone(&schema), + ), + ( + // Union orderings + vec![ + // [a ASC] + vec![(col_a, options)], + ], + // No constant + vec![], + ), + ), + //-----------TEST CASE 3----------// + // Meet ordering between [a ASC], [a DESC] should be [] + ( + ( + // First child orderings + vec![ + // [a ASC] + vec![(col_a, options)], + ], + // No constant + vec![], + Arc::clone(&schema), + ), + ( + // Second child orderings + vec![ + // [a DESC] + vec![(col_a, options_desc)], + ], + // No constant + vec![], + Arc::clone(&schema), + ), + ( + // Union doesn't have any ordering + vec![], + // No constant + vec![], + ), + ), + //-----------TEST CASE 4----------// + // Meet ordering between [a ASC], [a1 ASC, b1 ASC] should be [a ASC] + // Where a, and a1 ath the same index for their corresponding schemas. + ( + ( + // First child orderings + vec![ + // [a ASC] + vec![(col_a, options)], + ], + // No constant + vec![], + Arc::clone(&schema), + ), + ( + // Second child orderings + vec![ + // [a1 ASC, b1 ASC] + vec![(col_a1, options), (col_b1, options)], + ], + // No constant + vec![], + Arc::clone(&schema2), + ), + ( + // Union orderings + vec![ + // [a ASC] + vec![(col_a, options)], + ], + // No constant + vec![], + ), + ), + ]; + + for ( + test_idx, + ( + (first_child_orderings, first_child_constants, first_schema), + (second_child_orderings, second_child_constants, second_schema), + (union_orderings, union_constants), + ), + ) in test_cases.iter().enumerate() + { + let first_orderings = first_child_orderings + .iter() + .map(|ordering| convert_to_sort_exprs(ordering)) + .collect::>(); + let first_constants = first_child_constants + .iter() + .map(|expr| ConstExpr::new(Arc::clone(expr))) + .collect::>(); + let mut lhs = EquivalenceProperties::new(Arc::clone(first_schema)); + lhs = lhs.add_constants(first_constants); + lhs.add_new_orderings(first_orderings); + + let second_orderings = second_child_orderings + .iter() + .map(|ordering| convert_to_sort_exprs(ordering)) + .collect::>(); + let second_constants = second_child_constants + .iter() + .map(|expr| ConstExpr::new(Arc::clone(expr))) + .collect::>(); + let mut rhs = EquivalenceProperties::new(Arc::clone(second_schema)); + rhs = rhs.add_constants(second_constants); + rhs.add_new_orderings(second_orderings); + + let union_expected_orderings = union_orderings + .iter() + .map(|ordering| convert_to_sort_exprs(ordering)) + .collect::>(); + let union_constants = union_constants + .iter() + .map(|expr| ConstExpr::new(Arc::clone(expr))) + .collect::>(); + let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema)); + union_expected_eq = union_expected_eq.add_constants(union_constants); + union_expected_eq.add_new_orderings(union_expected_orderings); + + let actual_union_eq = calculate_union_binary(lhs, rhs)?; + let err_msg = format!( + "Error in test id: {:?}, test case: {:?}", + test_idx, test_cases[test_idx] + ); + assert_eq_properties_same(&actual_union_eq, &union_expected_eq, err_msg); + } + Ok(()) + } + + fn assert_eq_properties_same( + lhs: &EquivalenceProperties, + rhs: &EquivalenceProperties, + err_msg: String, + ) { + // Check whether constants are same + let lhs_constants = lhs.constants(); + let rhs_constants = rhs.constants(); + assert_eq!(lhs_constants.len(), rhs_constants.len(), "{}", err_msg); + for rhs_constant in rhs_constants { + assert!( + const_exprs_contains(lhs_constants, rhs_constant.expr()), + "{}", + err_msg + ); + } + + // Check whether orderings are same. + let lhs_orderings = lhs.oeq_class(); + let rhs_orderings = &rhs.oeq_class.orderings; + assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{}", err_msg); + for rhs_ordering in rhs_orderings { + assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg); + } + } } diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 4f83ae01959b..2e78119eba46 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -48,7 +48,7 @@ pub use analysis::{analyze, AnalysisContext, ExprBoundaries}; pub use datafusion_physical_expr_common::aggregate::{ AggregateExpr, AggregatePhysicalExpressions, }; -pub use equivalence::{ConstExpr, EquivalenceProperties}; +pub use equivalence::{calculate_union, ConstExpr, EquivalenceProperties}; pub use partitioning::{Distribution, Partitioning}; pub use physical_expr::{ physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal, diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index bf9d14e73dd8..4b5eea6b760d 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -22,9 +22,9 @@ use std::fs::{metadata, File}; use std::path::{Path, PathBuf}; use std::sync::Arc; -use super::{ExecutionPlanProperties, SendableRecordBatchStream}; +use super::SendableRecordBatchStream; use crate::stream::RecordBatchReceiverStream; -use crate::{ColumnStatistics, ExecutionPlan, Statistics}; +use crate::{ColumnStatistics, Statistics}; use arrow::datatypes::Schema; use arrow::ipc::writer::{FileWriter, IpcWriteOptions}; @@ -33,8 +33,6 @@ use arrow_array::Array; use datafusion_common::stats::Precision; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_execution::memory_pool::MemoryReservation; -use datafusion_physical_expr::expressions::{BinaryExpr, Column}; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use futures::{StreamExt, TryStreamExt}; use parking_lot::Mutex; @@ -178,71 +176,6 @@ pub fn compute_record_batch_statistics( } } -/// Calculates the "meet" of given orderings. -/// The meet is the finest ordering that satisfied by all the given -/// orderings, see . -pub fn get_meet_of_orderings( - given: &[Arc], -) -> Option<&[PhysicalSortExpr]> { - given - .iter() - .map(|item| item.output_ordering()) - .collect::>>() - .and_then(get_meet_of_orderings_helper) -} - -fn get_meet_of_orderings_helper( - orderings: Vec<&[PhysicalSortExpr]>, -) -> Option<&[PhysicalSortExpr]> { - let mut idx = 0; - let first = orderings[0]; - loop { - for ordering in orderings.iter() { - if idx >= ordering.len() { - return Some(ordering); - } else { - let schema_aligned = check_expr_alignment( - ordering[idx].expr.as_ref(), - first[idx].expr.as_ref(), - ); - if !schema_aligned || (ordering[idx].options != first[idx].options) { - // In a union, the output schema is that of the first child (by convention). - // Therefore, generate the result from the first child's schema: - return if idx > 0 { Some(&first[..idx]) } else { None }; - } - } - } - idx += 1; - } - - fn check_expr_alignment(first: &dyn PhysicalExpr, second: &dyn PhysicalExpr) -> bool { - match ( - first.as_any().downcast_ref::(), - second.as_any().downcast_ref::(), - first.as_any().downcast_ref::(), - second.as_any().downcast_ref::(), - ) { - (Some(first_col), Some(second_col), _, _) => { - first_col.index() == second_col.index() - } - (_, _, Some(first_binary), Some(second_binary)) => { - if first_binary.op() == second_binary.op() { - check_expr_alignment( - first_binary.left().as_ref(), - second_binary.left().as_ref(), - ) && check_expr_alignment( - first_binary.right().as_ref(), - second_binary.right().as_ref(), - ) - } else { - false - } - } - (_, _, _, _) => false, - } - } -} - /// Write in Arrow IPC format. pub struct IPCWriter { /// path @@ -342,297 +275,12 @@ pub fn can_project( #[cfg(test)] mod tests { - use std::ops::Not; - use super::*; - use crate::memory::MemoryExec; - use crate::sorts::sort::SortExec; - use crate::union::UnionExec; - use arrow::compute::SortOptions; use arrow::{ array::{Float32Array, Float64Array, UInt64Array}, datatypes::{DataType, Field}, }; - use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::col; - - #[test] - fn get_meet_of_orderings_helper_common_prefix_test() -> Result<()> { - let input1: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: SortOptions::default(), - }, - ]; - - let input2: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("x", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("y", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("z", 2)), - options: SortOptions::default(), - }, - ]; - - let input3: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("d", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("e", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("f", 2)), - options: SortOptions::default(), - }, - ]; - - let input4: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("g", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("h", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - // Note that index of this column is not 2. Hence this 3rd entry shouldn't be - // in the output ordering. - expr: Arc::new(Column::new("i", 3)), - options: SortOptions::default(), - }, - ]; - - let expected = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: SortOptions::default(), - }, - ]; - let result = get_meet_of_orderings_helper(vec![&input1, &input2, &input3]); - assert_eq!(result.unwrap(), expected); - - let expected = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - ]; - let result = get_meet_of_orderings_helper(vec![&input1, &input2, &input4]); - assert_eq!(result.unwrap(), expected); - Ok(()) - } - - #[test] - fn get_meet_of_orderings_helper_subset_test() -> Result<()> { - let input1: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - ]; - - let input2: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("d", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("e", 2)), - options: SortOptions::default(), - }, - ]; - - let input3: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("f", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("g", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("h", 2)), - options: SortOptions::default(), - }, - ]; - - let result = get_meet_of_orderings_helper(vec![&input1, &input2, &input3]); - assert_eq!(result.unwrap(), input1); - Ok(()) - } - - #[test] - fn get_meet_of_orderings_helper_no_overlap_test() -> Result<()> { - let input1: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - // Since ordering is conflicting with other inputs - // output ordering should be empty - options: SortOptions::default().not(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - ]; - - let input2: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("x", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 1)), - options: SortOptions::default(), - }, - ]; - - let input3: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 2)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("y", 1)), - options: SortOptions::default(), - }, - ]; - - let result = get_meet_of_orderings_helper(vec![&input1, &input2]); - assert!(result.is_none()); - - let result = get_meet_of_orderings_helper(vec![&input2, &input3]); - assert!(result.is_none()); - - let result = get_meet_of_orderings_helper(vec![&input1, &input3]); - assert!(result.is_none()); - Ok(()) - } - - #[test] - fn get_meet_of_orderings_helper_binary_exprs() -> Result<()> { - let input1: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), - Operator::Plus, - Arc::new(Column::new("b", 1)), - )), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: SortOptions::default(), - }, - ]; - - let input2: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("x", 0)), - Operator::Plus, - Arc::new(Column::new("y", 1)), - )), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("z", 2)), - options: SortOptions::default(), - }, - ]; - - // erroneous input - let input3: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 1)), - Operator::Plus, - Arc::new(Column::new("b", 0)), - )), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: SortOptions::default(), - }, - ]; - - let result = get_meet_of_orderings_helper(vec![&input1, &input2]); - assert_eq!(input1, result.unwrap()); - - let result = get_meet_of_orderings_helper(vec![&input2, &input3]); - assert!(result.is_none()); - - let result = get_meet_of_orderings_helper(vec![&input1, &input3]); - assert!(result.is_none()); - Ok(()) - } - - #[test] - fn test_meet_of_orderings() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("f32", DataType::Float32, false), - Field::new("f64", DataType::Float64, false), - ])); - let sort_expr = vec![PhysicalSortExpr { - expr: col("f32", &schema).unwrap(), - options: SortOptions::default(), - }]; - let memory_exec = - Arc::new(MemoryExec::try_new(&[], Arc::clone(&schema), None)?) as _; - let sort_exec = Arc::new(SortExec::new(sort_expr.clone(), memory_exec)) - as Arc; - let memory_exec2 = Arc::new(MemoryExec::try_new(&[], schema, None)?) as _; - // memory_exec2 doesn't have output ordering - let union_exec = UnionExec::new(vec![Arc::clone(&sort_exec), memory_exec2]); - let res = get_meet_of_orderings(union_exec.inputs()); - assert!(res.is_none()); - - let union_exec = UnionExec::new(vec![Arc::clone(&sort_exec), sort_exec]); - let res = get_meet_of_orderings(union_exec.inputs()); - assert_eq!(res, Some(&sort_expr[..])); - Ok(()) - } #[test] fn test_compute_record_batch_statistics_empty() -> Result<()> { diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 24c80048ab4a..9321fdb2cadf 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -41,7 +41,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{ConstExpr, EquivalenceProperties}; +use datafusion_physical_expr::{calculate_union, EquivalenceProperties}; use futures::Stream; use itertools::Itertools; @@ -99,7 +99,12 @@ impl UnionExec { /// Create a new UnionExec pub fn new(inputs: Vec>) -> Self { let schema = union_schema(&inputs); - let cache = Self::compute_properties(&inputs, schema); + // The schema of the inputs and the union schema is consistent when: + // - They have the same number of fields, and + // - Their fields have same types at the same indices. + // Here, we know that schemas are consistent and the call below can + // not return an error. + let cache = Self::compute_properties(&inputs, schema).unwrap(); UnionExec { inputs, metrics: ExecutionPlanMetricsSet::new(), @@ -116,13 +121,13 @@ impl UnionExec { fn compute_properties( inputs: &[Arc], schema: SchemaRef, - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties: - let children_eqs = inputs + let children_eqps = inputs .iter() - .map(|child| child.equivalence_properties()) + .map(|child| child.equivalence_properties().clone()) .collect::>(); - let eq_properties = calculate_union_eq_properties(&children_eqs, schema); + let eq_properties = calculate_union(children_eqps, schema)?; // Calculate output partitioning; i.e. sum output partitions of the inputs. let num_partitions = inputs @@ -134,71 +139,13 @@ impl UnionExec { // Determine execution mode: let mode = execution_mode_from_children(inputs.iter()); - PlanProperties::new(eq_properties, output_partitioning, mode) + Ok(PlanProperties::new( + eq_properties, + output_partitioning, + mode, + )) } } -/// Calculate `EquivalenceProperties` for `UnionExec` from the `EquivalenceProperties` -/// of its children. -fn calculate_union_eq_properties( - children_eqs: &[&EquivalenceProperties], - schema: SchemaRef, -) -> EquivalenceProperties { - // Calculate equivalence properties: - // TODO: In some cases, we should be able to preserve some equivalence - // classes and constants. Add support for such cases. - let mut eq_properties = EquivalenceProperties::new(schema); - // Use the ordering equivalence class of the first child as the seed: - let mut meets = children_eqs[0] - .oeq_class() - .iter() - .map(|item| item.to_vec()) - .collect::>(); - // Iterate over all the children: - for child_eqs in &children_eqs[1..] { - // Compute meet orderings of the current meets and the new ordering - // equivalence class. - let mut idx = 0; - while idx < meets.len() { - // Find all the meets of `current_meet` with this child's orderings: - let valid_meets = child_eqs.oeq_class().iter().filter_map(|ordering| { - child_eqs.get_meet_ordering(ordering, &meets[idx]) - }); - // Use the longest of these meets as others are redundant: - if let Some(next_meet) = valid_meets.max_by_key(|m| m.len()) { - meets[idx] = next_meet; - idx += 1; - } else { - meets.swap_remove(idx); - } - } - } - // We know have all the valid orderings after union, remove redundant - // entries (implicitly) and return: - eq_properties.add_new_orderings(meets); - - let mut meet_constants = children_eqs[0].constants().to_vec(); - // Iterate over all the children: - for child_eqs in &children_eqs[1..] { - let constants = child_eqs.constants(); - meet_constants = meet_constants - .into_iter() - .filter_map(|meet_constant| { - for const_expr in constants { - if const_expr.expr().eq(meet_constant.expr()) { - // TODO: Check whether constant expressions evaluates the same value or not for each partition - let across_partitions = false; - return Some( - ConstExpr::from(meet_constant.owned_expr()) - .with_across_partitions(across_partitions), - ); - } - } - None - }) - .collect::>(); - } - eq_properties.add_constants(meet_constants) -} impl DisplayAs for UnionExec { fn fmt_as( @@ -639,8 +586,8 @@ mod tests { use arrow_schema::{DataType, SortOptions}; use datafusion_common::ScalarValue; - use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; + use datafusion_physical_expr_common::expressions::column::col; // Generate a schema which consists of 7 columns (a, b, c, d, e, f, g) fn create_test_schema() -> Result { @@ -856,23 +803,31 @@ mod tests { .with_sort_information(second_orderings), ); + let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema)); + union_expected_eq.add_new_orderings(union_expected_orderings); + let union = UnionExec::new(vec![child1, child2]); let union_eq_properties = union.properties().equivalence_properties(); - let union_actual_orderings = union_eq_properties.oeq_class(); let err_msg = format!( "Error in test id: {:?}, test case: {:?}", test_idx, test_cases[test_idx] ); - assert_eq!( - union_actual_orderings.len(), - union_expected_orderings.len(), - "{}", - err_msg - ); - for expected in &union_expected_orderings { - assert!(union_actual_orderings.contains(expected), "{}", err_msg); - } + assert_eq_properties_same(union_eq_properties, &union_expected_eq, err_msg); } Ok(()) } + + fn assert_eq_properties_same( + lhs: &EquivalenceProperties, + rhs: &EquivalenceProperties, + err_msg: String, + ) { + // Check whether orderings are same. + let lhs_orderings = lhs.oeq_class(); + let rhs_orderings = &rhs.oeq_class.orderings; + assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{}", err_msg); + for rhs_ordering in rhs_orderings { + assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg); + } + } } diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index 51de40fb1972..1aeaf9b76d48 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -1132,3 +1132,10 @@ physical_plan 02)--ProjectionExec: expr=[CAST(inc_col@0 > desc_col@1 AS Int32) as c] 03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[inc_col, desc_col], output_orderings=[[inc_col@0 ASC NULLS LAST], [desc_col@1 DESC]], has_header=true + +# Union a query with the actual data and one with a constant +query I +SELECT (SELECT c from ordered_table ORDER BY c LIMIT 1) UNION ALL (SELECT 23 as c from ordered_table ORDER BY c LIMIT 1) ORDER BY c; +---- +0 +23 From d941dc3f0196e85520c6aa923edbcefcc8c7b265 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 23 Jul 2024 08:02:11 +0800 Subject: [PATCH 35/37] Migrate `OrderSensitiveArrayAgg` to be a user defined aggregate (#11564) * first draft Signed-off-by: jayzhan211 * rm old agg Signed-off-by: jayzhan211 * replace udaf with interal function - create aggregate with dfschema Signed-off-by: jayzhan211 * rm test Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * rm useless Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 * rename Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- .../aggregate_statistics.rs | 1 + .../combine_partial_final_agg.rs | 2 + datafusion/core/src/physical_planner.rs | 35 +- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 1 + datafusion/expr/src/function.rs | 8 +- .../src/approx_percentile_cont.rs | 9 +- .../functions-aggregate/src/array_agg.rs | 423 +++++++++++++- .../functions-aggregate/src/first_last.rs | 11 +- .../functions-aggregate/src/nth_value.rs | 7 +- datafusion/functions-aggregate/src/stddev.rs | 6 + .../physical-expr-common/src/aggregate/mod.rs | 95 +++- .../physical-expr-common/src/sort_expr.rs | 13 +- datafusion/physical-expr-common/src/utils.rs | 23 +- .../src/aggregate/array_agg_ordered.rs | 520 ------------------ .../physical-expr/src/aggregate/build_in.rs | 24 +- datafusion/physical-expr/src/aggregate/mod.rs | 5 +- .../physical-expr/src/expressions/mod.rs | 1 - .../physical-plan/src/aggregates/mod.rs | 125 ++++- datafusion/physical-plan/src/lib.rs | 2 +- datafusion/physical-plan/src/windows/mod.rs | 1 + datafusion/proto/src/physical_plan/mod.rs | 2 +- .../proto/src/physical_plan/to_proto.rs | 10 +- .../tests/cases/roundtrip_physical_plan.rs | 9 + 23 files changed, 681 insertions(+), 652 deletions(-) delete mode 100644 datafusion/physical-expr/src/aggregate/array_agg_ordered.rs diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index e412d814239d..e7580d3e33ef 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -429,6 +429,7 @@ pub(crate) mod tests { self.column_name(), false, false, + false, ) .unwrap() } diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 29148a594f31..ddb7d36fb595 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -288,6 +288,7 @@ mod tests { name, false, false, + false, ) .unwrap() } @@ -378,6 +379,7 @@ mod tests { "Sum(b)", false, false, + false, )?]; let groups: Vec<(Arc, String)> = vec![(col("c", &schema)?, "c".to_string())]; diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 97533cd5276a..329d343f13fc 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1839,34 +1839,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; - // TODO: Remove this after array_agg are all udafs let (agg_expr, filter, order_by) = match func_def { - AggregateFunctionDefinition::UDF(udf) - if udf.name() == "ARRAY_AGG" && order_by.is_some() => - { - // not yet support UDAF, fallback to builtin - let physical_sort_exprs = match order_by { - Some(exprs) => Some(create_physical_sort_exprs( - exprs, - logical_input_schema, - execution_props, - )?), - None => None, - }; - let ordering_reqs: Vec = - physical_sort_exprs.clone().unwrap_or(vec![]); - let fun = aggregates::AggregateFunction::ArrayAgg; - let agg_expr = aggregates::create_aggregate_expr( - &fun, - *distinct, - &physical_args, - &ordering_reqs, - physical_input_schema, - name, - ignore_nulls, - )?; - (agg_expr, filter, physical_sort_exprs) - } AggregateFunctionDefinition::BuiltIn(fun) => { let physical_sort_exprs = match order_by { Some(exprs) => Some(create_physical_sort_exprs( @@ -1899,19 +1872,23 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( )?), None => None, }; + let ordering_reqs: Vec = physical_sort_exprs.clone().unwrap_or(vec![]); - let agg_expr = udaf::create_aggregate_expr( + + let agg_expr = udaf::create_aggregate_expr_with_dfschema( fun, &physical_args, args, &sort_exprs, &ordering_reqs, - physical_input_schema, + logical_input_schema, name, ignore_nulls, *distinct, + false, )?; + (agg_expr, filter, physical_sort_exprs) } }; diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index a04f4f349122..736560da97db 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -113,6 +113,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str "sum1", false, false, + false, ) .unwrap()]; let expr = group_by_columns diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 73ab51494de6..d722e55de487 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -20,7 +20,7 @@ use crate::ColumnarValue; use crate::{Accumulator, Expr, PartitionEvaluator}; use arrow::datatypes::{DataType, Field, Schema}; -use datafusion_common::Result; +use datafusion_common::{DFSchema, Result}; use std::sync::Arc; #[derive(Debug, Clone, Copy)] @@ -57,6 +57,9 @@ pub struct AccumulatorArgs<'a> { /// The schema of the input arguments pub schema: &'a Schema, + /// The schema of the input arguments + pub dfschema: &'a DFSchema, + /// Whether to ignore nulls. /// /// SQL allows the user to specify `IGNORE NULLS`, for example: @@ -78,6 +81,9 @@ pub struct AccumulatorArgs<'a> { /// If no `ORDER BY` is specified, `sort_exprs`` will be empty. pub sort_exprs: &'a [Expr], + /// Whether the aggregation is running in reverse order + pub is_reversed: bool, + /// The name of the aggregate expression pub name: &'a str, diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index bbe7d21e2486..dfb94a84cbec 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -30,7 +30,8 @@ use arrow::{ use arrow_schema::{Field, Schema}; use datafusion_common::{ - downcast_value, internal_err, not_impl_err, plan_err, DataFusionError, ScalarValue, + downcast_value, internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, + ScalarValue, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; @@ -42,7 +43,7 @@ use datafusion_expr::{ use datafusion_physical_expr_common::aggregate::tdigest::{ TDigest, TryIntoF64, DEFAULT_MAX_SIZE, }; -use datafusion_physical_expr_common::utils::limited_convert_logical_expr_to_physical_expr; +use datafusion_physical_expr_common::utils::limited_convert_logical_expr_to_physical_expr_with_dfschema; make_udaf_expr_and_func!( ApproxPercentileCont, @@ -135,7 +136,9 @@ impl ApproxPercentileCont { fn get_lit_value(expr: &Expr) -> datafusion_common::Result { let empty_schema = Arc::new(Schema::empty()); let empty_batch = RecordBatch::new_empty(Arc::clone(&empty_schema)); - let expr = limited_convert_logical_expr_to_physical_expr(expr, &empty_schema)?; + let dfschema = DFSchema::empty(); + let expr = + limited_convert_logical_expr_to_physical_expr_with_dfschema(expr, &dfschema)?; let result = expr.evaluate(&empty_batch)?; match result { ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!( diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 9ad453d7a4b2..777a242aa27e 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -17,19 +17,25 @@ //! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`] -use arrow::array::{Array, ArrayRef, AsArray}; +use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, StructArray}; use arrow::datatypes::DataType; -use arrow_schema::Field; +use arrow_schema::{Field, Fields}; use datafusion_common::cast::as_list_array; -use datafusion_common::utils::array_into_list_array_nullable; -use datafusion_common::ScalarValue; +use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; +use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{internal_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::AggregateUDFImpl; use datafusion_expr::{Accumulator, Signature, Volatility}; -use std::collections::HashSet; +use datafusion_physical_expr_common::aggregate::merge_arrays::merge_ordered_arrays; +use datafusion_physical_expr_common::aggregate::utils::ordering_fields; +use datafusion_physical_expr_common::sort_expr::{ + limited_convert_logical_sort_exprs_to_physical_with_dfschema, LexOrdering, + PhysicalSortExpr, +}; +use std::collections::{HashSet, VecDeque}; use std::sync::Arc; make_udaf_expr_and_func!( @@ -91,11 +97,24 @@ impl AggregateUDFImpl for ArrayAgg { )]); } - Ok(vec![Field::new_list( + let mut fields = vec![Field::new_list( format_state_name(args.name, "array_agg"), Field::new("item", args.input_type.clone(), true), true, - )]) + )]; + + if args.ordering_fields.is_empty() { + return Ok(fields); + } + + let orderings = args.ordering_fields.to_vec(); + fields.push(Field::new_list( + format_state_name(args.name, "array_agg_orderings"), + Field::new("item", DataType::Struct(Fields::from(orderings)), true), + false, + )); + + Ok(fields) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -105,7 +124,31 @@ impl AggregateUDFImpl for ArrayAgg { )?)); } - Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?)) + if acc_args.sort_exprs.is_empty() { + return Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?)); + } + + let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( + acc_args.sort_exprs, + acc_args.dfschema, + )?; + + let ordering_dtypes = ordering_req + .iter() + .map(|e| e.expr.data_type(acc_args.schema)) + .collect::>>()?; + + OrderSensitiveArrayAggAccumulator::try_new( + acc_args.input_type, + &ordering_dtypes, + ordering_req, + acc_args.is_reversed, + ) + .map(|acc| Box::new(acc) as _) + } + + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Reversed(array_agg_udaf()) } } @@ -259,3 +302,367 @@ impl Accumulator for DistinctArrayAggAccumulator { - std::mem::size_of_val(&self.datatype) } } + +/// Accumulator for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi +/// partition setting, partial aggregations are computed for every partition, +/// and then their results are merged. +#[derive(Debug)] +pub(crate) struct OrderSensitiveArrayAggAccumulator { + /// Stores entries in the `ARRAY_AGG` result. + values: Vec, + /// Stores values of ordering requirement expressions corresponding to each + /// entry in `values`. This information is used when merging results from + /// different partitions. For detailed information how merging is done, see + /// [`merge_ordered_arrays`]. + ordering_values: Vec>, + /// Stores datatypes of expressions inside values and ordering requirement + /// expressions. + datatypes: Vec, + /// Stores the ordering requirement of the `Accumulator`. + ordering_req: LexOrdering, + /// Whether the aggregation is running in reverse. + reverse: bool, +} + +impl OrderSensitiveArrayAggAccumulator { + /// Create a new order-sensitive ARRAY_AGG accumulator based on the given + /// item data type. + pub fn try_new( + datatype: &DataType, + ordering_dtypes: &[DataType], + ordering_req: LexOrdering, + reverse: bool, + ) -> Result { + let mut datatypes = vec![datatype.clone()]; + datatypes.extend(ordering_dtypes.iter().cloned()); + Ok(Self { + values: vec![], + ordering_values: vec![], + datatypes, + ordering_req, + reverse, + }) + } +} + +impl Accumulator for OrderSensitiveArrayAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let n_row = values[0].len(); + for index in 0..n_row { + let row = get_row_at_idx(values, index)?; + self.values.push(row[0].clone()); + self.ordering_values.push(row[1..].to_vec()); + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + + // First entry in the state is the aggregation result. Second entry + // stores values received for ordering requirement columns for each + // aggregation value inside `ARRAY_AGG` list. For each `StructArray` + // inside `ARRAY_AGG` list, we will receive an `Array` that stores values + // received from its ordering requirement expression. (This information + // is necessary for during merging). + let [array_agg_values, agg_orderings, ..] = &states else { + return exec_err!("State should have two elements"); + }; + let Some(agg_orderings) = agg_orderings.as_list_opt::() else { + return exec_err!("Expects to receive a list array"); + }; + + // Stores ARRAY_AGG results coming from each partition + let mut partition_values = vec![]; + // Stores ordering requirement expression results coming from each partition + let mut partition_ordering_values = vec![]; + + // Existing values should be merged also. + partition_values.push(self.values.clone().into()); + partition_ordering_values.push(self.ordering_values.clone().into()); + + // Convert array to Scalars to sort them easily. Convert back to array at evaluation. + let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; + for v in array_agg_res.into_iter() { + partition_values.push(v.into()); + } + + let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; + + for partition_ordering_rows in orderings.into_iter() { + // Extract value from struct to ordering_rows for each group/partition + let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| { + if let ScalarValue::Struct(s) = ordering_row { + let mut ordering_columns_per_row = vec![]; + + for column in s.columns() { + let sv = ScalarValue::try_from_array(column, 0)?; + ordering_columns_per_row.push(sv); + } + + Ok(ordering_columns_per_row) + } else { + exec_err!( + "Expects to receive ScalarValue::Struct(Arc) but got:{:?}", + ordering_row.data_type() + ) + } + }).collect::>>()?; + + partition_ordering_values.push(ordering_value); + } + + let sort_options = self + .ordering_req + .iter() + .map(|sort_expr| sort_expr.options) + .collect::>(); + + (self.values, self.ordering_values) = merge_ordered_arrays( + &mut partition_values, + &mut partition_ordering_values, + &sort_options, + )?; + + Ok(()) + } + + fn state(&mut self) -> Result> { + let mut result = vec![self.evaluate()?]; + result.push(self.evaluate_orderings()?); + + Ok(result) + } + + fn evaluate(&mut self) -> Result { + if self.values.is_empty() { + return Ok(ScalarValue::new_null_list( + self.datatypes[0].clone(), + true, + 1, + )); + } + + let values = self.values.clone(); + let array = if self.reverse { + ScalarValue::new_list_from_iter( + values.into_iter().rev(), + &self.datatypes[0], + true, + ) + } else { + ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true) + }; + Ok(ScalarValue::List(array)) + } + + fn size(&self) -> usize { + let mut total = std::mem::size_of_val(self) + + ScalarValue::size_of_vec(&self.values) + - std::mem::size_of_val(&self.values); + + // Add size of the `self.ordering_values` + total += + std::mem::size_of::>() * self.ordering_values.capacity(); + for row in &self.ordering_values { + total += ScalarValue::size_of_vec(row) - std::mem::size_of_val(row); + } + + // Add size of the `self.datatypes` + total += std::mem::size_of::() * self.datatypes.capacity(); + for dtype in &self.datatypes { + total += dtype.size() - std::mem::size_of_val(dtype); + } + + // Add size of the `self.ordering_req` + total += std::mem::size_of::() * self.ordering_req.capacity(); + // TODO: Calculate size of each `PhysicalSortExpr` more accurately. + total + } +} + +impl OrderSensitiveArrayAggAccumulator { + fn evaluate_orderings(&self) -> Result { + let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); + let num_columns = fields.len(); + let struct_field = Fields::from(fields.clone()); + + let mut column_wise_ordering_values = vec![]; + for i in 0..num_columns { + let column_values = self + .ordering_values + .iter() + .map(|x| x[i].clone()) + .collect::>(); + let array = if column_values.is_empty() { + new_empty_array(fields[i].data_type()) + } else { + ScalarValue::iter_to_array(column_values.into_iter())? + }; + column_wise_ordering_values.push(array); + } + + let ordering_array = StructArray::try_new( + struct_field.clone(), + column_wise_ordering_values, + None, + )?; + Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable( + Arc::new(ordering_array), + )))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::collections::VecDeque; + use std::sync::Arc; + + use arrow::array::Int64Array; + use arrow_schema::SortOptions; + + use datafusion_common::utils::get_row_at_idx; + use datafusion_common::{Result, ScalarValue}; + + #[test] + fn test_merge_asc() -> Result<()> { + let lhs_arrays: Vec = vec![ + Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2])), + Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])), + ]; + let n_row = lhs_arrays[0].len(); + let lhs_orderings = (0..n_row) + .map(|idx| get_row_at_idx(&lhs_arrays, idx)) + .collect::>>()?; + + let rhs_arrays: Vec = vec![ + Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2])), + Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])), + ]; + let n_row = rhs_arrays[0].len(); + let rhs_orderings = (0..n_row) + .map(|idx| get_row_at_idx(&rhs_arrays, idx)) + .collect::>>()?; + let sort_options = vec![ + SortOptions { + descending: false, + nulls_first: false, + }, + SortOptions { + descending: false, + nulls_first: false, + }, + ]; + + let lhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])) as ArrayRef; + let lhs_vals = (0..lhs_vals_arr.len()) + .map(|idx| ScalarValue::try_from_array(&lhs_vals_arr, idx)) + .collect::>>()?; + + let rhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])) as ArrayRef; + let rhs_vals = (0..rhs_vals_arr.len()) + .map(|idx| ScalarValue::try_from_array(&rhs_vals_arr, idx)) + .collect::>>()?; + let expected = + Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 3, 3, 4, 4])) as ArrayRef; + let expected_ts = vec![ + Arc::new(Int64Array::from(vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2])) as ArrayRef, + Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 3, 3, 4, 4])) as ArrayRef, + ]; + + let (merged_vals, merged_ts) = merge_ordered_arrays( + &mut [lhs_vals, rhs_vals], + &mut [lhs_orderings, rhs_orderings], + &sort_options, + )?; + let merged_vals = ScalarValue::iter_to_array(merged_vals.into_iter())?; + let merged_ts = (0..merged_ts[0].len()) + .map(|col_idx| { + ScalarValue::iter_to_array( + (0..merged_ts.len()) + .map(|row_idx| merged_ts[row_idx][col_idx].clone()), + ) + }) + .collect::>>()?; + + assert_eq!(&merged_vals, &expected); + assert_eq!(&merged_ts, &expected_ts); + + Ok(()) + } + + #[test] + fn test_merge_desc() -> Result<()> { + let lhs_arrays: Vec = vec![ + Arc::new(Int64Array::from(vec![2, 1, 1, 0, 0])), + Arc::new(Int64Array::from(vec![4, 3, 2, 1, 0])), + ]; + let n_row = lhs_arrays[0].len(); + let lhs_orderings = (0..n_row) + .map(|idx| get_row_at_idx(&lhs_arrays, idx)) + .collect::>>()?; + + let rhs_arrays: Vec = vec![ + Arc::new(Int64Array::from(vec![2, 1, 1, 0, 0])), + Arc::new(Int64Array::from(vec![4, 3, 2, 1, 0])), + ]; + let n_row = rhs_arrays[0].len(); + let rhs_orderings = (0..n_row) + .map(|idx| get_row_at_idx(&rhs_arrays, idx)) + .collect::>>()?; + let sort_options = vec![ + SortOptions { + descending: true, + nulls_first: false, + }, + SortOptions { + descending: true, + nulls_first: false, + }, + ]; + + // Values (which will be merged) doesn't have to be ordered. + let lhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 1, 2])) as ArrayRef; + let lhs_vals = (0..lhs_vals_arr.len()) + .map(|idx| ScalarValue::try_from_array(&lhs_vals_arr, idx)) + .collect::>>()?; + + let rhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 1, 2])) as ArrayRef; + let rhs_vals = (0..rhs_vals_arr.len()) + .map(|idx| ScalarValue::try_from_array(&rhs_vals_arr, idx)) + .collect::>>()?; + let expected = + Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 1, 1, 2, 2])) as ArrayRef; + let expected_ts = vec![ + Arc::new(Int64Array::from(vec![2, 2, 1, 1, 1, 1, 0, 0, 0, 0])) as ArrayRef, + Arc::new(Int64Array::from(vec![4, 4, 3, 3, 2, 2, 1, 1, 0, 0])) as ArrayRef, + ]; + let (merged_vals, merged_ts) = merge_ordered_arrays( + &mut [lhs_vals, rhs_vals], + &mut [lhs_orderings, rhs_orderings], + &sort_options, + )?; + let merged_vals = ScalarValue::iter_to_array(merged_vals.into_iter())?; + let merged_ts = (0..merged_ts[0].len()) + .map(|col_idx| { + ScalarValue::iter_to_array( + (0..merged_ts.len()) + .map(|row_idx| merged_ts[row_idx][col_idx].clone()), + ) + }) + .collect::>>()?; + + assert_eq!(&merged_vals, &expected); + assert_eq!(&merged_ts, &expected_ts); + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 0e619bacef82..ba11f7e91e07 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -36,7 +36,8 @@ use datafusion_expr::{ }; use datafusion_physical_expr_common::aggregate::utils::get_sort_options; use datafusion_physical_expr_common::sort_expr::{ - limited_convert_logical_sort_exprs_to_physical, LexOrdering, PhysicalSortExpr, + limited_convert_logical_sort_exprs_to_physical_with_dfschema, LexOrdering, + PhysicalSortExpr, }; create_func!(FirstValue, first_value_udaf); @@ -116,9 +117,9 @@ impl AggregateUDFImpl for FirstValue { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let ordering_req = limited_convert_logical_sort_exprs_to_physical( + let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( acc_args.sort_exprs, - acc_args.schema, + acc_args.dfschema, )?; let ordering_dtypes = ordering_req @@ -415,9 +416,9 @@ impl AggregateUDFImpl for LastValue { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let ordering_req = limited_convert_logical_sort_exprs_to_physical( + let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( acc_args.sort_exprs, - acc_args.schema, + acc_args.dfschema, )?; let ordering_dtypes = ordering_req diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 6719c673c55b..9bbd68c9bdf6 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -36,7 +36,8 @@ use datafusion_expr::{ use datafusion_physical_expr_common::aggregate::merge_arrays::merge_ordered_arrays; use datafusion_physical_expr_common::aggregate::utils::ordering_fields; use datafusion_physical_expr_common::sort_expr::{ - limited_convert_logical_sort_exprs_to_physical, LexOrdering, PhysicalSortExpr, + limited_convert_logical_sort_exprs_to_physical_with_dfschema, LexOrdering, + PhysicalSortExpr, }; make_udaf_expr_and_func!( @@ -111,9 +112,9 @@ impl AggregateUDFImpl for NthValueAgg { ), }?; - let ordering_req = limited_convert_logical_sort_exprs_to_physical( + let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( acc_args.sort_exprs, - acc_args.schema, + acc_args.dfschema, )?; let ordering_dtypes = ordering_req diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 42cf44f65d8f..247962dc2ce1 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -273,6 +273,7 @@ mod tests { use arrow::{array::*, datatypes::*}; + use datafusion_common::DFSchema; use datafusion_expr::AggregateUDF; use datafusion_physical_expr_common::aggregate::utils::get_accum_scalar_values_as_arrays; use datafusion_physical_expr_common::expressions::column::col; @@ -324,13 +325,16 @@ mod tests { agg2: Arc, schema: &Schema, ) -> Result { + let dfschema = DFSchema::empty(); let args1 = AccumulatorArgs { data_type: &DataType::Float64, schema, + dfschema: &dfschema, ignore_nulls: false, sort_exprs: &[], name: "a", is_distinct: false, + is_reversed: false, input_type: &DataType::Float64, input_exprs: &[datafusion_expr::col("a")], }; @@ -338,10 +342,12 @@ mod tests { let args2 = AccumulatorArgs { data_type: &DataType::Float64, schema, + dfschema: &dfschema, ignore_nulls: false, sort_exprs: &[], name: "a", is_distinct: false, + is_reversed: false, input_type: &DataType::Float64, input_exprs: &[datafusion_expr::col("a")], }; diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 7a4a3a6cac4b..05c7e1caed0e 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -23,7 +23,7 @@ pub mod tdigest; pub mod utils; use arrow::datatypes::{DataType, Field, Schema}; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::{not_impl_err, DFSchema, Result}; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::type_coercion::aggregates::check_arg_count; use datafusion_expr::ReversedUDAF; @@ -51,6 +51,10 @@ use datafusion_expr::utils::AggregateOrderSensitivity; /// /// `input_exprs` and `sort_exprs` are used for customizing Accumulator as the arguments in `AccumulatorArgs`, /// if you don't need them it is fine to pass empty slice `&[]`. +/// +/// `is_reversed` is used to indicate whether the aggregation is running in reverse order, +/// it could be used to hint Accumulator to accumulate in the reversed order, +/// you can just set to false if you are not reversing expression #[allow(clippy::too_many_arguments)] pub fn create_aggregate_expr( fun: &AggregateUDF, @@ -62,6 +66,7 @@ pub fn create_aggregate_expr( name: impl Into, ignore_nulls: bool, is_distinct: bool, + is_reversed: bool, ) -> Result> { debug_assert_eq!(sort_exprs.len(), ordering_req.len()); @@ -81,6 +86,61 @@ pub fn create_aggregate_expr( .map(|e| e.expr.data_type(schema)) .collect::>>()?; + let ordering_fields = ordering_fields(ordering_req, &ordering_types); + let name = name.into(); + + Ok(Arc::new(AggregateFunctionExpr { + fun: fun.clone(), + args: input_phy_exprs.to_vec(), + logical_args: input_exprs.to_vec(), + data_type: fun.return_type(&input_exprs_types)?, + name, + schema: schema.clone(), + dfschema: DFSchema::empty(), + sort_exprs: sort_exprs.to_vec(), + ordering_req: ordering_req.to_vec(), + ignore_nulls, + ordering_fields, + is_distinct, + input_type: input_exprs_types[0].clone(), + is_reversed, + })) +} + +#[allow(clippy::too_many_arguments)] +// This is not for external usage, consider creating with `create_aggregate_expr` instead. +pub fn create_aggregate_expr_with_dfschema( + fun: &AggregateUDF, + input_phy_exprs: &[Arc], + input_exprs: &[Expr], + sort_exprs: &[Expr], + ordering_req: &[PhysicalSortExpr], + dfschema: &DFSchema, + name: impl Into, + ignore_nulls: bool, + is_distinct: bool, + is_reversed: bool, +) -> Result> { + debug_assert_eq!(sort_exprs.len(), ordering_req.len()); + + let schema: Schema = dfschema.into(); + + let input_exprs_types = input_phy_exprs + .iter() + .map(|arg| arg.data_type(&schema)) + .collect::>>()?; + + check_arg_count( + fun.name(), + &input_exprs_types, + &fun.signature().type_signature, + )?; + + let ordering_types = ordering_req + .iter() + .map(|e| e.expr.data_type(&schema)) + .collect::>>()?; + let ordering_fields = ordering_fields(ordering_req, &ordering_types); Ok(Arc::new(AggregateFunctionExpr { @@ -90,12 +150,14 @@ pub fn create_aggregate_expr( data_type: fun.return_type(&input_exprs_types)?, name: name.into(), schema: schema.clone(), + dfschema: dfschema.clone(), sort_exprs: sort_exprs.to_vec(), ordering_req: ordering_req.to_vec(), ignore_nulls, ordering_fields, is_distinct, input_type: input_exprs_types[0].clone(), + is_reversed, })) } @@ -261,6 +323,7 @@ pub struct AggregateFunctionExpr { data_type: DataType, name: String, schema: Schema, + dfschema: DFSchema, // The logical order by expressions sort_exprs: Vec, // The physical order by expressions @@ -270,6 +333,7 @@ pub struct AggregateFunctionExpr { // fields used for order sensitive aggregation functions ordering_fields: Vec, is_distinct: bool, + is_reversed: bool, input_type: DataType, } @@ -288,6 +352,11 @@ impl AggregateFunctionExpr { pub fn ignore_nulls(&self) -> bool { self.ignore_nulls } + + /// Return if the aggregation is distinct + pub fn is_reversed(&self) -> bool { + self.is_reversed + } } impl AggregateExpr for AggregateFunctionExpr { @@ -320,12 +389,14 @@ impl AggregateExpr for AggregateFunctionExpr { let acc_args = AccumulatorArgs { data_type: &self.data_type, schema: &self.schema, + dfschema: &self.dfschema, ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, input_exprs: &self.logical_args, name: &self.name, + is_reversed: self.is_reversed, }; self.fun.accumulator(acc_args) @@ -335,12 +406,14 @@ impl AggregateExpr for AggregateFunctionExpr { let args = AccumulatorArgs { data_type: &self.data_type, schema: &self.schema, + dfschema: &self.dfschema, ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, input_exprs: &self.logical_args, name: &self.name, + is_reversed: self.is_reversed, }; let accumulator = self.fun.create_sliding_accumulator(args)?; @@ -405,12 +478,14 @@ impl AggregateExpr for AggregateFunctionExpr { let args = AccumulatorArgs { data_type: &self.data_type, schema: &self.schema, + dfschema: &self.dfschema, ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, input_exprs: &self.logical_args, name: &self.name, + is_reversed: self.is_reversed, }; self.fun.groups_accumulator_supported(args) } @@ -419,12 +494,14 @@ impl AggregateExpr for AggregateFunctionExpr { let args = AccumulatorArgs { data_type: &self.data_type, schema: &self.schema, + dfschema: &self.dfschema, ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, input_exprs: &self.logical_args, name: &self.name, + is_reversed: self.is_reversed, }; self.fun.create_groups_accumulator(args) } @@ -462,16 +539,17 @@ impl AggregateExpr for AggregateFunctionExpr { else { return Ok(None); }; - create_aggregate_expr( + create_aggregate_expr_with_dfschema( &updated_fn, &self.args, &self.logical_args, &self.sort_exprs, &self.ordering_req, - &self.schema, + &self.dfschema, self.name(), self.ignore_nulls, self.is_distinct, + self.is_reversed, ) .map(Some) } @@ -495,18 +573,23 @@ impl AggregateExpr for AggregateFunctionExpr { }) .collect::>(); let mut name = self.name().to_string(); - replace_order_by_clause(&mut name); + // TODO: Generalize order-by clause rewrite + if reverse_udf.name() == "ARRAY_AGG" { + } else { + replace_order_by_clause(&mut name); + } replace_fn_name_clause(&mut name, self.fun.name(), reverse_udf.name()); - let reverse_aggr = create_aggregate_expr( + let reverse_aggr = create_aggregate_expr_with_dfschema( &reverse_udf, &self.args, &self.logical_args, &reverse_sort_exprs, &reverse_ordering_req, - &self.schema, + &self.dfschema, name, self.ignore_nulls, self.is_distinct, + !self.is_reversed, ) .unwrap(); diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 8fb1356a8092..2b506b74216f 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -22,12 +22,12 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; -use crate::utils::limited_convert_logical_expr_to_physical_expr; +use crate::utils::limited_convert_logical_expr_to_physical_expr_with_dfschema; use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, DFSchema, Result}; use datafusion_expr::{ColumnarValue, Expr}; /// Represents Sort operation for a column in a RecordBatch @@ -275,9 +275,9 @@ pub type LexRequirementRef<'a> = &'a [PhysicalSortRequirement]; /// Converts each [`Expr::Sort`] into a corresponding [`PhysicalSortExpr`]. /// Returns an error if the given logical expression is not a [`Expr::Sort`]. -pub fn limited_convert_logical_sort_exprs_to_physical( +pub fn limited_convert_logical_sort_exprs_to_physical_with_dfschema( exprs: &[Expr], - schema: &Schema, + dfschema: &DFSchema, ) -> Result> { // Construct PhysicalSortExpr objects from Expr objects: let mut sort_exprs = vec![]; @@ -286,7 +286,10 @@ pub fn limited_convert_logical_sort_exprs_to_physical( return exec_err!("Expects to receive sort expression"); }; sort_exprs.push(PhysicalSortExpr::new( - limited_convert_logical_expr_to_physical_expr(sort.expr.as_ref(), schema)?, + limited_convert_logical_expr_to_physical_expr_with_dfschema( + sort.expr.as_ref(), + dfschema, + )?, SortOptions { descending: !sort.asc, nulls_first: sort.nulls_first, diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index 44622bd309df..0978a906a5dc 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -19,15 +19,15 @@ use std::sync::Arc; use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; -use arrow::datatypes::Schema; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, DFSchema, Result}; use datafusion_expr::expr::Alias; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::Expr; +use crate::expressions::column::Column; use crate::expressions::literal::Literal; -use crate::expressions::{self, CastExpr}; +use crate::expressions::CastExpr; use crate::physical_expr::PhysicalExpr; use crate::sort_expr::PhysicalSortExpr; use crate::tree_node::ExprContext; @@ -110,19 +110,22 @@ pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec`. /// If conversion is not supported yet, returns Error. -pub fn limited_convert_logical_expr_to_physical_expr( +pub fn limited_convert_logical_expr_to_physical_expr_with_dfschema( expr: &Expr, - schema: &Schema, + dfschema: &DFSchema, ) -> Result> { match expr { - Expr::Alias(Alias { expr, .. }) => { - Ok(limited_convert_logical_expr_to_physical_expr(expr, schema)?) + Expr::Alias(Alias { expr, .. }) => Ok( + limited_convert_logical_expr_to_physical_expr_with_dfschema(expr, dfschema)?, + ), + Expr::Column(col) => { + let idx = dfschema.index_of_column(col)?; + Ok(Arc::new(Column::new(&col.name, idx))) } - Expr::Column(col) => expressions::column::col(&col.name, schema), Expr::Cast(cast_expr) => Ok(Arc::new(CastExpr::new( - limited_convert_logical_expr_to_physical_expr( + limited_convert_logical_expr_to_physical_expr_with_dfschema( cast_expr.expr.as_ref(), - schema, + dfschema, )?, cast_expr.data_type.clone(), None, diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs deleted file mode 100644 index 992c06f5bf62..000000000000 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ /dev/null @@ -1,520 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions which specify ordering requirement -//! that can evaluated at runtime during query execution - -use std::any::Any; -use std::collections::VecDeque; -use std::fmt::Debug; -use std::sync::Arc; - -use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; -use crate::expressions::format_state_name; -use crate::{ - reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, -}; - -use arrow::datatypes::{DataType, Field}; -use arrow_array::cast::AsArray; -use arrow_array::{new_empty_array, Array, ArrayRef, StructArray}; -use arrow_schema::Fields; -use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; -use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::utils::AggregateOrderSensitivity; -use datafusion_expr::Accumulator; -use datafusion_physical_expr_common::aggregate::merge_arrays::merge_ordered_arrays; - -/// Expression for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi -/// partition setting, partial aggregations are computed for every partition, -/// and then their results are merged. -#[derive(Debug)] -pub struct OrderSensitiveArrayAgg { - /// Column name - name: String, - /// The `DataType` for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, - /// Ordering data types - order_by_data_types: Vec, - /// Ordering requirement - ordering_req: LexOrdering, - /// Whether the aggregation is running in reverse - reverse: bool, -} - -impl OrderSensitiveArrayAgg { - /// Create a new `OrderSensitiveArrayAgg` aggregate function - pub fn new( - expr: Arc, - name: impl Into, - input_data_type: DataType, - order_by_data_types: Vec, - ordering_req: LexOrdering, - ) -> Self { - Self { - name: name.into(), - input_data_type, - expr, - order_by_data_types, - ordering_req, - reverse: false, - } - } -} - -impl AggregateExpr for OrderSensitiveArrayAgg { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new_list( - &self.name, - // This should be the same as return type of AggregateFunction::OrderSensitiveArrayAgg - Field::new("item", self.input_data_type.clone(), true), - true, - )) - } - - fn create_accumulator(&self) -> Result> { - OrderSensitiveArrayAggAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.reverse, - ) - .map(|acc| Box::new(acc) as _) - } - - fn state_fields(&self) -> Result> { - let mut fields = vec![Field::new_list( - format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), true), - true, // This should be the same as field() - )]; - let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); - fields.push(Field::new_list( - format_state_name(&self.name, "array_agg_orderings"), - Field::new("item", DataType::Struct(Fields::from(orderings)), true), - false, - )); - Ok(fields) - } - - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) - } - - fn order_sensitivity(&self) -> AggregateOrderSensitivity { - AggregateOrderSensitivity::HardRequirement - } - - fn name(&self) -> &str { - &self.name - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(Self { - name: self.name.to_string(), - input_data_type: self.input_data_type.clone(), - expr: Arc::clone(&self.expr), - order_by_data_types: self.order_by_data_types.clone(), - // Reverse requirement: - ordering_req: reverse_order_bys(&self.ordering_req), - reverse: !self.reverse, - })) - } -} - -impl PartialEq for OrderSensitiveArrayAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.order_by_data_types == x.order_by_data_types - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -pub(crate) struct OrderSensitiveArrayAggAccumulator { - /// Stores entries in the `ARRAY_AGG` result. - values: Vec, - /// Stores values of ordering requirement expressions corresponding to each - /// entry in `values`. This information is used when merging results from - /// different partitions. For detailed information how merging is done, see - /// [`merge_ordered_arrays`]. - ordering_values: Vec>, - /// Stores datatypes of expressions inside values and ordering requirement - /// expressions. - datatypes: Vec, - /// Stores the ordering requirement of the `Accumulator`. - ordering_req: LexOrdering, - /// Whether the aggregation is running in reverse. - reverse: bool, -} - -impl OrderSensitiveArrayAggAccumulator { - /// Create a new order-sensitive ARRAY_AGG accumulator based on the given - /// item data type. - pub fn try_new( - datatype: &DataType, - ordering_dtypes: &[DataType], - ordering_req: LexOrdering, - reverse: bool, - ) -> Result { - let mut datatypes = vec![datatype.clone()]; - datatypes.extend(ordering_dtypes.iter().cloned()); - Ok(Self { - values: vec![], - ordering_values: vec![], - datatypes, - ordering_req, - reverse, - }) - } -} - -impl Accumulator for OrderSensitiveArrayAggAccumulator { - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - let n_row = values[0].len(); - for index in 0..n_row { - let row = get_row_at_idx(values, index)?; - self.values.push(row[0].clone()); - self.ordering_values.push(row[1..].to_vec()); - } - - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - // First entry in the state is the aggregation result. Second entry - // stores values received for ordering requirement columns for each - // aggregation value inside `ARRAY_AGG` list. For each `StructArray` - // inside `ARRAY_AGG` list, we will receive an `Array` that stores values - // received from its ordering requirement expression. (This information - // is necessary for during merging). - let [array_agg_values, agg_orderings, ..] = &states else { - return exec_err!("State should have two elements"); - }; - let Some(agg_orderings) = agg_orderings.as_list_opt::() else { - return exec_err!("Expects to receive a list array"); - }; - - // Stores ARRAY_AGG results coming from each partition - let mut partition_values = vec![]; - // Stores ordering requirement expression results coming from each partition - let mut partition_ordering_values = vec![]; - - // Existing values should be merged also. - partition_values.push(self.values.clone().into()); - partition_ordering_values.push(self.ordering_values.clone().into()); - - // Convert array to Scalars to sort them easily. Convert back to array at evaluation. - let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; - for v in array_agg_res.into_iter() { - partition_values.push(v.into()); - } - - let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; - - for partition_ordering_rows in orderings.into_iter() { - // Extract value from struct to ordering_rows for each group/partition - let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| { - if let ScalarValue::Struct(s) = ordering_row { - let mut ordering_columns_per_row = vec![]; - - for column in s.columns() { - let sv = ScalarValue::try_from_array(column, 0)?; - ordering_columns_per_row.push(sv); - } - - Ok(ordering_columns_per_row) - } else { - exec_err!( - "Expects to receive ScalarValue::Struct(Arc) but got:{:?}", - ordering_row.data_type() - ) - } - }).collect::>>()?; - - partition_ordering_values.push(ordering_value); - } - - let sort_options = self - .ordering_req - .iter() - .map(|sort_expr| sort_expr.options) - .collect::>(); - - (self.values, self.ordering_values) = merge_ordered_arrays( - &mut partition_values, - &mut partition_ordering_values, - &sort_options, - )?; - - Ok(()) - } - - fn state(&mut self) -> Result> { - let mut result = vec![self.evaluate()?]; - result.push(self.evaluate_orderings()?); - Ok(result) - } - - fn evaluate(&mut self) -> Result { - if self.values.is_empty() { - return Ok(ScalarValue::new_null_list( - self.datatypes[0].clone(), - true, - 1, - )); - } - - let values = self.values.clone(); - let array = if self.reverse { - ScalarValue::new_list_from_iter( - values.into_iter().rev(), - &self.datatypes[0], - true, - ) - } else { - ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true) - }; - Ok(ScalarValue::List(array)) - } - - fn size(&self) -> usize { - let mut total = std::mem::size_of_val(self) - + ScalarValue::size_of_vec(&self.values) - - std::mem::size_of_val(&self.values); - - // Add size of the `self.ordering_values` - total += - std::mem::size_of::>() * self.ordering_values.capacity(); - for row in &self.ordering_values { - total += ScalarValue::size_of_vec(row) - std::mem::size_of_val(row); - } - - // Add size of the `self.datatypes` - total += std::mem::size_of::() * self.datatypes.capacity(); - for dtype in &self.datatypes { - total += dtype.size() - std::mem::size_of_val(dtype); - } - - // Add size of the `self.ordering_req` - total += std::mem::size_of::() * self.ordering_req.capacity(); - // TODO: Calculate size of each `PhysicalSortExpr` more accurately. - total - } -} - -impl OrderSensitiveArrayAggAccumulator { - fn evaluate_orderings(&self) -> Result { - let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); - let num_columns = fields.len(); - let struct_field = Fields::from(fields.clone()); - - let mut column_wise_ordering_values = vec![]; - for i in 0..num_columns { - let column_values = self - .ordering_values - .iter() - .map(|x| x[i].clone()) - .collect::>(); - let array = if column_values.is_empty() { - new_empty_array(fields[i].data_type()) - } else { - ScalarValue::iter_to_array(column_values.into_iter())? - }; - column_wise_ordering_values.push(array); - } - - let ordering_array = StructArray::try_new( - struct_field.clone(), - column_wise_ordering_values, - None, - )?; - Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable( - Arc::new(ordering_array), - )))) - } -} - -#[cfg(test)] -mod tests { - use std::collections::VecDeque; - use std::sync::Arc; - - use crate::aggregate::array_agg_ordered::merge_ordered_arrays; - - use arrow_array::{Array, ArrayRef, Int64Array}; - use arrow_schema::SortOptions; - use datafusion_common::utils::get_row_at_idx; - use datafusion_common::{Result, ScalarValue}; - - #[test] - fn test_merge_asc() -> Result<()> { - let lhs_arrays: Vec = vec![ - Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2])), - Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])), - ]; - let n_row = lhs_arrays[0].len(); - let lhs_orderings = (0..n_row) - .map(|idx| get_row_at_idx(&lhs_arrays, idx)) - .collect::>>()?; - - let rhs_arrays: Vec = vec![ - Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2])), - Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])), - ]; - let n_row = rhs_arrays[0].len(); - let rhs_orderings = (0..n_row) - .map(|idx| get_row_at_idx(&rhs_arrays, idx)) - .collect::>>()?; - let sort_options = vec![ - SortOptions { - descending: false, - nulls_first: false, - }, - SortOptions { - descending: false, - nulls_first: false, - }, - ]; - - let lhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])) as ArrayRef; - let lhs_vals = (0..lhs_vals_arr.len()) - .map(|idx| ScalarValue::try_from_array(&lhs_vals_arr, idx)) - .collect::>>()?; - - let rhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])) as ArrayRef; - let rhs_vals = (0..rhs_vals_arr.len()) - .map(|idx| ScalarValue::try_from_array(&rhs_vals_arr, idx)) - .collect::>>()?; - let expected = - Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 3, 3, 4, 4])) as ArrayRef; - let expected_ts = vec![ - Arc::new(Int64Array::from(vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2])) as ArrayRef, - Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 3, 3, 4, 4])) as ArrayRef, - ]; - - let (merged_vals, merged_ts) = merge_ordered_arrays( - &mut [lhs_vals, rhs_vals], - &mut [lhs_orderings, rhs_orderings], - &sort_options, - )?; - let merged_vals = ScalarValue::iter_to_array(merged_vals.into_iter())?; - let merged_ts = (0..merged_ts[0].len()) - .map(|col_idx| { - ScalarValue::iter_to_array( - (0..merged_ts.len()) - .map(|row_idx| merged_ts[row_idx][col_idx].clone()), - ) - }) - .collect::>>()?; - - assert_eq!(&merged_vals, &expected); - assert_eq!(&merged_ts, &expected_ts); - - Ok(()) - } - - #[test] - fn test_merge_desc() -> Result<()> { - let lhs_arrays: Vec = vec![ - Arc::new(Int64Array::from(vec![2, 1, 1, 0, 0])), - Arc::new(Int64Array::from(vec![4, 3, 2, 1, 0])), - ]; - let n_row = lhs_arrays[0].len(); - let lhs_orderings = (0..n_row) - .map(|idx| get_row_at_idx(&lhs_arrays, idx)) - .collect::>>()?; - - let rhs_arrays: Vec = vec![ - Arc::new(Int64Array::from(vec![2, 1, 1, 0, 0])), - Arc::new(Int64Array::from(vec![4, 3, 2, 1, 0])), - ]; - let n_row = rhs_arrays[0].len(); - let rhs_orderings = (0..n_row) - .map(|idx| get_row_at_idx(&rhs_arrays, idx)) - .collect::>>()?; - let sort_options = vec![ - SortOptions { - descending: true, - nulls_first: false, - }, - SortOptions { - descending: true, - nulls_first: false, - }, - ]; - - // Values (which will be merged) doesn't have to be ordered. - let lhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 1, 2])) as ArrayRef; - let lhs_vals = (0..lhs_vals_arr.len()) - .map(|idx| ScalarValue::try_from_array(&lhs_vals_arr, idx)) - .collect::>>()?; - - let rhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 1, 2])) as ArrayRef; - let rhs_vals = (0..rhs_vals_arr.len()) - .map(|idx| ScalarValue::try_from_array(&rhs_vals_arr, idx)) - .collect::>>()?; - let expected = - Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 1, 1, 2, 2])) as ArrayRef; - let expected_ts = vec![ - Arc::new(Int64Array::from(vec![2, 2, 1, 1, 1, 1, 0, 0, 0, 0])) as ArrayRef, - Arc::new(Int64Array::from(vec![4, 4, 3, 3, 2, 2, 1, 1, 0, 0])) as ArrayRef, - ]; - let (merged_vals, merged_ts) = merge_ordered_arrays( - &mut [lhs_vals, rhs_vals], - &mut [lhs_orderings, rhs_orderings], - &sort_options, - )?; - let merged_vals = ScalarValue::iter_to_array(merged_vals.into_iter())?; - let merged_ts = (0..merged_ts[0].len()) - .map(|col_idx| { - ScalarValue::iter_to_array( - (0..merged_ts.len()) - .map(|row_idx| merged_ts[row_idx][col_idx].clone()), - ) - }) - .collect::>>()?; - - assert_eq!(&merged_vals, &expected); - assert_eq!(&merged_ts, &expected_ts); - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 9c270561f37d..27c1533d0552 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -42,7 +42,7 @@ pub fn create_aggregate_expr( fun: &AggregateFunction, distinct: bool, input_phy_exprs: &[Arc], - ordering_req: &[PhysicalSortExpr], + _ordering_req: &[PhysicalSortExpr], input_schema: &Schema, name: impl Into, _ignore_nulls: bool, @@ -54,29 +54,9 @@ pub fn create_aggregate_expr( .map(|e| e.data_type(input_schema)) .collect::>>()?; let data_type = input_phy_types[0].clone(); - let ordering_types = ordering_req - .iter() - .map(|e| e.expr.data_type(input_schema)) - .collect::>>()?; let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { - (AggregateFunction::ArrayAgg, _) => { - let expr = Arc::clone(&input_phy_exprs[0]); - - if ordering_req.is_empty() { - return internal_err!( - "ArrayAgg without ordering should be handled as UDAF" - ); - } else { - Arc::new(expressions::OrderSensitiveArrayAgg::new( - expr, - name, - data_type, - ordering_types, - ordering_req.to_vec(), - )) - } - } + (AggregateFunction::ArrayAgg, _) => return internal_err!("not reachable"), (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( Arc::clone(&input_phy_exprs[0]), name, diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 749cf2be7297..264c48513050 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -15,9 +15,6 @@ // specific language governing permissions and limitations // under the License. -pub use datafusion_physical_expr_common::aggregate::AggregateExpr; - -pub(crate) mod array_agg_ordered; #[macro_use] pub(crate) mod min_max; pub(crate) mod groups_accumulator; @@ -31,3 +28,5 @@ pub mod utils { get_sort_options, ordering_fields, DecimalAverager, Hashable, }; } + +pub use datafusion_physical_expr_common::aggregate::AggregateExpr; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 5a2bcb63b18e..7cbe4e796844 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -34,7 +34,6 @@ mod unknown_column; pub mod helpers { pub use crate::aggregate::min_max::{max, min}; } -pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; pub use crate::aggregate::build_in::create_aggregate_expr; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::stats::StatsType; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 4146dda7641d..e7cd5cb2725b 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1194,22 +1194,25 @@ mod tests { use arrow::datatypes::DataType; use arrow_array::{Float32Array, Int32Array}; use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError, - ScalarValue, + assert_batches_eq, assert_batches_sorted_eq, internal_err, DFSchema, DFSchemaRef, + DataFusionError, ScalarValue, }; use datafusion_execution::config::SessionConfig; use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::expr::Sort; + use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::first_last::{FirstValue, LastValue}; use datafusion_functions_aggregate::median::median_udaf; - use datafusion_physical_expr::expressions::{lit, OrderSensitiveArrayAgg}; + use datafusion_physical_expr::expressions::lit; use datafusion_physical_expr::PhysicalSortExpr; use crate::common::collect; - use datafusion_physical_expr_common::aggregate::create_aggregate_expr; + use datafusion_physical_expr_common::aggregate::{ + create_aggregate_expr, create_aggregate_expr_with_dfschema, + }; use datafusion_physical_expr_common::expressions::Literal; use futures::{FutureExt, Stream}; @@ -1258,19 +1261,22 @@ mod tests { } /// Generates some mock data for aggregate tests. - fn some_data_v2() -> (Arc, Vec) { + fn some_data_v2() -> (Arc, DFSchemaRef, Vec) { // Define a schema: let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::UInt32, false), Field::new("b", DataType::Float64, false), ])); + let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap(); + // Generate data so that first and last value results are at 2nd and // 3rd partitions. With this construction, we guarantee we don't receive // the expected result by accident, but merging actually works properly; // i.e. it doesn't depend on the data insertion order. ( Arc::clone(&schema), + Arc::new(df_schema), vec![ RecordBatch::try_new( Arc::clone(&schema), @@ -1355,6 +1361,7 @@ mod tests { "COUNT(1)", false, false, + false, )?]; let task_ctx = if spill { @@ -1504,6 +1511,7 @@ mod tests { "AVG(b)", false, false, + false, )?]; let task_ctx = if spill { @@ -1808,6 +1816,7 @@ mod tests { "MEDIAN(a)", false, false, + false, ) } @@ -1844,6 +1853,7 @@ mod tests { "AVG(b)", false, false, + false, )?]; for (version, groups, aggregates) in [ @@ -1908,6 +1918,7 @@ mod tests { "AVG(a)", false, false, + false, )?]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); @@ -1952,6 +1963,7 @@ mod tests { "AVG(b)", false, false, + false, )?]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); @@ -1996,12 +2008,11 @@ mod tests { // FIRST_VALUE(b ORDER BY b ) fn test_first_value_agg_expr( schema: &Schema, + dfschema: &DFSchema, sort_options: SortOptions, ) -> Result> { let sort_exprs = vec![datafusion_expr::Expr::Sort(Sort { - expr: Box::new(datafusion_expr::Expr::Column( - datafusion_common::Column::new(Some("table1"), "b"), - )), + expr: Box::new(datafusion_expr::col("b")), asc: !sort_options.descending, nulls_first: sort_options.nulls_first, })]; @@ -2012,28 +2023,28 @@ mod tests { let args = vec![col("b", schema)?]; let logical_args = vec![datafusion_expr::col("b")]; let func = datafusion_expr::AggregateUDF::new_from_impl(FirstValue::new()); - datafusion_physical_expr_common::aggregate::create_aggregate_expr( + datafusion_physical_expr_common::aggregate::create_aggregate_expr_with_dfschema( &func, &args, &logical_args, &sort_exprs, &ordering_req, - schema, + dfschema, "FIRST_VALUE(b)", false, false, + false, ) } // LAST_VALUE(b ORDER BY b ) fn test_last_value_agg_expr( schema: &Schema, + dfschema: &DFSchema, sort_options: SortOptions, ) -> Result> { let sort_exprs = vec![datafusion_expr::Expr::Sort(Sort { - expr: Box::new(datafusion_expr::Expr::Column( - datafusion_common::Column::new(Some("table1"), "b"), - )), + expr: Box::new(datafusion_expr::col("b")), asc: !sort_options.descending, nulls_first: sort_options.nulls_first, })]; @@ -2044,16 +2055,17 @@ mod tests { let args = vec![col("b", schema)?]; let logical_args = vec![datafusion_expr::col("b")]; let func = datafusion_expr::AggregateUDF::new_from_impl(LastValue::new()); - create_aggregate_expr( + create_aggregate_expr_with_dfschema( &func, &args, &logical_args, &sort_exprs, &ordering_req, - schema, + dfschema, "LAST_VALUE(b)", false, false, + false, ) } @@ -2086,7 +2098,7 @@ mod tests { Arc::new(TaskContext::default()) }; - let (schema, data) = some_data_v2(); + let (schema, df_schema, data) = some_data_v2(); let partition1 = data[0].clone(); let partition2 = data[1].clone(); let partition3 = data[2].clone(); @@ -2100,9 +2112,13 @@ mod tests { nulls_first: false, }; let aggregates: Vec> = if is_first_acc { - vec![test_first_value_agg_expr(&schema, sort_options)?] + vec![test_first_value_agg_expr( + &schema, + &df_schema, + sort_options, + )?] } else { - vec![test_last_value_agg_expr(&schema, sort_options)?] + vec![test_last_value_agg_expr(&schema, &df_schema, sort_options)?] }; let memory_exec = Arc::new(MemoryExec::try_new( @@ -2169,6 +2185,8 @@ mod tests { #[tokio::test] async fn test_get_finest_requirements() -> Result<()> { let test_schema = create_test_schema()?; + let test_df_schema = DFSchema::try_from(Arc::clone(&test_schema)).unwrap(); + // Assume column a and b are aliases // Assume also that a ASC and c DESC describe the same global ordering for the table. (Since they are ordering equivalent). let options1 = SortOptions { @@ -2178,7 +2196,7 @@ mod tests { let col_a = &col("a", &test_schema)?; let col_b = &col("b", &test_schema)?; let col_c = &col("c", &test_schema)?; - let mut eq_properties = EquivalenceProperties::new(test_schema); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); // Columns a and b are equal. eq_properties.add_equal_conditions(col_a, col_b)?; // Aggregate requirements are @@ -2214,6 +2232,46 @@ mod tests { }, ]), ]; + let col_expr_a = Box::new(datafusion_expr::col("a")); + let col_expr_b = Box::new(datafusion_expr::col("b")); + let col_expr_c = Box::new(datafusion_expr::col("c")); + let sort_exprs = vec![ + None, + Some(vec![datafusion_expr::Expr::Sort(Sort::new( + col_expr_a.clone(), + options1.descending, + options1.nulls_first, + ))]), + Some(vec![ + datafusion_expr::Expr::Sort(Sort::new( + col_expr_a.clone(), + options1.descending, + options1.nulls_first, + )), + datafusion_expr::Expr::Sort(Sort::new( + col_expr_b.clone(), + options1.descending, + options1.nulls_first, + )), + datafusion_expr::Expr::Sort(Sort::new( + col_expr_c, + options1.descending, + options1.nulls_first, + )), + ]), + Some(vec![ + datafusion_expr::Expr::Sort(Sort::new( + col_expr_a, + options1.descending, + options1.nulls_first, + )), + datafusion_expr::Expr::Sort(Sort::new( + col_expr_b, + options1.descending, + options1.nulls_first, + )), + ]), + ]; let common_requirement = vec![ PhysicalSortExpr { expr: Arc::clone(col_a), @@ -2226,14 +2284,23 @@ mod tests { ]; let mut aggr_exprs = order_by_exprs .into_iter() - .map(|order_by_expr| { - Arc::new(OrderSensitiveArrayAgg::new( - Arc::clone(col_a), + .zip(sort_exprs.into_iter()) + .map(|(order_by_expr, sort_exprs)| { + let ordering_req = order_by_expr.unwrap_or_default(); + let sort_exprs = sort_exprs.unwrap_or_default(); + create_aggregate_expr_with_dfschema( + &array_agg_udaf(), + &[Arc::clone(col_a)], + &[], + &sort_exprs, + &ordering_req, + &test_df_schema, "array_agg", - DataType::Int32, - vec![], - order_by_expr.unwrap_or_default(), - )) as _ + false, + false, + false, + ) + .unwrap() }) .collect::>(); let group_by = PhysicalGroupBy::new_single(vec![]); @@ -2254,6 +2321,7 @@ mod tests { Field::new("a", DataType::Float32, true), Field::new("b", DataType::Float32, true), ])); + let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap(); let col_a = col("a", &schema)?; let option_desc = SortOptions { @@ -2263,8 +2331,8 @@ mod tests { let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]); let aggregates: Vec> = vec![ - test_first_value_agg_expr(&schema, option_desc)?, - test_last_value_agg_expr(&schema, option_desc)?, + test_first_value_agg_expr(&schema, &df_schema, option_desc)?, + test_last_value_agg_expr(&schema, &df_schema, option_desc)?, ]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let aggregate_exec = Arc::new(AggregateExec::try_new( @@ -2330,6 +2398,7 @@ mod tests { "1", false, false, + false, )?]; let input_batches = (0..4) diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 046977da0a37..c834005bb7c3 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -92,7 +92,7 @@ pub mod work_table; pub mod udaf { pub use datafusion_physical_expr_common::aggregate::{ - create_aggregate_expr, AggregateFunctionExpr, + create_aggregate_expr, create_aggregate_expr_with_dfschema, AggregateFunctionExpr, }; } diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 5eca7af19d16..959796489c19 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -157,6 +157,7 @@ pub fn create_window_expr( name, ignore_nulls, false, + false, )?; window_expr_from_aggregate_expr( partition_by, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 9e17c19ecbc5..8c9e5bbd0e95 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -507,7 +507,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { // TODO: `order by` is not supported for UDAF yet let sort_exprs = &[]; let ordering_req = &[]; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, sort_exprs, ordering_req, &physical_schema, name, agg_node.ignore_nulls, agg_node.distinct) + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, sort_exprs, ordering_req, &physical_schema, name, agg_node.ignore_nulls, agg_node.distinct, false) } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index e9a90fce2663..140482b9903c 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -24,8 +24,8 @@ use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, InListExpr, IsNotNullExpr, - IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, Ntile, - OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr, WindowShift, + IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, Ntile, Rank, + RankType, RowNumber, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -260,10 +260,8 @@ struct AggrFn { fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let aggr_expr = expr.as_any(); - // TODO: remove OrderSensitiveArrayAgg - let inner = if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ArrayAgg - } else if aggr_expr.downcast_ref::().is_some() { + // TODO: remove Min and Max + let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Min } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Max diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index fba6dfe42599..31ed0837d2f5 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -301,6 +301,7 @@ fn roundtrip_window() -> Result<()> { "avg(b)", false, false, + false, )?, &[], &[], @@ -324,6 +325,7 @@ fn roundtrip_window() -> Result<()> { "SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING", false, false, + false, )?; let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new( @@ -367,6 +369,7 @@ fn rountrip_aggregate() -> Result<()> { "AVG(b)", false, false, + false, )?], // NTH_VALUE vec![create_aggregate_expr( @@ -379,6 +382,7 @@ fn rountrip_aggregate() -> Result<()> { "NTH_VALUE(b, 1)", false, false, + false, )?], // STRING_AGG vec![create_aggregate_expr( @@ -394,6 +398,7 @@ fn rountrip_aggregate() -> Result<()> { "STRING_AGG(name, ',')", false, false, + false, )?], ]; @@ -431,6 +436,7 @@ fn rountrip_aggregate_with_limit() -> Result<()> { "AVG(b)", false, false, + false, )?]; let agg = AggregateExec::try_new( @@ -502,6 +508,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { "example_agg", false, false, + false, )?]; roundtrip_test_with_context( @@ -1000,6 +1007,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { "aggregate_udf", false, false, + false, )?; let filter = Arc::new(FilterExec::try_new( @@ -1032,6 +1040,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { "aggregate_udf", true, true, + false, )?; let aggregate = Arc::new(AggregateExec::try_new( From deef834e7adfd859414448ab4da461e2d4eabb9e Mon Sep 17 00:00:00 2001 From: Oleks V Date: Mon, 22 Jul 2024 17:46:23 -0700 Subject: [PATCH 36/37] Minor: Disable flaky antijoin test until perm fix (#11608) --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 604c1f93e55e..f1cca66712d7 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -250,6 +250,9 @@ async fn test_anti_join_1k() { } #[tokio::test] +#[ignore] +// flaky test giving 1 rows difference sometimes +// https://github.com/apache/datafusion/issues/11555 async fn test_anti_join_1k_filtered() { // NLJ vs HJ gives wrong result // Tracked in https://github.com/apache/datafusion/issues/11537 From 77311a5896272c7ed252d8cd53d48ec6ea7c0ccf Mon Sep 17 00:00:00 2001 From: Leonardo Yvens Date: Tue, 23 Jul 2024 11:18:00 +0100 Subject: [PATCH 37/37] support Decimal256 type in datafusion-proto (#11606) --- .../proto/datafusion_common.proto | 7 + datafusion/proto-common/src/from_proto/mod.rs | 4 + .../proto-common/src/generated/pbjson.rs | 125 ++++++++++++++++++ .../proto-common/src/generated/prost.rs | 12 +- datafusion/proto-common/src/to_proto/mod.rs | 7 +- .../src/generated/datafusion_proto_common.rs | 12 +- .../tests/cases/roundtrip_logical_plan.rs | 2 + 7 files changed, 164 insertions(+), 5 deletions(-) diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index ca95136dadd9..8e8fd2352c6c 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -130,6 +130,12 @@ message Decimal{ int32 scale = 4; } +message Decimal256Type{ + reserved 1, 2; + uint32 precision = 3; + int32 scale = 4; +} + message List{ Field field_type = 1; } @@ -335,6 +341,7 @@ message ArrowType{ TimeUnit TIME64 = 22 ; IntervalUnit INTERVAL = 23 ; Decimal DECIMAL = 24 ; + Decimal256Type DECIMAL256 = 36; List LIST = 25; List LARGE_LIST = 26; FixedSizeList FIXED_SIZE_LIST = 27; diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index 9191ff185a04..5fe9d937f7c4 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -260,6 +260,10 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { precision, scale, }) => DataType::Decimal128(*precision as u8, *scale as i8), + arrow_type::ArrowTypeEnum::Decimal256(protobuf::Decimal256Type { + precision, + scale, + }) => DataType::Decimal256(*precision as u8, *scale as i8), arrow_type::ArrowTypeEnum::List(list) => { let list_type = list.as_ref().field_type.as_deref().required("field_type")?; diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index 4b34660ae2ef..511072f3cb55 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -175,6 +175,9 @@ impl serde::Serialize for ArrowType { arrow_type::ArrowTypeEnum::Decimal(v) => { struct_ser.serialize_field("DECIMAL", v)?; } + arrow_type::ArrowTypeEnum::Decimal256(v) => { + struct_ser.serialize_field("DECIMAL256", v)?; + } arrow_type::ArrowTypeEnum::List(v) => { struct_ser.serialize_field("LIST", v)?; } @@ -241,6 +244,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { "TIME64", "INTERVAL", "DECIMAL", + "DECIMAL256", "LIST", "LARGE_LIST", "LARGELIST", @@ -282,6 +286,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { Time64, Interval, Decimal, + Decimal256, List, LargeList, FixedSizeList, @@ -338,6 +343,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { "TIME64" => Ok(GeneratedField::Time64), "INTERVAL" => Ok(GeneratedField::Interval), "DECIMAL" => Ok(GeneratedField::Decimal), + "DECIMAL256" => Ok(GeneratedField::Decimal256), "LIST" => Ok(GeneratedField::List), "LARGELIST" | "LARGE_LIST" => Ok(GeneratedField::LargeList), "FIXEDSIZELIST" | "FIXED_SIZE_LIST" => Ok(GeneratedField::FixedSizeList), @@ -556,6 +562,13 @@ impl<'de> serde::Deserialize<'de> for ArrowType { return Err(serde::de::Error::duplicate_field("DECIMAL")); } arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal) +; + } + GeneratedField::Decimal256 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("DECIMAL256")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal256) ; } GeneratedField::List => { @@ -2849,6 +2862,118 @@ impl<'de> serde::Deserialize<'de> for Decimal256 { deserializer.deserialize_struct("datafusion_common.Decimal256", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for Decimal256Type { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.precision != 0 { + len += 1; + } + if self.scale != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal256Type", len)?; + if self.precision != 0 { + struct_ser.serialize_field("precision", &self.precision)?; + } + if self.scale != 0 { + struct_ser.serialize_field("scale", &self.scale)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Decimal256Type { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "precision", + "scale", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Precision, + Scale, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "precision" => Ok(GeneratedField::Precision), + "scale" => Ok(GeneratedField::Scale), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Decimal256Type; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Decimal256Type") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut precision__ = None; + let mut scale__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Precision => { + if precision__.is_some() { + return Err(serde::de::Error::duplicate_field("precision")); + } + precision__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Scale => { + if scale__.is_some() { + return Err(serde::de::Error::duplicate_field("scale")); + } + scale__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(Decimal256Type { + precision: precision__.unwrap_or_default(), + scale: scale__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Decimal256Type", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for DfField { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index 9a2770997f15..62919e218b13 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -140,6 +140,14 @@ pub struct Decimal { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct Decimal256Type { + #[prost(uint32, tag = "3")] + pub precision: u32, + #[prost(int32, tag = "4")] + pub scale: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct List { #[prost(message, optional, boxed, tag = "1")] pub field_type: ::core::option::Option<::prost::alloc::boxed::Box>, @@ -446,7 +454,7 @@ pub struct Decimal256 { pub struct ArrowType { #[prost( oneof = "arrow_type::ArrowTypeEnum", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33" )] pub arrow_type_enum: ::core::option::Option, } @@ -516,6 +524,8 @@ pub mod arrow_type { Interval(i32), #[prost(message, tag = "24")] Decimal(super::Decimal), + #[prost(message, tag = "36")] + Decimal256(super::Decimal256Type), #[prost(message, tag = "25")] List(::prost::alloc::boxed::Box), #[prost(message, tag = "26")] diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index 9dcb65444a47..c15da2895b7c 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -191,9 +191,10 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { precision: *precision as u32, scale: *scale as i32, }), - DataType::Decimal256(_, _) => { - return Err(Error::General("Proto serialization error: The Decimal256 data type is not yet supported".to_owned())) - } + DataType::Decimal256(precision, scale) => Self::Decimal256(protobuf::Decimal256Type { + precision: *precision as u32, + scale: *scale as i32, + }), DataType::Map(field, sorted) => { Self::Map(Box::new( protobuf::Map { diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index 9a2770997f15..62919e218b13 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -140,6 +140,14 @@ pub struct Decimal { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct Decimal256Type { + #[prost(uint32, tag = "3")] + pub precision: u32, + #[prost(int32, tag = "4")] + pub scale: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct List { #[prost(message, optional, boxed, tag = "1")] pub field_type: ::core::option::Option<::prost::alloc::boxed::Box>, @@ -446,7 +454,7 @@ pub struct Decimal256 { pub struct ArrowType { #[prost( oneof = "arrow_type::ArrowTypeEnum", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33" )] pub arrow_type_enum: ::core::option::Option, } @@ -516,6 +524,8 @@ pub mod arrow_type { Interval(i32), #[prost(message, tag = "24")] Decimal(super::Decimal), + #[prost(message, tag = "36")] + Decimal256(super::Decimal256Type), #[prost(message, tag = "25")] List(::prost::alloc::boxed::Box), #[prost(message, tag = "26")] diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3476d5d042cc..f6557c7b2d8f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -27,6 +27,7 @@ use arrow::array::{ use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, + DECIMAL256_MAX_PRECISION, }; use prost::Message; @@ -1379,6 +1380,7 @@ fn round_trip_datatype() { DataType::Utf8, DataType::LargeUtf8, DataType::Decimal128(7, 12), + DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0), // Recursive list tests DataType::List(new_arc_field("Level1", DataType::Binary, true)), DataType::List(new_arc_field(