diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 45abeb8f6fe2..fd8c2d2090b9 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -234,7 +234,7 @@ jobs: rust-version: stable - name: Run cargo doc run: | - export RUSTDOCFLAGS="-D warnings -A rustdoc::private-intra-doc-links" + export RUSTDOCFLAGS="-D warnings" cargo doc --document-private-items --no-deps --workspace cd datafusion-cli cargo doc --document-private-items --no-deps diff --git a/benchmarks/queries/clickbench/README.md b/benchmarks/queries/clickbench/README.md index 29b1a7588f17..560b54181d5f 100644 --- a/benchmarks/queries/clickbench/README.md +++ b/benchmarks/queries/clickbench/README.md @@ -14,7 +14,7 @@ ClickBench is focused on aggregation and filtering performance (though it has no The "extended" queries are not part of the official ClickBench benchmark. Instead they are used to test other DataFusion features that are not covered by -the standard benchmark Each description below is for the corresponding line in +the standard benchmark. Each description below is for the corresponding line in `extended.sql` (line 1 is `Q0`, line 2 is `Q1`, etc.) ### Q0: Data Exploration diff --git a/datafusion-examples/examples/dataframe_subquery.rs b/datafusion-examples/examples/dataframe_subquery.rs index e798751b3353..3e3d0c1b5a84 100644 --- a/datafusion-examples/examples/dataframe_subquery.rs +++ b/datafusion-examples/examples/dataframe_subquery.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use datafusion::error::Result; use datafusion::functions_aggregate::average::avg; +use datafusion::functions_aggregate::min_max::max; use datafusion::prelude::*; use datafusion::test_util::arrow_test_data; use datafusion_common::ScalarValue; diff --git a/datafusion-examples/examples/parquet_index.rs b/datafusion-examples/examples/parquet_index.rs index 91e178f1f1a5..d6e17764442d 100644 --- a/datafusion-examples/examples/parquet_index.rs +++ b/datafusion-examples/examples/parquet_index.rs @@ -25,12 +25,10 @@ use arrow_schema::SchemaRef; use async_trait::async_trait; use datafusion::catalog::Session; use datafusion::datasource::listing::PartitionedFile; -use datafusion::datasource::physical_plan::{ - parquet::StatisticsConverter, - {FileScanConfig, ParquetExec}, -}; +use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; use datafusion::datasource::TableProvider; use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::parquet::arrow::arrow_reader::statistics::StatisticsConverter; use datafusion::parquet::arrow::{ arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter, }; diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 58ff1121e36d..f62acaf0493b 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -321,7 +321,8 @@ impl From for io::Error { } impl DataFusionError { - const BACK_TRACE_SEP: &'static str = "\n\nbacktrace: "; + /// The separator between the error message and the backtrace + pub const BACK_TRACE_SEP: &'static str = "\n\nbacktrace: "; /// Get deepest underlying [`DataFusionError`] /// diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 09b90a56d2aa..0415c3164b38 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -217,10 +217,6 @@ name = "sort" harness = false name = "topk_aggregate" -[[bench]] -harness = false -name = "parquet_statistic" - [[bench]] harness = false name = "map_query_sql" diff --git a/datafusion/core/benches/parquet_statistic.rs b/datafusion/core/benches/parquet_statistic.rs deleted file mode 100644 index 3595e8773b07..000000000000 --- a/datafusion/core/benches/parquet_statistic.rs +++ /dev/null @@ -1,287 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Benchmarks of benchmark for extracting arrow statistics from parquet - -use arrow::array::{ArrayRef, DictionaryArray, Float64Array, StringArray, UInt64Array}; -use arrow_array::{Int32Array, Int64Array, RecordBatch}; -use arrow_schema::{ - DataType::{self, *}, - Field, Schema, -}; -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use datafusion::datasource::physical_plan::parquet::StatisticsConverter; -use parquet::{ - arrow::arrow_reader::ArrowReaderOptions, file::properties::WriterProperties, -}; -use parquet::{ - arrow::{arrow_reader::ArrowReaderBuilder, ArrowWriter}, - file::properties::EnabledStatistics, -}; -use std::sync::Arc; -use tempfile::NamedTempFile; -#[derive(Debug, Clone)] -enum TestTypes { - UInt64, - Int64, - F64, - String, - Dictionary, -} - -use std::fmt; - -impl fmt::Display for TestTypes { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - TestTypes::UInt64 => write!(f, "UInt64"), - TestTypes::Int64 => write!(f, "Int64"), - TestTypes::F64 => write!(f, "F64"), - TestTypes::String => write!(f, "String"), - TestTypes::Dictionary => write!(f, "Dictionary(Int32, String)"), - } - } -} - -fn create_parquet_file( - dtype: TestTypes, - row_groups: usize, - data_page_row_count_limit: &Option, -) -> NamedTempFile { - let schema = match dtype { - TestTypes::UInt64 => { - Arc::new(Schema::new(vec![Field::new("col", DataType::UInt64, true)])) - } - TestTypes::Int64 => { - Arc::new(Schema::new(vec![Field::new("col", DataType::Int64, true)])) - } - TestTypes::F64 => Arc::new(Schema::new(vec![Field::new( - "col", - DataType::Float64, - true, - )])), - TestTypes::String => { - Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, true)])) - } - TestTypes::Dictionary => Arc::new(Schema::new(vec![Field::new( - "col", - DataType::Dictionary(Box::new(Int32), Box::new(Utf8)), - true, - )])), - }; - - let mut props = WriterProperties::builder().set_max_row_group_size(row_groups); - if let Some(limit) = data_page_row_count_limit { - props = props - .set_data_page_row_count_limit(*limit) - .set_statistics_enabled(EnabledStatistics::Page); - }; - let props = props.build(); - - let file = tempfile::Builder::new() - .suffix(".parquet") - .tempfile() - .unwrap(); - let mut writer = - ArrowWriter::try_new(file.reopen().unwrap(), schema.clone(), Some(props)) - .unwrap(); - - for _ in 0..row_groups { - let batch = match dtype { - TestTypes::UInt64 => make_uint64_batch(), - TestTypes::Int64 => make_int64_batch(), - TestTypes::F64 => make_f64_batch(), - TestTypes::String => make_string_batch(), - TestTypes::Dictionary => make_dict_batch(), - }; - if data_page_row_count_limit.is_some() { - // Send batches one at a time. This allows the - // writer to apply the page limit, that is only - // checked on RecordBatch boundaries. - for i in 0..batch.num_rows() { - writer.write(&batch.slice(i, 1)).unwrap(); - } - } else { - writer.write(&batch).unwrap(); - } - } - writer.close().unwrap(); - file -} - -fn make_uint64_batch() -> RecordBatch { - let array: ArrayRef = Arc::new(UInt64Array::from(vec![ - Some(1), - Some(2), - Some(3), - Some(4), - Some(5), - ])); - RecordBatch::try_new( - Arc::new(arrow::datatypes::Schema::new(vec![ - arrow::datatypes::Field::new("col", UInt64, false), - ])), - vec![array], - ) - .unwrap() -} - -fn make_int64_batch() -> RecordBatch { - let array: ArrayRef = Arc::new(Int64Array::from(vec![ - Some(1), - Some(2), - Some(3), - Some(4), - Some(5), - ])); - RecordBatch::try_new( - Arc::new(arrow::datatypes::Schema::new(vec![ - arrow::datatypes::Field::new("col", Int64, false), - ])), - vec![array], - ) - .unwrap() -} - -fn make_f64_batch() -> RecordBatch { - let array: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0])); - RecordBatch::try_new( - Arc::new(arrow::datatypes::Schema::new(vec![ - arrow::datatypes::Field::new("col", Float64, false), - ])), - vec![array], - ) - .unwrap() -} - -fn make_string_batch() -> RecordBatch { - let array: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); - RecordBatch::try_new( - Arc::new(arrow::datatypes::Schema::new(vec![ - arrow::datatypes::Field::new("col", Utf8, false), - ])), - vec![array], - ) - .unwrap() -} - -fn make_dict_batch() -> RecordBatch { - let keys = Int32Array::from(vec![0, 1, 2, 3, 4]); - let values = StringArray::from(vec!["a", "b", "c", "d", "e"]); - let array: ArrayRef = - Arc::new(DictionaryArray::try_new(keys, Arc::new(values)).unwrap()); - RecordBatch::try_new( - Arc::new(Schema::new(vec![Field::new( - "col", - Dictionary(Box::new(Int32), Box::new(Utf8)), - false, - )])), - vec![array], - ) - .unwrap() -} - -fn criterion_benchmark(c: &mut Criterion) { - let row_groups = 100; - use TestTypes::*; - let types = vec![Int64, UInt64, F64, String, Dictionary]; - let data_page_row_count_limits = vec![None, Some(1)]; - - for dtype in types { - for data_page_row_count_limit in &data_page_row_count_limits { - let file = - create_parquet_file(dtype.clone(), row_groups, data_page_row_count_limit); - let file = file.reopen().unwrap(); - let options = ArrowReaderOptions::new().with_page_index(true); - let reader = ArrowReaderBuilder::try_new_with_options(file, options).unwrap(); - let metadata = reader.metadata(); - let row_groups = metadata.row_groups(); - let row_group_indices: Vec<_> = (0..row_groups.len()).collect(); - - let statistic_type = if data_page_row_count_limit.is_some() { - "data page" - } else { - "row group" - }; - - let mut group = c.benchmark_group(format!( - "Extract {} statistics for {}", - statistic_type, - dtype.clone() - )); - group.bench_function( - BenchmarkId::new("extract_statistics", dtype.clone()), - |b| { - b.iter(|| { - let converter = StatisticsConverter::try_new( - "col", - reader.schema(), - reader.parquet_schema(), - ) - .unwrap(); - - if data_page_row_count_limit.is_some() { - let column_page_index = reader - .metadata() - .column_index() - .expect("File should have column page indices"); - - let column_offset_index = reader - .metadata() - .offset_index() - .expect("File should have column offset indices"); - - let _ = converter.data_page_mins( - column_page_index, - column_offset_index, - &row_group_indices, - ); - let _ = converter.data_page_maxes( - column_page_index, - column_offset_index, - &row_group_indices, - ); - let _ = converter.data_page_null_counts( - column_page_index, - column_offset_index, - &row_group_indices, - ); - let _ = converter.data_page_row_counts( - column_offset_index, - row_groups, - &row_group_indices, - ); - } else { - let _ = converter.row_group_mins(row_groups.iter()).unwrap(); - let _ = converter.row_group_maxes(row_groups.iter()).unwrap(); - let _ = converter - .row_group_null_counts(row_groups.iter()) - .unwrap(); - let _ = converter - .row_group_row_counts(row_groups.iter()) - .unwrap(); - } - }) - }, - ); - group.finish(); - } - } -} - -criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 8feccfb43d6b..cacfa4c6f2aa 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -53,9 +53,11 @@ use datafusion_common::{ }; use datafusion_expr::{case, is_null, lit}; use datafusion_expr::{ - max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, + utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, +}; +use datafusion_functions_aggregate::expr_fn::{ + avg, count, max, median, min, stddev, sum, }; -use datafusion_functions_aggregate::expr_fn::{avg, count, median, stddev, sum}; use async_trait::async_trait; use datafusion_catalog::Session; @@ -144,6 +146,7 @@ impl Default for DataFrameWriteOptions { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; +/// # use datafusion::functions_aggregate::expr_fn::min; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); @@ -407,6 +410,7 @@ impl DataFrame { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; + /// # use datafusion::functions_aggregate::expr_fn::min; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 25956665d56c..f233f3842c8c 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -50,7 +50,7 @@ use datafusion_common::{ use datafusion_common_runtime::SpawnedTask; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::expressions::{MaxAccumulator, MinAccumulator}; +use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::metrics::MetricsSet; @@ -75,12 +75,11 @@ use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver, Sender}; use tokio::task::JoinSet; -use crate::datasource::physical_plan::parquet::{ - ParquetExecBuilder, StatisticsConverter, -}; +use crate::datasource::physical_plan::parquet::ParquetExecBuilder; use futures::{StreamExt, TryStreamExt}; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; +use parquet::arrow::arrow_reader::statistics::StatisticsConverter; /// Initial writing buffer size. Note this is just a size hint for efficiency. It /// will grow beyond the set value if needed. diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 72c6e0d84c04..80f49e4eb8e6 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -743,7 +743,7 @@ impl TableProvider for ListingTable { filters: &[Expr], limit: Option, ) -> Result> { - // TODO remove downcast_ref from here? + // TODO (https://github.com/apache/datafusion/issues/11600) remove downcast_ref from here? let session_state = state.as_any().downcast_ref::().unwrap(); let (mut partitioned_file_lists, statistics) = self .list_files_for_scan(session_state, filters, limit) @@ -883,7 +883,7 @@ impl TableProvider for ListingTable { // Get the object store for the table path. let store = state.runtime_env().object_store(table_path)?; - // TODO remove downcast_ref from here? + // TODO (https://github.com/apache/datafusion/issues/11600) remove downcast_ref from here? let session_state = state.as_any().downcast_ref::().unwrap(); let file_list_stream = pruned_partition_list( session_state, diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index ce52dd98166e..591a19aab49b 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -52,7 +52,7 @@ impl TableProviderFactory for ListingTableFactory { state: &dyn Session, cmd: &CreateExternalTable, ) -> Result> { - // TODO remove downcast_ref from here. Should file format factory be an extension to session state? + // TODO (https://github.com/apache/datafusion/issues/11600) remove downcast_ref from here. Should file format factory be an extension to session state? let session_state = state.as_any().downcast_ref::().unwrap(); let file_format = session_state .get_file_format_factory(cmd.file_type.as_str()) diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index a897895246e3..f810fb86bd89 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -35,7 +35,7 @@ pub use self::parquet::{ParquetExec, ParquetFileMetrics, ParquetFileReaderFactor pub use arrow_file::ArrowExec; pub use avro::AvroExec; -pub use csv::{CsvConfig, CsvExec, CsvOpener}; +pub use csv::{CsvConfig, CsvExec, CsvExecBuilder, CsvOpener}; pub use file_groups::FileGroupPartitioner; pub use file_scan_config::{ wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index ed71d871b3fd..72aabefba595 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -52,7 +52,6 @@ mod page_filter; mod reader; mod row_filter; mod row_group_filter; -mod statistics; mod writer; use crate::datasource::schema_adapter::{ @@ -62,7 +61,6 @@ pub use access_plan::{ParquetAccessPlan, RowGroupAccess}; pub use metrics::ParquetFileMetrics; use opener::ParquetOpener; pub use reader::{DefaultParquetFileReaderFactory, ParquetFileReaderFactory}; -pub use statistics::StatisticsConverter; pub use writer::plan_to_parquet; /// Execution plan for reading one or more Parquet files. @@ -191,9 +189,9 @@ pub use writer::plan_to_parquet; /// # Execution Overview /// /// * Step 1: [`ParquetExec::execute`] is called, returning a [`FileStream`] -/// configured to open parquet files with a [`ParquetOpener`]. +/// configured to open parquet files with a `ParquetOpener`. /// -/// * Step 2: When the stream is polled, the [`ParquetOpener`] is called to open +/// * Step 2: When the stream is polled, the `ParquetOpener` is called to open /// the file. /// /// * Step 3: The `ParquetOpener` gets the [`ParquetMetaData`] (file metadata) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index d658608ab4f1..e4d26a460ecd 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -17,8 +17,8 @@ //! Contains code to filter entire pages +use super::metrics::ParquetFileMetrics; use crate::datasource::physical_plan::parquet::ParquetAccessPlan; -use crate::datasource::physical_plan::parquet::StatisticsConverter; use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; use arrow::array::BooleanArray; use arrow::{array::ArrayRef, datatypes::SchemaRef}; @@ -26,6 +26,7 @@ use arrow_schema::Schema; use datafusion_common::ScalarValue; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use log::{debug, trace}; +use parquet::arrow::arrow_reader::statistics::StatisticsConverter; use parquet::file::metadata::{ParquetColumnIndex, ParquetOffsetIndex}; use parquet::format::PageLocation; use parquet::schema::types::SchemaDescriptor; @@ -36,8 +37,6 @@ use parquet::{ use std::collections::HashSet; use std::sync::Arc; -use super::metrics::ParquetFileMetrics; - /// Filters a [`ParquetAccessPlan`] based on the [Parquet PageIndex], if present /// /// It does so by evaluating statistics from the [`ParquetColumnIndex`] and @@ -377,7 +376,7 @@ impl<'a> PagesPruningStatistics<'a> { converter: StatisticsConverter<'a>, parquet_metadata: &'a ParquetMetaData, ) -> Option { - let Some(parquet_column_index) = converter.parquet_index() else { + let Some(parquet_column_index) = converter.parquet_column_index() else { trace!( "Column {:?} not in parquet file, skipping", converter.arrow_field() @@ -432,7 +431,6 @@ impl<'a> PagesPruningStatistics<'a> { Some(vec) } } - impl<'a> PruningStatistics for PagesPruningStatistics<'a> { fn min_values(&self, _column: &datafusion_common::Column) -> Option { match self.converter.data_page_mins( diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs index 170beb15ead2..6a6910748fc8 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs @@ -15,9 +15,13 @@ // specific language governing permissions and limitations // under the License. +use crate::datasource::listing::FileRange; +use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; use arrow::{array::ArrayRef, datatypes::Schema}; use arrow_array::BooleanArray; use datafusion_common::{Column, Result, ScalarValue}; +use parquet::arrow::arrow_reader::statistics::StatisticsConverter; +use parquet::arrow::parquet_column; use parquet::basic::Type; use parquet::data_type::Decimal; use parquet::schema::types::SchemaDescriptor; @@ -29,11 +33,7 @@ use parquet::{ use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use crate::datasource::listing::FileRange; -use crate::datasource::physical_plan::parquet::statistics::parquet_column; -use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; - -use super::{ParquetAccessPlan, ParquetFileMetrics, StatisticsConverter}; +use super::{ParquetAccessPlan, ParquetFileMetrics}; /// Reduces the [`ParquetAccessPlan`] based on row group level metadata. /// @@ -356,20 +356,24 @@ impl<'a> RowGroupPruningStatistics<'a> { &'a self, column: &'b Column, ) -> Result> { - StatisticsConverter::try_new(&column.name, self.arrow_schema, self.parquet_schema) + Ok(StatisticsConverter::try_new( + &column.name, + self.arrow_schema, + self.parquet_schema, + )?) } } impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { fn min_values(&self, column: &Column) -> Option { self.statistics_converter(column) - .and_then(|c| c.row_group_mins(self.metadata_iter())) + .and_then(|c| Ok(c.row_group_mins(self.metadata_iter())?)) .ok() } fn max_values(&self, column: &Column) -> Option { self.statistics_converter(column) - .and_then(|c| c.row_group_maxes(self.metadata_iter())) + .and_then(|c| Ok(c.row_group_maxes(self.metadata_iter())?)) .ok() } @@ -379,7 +383,7 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { fn null_counts(&self, column: &Column) -> Option { self.statistics_converter(column) - .and_then(|c| c.row_group_null_counts(self.metadata_iter())) + .and_then(|c| Ok(c.row_group_null_counts(self.metadata_iter())?)) .ok() .map(|counts| Arc::new(counts) as ArrayRef) } @@ -387,7 +391,7 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { fn row_counts(&self, column: &Column) -> Option { // row counts are the same for all columns in a row group self.statistics_converter(column) - .and_then(|c| c.row_group_row_counts(self.metadata_iter())) + .and_then(|c| Ok(c.row_group_row_counts(self.metadata_iter())?)) .ok() .flatten() .map(|counts| Arc::new(counts) as ArrayRef) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs deleted file mode 100644 index 11b8f5fc6c79..000000000000 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ /dev/null @@ -1,2642 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`StatisticsConverter`] to convert statistics in parquet format to arrow [`ArrayRef`]. - -// TODO: potentially move this to arrow-rs: https://github.com/apache/arrow-rs/issues/4328 - -use arrow::array::{ - BooleanBuilder, FixedSizeBinaryBuilder, LargeStringBuilder, StringBuilder, -}; -use arrow::datatypes::i256; -use arrow::{array::ArrayRef, datatypes::DataType}; -use arrow_array::{ - new_empty_array, new_null_array, BinaryArray, BooleanArray, Date32Array, Date64Array, - Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array, - Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, - Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, - Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, -}; -use arrow_schema::{Field, FieldRef, Schema, TimeUnit}; -use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; -use half::f16; -use parquet::data_type::{ByteArray, FixedLenByteArray}; -use parquet::file::metadata::{ParquetColumnIndex, ParquetOffsetIndex, RowGroupMetaData}; -use parquet::file::page_index::index::{Index, PageIndex}; -use parquet::file::statistics::Statistics as ParquetStatistics; -use parquet::schema::types::SchemaDescriptor; -use paste::paste; -use std::sync::Arc; - -// Convert the bytes array to i128. -// The endian of the input bytes array must be big-endian. -pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 { - // The bytes array are from parquet file and must be the big-endian. - // The endian is defined by parquet format, and the reference document - // https://github.com/apache/parquet-format/blob/54e53e5d7794d383529dd30746378f19a12afd58/src/main/thrift/parquet.thrift#L66 - i128::from_be_bytes(sign_extend_be::<16>(b)) -} - -// Convert the bytes array to i256. -// The endian of the input bytes array must be big-endian. -pub(crate) fn from_bytes_to_i256(b: &[u8]) -> i256 { - i256::from_be_bytes(sign_extend_be::<32>(b)) -} - -// Convert the bytes array to f16 -pub(crate) fn from_bytes_to_f16(b: &[u8]) -> Option { - match b { - [low, high] => Some(f16::from_be_bytes([*high, *low])), - _ => None, - } -} - -// Copy from arrow-rs -// https://github.com/apache/arrow-rs/blob/198af7a3f4aa20f9bd003209d9f04b0f37bb120e/parquet/src/arrow/buffer/bit_util.rs#L54 -// Convert the byte slice to fixed length byte array with the length of N. -fn sign_extend_be(b: &[u8]) -> [u8; N] { - assert!(b.len() <= N, "Array too large, expected less than {N}"); - let is_negative = (b[0] & 128u8) == 128u8; - let mut result = if is_negative { [255u8; N] } else { [0u8; N] }; - for (d, s) in result.iter_mut().skip(N - b.len()).zip(b) { - *d = *s; - } - result -} - -/// Define an adapter iterator for extracting statistics from an iterator of -/// `ParquetStatistics` -/// -/// -/// Handles checking if the statistics are present and valid with the correct type. -/// -/// Parameters: -/// * `$iterator_type` is the name of the iterator type (e.g. `MinBooleanStatsIterator`) -/// * `$func` is the function to call to get the value (e.g. `min` or `max`) -/// * `$parquet_statistics_type` is the type of the statistics (e.g. `ParquetStatistics::Boolean`) -/// * `$stat_value_type` is the type of the statistics value (e.g. `bool`) -macro_rules! make_stats_iterator { - ($iterator_type:ident, $func:ident, $parquet_statistics_type:path, $stat_value_type:ty) => { - /// Maps an iterator of `ParquetStatistics` into an iterator of - /// `&$stat_value_type`` - /// - /// Yielded elements: - /// * Some(stats) if valid - /// * None if the statistics are not present, not valid, or not $stat_value_type - struct $iterator_type<'a, I> - where - I: Iterator>, - { - iter: I, - } - - impl<'a, I> $iterator_type<'a, I> - where - I: Iterator>, - { - /// Create a new iterator to extract the statistics - fn new(iter: I) -> Self { - Self { iter } - } - } - - /// Implement the Iterator trait for the iterator - impl<'a, I> Iterator for $iterator_type<'a, I> - where - I: Iterator>, - { - type Item = Option<&'a $stat_value_type>; - - /// return the next statistics value - fn next(&mut self) -> Option { - let next = self.iter.next(); - next.map(|x| { - x.and_then(|stats| match stats { - $parquet_statistics_type(s) if stats.has_min_max_set() => { - Some(s.$func()) - } - _ => None, - }) - }) - } - - fn size_hint(&self) -> (usize, Option) { - self.iter.size_hint() - } - } - }; -} - -make_stats_iterator!( - MinBooleanStatsIterator, - min, - ParquetStatistics::Boolean, - bool -); -make_stats_iterator!( - MaxBooleanStatsIterator, - max, - ParquetStatistics::Boolean, - bool -); -make_stats_iterator!(MinInt32StatsIterator, min, ParquetStatistics::Int32, i32); -make_stats_iterator!(MaxInt32StatsIterator, max, ParquetStatistics::Int32, i32); -make_stats_iterator!(MinInt64StatsIterator, min, ParquetStatistics::Int64, i64); -make_stats_iterator!(MaxInt64StatsIterator, max, ParquetStatistics::Int64, i64); -make_stats_iterator!(MinFloatStatsIterator, min, ParquetStatistics::Float, f32); -make_stats_iterator!(MaxFloatStatsIterator, max, ParquetStatistics::Float, f32); -make_stats_iterator!(MinDoubleStatsIterator, min, ParquetStatistics::Double, f64); -make_stats_iterator!(MaxDoubleStatsIterator, max, ParquetStatistics::Double, f64); -make_stats_iterator!( - MinByteArrayStatsIterator, - min_bytes, - ParquetStatistics::ByteArray, - [u8] -); -make_stats_iterator!( - MaxByteArrayStatsIterator, - max_bytes, - ParquetStatistics::ByteArray, - [u8] -); -make_stats_iterator!( - MinFixedLenByteArrayStatsIterator, - min_bytes, - ParquetStatistics::FixedLenByteArray, - [u8] -); -make_stats_iterator!( - MaxFixedLenByteArrayStatsIterator, - max_bytes, - ParquetStatistics::FixedLenByteArray, - [u8] -); - -/// Special iterator adapter for extracting i128 values from from an iterator of -/// `ParquetStatistics` -/// -/// Handles checking if the statistics are present and valid with the correct type. -/// -/// Depending on the parquet file, the statistics for `Decimal128` can be stored as -/// `Int32`, `Int64` or `ByteArray` or `FixedSizeByteArray` :mindblown: -/// -/// This iterator handles all cases, extracting the values -/// and converting it to `stat_value_type`. -/// -/// Parameters: -/// * `$iterator_type` is the name of the iterator type (e.g. `MinBooleanStatsIterator`) -/// * `$func` is the function to call to get the value (e.g. `min` or `max`) -/// * `$bytes_func` is the function to call to get the value as bytes (e.g. `min_bytes` or `max_bytes`) -/// * `$stat_value_type` is the type of the statistics value (e.g. `i128`) -/// * `convert_func` is the function to convert the bytes to stats value (e.g. `from_bytes_to_i128`) -macro_rules! make_decimal_stats_iterator { - ($iterator_type:ident, $func:ident, $bytes_func:ident, $stat_value_type:ident, $convert_func: ident) => { - struct $iterator_type<'a, I> - where - I: Iterator>, - { - iter: I, - } - - impl<'a, I> $iterator_type<'a, I> - where - I: Iterator>, - { - fn new(iter: I) -> Self { - Self { iter } - } - } - - impl<'a, I> Iterator for $iterator_type<'a, I> - where - I: Iterator>, - { - type Item = Option<$stat_value_type>; - - fn next(&mut self) -> Option { - let next = self.iter.next(); - next.map(|x| { - x.and_then(|stats| { - if !stats.has_min_max_set() { - return None; - } - match stats { - ParquetStatistics::Int32(s) => { - Some($stat_value_type::from(*s.$func())) - } - ParquetStatistics::Int64(s) => { - Some($stat_value_type::from(*s.$func())) - } - ParquetStatistics::ByteArray(s) => { - Some($convert_func(s.$bytes_func())) - } - ParquetStatistics::FixedLenByteArray(s) => { - Some($convert_func(s.$bytes_func())) - } - _ => None, - } - }) - }) - } - - fn size_hint(&self) -> (usize, Option) { - self.iter.size_hint() - } - } - }; -} - -make_decimal_stats_iterator!( - MinDecimal128StatsIterator, - min, - min_bytes, - i128, - from_bytes_to_i128 -); -make_decimal_stats_iterator!( - MaxDecimal128StatsIterator, - max, - max_bytes, - i128, - from_bytes_to_i128 -); -make_decimal_stats_iterator!( - MinDecimal256StatsIterator, - min, - min_bytes, - i256, - from_bytes_to_i256 -); -make_decimal_stats_iterator!( - MaxDecimal256StatsIterator, - max, - max_bytes, - i256, - from_bytes_to_i256 -); - -/// Special macro to combine the statistics iterators for min and max using the [`mod@paste`] macro. -/// This is used to avoid repeating the same code for min and max statistics extractions -/// -/// Parameters: -/// stat_type_prefix: The prefix of the statistics iterator type (e.g. `Min` or `Max`) -/// data_type: The data type of the statistics (e.g. `DataType::Int32`) -/// iterator: The iterator of [`ParquetStatistics`] to extract the statistics from. -macro_rules! get_statistics { - ($stat_type_prefix: ident, $data_type: ident, $iterator: ident) => { - paste! { - match $data_type { - DataType::Boolean => Ok(Arc::new(BooleanArray::from_iter( - [<$stat_type_prefix BooleanStatsIterator>]::new($iterator).map(|x| x.copied()), - ))), - DataType::Int8 => Ok(Arc::new(Int8Array::from_iter( - [<$stat_type_prefix Int32StatsIterator>]::new($iterator).map(|x| { - x.and_then(|x| i8::try_from(*x).ok()) - }), - ))), - DataType::Int16 => Ok(Arc::new(Int16Array::from_iter( - [<$stat_type_prefix Int32StatsIterator>]::new($iterator).map(|x| { - x.and_then(|x| i16::try_from(*x).ok()) - }), - ))), - DataType::Int32 => Ok(Arc::new(Int32Array::from_iter( - [<$stat_type_prefix Int32StatsIterator>]::new($iterator).map(|x| x.copied()), - ))), - DataType::Int64 => Ok(Arc::new(Int64Array::from_iter( - [<$stat_type_prefix Int64StatsIterator>]::new($iterator).map(|x| x.copied()), - ))), - DataType::UInt8 => Ok(Arc::new(UInt8Array::from_iter( - [<$stat_type_prefix Int32StatsIterator>]::new($iterator).map(|x| { - x.and_then(|x| u8::try_from(*x).ok()) - }), - ))), - DataType::UInt16 => Ok(Arc::new(UInt16Array::from_iter( - [<$stat_type_prefix Int32StatsIterator>]::new($iterator).map(|x| { - x.and_then(|x| u16::try_from(*x).ok()) - }), - ))), - DataType::UInt32 => Ok(Arc::new(UInt32Array::from_iter( - [<$stat_type_prefix Int32StatsIterator>]::new($iterator).map(|x| x.map(|x| *x as u32)), - ))), - DataType::UInt64 => Ok(Arc::new(UInt64Array::from_iter( - [<$stat_type_prefix Int64StatsIterator>]::new($iterator).map(|x| x.map(|x| *x as u64)), - ))), - DataType::Float16 => Ok(Arc::new(Float16Array::from_iter( - [<$stat_type_prefix FixedLenByteArrayStatsIterator>]::new($iterator).map(|x| x.and_then(|x| { - from_bytes_to_f16(x) - })), - ))), - DataType::Float32 => Ok(Arc::new(Float32Array::from_iter( - [<$stat_type_prefix FloatStatsIterator>]::new($iterator).map(|x| x.copied()), - ))), - DataType::Float64 => Ok(Arc::new(Float64Array::from_iter( - [<$stat_type_prefix DoubleStatsIterator>]::new($iterator).map(|x| x.copied()), - ))), - DataType::Date32 => Ok(Arc::new(Date32Array::from_iter( - [<$stat_type_prefix Int32StatsIterator>]::new($iterator).map(|x| x.copied()), - ))), - DataType::Date64 => Ok(Arc::new(Date64Array::from_iter( - [<$stat_type_prefix Int32StatsIterator>]::new($iterator) - .map(|x| x.map(|x| i64::from(*x) * 24 * 60 * 60 * 1000)), - ))), - DataType::Timestamp(unit, timezone) =>{ - let iter = [<$stat_type_prefix Int64StatsIterator>]::new($iterator).map(|x| x.copied()); - Ok(match unit { - TimeUnit::Second => Arc::new(TimestampSecondArray::from_iter(iter).with_timezone_opt(timezone.clone())), - TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from_iter(iter).with_timezone_opt(timezone.clone())), - TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from_iter(iter).with_timezone_opt(timezone.clone())), - TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from_iter(iter).with_timezone_opt(timezone.clone())), - }) - }, - DataType::Time32(unit) => { - Ok(match unit { - TimeUnit::Second => Arc::new(Time32SecondArray::from_iter( - [<$stat_type_prefix Int32StatsIterator>]::new($iterator).map(|x| x.copied()), - )), - TimeUnit::Millisecond => Arc::new(Time32MillisecondArray::from_iter( - [<$stat_type_prefix Int32StatsIterator>]::new($iterator).map(|x| x.copied()), - )), - _ => { - let len = $iterator.count(); - // don't know how to extract statistics, so return a null array - new_null_array($data_type, len) - } - }) - }, - DataType::Time64(unit) => { - Ok(match unit { - TimeUnit::Microsecond => Arc::new(Time64MicrosecondArray::from_iter( - [<$stat_type_prefix Int64StatsIterator>]::new($iterator).map(|x| x.copied()), - )), - TimeUnit::Nanosecond => Arc::new(Time64NanosecondArray::from_iter( - [<$stat_type_prefix Int64StatsIterator>]::new($iterator).map(|x| x.copied()), - )), - _ => { - let len = $iterator.count(); - // don't know how to extract statistics, so return a null array - new_null_array($data_type, len) - } - }) - }, - DataType::Binary => Ok(Arc::new(BinaryArray::from_iter( - [<$stat_type_prefix ByteArrayStatsIterator>]::new($iterator) - ))), - DataType::LargeBinary => Ok(Arc::new(LargeBinaryArray::from_iter( - [<$stat_type_prefix ByteArrayStatsIterator>]::new($iterator) - ))), - DataType::Utf8 => { - let iterator = [<$stat_type_prefix ByteArrayStatsIterator>]::new($iterator); - let mut builder = StringBuilder::new(); - for x in iterator { - let Some(x) = x else { - builder.append_null(); // no statistics value - continue; - }; - - let Ok(x) = std::str::from_utf8(x) else { - log::debug!("Utf8 statistics is a non-UTF8 value, ignoring it."); - builder.append_null(); - continue; - }; - - builder.append_value(x); - } - Ok(Arc::new(builder.finish())) - }, - DataType::LargeUtf8 => { - let iterator = [<$stat_type_prefix ByteArrayStatsIterator>]::new($iterator); - let mut builder = LargeStringBuilder::new(); - for x in iterator { - let Some(x) = x else { - builder.append_null(); // no statistics value - continue; - }; - - let Ok(x) = std::str::from_utf8(x) else { - log::debug!("Utf8 statistics is a non-UTF8 value, ignoring it."); - builder.append_null(); - continue; - }; - - builder.append_value(x); - } - Ok(Arc::new(builder.finish())) - }, - DataType::FixedSizeBinary(size) => { - let iterator = [<$stat_type_prefix FixedLenByteArrayStatsIterator>]::new($iterator); - let mut builder = FixedSizeBinaryBuilder::new(*size); - for x in iterator { - let Some(x) = x else { - builder.append_null(); // no statistics value - continue; - }; - - // ignore invalid values - if x.len().try_into() != Ok(*size){ - log::debug!( - "FixedSizeBinary({}) statistics is a binary of size {}, ignoring it.", - size, - x.len(), - ); - builder.append_null(); - continue; - } - - builder.append_value(x).expect("ensure to append successfully here, because size have been checked before"); - } - Ok(Arc::new(builder.finish())) - }, - DataType::Decimal128(precision, scale) => { - let arr = Decimal128Array::from_iter( - [<$stat_type_prefix Decimal128StatsIterator>]::new($iterator) - ).with_precision_and_scale(*precision, *scale)?; - Ok(Arc::new(arr)) - }, - DataType::Decimal256(precision, scale) => { - let arr = Decimal256Array::from_iter( - [<$stat_type_prefix Decimal256StatsIterator>]::new($iterator) - ).with_precision_and_scale(*precision, *scale)?; - Ok(Arc::new(arr)) - }, - DataType::Dictionary(_, value_type) => { - [<$stat_type_prefix:lower _ statistics>](value_type, $iterator) - } - - DataType::Map(_,_) | - DataType::Duration(_) | - DataType::Interval(_) | - DataType::Null | - DataType::BinaryView | - DataType::Utf8View | - DataType::List(_) | - DataType::ListView(_) | - DataType::FixedSizeList(_, _) | - DataType::LargeList(_) | - DataType::LargeListView(_) | - DataType::Struct(_) | - DataType::Union(_, _) | - DataType::RunEndEncoded(_, _) => { - let len = $iterator.count(); - // don't know how to extract statistics, so return a null array - Ok(new_null_array($data_type, len)) - } - }}} -} - -macro_rules! make_data_page_stats_iterator { - ($iterator_type: ident, $func: expr, $index_type: path, $stat_value_type: ty) => { - struct $iterator_type<'a, I> - where - I: Iterator, - { - iter: I, - } - - impl<'a, I> $iterator_type<'a, I> - where - I: Iterator, - { - fn new(iter: I) -> Self { - Self { iter } - } - } - - impl<'a, I> Iterator for $iterator_type<'a, I> - where - I: Iterator, - { - type Item = Vec>; - - fn next(&mut self) -> Option { - let next = self.iter.next(); - match next { - Some((len, index)) => match index { - $index_type(native_index) => Some( - native_index - .indexes - .iter() - .map(|x| $func(x)) - .collect::>(), - ), - // No matching `Index` found; - // thus no statistics that can be extracted. - // We return vec![None; len] to effectively - // create an arrow null-array with the length - // corresponding to the number of entries in - // `ParquetOffsetIndex` per row group per column. - _ => Some(vec![None; len]), - }, - _ => None, - } - } - - fn size_hint(&self) -> (usize, Option) { - self.iter.size_hint() - } - } - }; -} - -make_data_page_stats_iterator!( - MinBooleanDataPageStatsIterator, - |x: &PageIndex| { x.min }, - Index::BOOLEAN, - bool -); -make_data_page_stats_iterator!( - MaxBooleanDataPageStatsIterator, - |x: &PageIndex| { x.max }, - Index::BOOLEAN, - bool -); -make_data_page_stats_iterator!( - MinInt32DataPageStatsIterator, - |x: &PageIndex| { x.min }, - Index::INT32, - i32 -); -make_data_page_stats_iterator!( - MaxInt32DataPageStatsIterator, - |x: &PageIndex| { x.max }, - Index::INT32, - i32 -); -make_data_page_stats_iterator!( - MinInt64DataPageStatsIterator, - |x: &PageIndex| { x.min }, - Index::INT64, - i64 -); -make_data_page_stats_iterator!( - MaxInt64DataPageStatsIterator, - |x: &PageIndex| { x.max }, - Index::INT64, - i64 -); -make_data_page_stats_iterator!( - MinFloat16DataPageStatsIterator, - |x: &PageIndex| { x.min.clone() }, - Index::FIXED_LEN_BYTE_ARRAY, - FixedLenByteArray -); -make_data_page_stats_iterator!( - MaxFloat16DataPageStatsIterator, - |x: &PageIndex| { x.max.clone() }, - Index::FIXED_LEN_BYTE_ARRAY, - FixedLenByteArray -); -make_data_page_stats_iterator!( - MinFloat32DataPageStatsIterator, - |x: &PageIndex| { x.min }, - Index::FLOAT, - f32 -); -make_data_page_stats_iterator!( - MaxFloat32DataPageStatsIterator, - |x: &PageIndex| { x.max }, - Index::FLOAT, - f32 -); -make_data_page_stats_iterator!( - MinFloat64DataPageStatsIterator, - |x: &PageIndex| { x.min }, - Index::DOUBLE, - f64 -); -make_data_page_stats_iterator!( - MaxFloat64DataPageStatsIterator, - |x: &PageIndex| { x.max }, - Index::DOUBLE, - f64 -); -make_data_page_stats_iterator!( - MinByteArrayDataPageStatsIterator, - |x: &PageIndex| { x.min.clone() }, - Index::BYTE_ARRAY, - ByteArray -); -make_data_page_stats_iterator!( - MaxByteArrayDataPageStatsIterator, - |x: &PageIndex| { x.max.clone() }, - Index::BYTE_ARRAY, - ByteArray -); -make_data_page_stats_iterator!( - MaxFixedLenByteArrayDataPageStatsIterator, - |x: &PageIndex| { x.max.clone() }, - Index::FIXED_LEN_BYTE_ARRAY, - FixedLenByteArray -); - -make_data_page_stats_iterator!( - MinFixedLenByteArrayDataPageStatsIterator, - |x: &PageIndex| { x.min.clone() }, - Index::FIXED_LEN_BYTE_ARRAY, - FixedLenByteArray -); - -macro_rules! get_decimal_page_stats_iterator { - ($iterator_type: ident, $func: ident, $stat_value_type: ident, $convert_func: ident) => { - struct $iterator_type<'a, I> - where - I: Iterator, - { - iter: I, - } - - impl<'a, I> $iterator_type<'a, I> - where - I: Iterator, - { - fn new(iter: I) -> Self { - Self { iter } - } - } - - impl<'a, I> Iterator for $iterator_type<'a, I> - where - I: Iterator, - { - type Item = Vec>; - - fn next(&mut self) -> Option { - let next = self.iter.next(); - match next { - Some((len, index)) => match index { - Index::INT32(native_index) => Some( - native_index - .indexes - .iter() - .map(|x| { - x.$func.and_then(|x| Some($stat_value_type::from(x))) - }) - .collect::>(), - ), - Index::INT64(native_index) => Some( - native_index - .indexes - .iter() - .map(|x| { - x.$func.and_then(|x| Some($stat_value_type::from(x))) - }) - .collect::>(), - ), - Index::BYTE_ARRAY(native_index) => Some( - native_index - .indexes - .iter() - .map(|x| { - x.clone() - .$func - .and_then(|x| Some($convert_func(x.data()))) - }) - .collect::>(), - ), - Index::FIXED_LEN_BYTE_ARRAY(native_index) => Some( - native_index - .indexes - .iter() - .map(|x| { - x.clone() - .$func - .and_then(|x| Some($convert_func(x.data()))) - }) - .collect::>(), - ), - _ => Some(vec![None; len]), - }, - _ => None, - } - } - - fn size_hint(&self) -> (usize, Option) { - self.iter.size_hint() - } - } - }; -} - -get_decimal_page_stats_iterator!( - MinDecimal128DataPageStatsIterator, - min, - i128, - from_bytes_to_i128 -); - -get_decimal_page_stats_iterator!( - MaxDecimal128DataPageStatsIterator, - max, - i128, - from_bytes_to_i128 -); - -get_decimal_page_stats_iterator!( - MinDecimal256DataPageStatsIterator, - min, - i256, - from_bytes_to_i256 -); - -get_decimal_page_stats_iterator!( - MaxDecimal256DataPageStatsIterator, - max, - i256, - from_bytes_to_i256 -); - -macro_rules! get_data_page_statistics { - ($stat_type_prefix: ident, $data_type: ident, $iterator: ident) => { - paste! { - match $data_type { - Some(DataType::Boolean) => { - let iterator = [<$stat_type_prefix BooleanDataPageStatsIterator>]::new($iterator); - let mut builder = BooleanBuilder::new(); - for x in iterator { - for x in x.into_iter() { - let Some(x) = x else { - builder.append_null(); // no statistics value - continue; - }; - builder.append_value(x); - } - } - Ok(Arc::new(builder.finish())) - }, - Some(DataType::UInt8) => Ok(Arc::new( - UInt8Array::from_iter( - [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) - .map(|x| { - x.into_iter().map(|x| { - x.and_then(|x| u8::try_from(x).ok()) - }) - }) - .flatten() - ) - )), - Some(DataType::UInt16) => Ok(Arc::new( - UInt16Array::from_iter( - [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) - .map(|x| { - x.into_iter().map(|x| { - x.and_then(|x| u16::try_from(x).ok()) - }) - }) - .flatten() - ) - )), - Some(DataType::UInt32) => Ok(Arc::new( - UInt32Array::from_iter( - [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) - .map(|x| { - x.into_iter().map(|x| { - x.and_then(|x| Some(x as u32)) - }) - }) - .flatten() - ))), - Some(DataType::UInt64) => Ok(Arc::new( - UInt64Array::from_iter( - [<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator) - .map(|x| { - x.into_iter().map(|x| { - x.and_then(|x| Some(x as u64)) - }) - }) - .flatten() - ))), - Some(DataType::Int8) => Ok(Arc::new( - Int8Array::from_iter( - [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) - .map(|x| { - x.into_iter().map(|x| { - x.and_then(|x| i8::try_from(x).ok()) - }) - }) - .flatten() - ) - )), - Some(DataType::Int16) => Ok(Arc::new( - Int16Array::from_iter( - [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) - .map(|x| { - x.into_iter().map(|x| { - x.and_then(|x| i16::try_from(x).ok()) - }) - }) - .flatten() - ) - )), - Some(DataType::Int32) => Ok(Arc::new(Int32Array::from_iter([<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator).flatten()))), - Some(DataType::Int64) => Ok(Arc::new(Int64Array::from_iter([<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator).flatten()))), - Some(DataType::Float16) => Ok(Arc::new( - Float16Array::from_iter( - [<$stat_type_prefix Float16DataPageStatsIterator>]::new($iterator) - .map(|x| { - x.into_iter().map(|x| { - x.and_then(|x| from_bytes_to_f16(x.data())) - }) - }) - .flatten() - ) - )), - Some(DataType::Float32) => Ok(Arc::new(Float32Array::from_iter([<$stat_type_prefix Float32DataPageStatsIterator>]::new($iterator).flatten()))), - Some(DataType::Float64) => Ok(Arc::new(Float64Array::from_iter([<$stat_type_prefix Float64DataPageStatsIterator>]::new($iterator).flatten()))), - Some(DataType::Binary) => Ok(Arc::new(BinaryArray::from_iter([<$stat_type_prefix ByteArrayDataPageStatsIterator>]::new($iterator).flatten()))), - Some(DataType::LargeBinary) => Ok(Arc::new(LargeBinaryArray::from_iter([<$stat_type_prefix ByteArrayDataPageStatsIterator>]::new($iterator).flatten()))), - Some(DataType::Utf8) => { - let mut builder = StringBuilder::new(); - let iterator = [<$stat_type_prefix ByteArrayDataPageStatsIterator>]::new($iterator); - for x in iterator { - for x in x.into_iter() { - let Some(x) = x else { - builder.append_null(); // no statistics value - continue; - }; - - let Ok(x) = std::str::from_utf8(x.data()) else { - log::debug!("Utf8 statistics is a non-UTF8 value, ignoring it."); - builder.append_null(); - continue; - }; - - builder.append_value(x); - } - } - Ok(Arc::new(builder.finish())) - }, - Some(DataType::LargeUtf8) => { - let mut builder = LargeStringBuilder::new(); - let iterator = [<$stat_type_prefix ByteArrayDataPageStatsIterator>]::new($iterator); - for x in iterator { - for x in x.into_iter() { - let Some(x) = x else { - builder.append_null(); // no statistics value - continue; - }; - - let Ok(x) = std::str::from_utf8(x.data()) else { - log::debug!("LargeUtf8 statistics is a non-UTF8 value, ignoring it."); - builder.append_null(); - continue; - }; - - builder.append_value(x); - } - } - Ok(Arc::new(builder.finish())) - }, - Some(DataType::Dictionary(_, value_type)) => { - [<$stat_type_prefix:lower _ page_statistics>](Some(value_type), $iterator) - }, - Some(DataType::Timestamp(unit, timezone)) => { - let iter = [<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator).flatten(); - Ok(match unit { - TimeUnit::Second => Arc::new(TimestampSecondArray::from_iter(iter).with_timezone_opt(timezone.clone())), - TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from_iter(iter).with_timezone_opt(timezone.clone())), - TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from_iter(iter).with_timezone_opt(timezone.clone())), - TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from_iter(iter).with_timezone_opt(timezone.clone())), - }) - }, - Some(DataType::Date32) => Ok(Arc::new(Date32Array::from_iter([<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator).flatten()))), - Some(DataType::Date64) => Ok( - Arc::new( - Date64Array::from_iter([<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) - .map(|x| { - x.into_iter() - .map(|x| { - x.and_then(|x| i64::try_from(x).ok()) - }) - .map(|x| x.map(|x| x * 24 * 60 * 60 * 1000)) - }).flatten() - ) - ) - ), - Some(DataType::Decimal128(precision, scale)) => Ok(Arc::new( - Decimal128Array::from_iter([<$stat_type_prefix Decimal128DataPageStatsIterator>]::new($iterator).flatten()).with_precision_and_scale(*precision, *scale)?)), - Some(DataType::Decimal256(precision, scale)) => Ok(Arc::new( - Decimal256Array::from_iter([<$stat_type_prefix Decimal256DataPageStatsIterator>]::new($iterator).flatten()).with_precision_and_scale(*precision, *scale)?)), - Some(DataType::Time32(unit)) => { - Ok(match unit { - TimeUnit::Second => Arc::new(Time32SecondArray::from_iter( - [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator).flatten(), - )), - TimeUnit::Millisecond => Arc::new(Time32MillisecondArray::from_iter( - [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator).flatten(), - )), - _ => { - // don't know how to extract statistics, so return an empty array - new_empty_array(&DataType::Time32(unit.clone())) - } - }) - } - Some(DataType::Time64(unit)) => { - Ok(match unit { - TimeUnit::Microsecond => Arc::new(Time64MicrosecondArray::from_iter( - [<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator).flatten(), - )), - TimeUnit::Nanosecond => Arc::new(Time64NanosecondArray::from_iter( - [<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator).flatten(), - )), - _ => { - // don't know how to extract statistics, so return an empty array - new_empty_array(&DataType::Time64(unit.clone())) - } - }) - }, - Some(DataType::FixedSizeBinary(size)) => { - let mut builder = FixedSizeBinaryBuilder::new(*size); - let iterator = [<$stat_type_prefix FixedLenByteArrayDataPageStatsIterator>]::new($iterator); - for x in iterator { - for x in x.into_iter() { - let Some(x) = x else { - builder.append_null(); // no statistics value - continue; - }; - - if x.len() == *size as usize { - let _ = builder.append_value(x.data()); - } else { - log::debug!( - "FixedSizeBinary({}) statistics is a binary of size {}, ignoring it.", - size, - x.len(), - ); - builder.append_null(); - } - } - } - Ok(Arc::new(builder.finish())) - }, - _ => unimplemented!() - } - } - } -} - -/// Lookups up the parquet column by name -/// -/// Returns the parquet column index and the corresponding arrow field -pub(crate) fn parquet_column<'a>( - parquet_schema: &SchemaDescriptor, - arrow_schema: &'a Schema, - name: &str, -) -> Option<(usize, &'a FieldRef)> { - let (root_idx, field) = arrow_schema.fields.find(name)?; - if field.data_type().is_nested() { - // Nested fields are not supported and require non-trivial logic - // to correctly walk the parquet schema accounting for the - // logical type rules - - // - // For example a ListArray could correspond to anything from 1 to 3 levels - // in the parquet schema - return None; - } - - // This could be made more efficient (#TBD) - let parquet_idx = (0..parquet_schema.columns().len()) - .find(|x| parquet_schema.get_column_root_idx(*x) == root_idx)?; - Some((parquet_idx, field)) -} - -/// Extracts the min statistics from an iterator of [`ParquetStatistics`] to an -/// [`ArrayRef`] -/// -/// This is an internal helper -- see [`StatisticsConverter`] for public API -fn min_statistics<'a, I: Iterator>>( - data_type: &DataType, - iterator: I, -) -> Result { - get_statistics!(Min, data_type, iterator) -} - -/// Extracts the max statistics from an iterator of [`ParquetStatistics`] to an [`ArrayRef`] -/// -/// This is an internal helper -- see [`StatisticsConverter`] for public API -fn max_statistics<'a, I: Iterator>>( - data_type: &DataType, - iterator: I, -) -> Result { - get_statistics!(Max, data_type, iterator) -} - -/// Extracts the min statistics from an iterator -/// of parquet page [`Index`]'es to an [`ArrayRef`] -pub(crate) fn min_page_statistics<'a, I>( - data_type: Option<&DataType>, - iterator: I, -) -> Result -where - I: Iterator, -{ - get_data_page_statistics!(Min, data_type, iterator) -} - -/// Extracts the max statistics from an iterator -/// of parquet page [`Index`]'es to an [`ArrayRef`] -pub(crate) fn max_page_statistics<'a, I>( - data_type: Option<&DataType>, - iterator: I, -) -> Result -where - I: Iterator, -{ - get_data_page_statistics!(Max, data_type, iterator) -} - -/// Extracts the null count statistics from an iterator -/// of parquet page [`Index`]'es to an [`ArrayRef`] -/// -/// The returned Array is an [`UInt64Array`] -pub(crate) fn null_counts_page_statistics<'a, I>(iterator: I) -> Result -where - I: Iterator, -{ - let iter = iterator.flat_map(|(len, index)| match index { - Index::NONE => vec![None; len], - Index::BOOLEAN(native_index) => native_index - .indexes - .iter() - .map(|x| x.null_count.map(|x| x as u64)) - .collect::>(), - Index::INT32(native_index) => native_index - .indexes - .iter() - .map(|x| x.null_count.map(|x| x as u64)) - .collect::>(), - Index::INT64(native_index) => native_index - .indexes - .iter() - .map(|x| x.null_count.map(|x| x as u64)) - .collect::>(), - Index::FLOAT(native_index) => native_index - .indexes - .iter() - .map(|x| x.null_count.map(|x| x as u64)) - .collect::>(), - Index::DOUBLE(native_index) => native_index - .indexes - .iter() - .map(|x| x.null_count.map(|x| x as u64)) - .collect::>(), - Index::FIXED_LEN_BYTE_ARRAY(native_index) => native_index - .indexes - .iter() - .map(|x| x.null_count.map(|x| x as u64)) - .collect::>(), - Index::BYTE_ARRAY(native_index) => native_index - .indexes - .iter() - .map(|x| x.null_count.map(|x| x as u64)) - .collect::>(), - _ => unimplemented!(), - }); - - Ok(UInt64Array::from_iter(iter)) -} - -/// Extracts Parquet statistics as Arrow arrays -/// -/// This is used to convert Parquet statistics to Arrow arrays, with proper type -/// conversions. This information can be used for pruning parquet files or row -/// groups based on the statistics embedded in parquet files -/// -/// # Schemas -/// -/// The schema of the parquet file and the arrow schema are used to convert the -/// underlying statistics value (stored as a parquet value) into the -/// corresponding Arrow value. For example, Decimals are stored as binary in -/// parquet files. -/// -/// The parquet_schema and arrow_schema do not have to be identical (for -/// example, the columns may be in different orders and one or the other schemas -/// may have additional columns). The function [`parquet_column`] is used to -/// match the column in the parquet file to the column in the arrow schema. -#[derive(Debug)] -pub struct StatisticsConverter<'a> { - /// the index of the matched column in the parquet schema - parquet_index: Option, - /// The field (with data type) of the column in the arrow schema - arrow_field: &'a Field, -} - -impl<'a> StatisticsConverter<'a> { - /// Return the index of the column in the parquet file, if any - pub fn parquet_index(&self) -> Option { - self.parquet_index - } - - /// Return the arrow field of the column in the arrow schema - pub fn arrow_field(&self) -> &'a Field { - self.arrow_field - } - - /// Returns a [`UInt64Array`] with row counts for each row group - /// - /// # Return Value - /// - /// The returned array has no nulls, and has one value for each row group. - /// Each value is the number of rows in the row group. - /// - /// # Example - /// ```no_run - /// # use arrow::datatypes::Schema; - /// # use arrow_array::ArrayRef; - /// # use parquet::file::metadata::ParquetMetaData; - /// # use datafusion::datasource::physical_plan::parquet::StatisticsConverter; - /// # fn get_parquet_metadata() -> ParquetMetaData { unimplemented!() } - /// # fn get_arrow_schema() -> Schema { unimplemented!() } - /// // Given the metadata for a parquet file and the arrow schema - /// let metadata: ParquetMetaData = get_parquet_metadata(); - /// let arrow_schema: Schema = get_arrow_schema(); - /// let parquet_schema = metadata.file_metadata().schema_descr(); - /// // create a converter - /// let converter = StatisticsConverter::try_new("foo", &arrow_schema, parquet_schema) - /// .unwrap(); - /// // get the row counts for each row group - /// let row_counts = converter.row_group_row_counts(metadata - /// .row_groups() - /// .iter() - /// ); - /// ``` - pub fn row_group_row_counts(&self, metadatas: I) -> Result> - where - I: IntoIterator, - { - let Some(_) = self.parquet_index else { - return Ok(None); - }; - - let mut builder = UInt64Array::builder(10); - for metadata in metadatas.into_iter() { - let row_count = metadata.num_rows(); - let row_count: u64 = row_count.try_into().map_err(|e| { - internal_datafusion_err!( - "Parquet row count {row_count} too large to convert to u64: {e}" - ) - })?; - builder.append_value(row_count); - } - Ok(Some(builder.finish())) - } - - /// Create a new `StatisticsConverter` to extract statistics for a column - /// - /// Note if there is no corresponding column in the parquet file, the returned - /// arrays will be null. This can happen if the column is in the arrow - /// schema but not in the parquet schema due to schema evolution. - /// - /// See example on [`Self::row_group_mins`] for usage - /// - /// # Errors - /// - /// * If the column is not found in the arrow schema - pub fn try_new<'b>( - column_name: &'b str, - arrow_schema: &'a Schema, - parquet_schema: &'a SchemaDescriptor, - ) -> Result { - // ensure the requested column is in the arrow schema - let Some((_idx, arrow_field)) = arrow_schema.column_with_name(column_name) else { - return plan_err!( - "Column '{}' not found in schema for statistics conversion", - column_name - ); - }; - - // find the column in the parquet schema, if not, return a null array - let parquet_index = match parquet_column( - parquet_schema, - arrow_schema, - column_name, - ) { - Some((parquet_idx, matched_field)) => { - // sanity check that matching field matches the arrow field - if matched_field.as_ref() != arrow_field { - return internal_err!( - "Matched column '{:?}' does not match original matched column '{:?}'", - matched_field, - arrow_field - ); - } - Some(parquet_idx) - } - None => None, - }; - - Ok(Self { - parquet_index, - arrow_field, - }) - } - - /// Extract the minimum values from row group statistics in [`RowGroupMetaData`] - /// - /// # Return Value - /// - /// The returned array contains 1 value for each row group, in the same order as `metadatas` - /// - /// Each value is either - /// * the minimum value for the column - /// * a null value, if the statistics can not be extracted - /// - /// Note that a null value does NOT mean the min value was actually - /// `null` it means it the requested statistic is unknown - /// - /// # Errors - /// - /// Reasons for not being able to extract the statistics include: - /// * the column is not present in the parquet file - /// * statistics for the column are not present in the row group - /// * the stored statistic value can not be converted to the requested type - /// - /// # Example - /// ```no_run - /// # use arrow::datatypes::Schema; - /// # use arrow_array::ArrayRef; - /// # use parquet::file::metadata::ParquetMetaData; - /// # use datafusion::datasource::physical_plan::parquet::StatisticsConverter; - /// # fn get_parquet_metadata() -> ParquetMetaData { unimplemented!() } - /// # fn get_arrow_schema() -> Schema { unimplemented!() } - /// // Given the metadata for a parquet file and the arrow schema - /// let metadata: ParquetMetaData = get_parquet_metadata(); - /// let arrow_schema: Schema = get_arrow_schema(); - /// let parquet_schema = metadata.file_metadata().schema_descr(); - /// // create a converter - /// let converter = StatisticsConverter::try_new("foo", &arrow_schema, parquet_schema) - /// .unwrap(); - /// // get the minimum value for the column "foo" in the parquet file - /// let min_values: ArrayRef = converter - /// .row_group_mins(metadata.row_groups().iter()) - /// .unwrap(); - /// ``` - pub fn row_group_mins(&self, metadatas: I) -> Result - where - I: IntoIterator, - { - let data_type = self.arrow_field.data_type(); - - let Some(parquet_index) = self.parquet_index else { - return Ok(self.make_null_array(data_type, metadatas)); - }; - - let iter = metadatas - .into_iter() - .map(|x| x.column(parquet_index).statistics()); - min_statistics(data_type, iter) - } - - /// Extract the maximum values from row group statistics in [`RowGroupMetaData`] - /// - /// See docs on [`Self::row_group_mins`] for details - pub fn row_group_maxes(&self, metadatas: I) -> Result - where - I: IntoIterator, - { - let data_type = self.arrow_field.data_type(); - - let Some(parquet_index) = self.parquet_index else { - return Ok(self.make_null_array(data_type, metadatas)); - }; - - let iter = metadatas - .into_iter() - .map(|x| x.column(parquet_index).statistics()); - max_statistics(data_type, iter) - } - - /// Extract the null counts from row group statistics in [`RowGroupMetaData`] - /// - /// See docs on [`Self::row_group_mins`] for details - pub fn row_group_null_counts(&self, metadatas: I) -> Result - where - I: IntoIterator, - { - let Some(parquet_index) = self.parquet_index else { - let num_row_groups = metadatas.into_iter().count(); - return Ok(UInt64Array::from_iter( - std::iter::repeat(None).take(num_row_groups), - )); - }; - - let null_counts = metadatas - .into_iter() - .map(|x| x.column(parquet_index).statistics()) - .map(|s| s.map(|s| s.null_count())); - Ok(UInt64Array::from_iter(null_counts)) - } - - /// Extract the minimum values from Data Page statistics. - /// - /// In Parquet files, in addition to the Column Chunk level statistics - /// (stored for each column for each row group) there are also - /// optional statistics stored for each data page, as part of - /// the [`ParquetColumnIndex`]. - /// - /// Since a single Column Chunk is stored as one or more pages, - /// page level statistics can prune at a finer granularity. - /// - /// However since they are stored in a separate metadata - /// structure ([`Index`]) there is different code to extract them as - /// compared to arrow statistics. - /// - /// # Parameters: - /// - /// * `column_page_index`: The parquet column page indices, read from - /// `ParquetMetaData` column_index - /// - /// * `column_offset_index`: The parquet column offset indices, read from - /// `ParquetMetaData` offset_index - /// - /// * `row_group_indices`: The indices of the row groups, that are used to - /// extract the column page index and offset index on a per row group - /// per column basis. - /// - /// # Return Value - /// - /// The returned array contains 1 value for each `NativeIndex` - /// in the underlying `Index`es, in the same order as they appear - /// in `metadatas`. - /// - /// For example, if there are two `Index`es in `metadatas`: - /// 1. the first having `3` `PageIndex` entries - /// 2. the second having `2` `PageIndex` entries - /// - /// The returned array would have 5 rows. - /// - /// Each value is either: - /// * the minimum value for the page - /// * a null value, if the statistics can not be extracted - /// - /// Note that a null value does NOT mean the min value was actually - /// `null` it means it the requested statistic is unknown - /// - /// # Errors - /// - /// Reasons for not being able to extract the statistics include: - /// * the column is not present in the parquet file - /// * statistics for the pages are not present in the row group - /// * the stored statistic value can not be converted to the requested type - pub fn data_page_mins( - &self, - column_page_index: &ParquetColumnIndex, - column_offset_index: &ParquetOffsetIndex, - row_group_indices: I, - ) -> Result - where - I: IntoIterator, - { - let data_type = self.arrow_field.data_type(); - - let Some(parquet_index) = self.parquet_index else { - return Ok(self.make_null_array(data_type, row_group_indices)); - }; - - let iter = row_group_indices.into_iter().map(|rg_index| { - let column_page_index_per_row_group_per_column = - &column_page_index[*rg_index][parquet_index]; - let num_data_pages = &column_offset_index[*rg_index][parquet_index].len(); - - (*num_data_pages, column_page_index_per_row_group_per_column) - }); - - min_page_statistics(Some(data_type), iter) - } - - /// Extract the maximum values from Data Page statistics. - /// - /// See docs on [`Self::data_page_mins`] for details. - pub fn data_page_maxes( - &self, - column_page_index: &ParquetColumnIndex, - column_offset_index: &ParquetOffsetIndex, - row_group_indices: I, - ) -> Result - where - I: IntoIterator, - { - let data_type = self.arrow_field.data_type(); - - let Some(parquet_index) = self.parquet_index else { - return Ok(self.make_null_array(data_type, row_group_indices)); - }; - - let iter = row_group_indices.into_iter().map(|rg_index| { - let column_page_index_per_row_group_per_column = - &column_page_index[*rg_index][parquet_index]; - let num_data_pages = &column_offset_index[*rg_index][parquet_index].len(); - - (*num_data_pages, column_page_index_per_row_group_per_column) - }); - - max_page_statistics(Some(data_type), iter) - } - - /// Extract the null counts from Data Page statistics. - /// - /// The returned Array is an [`UInt64Array`] - /// - /// See docs on [`Self::data_page_mins`] for details. - pub fn data_page_null_counts( - &self, - column_page_index: &ParquetColumnIndex, - column_offset_index: &ParquetOffsetIndex, - row_group_indices: I, - ) -> Result - where - I: IntoIterator, - { - let Some(parquet_index) = self.parquet_index else { - let num_row_groups = row_group_indices.into_iter().count(); - return Ok(UInt64Array::from_iter( - std::iter::repeat(None).take(num_row_groups), - )); - }; - - let iter = row_group_indices.into_iter().map(|rg_index| { - let column_page_index_per_row_group_per_column = - &column_page_index[*rg_index][parquet_index]; - let num_data_pages = &column_offset_index[*rg_index][parquet_index].len(); - - (*num_data_pages, column_page_index_per_row_group_per_column) - }); - null_counts_page_statistics(iter) - } - - /// Returns an [`ArrayRef`] with row counts for each row group. - /// - /// This function iterates over the given row group indexes and computes - /// the row count for each page in the specified column. - /// - /// # Parameters: - /// - /// * `column_offset_index`: The parquet column offset indices, read from - /// `ParquetMetaData` offset_index - /// - /// * `row_group_metadatas`: The metadata slice of the row groups, read - /// from `ParquetMetaData` row_groups - /// - /// * `row_group_indices`: The indices of the row groups, that are used to - /// extract the column offset index on a per row group per column basis. - /// - /// See docs on [`Self::data_page_mins`] for details. - pub fn data_page_row_counts( - &self, - column_offset_index: &ParquetOffsetIndex, - row_group_metadatas: &'a [RowGroupMetaData], - row_group_indices: I, - ) -> Result> - where - I: IntoIterator, - { - let Some(parquet_index) = self.parquet_index else { - // no matching column found in parquet_index; - // thus we cannot extract page_locations in order to determine - // the row count on a per DataPage basis. - return Ok(None); - }; - - let mut row_count_total = Vec::new(); - for rg_idx in row_group_indices { - let page_locations = &column_offset_index[*rg_idx][parquet_index]; - - let row_count_per_page = page_locations.windows(2).map(|loc| { - Some(loc[1].first_row_index as u64 - loc[0].first_row_index as u64) - }); - - // append the last page row count - let num_rows_in_row_group = &row_group_metadatas[*rg_idx].num_rows(); - let row_count_per_page = row_count_per_page - .chain(std::iter::once(Some( - *num_rows_in_row_group as u64 - - page_locations.last().unwrap().first_row_index as u64, - ))) - .collect::>(); - - row_count_total.extend(row_count_per_page); - } - - Ok(Some(UInt64Array::from_iter(row_count_total))) - } - - /// Returns a null array of data_type with one element per row group - fn make_null_array(&self, data_type: &DataType, metadatas: I) -> ArrayRef - where - I: IntoIterator, - { - // column was in the arrow schema but not in the parquet schema, so return a null array - let num_row_groups = metadatas.into_iter().count(); - new_null_array(data_type, num_row_groups) - } -} - -#[cfg(test)] -mod test { - use super::*; - use arrow::compute::kernels::cast_utils::Parser; - use arrow::datatypes::{i256, Date32Type, Date64Type}; - use arrow_array::{ - new_empty_array, new_null_array, Array, BinaryArray, BooleanArray, Date32Array, - Date64Array, Decimal128Array, Decimal256Array, Float32Array, Float64Array, - Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, RecordBatch, - StringArray, StructArray, TimestampNanosecondArray, - }; - use arrow_schema::{Field, SchemaRef}; - use bytes::Bytes; - use datafusion_common::test_util::parquet_test_data; - use parquet::arrow::arrow_reader::ArrowReaderBuilder; - use parquet::arrow::arrow_writer::ArrowWriter; - use parquet::file::metadata::{ParquetMetaData, RowGroupMetaData}; - use parquet::file::properties::{EnabledStatistics, WriterProperties}; - use std::path::PathBuf; - use std::sync::Arc; - - // TODO error cases (with parquet statistics that are mismatched in expected type) - - #[test] - fn roundtrip_empty() { - let empty_bool_array = new_empty_array(&DataType::Boolean); - Test { - input: empty_bool_array.clone(), - expected_min: empty_bool_array.clone(), - expected_max: empty_bool_array.clone(), - } - .run() - } - - #[test] - fn roundtrip_bool() { - Test { - input: bool_array([ - // row group 1 - Some(true), - None, - Some(true), - // row group 2 - Some(true), - Some(false), - None, - // row group 3 - None, - None, - None, - ]), - expected_min: bool_array([Some(true), Some(false), None]), - expected_max: bool_array([Some(true), Some(true), None]), - } - .run() - } - - #[test] - fn roundtrip_int32() { - Test { - input: i32_array([ - // row group 1 - Some(1), - None, - Some(3), - // row group 2 - Some(0), - Some(5), - None, - // row group 3 - None, - None, - None, - ]), - expected_min: i32_array([Some(1), Some(0), None]), - expected_max: i32_array([Some(3), Some(5), None]), - } - .run() - } - - #[test] - fn roundtrip_int64() { - Test { - input: i64_array([ - // row group 1 - Some(1), - None, - Some(3), - // row group 2 - Some(0), - Some(5), - None, - // row group 3 - None, - None, - None, - ]), - expected_min: i64_array([Some(1), Some(0), None]), - expected_max: i64_array(vec![Some(3), Some(5), None]), - } - .run() - } - - #[test] - fn roundtrip_f32() { - Test { - input: f32_array([ - // row group 1 - Some(1.0), - None, - Some(3.0), - // row group 2 - Some(-1.0), - Some(5.0), - None, - // row group 3 - None, - None, - None, - ]), - expected_min: f32_array([Some(1.0), Some(-1.0), None]), - expected_max: f32_array([Some(3.0), Some(5.0), None]), - } - .run() - } - - #[test] - fn roundtrip_f64() { - Test { - input: f64_array([ - // row group 1 - Some(1.0), - None, - Some(3.0), - // row group 2 - Some(-1.0), - Some(5.0), - None, - // row group 3 - None, - None, - None, - ]), - expected_min: f64_array([Some(1.0), Some(-1.0), None]), - expected_max: f64_array([Some(3.0), Some(5.0), None]), - } - .run() - } - - #[test] - fn roundtrip_timestamp() { - Test { - input: timestamp_seconds_array( - [ - // row group 1 - Some(1), - None, - Some(3), - // row group 2 - Some(9), - Some(5), - None, - // row group 3 - None, - None, - None, - ], - None, - ), - expected_min: timestamp_seconds_array([Some(1), Some(5), None], None), - expected_max: timestamp_seconds_array([Some(3), Some(9), None], None), - } - .run(); - - Test { - input: timestamp_milliseconds_array( - [ - // row group 1 - Some(1), - None, - Some(3), - // row group 2 - Some(9), - Some(5), - None, - // row group 3 - None, - None, - None, - ], - None, - ), - expected_min: timestamp_milliseconds_array([Some(1), Some(5), None], None), - expected_max: timestamp_milliseconds_array([Some(3), Some(9), None], None), - } - .run(); - - Test { - input: timestamp_microseconds_array( - [ - // row group 1 - Some(1), - None, - Some(3), - // row group 2 - Some(9), - Some(5), - None, - // row group 3 - None, - None, - None, - ], - None, - ), - expected_min: timestamp_microseconds_array([Some(1), Some(5), None], None), - expected_max: timestamp_microseconds_array([Some(3), Some(9), None], None), - } - .run(); - - Test { - input: timestamp_nanoseconds_array( - [ - // row group 1 - Some(1), - None, - Some(3), - // row group 2 - Some(9), - Some(5), - None, - // row group 3 - None, - None, - None, - ], - None, - ), - expected_min: timestamp_nanoseconds_array([Some(1), Some(5), None], None), - expected_max: timestamp_nanoseconds_array([Some(3), Some(9), None], None), - } - .run() - } - - #[test] - fn roundtrip_timestamp_timezoned() { - Test { - input: timestamp_seconds_array( - [ - // row group 1 - Some(1), - None, - Some(3), - // row group 2 - Some(9), - Some(5), - None, - // row group 3 - None, - None, - None, - ], - Some("UTC"), - ), - expected_min: timestamp_seconds_array([Some(1), Some(5), None], Some("UTC")), - expected_max: timestamp_seconds_array([Some(3), Some(9), None], Some("UTC")), - } - .run(); - - Test { - input: timestamp_milliseconds_array( - [ - // row group 1 - Some(1), - None, - Some(3), - // row group 2 - Some(9), - Some(5), - None, - // row group 3 - None, - None, - None, - ], - Some("UTC"), - ), - expected_min: timestamp_milliseconds_array( - [Some(1), Some(5), None], - Some("UTC"), - ), - expected_max: timestamp_milliseconds_array( - [Some(3), Some(9), None], - Some("UTC"), - ), - } - .run(); - - Test { - input: timestamp_microseconds_array( - [ - // row group 1 - Some(1), - None, - Some(3), - // row group 2 - Some(9), - Some(5), - None, - // row group 3 - None, - None, - None, - ], - Some("UTC"), - ), - expected_min: timestamp_microseconds_array( - [Some(1), Some(5), None], - Some("UTC"), - ), - expected_max: timestamp_microseconds_array( - [Some(3), Some(9), None], - Some("UTC"), - ), - } - .run(); - - Test { - input: timestamp_nanoseconds_array( - [ - // row group 1 - Some(1), - None, - Some(3), - // row group 2 - Some(9), - Some(5), - None, - // row group 3 - None, - None, - None, - ], - Some("UTC"), - ), - expected_min: timestamp_nanoseconds_array( - [Some(1), Some(5), None], - Some("UTC"), - ), - expected_max: timestamp_nanoseconds_array( - [Some(3), Some(9), None], - Some("UTC"), - ), - } - .run() - } - - #[test] - fn roundtrip_decimal() { - Test { - input: Arc::new( - Decimal128Array::from(vec![ - // row group 1 - Some(100), - None, - Some(22000), - // row group 2 - Some(500000), - Some(330000), - None, - // row group 3 - None, - None, - None, - ]) - .with_precision_and_scale(9, 2) - .unwrap(), - ), - expected_min: Arc::new( - Decimal128Array::from(vec![Some(100), Some(330000), None]) - .with_precision_and_scale(9, 2) - .unwrap(), - ), - expected_max: Arc::new( - Decimal128Array::from(vec![Some(22000), Some(500000), None]) - .with_precision_and_scale(9, 2) - .unwrap(), - ), - } - .run(); - - Test { - input: Arc::new( - Decimal256Array::from(vec![ - // row group 1 - Some(i256::from(100)), - None, - Some(i256::from(22000)), - // row group 2 - Some(i256::MAX), - Some(i256::MIN), - None, - // row group 3 - None, - None, - None, - ]) - .with_precision_and_scale(76, 76) - .unwrap(), - ), - expected_min: Arc::new( - Decimal256Array::from(vec![Some(i256::from(100)), Some(i256::MIN), None]) - .with_precision_and_scale(76, 76) - .unwrap(), - ), - expected_max: Arc::new( - Decimal256Array::from(vec![ - Some(i256::from(22000)), - Some(i256::MAX), - None, - ]) - .with_precision_and_scale(76, 76) - .unwrap(), - ), - } - .run() - } - - #[test] - fn roundtrip_utf8() { - Test { - input: utf8_array([ - // row group 1 - Some("A"), - None, - Some("Q"), - // row group 2 - Some("ZZ"), - Some("AA"), - None, - // row group 3 - None, - None, - None, - ]), - expected_min: utf8_array([Some("A"), Some("AA"), None]), - expected_max: utf8_array([Some("Q"), Some("ZZ"), None]), - } - .run() - } - - #[test] - fn roundtrip_struct() { - let mut test = Test { - input: struct_array(vec![ - // row group 1 - (Some(true), Some(1)), - (None, None), - (Some(true), Some(3)), - // row group 2 - (Some(true), Some(0)), - (Some(false), Some(5)), - (None, None), - // row group 3 - (None, None), - (None, None), - (None, None), - ]), - expected_min: struct_array(vec![ - (Some(true), Some(1)), - (Some(true), Some(0)), - (None, None), - ]), - - expected_max: struct_array(vec![ - (Some(true), Some(3)), - (Some(true), Some(0)), - (None, None), - ]), - }; - // Due to https://github.com/apache/datafusion/issues/8334, - // statistics for struct arrays are not supported - test.expected_min = - new_null_array(test.input.data_type(), test.expected_min.len()); - test.expected_max = - new_null_array(test.input.data_type(), test.expected_min.len()); - test.run() - } - - #[test] - fn roundtrip_binary() { - Test { - input: Arc::new(BinaryArray::from_opt_vec(vec![ - // row group 1 - Some(b"A"), - None, - Some(b"Q"), - // row group 2 - Some(b"ZZ"), - Some(b"AA"), - None, - // row group 3 - None, - None, - None, - ])), - expected_min: Arc::new(BinaryArray::from_opt_vec(vec![ - Some(b"A"), - Some(b"AA"), - None, - ])), - expected_max: Arc::new(BinaryArray::from_opt_vec(vec![ - Some(b"Q"), - Some(b"ZZ"), - None, - ])), - } - .run() - } - - #[test] - fn roundtrip_date32() { - Test { - input: date32_array(vec![ - // row group 1 - Some("2021-01-01"), - None, - Some("2021-01-03"), - // row group 2 - Some("2021-01-01"), - Some("2021-01-05"), - None, - // row group 3 - None, - None, - None, - ]), - expected_min: date32_array(vec![ - Some("2021-01-01"), - Some("2021-01-01"), - None, - ]), - expected_max: date32_array(vec![ - Some("2021-01-03"), - Some("2021-01-05"), - None, - ]), - } - .run() - } - - #[test] - fn roundtrip_date64() { - Test { - input: date64_array(vec![ - // row group 1 - Some("2021-01-01"), - None, - Some("2021-01-03"), - // row group 2 - Some("2021-01-01"), - Some("2021-01-05"), - None, - // row group 3 - None, - None, - None, - ]), - expected_min: date64_array(vec![ - Some("2021-01-01"), - Some("2021-01-01"), - None, - ]), - expected_max: date64_array(vec![ - Some("2021-01-03"), - Some("2021-01-05"), - None, - ]), - } - .run() - } - - #[test] - fn roundtrip_large_binary_array() { - let input: Vec> = vec![ - // row group 1 - Some(b"A"), - None, - Some(b"Q"), - // row group 2 - Some(b"ZZ"), - Some(b"AA"), - None, - // row group 3 - None, - None, - None, - ]; - - let expected_min: Vec> = vec![Some(b"A"), Some(b"AA"), None]; - let expected_max: Vec> = vec![Some(b"Q"), Some(b"ZZ"), None]; - - Test { - input: large_binary_array(input), - expected_min: large_binary_array(expected_min), - expected_max: large_binary_array(expected_max), - } - .run(); - } - - #[test] - fn struct_and_non_struct() { - // Ensures that statistics for an array that appears *after* a struct - // array are not wrong - let struct_col = struct_array(vec![ - // row group 1 - (Some(true), Some(1)), - (None, None), - (Some(true), Some(3)), - ]); - let int_col = i32_array([Some(100), Some(200), Some(300)]); - let expected_min = i32_array([Some(100)]); - let expected_max = i32_array(vec![Some(300)]); - - // use a name that shadows a name in the struct column - match struct_col.data_type() { - DataType::Struct(fields) => { - assert_eq!(fields.get(1).unwrap().name(), "int_col") - } - _ => panic!("unexpected data type for struct column"), - }; - - let input_batch = RecordBatch::try_from_iter([ - ("struct_col", struct_col), - ("int_col", int_col), - ]) - .unwrap(); - - let schema = input_batch.schema(); - - let metadata = parquet_metadata(schema.clone(), input_batch); - let parquet_schema = metadata.file_metadata().schema_descr(); - - // read the int_col statistics - let (idx, _) = parquet_column(parquet_schema, &schema, "int_col").unwrap(); - assert_eq!(idx, 2); - - let row_groups = metadata.row_groups(); - let converter = - StatisticsConverter::try_new("int_col", &schema, parquet_schema).unwrap(); - - let min = converter.row_group_mins(row_groups.iter()).unwrap(); - assert_eq!( - &min, - &expected_min, - "Min. Statistics\n\n{}\n\n", - DisplayStats(row_groups) - ); - - let max = converter.row_group_maxes(row_groups.iter()).unwrap(); - assert_eq!( - &max, - &expected_max, - "Max. Statistics\n\n{}\n\n", - DisplayStats(row_groups) - ); - } - - #[test] - fn nan_in_stats() { - // /parquet-testing/data/nan_in_stats.parquet - // row_groups: 1 - // "x": Double({min: Some(1.0), max: Some(NaN), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - - TestFile::new("nan_in_stats.parquet") - .with_column(ExpectedColumn { - name: "x", - expected_min: Arc::new(Float64Array::from(vec![Some(1.0)])), - expected_max: Arc::new(Float64Array::from(vec![Some(f64::NAN)])), - }) - .run(); - } - - #[test] - fn alltypes_plain() { - // /parquet-testing/data/datapage_v1-snappy-compressed-checksum.parquet - // row_groups: 1 - // (has no statistics) - TestFile::new("alltypes_plain.parquet") - // No column statistics should be read as NULL, but with the right type - .with_column(ExpectedColumn { - name: "id", - expected_min: i32_array([None]), - expected_max: i32_array([None]), - }) - .with_column(ExpectedColumn { - name: "bool_col", - expected_min: bool_array([None]), - expected_max: bool_array([None]), - }) - .run(); - } - - #[test] - fn alltypes_tiny_pages() { - // /parquet-testing/data/alltypes_tiny_pages.parquet - // row_groups: 1 - // "id": Int32({min: Some(0), max: Some(7299), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "bool_col": Boolean({min: Some(false), max: Some(true), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "tinyint_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "smallint_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "int_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "bigint_col": Int64({min: Some(0), max: Some(90), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "float_col": Float({min: Some(0.0), max: Some(9.9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "double_col": Double({min: Some(0.0), max: Some(90.89999999999999), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "date_string_col": ByteArray({min: Some(ByteArray { data: "01/01/09" }), max: Some(ByteArray { data: "12/31/10" }), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "string_col": ByteArray({min: Some(ByteArray { data: "0" }), max: Some(ByteArray { data: "9" }), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "timestamp_col": Int96({min: None, max: None, distinct_count: None, null_count: 0, min_max_deprecated: true, min_max_backwards_compatible: true}) - // "year": Int32({min: Some(2009), max: Some(2010), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "month": Int32({min: Some(1), max: Some(12), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - TestFile::new("alltypes_tiny_pages.parquet") - .with_column(ExpectedColumn { - name: "id", - expected_min: i32_array([Some(0)]), - expected_max: i32_array([Some(7299)]), - }) - .with_column(ExpectedColumn { - name: "bool_col", - expected_min: bool_array([Some(false)]), - expected_max: bool_array([Some(true)]), - }) - .with_column(ExpectedColumn { - name: "tinyint_col", - expected_min: i8_array([Some(0)]), - expected_max: i8_array([Some(9)]), - }) - .with_column(ExpectedColumn { - name: "smallint_col", - expected_min: i16_array([Some(0)]), - expected_max: i16_array([Some(9)]), - }) - .with_column(ExpectedColumn { - name: "int_col", - expected_min: i32_array([Some(0)]), - expected_max: i32_array([Some(9)]), - }) - .with_column(ExpectedColumn { - name: "bigint_col", - expected_min: i64_array([Some(0)]), - expected_max: i64_array([Some(90)]), - }) - .with_column(ExpectedColumn { - name: "float_col", - expected_min: f32_array([Some(0.0)]), - expected_max: f32_array([Some(9.9)]), - }) - .with_column(ExpectedColumn { - name: "double_col", - expected_min: f64_array([Some(0.0)]), - expected_max: f64_array([Some(90.89999999999999)]), - }) - .with_column(ExpectedColumn { - name: "date_string_col", - expected_min: utf8_array([Some("01/01/09")]), - expected_max: utf8_array([Some("12/31/10")]), - }) - .with_column(ExpectedColumn { - name: "string_col", - expected_min: utf8_array([Some("0")]), - expected_max: utf8_array([Some("9")]), - }) - // File has no min/max for timestamp_col - .with_column(ExpectedColumn { - name: "timestamp_col", - expected_min: timestamp_nanoseconds_array([None], None), - expected_max: timestamp_nanoseconds_array([None], None), - }) - .with_column(ExpectedColumn { - name: "year", - expected_min: i32_array([Some(2009)]), - expected_max: i32_array([Some(2010)]), - }) - .with_column(ExpectedColumn { - name: "month", - expected_min: i32_array([Some(1)]), - expected_max: i32_array([Some(12)]), - }) - .run(); - } - - #[test] - fn fixed_length_decimal_legacy() { - // /parquet-testing/data/fixed_length_decimal_legacy.parquet - // row_groups: 1 - // "value": FixedLenByteArray({min: Some(FixedLenByteArray(ByteArray { data: Some(ByteBufferPtr { data: b"\0\0\0\0\0\xc8" }) })), max: Some(FixedLenByteArray(ByteArray { data: "\0\0\0\0\t`" })), distinct_count: None, null_count: 0, min_max_deprecated: true, min_max_backwards_compatible: true}) - - TestFile::new("fixed_length_decimal_legacy.parquet") - .with_column(ExpectedColumn { - name: "value", - expected_min: Arc::new( - Decimal128Array::from(vec![Some(200)]) - .with_precision_and_scale(13, 2) - .unwrap(), - ), - expected_max: Arc::new( - Decimal128Array::from(vec![Some(2400)]) - .with_precision_and_scale(13, 2) - .unwrap(), - ), - }) - .run(); - } - - const ROWS_PER_ROW_GROUP: usize = 3; - - /// Writes the input batch into a parquet file, with every every three rows as - /// their own row group, and compares the min/maxes to the expected values - struct Test { - input: ArrayRef, - expected_min: ArrayRef, - expected_max: ArrayRef, - } - - impl Test { - fn run(self) { - let Self { - input, - expected_min, - expected_max, - } = self; - - let input_batch = RecordBatch::try_from_iter([("c1", input)]).unwrap(); - - let schema = input_batch.schema(); - - let metadata = parquet_metadata(schema.clone(), input_batch); - let parquet_schema = metadata.file_metadata().schema_descr(); - - let row_groups = metadata.row_groups(); - - for field in schema.fields() { - if field.data_type().is_nested() { - let lookup = parquet_column(parquet_schema, &schema, field.name()); - assert_eq!(lookup, None); - continue; - } - - let converter = - StatisticsConverter::try_new(field.name(), &schema, parquet_schema) - .unwrap(); - - assert_eq!(converter.arrow_field, field.as_ref()); - - let mins = converter.row_group_mins(row_groups.iter()).unwrap(); - assert_eq!( - &mins, - &expected_min, - "Min. Statistics\n\n{}\n\n", - DisplayStats(row_groups) - ); - - let maxes = converter.row_group_maxes(row_groups.iter()).unwrap(); - assert_eq!( - &maxes, - &expected_max, - "Max. Statistics\n\n{}\n\n", - DisplayStats(row_groups) - ); - } - } - } - - /// Write the specified batches out as parquet and return the metadata - fn parquet_metadata(schema: SchemaRef, batch: RecordBatch) -> Arc { - let props = WriterProperties::builder() - .set_statistics_enabled(EnabledStatistics::Chunk) - .set_max_row_group_size(ROWS_PER_ROW_GROUP) - .build(); - - let mut buffer = Vec::new(); - let mut writer = ArrowWriter::try_new(&mut buffer, schema, Some(props)).unwrap(); - writer.write(&batch).unwrap(); - writer.close().unwrap(); - - let reader = ArrowReaderBuilder::try_new(Bytes::from(buffer)).unwrap(); - reader.metadata().clone() - } - - /// Formats the statistics nicely for display - struct DisplayStats<'a>(&'a [RowGroupMetaData]); - impl<'a> std::fmt::Display for DisplayStats<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let row_groups = self.0; - writeln!(f, " row_groups: {}", row_groups.len())?; - for rg in row_groups { - for col in rg.columns() { - if let Some(statistics) = col.statistics() { - writeln!(f, " {}: {:?}", col.column_path(), statistics)?; - } - } - } - Ok(()) - } - } - - struct ExpectedColumn { - name: &'static str, - expected_min: ArrayRef, - expected_max: ArrayRef, - } - - /// Reads statistics out of the specified, and compares them to the expected values - struct TestFile { - file_name: &'static str, - expected_columns: Vec, - } - - impl TestFile { - fn new(file_name: &'static str) -> Self { - Self { - file_name, - expected_columns: Vec::new(), - } - } - - fn with_column(mut self, column: ExpectedColumn) -> Self { - self.expected_columns.push(column); - self - } - - /// Reads the specified parquet file and validates that the expected min/max - /// values for the specified columns are as expected. - fn run(self) { - let path = PathBuf::from(parquet_test_data()).join(self.file_name); - let file = std::fs::File::open(path).unwrap(); - let reader = ArrowReaderBuilder::try_new(file).unwrap(); - let arrow_schema = reader.schema(); - let metadata = reader.metadata(); - let row_groups = metadata.row_groups(); - let parquet_schema = metadata.file_metadata().schema_descr(); - - for expected_column in self.expected_columns { - let ExpectedColumn { - name, - expected_min, - expected_max, - } = expected_column; - - let converter = - StatisticsConverter::try_new(name, arrow_schema, parquet_schema) - .unwrap(); - let actual_min = converter.row_group_mins(row_groups.iter()).unwrap(); - assert_eq!(&expected_min, &actual_min, "column {name}"); - - let actual_max = converter.row_group_maxes(row_groups.iter()).unwrap(); - assert_eq!(&expected_max, &actual_max, "column {name}"); - } - } - } - - fn bool_array(input: impl IntoIterator>) -> ArrayRef { - let array: BooleanArray = input.into_iter().collect(); - Arc::new(array) - } - - fn i8_array(input: impl IntoIterator>) -> ArrayRef { - let array: Int8Array = input.into_iter().collect(); - Arc::new(array) - } - - fn i16_array(input: impl IntoIterator>) -> ArrayRef { - let array: Int16Array = input.into_iter().collect(); - Arc::new(array) - } - - fn i32_array(input: impl IntoIterator>) -> ArrayRef { - let array: Int32Array = input.into_iter().collect(); - Arc::new(array) - } - - fn i64_array(input: impl IntoIterator>) -> ArrayRef { - let array: Int64Array = input.into_iter().collect(); - Arc::new(array) - } - - fn f32_array(input: impl IntoIterator>) -> ArrayRef { - let array: Float32Array = input.into_iter().collect(); - Arc::new(array) - } - - fn f64_array(input: impl IntoIterator>) -> ArrayRef { - let array: Float64Array = input.into_iter().collect(); - Arc::new(array) - } - - fn timestamp_seconds_array( - input: impl IntoIterator>, - timezone: Option<&str>, - ) -> ArrayRef { - let array: TimestampSecondArray = input.into_iter().collect(); - match timezone { - Some(tz) => Arc::new(array.with_timezone(tz)), - None => Arc::new(array), - } - } - - fn timestamp_milliseconds_array( - input: impl IntoIterator>, - timezone: Option<&str>, - ) -> ArrayRef { - let array: TimestampMillisecondArray = input.into_iter().collect(); - match timezone { - Some(tz) => Arc::new(array.with_timezone(tz)), - None => Arc::new(array), - } - } - - fn timestamp_microseconds_array( - input: impl IntoIterator>, - timezone: Option<&str>, - ) -> ArrayRef { - let array: TimestampMicrosecondArray = input.into_iter().collect(); - match timezone { - Some(tz) => Arc::new(array.with_timezone(tz)), - None => Arc::new(array), - } - } - - fn timestamp_nanoseconds_array( - input: impl IntoIterator>, - timezone: Option<&str>, - ) -> ArrayRef { - let array: TimestampNanosecondArray = input.into_iter().collect(); - match timezone { - Some(tz) => Arc::new(array.with_timezone(tz)), - None => Arc::new(array), - } - } - - fn utf8_array<'a>(input: impl IntoIterator>) -> ArrayRef { - let array: StringArray = input - .into_iter() - .map(|s| s.map(|s| s.to_string())) - .collect(); - Arc::new(array) - } - - // returns a struct array with columns "bool_col" and "int_col" with the specified values - fn struct_array(input: Vec<(Option, Option)>) -> ArrayRef { - let boolean: BooleanArray = input.iter().map(|(b, _i)| b).collect(); - let int: Int32Array = input.iter().map(|(_b, i)| i).collect(); - - let nullable = true; - let struct_array = StructArray::from(vec![ - ( - Arc::new(Field::new("bool_col", DataType::Boolean, nullable)), - Arc::new(boolean) as ArrayRef, - ), - ( - Arc::new(Field::new("int_col", DataType::Int32, nullable)), - Arc::new(int) as ArrayRef, - ), - ]); - Arc::new(struct_array) - } - - fn date32_array<'a>(input: impl IntoIterator>) -> ArrayRef { - let array = Date32Array::from( - input - .into_iter() - .map(|s| Date32Type::parse(s.unwrap_or_default())) - .collect::>(), - ); - Arc::new(array) - } - - fn date64_array<'a>(input: impl IntoIterator>) -> ArrayRef { - let array = Date64Array::from( - input - .into_iter() - .map(|s| Date64Type::parse(s.unwrap_or_default())) - .collect::>(), - ); - Arc::new(array) - } - - fn large_binary_array<'a>( - input: impl IntoIterator>, - ) -> ArrayRef { - let array = - LargeBinaryArray::from(input.into_iter().collect::>>()); - - Arc::new(array) - } -} diff --git a/datafusion/core/src/datasource/statistics.rs b/datafusion/core/src/datasource/statistics.rs index a243a1c3558f..8c789e461b08 100644 --- a/datafusion/core/src/datasource/statistics.rs +++ b/datafusion/core/src/datasource/statistics.rs @@ -18,7 +18,7 @@ use super::listing::PartitionedFile; use crate::arrow::datatypes::{Schema, SchemaRef}; use crate::error::Result; -use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator}; +use crate::functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; use crate::physical_plan::{Accumulator, ColumnStatistics, Statistics}; use arrow_schema::DataType; diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 9b889c37ab52..24704bc794c2 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -144,6 +144,7 @@ where /// /// ``` /// use datafusion::prelude::*; +/// # use datafusion::functions_aggregate::expr_fn::min; /// # use datafusion::{error::Result, assert_batches_eq}; /// # #[tokio::main] /// # async fn main() -> Result<()> { diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index cf5a184e3416..3bb0636652c0 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -52,6 +52,7 @@ //! ```rust //! # use datafusion::prelude::*; //! # use datafusion::error::Result; +//! # use datafusion::functions_aggregate::expr_fn::min; //! # use datafusion::arrow::record_batch::RecordBatch; //! //! # #[tokio::main] diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 590f9dc8fde1..cde8bb241ee4 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -272,39 +272,28 @@ fn is_non_distinct_count(agg_expr: &dyn AggregateExpr) -> bool { return true; } } - false } // TODO: Move this check into AggregateUDFImpl // https://github.com/apache/datafusion/issues/11153 fn is_min(agg_expr: &dyn AggregateExpr) -> bool { - if agg_expr.as_any().is::() { - return true; - } - if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { - if agg_expr.fun().name() == "min" { + if agg_expr.fun().name().to_lowercase() == "min" { return true; } } - false } // TODO: Move this check into AggregateUDFImpl // https://github.com/apache/datafusion/issues/11153 fn is_max(agg_expr: &dyn AggregateExpr) -> bool { - if agg_expr.as_any().is::() { - return true; - } - if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { - if agg_expr.fun().name() == "max" { + if agg_expr.fun().name().to_lowercase() == "max" { return true; } } - false } diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index cf9d33252ad9..faf8d01a97fd 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -621,6 +621,7 @@ mod tests { limit_exec, local_limit_exec, memory_exec, parquet_exec, parquet_exec_sorted, repartition_exec, sort_exec, sort_expr, sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec, spr_repartition_exec, union_exec, + RequirementsTestExec, }; use crate::physical_plan::{displayable, get_plan_string, Partitioning}; use crate::prelude::{SessionConfig, SessionContext}; @@ -2346,4 +2347,67 @@ mod tests { assert_optimized!(expected_input, expected_no_change, physical_plan, true); Ok(()) } + + #[tokio::test] + async fn test_push_with_required_input_ordering_prohibited() -> Result<()> { + // SortExec: expr=[b] <-- can't push this down + // RequiredInputOrder expr=[a] <-- this requires input sorted by a, and preserves the input order + // SortExec: expr=[a] + // MemoryExec + let schema = create_test_schema3()?; + let sort_exprs_a = vec![sort_expr("a", &schema)]; + let sort_exprs_b = vec![sort_expr("b", &schema)]; + let plan = memory_exec(&schema); + let plan = sort_exec(sort_exprs_a.clone(), plan); + let plan = RequirementsTestExec::new(plan) + .with_required_input_ordering(sort_exprs_a) + .with_maintains_input_order(true) + .into_arc(); + let plan = sort_exec(sort_exprs_b, plan); + + let expected_input = [ + "SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", + " RequiredInputOrderingExec", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + // should not be able to push shorts + let expected_no_change = expected_input; + assert_optimized!(expected_input, expected_no_change, plan, true); + Ok(()) + } + + // test when the required input ordering is satisfied so could push through + #[tokio::test] + async fn test_push_with_required_input_ordering_allowed() -> Result<()> { + // SortExec: expr=[a,b] <-- can push this down (as it is compatible with the required input ordering) + // RequiredInputOrder expr=[a] <-- this requires input sorted by a, and preserves the input order + // SortExec: expr=[a] + // MemoryExec + let schema = create_test_schema3()?; + let sort_exprs_a = vec![sort_expr("a", &schema)]; + let sort_exprs_ab = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let plan = memory_exec(&schema); + let plan = sort_exec(sort_exprs_a.clone(), plan); + let plan = RequirementsTestExec::new(plan) + .with_required_input_ordering(sort_exprs_a) + .with_maintains_input_order(true) + .into_arc(); + let plan = sort_exec(sort_exprs_ab, plan); + + let expected_input = [ + "SortExec: expr=[a@0 ASC,b@1 ASC], preserve_partitioning=[false]", + " RequiredInputOrderingExec", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + // should able to push shorts + let expected = [ + "RequiredInputOrderingExec", + " SortExec: expr=[a@0 ASC,b@1 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + assert_optimized!(expected_input, expected, plan, true); + Ok(()) + } } diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index 36ac4b22d594..3577e109b069 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -176,6 +176,7 @@ fn pushdown_requirement_to_children( || plan.as_any().is::() || is_limit(plan) || plan.as_any().is::() + || pushdown_would_violate_requirements(parent_required, plan.as_ref()) { // If the current plan is a leaf node or can not maintain any of the input ordering, can not pushed down requirements. // For RepartitionExec, we always choose to not push down the sort requirements even the RepartitionExec(input_partition=1) could maintain input ordering. @@ -211,6 +212,29 @@ fn pushdown_requirement_to_children( // TODO: Add support for Projection push down } +/// Return true if pushing the sort requirements through a node would violate +/// the input sorting requirements for the plan +fn pushdown_would_violate_requirements( + parent_required: LexRequirementRef, + child: &dyn ExecutionPlan, +) -> bool { + child + .required_input_ordering() + .iter() + .any(|child_required| { + let Some(child_required) = child_required.as_ref() else { + // no requirements, so pushing down would not violate anything + return false; + }; + // check if the plan's requirements would still e satisfied if we pushed + // down the parent requirements + child_required + .iter() + .zip(parent_required.iter()) + .all(|(c, p)| !c.compatible(p)) + }) +} + /// Determine children requirements: /// - If children requirements are more specific, do not push down parent /// requirements. diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 5320938d2eb8..55a0fa814552 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -17,6 +17,8 @@ //! Collection of testing utility functions that are leveraged by the query optimizer rules +use std::any::Any; +use std::fmt::Formatter; use std::sync::Arc; use crate::datasource::listing::PartitionedFile; @@ -47,10 +49,14 @@ use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use datafusion_physical_plan::displayable; use datafusion_physical_plan::tree_node::PlanContext; +use datafusion_physical_plan::{ + displayable, DisplayAs, DisplayFormatType, PlanProperties, +}; use async_trait::async_trait; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortRequirement; async fn register_current_csv( ctx: &SessionContext, @@ -354,6 +360,97 @@ pub fn sort_exec( Arc::new(SortExec::new(sort_exprs, input)) } +/// A test [`ExecutionPlan`] whose requirements can be configured. +#[derive(Debug)] +pub struct RequirementsTestExec { + required_input_ordering: Vec, + maintains_input_order: bool, + input: Arc, +} + +impl RequirementsTestExec { + pub fn new(input: Arc) -> Self { + Self { + required_input_ordering: vec![], + maintains_input_order: true, + input, + } + } + + /// sets the required input ordering + pub fn with_required_input_ordering( + mut self, + required_input_ordering: Vec, + ) -> Self { + self.required_input_ordering = required_input_ordering; + self + } + + /// set the maintains_input_order flag + pub fn with_maintains_input_order(mut self, maintains_input_order: bool) -> Self { + self.maintains_input_order = maintains_input_order; + self + } + + /// returns this ExecutionPlan as an Arc + pub fn into_arc(self) -> Arc { + Arc::new(self) + } +} + +impl DisplayAs for RequirementsTestExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "RequiredInputOrderingExec") + } +} + +impl ExecutionPlan for RequirementsTestExec { + fn name(&self) -> &str { + "RequiredInputOrderingExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + self.input.properties() + } + + fn required_input_ordering(&self) -> Vec>> { + let requirement = + PhysicalSortRequirement::from_sort_exprs(&self.required_input_ordering); + vec![Some(requirement)] + } + + fn maintains_input_order(&self) -> Vec { + vec![self.maintains_input_order] + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert_eq!(children.len(), 1); + Ok(RequirementsTestExec::new(children[0].clone()) + .with_required_input_ordering(self.required_input_ordering.clone()) + .with_maintains_input_order(self.maintains_input_order) + .into_arc()) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!("Test exec does not support execution") + } +} + /// A [`PlanContext`] object is susceptible to being left in an inconsistent state after /// untested mutable operations. It is crucial that there be no discrepancies between a plan /// associated with the root node and the plan generated after traversing all nodes diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 329d343f13fc..03e20b886e2c 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -59,8 +59,8 @@ use crate::physical_plan::unnest::UnnestExec; use crate::physical_plan::values::ValuesExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{ - aggregates, displayable, udaf, windows, AggregateExpr, ExecutionPlan, - ExecutionPlanProperties, InputOrderMode, Partitioning, PhysicalExpr, WindowExpr, + displayable, udaf, windows, AggregateExpr, ExecutionPlan, ExecutionPlanProperties, + InputOrderMode, Partitioning, PhysicalExpr, WindowExpr, }; use arrow::compute::SortOptions; @@ -1812,7 +1812,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( e: &Expr, name: impl Into, logical_input_schema: &DFSchema, - physical_input_schema: &Schema, + _physical_input_schema: &Schema, execution_props: &ExecutionProps, ) -> Result { match e { @@ -1840,28 +1840,6 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( == NullTreatment::IgnoreNulls; let (agg_expr, filter, order_by) = match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - let physical_sort_exprs = match order_by { - Some(exprs) => Some(create_physical_sort_exprs( - exprs, - logical_input_schema, - execution_props, - )?), - None => None, - }; - let ordering_reqs: Vec = - physical_sort_exprs.clone().unwrap_or(vec![]); - let agg_expr = aggregates::create_aggregate_expr( - fun, - *distinct, - &physical_args, - &ordering_reqs, - physical_input_schema, - name, - ignore_nulls, - )?; - (agg_expr, filter, physical_sort_exprs) - } AggregateFunctionDefinition::UDF(fun) => { let sort_exprs = order_by.clone().unwrap_or(vec![]); let physical_sort_exprs = match order_by { diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index d83a47ceb069..86cacbaa06d8 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -54,11 +54,11 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, - scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + cast, col, exists, expr, in_subquery, lit, out_ref_col, placeholder, scalar_subquery, + when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, sum}; +use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, max, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index c97621ec4d01..813862c4cc2f 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -32,13 +32,13 @@ use datafusion::physical_plan::{collect, InputOrderMode}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_common_runtime::SpawnedTask; -use datafusion_expr::type_coercion::aggregates::coerce_types; use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf; use datafusion_expr::{ - AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunctionDefinition, + BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, }; use datafusion_functions_aggregate::count::count_udaf; +use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf}; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -361,14 +361,14 @@ fn get_random_function( window_fn_map.insert( "min", ( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![arg.clone()], ), ); window_fn_map.insert( "max", ( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![arg.clone()], ), ); @@ -465,16 +465,7 @@ fn get_random_function( let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; let (window_fn, args) = window_fn_map.values().collect::>()[rand_fn_idx]; let mut args = args.clone(); - if let WindowFunctionDefinition::AggregateFunction(f) = window_fn { - if !args.is_empty() { - // Do type coercion first argument - let a = args[0].clone(); - let dt = a.data_type(schema.as_ref()).unwrap(); - let sig = f.signature(); - let coerced = coerce_types(f, &[dt], &sig).unwrap(); - args[0] = cast(a, schema, coerced[0].clone()).unwrap(); - } - } else if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn { + if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn { if !args.is_empty() { // Do type coercion first argument let a = args[0].clone(); diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index a2bdbe64aa43..5c712af80192 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -26,10 +26,14 @@ use datafusion::assert_batches_eq; use datafusion::physical_optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::streaming::PartitionStream; +use datafusion_execution::memory_pool::{ + GreedyMemoryPool, MemoryPool, TrackConsumersPool, +}; use datafusion_expr::{Expr, TableType}; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; use futures::StreamExt; use std::any::Any; +use std::num::NonZeroUsize; use std::sync::{Arc, OnceLock}; use tokio::fs::File; @@ -371,6 +375,39 @@ async fn oom_parquet_sink() { .await } +#[tokio::test] +async fn oom_with_tracked_consumer_pool() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.into_path().join("test.parquet"); + let _ = File::create(path.clone()).await.unwrap(); + + TestCase::new() + .with_config( + SessionConfig::new() + ) + .with_query(format!( + " + COPY (select * from t) + TO '{}' + STORED AS PARQUET OPTIONS (compression 'uncompressed'); + ", + path.to_string_lossy() + )) + .with_expected_errors(vec![ + "Failed to allocate additional", + "for ParquetSink(ArrowColumnWriter)", + "Resources exhausted with top memory consumers (across reservations) are: ParquetSink(ArrowColumnWriter)" + ]) + .with_memory_pool(Arc::new( + TrackConsumersPool::new( + GreedyMemoryPool::new(200_000), + NonZeroUsize::new(1).unwrap() + ) + )) + .run() + .await +} + /// Run the query with the specified memory limit, /// and verifies the expected errors are returned #[derive(Clone, Debug)] @@ -378,6 +415,7 @@ struct TestCase { query: Option, expected_errors: Vec, memory_limit: usize, + memory_pool: Option>, config: SessionConfig, scenario: Scenario, /// How should the disk manager (that allows spilling) be @@ -396,6 +434,7 @@ impl TestCase { expected_errors: vec![], memory_limit: 0, config: SessionConfig::new(), + memory_pool: None, scenario: Scenario::AccessLog, disk_manager_config: DiskManagerConfig::Disabled, expected_plan: vec![], @@ -425,6 +464,15 @@ impl TestCase { self } + /// Set the memory pool to be used + /// + /// This will override the memory_limit requested, + /// as the memory pool includes the limit. + fn with_memory_pool(mut self, memory_pool: Arc) -> Self { + self.memory_pool = Some(memory_pool); + self + } + /// Specify the configuration to use pub fn with_config(mut self, config: SessionConfig) -> Self { self.config = config; @@ -465,6 +513,7 @@ impl TestCase { query, expected_errors, memory_limit, + memory_pool, config, scenario, disk_manager_config, @@ -474,11 +523,15 @@ impl TestCase { let table = scenario.table(); - let rt_config = RuntimeConfig::new() + let mut rt_config = RuntimeConfig::new() // disk manager setting controls the spilling .with_disk_manager(disk_manager_config) .with_memory_limit(memory_limit, MEMORY_FRACTION); + if let Some(pool) = memory_pool { + rt_config = rt_config.with_memory_pool(pool); + }; + let runtime = RuntimeEnv::new(rt_config).unwrap(); // Configure execution diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs deleted file mode 100644 index 623f321ce152..000000000000 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ /dev/null @@ -1,2178 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! This file contains an end to end test of extracting statistics from parquet files. -//! It writes data into a parquet file, reads statistics and verifies they are correct - -use std::default::Default; -use std::fs::File; -use std::sync::Arc; - -use crate::parquet::{struct_array, Scenario}; -use arrow::compute::kernels::cast_utils::Parser; -use arrow::datatypes::{ - i256, Date32Type, Date64Type, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, -}; -use arrow_array::{ - make_array, new_null_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, - Date64Array, Decimal128Array, Decimal256Array, FixedSizeBinaryArray, Float16Array, - Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, - LargeBinaryArray, LargeStringArray, RecordBatch, StringArray, Time32MillisecondArray, - Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, -}; -use arrow_schema::{DataType, Field, Schema, TimeUnit}; -use datafusion::datasource::physical_plan::parquet::StatisticsConverter; -use half::f16; -use parquet::arrow::arrow_reader::{ - ArrowReaderBuilder, ArrowReaderOptions, ParquetRecordBatchReaderBuilder, -}; -use parquet::arrow::ArrowWriter; -use parquet::file::properties::{EnabledStatistics, WriterProperties}; - -use super::make_test_file_rg; - -#[derive(Debug, Default, Clone)] -struct Int64Case { - /// Number of nulls in the column - null_values: usize, - /// Non null values in the range `[no_null_values_start, - /// no_null_values_end]`, one value for each row - no_null_values_start: i64, - no_null_values_end: i64, - /// Number of rows per row group - row_per_group: usize, - /// if specified, overrides default statistics settings - enable_stats: Option, - /// If specified, the number of values in each page - data_page_row_count_limit: Option, -} - -impl Int64Case { - /// Return a record batch with i64 with Null values - /// The first no_null_values_end - no_null_values_start values - /// are non-null with the specified range, the rest are null - fn make_int64_batches_with_null(&self) -> RecordBatch { - let schema = - Arc::new(Schema::new(vec![Field::new("i64", DataType::Int64, true)])); - - let v64: Vec = - (self.no_null_values_start as _..self.no_null_values_end as _).collect(); - - RecordBatch::try_new( - schema, - vec![make_array( - Int64Array::from_iter( - v64.into_iter() - .map(Some) - .chain(std::iter::repeat(None).take(self.null_values)), - ) - .to_data(), - )], - ) - .unwrap() - } - - // Create a parquet file with the specified settings - pub fn build(&self) -> ParquetRecordBatchReaderBuilder { - let batches = vec![self.make_int64_batches_with_null()]; - build_parquet_file( - self.row_per_group, - self.enable_stats, - self.data_page_row_count_limit, - batches, - ) - } -} - -fn build_parquet_file( - row_per_group: usize, - enable_stats: Option, - data_page_row_count_limit: Option, - batches: Vec, -) -> ParquetRecordBatchReaderBuilder { - let mut output_file = tempfile::Builder::new() - .prefix("parquert_statistics_test") - .suffix(".parquet") - .tempfile() - .expect("tempfile creation"); - - let mut builder = WriterProperties::builder().set_max_row_group_size(row_per_group); - if let Some(enable_stats) = enable_stats { - builder = builder.set_statistics_enabled(enable_stats); - } - if let Some(data_page_row_count_limit) = data_page_row_count_limit { - builder = builder.set_data_page_row_count_limit(data_page_row_count_limit); - } - let props = builder.build(); - - let schema = batches[0].schema(); - - let mut writer = ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap(); - - // if we have a datapage limit send the batches in one at a time to give - // the writer a chance to be split into multiple pages - if data_page_row_count_limit.is_some() { - for batch in &batches { - for i in 0..batch.num_rows() { - writer.write(&batch.slice(i, 1)).expect("writing batch"); - } - } - } else { - for batch in &batches { - writer.write(batch).expect("writing batch"); - } - } - - let _file_meta = writer.close().unwrap(); - - let file = output_file.reopen().unwrap(); - let options = ArrowReaderOptions::new().with_page_index(true); - ArrowReaderBuilder::try_new_with_options(file, options).unwrap() -} - -/// Defines what data to create in a parquet file -#[derive(Debug, Clone, Copy)] -struct TestReader { - /// What data to create in the parquet file - scenario: Scenario, - /// Number of rows per row group - row_per_group: usize, -} - -impl TestReader { - /// Create a parquet file with the specified data, and return a - /// ParquetRecordBatchReaderBuilder opened to that file. - async fn build(self) -> ParquetRecordBatchReaderBuilder { - let TestReader { - scenario, - row_per_group, - } = self; - let file = make_test_file_rg(scenario, row_per_group).await; - - // open the file & get the reader - let file = file.reopen().unwrap(); - let options = ArrowReaderOptions::new().with_page_index(true); - ArrowReaderBuilder::try_new_with_options(file, options).unwrap() - } -} - -/// Which statistics should we check? -#[derive(Clone, Debug, Copy)] -enum Check { - /// Extract and check row group statistics - RowGroup, - /// Extract and check data page statistics - DataPage, - /// Extract and check both row group and data page statistics. - /// - /// Note if a row group contains a single data page, - /// the statistics for row groups and data pages are the same. - Both, -} - -impl Check { - fn row_group(&self) -> bool { - match self { - Self::RowGroup | Self::Both => true, - Self::DataPage => false, - } - } - - fn data_page(&self) -> bool { - match self { - Self::DataPage | Self::Both => true, - Self::RowGroup => false, - } - } -} - -/// Defines a test case for statistics extraction -struct Test<'a> { - /// The parquet file reader - reader: &'a ParquetRecordBatchReaderBuilder, - expected_min: ArrayRef, - expected_max: ArrayRef, - expected_null_counts: UInt64Array, - expected_row_counts: Option, - /// Which column to extract statistics from - column_name: &'static str, - /// What statistics should be checked? - check: Check, -} - -impl<'a> Test<'a> { - fn run(self) { - let converter = StatisticsConverter::try_new( - self.column_name, - self.reader.schema(), - self.reader.parquet_schema(), - ) - .unwrap(); - - self.run_checks(converter); - } - - fn run_with_schema(self, schema: &Schema) { - let converter = StatisticsConverter::try_new( - self.column_name, - schema, - self.reader.parquet_schema(), - ) - .unwrap(); - - self.run_checks(converter); - } - - fn run_checks(self, converter: StatisticsConverter) { - let Self { - reader, - expected_min, - expected_max, - expected_null_counts, - expected_row_counts, - column_name, - check, - } = self; - - let row_groups = reader.metadata().row_groups(); - - if check.data_page() { - let column_page_index = reader - .metadata() - .column_index() - .expect("File should have column page indices"); - - let column_offset_index = reader - .metadata() - .offset_index() - .expect("File should have column offset indices"); - - let row_group_indices: Vec<_> = (0..row_groups.len()).collect(); - - let min = converter - .data_page_mins( - column_page_index, - column_offset_index, - &row_group_indices, - ) - .unwrap(); - assert_eq!( - &min, &expected_min, - "{column_name}: Mismatch with expected data page minimums" - ); - - let max = converter - .data_page_maxes( - column_page_index, - column_offset_index, - &row_group_indices, - ) - .unwrap(); - assert_eq!( - &max, &expected_max, - "{column_name}: Mismatch with expected data page maximum" - ); - - let null_counts = converter - .data_page_null_counts( - column_page_index, - column_offset_index, - &row_group_indices, - ) - .unwrap(); - - assert_eq!( - &null_counts, &expected_null_counts, - "{column_name}: Mismatch with expected data page null counts. \ - Actual: {null_counts:?}. Expected: {expected_null_counts:?}" - ); - - let row_counts = converter - .data_page_row_counts(column_offset_index, row_groups, &row_group_indices) - .unwrap(); - assert_eq!( - row_counts, expected_row_counts, - "{column_name}: Mismatch with expected row counts. \ - Actual: {row_counts:?}. Expected: {expected_row_counts:?}" - ); - } - - if check.row_group() { - let min = converter.row_group_mins(row_groups).unwrap(); - assert_eq!( - &min, &expected_min, - "{column_name}: Mismatch with expected minimums" - ); - - let max = converter.row_group_maxes(row_groups).unwrap(); - assert_eq!( - &max, &expected_max, - "{column_name}: Mismatch with expected maximum" - ); - - let null_counts = converter.row_group_null_counts(row_groups).unwrap(); - assert_eq!( - &null_counts, &expected_null_counts, - "{column_name}: Mismatch with expected null counts. \ - Actual: {null_counts:?}. Expected: {expected_null_counts:?}" - ); - - let row_counts = converter - .row_group_row_counts(reader.metadata().row_groups().iter()) - .unwrap(); - assert_eq!( - row_counts, expected_row_counts, - "{column_name}: Mismatch with expected row counts. \ - Actual: {row_counts:?}. Expected: {expected_row_counts:?}" - ); - } - } - - /// Run the test and expect a column not found error - fn run_col_not_found(self) { - let Self { - reader, - expected_min: _, - expected_max: _, - expected_null_counts: _, - expected_row_counts: _, - column_name, - .. - } = self; - - let converter = StatisticsConverter::try_new( - column_name, - reader.schema(), - reader.parquet_schema(), - ); - - assert!(converter.is_err()); - } -} - -// TESTS -// -// Remaining cases -// f64::NAN -// - Using truncated statistics ("exact min value" and "exact max value" https://docs.rs/parquet/latest/parquet/file/statistics/enum.Statistics.html#method.max_is_exact) - -#[tokio::test] -async fn test_one_row_group_without_null() { - let reader = Int64Case { - null_values: 0, - no_null_values_start: 4, - no_null_values_end: 7, - row_per_group: 20, - ..Default::default() - } - .build(); - - Test { - reader: &reader, - // min is 4 - expected_min: Arc::new(Int64Array::from(vec![4])), - // max is 6 - expected_max: Arc::new(Int64Array::from(vec![6])), - // no nulls - expected_null_counts: UInt64Array::from(vec![0]), - // 3 rows - expected_row_counts: Some(UInt64Array::from(vec![3])), - column_name: "i64", - check: Check::Both, - } - .run() -} - -#[tokio::test] -async fn test_one_row_group_with_null_and_negative() { - let reader = Int64Case { - null_values: 2, - no_null_values_start: -1, - no_null_values_end: 5, - row_per_group: 20, - ..Default::default() - } - .build(); - - Test { - reader: &reader, - // min is -1 - expected_min: Arc::new(Int64Array::from(vec![-1])), - // max is 4 - expected_max: Arc::new(Int64Array::from(vec![4])), - // 2 nulls - expected_null_counts: UInt64Array::from(vec![2]), - // 8 rows - expected_row_counts: Some(UInt64Array::from(vec![8])), - column_name: "i64", - check: Check::Both, - } - .run() -} - -#[tokio::test] -async fn test_two_row_group_with_null() { - let reader = Int64Case { - null_values: 2, - no_null_values_start: 4, - no_null_values_end: 17, - row_per_group: 10, - ..Default::default() - } - .build(); - - Test { - reader: &reader, - // mins are [4, 14] - expected_min: Arc::new(Int64Array::from(vec![4, 14])), - // maxes are [13, 16] - expected_max: Arc::new(Int64Array::from(vec![13, 16])), - // nulls are [0, 2] - expected_null_counts: UInt64Array::from(vec![0, 2]), - // row counts are [10, 5] - expected_row_counts: Some(UInt64Array::from(vec![10, 5])), - column_name: "i64", - check: Check::Both, - } - .run() -} - -#[tokio::test] -async fn test_two_row_groups_with_all_nulls_in_one() { - let reader = Int64Case { - null_values: 4, - no_null_values_start: -2, - no_null_values_end: 2, - row_per_group: 5, - ..Default::default() - } - .build(); - - Test { - reader: &reader, - // mins are [-2, null] - expected_min: Arc::new(Int64Array::from(vec![Some(-2), None])), - // maxes are [1, null] - expected_max: Arc::new(Int64Array::from(vec![Some(1), None])), - // nulls are [1, 3] - expected_null_counts: UInt64Array::from(vec![1, 3]), - // row counts are [5, 3] - expected_row_counts: Some(UInt64Array::from(vec![5, 3])), - column_name: "i64", - check: Check::Both, - } - .run() -} - -#[tokio::test] -async fn test_multiple_data_pages_nulls_and_negatives() { - let reader = Int64Case { - null_values: 3, - no_null_values_start: -1, - no_null_values_end: 10, - row_per_group: 20, - // limit page row count to 4 - data_page_row_count_limit: Some(4), - enable_stats: Some(EnabledStatistics::Page), - } - .build(); - - // Data layout looks like this: - // - // page 0: [-1, 0, 1, 2] - // page 1: [3, 4, 5, 6] - // page 2: [7, 8, 9, null] - // page 3: [null, null] - Test { - reader: &reader, - expected_min: Arc::new(Int64Array::from(vec![Some(-1), Some(3), Some(7), None])), - expected_max: Arc::new(Int64Array::from(vec![Some(2), Some(6), Some(9), None])), - expected_null_counts: UInt64Array::from(vec![0, 0, 1, 2]), - expected_row_counts: Some(UInt64Array::from(vec![4, 4, 4, 2])), - column_name: "i64", - check: Check::DataPage, - } - .run() -} - -#[tokio::test] -async fn test_data_page_stats_with_all_null_page() { - for data_type in &[ - DataType::Boolean, - DataType::UInt64, - DataType::UInt32, - DataType::UInt16, - DataType::UInt8, - DataType::Int64, - DataType::Int32, - DataType::Int16, - DataType::Int8, - DataType::Float16, - DataType::Float32, - DataType::Float64, - DataType::Date32, - DataType::Date64, - DataType::Time32(TimeUnit::Millisecond), - DataType::Time32(TimeUnit::Second), - DataType::Time64(TimeUnit::Microsecond), - DataType::Time64(TimeUnit::Nanosecond), - DataType::Timestamp(TimeUnit::Second, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Binary, - DataType::LargeBinary, - DataType::FixedSizeBinary(3), - DataType::Utf8, - DataType::LargeUtf8, - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), - DataType::Decimal128(8, 2), // as INT32 - DataType::Decimal128(10, 2), // as INT64 - DataType::Decimal128(20, 2), // as FIXED_LEN_BYTE_ARRAY - DataType::Decimal256(8, 2), // as INT32 - DataType::Decimal256(10, 2), // as INT64 - DataType::Decimal256(20, 2), // as FIXED_LEN_BYTE_ARRAY - ] { - let batch = - RecordBatch::try_from_iter(vec![("col", new_null_array(data_type, 4))]) - .expect("record batch creation"); - - let reader = - build_parquet_file(4, Some(EnabledStatistics::Page), Some(4), vec![batch]); - - let expected_data_type = match data_type { - DataType::Dictionary(_, value_type) => value_type.as_ref(), - _ => data_type, - }; - - // There is one data page with 4 nulls - // The statistics should be present but null - Test { - reader: &reader, - expected_min: new_null_array(expected_data_type, 1), - expected_max: new_null_array(expected_data_type, 1), - expected_null_counts: UInt64Array::from(vec![4]), - expected_row_counts: Some(UInt64Array::from(vec![4])), - column_name: "col", - check: Check::DataPage, - } - .run() - } -} - -/////////////// MORE GENERAL TESTS ////////////////////// -// . Many columns in a file -// . Different data types -// . Different row group sizes - -// Four different integer types -#[tokio::test] -async fn test_int_64() { - // This creates a parquet files of 4 columns named "i8", "i16", "i32", "i64" - let reader = TestReader { - scenario: Scenario::Int, - row_per_group: 5, - } - .build() - .await; - - // since each row has only one data page, the statistics are the same - Test { - reader: &reader, - // mins are [-5, -4, 0, 5] - expected_min: Arc::new(Int64Array::from(vec![-5, -4, 0, 5])), - // maxes are [-1, 0, 4, 9] - expected_max: Arc::new(Int64Array::from(vec![-1, 0, 4, 9])), - // nulls are [0, 0, 0, 0] - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), - // row counts are [5, 5, 5, 5] - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), - column_name: "i64", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_int_32() { - // This creates a parquet files of 4 columns named "i8", "i16", "i32", "i64" - let reader = TestReader { - scenario: Scenario::Int, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - // mins are [-5, -4, 0, 5] - expected_min: Arc::new(Int32Array::from(vec![-5, -4, 0, 5])), - // maxes are [-1, 0, 4, 9] - expected_max: Arc::new(Int32Array::from(vec![-1, 0, 4, 9])), - // nulls are [0, 0, 0, 0] - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), - // row counts are [5, 5, 5, 5] - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), - column_name: "i32", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_int_16() { - // This creates a parquet files of 4 columns named "i8", "i16", "i32", "i64" - let reader = TestReader { - scenario: Scenario::Int, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - // mins are [-5, -4, 0, 5] - expected_min: Arc::new(Int16Array::from(vec![-5, -4, 0, 5])), // panic here because the actual data is Int32Array - // maxes are [-1, 0, 4, 9] - expected_max: Arc::new(Int16Array::from(vec![-1, 0, 4, 9])), - // nulls are [0, 0, 0, 0] - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), - // row counts are [5, 5, 5, 5] - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), - column_name: "i16", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_int_8() { - // This creates a parquet files of 4 columns named "i8", "i16", "i32", "i64" - let reader = TestReader { - scenario: Scenario::Int, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - // mins are [-5, -4, 0, 5] - expected_min: Arc::new(Int8Array::from(vec![-5, -4, 0, 5])), // panic here because the actual data is Int32Array - // maxes are [-1, 0, 4, 9] - expected_max: Arc::new(Int8Array::from(vec![-1, 0, 4, 9])), - // nulls are [0, 0, 0, 0] - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), - // row counts are [5, 5, 5, 5] - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), - column_name: "i8", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_float_16() { - // This creates a parquet files of 1 column named f - let reader = TestReader { - scenario: Scenario::Float16, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - // mins are [-5, -4, 0, 5] - expected_min: Arc::new(Float16Array::from(vec![ - f16::from_f32(-5.), - f16::from_f32(-4.), - f16::from_f32(-0.), - f16::from_f32(5.), - ])), - // maxes are [-1, 0, 4, 9] - expected_max: Arc::new(Float16Array::from(vec![ - f16::from_f32(-1.), - f16::from_f32(0.), - f16::from_f32(4.), - f16::from_f32(9.), - ])), - // nulls are [0, 0, 0, 0] - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), - // row counts are [5, 5, 5, 5] - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), - column_name: "f", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_float_32() { - // This creates a parquet files of 1 column named f - let reader = TestReader { - scenario: Scenario::Float32, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - // mins are [-5, -4, 0, 5] - expected_min: Arc::new(Float32Array::from(vec![-5., -4., -0., 5.0])), - // maxes are [-1, 0, 4, 9] - expected_max: Arc::new(Float32Array::from(vec![-1., 0., 4., 9.])), - // nulls are [0, 0, 0, 0] - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), - // row counts are [5, 5, 5, 5] - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), - column_name: "f", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_float_64() { - // This creates a parquet files of 1 column named f - let reader = TestReader { - scenario: Scenario::Float64, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - // mins are [-5, -4, 0, 5] - expected_min: Arc::new(Float64Array::from(vec![-5., -4., -0., 5.0])), - // maxes are [-1, 0, 4, 9] - expected_max: Arc::new(Float64Array::from(vec![-1., 0., 4., 9.])), - // nulls are [0, 0, 0, 0] - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), - // row counts are [5, 5, 5, 5] - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), - column_name: "f", - check: Check::Both, - } - .run(); -} - -// timestamp -#[tokio::test] -async fn test_timestamp() { - // This creates a parquet files of 9 columns named "nanos", "nanos_timezoned", "micros", "micros_timezoned", "millis", "millis_timezoned", "seconds", "seconds_timezoned", "names" - // "nanos" --> TimestampNanosecondArray - // "nanos_timezoned" --> TimestampNanosecondArray - // "micros" --> TimestampMicrosecondArray - // "micros_timezoned" --> TimestampMicrosecondArray - // "millis" --> TimestampMillisecondArray - // "millis_timezoned" --> TimestampMillisecondArray - // "seconds" --> TimestampSecondArray - // "seconds_timezoned" --> TimestampSecondArray - // "names" --> StringArray - // - // The file is created by 4 record batches, each has 5 rows. - // Since the row group size is set to 5, those 4 batches will go into 4 row groups - // This creates a parquet files of 4 columns named "nanos", "nanos_timezoned", "micros", "micros_timezoned", "millis", "millis_timezoned", "seconds", "seconds_timezoned" - let reader = TestReader { - scenario: Scenario::Timestamps, - row_per_group: 5, - } - .build() - .await; - - let tz = "Pacific/Efate"; - - Test { - reader: &reader, - expected_min: Arc::new(TimestampNanosecondArray::from(vec![ - TimestampNanosecondType::parse("2020-01-01T01:01:01"), - TimestampNanosecondType::parse("2020-01-01T01:01:11"), - TimestampNanosecondType::parse("2020-01-01T01:11:01"), - TimestampNanosecondType::parse("2020-01-11T01:01:01"), - ])), - expected_max: Arc::new(TimestampNanosecondArray::from(vec![ - TimestampNanosecondType::parse("2020-01-02T01:01:01"), - TimestampNanosecondType::parse("2020-01-02T01:01:11"), - TimestampNanosecondType::parse("2020-01-02T01:11:01"), - TimestampNanosecondType::parse("2020-01-12T01:01:01"), - ])), - // nulls are [1, 1, 1, 1] - expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), - // row counts are [5, 5, 5, 5] - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), - column_name: "nanos", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new( - TimestampNanosecondArray::from(vec![ - TimestampNanosecondType::parse("2020-01-01T01:01:01"), - TimestampNanosecondType::parse("2020-01-01T01:01:11"), - TimestampNanosecondType::parse("2020-01-01T01:11:01"), - TimestampNanosecondType::parse("2020-01-11T01:01:01"), - ]) - .with_timezone(tz), - ), - expected_max: Arc::new( - TimestampNanosecondArray::from(vec![ - TimestampNanosecondType::parse("2020-01-02T01:01:01"), - TimestampNanosecondType::parse("2020-01-02T01:01:11"), - TimestampNanosecondType::parse("2020-01-02T01:11:01"), - TimestampNanosecondType::parse("2020-01-12T01:01:01"), - ]) - .with_timezone(tz), - ), - // nulls are [1, 1, 1, 1] - expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), - // row counts are [5, 5, 5, 5] - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), - column_name: "nanos_timezoned", - check: Check::Both, - } - .run(); - - // micros - Test { - reader: &reader, - expected_min: Arc::new(TimestampMicrosecondArray::from(vec![ - TimestampMicrosecondType::parse("2020-01-01T01:01:01"), - TimestampMicrosecondType::parse("2020-01-01T01:01:11"), - TimestampMicrosecondType::parse("2020-01-01T01:11:01"), - TimestampMicrosecondType::parse("2020-01-11T01:01:01"), - ])), - expected_max: Arc::new(TimestampMicrosecondArray::from(vec![ - TimestampMicrosecondType::parse("2020-01-02T01:01:01"), - TimestampMicrosecondType::parse("2020-01-02T01:01:11"), - TimestampMicrosecondType::parse("2020-01-02T01:11:01"), - TimestampMicrosecondType::parse("2020-01-12T01:01:01"), - ])), - expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), - column_name: "micros", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new( - TimestampMicrosecondArray::from(vec![ - TimestampMicrosecondType::parse("2020-01-01T01:01:01"), - TimestampMicrosecondType::parse("2020-01-01T01:01:11"), - TimestampMicrosecondType::parse("2020-01-01T01:11:01"), - TimestampMicrosecondType::parse("2020-01-11T01:01:01"), - ]) - .with_timezone(tz), - ), - expected_max: Arc::new( - TimestampMicrosecondArray::from(vec![ - TimestampMicrosecondType::parse("2020-01-02T01:01:01"), - TimestampMicrosecondType::parse("2020-01-02T01:01:11"), - TimestampMicrosecondType::parse("2020-01-02T01:11:01"), - TimestampMicrosecondType::parse("2020-01-12T01:01:01"), - ]) - .with_timezone(tz), - ), - // nulls are [1, 1, 1, 1] - expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), - // row counts are [5, 5, 5, 5] - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), - column_name: "micros_timezoned", - check: Check::Both, - } - .run(); - - // millis - Test { - reader: &reader, - expected_min: Arc::new(TimestampMillisecondArray::from(vec![ - TimestampMillisecondType::parse("2020-01-01T01:01:01"), - TimestampMillisecondType::parse("2020-01-01T01:01:11"), - TimestampMillisecondType::parse("2020-01-01T01:11:01"), - TimestampMillisecondType::parse("2020-01-11T01:01:01"), - ])), - expected_max: Arc::new(TimestampMillisecondArray::from(vec![ - TimestampMillisecondType::parse("2020-01-02T01:01:01"), - TimestampMillisecondType::parse("2020-01-02T01:01:11"), - TimestampMillisecondType::parse("2020-01-02T01:11:01"), - TimestampMillisecondType::parse("2020-01-12T01:01:01"), - ])), - expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), - column_name: "millis", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new( - TimestampMillisecondArray::from(vec![ - TimestampMillisecondType::parse("2020-01-01T01:01:01"), - TimestampMillisecondType::parse("2020-01-01T01:01:11"), - TimestampMillisecondType::parse("2020-01-01T01:11:01"), - TimestampMillisecondType::parse("2020-01-11T01:01:01"), - ]) - .with_timezone(tz), - ), - expected_max: Arc::new( - TimestampMillisecondArray::from(vec![ - TimestampMillisecondType::parse("2020-01-02T01:01:01"), - TimestampMillisecondType::parse("2020-01-02T01:01:11"), - TimestampMillisecondType::parse("2020-01-02T01:11:01"), - TimestampMillisecondType::parse("2020-01-12T01:01:01"), - ]) - .with_timezone(tz), - ), - // nulls are [1, 1, 1, 1] - expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), - // row counts are [5, 5, 5, 5] - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), - column_name: "millis_timezoned", - check: Check::Both, - } - .run(); - - // seconds - Test { - reader: &reader, - expected_min: Arc::new(TimestampSecondArray::from(vec![ - TimestampSecondType::parse("2020-01-01T01:01:01"), - TimestampSecondType::parse("2020-01-01T01:01:11"), - TimestampSecondType::parse("2020-01-01T01:11:01"), - TimestampSecondType::parse("2020-01-11T01:01:01"), - ])), - expected_max: Arc::new(TimestampSecondArray::from(vec![ - TimestampSecondType::parse("2020-01-02T01:01:01"), - TimestampSecondType::parse("2020-01-02T01:01:11"), - TimestampSecondType::parse("2020-01-02T01:11:01"), - TimestampSecondType::parse("2020-01-12T01:01:01"), - ])), - expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), - column_name: "seconds", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new( - TimestampSecondArray::from(vec![ - TimestampSecondType::parse("2020-01-01T01:01:01"), - TimestampSecondType::parse("2020-01-01T01:01:11"), - TimestampSecondType::parse("2020-01-01T01:11:01"), - TimestampSecondType::parse("2020-01-11T01:01:01"), - ]) - .with_timezone(tz), - ), - expected_max: Arc::new( - TimestampSecondArray::from(vec![ - TimestampSecondType::parse("2020-01-02T01:01:01"), - TimestampSecondType::parse("2020-01-02T01:01:11"), - TimestampSecondType::parse("2020-01-02T01:11:01"), - TimestampSecondType::parse("2020-01-12T01:01:01"), - ]) - .with_timezone(tz), - ), - // nulls are [1, 1, 1, 1] - expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), - // row counts are [5, 5, 5, 5] - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), - column_name: "seconds_timezoned", - check: Check::Both, - } - .run(); -} - -// timestamp with different row group sizes -#[tokio::test] -async fn test_timestamp_diff_rg_sizes() { - // This creates a parquet files of 9 columns named "nanos", "nanos_timezoned", "micros", "micros_timezoned", "millis", "millis_timezoned", "seconds", "seconds_timezoned", "names" - // "nanos" --> TimestampNanosecondArray - // "nanos_timezoned" --> TimestampNanosecondArray - // "micros" --> TimestampMicrosecondArray - // "micros_timezoned" --> TimestampMicrosecondArray - // "millis" --> TimestampMillisecondArray - // "millis_timezoned" --> TimestampMillisecondArray - // "seconds" --> TimestampSecondArray - // "seconds_timezoned" --> TimestampSecondArray - // "names" --> StringArray - // - // The file is created by 4 record batches (each has a null row), each has 5 rows but then will be split into 3 row groups with size 8, 8, 4 - let reader = TestReader { - scenario: Scenario::Timestamps, - row_per_group: 8, // note that the row group size is 8 - } - .build() - .await; - - let tz = "Pacific/Efate"; - - Test { - reader: &reader, - expected_min: Arc::new(TimestampNanosecondArray::from(vec![ - TimestampNanosecondType::parse("2020-01-01T01:01:01"), - TimestampNanosecondType::parse("2020-01-01T01:11:01"), - TimestampNanosecondType::parse("2020-01-11T01:02:01"), - ])), - expected_max: Arc::new(TimestampNanosecondArray::from(vec![ - TimestampNanosecondType::parse("2020-01-02T01:01:01"), - TimestampNanosecondType::parse("2020-01-11T01:01:01"), - TimestampNanosecondType::parse("2020-01-12T01:01:01"), - ])), - // nulls are [1, 2, 1] - expected_null_counts: UInt64Array::from(vec![1, 2, 1]), - // row counts are [8, 8, 4] - expected_row_counts: Some(UInt64Array::from(vec![8, 8, 4])), - column_name: "nanos", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new( - TimestampNanosecondArray::from(vec![ - TimestampNanosecondType::parse("2020-01-01T01:01:01"), - TimestampNanosecondType::parse("2020-01-01T01:11:01"), - TimestampNanosecondType::parse("2020-01-11T01:02:01"), - ]) - .with_timezone(tz), - ), - expected_max: Arc::new( - TimestampNanosecondArray::from(vec![ - TimestampNanosecondType::parse("2020-01-02T01:01:01"), - TimestampNanosecondType::parse("2020-01-11T01:01:01"), - TimestampNanosecondType::parse("2020-01-12T01:01:01"), - ]) - .with_timezone(tz), - ), - // nulls are [1, 2, 1] - expected_null_counts: UInt64Array::from(vec![1, 2, 1]), - // row counts are [8, 8, 4] - expected_row_counts: Some(UInt64Array::from(vec![8, 8, 4])), - column_name: "nanos_timezoned", - check: Check::Both, - } - .run(); - - // micros - Test { - reader: &reader, - expected_min: Arc::new(TimestampMicrosecondArray::from(vec![ - TimestampMicrosecondType::parse("2020-01-01T01:01:01"), - TimestampMicrosecondType::parse("2020-01-01T01:11:01"), - TimestampMicrosecondType::parse("2020-01-11T01:02:01"), - ])), - expected_max: Arc::new(TimestampMicrosecondArray::from(vec![ - TimestampMicrosecondType::parse("2020-01-02T01:01:01"), - TimestampMicrosecondType::parse("2020-01-11T01:01:01"), - TimestampMicrosecondType::parse("2020-01-12T01:01:01"), - ])), - expected_null_counts: UInt64Array::from(vec![1, 2, 1]), - expected_row_counts: Some(UInt64Array::from(vec![8, 8, 4])), - column_name: "micros", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new( - TimestampMicrosecondArray::from(vec![ - TimestampMicrosecondType::parse("2020-01-01T01:01:01"), - TimestampMicrosecondType::parse("2020-01-01T01:11:01"), - TimestampMicrosecondType::parse("2020-01-11T01:02:01"), - ]) - .with_timezone(tz), - ), - expected_max: Arc::new( - TimestampMicrosecondArray::from(vec![ - TimestampMicrosecondType::parse("2020-01-02T01:01:01"), - TimestampMicrosecondType::parse("2020-01-11T01:01:01"), - TimestampMicrosecondType::parse("2020-01-12T01:01:01"), - ]) - .with_timezone(tz), - ), - // nulls are [1, 2, 1] - expected_null_counts: UInt64Array::from(vec![1, 2, 1]), - // row counts are [8, 8, 4] - expected_row_counts: Some(UInt64Array::from(vec![8, 8, 4])), - column_name: "micros_timezoned", - check: Check::Both, - } - .run(); - - // millis - Test { - reader: &reader, - expected_min: Arc::new(TimestampMillisecondArray::from(vec![ - TimestampMillisecondType::parse("2020-01-01T01:01:01"), - TimestampMillisecondType::parse("2020-01-01T01:11:01"), - TimestampMillisecondType::parse("2020-01-11T01:02:01"), - ])), - expected_max: Arc::new(TimestampMillisecondArray::from(vec![ - TimestampMillisecondType::parse("2020-01-02T01:01:01"), - TimestampMillisecondType::parse("2020-01-11T01:01:01"), - TimestampMillisecondType::parse("2020-01-12T01:01:01"), - ])), - expected_null_counts: UInt64Array::from(vec![1, 2, 1]), - expected_row_counts: Some(UInt64Array::from(vec![8, 8, 4])), - column_name: "millis", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new( - TimestampMillisecondArray::from(vec![ - TimestampMillisecondType::parse("2020-01-01T01:01:01"), - TimestampMillisecondType::parse("2020-01-01T01:11:01"), - TimestampMillisecondType::parse("2020-01-11T01:02:01"), - ]) - .with_timezone(tz), - ), - expected_max: Arc::new( - TimestampMillisecondArray::from(vec![ - TimestampMillisecondType::parse("2020-01-02T01:01:01"), - TimestampMillisecondType::parse("2020-01-11T01:01:01"), - TimestampMillisecondType::parse("2020-01-12T01:01:01"), - ]) - .with_timezone(tz), - ), - // nulls are [1, 2, 1] - expected_null_counts: UInt64Array::from(vec![1, 2, 1]), - // row counts are [8, 8, 4] - expected_row_counts: Some(UInt64Array::from(vec![8, 8, 4])), - column_name: "millis_timezoned", - check: Check::Both, - } - .run(); - - // seconds - Test { - reader: &reader, - expected_min: Arc::new(TimestampSecondArray::from(vec![ - TimestampSecondType::parse("2020-01-01T01:01:01"), - TimestampSecondType::parse("2020-01-01T01:11:01"), - TimestampSecondType::parse("2020-01-11T01:02:01"), - ])), - expected_max: Arc::new(TimestampSecondArray::from(vec![ - TimestampSecondType::parse("2020-01-02T01:01:01"), - TimestampSecondType::parse("2020-01-11T01:01:01"), - TimestampSecondType::parse("2020-01-12T01:01:01"), - ])), - expected_null_counts: UInt64Array::from(vec![1, 2, 1]), - expected_row_counts: Some(UInt64Array::from(vec![8, 8, 4])), - column_name: "seconds", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new( - TimestampSecondArray::from(vec![ - TimestampSecondType::parse("2020-01-01T01:01:01"), - TimestampSecondType::parse("2020-01-01T01:11:01"), - TimestampSecondType::parse("2020-01-11T01:02:01"), - ]) - .with_timezone(tz), - ), - expected_max: Arc::new( - TimestampSecondArray::from(vec![ - TimestampSecondType::parse("2020-01-02T01:01:01"), - TimestampSecondType::parse("2020-01-11T01:01:01"), - TimestampSecondType::parse("2020-01-12T01:01:01"), - ]) - .with_timezone(tz), - ), - // nulls are [1, 2, 1] - expected_null_counts: UInt64Array::from(vec![1, 2, 1]), - // row counts are [8, 8, 4] - expected_row_counts: Some(UInt64Array::from(vec![8, 8, 4])), - column_name: "seconds_timezoned", - check: Check::Both, - } - .run(); -} - -// date with different row group sizes -#[tokio::test] -async fn test_dates_32_diff_rg_sizes() { - // This creates a parquet files of 3 columns named "date32", "date64", "names" - // "date32" --> Date32Array - // "date64" --> Date64Array - // "names" --> StringArray - // - // The file is created by 4 record batches (each has a null row), each has 5 rows but then will be split into 2 row groups with size 13, 7 - let reader = TestReader { - scenario: Scenario::Dates, - row_per_group: 13, - } - .build() - .await; - - Test { - reader: &reader, - // mins are [2020-01-01, 2020-10-30] - expected_min: Arc::new(Date32Array::from(vec![ - Date32Type::parse("2020-01-01"), - Date32Type::parse("2020-10-30"), - ])), - // maxes are [2020-10-29, 2029-11-12] - expected_max: Arc::new(Date32Array::from(vec![ - Date32Type::parse("2020-10-29"), - Date32Type::parse("2029-11-12"), - ])), - // nulls are [2, 2] - expected_null_counts: UInt64Array::from(vec![2, 2]), - // row counts are [13, 7] - expected_row_counts: Some(UInt64Array::from(vec![13, 7])), - column_name: "date32", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_time32_second_diff_rg_sizes() { - let reader = TestReader { - scenario: Scenario::Time32Second, - row_per_group: 4, - } - .build() - .await; - - // Test for Time32Second column - Test { - reader: &reader, - // Assuming specific minimum and maximum values for demonstration - expected_min: Arc::new(Time32SecondArray::from(vec![18506, 18510, 18514, 18518])), - expected_max: Arc::new(Time32SecondArray::from(vec![18509, 18513, 18517, 18521])), - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity - expected_row_counts: Some(UInt64Array::from(vec![4, 4, 4, 4])), - column_name: "second", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_time32_millisecond_diff_rg_sizes() { - let reader = TestReader { - scenario: Scenario::Time32Millisecond, - row_per_group: 4, - } - .build() - .await; - - // Test for Time32Millisecond column - Test { - reader: &reader, - // Assuming specific minimum and maximum values for demonstration - expected_min: Arc::new(Time32MillisecondArray::from(vec![ - 3600000, 3600004, 3600008, 3600012, - ])), - expected_max: Arc::new(Time32MillisecondArray::from(vec![ - 3600003, 3600007, 3600011, 3600015, - ])), - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity - expected_row_counts: Some(UInt64Array::from(vec![4, 4, 4, 4])), - column_name: "millisecond", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_time64_microsecond_diff_rg_sizes() { - let reader = TestReader { - scenario: Scenario::Time64Microsecond, - row_per_group: 4, - } - .build() - .await; - - // Test for Time64MicroSecond column - Test { - reader: &reader, - // Assuming specific minimum and maximum values for demonstration - expected_min: Arc::new(Time64MicrosecondArray::from(vec![ - 1234567890123, - 1234567890127, - 1234567890131, - 1234567890135, - ])), - expected_max: Arc::new(Time64MicrosecondArray::from(vec![ - 1234567890126, - 1234567890130, - 1234567890134, - 1234567890138, - ])), - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity - expected_row_counts: Some(UInt64Array::from(vec![4, 4, 4, 4])), - column_name: "microsecond", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_time64_nanosecond_diff_rg_sizes() { - let reader = TestReader { - scenario: Scenario::Time64Nanosecond, - row_per_group: 4, - } - .build() - .await; - - // Test for Time32Second column - Test { - reader: &reader, - // Assuming specific minimum and maximum values for demonstration - expected_min: Arc::new(Time64NanosecondArray::from(vec![ - 987654321012345, - 987654321012349, - 987654321012353, - 987654321012357, - ])), - expected_max: Arc::new(Time64NanosecondArray::from(vec![ - 987654321012348, - 987654321012352, - 987654321012356, - 987654321012360, - ])), - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity - expected_row_counts: Some(UInt64Array::from(vec![4, 4, 4, 4])), - column_name: "nanosecond", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_dates_64_diff_rg_sizes() { - // The file is created by 4 record batches (each has a null row), each has 5 rows but then will be split into 2 row groups with size 13, 7 - let reader = TestReader { - scenario: Scenario::Dates, - row_per_group: 13, - } - .build() - .await; - Test { - reader: &reader, - expected_min: Arc::new(Date64Array::from(vec![ - Date64Type::parse("2020-01-01"), - Date64Type::parse("2020-10-30"), - ])), - expected_max: Arc::new(Date64Array::from(vec![ - Date64Type::parse("2020-10-29"), - Date64Type::parse("2029-11-12"), - ])), - expected_null_counts: UInt64Array::from(vec![2, 2]), - expected_row_counts: Some(UInt64Array::from(vec![13, 7])), - column_name: "date64", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_uint() { - // This creates a parquet files of 4 columns named "u8", "u16", "u32", "u64" - // "u8" --> UInt8Array - // "u16" --> UInt16Array - // "u32" --> UInt32Array - // "u64" --> UInt64Array - - // The file is created by 4 record batches (each has a null row), each has 5 rows but then will be split into 5 row groups with size 4 - let reader = TestReader { - scenario: Scenario::UInt, - row_per_group: 4, - } - .build() - .await; - - Test { - reader: &reader, - expected_min: Arc::new(UInt8Array::from(vec![0, 1, 4, 7, 251])), - expected_max: Arc::new(UInt8Array::from(vec![3, 4, 6, 250, 254])), - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![4, 4, 4, 4, 4])), - column_name: "u8", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new(UInt16Array::from(vec![0, 1, 4, 7, 251])), - expected_max: Arc::new(UInt16Array::from(vec![3, 4, 6, 250, 254])), - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![4, 4, 4, 4, 4])), - column_name: "u16", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new(UInt32Array::from(vec![0, 1, 4, 7, 251])), - expected_max: Arc::new(UInt32Array::from(vec![3, 4, 6, 250, 254])), - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![4, 4, 4, 4, 4])), - column_name: "u32", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new(UInt64Array::from(vec![0, 1, 4, 7, 251])), - expected_max: Arc::new(UInt64Array::from(vec![3, 4, 6, 250, 254])), - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![4, 4, 4, 4, 4])), - column_name: "u64", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_int32_range() { - // This creates a parquet file of 1 column "i" - // file has 2 record batches, each has 2 rows. They will be saved into one row group - let reader = TestReader { - scenario: Scenario::Int32Range, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - expected_min: Arc::new(Int32Array::from(vec![0])), - expected_max: Arc::new(Int32Array::from(vec![300000])), - expected_null_counts: UInt64Array::from(vec![0]), - expected_row_counts: Some(UInt64Array::from(vec![4])), - column_name: "i", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_uint32_range() { - // This creates a parquet file of 1 column "u" - // file has 2 record batches, each has 2 rows. They will be saved into one row group - let reader = TestReader { - scenario: Scenario::UInt32Range, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - expected_min: Arc::new(UInt32Array::from(vec![0])), - expected_max: Arc::new(UInt32Array::from(vec![300000])), - expected_null_counts: UInt64Array::from(vec![0]), - expected_row_counts: Some(UInt64Array::from(vec![4])), - column_name: "u", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_numeric_limits_unsigned() { - // file has 7 rows, 2 row groups: one with 5 rows, one with 2 rows. - let reader = TestReader { - scenario: Scenario::NumericLimits, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - expected_min: Arc::new(UInt8Array::from(vec![u8::MIN, 100])), - expected_max: Arc::new(UInt8Array::from(vec![100, u8::MAX])), - expected_null_counts: UInt64Array::from(vec![0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 2])), - column_name: "u8", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new(UInt16Array::from(vec![u16::MIN, 100])), - expected_max: Arc::new(UInt16Array::from(vec![100, u16::MAX])), - expected_null_counts: UInt64Array::from(vec![0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 2])), - column_name: "u16", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new(UInt32Array::from(vec![u32::MIN, 100])), - expected_max: Arc::new(UInt32Array::from(vec![100, u32::MAX])), - expected_null_counts: UInt64Array::from(vec![0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 2])), - column_name: "u32", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new(UInt64Array::from(vec![u64::MIN, 100])), - expected_max: Arc::new(UInt64Array::from(vec![100, u64::MAX])), - expected_null_counts: UInt64Array::from(vec![0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 2])), - column_name: "u64", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_numeric_limits_signed() { - // file has 7 rows, 2 row groups: one with 5 rows, one with 2 rows. - let reader = TestReader { - scenario: Scenario::NumericLimits, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - expected_min: Arc::new(Int8Array::from(vec![i8::MIN, -100])), - expected_max: Arc::new(Int8Array::from(vec![100, i8::MAX])), - expected_null_counts: UInt64Array::from(vec![0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 2])), - column_name: "i8", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new(Int16Array::from(vec![i16::MIN, -100])), - expected_max: Arc::new(Int16Array::from(vec![100, i16::MAX])), - expected_null_counts: UInt64Array::from(vec![0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 2])), - column_name: "i16", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new(Int32Array::from(vec![i32::MIN, -100])), - expected_max: Arc::new(Int32Array::from(vec![100, i32::MAX])), - expected_null_counts: UInt64Array::from(vec![0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 2])), - column_name: "i32", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new(Int64Array::from(vec![i64::MIN, -100])), - expected_max: Arc::new(Int64Array::from(vec![100, i64::MAX])), - expected_null_counts: UInt64Array::from(vec![0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 2])), - column_name: "i64", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_numeric_limits_float() { - // file has 7 rows, 2 row groups: one with 5 rows, one with 2 rows. - let reader = TestReader { - scenario: Scenario::NumericLimits, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - expected_min: Arc::new(Float32Array::from(vec![f32::MIN, -100.0])), - expected_max: Arc::new(Float32Array::from(vec![100.0, f32::MAX])), - expected_null_counts: UInt64Array::from(vec![0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 2])), - column_name: "f32", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new(Float64Array::from(vec![f64::MIN, -100.0])), - expected_max: Arc::new(Float64Array::from(vec![100.0, f64::MAX])), - expected_null_counts: UInt64Array::from(vec![0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 2])), - column_name: "f64", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new(Float32Array::from(vec![-1.0, -100.0])), - expected_max: Arc::new(Float32Array::from(vec![100.0, -100.0])), - expected_null_counts: UInt64Array::from(vec![0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 2])), - column_name: "f32_nan", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new(Float64Array::from(vec![-1.0, -100.0])), - expected_max: Arc::new(Float64Array::from(vec![100.0, -100.0])), - expected_null_counts: UInt64Array::from(vec![0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 2])), - column_name: "f64_nan", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_float64() { - // This creates a parquet file of 1 column "f" - // file has 4 record batches, each has 5 rows. They will be saved into 4 row groups - let reader = TestReader { - scenario: Scenario::Float64, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - expected_min: Arc::new(Float64Array::from(vec![-5.0, -4.0, -0.0, 5.0])), - expected_max: Arc::new(Float64Array::from(vec![-1.0, 0.0, 4.0, 9.0])), - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), - column_name: "f", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_float16() { - // This creates a parquet file of 1 column "f" - // file has 4 record batches, each has 5 rows. They will be saved into 4 row groups - let reader = TestReader { - scenario: Scenario::Float16, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - expected_min: Arc::new(Float16Array::from( - vec![-5.0, -4.0, -0.0, 5.0] - .into_iter() - .map(f16::from_f32) - .collect::>(), - )), - expected_max: Arc::new(Float16Array::from( - vec![-1.0, 0.0, 4.0, 9.0] - .into_iter() - .map(f16::from_f32) - .collect::>(), - )), - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), - column_name: "f", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_decimal() { - // This creates a parquet file of 1 column "decimal_col" with decimal data type and precision 9, scale 2 - // file has 3 record batches, each has 5 rows. They will be saved into 3 row groups - let reader = TestReader { - scenario: Scenario::Decimal, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - expected_min: Arc::new( - Decimal128Array::from(vec![100, -500, 2000]) - .with_precision_and_scale(9, 2) - .unwrap(), - ), - expected_max: Arc::new( - Decimal128Array::from(vec![600, 600, 6000]) - .with_precision_and_scale(9, 2) - .unwrap(), - ), - expected_null_counts: UInt64Array::from(vec![0, 0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), - column_name: "decimal_col", - check: Check::Both, - } - .run(); -} -#[tokio::test] -async fn test_decimal_256() { - // This creates a parquet file of 1 column "decimal256_col" with decimal data type and precision 9, scale 2 - // file has 3 record batches, each has 5 rows. They will be saved into 3 row groups - let reader = TestReader { - scenario: Scenario::Decimal256, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - expected_min: Arc::new( - Decimal256Array::from(vec![ - i256::from(100), - i256::from(-500), - i256::from(2000), - ]) - .with_precision_and_scale(9, 2) - .unwrap(), - ), - expected_max: Arc::new( - Decimal256Array::from(vec![ - i256::from(600), - i256::from(600), - i256::from(6000), - ]) - .with_precision_and_scale(9, 2) - .unwrap(), - ), - expected_null_counts: UInt64Array::from(vec![0, 0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), - column_name: "decimal256_col", - check: Check::Both, - } - .run(); -} -#[tokio::test] -async fn test_dictionary() { - let reader = TestReader { - scenario: Scenario::Dictionary, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - expected_min: Arc::new(StringArray::from(vec!["abc", "aaa"])), - expected_max: Arc::new(StringArray::from(vec!["def", "fffff"])), - expected_null_counts: UInt64Array::from(vec![1, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 2])), - column_name: "string_dict_i8", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new(StringArray::from(vec!["abc", "aaa"])), - expected_max: Arc::new(StringArray::from(vec!["def", "fffff"])), - expected_null_counts: UInt64Array::from(vec![1, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 2])), - column_name: "string_dict_i32", - check: Check::Both, - } - .run(); - - Test { - reader: &reader, - expected_min: Arc::new(Int64Array::from(vec![-100, 0])), - expected_max: Arc::new(Int64Array::from(vec![0, 100])), - expected_null_counts: UInt64Array::from(vec![1, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 2])), - column_name: "int_dict_i8", - check: Check::Both, - } - .run(); -} - -#[tokio::test] -async fn test_byte() { - // This creates a parquet file of 5 columns - // "name" - // "service_string" - // "service_binary" - // "service_fixedsize" - // "service_large_binary" - - // file has 3 record batches, each has 5 rows. They will be saved into 3 row groups - let reader = TestReader { - scenario: Scenario::ByteArray, - row_per_group: 5, - } - .build() - .await; - - // column "name" - Test { - reader: &reader, - expected_min: Arc::new(StringArray::from(vec![ - "all frontends", - "mixed", - "all backends", - ])), - expected_max: Arc::new(StringArray::from(vec![ - "all frontends", - "mixed", - "all backends", - ])), - expected_null_counts: UInt64Array::from(vec![0, 0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), - column_name: "name", - check: Check::Both, - } - .run(); - - // column "service_string" - Test { - reader: &reader, - expected_min: Arc::new(StringArray::from(vec![ - "frontend five", - "backend one", - "backend eight", - ])), - expected_max: Arc::new(StringArray::from(vec![ - "frontend two", - "frontend six", - "backend six", - ])), - expected_null_counts: UInt64Array::from(vec![0, 0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), - column_name: "service_string", - check: Check::Both, - } - .run(); - - // column "service_binary" - - let expected_service_binary_min_values: Vec<&[u8]> = - vec![b"frontend five", b"backend one", b"backend eight"]; - - let expected_service_binary_max_values: Vec<&[u8]> = - vec![b"frontend two", b"frontend six", b"backend six"]; - - Test { - reader: &reader, - expected_min: Arc::new(BinaryArray::from(expected_service_binary_min_values)), - expected_max: Arc::new(BinaryArray::from(expected_service_binary_max_values)), - expected_null_counts: UInt64Array::from(vec![0, 0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), - column_name: "service_binary", - check: Check::Both, - } - .run(); - - // column "service_fixedsize" - // b"fe1", b"be1", b"be4" - let min_input = vec![vec![102, 101, 49], vec![98, 101, 49], vec![98, 101, 52]]; - // b"fe5", b"fe6", b"be8" - let max_input = vec![vec![102, 101, 55], vec![102, 101, 54], vec![98, 101, 56]]; - - Test { - reader: &reader, - expected_min: Arc::new( - FixedSizeBinaryArray::try_from_iter(min_input.into_iter()).unwrap(), - ), - expected_max: Arc::new( - FixedSizeBinaryArray::try_from_iter(max_input.into_iter()).unwrap(), - ), - expected_null_counts: UInt64Array::from(vec![0, 0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), - column_name: "service_fixedsize", - check: Check::Both, - } - .run(); - - let expected_service_large_binary_min_values: Vec<&[u8]> = - vec![b"frontend five", b"backend one", b"backend eight"]; - - let expected_service_large_binary_max_values: Vec<&[u8]> = - vec![b"frontend two", b"frontend six", b"backend six"]; - - Test { - reader: &reader, - expected_min: Arc::new(LargeBinaryArray::from( - expected_service_large_binary_min_values, - )), - expected_max: Arc::new(LargeBinaryArray::from( - expected_service_large_binary_max_values, - )), - expected_null_counts: UInt64Array::from(vec![0, 0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), - column_name: "service_large_binary", - check: Check::Both, - } - .run(); -} - -// PeriodsInColumnNames -#[tokio::test] -async fn test_period_in_column_names() { - // This creates a parquet file of 2 columns "name" and "service.name" - // file has 3 record batches, each has 5 rows. They will be saved into 3 row groups - let reader = TestReader { - scenario: Scenario::PeriodsInColumnNames, - row_per_group: 5, - } - .build() - .await; - - // column "name" - Test { - reader: &reader, - expected_min: Arc::new(StringArray::from(vec![ - "HTTP GET / DISPATCH", - "HTTP PUT / DISPATCH", - "HTTP GET / DISPATCH", - ])), - expected_max: Arc::new(StringArray::from(vec![ - "HTTP GET / DISPATCH", - "HTTP PUT / DISPATCH", - "HTTP GET / DISPATCH", - ])), - expected_null_counts: UInt64Array::from(vec![0, 0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), - column_name: "name", - check: Check::Both, - } - .run(); - - // column "service.name" - Test { - reader: &reader, - expected_min: Arc::new(StringArray::from(vec!["frontend", "backend", "backend"])), - expected_max: Arc::new(StringArray::from(vec![ - "frontend", "frontend", "backend", - ])), - expected_null_counts: UInt64Array::from(vec![0, 0, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), - column_name: "service.name", - check: Check::Both, - } - .run(); -} - -// Boolean -#[tokio::test] -async fn test_boolean() { - // This creates a parquet files of 1 column named "bool" - // The file is created by 2 record batches each has 5 rows --> 2 row groups - let reader = TestReader { - scenario: Scenario::Boolean, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - expected_min: Arc::new(BooleanArray::from(vec![false, false])), - expected_max: Arc::new(BooleanArray::from(vec![true, false])), - expected_null_counts: UInt64Array::from(vec![1, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 5])), - column_name: "bool", - check: Check::Both, - } - .run(); -} - -// struct array -// BUG -// https://github.com/apache/datafusion/issues/10609 -// Note that: since I have not worked on struct array before, there may be a bug in the test code rather than the real bug in the code -#[ignore] -#[tokio::test] -async fn test_struct() { - // This creates a parquet files of 1 column named "struct" - // The file is created by 1 record batch with 3 rows in the struct array - let reader = TestReader { - scenario: Scenario::StructArray, - row_per_group: 5, - } - .build() - .await; - Test { - reader: &reader, - expected_min: Arc::new(struct_array(vec![(Some(1), Some(6.0), Some(12.0))])), - expected_max: Arc::new(struct_array(vec![(Some(2), Some(8.5), Some(14.0))])), - expected_null_counts: UInt64Array::from(vec![0]), - expected_row_counts: Some(UInt64Array::from(vec![3])), - column_name: "struct", - check: Check::RowGroup, - } - .run(); -} - -// UTF8 -#[tokio::test] -async fn test_utf8() { - let reader = TestReader { - scenario: Scenario::UTF8, - row_per_group: 5, - } - .build() - .await; - - // test for utf8 - Test { - reader: &reader, - expected_min: Arc::new(StringArray::from(vec!["a", "e"])), - expected_max: Arc::new(StringArray::from(vec!["d", "i"])), - expected_null_counts: UInt64Array::from(vec![1, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 5])), - column_name: "utf8", - check: Check::Both, - } - .run(); - - // test for large_utf8 - Test { - reader: &reader, - expected_min: Arc::new(LargeStringArray::from(vec!["a", "e"])), - expected_max: Arc::new(LargeStringArray::from(vec!["d", "i"])), - expected_null_counts: UInt64Array::from(vec![1, 0]), - expected_row_counts: Some(UInt64Array::from(vec![5, 5])), - column_name: "large_utf8", - check: Check::Both, - } - .run(); -} - -////// Files with missing statistics /////// - -#[tokio::test] -async fn test_missing_statistics() { - let reader = Int64Case { - null_values: 0, - no_null_values_start: 4, - no_null_values_end: 7, - row_per_group: 5, - enable_stats: Some(EnabledStatistics::None), - ..Default::default() - } - .build(); - - Test { - reader: &reader, - expected_min: Arc::new(Int64Array::from(vec![None])), - expected_max: Arc::new(Int64Array::from(vec![None])), - expected_null_counts: UInt64Array::from(vec![None]), - expected_row_counts: Some(UInt64Array::from(vec![3])), // still has row count statistics - column_name: "i64", - check: Check::Both, - } - .run(); -} - -/////// NEGATIVE TESTS /////// -// column not found -#[tokio::test] -async fn test_column_not_found() { - let reader = TestReader { - scenario: Scenario::Dates, - row_per_group: 5, - } - .build() - .await; - Test { - reader: &reader, - expected_min: Arc::new(Int64Array::from(vec![18262, 18565])), - expected_max: Arc::new(Int64Array::from(vec![18564, 21865])), - expected_null_counts: UInt64Array::from(vec![2, 2]), - expected_row_counts: Some(UInt64Array::from(vec![13, 7])), - column_name: "not_a_column", - check: Check::Both, - } - .run_col_not_found(); -} - -#[tokio::test] -async fn test_column_non_existent() { - // Create a schema with an additional column - // that will not have a matching parquet index - let schema = Arc::new(Schema::new(vec![ - Field::new("i8", DataType::Int8, true), - Field::new("i16", DataType::Int16, true), - Field::new("i32", DataType::Int32, true), - Field::new("i64", DataType::Int64, true), - Field::new("i_do_not_exist", DataType::Int64, true), - ])); - - let reader = TestReader { - scenario: Scenario::Int, - row_per_group: 5, - } - .build() - .await; - - Test { - reader: &reader, - // mins are [-5, -4, 0, 5] - expected_min: Arc::new(Int64Array::from(vec![None, None, None, None])), - // maxes are [-1, 0, 4, 9] - expected_max: Arc::new(Int64Array::from(vec![None, None, None, None])), - // nulls are [0, 0, 0, 0] - expected_null_counts: UInt64Array::from(vec![None, None, None, None]), - // row counts are [5, 5, 5, 5] - expected_row_counts: None, - column_name: "i_do_not_exist", - check: Check::Both, - } - .run_with_schema(&schema); -} diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 1b68a4aa4eb3..60a8dd400786 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -18,19 +18,15 @@ //! Parquet integration tests use crate::parquet::utils::MetricsFinder; use arrow::array::Decimal128Array; -use arrow::datatypes::i256; use arrow::{ array::{ - make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, - Decimal256Array, DictionaryArray, FixedSizeBinaryArray, Float16Array, - Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, - LargeBinaryArray, LargeStringArray, StringArray, StructArray, - Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, - Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, + make_array, Array, ArrayRef, BinaryArray, Date32Array, Date64Array, + FixedSizeBinaryArray, Float64Array, Int16Array, Int32Array, Int64Array, + Int8Array, LargeBinaryArray, LargeStringArray, StringArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }, - datatypes::{DataType, Field, Int32Type, Int8Type, Schema, TimeUnit}, + datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, util::pretty::pretty_format_batches, }; @@ -41,13 +37,11 @@ use datafusion::{ prelude::{ParquetReadOptions, SessionConfig, SessionContext}, }; use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; -use half::f16; use parquet::arrow::ArrowWriter; use parquet::file::properties::{EnabledStatistics, WriterProperties}; use std::sync::Arc; use tempfile::NamedTempFile; -mod arrow_statistics; mod custom_reader; // Don't run on windows as tempfiles don't seem to work the same #[cfg(not(target_os = "windows"))] @@ -75,37 +69,23 @@ fn init() { /// What data to use #[derive(Debug, Clone, Copy)] enum Scenario { - Boolean, Timestamps, Dates, Int, Int32Range, UInt, UInt32Range, - Time32Second, - Time32Millisecond, - Time64Nanosecond, - Time64Microsecond, - /// 7 Rows, for each i8, i16, i32, i64, u8, u16, u32, u64, f32, f64 - /// -MIN, -100, -1, 0, 1, 100, MAX - NumericLimits, - Float16, - Float32, Float64, Decimal, - Decimal256, DecimalBloomFilterInt32, DecimalBloomFilterInt64, DecimalLargePrecision, DecimalLargePrecisionBloomFilter, /// StringArray, BinaryArray, FixedSizeBinaryArray ByteArray, - /// DictionaryArray - Dictionary, PeriodsInColumnNames, WithNullValues, WithNullValuesPageLevel, - StructArray, UTF8, } @@ -321,16 +301,6 @@ impl ContextWithParquet { } } -fn make_boolean_batch(v: Vec>) -> RecordBatch { - let schema = Arc::new(Schema::new(vec![Field::new( - "bool", - DataType::Boolean, - true, - )])); - let array = Arc::new(BooleanArray::from(v)) as ArrayRef; - RecordBatch::try_new(schema, vec![array.clone()]).unwrap() -} - /// Return record batch with a few rows of data for all of the supported timestamp types /// values with the specified offset /// @@ -484,55 +454,6 @@ fn make_int_batches(start: i8, end: i8) -> RecordBatch { .unwrap() } -/// Return record batch with Time32Second, Time32Millisecond sequences -fn make_time32_batches(scenario: Scenario, v: Vec) -> RecordBatch { - match scenario { - Scenario::Time32Second => { - let schema = Arc::new(Schema::new(vec![Field::new( - "second", - DataType::Time32(TimeUnit::Second), - true, - )])); - let array = Arc::new(Time32SecondArray::from(v)) as ArrayRef; - RecordBatch::try_new(schema, vec![array]).unwrap() - } - Scenario::Time32Millisecond => { - let schema = Arc::new(Schema::new(vec![Field::new( - "millisecond", - DataType::Time32(TimeUnit::Millisecond), - true, - )])); - let array = Arc::new(Time32MillisecondArray::from(v)) as ArrayRef; - RecordBatch::try_new(schema, vec![array]).unwrap() - } - _ => panic!("Unsupported scenario for Time32"), - } -} - -/// Return record batch with Time64Microsecond, Time64Nanosecond sequences -fn make_time64_batches(scenario: Scenario, v: Vec) -> RecordBatch { - match scenario { - Scenario::Time64Microsecond => { - let schema = Arc::new(Schema::new(vec![Field::new( - "microsecond", - DataType::Time64(TimeUnit::Microsecond), - true, - )])); - let array = Arc::new(Time64MicrosecondArray::from(v)) as ArrayRef; - RecordBatch::try_new(schema, vec![array]).unwrap() - } - Scenario::Time64Nanosecond => { - let schema = Arc::new(Schema::new(vec![Field::new( - "nanosecond", - DataType::Time64(TimeUnit::Nanosecond), - true, - )])); - let array = Arc::new(Time64NanosecondArray::from(v)) as ArrayRef; - RecordBatch::try_new(schema, vec![array]).unwrap() - } - _ => panic!("Unsupported scenario for Time64"), - } -} /// Return record batch with u8, u16, u32, and u64 sequences /// /// Columns are named @@ -587,18 +508,6 @@ fn make_f64_batch(v: Vec) -> RecordBatch { RecordBatch::try_new(schema, vec![array.clone()]).unwrap() } -fn make_f32_batch(v: Vec) -> RecordBatch { - let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Float32, true)])); - let array = Arc::new(Float32Array::from(v)) as ArrayRef; - RecordBatch::try_new(schema, vec![array.clone()]).unwrap() -} - -fn make_f16_batch(v: Vec) -> RecordBatch { - let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Float16, true)])); - let array = Arc::new(Float16Array::from(v)) as ArrayRef; - RecordBatch::try_new(schema, vec![array.clone()]).unwrap() -} - /// Return record batch with decimal vector /// /// Columns are named @@ -617,24 +526,6 @@ fn make_decimal_batch(v: Vec, precision: u8, scale: i8) -> RecordBatch { RecordBatch::try_new(schema, vec![array.clone()]).unwrap() } -/// Return record batch with decimal256 vector -/// -/// Columns are named -/// "decimal256_col" -> Decimal256Array -fn make_decimal256_batch(v: Vec, precision: u8, scale: i8) -> RecordBatch { - let schema = Arc::new(Schema::new(vec![Field::new( - "decimal256_col", - DataType::Decimal256(precision, scale), - true, - )])); - let array = Arc::new( - Decimal256Array::from(v) - .with_precision_and_scale(precision, scale) - .unwrap(), - ) as ArrayRef; - RecordBatch::try_new(schema, vec![array]).unwrap() -} - /// Return record batch with a few rows of data for all of the supported date /// types with the specified offset (in days) /// @@ -843,39 +734,6 @@ fn make_int_batches_with_null( .unwrap() } -fn make_numeric_limit_batch() -> RecordBatch { - let i8 = Int8Array::from(vec![i8::MIN, 100, -1, 0, 1, -100, i8::MAX]); - let i16 = Int16Array::from(vec![i16::MIN, 100, -1, 0, 1, -100, i16::MAX]); - let i32 = Int32Array::from(vec![i32::MIN, 100, -1, 0, 1, -100, i32::MAX]); - let i64 = Int64Array::from(vec![i64::MIN, 100, -1, 0, 1, -100, i64::MAX]); - let u8 = UInt8Array::from(vec![u8::MIN, 100, 1, 0, 1, 100, u8::MAX]); - let u16 = UInt16Array::from(vec![u16::MIN, 100, 1, 0, 1, 100, u16::MAX]); - let u32 = UInt32Array::from(vec![u32::MIN, 100, 1, 0, 1, 100, u32::MAX]); - let u64 = UInt64Array::from(vec![u64::MIN, 100, 1, 0, 1, 100, u64::MAX]); - let f32 = Float32Array::from(vec![f32::MIN, 100.0, -1.0, 0.0, 1.0, -100.0, f32::MAX]); - let f64 = Float64Array::from(vec![f64::MIN, 100.0, -1.0, 0.0, 1.0, -100.0, f64::MAX]); - let f32_nan = - Float32Array::from(vec![f32::NAN, 100.0, -1.0, 0.0, 1.0, -100.0, f32::NAN]); - let f64_nan = - Float64Array::from(vec![f64::NAN, 100.0, -1.0, 0.0, 1.0, -100.0, f64::NAN]); - - RecordBatch::try_from_iter(vec![ - ("i8", Arc::new(i8) as _), - ("i16", Arc::new(i16) as _), - ("i32", Arc::new(i32) as _), - ("i64", Arc::new(i64) as _), - ("u8", Arc::new(u8) as _), - ("u16", Arc::new(u16) as _), - ("u32", Arc::new(u32) as _), - ("u64", Arc::new(u64) as _), - ("f32", Arc::new(f32) as _), - ("f64", Arc::new(f64) as _), - ("f32_nan", Arc::new(f32_nan) as _), - ("f64_nan", Arc::new(f64_nan) as _), - ]) - .unwrap() -} - fn make_utf8_batch(value: Vec>) -> RecordBatch { let utf8 = StringArray::from(value.clone()); let large_utf8 = LargeStringArray::from(value); @@ -886,61 +744,8 @@ fn make_utf8_batch(value: Vec>) -> RecordBatch { .unwrap() } -fn make_dict_batch() -> RecordBatch { - let values = [ - Some("abc"), - Some("def"), - None, - Some("def"), - Some("abc"), - Some("fffff"), - Some("aaa"), - ]; - let dict_i8_array = DictionaryArray::::from_iter(values.iter().cloned()); - let dict_i32_array = DictionaryArray::::from_iter(values.iter().cloned()); - - // Dictionary array of integers - let int64_values = Int64Array::from(vec![0, -100, 100]); - let keys = Int8Array::from_iter([ - Some(0), - Some(1), - None, - Some(0), - Some(0), - Some(2), - Some(0), - ]); - let dict_i8_int_array = - DictionaryArray::::try_new(keys, Arc::new(int64_values)).unwrap(); - - RecordBatch::try_from_iter(vec![ - ("string_dict_i8", Arc::new(dict_i8_array) as _), - ("string_dict_i32", Arc::new(dict_i32_array) as _), - ("int_dict_i8", Arc::new(dict_i8_int_array) as _), - ]) - .unwrap() -} - fn create_data_batch(scenario: Scenario) -> Vec { match scenario { - Scenario::Boolean => { - vec![ - make_boolean_batch(vec![ - Some(true), - Some(false), - Some(true), - Some(false), - None, - ]), - make_boolean_batch(vec![ - Some(false), - Some(false), - Some(false), - Some(false), - Some(false), - ]), - ] - } Scenario::Timestamps => { vec![ make_timestamp_batch(TimeDelta::try_seconds(0).unwrap()), @@ -979,45 +784,7 @@ fn create_data_batch(scenario: Scenario) -> Vec { Scenario::UInt32Range => { vec![make_uint32_range(0, 10), make_uint32_range(200000, 300000)] } - Scenario::NumericLimits => { - vec![make_numeric_limit_batch()] - } - Scenario::Float16 => { - vec![ - make_f16_batch( - vec![-5.0, -4.0, -3.0, -2.0, -1.0] - .into_iter() - .map(f16::from_f32) - .collect(), - ), - make_f16_batch( - vec![-4.0, -3.0, -2.0, -1.0, 0.0] - .into_iter() - .map(f16::from_f32) - .collect(), - ), - make_f16_batch( - vec![0.0, 1.0, 2.0, 3.0, 4.0] - .into_iter() - .map(f16::from_f32) - .collect(), - ), - make_f16_batch( - vec![5.0, 6.0, 7.0, 8.0, 9.0] - .into_iter() - .map(f16::from_f32) - .collect(), - ), - ] - } - Scenario::Float32 => { - vec![ - make_f32_batch(vec![-5.0, -4.0, -3.0, -2.0, -1.0]), - make_f32_batch(vec![-4.0, -3.0, -2.0, -1.0, 0.0]), - make_f32_batch(vec![0.0, 1.0, 2.0, 3.0, 4.0]), - make_f32_batch(vec![5.0, 6.0, 7.0, 8.0, 9.0]), - ] - } + Scenario::Float64 => { vec![ make_f64_batch(vec![-5.0, -4.0, -3.0, -2.0, -1.0]), @@ -1034,44 +801,7 @@ fn create_data_batch(scenario: Scenario) -> Vec { make_decimal_batch(vec![2000, 3000, 3000, 4000, 6000], 9, 2), ] } - Scenario::Decimal256 => { - // decimal256 record batch - vec![ - make_decimal256_batch( - vec![ - i256::from(100), - i256::from(200), - i256::from(300), - i256::from(400), - i256::from(600), - ], - 9, - 2, - ), - make_decimal256_batch( - vec![ - i256::from(-500), - i256::from(100), - i256::from(300), - i256::from(400), - i256::from(600), - ], - 9, - 2, - ), - make_decimal256_batch( - vec![ - i256::from(2000), - i256::from(3000), - i256::from(3000), - i256::from(4000), - i256::from(6000), - ], - 9, - 2, - ), - ] - } + Scenario::DecimalBloomFilterInt32 => { // decimal record batch vec![ @@ -1187,9 +917,7 @@ fn create_data_batch(scenario: Scenario) -> Vec { ), ] } - Scenario::Dictionary => { - vec![make_dict_batch()] - } + Scenario::PeriodsInColumnNames => { vec![ // all frontend @@ -1224,120 +952,7 @@ fn create_data_batch(scenario: Scenario) -> Vec { make_int_batches_with_null(5, 1, 6), ] } - Scenario::StructArray => { - let struct_array_data = struct_array(vec![ - (Some(1), Some(6.0), Some(12.0)), - (Some(2), Some(8.5), None), - (None, Some(8.5), Some(14.0)), - ]); - - let schema = Arc::new(Schema::new(vec![Field::new( - "struct", - struct_array_data.data_type().clone(), - true, - )])); - vec![RecordBatch::try_new(schema, vec![struct_array_data]).unwrap()] - } - Scenario::Time32Second => { - vec![ - make_time32_batches( - Scenario::Time32Second, - vec![18506, 18507, 18508, 18509], - ), - make_time32_batches( - Scenario::Time32Second, - vec![18510, 18511, 18512, 18513], - ), - make_time32_batches( - Scenario::Time32Second, - vec![18514, 18515, 18516, 18517], - ), - make_time32_batches( - Scenario::Time32Second, - vec![18518, 18519, 18520, 18521], - ), - ] - } - Scenario::Time32Millisecond => { - vec![ - make_time32_batches( - Scenario::Time32Millisecond, - vec![3600000, 3600001, 3600002, 3600003], - ), - make_time32_batches( - Scenario::Time32Millisecond, - vec![3600004, 3600005, 3600006, 3600007], - ), - make_time32_batches( - Scenario::Time32Millisecond, - vec![3600008, 3600009, 3600010, 3600011], - ), - make_time32_batches( - Scenario::Time32Millisecond, - vec![3600012, 3600013, 3600014, 3600015], - ), - ] - } - Scenario::Time64Microsecond => { - vec![ - make_time64_batches( - Scenario::Time64Microsecond, - vec![1234567890123, 1234567890124, 1234567890125, 1234567890126], - ), - make_time64_batches( - Scenario::Time64Microsecond, - vec![1234567890127, 1234567890128, 1234567890129, 1234567890130], - ), - make_time64_batches( - Scenario::Time64Microsecond, - vec![1234567890131, 1234567890132, 1234567890133, 1234567890134], - ), - make_time64_batches( - Scenario::Time64Microsecond, - vec![1234567890135, 1234567890136, 1234567890137, 1234567890138], - ), - ] - } - Scenario::Time64Nanosecond => { - vec![ - make_time64_batches( - Scenario::Time64Nanosecond, - vec![ - 987654321012345, - 987654321012346, - 987654321012347, - 987654321012348, - ], - ), - make_time64_batches( - Scenario::Time64Nanosecond, - vec![ - 987654321012349, - 987654321012350, - 987654321012351, - 987654321012352, - ], - ), - make_time64_batches( - Scenario::Time64Nanosecond, - vec![ - 987654321012353, - 987654321012354, - 987654321012355, - 987654321012356, - ], - ), - make_time64_batches( - Scenario::Time64Nanosecond, - vec![ - 987654321012357, - 987654321012358, - 987654321012359, - 987654321012360, - ], - ), - ] - } + Scenario::UTF8 => { vec![ make_utf8_batch(vec![Some("a"), Some("b"), Some("c"), Some("d"), None]), @@ -1405,27 +1020,3 @@ async fn make_test_file_page(scenario: Scenario, row_per_page: usize) -> NamedTe writer.close().unwrap(); output_file } - -// returns a struct array with columns "int32_col", "float32_col" and "float64_col" with the specified values -fn struct_array(input: Vec<(Option, Option, Option)>) -> ArrayRef { - let int_32: Int32Array = input.iter().map(|(i, _, _)| i).collect(); - let float_32: Float32Array = input.iter().map(|(_, f, _)| f).collect(); - let float_64: Float64Array = input.iter().map(|(_, _, f)| f).collect(); - - let nullable = true; - let struct_array = StructArray::from(vec![ - ( - Arc::new(Field::new("int32_col", DataType::Int32, nullable)), - Arc::new(int_32) as ArrayRef, - ), - ( - Arc::new(Field::new("float32_col", DataType::Float32, nullable)), - Arc::new(float_32) as ArrayRef, - ), - ( - Arc::new(Field::new("float64_col", DataType::Float64, nullable)), - Arc::new(float_64) as ArrayRef, - ), - ]); - Arc::new(struct_array) -} diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 3df212d466c9..dcd59acbd49e 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -117,7 +117,7 @@ pub trait MemoryPool: Send + Sync + std::fmt::Debug { /// For help with allocation accounting, see the [proxy] module. /// /// [proxy]: crate::memory_pool::proxy -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash, Clone)] pub struct MemoryConsumer { name: String, can_spill: bool, diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index fd7724f3076c..4a41602bd961 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -17,9 +17,13 @@ use crate::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; use datafusion_common::{resources_datafusion_err, DataFusionError, Result}; +use hashbrown::HashMap; use log::debug; use parking_lot::Mutex; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::{ + num::NonZeroUsize, + sync::atomic::{AtomicU64, AtomicUsize, Ordering}, +}; /// A [`MemoryPool`] that enforces no limit #[derive(Debug, Default)] @@ -231,13 +235,164 @@ impl MemoryPool for FairSpillPool { } } +/// Constructs a resources error based upon the individual [`MemoryReservation`]. +/// +/// The error references the `bytes already allocated` for the reservation, +/// and not the total within the collective [`MemoryPool`], +/// nor the total across multiple reservations with the same [`MemoryConsumer`]. #[inline(always)] fn insufficient_capacity_err( reservation: &MemoryReservation, additional: usize, available: usize, ) -> DataFusionError { - resources_datafusion_err!("Failed to allocate additional {} bytes for {} with {} bytes already allocated - maximum available is {}", additional, reservation.registration.consumer.name, reservation.size, available) + resources_datafusion_err!("Failed to allocate additional {} bytes for {} with {} bytes already allocated for this reservation - {} bytes remain available for the total pool", additional, reservation.registration.consumer.name, reservation.size, available) +} + +/// A [`MemoryPool`] that tracks the consumers that have +/// reserved memory within the inner memory pool. +/// +/// By tracking memory reservations more carefully this pool +/// can provide better error messages on the largest memory users +/// +/// Tracking is per hashed [`MemoryConsumer`], not per [`MemoryReservation`]. +/// The same consumer can have multiple reservations. +#[derive(Debug)] +pub struct TrackConsumersPool { + inner: I, + top: NonZeroUsize, + tracked_consumers: Mutex>, +} + +impl TrackConsumersPool { + /// Creates a new [`TrackConsumersPool`]. + /// + /// The `top` determines how many Top K [`MemoryConsumer`]s to include + /// in the reported [`DataFusionError::ResourcesExhausted`]. + pub fn new(inner: I, top: NonZeroUsize) -> Self { + Self { + inner, + top, + tracked_consumers: Default::default(), + } + } + + /// Determine if there are multiple [`MemoryConsumer`]s registered + /// which have the same name. + /// + /// This is very tied to the implementation of the memory consumer. + fn has_multiple_consumers(&self, name: &String) -> bool { + let consumer = MemoryConsumer::new(name); + let consumer_with_spill = consumer.clone().with_can_spill(true); + let guard = self.tracked_consumers.lock(); + guard.contains_key(&consumer) && guard.contains_key(&consumer_with_spill) + } + + /// The top consumers in a report string. + pub fn report_top(&self, top: usize) -> String { + let mut consumers = self + .tracked_consumers + .lock() + .iter() + .map(|(consumer, reserved)| { + ( + (consumer.name().to_owned(), consumer.can_spill()), + reserved.load(Ordering::Acquire), + ) + }) + .collect::>(); + consumers.sort_by(|a, b| b.1.cmp(&a.1)); // inverse ordering + + consumers[0..std::cmp::min(top, consumers.len())] + .iter() + .map(|((name, can_spill), size)| { + if self.has_multiple_consumers(name) { + format!("{name}(can_spill={}) consumed {:?} bytes", can_spill, size) + } else { + format!("{name} consumed {:?} bytes", size) + } + }) + .collect::>() + .join(", ") + } +} + +impl MemoryPool for TrackConsumersPool { + fn register(&self, consumer: &MemoryConsumer) { + self.inner.register(consumer); + + let mut guard = self.tracked_consumers.lock(); + if let Some(already_reserved) = guard.insert(consumer.clone(), Default::default()) + { + guard.entry_ref(consumer).and_modify(|bytes| { + bytes.fetch_add( + already_reserved.load(Ordering::Acquire), + Ordering::AcqRel, + ); + }); + } + } + + fn unregister(&self, consumer: &MemoryConsumer) { + self.inner.unregister(consumer); + self.tracked_consumers.lock().remove(consumer); + } + + fn grow(&self, reservation: &MemoryReservation, additional: usize) { + self.inner.grow(reservation, additional); + self.tracked_consumers + .lock() + .entry_ref(reservation.consumer()) + .and_modify(|bytes| { + bytes.fetch_add(additional as u64, Ordering::AcqRel); + }); + } + + fn shrink(&self, reservation: &MemoryReservation, shrink: usize) { + self.inner.shrink(reservation, shrink); + self.tracked_consumers + .lock() + .entry_ref(reservation.consumer()) + .and_modify(|bytes| { + bytes.fetch_sub(shrink as u64, Ordering::AcqRel); + }); + } + + fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()> { + self.inner + .try_grow(reservation, additional) + .map_err(|e| match e { + DataFusionError::ResourcesExhausted(e) => { + // wrap OOM message in top consumers + DataFusionError::ResourcesExhausted( + provide_top_memory_consumers_to_error_msg( + e.to_owned(), + self.report_top(self.top.into()), + ), + ) + } + _ => e, + })?; + + self.tracked_consumers + .lock() + .entry_ref(reservation.consumer()) + .and_modify(|bytes| { + bytes.fetch_add(additional as u64, Ordering::AcqRel); + }); + Ok(()) + } + + fn reserved(&self) -> usize { + self.inner.reserved() + } +} + +fn provide_top_memory_consumers_to_error_msg( + error_msg: String, + top_consumers: String, +) -> String { + format!("Resources exhausted with top memory consumers (across reservations) are: {}. Error: {}", top_consumers, error_msg) } #[cfg(test)] @@ -263,10 +418,10 @@ mod tests { assert_eq!(pool.reserved(), 4000); let err = r2.try_grow(1).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated - maximum available is 0"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated for this reservation - 0 bytes remain available for the total pool"); let err = r2.try_grow(1).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated - maximum available is 0"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated for this reservation - 0 bytes remain available for the total pool"); r1.shrink(1990); r2.shrink(2000); @@ -291,12 +446,12 @@ mod tests { .register(&pool); let err = r3.try_grow(70).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated - maximum available is 40"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated for this reservation - 40 bytes remain available for the total pool"); //Shrinking r2 to zero doesn't allow a3 to allocate more than 45 r2.free(); let err = r3.try_grow(70).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated - maximum available is 40"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated for this reservation - 40 bytes remain available for the total pool"); // But dropping r2 does drop(r2); @@ -309,6 +464,226 @@ mod tests { let mut r4 = MemoryConsumer::new("s4").register(&pool); let err = r4.try_grow(30).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 30 bytes for s4 with 0 bytes already allocated - maximum available is 20"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 30 bytes for s4 with 0 bytes already allocated for this reservation - 20 bytes remain available for the total pool"); + } + + #[test] + fn test_tracked_consumers_pool() { + let pool: Arc = Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(100), + NonZeroUsize::new(3).unwrap(), + )); + + // Test: use all the different interfaces to change reservation size + + // set r1=50, using grow and shrink + let mut r1 = MemoryConsumer::new("r1").register(&pool); + r1.grow(70); + r1.shrink(20); + + // set r2=15 using try_grow + let mut r2 = MemoryConsumer::new("r2").register(&pool); + r2.try_grow(15) + .expect("should succeed in memory allotment for r2"); + + // set r3=20 using try_resize + let mut r3 = MemoryConsumer::new("r3").register(&pool); + r3.try_resize(25) + .expect("should succeed in memory allotment for r3"); + r3.try_resize(20) + .expect("should succeed in memory allotment for r3"); + + // set r4=10 + // this should not be reported in top 3 + let mut r4 = MemoryConsumer::new("r4").register(&pool); + r4.grow(10); + + // Test: reports if new reservation causes error + // using the previously set sizes for other consumers + let mut r5 = MemoryConsumer::new("r5").register(&pool); + let expected = "Resources exhausted with top memory consumers (across reservations) are: r1 consumed 50 bytes, r3 consumed 20 bytes, r2 consumed 15 bytes. Error: Failed to allocate additional 150 bytes for r5 with 0 bytes already allocated for this reservation - 5 bytes remain available for the total pool"; + let res = r5.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) + ), + "should provide list of top memory consumers, instead found {:?}", + res + ); + } + + #[test] + fn test_tracked_consumers_pool_register() { + let pool: Arc = Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(100), + NonZeroUsize::new(3).unwrap(), + )); + + let same_name = "foo"; + + // Test: see error message when no consumers recorded yet + let mut r0 = MemoryConsumer::new(same_name).register(&pool); + let expected = "Resources exhausted with top memory consumers (across reservations) are: foo consumed 0 bytes. Error: Failed to allocate additional 150 bytes for foo with 0 bytes already allocated for this reservation - 100 bytes remain available for the total pool"; + let res = r0.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) + ), + "should provide proper error when no reservations have been made yet, instead found {:?}", res + ); + + // API: multiple registrations using the same hashed consumer, + // will be recognized as the same in the TrackConsumersPool. + + // Test: will be the same per Top Consumers reported. + r0.grow(10); // make r0=10, pool available=90 + let new_consumer_same_name = MemoryConsumer::new(same_name); + let mut r1 = new_consumer_same_name.clone().register(&pool); + // TODO: the insufficient_capacity_err() message is per reservation, not per consumer. + // a followup PR will clarify this message "0 bytes already allocated for this reservation" + let expected = "Resources exhausted with top memory consumers (across reservations) are: foo consumed 10 bytes. Error: Failed to allocate additional 150 bytes for foo with 0 bytes already allocated for this reservation - 90 bytes remain available for the total pool"; + let res = r1.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) + ), + "should provide proper error with same hashed consumer (a single foo=10 bytes, available=90), instead found {:?}", res + ); + + // Test: will accumulate size changes per consumer, not per reservation + r1.grow(20); + let expected = "Resources exhausted with top memory consumers (across reservations) are: foo consumed 30 bytes. Error: Failed to allocate additional 150 bytes for foo with 20 bytes already allocated for this reservation - 70 bytes remain available for the total pool"; + let res = r1.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) + ), + "should provide proper error with same hashed consumer (a single foo=30 bytes, available=70), instead found {:?}", res + ); + + // Test: different hashed consumer, (even with the same name), + // will be recognized as different in the TrackConsumersPool + let consumer_with_same_name_but_different_hash = + MemoryConsumer::new(same_name).with_can_spill(true); + let mut r2 = consumer_with_same_name_but_different_hash.register(&pool); + let expected = "Resources exhausted with top memory consumers (across reservations) are: foo(can_spill=false) consumed 30 bytes, foo(can_spill=true) consumed 0 bytes. Error: Failed to allocate additional 150 bytes for foo with 0 bytes already allocated for this reservation - 70 bytes remain available for the total pool"; + let res = r2.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) + ), + "should provide proper error with different hashed consumer (foo(can_spill=false)=30 bytes and foo(can_spill=true)=0 bytes, available=70), instead found {:?}", res + ); + } + + #[test] + fn test_tracked_consumers_pool_deregister() { + fn test_per_pool_type(pool: Arc) { + // Baseline: see the 2 memory consumers + let mut r0 = MemoryConsumer::new("r0").register(&pool); + r0.grow(10); + let r1_consumer = MemoryConsumer::new("r1"); + let mut r1 = r1_consumer.clone().register(&pool); + r1.grow(20); + let expected = "Resources exhausted with top memory consumers (across reservations) are: r1 consumed 20 bytes, r0 consumed 10 bytes. Error: Failed to allocate additional 150 bytes for r0 with 10 bytes already allocated for this reservation - 70 bytes remain available for the total pool"; + let res = r0.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) + ), + "should provide proper error with both consumers, instead found {:?}", + res + ); + + // Test: unregister one + // only the remaining one should be listed + pool.unregister(&r1_consumer); + let expected_consumers = "Resources exhausted with top memory consumers (across reservations) are: r0 consumed 10 bytes"; + let res = r0.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected_consumers) + ), + "should provide proper error with only 1 consumer left registered, instead found {:?}", res + ); + + // Test: actual message we see is the `available is 70`. When it should be `available is 90`. + // This is because the pool.shrink() does not automatically occur within the inner_pool.deregister(). + let expected_70_available = "Failed to allocate additional 150 bytes for r0 with 10 bytes already allocated for this reservation - 70 bytes remain available for the total pool"; + let res = r0.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected_70_available) + ), + "should find that the inner pool will still count all bytes for the deregistered consumer until the reservation is dropped, instead found {:?}", res + ); + + // Test: the registration needs to free itself (or be dropped), + // for the proper error message + r1.free(); + let expected_90_available = "Failed to allocate additional 150 bytes for r0 with 10 bytes already allocated for this reservation - 90 bytes remain available for the total pool"; + let res = r0.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected_90_available) + ), + "should correctly account the total bytes after reservation is free, instead found {:?}", res + ); + } + + let tracked_spill_pool: Arc = Arc::new(TrackConsumersPool::new( + FairSpillPool::new(100), + NonZeroUsize::new(3).unwrap(), + )); + test_per_pool_type(tracked_spill_pool); + + let tracked_greedy_pool: Arc = Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(100), + NonZeroUsize::new(3).unwrap(), + )); + test_per_pool_type(tracked_greedy_pool); + } + + #[test] + fn test_tracked_consumers_pool_use_beyond_errors() { + let upcasted: Arc = + Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(100), + NonZeroUsize::new(3).unwrap(), + )); + let pool: Arc = Arc::clone(&upcasted) + .downcast::>() + .unwrap(); + // set r1=20 + let mut r1 = MemoryConsumer::new("r1").register(&pool); + r1.grow(20); + // set r2=15 + let mut r2 = MemoryConsumer::new("r2").register(&pool); + r2.grow(15); + // set r3=45 + let mut r3 = MemoryConsumer::new("r3").register(&pool); + r3.grow(45); + + let downcasted = upcasted + .downcast::>() + .unwrap(); + + // Test: can get runtime metrics, even without an error thrown + let expected = "r3 consumed 45 bytes, r1 consumed 20 bytes"; + let res = downcasted.report_top(2); + assert_eq!( + res, expected, + "should provide list of top memory consumers, instead found {:?}", + res + ); } } diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs deleted file mode 100644 index 4037e3c5db9b..000000000000 --- a/datafusion/expr/src/aggregate_function.rs +++ /dev/null @@ -1,156 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Aggregate function module contains all built-in aggregate functions definitions - -use std::{fmt, str::FromStr}; - -use crate::utils; -use crate::{type_coercion::aggregates::*, Signature, Volatility}; - -use arrow::datatypes::DataType; -use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; - -use strum_macros::EnumIter; - -/// Enum of all built-in aggregate functions -// Contributor's guide for adding new aggregate functions -// https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] -pub enum AggregateFunction { - /// Minimum - Min, - /// Maximum - Max, -} - -impl AggregateFunction { - pub fn name(&self) -> &str { - use AggregateFunction::*; - match self { - Min => "MIN", - Max => "MAX", - } - } -} - -impl fmt::Display for AggregateFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.name()) - } -} - -impl FromStr for AggregateFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - Ok(match name { - // general - "max" => AggregateFunction::Max, - "min" => AggregateFunction::Min, - _ => { - return plan_err!("There is no built-in function named {name}"); - } - }) - } -} - -impl AggregateFunction { - /// Returns the datatype of the aggregate function given its argument types - /// - /// This is used to get the returned data type for aggregate expr. - pub fn return_type( - &self, - input_expr_types: &[DataType], - _input_expr_nullable: &[bool], - ) -> Result { - // Note that this function *must* return the same type that the respective physical expression returns - // or the execution panics. - - let coerced_data_types = coerce_types(self, input_expr_types, &self.signature()) - // original errors are all related to wrong function signature - // aggregate them for better error message - .map_err(|_| { - plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types, - ) - ) - })?; - - match self { - AggregateFunction::Max | AggregateFunction::Min => { - // For min and max agg function, the returned type is same as input type. - // The coerced_data_types is same with input_types. - Ok(coerced_data_types[0].clone()) - } - } - } - - /// Returns if the return type of the aggregate function is nullable given its argument - /// nullability - pub fn nullable(&self) -> Result { - match self { - AggregateFunction::Max | AggregateFunction::Min => Ok(true), - } - } -} - -impl AggregateFunction { - /// the signatures supported by the function `fun`. - pub fn signature(&self) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. - match self { - AggregateFunction::Min | AggregateFunction::Max => { - let valid = STRINGS - .iter() - .chain(NUMERICS.iter()) - .chain(TIMESTAMPS.iter()) - .chain(DATES.iter()) - .chain(TIMES.iter()) - .chain(BINARYS.iter()) - .cloned() - .collect::>(); - Signature::uniform(1, valid, Volatility::Immutable) - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use strum::IntoEnumIterator; - - #[test] - // Test for AggregateFunction's Display and from_str() implementations. - // For each variant in AggregateFunction, it converts the variant to a string - // and then back to a variant. The test asserts that the original variant and - // the reconstructed variant are the same. This assertion is also necessary for - // function suggestion. See https://github.com/apache/datafusion/issues/8082 - fn test_display_and_from_str() { - for func_original in AggregateFunction::iter() { - let func_name = func_original.to_string(); - let func_from_str = - AggregateFunction::from_str(func_name.to_lowercase().as_str()).unwrap(); - assert_eq!(func_from_str, func_original); - } - } -} diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 68d5504eea48..708843494814 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -28,8 +28,8 @@ use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; use crate::utils::expr_to_columns; use crate::{ - aggregate_function, built_in_window_function, udaf, BuiltInWindowFunction, - ExprSchemable, Operator, Signature, WindowFrame, WindowUDF, + built_in_window_function, udaf, BuiltInWindowFunction, ExprSchemable, Operator, + Signature, WindowFrame, WindowUDF, }; use crate::{window_frame, Volatility}; @@ -630,7 +630,6 @@ impl Sort { #[derive(Debug, Clone, PartialEq, Eq, Hash)] /// Defines which implementation of an aggregate function DataFusion should call. pub enum AggregateFunctionDefinition { - BuiltIn(aggregate_function::AggregateFunction), /// Resolved to a user defined aggregate function UDF(Arc), } @@ -639,7 +638,6 @@ impl AggregateFunctionDefinition { /// Function's name for display pub fn name(&self) -> &str { match self { - AggregateFunctionDefinition::BuiltIn(fun) => fun.name(), AggregateFunctionDefinition::UDF(udf) => udf.name(), } } @@ -666,24 +664,6 @@ pub struct AggregateFunction { } impl AggregateFunction { - pub fn new( - fun: aggregate_function::AggregateFunction, - args: Vec, - distinct: bool, - filter: Option>, - order_by: Option>, - null_treatment: Option, - ) -> Self { - Self { - func_def: AggregateFunctionDefinition::BuiltIn(fun), - args, - distinct, - filter, - order_by, - null_treatment, - } - } - /// Create a new AggregateFunction expression with a user-defined function (UDF) pub fn new_udf( udf: Arc, @@ -709,7 +689,6 @@ impl AggregateFunction { /// Defines which implementation of an aggregate function DataFusion should call. pub enum WindowFunctionDefinition { /// A built in aggregate function that leverages an aggregate function - AggregateFunction(aggregate_function::AggregateFunction), /// A a built-in window function BuiltInWindowFunction(built_in_window_function::BuiltInWindowFunction), /// A user defined aggregate function @@ -723,12 +702,9 @@ impl WindowFunctionDefinition { pub fn return_type( &self, input_expr_types: &[DataType], - input_expr_nullable: &[bool], + _input_expr_nullable: &[bool], ) -> Result { match self { - WindowFunctionDefinition::AggregateFunction(fun) => { - fun.return_type(input_expr_types, input_expr_nullable) - } WindowFunctionDefinition::BuiltInWindowFunction(fun) => { fun.return_type(input_expr_types) } @@ -742,7 +718,6 @@ impl WindowFunctionDefinition { /// the signatures supported by the function `fun`. pub fn signature(&self) -> Signature { match self { - WindowFunctionDefinition::AggregateFunction(fun) => fun.signature(), WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.signature(), WindowFunctionDefinition::AggregateUDF(fun) => fun.signature().clone(), WindowFunctionDefinition::WindowUDF(fun) => fun.signature().clone(), @@ -754,7 +729,6 @@ impl WindowFunctionDefinition { match self { WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.name(), WindowFunctionDefinition::WindowUDF(fun) => fun.name(), - WindowFunctionDefinition::AggregateFunction(fun) => fun.name(), WindowFunctionDefinition::AggregateUDF(fun) => fun.name(), } } @@ -763,9 +737,6 @@ impl WindowFunctionDefinition { impl fmt::Display for WindowFunctionDefinition { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - WindowFunctionDefinition::AggregateFunction(fun) => { - std::fmt::Display::fmt(fun, f) - } WindowFunctionDefinition::BuiltInWindowFunction(fun) => { std::fmt::Display::fmt(fun, f) } @@ -775,12 +746,6 @@ impl fmt::Display for WindowFunctionDefinition { } } -impl From for WindowFunctionDefinition { - fn from(value: aggregate_function::AggregateFunction) -> Self { - Self::AggregateFunction(value) - } -} - impl From for WindowFunctionDefinition { fn from(value: BuiltInWindowFunction) -> Self { Self::BuiltInWindowFunction(value) @@ -866,10 +831,6 @@ pub fn find_df_window_func(name: &str) -> Option { Some(WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, )) - } else if let Ok(aggregate) = - aggregate_function::AggregateFunction::from_str(name.as_str()) - { - Some(WindowFunctionDefinition::AggregateFunction(aggregate)) } else { None } @@ -2589,8 +2550,6 @@ mod test { "first_value", "last_value", "nth_value", - "min", - "max", ]; for name in names { let fun = find_df_window_func(name).unwrap(); @@ -2607,18 +2566,6 @@ mod test { #[test] fn test_find_df_window_function() { - assert_eq!( - find_df_window_func("max"), - Some(WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Max - )) - ); - assert_eq!( - find_df_window_func("min"), - Some(WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Min - )) - ); assert_eq!( find_df_window_func("cume_dist"), Some(WindowFunctionDefinition::BuiltInWindowFunction( diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1f51cded2239..e9c5485656c8 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -26,9 +26,9 @@ use crate::function::{ StateFieldsArgs, }; use crate::{ - aggregate_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, - AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, - Signature, Volatility, + conditional_expressions::CaseBuilder, logical_plan::Subquery, AggregateUDF, Expr, + LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, Signature, + Volatility, }; use crate::{ AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl, @@ -150,30 +150,6 @@ pub fn not(expr: Expr) -> Expr { expr.not() } -/// Create an expression to represent the min() aggregate function -pub fn min(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Min, - vec![expr], - false, - None, - None, - None, - )) -} - -/// Create an expression to represent the max() aggregate function -pub fn max(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Max, - vec![expr], - false, - None, - None, - None, - )) -} - /// Return a new expression with bitwise AND pub fn bitwise_and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new( diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 4b56ca3d1c2e..2efdcae1a790 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -156,11 +156,13 @@ mod test { use arrow::datatypes::{DataType, Field, Schema}; use crate::{ - cast, col, lit, logical_plan::builder::LogicalTableSource, min, - test::function_stub::avg, try_cast, LogicalPlanBuilder, + cast, col, lit, logical_plan::builder::LogicalTableSource, try_cast, + LogicalPlanBuilder, }; use super::*; + use crate::test::function_stub::avg; + use crate::test::function_stub::min; #[test] fn rewrite_sort_cols_by_agg() { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 5e0571f712ee..6344b892adb7 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -198,14 +198,7 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - let nullability = args - .iter() - .map(|e| e.nullable(schema)) - .collect::>>()?; match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - fun.return_type(&data_types, &nullability) - } AggregateFunctionDefinition::UDF(fun) => { let new_types = data_types_with_aggregate_udf(&data_types, fun) .map_err(|err| { @@ -338,7 +331,6 @@ impl ExprSchemable for Expr { Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), Expr::AggregateFunction(AggregateFunction { func_def, .. }) => { match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => fun.nullable(), // TODO: UDF should be able to customize nullability AggregateFunctionDefinition::UDF(udf) if udf.name() == "count" => { Ok(false) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 3e02b0fdb3ed..2599ed52ad17 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -39,7 +39,6 @@ mod udaf; mod udf; mod udwf; -pub mod aggregate_function; pub mod conditional_expressions; pub mod execution_props; pub mod expr; @@ -66,7 +65,6 @@ pub mod window_function; pub mod window_state; pub use accumulator::Accumulator; -pub use aggregate_function::AggregateFunction; pub use built_in_window_function::BuiltInWindowFunction; pub use columnar_value::ColumnarValue; pub use expr::{ diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 98e262f0b187..736310c7ac0f 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1531,7 +1531,12 @@ pub fn wrap_projection_for_join_if_necessary( let need_project = join_keys.iter().any(|key| !matches!(key, Expr::Column(_))); let plan = if need_project { - let mut projection = expand_wildcard(input_schema, &input, None)?; + // Include all columns from the input and extend them with the join keys + let mut projection = input_schema + .columns() + .into_iter() + .map(Expr::Column) + .collect::>(); let join_key_items = alias_join_keys .iter() .flat_map(|expr| expr.try_as_col().is_none().then_some(expr)) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 54c857a2b701..6bea1ad948a1 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1895,7 +1895,7 @@ impl ToStringifiedPlan for LogicalPlan { } /// Produces no rows: An empty relation with an empty schema -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct EmptyRelation { /// Whether to produce a placeholder row pub produce_one_row: bool, @@ -1925,7 +1925,7 @@ pub struct EmptyRelation { /// intermediate table, then empty the intermediate table. /// /// [Postgres Docs]: https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-RECURSIVE -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct RecursiveQuery { /// Name of the query pub name: String, @@ -1942,7 +1942,7 @@ pub struct RecursiveQuery { /// Values expression. See /// [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Values { /// The table schema pub schema: DFSchemaRef, @@ -2023,7 +2023,7 @@ pub fn projection_schema(input: &LogicalPlan, exprs: &[Expr]) -> Result, @@ -2368,7 +2368,7 @@ impl TableScan { } /// Apply Cross Join to two logical plans -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct CrossJoin { /// Left input pub left: Arc, @@ -2379,7 +2379,7 @@ pub struct CrossJoin { } /// Repartition the plan based on a partitioning scheme. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Repartition { /// The incoming logical plan pub input: Arc, @@ -2388,7 +2388,7 @@ pub struct Repartition { } /// Union multiple inputs -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Union { /// Inputs to merge pub inputs: Vec>, @@ -2398,7 +2398,7 @@ pub struct Union { /// Prepare a statement but do not execute it. Prepare statements can have 0 or more /// `Expr::Placeholder` expressions that are filled in during execution -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Prepare { /// The name of the statement pub name: String, @@ -2430,7 +2430,7 @@ pub struct Prepare { /// | parent_span_id | Utf8 | YES | /// +--------------------+-----------------------------+-------------+ /// ``` -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DescribeTable { /// Table schema pub schema: Arc, @@ -2440,7 +2440,7 @@ pub struct DescribeTable { /// Produces a relation with string representations of /// various parts of the plan -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Explain { /// Should extra (detailed, intermediate plans) be included? pub verbose: bool, @@ -2456,7 +2456,7 @@ pub struct Explain { /// Runs the actual plan, and then prints the physical plan with /// with execution metrics. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Analyze { /// Should extra detail be included? pub verbose: bool, @@ -2471,7 +2471,7 @@ pub struct Analyze { // the manual `PartialEq` is removed in favor of a derive. // (see `PartialEq` the impl for details.) #[allow(clippy::derived_hash_with_manual_eq)] -#[derive(Clone, Eq, Hash)] +#[derive(Debug, Clone, Eq, Hash)] pub struct Extension { /// The runtime extension operator pub node: Arc, @@ -2487,7 +2487,7 @@ impl PartialEq for Extension { } /// Produces the first `n` tuples from its input and discards the rest. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Limit { /// Number of rows to skip before fetch pub skip: usize, @@ -2499,7 +2499,7 @@ pub struct Limit { } /// Removes duplicate rows from the input -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Distinct { /// Plain `DISTINCT` referencing all selection expressions All(Arc), @@ -2518,7 +2518,7 @@ impl Distinct { } /// Removes duplicate rows from the input -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DistinctOn { /// The `DISTINCT ON` clause expression list pub on_expr: Vec, @@ -2604,7 +2604,7 @@ impl DistinctOn { /// Aggregates its input based on a set of grouping and aggregate /// expressions (e.g. SUM). -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] // mark non_exhaustive to encourage use of try_new/new() #[non_exhaustive] pub struct Aggregate { @@ -2767,7 +2767,7 @@ fn calc_func_dependencies_for_project( } /// Sorts its input according to a list of sort expressions. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Sort { /// The sort expressions pub expr: Vec, @@ -2778,7 +2778,7 @@ pub struct Sort { } /// Join two logical plans on one or more join columns -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Join { /// Left input pub left: Arc, diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index eadd7ac2f83f..b1cec3bad774 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -266,9 +266,10 @@ impl Signature { } } - pub fn numeric(num: usize, volatility: Volatility) -> Self { + /// A specified number of numeric arguments + pub fn numeric(arg_count: usize, volatility: Volatility) -> Self { Self { - type_signature: TypeSignature::Numeric(num), + type_signature: TypeSignature::Numeric(arg_count), volatility, } } diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index 14a6522ebe91..72b73ccee44f 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -44,11 +44,9 @@ macro_rules! create_func { /// named STATIC_$(UDAF). For example `STATIC_FirstValue` #[allow(non_upper_case_globals)] static [< STATIC_ $UDAF >]: std::sync::OnceLock> = - std::sync::OnceLock::new(); + std::sync::OnceLock::new(); - /// AggregateFunction that returns a [AggregateUDF] for [$UDAF] - /// - /// [AggregateUDF]: crate::AggregateUDF + #[doc = concat!("AggregateFunction that returns a [AggregateUDF](crate::AggregateUDF) for [`", stringify!($UDAF), "`]")] pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc { [< STATIC_ $UDAF >] .get_or_init(|| { @@ -291,6 +289,180 @@ impl AggregateUDFImpl for Count { } } +create_func!(Min, min_udaf); + +pub fn min(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + min_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + +/// Testing stub implementation of Min aggregate +pub struct Min { + signature: Signature, + aliases: Vec, +} + +impl std::fmt::Debug for Min { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Min") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Min { + fn default() -> Self { + Self::new() + } +} + +impl Min { + pub fn new() -> Self { + Self { + aliases: vec!["min".to_string()], + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Min { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "MIN" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + not_impl_err!("no impl for stub") + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + fn is_descending(&self) -> Option { + Some(false) + } +} + +create_func!(Max, max_udaf); + +pub fn max(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + max_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + +/// Testing stub implementation of MAX aggregate +pub struct Max { + signature: Signature, + aliases: Vec, +} + +impl std::fmt::Debug for Max { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Max") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Max { + fn default() -> Self { + Self::new() + } +} + +impl Max { + pub fn new() -> Self { + Self { + aliases: vec!["max".to_string()], + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Max { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "MAX" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + not_impl_err!("no impl for stub") + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + fn is_descending(&self) -> Option { + Some(true) + } +} + /// Testing stub implementation of avg aggregate #[derive(Debug)] pub struct Avg { diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 813257122b65..69656f21c8ea 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -323,16 +323,6 @@ impl TreeNode for Expr { )? .map_data( |(new_args, new_filter, new_order_by)| match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - Ok(Expr::AggregateFunction(AggregateFunction::new( - fun, - new_args, - distinct, - new_filter, - new_order_by, - null_treatment, - ))) - } AggregateFunctionDefinition::UDF(fun) => { Ok(Expr::AggregateFunction(AggregateFunction::new_udf( fun, diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index a024401e18d5..e7e58bf84362 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -15,8 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::ops::Deref; - +use crate::TypeSignature; use arrow::datatypes::{ DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, @@ -24,8 +23,6 @@ use arrow::datatypes::{ use datafusion_common::{internal_err, plan_err, Result}; -use crate::{AggregateFunction, Signature, TypeSignature}; - pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8]; pub static SIGNED_INTEGERS: &[DataType] = &[ @@ -84,25 +81,6 @@ pub static TIMES: &[DataType] = &[ DataType::Time64(TimeUnit::Nanosecond), ]; -/// Returns the coerced data type for each `input_types`. -/// Different aggregate function with different input data type will get corresponding coerced data type. -pub fn coerce_types( - agg_fun: &AggregateFunction, - input_types: &[DataType], - signature: &Signature, -) -> Result> { - // Validate input_types matches (at least one of) the func signature. - check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?; - - match agg_fun { - AggregateFunction::Min | AggregateFunction::Max => { - // min and max support the dictionary data type - // unpack the dictionary to get the value - get_min_max_result_type(input_types) - } - } -} - /// Validate the length of `input_types` matches the `signature` for `agg_fun`. /// /// This method DOES NOT validate the argument types - only that (at least one, @@ -163,22 +141,6 @@ pub fn check_arg_count( Ok(()) } -fn get_min_max_result_type(input_types: &[DataType]) -> Result> { - // make sure that the input types only has one element. - assert_eq!(input_types.len(), 1); - // min and max support the dictionary data type - // unpack the dictionary to get the value - match &input_types[0] { - DataType::Dictionary(_, dict_value_type) => { - // TODO add checker, if the value type is complex data type - Ok(vec![dict_value_type.deref().clone()]) - } - // TODO add checker for datatype which min and max supported - // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function - _ => Ok(input_types.to_vec()), - } -} - /// function return type of a sum pub fn sum_return_type(arg_type: &DataType) -> Result { match arg_type { @@ -348,32 +310,6 @@ pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result Result<()> { diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 1d6919494587..3c135c044812 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1328,8 +1328,9 @@ mod tests { use super::*; use crate::{ col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, - test::function_stub::sum_udaf, AggregateFunction, Cast, ExprFunctionExt, - WindowFrame, WindowFunctionDefinition, + test::function_stub::max_udaf, test::function_stub::min_udaf, + test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFrame, + WindowFunctionDefinition, }; #[test] @@ -1343,15 +1344,15 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( @@ -1374,18 +1375,18 @@ mod tests { let created_at_desc = Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) @@ -1427,7 +1428,7 @@ mod tests { fn test_find_sort_exprs() -> Result<()> { let exprs = &[ Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )) .order_by(vec![ diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 26630a0352d5..43ddd37cfb6f 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -48,3 +48,6 @@ datafusion-physical-expr-common = { workspace = true } log = { workspace = true } paste = "1.0.14" sqlparser = { workspace = true } + +[dev-dependencies] +rand = { workspace = true } diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index f37c164799bd..85b79cdd5267 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -78,7 +78,7 @@ impl AggregateUDFImpl for ApproxMedian { Ok(vec![ Field::new(format_state_name(args.name, "max_size"), UInt64, false), Field::new(format_state_name(args.name, "sum"), Float64, false), - Field::new(format_state_name(args.name, "count"), Float64, false), + Field::new(format_state_name(args.name, "count"), UInt64, false), Field::new(format_state_name(args.name, "max"), Float64, false), Field::new(format_state_name(args.name, "min"), Float64, false), Field::new_list( diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 8fdd45e71cf3..03ab32e5ab03 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -213,7 +213,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { ), Field::new( format_state_name(args.name, "count"), - DataType::Float64, + DataType::UInt64, false, ), Field::new( @@ -405,7 +405,7 @@ impl Accumulator for ApproxPercentileAccumulator { } fn evaluate(&mut self) -> datafusion_common::Result { - if self.digest.count() == 0.0 { + if self.digest.count() == 0 { return ScalarValue::try_from(self.return_type.clone()); } let q = self.digest.estimate_quantile(self.percentile); @@ -486,8 +486,8 @@ mod tests { ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100); accumulator.merge_digests(&[t1]); - assert_eq!(accumulator.digest.count(), 50_000.0); + assert_eq!(accumulator.digest.count(), 50_000); accumulator.merge_digests(&[t2]); - assert_eq!(accumulator.digest.count(), 100_000.0); + assert_eq!(accumulator.digest.count(), 100_000); } } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index aacff28baeea..4668f1caf101 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -238,7 +238,7 @@ impl AggregateUDFImpl for Count { Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) } DataType::Utf8View => { - Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8)) + Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View)) } DataType::LargeUtf8 => { Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) @@ -246,6 +246,9 @@ impl AggregateUDFImpl for Count { DataType::Binary => Box::new(BytesDistinctCountAccumulator::::new( OutputType::Binary, )), + DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new( + OutputType::BinaryView, + )), DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( OutputType::Binary, )), diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 171186966644..b54cd181a0cb 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -65,6 +65,7 @@ pub mod covariance; pub mod first_last; pub mod hyperloglog; pub mod median; +pub mod min_max; pub mod regr; pub mod stddev; pub mod sum; @@ -110,7 +111,8 @@ pub mod expr_fn { pub use super::first_last::last_value; pub use super::grouping::grouping; pub use super::median::median; - pub use super::nth_value::nth_value; + pub use super::min_max::max; + pub use super::min_max::min; pub use super::regr::regr_avgx; pub use super::regr::regr_avgy; pub use super::regr::regr_count; @@ -137,6 +139,8 @@ pub fn all_default_aggregate_functions() -> Vec> { covariance::covar_pop_udaf(), correlation::corr_udaf(), sum::sum_udaf(), + min_max::max_udaf(), + min_max::min_udaf(), median::median_udaf(), count::count_udaf(), regr::regr_slope_udaf(), @@ -192,11 +196,11 @@ mod tests { #[test] fn test_no_duplicate_name() -> Result<()> { let mut names = HashSet::new(); + let migrated_functions = ["array_agg", "count", "max", "min"]; for func in all_default_aggregate_functions() { // TODO: remove this - // These functions are in intermidiate migration state, skip them - let name_lower_case = func.name().to_lowercase(); - if name_lower_case == "count" || name_lower_case == "array_agg" { + // These functions are in intermediate migration state, skip them + if migrated_functions.contains(&func.name().to_lowercase().as_str()) { continue; } assert!( diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index cae72cf35223..573b9fd5bdb2 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -86,11 +86,9 @@ macro_rules! create_func { /// named STATIC_$(UDAF). For example `STATIC_FirstValue` #[allow(non_upper_case_globals)] static [< STATIC_ $UDAF >]: std::sync::OnceLock> = - std::sync::OnceLock::new(); + std::sync::OnceLock::new(); - /// AggregateFunction that returns a [AggregateUDF] for [$UDAF] - /// - /// [AggregateUDF]: datafusion_expr::AggregateUDF + #[doc = concat!("AggregateFunction that returns a [`AggregateUDF`](datafusion_expr::AggregateUDF) for [`", stringify!($UDAF), "`]")] pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc { [< STATIC_ $UDAF >] .get_or_init(|| { diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs similarity index 60% rename from datafusion/physical-expr/src/aggregate/min_max.rs rename to datafusion/functions-aggregate/src/min_max.rs index f9362db30196..4d743983411d 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -2,7 +2,6 @@ // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // @@ -15,103 +14,107 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions that can evaluated at runtime during query execution +//! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function +//! [`Min`] and [`MinAccumulator`] accumulator for the `max` function -use std::any::Any; -use std::sync::Arc; +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. -use crate::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; -use crate::{AggregateExpr, PhysicalExpr}; +use arrow::array::{ + ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, + Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, Int32Array, + Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, + IntervalYearMonthArray, LargeBinaryArray, LargeStringArray, StringArray, + StringViewArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, +}; use arrow::compute; use arrow::datatypes::{ - DataType, Date32Type, Date64Type, IntervalUnit, Time32MillisecondType, - Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, -}; -use arrow::{ - array::{ - ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, - IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, - LargeBinaryArray, LargeStringArray, StringArray, Time32MillisecondArray, - Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, - }, - datatypes::Field, -}; -use arrow_array::types::{ - Decimal128Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, + DataType, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type, + Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; -use arrow_array::{BinaryViewArray, StringViewArray}; -use datafusion_common::internal_err; -use datafusion_common::ScalarValue; -use datafusion_common::{downcast_value, DataFusionError, Result}; -use datafusion_expr::{Accumulator, GroupsAccumulator}; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use arrow::array::Array; -use arrow::array::Decimal128Array; -use arrow::array::Decimal256Array; -use arrow::datatypes::i256; -use arrow::datatypes::Decimal256Type; +use arrow_schema::IntervalUnit; +use datafusion_common::{downcast_value, internal_err, DataFusionError, Result}; +use datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use std::fmt::Debug; -use super::moving_min_max; +use arrow::datatypes::i256; +use arrow::datatypes::{ + Date32Type, Date64Type, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; -// Min/max aggregation can take Dictionary encode input but always produces unpacked -// (aka non Dictionary) output. We need to adjust the output data type to reflect this. -// The reason min/max aggregate produces unpacked output because there is only one -// min/max value per group; there is no needs to keep them Dictionary encode -fn min_max_aggregate_data_type(input_type: DataType) -> DataType { - if let DataType::Dictionary(_, value_type) = input_type { - *value_type - } else { - input_type +use datafusion_common::ScalarValue; +use datafusion_expr::GroupsAccumulator; +use datafusion_expr::{ + function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility, +}; +use std::ops::Deref; + +fn get_min_max_result_type(input_types: &[DataType]) -> Result> { + // make sure that the input types only has one element. + assert_eq!(input_types.len(), 1); + // min and max support the dictionary data type + // unpack the dictionary to get the value + match &input_types[0] { + DataType::Dictionary(_, dict_value_type) => { + // TODO add checker, if the value type is complex data type + Ok(vec![dict_value_type.deref().clone()]) + } + // TODO add checker for datatype which min and max supported + // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function + _ => Ok(input_types.to_vec()), } } -/// MAX aggregate expression -#[derive(Debug, Clone)] +// MAX aggregate UDF +#[derive(Debug)] pub struct Max { - name: String, - data_type: DataType, - nullable: bool, - expr: Arc, + aliases: Vec, + signature: Signature, } impl Max { - /// Create a new MAX aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { + pub fn new() -> Self { Self { - name: name.into(), - expr, - data_type: min_max_aggregate_data_type(data_type), - nullable: true, + aliases: vec!["max".to_owned()], + signature: Signature::user_defined(Volatility::Immutable), } } } + +impl Default for Max { + fn default() -> Self { + Self::new() + } +} /// Creates a [`PrimitiveGroupsAccumulator`] for computing `MAX` /// the specified [`ArrowPrimitiveType`]. /// /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType macro_rules! instantiate_max_accumulator { - ($SELF:expr, $NATIVE:ident, $PRIMTYPE:ident) => {{ + ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ Ok(Box::new( - PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( - &$SELF.data_type, - |cur, new| { - if *cur < new { - *cur = new - } - }, - ) + PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur, new| { + if *cur < new { + *cur = new + } + }) // Initialize each accumulator to $NATIVE::MIN .with_starting_value($NATIVE::MIN), )) @@ -124,60 +127,48 @@ macro_rules! instantiate_max_accumulator { /// /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType macro_rules! instantiate_min_accumulator { - ($SELF:expr, $NATIVE:ident, $PRIMTYPE:ident) => {{ + ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ Ok(Box::new( - PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( - &$SELF.data_type, - |cur, new| { - if *cur > new { - *cur = new - } - }, - ) + PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| { + if *cur > new { + *cur = new + } + }) // Initialize each accumulator to $NATIVE::MAX .with_starting_value($NATIVE::MAX), )) }}; } -impl AggregateExpr for Max { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { +impl AggregateUDFImpl for Max { + fn as_any(&self) -> &dyn std::any::Any { self } - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) + fn name(&self) -> &str { + "MAX" } - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "max"), - self.data_type.clone(), - true, - )]) + fn signature(&self) -> &Signature { + &self.signature } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].to_owned()) } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(MaxAccumulator::try_new(&self.data_type)?)) + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(MaxAccumulator::try_new(acc_args.data_type)?)) } - fn name(&self) -> &str { - &self.name + fn aliases(&self) -> &[String] { + &self.aliases } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( - self.data_type, + args.data_type, Int8 | Int16 | Int32 | Int64 @@ -197,97 +188,92 @@ impl AggregateExpr for Max { ) } - fn create_groups_accumulator(&self) -> Result> { + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { use DataType::*; use TimeUnit::*; - - match self.data_type { - Int8 => instantiate_max_accumulator!(self, i8, Int8Type), - Int16 => instantiate_max_accumulator!(self, i16, Int16Type), - Int32 => instantiate_max_accumulator!(self, i32, Int32Type), - Int64 => instantiate_max_accumulator!(self, i64, Int64Type), - UInt8 => instantiate_max_accumulator!(self, u8, UInt8Type), - UInt16 => instantiate_max_accumulator!(self, u16, UInt16Type), - UInt32 => instantiate_max_accumulator!(self, u32, UInt32Type), - UInt64 => instantiate_max_accumulator!(self, u64, UInt64Type), + let data_type = args.data_type; + match data_type { + Int8 => instantiate_max_accumulator!(data_type, i8, Int8Type), + Int16 => instantiate_max_accumulator!(data_type, i16, Int16Type), + Int32 => instantiate_max_accumulator!(data_type, i32, Int32Type), + Int64 => instantiate_max_accumulator!(data_type, i64, Int64Type), + UInt8 => instantiate_max_accumulator!(data_type, u8, UInt8Type), + UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type), + UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type), + UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type), Float32 => { - instantiate_max_accumulator!(self, f32, Float32Type) + instantiate_max_accumulator!(data_type, f32, Float32Type) } Float64 => { - instantiate_max_accumulator!(self, f64, Float64Type) + instantiate_max_accumulator!(data_type, f64, Float64Type) } - Date32 => instantiate_max_accumulator!(self, i32, Date32Type), - Date64 => instantiate_max_accumulator!(self, i64, Date64Type), + Date32 => instantiate_max_accumulator!(data_type, i32, Date32Type), + Date64 => instantiate_max_accumulator!(data_type, i64, Date64Type), Time32(Second) => { - instantiate_max_accumulator!(self, i32, Time32SecondType) + instantiate_max_accumulator!(data_type, i32, Time32SecondType) } Time32(Millisecond) => { - instantiate_max_accumulator!(self, i32, Time32MillisecondType) + instantiate_max_accumulator!(data_type, i32, Time32MillisecondType) } Time64(Microsecond) => { - instantiate_max_accumulator!(self, i64, Time64MicrosecondType) + instantiate_max_accumulator!(data_type, i64, Time64MicrosecondType) } Time64(Nanosecond) => { - instantiate_max_accumulator!(self, i64, Time64NanosecondType) + instantiate_max_accumulator!(data_type, i64, Time64NanosecondType) } Timestamp(Second, _) => { - instantiate_max_accumulator!(self, i64, TimestampSecondType) + instantiate_max_accumulator!(data_type, i64, TimestampSecondType) } Timestamp(Millisecond, _) => { - instantiate_max_accumulator!(self, i64, TimestampMillisecondType) + instantiate_max_accumulator!(data_type, i64, TimestampMillisecondType) } Timestamp(Microsecond, _) => { - instantiate_max_accumulator!(self, i64, TimestampMicrosecondType) + instantiate_max_accumulator!(data_type, i64, TimestampMicrosecondType) } Timestamp(Nanosecond, _) => { - instantiate_max_accumulator!(self, i64, TimestampNanosecondType) + instantiate_max_accumulator!(data_type, i64, TimestampNanosecondType) } Decimal128(_, _) => { - instantiate_max_accumulator!(self, i128, Decimal128Type) + instantiate_max_accumulator!(data_type, i128, Decimal128Type) } Decimal256(_, _) => { - instantiate_max_accumulator!(self, i256, Decimal256Type) + instantiate_max_accumulator!(data_type, i256, Decimal256Type) } // It would be nice to have a fast implementation for Strings as well // https://github.com/apache/datafusion/issues/6906 // This is only reached if groups_accumulator_supported is out of sync - _ => internal_err!( - "GroupsAccumulator not supported for max({})", - self.data_type - ), + _ => internal_err!("GroupsAccumulator not supported for max({})", data_type), } } - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(SlidingMaxAccumulator::try_new(args.data_type)?)) } - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(SlidingMaxAccumulator::try_new(&self.data_type)?)) + fn is_descending(&self) -> Option { + Some(true) } - - fn get_minmax_desc(&self) -> Option<(Field, bool)> { - Some((self.field().ok()?, true)) + fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity { + datafusion_expr::utils::AggregateOrderSensitivity::Insensitive } -} -impl PartialEq for Max { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + get_min_max_result_type(arg_types) + } + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Identical } } -// Statically-typed version of min/max(array) -> ScalarValue for string types. +// Statically-typed version of min/max(array) -> ScalarValue for string types macro_rules! typed_min_max_batch_string { ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ let array = downcast_value!($VALUES, $ARRAYTYPE); @@ -296,8 +282,7 @@ macro_rules! typed_min_max_batch_string { ScalarValue::$SCALAR(value) }}; } - -// Statically-typed version of min/max(array) -> ScalarValue for binary types. +// Statically-typed version of min/max(array) -> ScalarValue for binay types. macro_rules! typed_min_max_batch_binary { ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ let array = downcast_value!($VALUES, $ARRAYTYPE); @@ -545,7 +530,6 @@ macro_rules! typed_min_max { ) }}; } - macro_rules! typed_min_max_float { ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ ScalarValue::$SCALAR(match ($VALUE, $DELTA) { @@ -804,16 +788,6 @@ macro_rules! min_max { }}; } -/// the minimum of two scalar values -pub fn min(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { - min_max!(lhs, rhs, min) -} - -/// the maximum of two scalar values -pub fn max(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { - min_max!(lhs, rhs, max) -} - /// An accumulator to compute the maximum value #[derive(Debug)] pub struct MaxAccumulator { @@ -833,7 +807,9 @@ impl Accumulator for MaxAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; let delta = &max_batch(values)?; - self.max = max(&self.max, delta)?; + let new_max: Result = + min_max!(&self.max, delta, max); + self.max = new_max?; Ok(()) } @@ -842,9 +818,8 @@ impl Accumulator for MaxAccumulator { } fn state(&mut self) -> Result> { - Ok(vec![self.max.clone()]) + Ok(vec![self.evaluate()?]) } - fn evaluate(&mut self) -> Result { Ok(self.max.clone()) } @@ -854,11 +829,10 @@ impl Accumulator for MaxAccumulator { } } -/// An accumulator to compute the maximum value #[derive(Debug)] pub struct SlidingMaxAccumulator { max: ScalarValue, - moving_max: moving_min_max::MovingMax, + moving_max: MovingMax, } impl SlidingMaxAccumulator { @@ -866,7 +840,7 @@ impl SlidingMaxAccumulator { pub fn try_new(datatype: &DataType) -> Result { Ok(Self { max: ScalarValue::try_from(datatype)?, - moving_max: moving_min_max::MovingMax::::new(), + moving_max: MovingMax::::new(), }) } } @@ -914,69 +888,56 @@ impl Accumulator for SlidingMaxAccumulator { } } -/// MIN aggregate expression -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct Min { - name: String, - data_type: DataType, - nullable: bool, - expr: Arc, + signature: Signature, + aliases: Vec, } impl Min { - /// Create a new MIN aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { + pub fn new() -> Self { Self { - name: name.into(), - expr, - data_type: min_max_aggregate_data_type(data_type), - nullable: true, + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["min".to_owned()], } } } -impl AggregateExpr for Min { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { +impl Default for Min { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for Min { + fn as_any(&self) -> &dyn std::any::Any { self } - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) + fn name(&self) -> &str { + "MIN" } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(MinAccumulator::try_new(&self.data_type)?)) + fn signature(&self) -> &Signature { + &self.signature } - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "min"), - self.data_type.clone(), - true, - )]) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].to_owned()) } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(MinAccumulator::try_new(acc_args.data_type)?)) } - fn name(&self) -> &str { - &self.name + fn aliases(&self) -> &[String] { + &self.aliases } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( - self.data_type, + args.data_type, Int8 | Int16 | Int32 | Int64 @@ -996,91 +957,92 @@ impl AggregateExpr for Min { ) } - fn create_groups_accumulator(&self) -> Result> { + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { use DataType::*; use TimeUnit::*; - match self.data_type { - Int8 => instantiate_min_accumulator!(self, i8, Int8Type), - Int16 => instantiate_min_accumulator!(self, i16, Int16Type), - Int32 => instantiate_min_accumulator!(self, i32, Int32Type), - Int64 => instantiate_min_accumulator!(self, i64, Int64Type), - UInt8 => instantiate_min_accumulator!(self, u8, UInt8Type), - UInt16 => instantiate_min_accumulator!(self, u16, UInt16Type), - UInt32 => instantiate_min_accumulator!(self, u32, UInt32Type), - UInt64 => instantiate_min_accumulator!(self, u64, UInt64Type), + let data_type = args.data_type; + match data_type { + Int8 => instantiate_min_accumulator!(data_type, i8, Int8Type), + Int16 => instantiate_min_accumulator!(data_type, i16, Int16Type), + Int32 => instantiate_min_accumulator!(data_type, i32, Int32Type), + Int64 => instantiate_min_accumulator!(data_type, i64, Int64Type), + UInt8 => instantiate_min_accumulator!(data_type, u8, UInt8Type), + UInt16 => instantiate_min_accumulator!(data_type, u16, UInt16Type), + UInt32 => instantiate_min_accumulator!(data_type, u32, UInt32Type), + UInt64 => instantiate_min_accumulator!(data_type, u64, UInt64Type), Float32 => { - instantiate_min_accumulator!(self, f32, Float32Type) + instantiate_min_accumulator!(data_type, f32, Float32Type) } Float64 => { - instantiate_min_accumulator!(self, f64, Float64Type) + instantiate_min_accumulator!(data_type, f64, Float64Type) } - Date32 => instantiate_min_accumulator!(self, i32, Date32Type), - Date64 => instantiate_min_accumulator!(self, i64, Date64Type), + Date32 => instantiate_min_accumulator!(data_type, i32, Date32Type), + Date64 => instantiate_min_accumulator!(data_type, i64, Date64Type), Time32(Second) => { - instantiate_min_accumulator!(self, i32, Time32SecondType) + instantiate_min_accumulator!(data_type, i32, Time32SecondType) } Time32(Millisecond) => { - instantiate_min_accumulator!(self, i32, Time32MillisecondType) + instantiate_min_accumulator!(data_type, i32, Time32MillisecondType) } Time64(Microsecond) => { - instantiate_min_accumulator!(self, i64, Time64MicrosecondType) + instantiate_min_accumulator!(data_type, i64, Time64MicrosecondType) } Time64(Nanosecond) => { - instantiate_min_accumulator!(self, i64, Time64NanosecondType) + instantiate_min_accumulator!(data_type, i64, Time64NanosecondType) } Timestamp(Second, _) => { - instantiate_min_accumulator!(self, i64, TimestampSecondType) + instantiate_min_accumulator!(data_type, i64, TimestampSecondType) } Timestamp(Millisecond, _) => { - instantiate_min_accumulator!(self, i64, TimestampMillisecondType) + instantiate_min_accumulator!(data_type, i64, TimestampMillisecondType) } Timestamp(Microsecond, _) => { - instantiate_min_accumulator!(self, i64, TimestampMicrosecondType) + instantiate_min_accumulator!(data_type, i64, TimestampMicrosecondType) } Timestamp(Nanosecond, _) => { - instantiate_min_accumulator!(self, i64, TimestampNanosecondType) + instantiate_min_accumulator!(data_type, i64, TimestampNanosecondType) } Decimal128(_, _) => { - instantiate_min_accumulator!(self, i128, Decimal128Type) + instantiate_min_accumulator!(data_type, i128, Decimal128Type) } Decimal256(_, _) => { - instantiate_min_accumulator!(self, i256, Decimal256Type) + instantiate_min_accumulator!(data_type, i256, Decimal256Type) } + + // It would be nice to have a fast implementation for Strings as well + // https://github.com/apache/datafusion/issues/6906 + // This is only reached if groups_accumulator_supported is out of sync - _ => internal_err!( - "GroupsAccumulator not supported for min({})", - self.data_type - ), + _ => internal_err!("GroupsAccumulator not supported for min({})", data_type), } } - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(SlidingMinAccumulator::try_new(args.data_type)?)) } - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(SlidingMinAccumulator::try_new(&self.data_type)?)) + fn is_descending(&self) -> Option { + Some(false) } - fn get_minmax_desc(&self) -> Option<(Field, bool)> { - Some((self.field().ok()?, false)) + fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity { + datafusion_expr::utils::AggregateOrderSensitivity::Insensitive } -} -impl PartialEq for Min { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + get_min_max_result_type(arg_types) } -} + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Identical + } +} /// An accumulator to compute the minimum value #[derive(Debug)] pub struct MinAccumulator { @@ -1098,13 +1060,15 @@ impl MinAccumulator { impl Accumulator for MinAccumulator { fn state(&mut self) -> Result> { - Ok(vec![self.min.clone()]) + Ok(vec![self.evaluate()?]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; let delta = &min_batch(values)?; - self.min = min(&self.min, delta)?; + let new_min: Result = + min_max!(&self.min, delta, min); + self.min = new_min?; Ok(()) } @@ -1121,19 +1085,17 @@ impl Accumulator for MinAccumulator { } } -/// An accumulator to compute the minimum value #[derive(Debug)] pub struct SlidingMinAccumulator { min: ScalarValue, - moving_min: moving_min_max::MovingMin, + moving_min: MovingMin, } impl SlidingMinAccumulator { - /// new min accumulator pub fn try_new(datatype: &DataType) -> Result { Ok(Self { min: ScalarValue::try_from(datatype)?, - moving_min: moving_min_max::MovingMin::::new(), + moving_min: MovingMin::::new(), }) } } @@ -1186,12 +1148,278 @@ impl Accumulator for SlidingMinAccumulator { } } +// +// Moving min and moving max +// The implementation is taken from https://github.com/spebern/moving_min_max/blob/master/src/lib.rs. + +// Keep track of the minimum or maximum value in a sliding window. +// +// `moving min max` provides one data structure for keeping track of the +// minimum value and one for keeping track of the maximum value in a sliding +// window. +// +// Each element is stored with the current min/max. One stack to push and another one for pop. If pop stack is empty, +// push to this stack all elements popped from first stack while updating their current min/max. Now pop from +// the second stack (MovingMin/Max struct works as a queue). To find the minimum element of the queue, +// look at the smallest/largest two elements of the individual stacks, then take the minimum of those two values. +// +// The complexity of the operations are +// - O(1) for getting the minimum/maximum +// - O(1) for push +// - amortized O(1) for pop + +/// ``` +/// # use datafusion_functions_aggregate::min_max::MovingMin; +/// let mut moving_min = MovingMin::::new(); +/// moving_min.push(2); +/// moving_min.push(1); +/// moving_min.push(3); +/// +/// assert_eq!(moving_min.min(), Some(&1)); +/// assert_eq!(moving_min.pop(), Some(2)); +/// +/// assert_eq!(moving_min.min(), Some(&1)); +/// assert_eq!(moving_min.pop(), Some(1)); +/// +/// assert_eq!(moving_min.min(), Some(&3)); +/// assert_eq!(moving_min.pop(), Some(3)); +/// +/// assert_eq!(moving_min.min(), None); +/// assert_eq!(moving_min.pop(), None); +/// ``` +#[derive(Debug)] +pub struct MovingMin { + push_stack: Vec<(T, T)>, + pop_stack: Vec<(T, T)>, +} + +impl Default for MovingMin { + fn default() -> Self { + Self { + push_stack: Vec::new(), + pop_stack: Vec::new(), + } + } +} + +impl MovingMin { + /// Creates a new `MovingMin` to keep track of the minimum in a sliding + /// window. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Creates a new `MovingMin` to keep track of the minimum in a sliding + /// window with `capacity` allocated slots. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + push_stack: Vec::with_capacity(capacity), + pop_stack: Vec::with_capacity(capacity), + } + } + + /// Returns the minimum of the sliding window or `None` if the window is + /// empty. + #[inline] + pub fn min(&self) -> Option<&T> { + match (self.push_stack.last(), self.pop_stack.last()) { + (None, None) => None, + (Some((_, min)), None) => Some(min), + (None, Some((_, min))) => Some(min), + (Some((_, a)), Some((_, b))) => Some(if a < b { a } else { b }), + } + } + + /// Pushes a new element into the sliding window. + #[inline] + pub fn push(&mut self, val: T) { + self.push_stack.push(match self.push_stack.last() { + Some((_, min)) => { + if val > *min { + (val, min.clone()) + } else { + (val.clone(), val) + } + } + None => (val.clone(), val), + }); + } + + /// Removes and returns the last value of the sliding window. + #[inline] + pub fn pop(&mut self) -> Option { + if self.pop_stack.is_empty() { + match self.push_stack.pop() { + Some((val, _)) => { + let mut last = (val.clone(), val); + self.pop_stack.push(last.clone()); + while let Some((val, _)) = self.push_stack.pop() { + let min = if last.1 < val { + last.1.clone() + } else { + val.clone() + }; + last = (val.clone(), min); + self.pop_stack.push(last.clone()); + } + } + None => return None, + } + } + self.pop_stack.pop().map(|(val, _)| val) + } + + /// Returns the number of elements stored in the sliding window. + #[inline] + pub fn len(&self) -> usize { + self.push_stack.len() + self.pop_stack.len() + } + + /// Returns `true` if the moving window contains no elements. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} +/// ``` +/// # use datafusion_functions_aggregate::min_max::MovingMax; +/// let mut moving_max = MovingMax::::new(); +/// moving_max.push(2); +/// moving_max.push(3); +/// moving_max.push(1); +/// +/// assert_eq!(moving_max.max(), Some(&3)); +/// assert_eq!(moving_max.pop(), Some(2)); +/// +/// assert_eq!(moving_max.max(), Some(&3)); +/// assert_eq!(moving_max.pop(), Some(3)); +/// +/// assert_eq!(moving_max.max(), Some(&1)); +/// assert_eq!(moving_max.pop(), Some(1)); +/// +/// assert_eq!(moving_max.max(), None); +/// assert_eq!(moving_max.pop(), None); +/// ``` +#[derive(Debug)] +pub struct MovingMax { + push_stack: Vec<(T, T)>, + pop_stack: Vec<(T, T)>, +} + +impl Default for MovingMax { + fn default() -> Self { + Self { + push_stack: Vec::new(), + pop_stack: Vec::new(), + } + } +} + +impl MovingMax { + /// Creates a new `MovingMax` to keep track of the maximum in a sliding window. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Creates a new `MovingMax` to keep track of the maximum in a sliding window with + /// `capacity` allocated slots. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + push_stack: Vec::with_capacity(capacity), + pop_stack: Vec::with_capacity(capacity), + } + } + + /// Returns the maximum of the sliding window or `None` if the window is empty. + #[inline] + pub fn max(&self) -> Option<&T> { + match (self.push_stack.last(), self.pop_stack.last()) { + (None, None) => None, + (Some((_, max)), None) => Some(max), + (None, Some((_, max))) => Some(max), + (Some((_, a)), Some((_, b))) => Some(if a > b { a } else { b }), + } + } + + /// Pushes a new element into the sliding window. + #[inline] + pub fn push(&mut self, val: T) { + self.push_stack.push(match self.push_stack.last() { + Some((_, max)) => { + if val < *max { + (val, max.clone()) + } else { + (val.clone(), val) + } + } + None => (val.clone(), val), + }); + } + + /// Removes and returns the last value of the sliding window. + #[inline] + pub fn pop(&mut self) -> Option { + if self.pop_stack.is_empty() { + match self.push_stack.pop() { + Some((val, _)) => { + let mut last = (val.clone(), val); + self.pop_stack.push(last.clone()); + while let Some((val, _)) = self.push_stack.pop() { + let max = if last.1 > val { + last.1.clone() + } else { + val.clone() + }; + last = (val.clone(), max); + self.pop_stack.push(last.clone()); + } + } + None => return None, + } + } + self.pop_stack.pop().map(|(val, _)| val) + } + + /// Returns the number of elements stored in the sliding window. + #[inline] + pub fn len(&self) -> usize { + self.push_stack.len() + self.pop_stack.len() + } + + /// Returns `true` if the moving window contains no elements. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +make_udaf_expr_and_func!( + Max, + max, + expression, + "Returns the maximum of a group of values.", + max_udaf +); + +make_udaf_expr_and_func!( + Min, + min, + expression, + "Returns the minimum of a group of values.", + min_udaf +); + #[cfg(test)] mod tests { use super::*; use arrow::datatypes::{ IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, }; + use std::sync::Arc; #[test] fn interval_min_max() { @@ -1324,4 +1552,100 @@ mod tests { check(&mut max(), &[&[zero], &[neg_inf]], zero); check(&mut max(), &[&[zero, neg_inf]], zero); } + + use datafusion_common::Result; + use rand::Rng; + + fn get_random_vec_i32(len: usize) -> Vec { + let mut rng = rand::thread_rng(); + let mut input = Vec::with_capacity(len); + for _i in 0..len { + input.push(rng.gen_range(0..100)); + } + input + } + + fn moving_min_i32(len: usize, n_sliding_window: usize) -> Result<()> { + let data = get_random_vec_i32(len); + let mut expected = Vec::with_capacity(len); + let mut moving_min = MovingMin::::new(); + let mut res = Vec::with_capacity(len); + for i in 0..len { + let start = i.saturating_sub(n_sliding_window); + expected.push(*data[start..i + 1].iter().min().unwrap()); + + moving_min.push(data[i]); + if i > n_sliding_window { + moving_min.pop(); + } + res.push(*moving_min.min().unwrap()); + } + assert_eq!(res, expected); + Ok(()) + } + + fn moving_max_i32(len: usize, n_sliding_window: usize) -> Result<()> { + let data = get_random_vec_i32(len); + let mut expected = Vec::with_capacity(len); + let mut moving_max = MovingMax::::new(); + let mut res = Vec::with_capacity(len); + for i in 0..len { + let start = i.saturating_sub(n_sliding_window); + expected.push(*data[start..i + 1].iter().max().unwrap()); + + moving_max.push(data[i]); + if i > n_sliding_window { + moving_max.pop(); + } + res.push(*moving_max.max().unwrap()); + } + assert_eq!(res, expected); + Ok(()) + } + + #[test] + fn moving_min_tests() -> Result<()> { + moving_min_i32(100, 10)?; + moving_min_i32(100, 20)?; + moving_min_i32(100, 50)?; + moving_min_i32(100, 100)?; + Ok(()) + } + + #[test] + fn moving_max_tests() -> Result<()> { + moving_max_i32(100, 10)?; + moving_max_i32(100, 20)?; + moving_max_i32(100, 50)?; + moving_max_i32(100, 100)?; + Ok(()) + } + + #[test] + fn test_min_max_coerce_types() { + // the coerced types is same with input types + let funs: Vec> = + vec![Box::new(Min::new()), Box::new(Max::new())]; + let input_types = vec![ + vec![DataType::Int32], + vec![DataType::Decimal128(10, 2)], + vec![DataType::Decimal256(1, 1)], + vec![DataType::Utf8], + ]; + for fun in funs { + for input_type in &input_types { + let result = fun.coerce_types(input_type); + assert_eq!(*input_type, result.unwrap()); + } + } + } + + #[test] + fn test_get_min_max_return_type_coerce_dictionary() -> Result<()> { + let data_type = + DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Int32)); + let result = get_min_max_result_type(&[data_type])?; + assert_eq!(result, vec![DataType::Int32]); + Ok(()) + } } diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 9d16616e1c9a..30f03f4822f9 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the `string_agg` function +//! [`StringAgg`] accumulator for the `string_agg` function use arrow::array::ArrayRef; use arrow_schema::DataType; diff --git a/datafusion/functions-nested/src/macros.rs b/datafusion/functions-nested/src/macros.rs index a6e0c2ee62be..00247f39ac10 100644 --- a/datafusion/functions-nested/src/macros.rs +++ b/datafusion/functions-nested/src/macros.rs @@ -90,9 +90,9 @@ macro_rules! create_func { #[allow(non_upper_case_globals)] static [< STATIC_ $UDF >]: std::sync::OnceLock> = std::sync::OnceLock::new(); - /// ScalarFunction that returns a [`ScalarUDF`] for [`$UDF`] - /// - /// [`ScalarUDF`]: datafusion_expr::ScalarUDF + + #[doc = concat!("ScalarFunction that returns a [`ScalarUDF`](datafusion_expr::ScalarUDF) for ")] + #[doc = stringify!($UDF)] pub fn $SCALAR_UDF_FN() -> std::sync::Arc { [< STATIC_ $UDF >] .get_or_init(|| { diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index 97c54cc77beb..fee3e83a0d65 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -171,9 +171,6 @@ impl ExprPlanner for FieldAccessPlanner { } fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { - if let AggregateFunctionDefinition::UDF(udf) = &agg_func.func_def { - return udf.name() == "array_agg"; - } - - false + let AggregateFunctionDefinition::UDF(udf) = &agg_func.func_def; + return udf.name() == "array_agg"; } diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 9227f9e3a2a8..c4db3e77049d 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -18,12 +18,11 @@ //! [`ArrowCastFunc`]: Implementation of the `arrow_cast` use std::any::Any; -use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc}; -use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; +use arrow::datatypes::DataType; use datafusion_common::{ - internal_err, plan_datafusion_err, plan_err, DataFusionError, ExprSchema, Result, - ScalarValue, + arrow_datafusion_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, + ExprSchema, Result, ScalarValue, }; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; @@ -44,7 +43,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} /// select cast(column_x as int) ... /// ``` /// -/// You can use the `arrow_cast` functiont to cast to a specific arrow type +/// Use the `arrow_cast` function to cast to a specific arrow type /// /// For example /// ```sql @@ -139,767 +138,11 @@ fn data_type_from_args(args: &[Expr]) -> Result { &args[1] ); }; - parse_data_type(val) -} - -/// Parses `str` into a `DataType`. -/// -/// `parse_data_type` is the reverse of [`DataType`]'s `Display` -/// impl, and maintains the invariant that -/// `parse_data_type(data_type.to_string()) == data_type` -/// -/// Remove if added to arrow: -fn parse_data_type(val: &str) -> Result { - Parser::new(val).parse() -} - -fn make_error(val: &str, msg: &str) -> DataFusionError { - plan_datafusion_err!("Unsupported type '{val}'. Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'. Error {msg}" ) -} - -fn make_error_expected(val: &str, expected: &Token, actual: &Token) -> DataFusionError { - make_error(val, &format!("Expected '{expected}', got '{actual}'")) -} - -#[derive(Debug)] -/// Implementation of `parse_data_type`, modeled after -struct Parser<'a> { - val: &'a str, - tokenizer: Tokenizer<'a>, -} - -impl<'a> Parser<'a> { - fn new(val: &'a str) -> Self { - Self { - val, - tokenizer: Tokenizer::new(val), - } - } - - fn parse(mut self) -> Result { - let data_type = self.parse_next_type()?; - // ensure that there is no trailing content - if self.tokenizer.next().is_some() { - Err(make_error( - self.val, - &format!("checking trailing content after parsing '{data_type}'"), - )) - } else { - Ok(data_type) - } - } - - /// parses the next full DataType - fn parse_next_type(&mut self) -> Result { - match self.next_token()? { - Token::SimpleType(data_type) => Ok(data_type), - Token::Timestamp => self.parse_timestamp(), - Token::Time32 => self.parse_time32(), - Token::Time64 => self.parse_time64(), - Token::Duration => self.parse_duration(), - Token::Interval => self.parse_interval(), - Token::FixedSizeBinary => self.parse_fixed_size_binary(), - Token::Decimal128 => self.parse_decimal_128(), - Token::Decimal256 => self.parse_decimal_256(), - Token::Dictionary => self.parse_dictionary(), - Token::List => self.parse_list(), - Token::LargeList => self.parse_large_list(), - Token::FixedSizeList => self.parse_fixed_size_list(), - tok => Err(make_error( - self.val, - &format!("finding next type, got unexpected '{tok}'"), - )), - } - } - - /// Parses the List type - fn parse_list(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let data_type = self.parse_next_type()?; - self.expect_token(Token::RParen)?; - Ok(DataType::List(Arc::new(Field::new( - "item", data_type, true, - )))) - } - - /// Parses the LargeList type - fn parse_large_list(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let data_type = self.parse_next_type()?; - self.expect_token(Token::RParen)?; - Ok(DataType::LargeList(Arc::new(Field::new( - "item", data_type, true, - )))) - } - - /// Parses the FixedSizeList type - fn parse_fixed_size_list(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let length = self.parse_i32("FixedSizeList")?; - self.expect_token(Token::Comma)?; - let data_type = self.parse_next_type()?; - self.expect_token(Token::RParen)?; - Ok(DataType::FixedSizeList( - Arc::new(Field::new("item", data_type, true)), - length, - )) - } - - /// Parses the next timeunit - fn parse_time_unit(&mut self, context: &str) -> Result { - match self.next_token()? { - Token::TimeUnit(time_unit) => Ok(time_unit), - tok => Err(make_error( - self.val, - &format!("finding TimeUnit for {context}, got {tok}"), - )), - } - } - - /// Parses the next timezone - fn parse_timezone(&mut self, context: &str) -> Result> { - match self.next_token()? { - Token::None => Ok(None), - Token::Some => { - self.expect_token(Token::LParen)?; - let timezone = self.parse_double_quoted_string("Timezone")?; - self.expect_token(Token::RParen)?; - Ok(Some(timezone)) - } - tok => Err(make_error( - self.val, - &format!("finding Timezone for {context}, got {tok}"), - )), - } - } - - /// Parses the next double quoted string - fn parse_double_quoted_string(&mut self, context: &str) -> Result { - match self.next_token()? { - Token::DoubleQuotedString(s) => Ok(s), - tok => Err(make_error( - self.val, - &format!("finding double quoted string for {context}, got '{tok}'"), - )), - } - } - - /// Parses the next integer value - fn parse_i64(&mut self, context: &str) -> Result { - match self.next_token()? { - Token::Integer(v) => Ok(v), - tok => Err(make_error( - self.val, - &format!("finding i64 for {context}, got '{tok}'"), - )), - } - } - - /// Parses the next i32 integer value - fn parse_i32(&mut self, context: &str) -> Result { - let length = self.parse_i64(context)?; - length.try_into().map_err(|e| { - make_error( - self.val, - &format!("converting {length} into i32 for {context}: {e}"), - ) - }) - } - - /// Parses the next i8 integer value - fn parse_i8(&mut self, context: &str) -> Result { - let length = self.parse_i64(context)?; - length.try_into().map_err(|e| { - make_error( - self.val, - &format!("converting {length} into i8 for {context}: {e}"), - ) - }) - } - - /// Parses the next u8 integer value - fn parse_u8(&mut self, context: &str) -> Result { - let length = self.parse_i64(context)?; - length.try_into().map_err(|e| { - make_error( - self.val, - &format!("converting {length} into u8 for {context}: {e}"), - ) - }) - } - - /// Parses the next timestamp (called after `Timestamp` has been consumed) - fn parse_timestamp(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let time_unit = self.parse_time_unit("Timestamp")?; - self.expect_token(Token::Comma)?; - let timezone = self.parse_timezone("Timestamp")?; - self.expect_token(Token::RParen)?; - Ok(DataType::Timestamp(time_unit, timezone.map(Into::into))) - } - - /// Parses the next Time32 (called after `Time32` has been consumed) - fn parse_time32(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let time_unit = self.parse_time_unit("Time32")?; - self.expect_token(Token::RParen)?; - Ok(DataType::Time32(time_unit)) - } - - /// Parses the next Time64 (called after `Time64` has been consumed) - fn parse_time64(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let time_unit = self.parse_time_unit("Time64")?; - self.expect_token(Token::RParen)?; - Ok(DataType::Time64(time_unit)) - } - - /// Parses the next Duration (called after `Duration` has been consumed) - fn parse_duration(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let time_unit = self.parse_time_unit("Duration")?; - self.expect_token(Token::RParen)?; - Ok(DataType::Duration(time_unit)) - } - - /// Parses the next Interval (called after `Interval` has been consumed) - fn parse_interval(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let interval_unit = match self.next_token()? { - Token::IntervalUnit(interval_unit) => interval_unit, - tok => { - return Err(make_error( - self.val, - &format!("finding IntervalUnit for Interval, got {tok}"), - )) - } - }; - self.expect_token(Token::RParen)?; - Ok(DataType::Interval(interval_unit)) - } - - /// Parses the next FixedSizeBinary (called after `FixedSizeBinary` has been consumed) - fn parse_fixed_size_binary(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let length = self.parse_i32("FixedSizeBinary")?; - self.expect_token(Token::RParen)?; - Ok(DataType::FixedSizeBinary(length)) - } - - /// Parses the next Decimal128 (called after `Decimal128` has been consumed) - fn parse_decimal_128(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let precision = self.parse_u8("Decimal128")?; - self.expect_token(Token::Comma)?; - let scale = self.parse_i8("Decimal128")?; - self.expect_token(Token::RParen)?; - Ok(DataType::Decimal128(precision, scale)) - } - - /// Parses the next Decimal256 (called after `Decimal256` has been consumed) - fn parse_decimal_256(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let precision = self.parse_u8("Decimal256")?; - self.expect_token(Token::Comma)?; - let scale = self.parse_i8("Decimal256")?; - self.expect_token(Token::RParen)?; - Ok(DataType::Decimal256(precision, scale)) - } - - /// Parses the next Dictionary (called after `Dictionary` has been consumed) - fn parse_dictionary(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let key_type = self.parse_next_type()?; - self.expect_token(Token::Comma)?; - let value_type = self.parse_next_type()?; - self.expect_token(Token::RParen)?; - Ok(DataType::Dictionary( - Box::new(key_type), - Box::new(value_type), - )) - } - /// return the next token, or an error if there are none left - fn next_token(&mut self) -> Result { - match self.tokenizer.next() { - None => Err(make_error(self.val, "finding next token")), - Some(token) => token, - } - } - - /// consume the next token, returning OK(()) if it matches tok, and Err if not - fn expect_token(&mut self, tok: Token) -> Result<()> { - let next_token = self.next_token()?; - if next_token == tok { - Ok(()) - } else { - Err(make_error_expected(self.val, &tok, &next_token)) - } - } -} - -/// returns true if this character is a separator -fn is_separator(c: char) -> bool { - c == '(' || c == ')' || c == ',' || c == ' ' -} - -#[derive(Debug)] -/// Splits a strings like Dictionary(Int32, Int64) into tokens suitable for parsing -/// -/// For example the string "Timestamp(Nanosecond, None)" would be parsed into: -/// -/// * Token::Timestamp -/// * Token::Lparen -/// * Token::IntervalUnit(IntervalUnit::Nanosecond) -/// * Token::Comma, -/// * Token::None, -/// * Token::Rparen, -struct Tokenizer<'a> { - val: &'a str, - chars: Peekable>, - // temporary buffer for parsing words - word: String, -} - -impl<'a> Tokenizer<'a> { - fn new(val: &'a str) -> Self { - Self { - val, - chars: val.chars().peekable(), - word: String::new(), - } - } - - /// returns the next char, without consuming it - fn peek_next_char(&mut self) -> Option { - self.chars.peek().copied() - } - - /// returns the next char, and consuming it - fn next_char(&mut self) -> Option { - self.chars.next() - } - - /// parse the characters in val starting at pos, until the next - /// `,`, `(`, or `)` or end of line - fn parse_word(&mut self) -> Result { - // reset temp space - self.word.clear(); - loop { - match self.peek_next_char() { - None => break, - Some(c) if is_separator(c) => break, - Some(c) => { - self.next_char(); - self.word.push(c); - } - } - } - - if let Some(c) = self.word.chars().next() { - // if it started with a number, try parsing it as an integer - if c == '-' || c.is_numeric() { - let val: i64 = self.word.parse().map_err(|e| { - make_error( - self.val, - &format!("parsing {} as integer: {e}", self.word), - ) - })?; - return Ok(Token::Integer(val)); - } - // if it started with a double quote `"`, try parsing it as a double quoted string - else if c == '"' { - let len = self.word.chars().count(); - - // to verify it's double quoted - if let Some(last_c) = self.word.chars().last() { - if last_c != '"' || len < 2 { - return Err(make_error( - self.val, - &format!("parsing {} as double quoted string: last char must be \"", self.word), - )); - } - } - - if len == 2 { - return Err(make_error( - self.val, - &format!("parsing {} as double quoted string: empty string isn't supported", self.word), - )); - } - - let val: String = self.word.parse().map_err(|e| { - make_error( - self.val, - &format!("parsing {} as double quoted string: {e}", self.word), - ) - })?; - - let s = val[1..len - 1].to_string(); - if s.contains('"') { - return Err(make_error( - self.val, - &format!("parsing {} as double quoted string: escaped double quote isn't supported", self.word), - )); - } - - return Ok(Token::DoubleQuotedString(s)); - } - } - - // figure out what the word was - let token = match self.word.as_str() { - "Null" => Token::SimpleType(DataType::Null), - "Boolean" => Token::SimpleType(DataType::Boolean), - - "Int8" => Token::SimpleType(DataType::Int8), - "Int16" => Token::SimpleType(DataType::Int16), - "Int32" => Token::SimpleType(DataType::Int32), - "Int64" => Token::SimpleType(DataType::Int64), - - "UInt8" => Token::SimpleType(DataType::UInt8), - "UInt16" => Token::SimpleType(DataType::UInt16), - "UInt32" => Token::SimpleType(DataType::UInt32), - "UInt64" => Token::SimpleType(DataType::UInt64), - - "Utf8" => Token::SimpleType(DataType::Utf8), - "LargeUtf8" => Token::SimpleType(DataType::LargeUtf8), - "Utf8View" => Token::SimpleType(DataType::Utf8View), - "Binary" => Token::SimpleType(DataType::Binary), - "BinaryView" => Token::SimpleType(DataType::BinaryView), - "LargeBinary" => Token::SimpleType(DataType::LargeBinary), - - "Float16" => Token::SimpleType(DataType::Float16), - "Float32" => Token::SimpleType(DataType::Float32), - "Float64" => Token::SimpleType(DataType::Float64), - - "Date32" => Token::SimpleType(DataType::Date32), - "Date64" => Token::SimpleType(DataType::Date64), - - "List" => Token::List, - "LargeList" => Token::LargeList, - "FixedSizeList" => Token::FixedSizeList, - - "Second" => Token::TimeUnit(TimeUnit::Second), - "Millisecond" => Token::TimeUnit(TimeUnit::Millisecond), - "Microsecond" => Token::TimeUnit(TimeUnit::Microsecond), - "Nanosecond" => Token::TimeUnit(TimeUnit::Nanosecond), - - "Timestamp" => Token::Timestamp, - "Time32" => Token::Time32, - "Time64" => Token::Time64, - "Duration" => Token::Duration, - "Interval" => Token::Interval, - "Dictionary" => Token::Dictionary, - - "FixedSizeBinary" => Token::FixedSizeBinary, - "Decimal128" => Token::Decimal128, - "Decimal256" => Token::Decimal256, - - "YearMonth" => Token::IntervalUnit(IntervalUnit::YearMonth), - "DayTime" => Token::IntervalUnit(IntervalUnit::DayTime), - "MonthDayNano" => Token::IntervalUnit(IntervalUnit::MonthDayNano), - - "Some" => Token::Some, - "None" => Token::None, - - _ => { - return Err(make_error( - self.val, - &format!("unrecognized word: {}", self.word), - )) - } - }; - Ok(token) - } -} - -impl<'a> Iterator for Tokenizer<'a> { - type Item = Result; - - fn next(&mut self) -> Option { - loop { - match self.peek_next_char()? { - ' ' => { - // skip whitespace - self.next_char(); - continue; - } - '(' => { - self.next_char(); - return Some(Ok(Token::LParen)); - } - ')' => { - self.next_char(); - return Some(Ok(Token::RParen)); - } - ',' => { - self.next_char(); - return Some(Ok(Token::Comma)); - } - _ => return Some(self.parse_word()), - } - } - } -} - -/// Grammar is -/// -#[derive(Debug, PartialEq)] -enum Token { - // Null, or Int32 - SimpleType(DataType), - Timestamp, - Time32, - Time64, - Duration, - Interval, - FixedSizeBinary, - Decimal128, - Decimal256, - Dictionary, - TimeUnit(TimeUnit), - IntervalUnit(IntervalUnit), - LParen, - RParen, - Comma, - Some, - None, - Integer(i64), - DoubleQuotedString(String), - List, - LargeList, - FixedSizeList, -} - -impl Display for Token { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Token::SimpleType(t) => write!(f, "{t}"), - Token::List => write!(f, "List"), - Token::LargeList => write!(f, "LargeList"), - Token::FixedSizeList => write!(f, "FixedSizeList"), - Token::Timestamp => write!(f, "Timestamp"), - Token::Time32 => write!(f, "Time32"), - Token::Time64 => write!(f, "Time64"), - Token::Duration => write!(f, "Duration"), - Token::Interval => write!(f, "Interval"), - Token::TimeUnit(u) => write!(f, "TimeUnit({u:?})"), - Token::IntervalUnit(u) => write!(f, "IntervalUnit({u:?})"), - Token::LParen => write!(f, "("), - Token::RParen => write!(f, ")"), - Token::Comma => write!(f, ","), - Token::Some => write!(f, "Some"), - Token::None => write!(f, "None"), - Token::FixedSizeBinary => write!(f, "FixedSizeBinary"), - Token::Decimal128 => write!(f, "Decimal128"), - Token::Decimal256 => write!(f, "Decimal256"), - Token::Dictionary => write!(f, "Dictionary"), - Token::Integer(v) => write!(f, "Integer({v})"), - Token::DoubleQuotedString(s) => write!(f, "DoubleQuotedString({s})"), - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_parse_data_type() { - // this ensures types can be parsed correctly from their string representations - for dt in list_datatypes() { - round_trip(dt) - } - } - - /// convert data_type to a string, and then parse it as a type - /// verifying it is the same - fn round_trip(data_type: DataType) { - let data_type_string = data_type.to_string(); - println!("Input '{data_type_string}' ({data_type:?})"); - let parsed_type = parse_data_type(&data_type_string).unwrap(); - assert_eq!( - data_type, parsed_type, - "Mismatch parsing {data_type_string}" - ); - } - - fn list_datatypes() -> Vec { - vec![ - // --------- - // Non Nested types - // --------- - DataType::Null, - DataType::Boolean, - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Float16, - DataType::Float32, - DataType::Float64, - DataType::Timestamp(TimeUnit::Second, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Nanosecond, None), - // we can't cover all possible timezones, here we only test utc and +08:00 - DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), - DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())), - DataType::Timestamp(TimeUnit::Millisecond, Some("+00:00".into())), - DataType::Timestamp(TimeUnit::Second, Some("+00:00".into())), - DataType::Timestamp(TimeUnit::Nanosecond, Some("+08:00".into())), - DataType::Timestamp(TimeUnit::Microsecond, Some("+08:00".into())), - DataType::Timestamp(TimeUnit::Millisecond, Some("+08:00".into())), - DataType::Timestamp(TimeUnit::Second, Some("+08:00".into())), - DataType::Date32, - DataType::Date64, - DataType::Time32(TimeUnit::Second), - DataType::Time32(TimeUnit::Millisecond), - DataType::Time32(TimeUnit::Microsecond), - DataType::Time32(TimeUnit::Nanosecond), - DataType::Time64(TimeUnit::Second), - DataType::Time64(TimeUnit::Millisecond), - DataType::Time64(TimeUnit::Microsecond), - DataType::Time64(TimeUnit::Nanosecond), - DataType::Duration(TimeUnit::Second), - DataType::Duration(TimeUnit::Millisecond), - DataType::Duration(TimeUnit::Microsecond), - DataType::Duration(TimeUnit::Nanosecond), - DataType::Interval(IntervalUnit::YearMonth), - DataType::Interval(IntervalUnit::DayTime), - DataType::Interval(IntervalUnit::MonthDayNano), - DataType::Binary, - DataType::BinaryView, - DataType::FixedSizeBinary(0), - DataType::FixedSizeBinary(1234), - DataType::FixedSizeBinary(-432), - DataType::LargeBinary, - DataType::Utf8, - DataType::Utf8View, - DataType::LargeUtf8, - DataType::Decimal128(7, 12), - DataType::Decimal256(6, 13), - // --------- - // Nested types - // --------- - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Timestamp(TimeUnit::Nanosecond, None)), - ), - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::FixedSizeBinary(23)), - ), - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new( - // nested dictionaries are probably a bad idea but they are possible - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Utf8), - ), - ), - ), - // TODO support more structured types (List, LargeList, Struct, Union, Map, RunEndEncoded, etc) - ] - } - - #[test] - fn test_parse_data_type_whitespace_tolerance() { - // (string to parse, expected DataType) - let cases = [ - ("Int8", DataType::Int8), - ( - "Timestamp (Nanosecond, None)", - DataType::Timestamp(TimeUnit::Nanosecond, None), - ), - ( - "Timestamp (Nanosecond, None) ", - DataType::Timestamp(TimeUnit::Nanosecond, None), - ), - ( - " Timestamp (Nanosecond, None )", - DataType::Timestamp(TimeUnit::Nanosecond, None), - ), - ( - "Timestamp (Nanosecond, None ) ", - DataType::Timestamp(TimeUnit::Nanosecond, None), - ), - ]; - - for (data_type_string, expected_data_type) in cases { - println!("Parsing '{data_type_string}', expecting '{expected_data_type:?}'"); - let parsed_data_type = parse_data_type(data_type_string).unwrap(); - assert_eq!(parsed_data_type, expected_data_type); - } - } - - #[test] - fn parse_data_type_errors() { - // (string to parse, expected error message) - let cases = [ - ("", "Unsupported type ''"), - ("", "Error finding next token"), - ("null", "Unsupported type 'null'"), - ("Nu", "Unsupported type 'Nu'"), - ( - r#"Timestamp(Nanosecond, Some(+00:00))"#, - "Error unrecognized word: +00:00", - ), - ( - r#"Timestamp(Nanosecond, Some("+00:00))"#, - r#"parsing "+00:00 as double quoted string: last char must be ""#, - ), - ( - r#"Timestamp(Nanosecond, Some(""))"#, - r#"parsing "" as double quoted string: empty string isn't supported"#, - ), - ( - r#"Timestamp(Nanosecond, Some("+00:00""))"#, - r#"parsing "+00:00"" as double quoted string: escaped double quote isn't supported"#, - ), - ("Timestamp(Nanosecond, ", "Error finding next token"), - ( - "Float32 Float32", - "trailing content after parsing 'Float32'", - ), - ("Int32, ", "trailing content after parsing 'Int32'"), - ("Int32(3), ", "trailing content after parsing 'Int32'"), - ("FixedSizeBinary(Int32), ", "Error finding i64 for FixedSizeBinary, got 'Int32'"), - ("FixedSizeBinary(3.0), ", "Error parsing 3.0 as integer: invalid digit found in string"), - // too large for i32 - ("FixedSizeBinary(4000000000), ", "Error converting 4000000000 into i32 for FixedSizeBinary: out of range integral type conversion attempted"), - // can't have negative precision - ("Decimal128(-3, 5)", "Error converting -3 into u8 for Decimal128: out of range integral type conversion attempted"), - ("Decimal256(-3, 5)", "Error converting -3 into u8 for Decimal256: out of range integral type conversion attempted"), - ("Decimal128(3, 500)", "Error converting 500 into i8 for Decimal128: out of range integral type conversion attempted"), - ("Decimal256(3, 500)", "Error converting 500 into i8 for Decimal256: out of range integral type conversion attempted"), - - ]; - - for (data_type_string, expected_message) in cases { - print!("Parsing '{data_type_string}', expecting '{expected_message}'"); - match parse_data_type(data_type_string) { - Ok(d) => panic!( - "Expected error while parsing '{data_type_string}', but got '{d}'" - ), - Err(e) => { - let message = e.to_string(); - assert!( - message.contains(expected_message), - "\n\ndid not find expected in actual.\n\nexpected: {expected_message}\nactual:{message}\n" - ); - // errors should also contain a help message - assert!(message.contains("Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'")); - } - } - } - } + val.parse().map_err(|e| match e { + // If the data type cannot be parsed, return a Plan error to signal an + // error in the input rather than a more general ArrowError + arrow::error::ArrowError::ParseError(e) => plan_datafusion_err!("{e}"), + e => arrow_datafusion_err!(e), + }) } diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index e26c94e1bb79..484afb57f74e 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -75,9 +75,8 @@ macro_rules! make_udf_function { static $GNAME: std::sync::OnceLock> = std::sync::OnceLock::new(); - /// Return a [`ScalarUDF`] for [`$UDF`] - /// - /// [`ScalarUDF`]: datafusion_expr::ScalarUDF + #[doc = "Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) for "] + #[doc = stringify!($UDF)] pub fn $NAME() -> std::sync::Arc { $GNAME .get_or_init(|| { diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 338268e299da..6f832966671c 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -103,11 +103,11 @@ mod tests { use datafusion_expr::expr::Sort; use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ - col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, - out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, - WindowFrameUnits, + col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, out_ref_col, + scalar_subquery, wildcard, WindowFrame, WindowFrameBound, WindowFrameUnits, }; use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::expr_fn::max; use std::sync::Arc; use datafusion_functions_aggregate::expr_fn::{count, sum}; diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 32bb2bc70452..91ee8a9e1033 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -136,9 +136,15 @@ impl Analyzer { // Note this is run before all other rules since it rewrites based on // the argument types (List or Scalar), and TypeCoercion may cast the // argument types from Scalar to List. - let expr_to_function: Arc = - Arc::new(ApplyFunctionRewrites::new(self.function_rewrites.clone())); - let rules = std::iter::once(&expr_to_function).chain(self.rules.iter()); + let expr_to_function: Option> = + if self.function_rewrites.is_empty() { + None + } else { + Some(Arc::new(ApplyFunctionRewrites::new( + self.function_rewrites.clone(), + ))) + }; + let rules = expr_to_function.iter().chain(self.rules.iter()); // TODO add common rule executor for Analyzer and Optimizer for rule in rules { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 75dbb4d1adcd..bcd1cbcce23e 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -47,9 +47,8 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - type_coercion, AggregateFunction, AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, - LogicalPlan, Operator, ScalarUDF, Signature, WindowFrame, WindowFrameBound, - WindowFrameUnits, + AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, Operator, ScalarUDF, + WindowFrame, WindowFrameBound, WindowFrameUnits, }; use crate::analyzer::AnalyzerRule; @@ -401,24 +400,6 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { order_by, null_treatment, }) => match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - let new_expr = coerce_agg_exprs_for_signature( - &fun, - args, - self.schema, - &fun.signature(), - )?; - Ok(Transformed::yes(Expr::AggregateFunction( - expr::AggregateFunction::new( - fun, - new_expr, - distinct, - filter, - order_by, - null_treatment, - ), - ))) - } AggregateFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature_with_aggregate_udf( args, @@ -449,14 +430,6 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { coerce_window_frame(window_frame, self.schema, &order_by)?; let args = match &fun { - expr::WindowFunctionDefinition::AggregateFunction(fun) => { - coerce_agg_exprs_for_signature( - fun, - args, - self.schema, - &fun.signature(), - )? - } expr::WindowFunctionDefinition::AggregateUDF(udf) => { coerce_arguments_for_signature_with_aggregate_udf( args, @@ -692,33 +665,6 @@ fn coerce_arguments_for_fun( } } -/// Returns the coerced exprs for each `input_exprs`. -/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the -/// data type of `input_exprs` need to be coerced. -fn coerce_agg_exprs_for_signature( - agg_fun: &AggregateFunction, - input_exprs: Vec, - schema: &DFSchema, - signature: &Signature, -) -> Result> { - if input_exprs.is_empty() { - return Ok(input_exprs); - } - let current_types = input_exprs - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - - let coerced_types = - type_coercion::aggregates::coerce_types(agg_fun, ¤t_types, signature)?; - - input_exprs - .into_iter() - .enumerate() - .map(|(i, expr)| expr.cast_to(&coerced_types[i], schema)) - .collect() -} - fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { // Given expressions like: // diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index c998e8442548..6dbf1641bd7c 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -436,9 +436,6 @@ fn agg_exprs_evaluation_result_on_empty_batch( Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => match func_def { - AggregateFunctionDefinition::BuiltIn(_fun) => { - Transformed::yes(Expr::Literal(ScalarValue::Null)) - } AggregateFunctionDefinition::UDF(fun) => { if fun.name() == "count" { Transformed::yes(Expr::Literal(ScalarValue::Int64(Some( diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 16abf93f3807..31d59da13323 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -814,13 +814,13 @@ mod tests { expr::{self, Cast}, lit, logical_plan::{builder::LogicalPlanBuilder, table_scan}, - max, min, not, try_cast, when, AggregateFunction, BinaryExpr, Expr, Extension, - Like, LogicalPlan, Operator, Projection, UserDefinedLogicalNodeCore, - WindowFunctionDefinition, + not, try_cast, when, BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator, + Projection, UserDefinedLogicalNodeCore, WindowFunctionDefinition, }; use datafusion_functions_aggregate::count::count_udaf; - use datafusion_functions_aggregate::expr_fn::count; + use datafusion_functions_aggregate::expr_fn::{count, max, min}; + use datafusion_functions_aggregate::min_max::max_udaf; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) @@ -1917,7 +1917,7 @@ mod tests { let table_scan = test_table_scan()?; let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.a")], )) .partition_by(vec![col("test.b")]) @@ -1925,7 +1925,7 @@ mod tests { .unwrap(); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.b")], )); let col1 = col(max1.display_name()?); diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 79980f8fc9ec..d7da3871ee89 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -321,8 +321,8 @@ mod test { use super::*; use crate::test::*; - - use datafusion_expr::{col, exists, logical_plan::builder::LogicalPlanBuilder, max}; + use datafusion_expr::{col, exists, logical_plan::builder::LogicalPlanBuilder}; + use datafusion_functions_aggregate::expr_fn::max; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected) diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 35691847fb8e..fbec675f6fc4 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -394,7 +394,9 @@ mod tests { use arrow::datatypes::DataType; use datafusion_expr::test::function_stub::sum; - use datafusion_expr::{col, lit, max, min, out_ref_col, scalar_subquery, Between}; + + use datafusion_expr::{col, lit, out_ref_col, scalar_subquery, Between}; + use datafusion_functions_aggregate::min_max::{max, min}; /// Test multiple correlated subqueries #[test] diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 38dfbb3ed551..1e1418744fb8 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -289,7 +289,7 @@ impl ExprSimplifier { self } - /// Should [`Canonicalizer`] be applied before simplification? + /// Should `Canonicalizer` be applied before simplification? /// /// If true (the default), the expression will be rewritten to canonical /// form before simplification. This is useful to ensure that the simplifier diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index e650d4c09c23..e44f60d1df22 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -160,6 +160,7 @@ mod tests { ExprSchemable, JoinType, }; use datafusion_expr::{or, BinaryExpr, Cast, Operator}; + use datafusion_functions_aggregate::expr_fn::{max, min}; use crate::test::{assert_fields_eq, test_table_scan_with_name}; use crate::OptimizerContext; @@ -395,10 +396,7 @@ mod tests { .project(vec![col("a"), col("c"), col("b")])? .aggregate( vec![col("a"), col("c")], - vec![ - datafusion_expr::max(col("b").eq(lit(true))), - datafusion_expr::min(col("b")), - ], + vec![max(col("b").eq(lit(true))), min(col("b"))], )? .build()?; diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index d776e6598cbe..69c1b505727d 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -28,7 +28,6 @@ use datafusion_common::{ use datafusion_expr::builder::project; use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::{ - aggregate_function::AggregateFunction::{Max, Min}, col, expr::AggregateFunction, logical_plan::{Aggregate, LogicalPlan}, @@ -71,26 +70,6 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { let mut aggregate_count = 0; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn(fun), - distinct, - args, - filter, - order_by, - null_treatment: _, - }) = expr - { - if filter.is_some() || order_by.is_some() { - return Ok(false); - } - aggregate_count += 1; - if *distinct { - for e in args { - fields_set.insert(e); - } - } else if !matches!(fun, Min | Max) { - return Ok(false); - } - } else if let Expr::AggregateFunction(AggregateFunction { func_def: AggregateFunctionDefinition::UDF(fun), distinct, args, @@ -107,7 +86,10 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { for e in args { fields_set.insert(e); } - } else if fun.name() != "sum" && fun.name() != "MIN" && fun.name() != "MAX" { + } else if fun.name() != "sum" + && fun.name().to_lowercase() != "min" + && fun.name().to_lowercase() != "max" + { return Ok(false); } } else { @@ -173,6 +155,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { // // First aggregate(from bottom) refers to `test.a` column. // Second aggregate refers to the `group_alias_0` column, Which is a valid field in the first aggregate. + // If we were to write plan above as below without alias // // Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\ @@ -200,55 +183,6 @@ impl OptimizerRule for SingleDistinctToGroupBy { let outer_aggr_exprs = aggr_expr .into_iter() .map(|aggr_expr| match aggr_expr { - Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn(fun), - mut args, - distinct, - .. - }) => { - if distinct { - if args.len() != 1 { - return internal_err!("DISTINCT aggregate should have exactly one argument"); - } - let arg = args.swap_remove(0); - - if group_fields_set.insert(arg.display_name()?) { - inner_group_exprs - .push(arg.alias(SINGLE_DISTINCT_ALIAS)); - } - Ok(Expr::AggregateFunction(AggregateFunction::new( - fun, - vec![col(SINGLE_DISTINCT_ALIAS)], - false, // intentional to remove distinct here - None, - None, - None, - ))) - // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation - } else { - index += 1; - let alias_str = format!("alias{}", index); - inner_aggr_exprs.push( - Expr::AggregateFunction(AggregateFunction::new( - fun.clone(), - args, - false, - None, - None, - None, - )) - .alias(&alias_str), - ); - Ok(Expr::AggregateFunction(AggregateFunction::new( - fun, - vec![col(&alias_str)], - false, - None, - None, - None, - ))) - } - } Expr::AggregateFunction(AggregateFunction { func_def: AggregateFunctionDefinition::UDF(udf), mut args, @@ -355,13 +289,23 @@ mod tests { use crate::test::*; use datafusion_expr::expr::{self, GroupingSet}; use datafusion_expr::ExprFunctionExt; - use datafusion_expr::{ - lit, logical_plan::builder::LogicalPlanBuilder, max, min, AggregateFunction, - }; + use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use datafusion_functions_aggregate::count::count_udaf; - use datafusion_functions_aggregate::expr_fn::{count, count_distinct, sum}; + use datafusion_functions_aggregate::expr_fn::{count, count_distinct, max, min, sum}; + use datafusion_functions_aggregate::min_max::max_udaf; use datafusion_functions_aggregate::sum::sum_udaf; + fn max_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + max_udaf(), + vec![expr], + true, + None, + None, + None, + )) + } + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(SingleDistinctToGroupBy::new()), @@ -520,17 +464,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("a")], - vec![ - count_distinct(col("b")), - Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Max, - vec![col("b")], - true, - None, - None, - None, - )), - ], + vec![count_distinct(col("b")), max_distinct(col("b"))], )? .build()?; // Should work @@ -587,14 +521,7 @@ mod tests { vec![ sum(col("c")), count_distinct(col("b")), - Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Max, - vec![col("b")], - true, - None, - None, - None, - )), + max_distinct(col("b")), ], )? .build()?; diff --git a/datafusion/physical-expr-common/src/aggregate/tdigest.rs b/datafusion/physical-expr-common/src/aggregate/tdigest.rs index 1da3d7180d84..070ebc46483b 100644 --- a/datafusion/physical-expr-common/src/aggregate/tdigest.rs +++ b/datafusion/physical-expr-common/src/aggregate/tdigest.rs @@ -47,6 +47,17 @@ macro_rules! cast_scalar_f64 { }; } +// Cast a non-null [`ScalarValue::UInt64`] to an [`u64`], or +// panic. +macro_rules! cast_scalar_u64 { + ($value:expr ) => { + match &$value { + ScalarValue::UInt64(Some(v)) => *v, + v => panic!("invalid type {:?}", v), + } + }; +} + /// This trait is implemented for each type a [`TDigest`] can operate on, /// allowing it to support both numerical rust types (obtained from /// `PrimitiveArray` instances), and [`ScalarValue`] instances. @@ -142,7 +153,7 @@ pub struct TDigest { centroids: Vec, max_size: usize, sum: f64, - count: f64, + count: u64, max: f64, min: f64, } @@ -153,7 +164,7 @@ impl TDigest { centroids: Vec::new(), max_size, sum: 0_f64, - count: 0_f64, + count: 0, max: f64::NAN, min: f64::NAN, } @@ -164,14 +175,14 @@ impl TDigest { centroids: vec![centroid.clone()], max_size, sum: centroid.mean * centroid.weight, - count: 1_f64, + count: 1, max: centroid.mean, min: centroid.mean, } } #[inline] - pub fn count(&self) -> f64 { + pub fn count(&self) -> u64 { self.count } @@ -203,7 +214,7 @@ impl Default for TDigest { centroids: Vec::new(), max_size: 100, sum: 0_f64, - count: 0_f64, + count: 0, max: f64::NAN, min: f64::NAN, } @@ -211,8 +222,8 @@ impl Default for TDigest { } impl TDigest { - fn k_to_q(k: f64, d: f64) -> f64 { - let k_div_d = k / d; + fn k_to_q(k: u64, d: usize) -> f64 { + let k_div_d = k as f64 / d as f64; if k_div_d >= 0.5 { let base = 1.0 - k_div_d; 1.0 - 2.0 * base * base @@ -244,12 +255,12 @@ impl TDigest { } let mut result = TDigest::new(self.max_size()); - result.count = self.count() + (sorted_values.len() as f64); + result.count = self.count() + sorted_values.len() as u64; let maybe_min = *sorted_values.first().unwrap(); let maybe_max = *sorted_values.last().unwrap(); - if self.count() > 0.0 { + if self.count() > 0 { result.min = self.min.min(maybe_min); result.max = self.max.max(maybe_max); } else { @@ -259,10 +270,10 @@ impl TDigest { let mut compressed: Vec = Vec::with_capacity(self.max_size); - let mut k_limit: f64 = 1.0; + let mut k_limit: u64 = 1; let mut q_limit_times_count = - Self::k_to_q(k_limit, self.max_size as f64) * result.count(); - k_limit += 1.0; + Self::k_to_q(k_limit, self.max_size) * result.count() as f64; + k_limit += 1; let mut iter_centroids = self.centroids.iter().peekable(); let mut iter_sorted_values = sorted_values.iter().peekable(); @@ -309,8 +320,8 @@ impl TDigest { compressed.push(curr.clone()); q_limit_times_count = - Self::k_to_q(k_limit, self.max_size as f64) * result.count(); - k_limit += 1.0; + Self::k_to_q(k_limit, self.max_size) * result.count() as f64; + k_limit += 1; curr = next; } } @@ -381,7 +392,7 @@ impl TDigest { let mut centroids: Vec = Vec::with_capacity(n_centroids); let mut starts: Vec = Vec::with_capacity(digests.len()); - let mut count: f64 = 0.0; + let mut count = 0; let mut min = f64::INFINITY; let mut max = f64::NEG_INFINITY; @@ -389,8 +400,8 @@ impl TDigest { for digest in digests.iter() { starts.push(start); - let curr_count: f64 = digest.count(); - if curr_count > 0.0 { + let curr_count = digest.count(); + if curr_count > 0 { min = min.min(digest.min); max = max.max(digest.max); count += curr_count; @@ -424,8 +435,8 @@ impl TDigest { let mut result = TDigest::new(max_size); let mut compressed: Vec = Vec::with_capacity(max_size); - let mut k_limit: f64 = 1.0; - let mut q_limit_times_count = Self::k_to_q(k_limit, max_size as f64) * (count); + let mut k_limit = 1; + let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64; let mut iter_centroids = centroids.iter_mut(); let mut curr = iter_centroids.next().unwrap(); @@ -444,8 +455,8 @@ impl TDigest { sums_to_merge = 0_f64; weights_to_merge = 0_f64; compressed.push(curr.clone()); - q_limit_times_count = Self::k_to_q(k_limit, max_size as f64) * (count); - k_limit += 1.0; + q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64; + k_limit += 1; curr = centroid; } } @@ -468,8 +479,7 @@ impl TDigest { return 0.0; } - let count_ = self.count; - let rank = q * count_; + let rank = q * self.count as f64; let mut pos: usize; let mut t; @@ -479,7 +489,7 @@ impl TDigest { } pos = 0; - t = count_; + t = self.count as f64; for (k, centroid) in self.centroids.iter().enumerate().rev() { t -= centroid.weight(); @@ -581,7 +591,7 @@ impl TDigest { vec![ ScalarValue::UInt64(Some(self.max_size as u64)), ScalarValue::Float64(Some(self.sum)), - ScalarValue::Float64(Some(self.count)), + ScalarValue::UInt64(Some(self.count)), ScalarValue::Float64(Some(self.max)), ScalarValue::Float64(Some(self.min)), ScalarValue::List(arr), @@ -627,7 +637,7 @@ impl TDigest { Self { max_size, sum: cast_scalar_f64!(state[1]), - count: cast_scalar_f64!(&state[2]), + count: cast_scalar_u64!(&state[2]), max, min, centroids, diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs deleted file mode 100644 index bdc41ff0a9bc..000000000000 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ /dev/null @@ -1,208 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Declaration of built-in (aggregate) functions. -//! This module contains built-in aggregates' enumeration and metadata. -//! -//! Generally, an aggregate has: -//! * a signature -//! * a return type, that is a function of the incoming argument's types -//! * the computation, that must accept each valid signature -//! -//! * Signature: see `Signature` -//! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64. - -use std::sync::Arc; - -use arrow::datatypes::Schema; - -use datafusion_common::Result; -use datafusion_expr::AggregateFunction; - -use crate::expressions::{self}; -use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; - -/// Create a physical aggregation expression. -/// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregation function. -pub fn create_aggregate_expr( - fun: &AggregateFunction, - distinct: bool, - input_phy_exprs: &[Arc], - _ordering_req: &[PhysicalSortExpr], - input_schema: &Schema, - name: impl Into, - _ignore_nulls: bool, -) -> Result> { - let name = name.into(); - // get the result data type for this aggregate function - let input_phy_types = input_phy_exprs - .iter() - .map(|e| e.data_type(input_schema)) - .collect::>>()?; - let data_type = input_phy_types[0].clone(); - let input_phy_exprs = input_phy_exprs.to_vec(); - Ok(match (fun, distinct) { - (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( - Arc::clone(&input_phy_exprs[0]), - name, - data_type, - )), - (AggregateFunction::Max, _) => Arc::new(expressions::Max::new( - Arc::clone(&input_phy_exprs[0]), - name, - data_type, - )), - }) -} - -#[cfg(test)] -mod tests { - use arrow::datatypes::{DataType, Field}; - - use datafusion_common::plan_err; - use datafusion_expr::{type_coercion, Signature}; - - use crate::expressions::{try_cast, Max, Min}; - - use super::*; - - #[test] - fn test_min_max_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; - let data_types = vec![ - DataType::UInt32, - DataType::Int32, - DataType::Float32, - DataType::Float64, - DataType::Decimal128(10, 2), - DataType::Utf8, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - match fun { - AggregateFunction::Min => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::Max => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - }; - } - } - Ok(()) - } - - #[test] - fn test_min_max() -> Result<()> { - let observed = AggregateFunction::Min.return_type(&[DataType::Utf8], &[true])?; - assert_eq!(DataType::Utf8, observed); - - let observed = AggregateFunction::Max.return_type(&[DataType::Int32], &[true])?; - assert_eq!(DataType::Int32, observed); - - // test decimal for min - let observed = AggregateFunction::Min - .return_type(&[DataType::Decimal128(10, 6)], &[true])?; - assert_eq!(DataType::Decimal128(10, 6), observed); - - // test decimal for max - let observed = AggregateFunction::Max - .return_type(&[DataType::Decimal128(28, 13)], &[true])?; - assert_eq!(DataType::Decimal128(28, 13), observed); - - Ok(()) - } - - // Helper function - // Create aggregate expr with type coercion - fn create_physical_agg_expr_for_test( - fun: &AggregateFunction, - distinct: bool, - input_phy_exprs: &[Arc], - input_schema: &Schema, - name: impl Into, - ) -> Result> { - let name = name.into(); - let coerced_phy_exprs = - coerce_exprs_for_test(fun, input_phy_exprs, input_schema, &fun.signature())?; - if coerced_phy_exprs.is_empty() { - return plan_err!( - "Invalid or wrong number of arguments passed to aggregate: '{name}'" - ); - } - create_aggregate_expr( - fun, - distinct, - &coerced_phy_exprs, - &[], - input_schema, - name, - false, - ) - } - - // Returns the coerced exprs for each `input_exprs`. - // Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the - // data type of `input_exprs` need to be coerced. - fn coerce_exprs_for_test( - agg_fun: &AggregateFunction, - input_exprs: &[Arc], - schema: &Schema, - signature: &Signature, - ) -> Result>> { - if input_exprs.is_empty() { - return Ok(vec![]); - } - let input_types = input_exprs - .iter() - .map(|e| e.data_type(schema)) - .collect::>>()?; - - // get the coerced data types - let coerced_types = - type_coercion::aggregates::coerce_types(agg_fun, &input_types, signature)?; - - // try cast if need - input_exprs - .iter() - .zip(coerced_types) - .map(|(expr, coerced_type)| try_cast(Arc::clone(expr), schema, coerced_type)) - .collect::>>() - } -} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs index 1944e2b2d415..3c0f3a28fedb 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs @@ -25,7 +25,3 @@ pub(crate) mod accumulate { } pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; - -pub(crate) mod prim_op { - pub use datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; -} diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 264c48513050..0760986a87c6 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -15,12 +15,9 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -pub(crate) mod min_max; pub(crate) mod groups_accumulator; pub(crate) mod stats; -pub mod build_in; pub mod moving_min_max; pub mod utils { pub use datafusion_physical_expr_common::aggregate::utils::{ diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 951bef4521e3..2c45de729120 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -31,11 +31,6 @@ mod try_cast; mod unknown_column; /// Module with some convenient methods used in expression building -pub mod helpers { - pub use crate::aggregate::min_max::{max, min}; -} -pub use crate::aggregate::build_in::create_aggregate_expr; -pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::stats::StatsType; pub use crate::window::cume_dist::{cume_dist, CumeDist}; pub use crate::window::lead_lag::{lag, lead, WindowShift}; diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 9f05da7cff53..dc948e28bb2d 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -122,10 +122,14 @@ impl GroupValues for GroupValuesRows { for (row, &target_hash) in batch_hashes.iter().enumerate() { let entry = self.map.get_mut(target_hash, |(exist_hash, group_idx)| { - // verify that a group that we are inserting with hash is - // actually the same key value as the group in - // existing_idx (aka group_values @ row) + // Somewhat surprisingly, this closure can be called even if the + // hash doesn't match, so check the hash first with an integer + // comparison first avoid the more expensive comparison with + // group value. https://github.com/apache/datafusion/pull/11718 target_hash == *exist_hash + // verify that the group that we are inserting with hash is + // actually the same key value as the group in + // existing_idx (aka group_values @ row) && group_rows.row(row) == group_values.row(*group_idx) }); diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index fd2510fbf90d..9d7f45603464 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -55,9 +55,6 @@ mod row_hash; mod topk; mod topk_stream; -pub use datafusion_expr::AggregateFunction; -pub use datafusion_physical_expr::expressions::create_aggregate_expr; - /// Hash aggregate modes #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum AggregateMode { diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index 038727daa7d8..b822ec2dafeb 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. -//! CoalesceBatchesExec combines small batches into larger batches for more efficient use of -//! vectorized processing by upstream operators. +//! [`CoalesceBatchesExec`] combines small batches into larger batches. use std::any::Any; use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use arrow::array::{AsArray, StringViewBuilder}; use arrow::compute::concat_batches; @@ -41,11 +40,43 @@ use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{DisplayAs, ExecutionPlanProperties, PlanProperties, Statistics}; /// `CoalesceBatchesExec` combines small batches into larger batches for more -/// efficient use of vectorized processing by later operators. The operator -/// works by buffering batches until it collects `target_batch_size` rows. When -/// only a limited number of rows are necessary (specified by the `fetch` -/// parameter), the operator will stop buffering and return the final batch -/// once the number of collected rows reaches the `fetch` value. +/// efficient use of vectorized processing by later operators. +/// +/// The operator buffers batches until it collects `target_batch_size` rows and +/// then emits a single concatenated batch. When only a limited number of rows +/// are necessary (specified by the `fetch` parameter), the operator will stop +/// buffering and returns the final batch once the number of collected rows +/// reaches the `fetch` value. +/// +/// # Background +/// +/// Generally speaking, larger RecordBatches are more efficient to process than +/// smaller record batches (until the CPU cache is exceeded) because there is +/// fixed processing overhead per batch. This code concatenates multiple small +/// record batches into larger ones to amortize this overhead. +/// +/// ```text +/// ┌────────────────────┐ +/// │ RecordBatch │ +/// │ num_rows = 23 │ +/// └────────────────────┘ ┌────────────────────┐ +/// │ │ +/// ┌────────────────────┐ Coalesce │ │ +/// │ │ Batches │ │ +/// │ RecordBatch │ │ │ +/// │ num_rows = 50 │ ─ ─ ─ ─ ─ ─ ▶ │ │ +/// │ │ │ RecordBatch │ +/// │ │ │ num_rows = 106 │ +/// └────────────────────┘ │ │ +/// │ │ +/// ┌────────────────────┐ │ │ +/// │ │ │ │ +/// │ RecordBatch │ │ │ +/// │ num_rows = 33 │ └────────────────────┘ +/// │ │ +/// └────────────────────┘ +/// ``` + #[derive(Debug)] pub struct CoalesceBatchesExec { /// The input plan @@ -166,12 +197,11 @@ impl ExecutionPlan for CoalesceBatchesExec { ) -> Result { Ok(Box::pin(CoalesceBatchesStream { input: self.input.execute(partition, context)?, - schema: self.input.schema(), - target_batch_size: self.target_batch_size, - fetch: self.fetch, - buffer: Vec::new(), - buffered_rows: 0, - total_rows: 0, + coalescer: BatchCoalescer::new( + self.input.schema(), + self.target_batch_size, + self.fetch, + ), is_closed: false, baseline_metrics: BaselineMetrics::new(&self.metrics, partition), })) @@ -196,21 +226,12 @@ impl ExecutionPlan for CoalesceBatchesExec { } } +/// Stream for [`CoalesceBatchesExec`]. See [`CoalesceBatchesExec`] for more details. struct CoalesceBatchesStream { /// The input plan input: SendableRecordBatchStream, - /// The input schema - schema: SchemaRef, - /// Minimum number of rows for coalesces batches - target_batch_size: usize, - /// Maximum number of rows to fetch, `None` means fetching all rows - fetch: Option, - /// Buffered batches - buffer: Vec, - /// Buffered row count - buffered_rows: usize, - /// Total number of rows returned - total_rows: usize, + /// Buffer for combining batches + coalescer: BatchCoalescer, /// Whether the stream has finished returning all of its data or not is_closed: bool, /// Execution metrics @@ -249,84 +270,178 @@ impl CoalesceBatchesStream { let input_batch = self.input.poll_next_unpin(cx); // records time on drop let _timer = cloned_time.timer(); - match input_batch { - Poll::Ready(x) => match x { - Some(Ok(batch)) => { - let batch = gc_string_view_batch(&batch); - - // Handle fetch limit: - if let Some(fetch) = self.fetch { - if self.total_rows + batch.num_rows() >= fetch { - // We have reached the fetch limit. - let remaining_rows = fetch - self.total_rows; - debug_assert!(remaining_rows > 0); - + match ready!(input_batch) { + Some(result) => { + let Ok(input_batch) = result else { + return Poll::Ready(Some(result)); // pass back error + }; + // Buffer the batch and either get more input if not enough + // rows yet or output + match self.coalescer.push_batch(input_batch) { + Ok(None) => continue, + res => { + if self.coalescer.limit_reached() { self.is_closed = true; - self.total_rows = fetch; - // Trim the batch and add to buffered batches: - let batch = batch.slice(0, remaining_rows); - self.buffered_rows += batch.num_rows(); - self.buffer.push(batch); - // Combine buffered batches: - let batch = concat_batches(&self.schema, &self.buffer)?; - // Reset the buffer state and return final batch: - self.buffer.clear(); - self.buffered_rows = 0; - return Poll::Ready(Some(Ok(batch))); - } - } - self.total_rows += batch.num_rows(); - - if batch.num_rows() >= self.target_batch_size - && self.buffer.is_empty() - { - return Poll::Ready(Some(Ok(batch))); - } else if batch.num_rows() == 0 { - // discard empty batches - } else { - // add to the buffered batches - self.buffered_rows += batch.num_rows(); - self.buffer.push(batch); - // check to see if we have enough batches yet - if self.buffered_rows >= self.target_batch_size { - // combine the batches and return - let batch = concat_batches(&self.schema, &self.buffer)?; - // reset buffer state - self.buffer.clear(); - self.buffered_rows = 0; - // return batch - return Poll::Ready(Some(Ok(batch))); } + return Poll::Ready(res.transpose()); } } - None => { - self.is_closed = true; - // we have reached the end of the input stream but there could still - // be buffered batches - if self.buffer.is_empty() { - return Poll::Ready(None); - } else { - // combine the batches and return - let batch = concat_batches(&self.schema, &self.buffer)?; - // reset buffer state - self.buffer.clear(); - self.buffered_rows = 0; - // return batch - return Poll::Ready(Some(Ok(batch))); - } - } - other => return Poll::Ready(other), - }, - Poll::Pending => return Poll::Pending, + } + None => { + self.is_closed = true; + // we have reached the end of the input stream but there could still + // be buffered batches + return match self.coalescer.finish() { + Ok(None) => Poll::Ready(None), + res => Poll::Ready(res.transpose()), + }; + } } } } } impl RecordBatchStream for CoalesceBatchesStream { + fn schema(&self) -> SchemaRef { + self.coalescer.schema() + } +} + +/// Concatenate multiple record batches into larger batches +/// +/// See [`CoalesceBatchesExec`] for more details. +/// +/// Notes: +/// +/// 1. The output rows is the same order as the input rows +/// +/// 2. The output is a sequence of batches, with all but the last being at least +/// `target_batch_size` rows. +/// +/// 3. Eventually this may also be able to handle other optimizations such as a +/// combined filter/coalesce operation. +#[derive(Debug)] +struct BatchCoalescer { + /// The input schema + schema: SchemaRef, + /// Minimum number of rows for coalesces batches + target_batch_size: usize, + /// Total number of rows returned so far + total_rows: usize, + /// Buffered batches + buffer: Vec, + /// Buffered row count + buffered_rows: usize, + /// Maximum number of rows to fetch, `None` means fetching all rows + fetch: Option, +} + +impl BatchCoalescer { + /// Create a new `BatchCoalescer` + /// + /// # Arguments + /// - `schema` - the schema of the output batches + /// - `target_batch_size` - the minimum number of rows for each + /// output batch (until limit reached) + /// - `fetch` - the maximum number of rows to fetch, `None` means fetch all rows + fn new(schema: SchemaRef, target_batch_size: usize, fetch: Option) -> Self { + Self { + schema, + target_batch_size, + total_rows: 0, + buffer: vec![], + buffered_rows: 0, + fetch, + } + } + + /// Return the schema of the output batches fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } + + /// Add a batch, returning a batch if the target batch size or limit is reached + fn push_batch(&mut self, batch: RecordBatch) -> Result> { + // discard empty batches + if batch.num_rows() == 0 { + return Ok(None); + } + + // past limit + if self.limit_reached() { + return Ok(None); + } + + let batch = gc_string_view_batch(&batch); + + // Handle fetch limit: + if let Some(fetch) = self.fetch { + if self.total_rows + batch.num_rows() >= fetch { + // We have reached the fetch limit. + let remaining_rows = fetch - self.total_rows; + debug_assert!(remaining_rows > 0); + self.total_rows = fetch; + // Trim the batch and add to buffered batches: + let batch = batch.slice(0, remaining_rows); + self.buffered_rows += batch.num_rows(); + self.buffer.push(batch); + // Combine buffered batches: + let batch = concat_batches(&self.schema, &self.buffer)?; + // Reset the buffer state and return final batch: + self.buffer.clear(); + self.buffered_rows = 0; + return Ok(Some(batch)); + } + } + self.total_rows += batch.num_rows(); + + // batch itself is already big enough and we have no buffered rows so + // return it directly + if batch.num_rows() >= self.target_batch_size && self.buffer.is_empty() { + return Ok(Some(batch)); + } + // add to the buffered batches + self.buffered_rows += batch.num_rows(); + self.buffer.push(batch); + // check to see if we have enough batches yet + let batch = if self.buffered_rows >= self.target_batch_size { + // combine the batches and return + let batch = concat_batches(&self.schema, &self.buffer)?; + // reset buffer state + self.buffer.clear(); + self.buffered_rows = 0; + // return batch + Some(batch) + } else { + None + }; + Ok(batch) + } + + /// Finish the coalescing process, returning all buffered data as a final, + /// single batch, if any + fn finish(&mut self) -> Result> { + if self.buffer.is_empty() { + Ok(None) + } else { + // combine the batches and return + let batch = concat_batches(&self.schema, &self.buffer)?; + // reset buffer state + self.buffer.clear(); + self.buffered_rows = 0; + // return batch + Ok(Some(batch)) + } + } + + /// returns true if there is a limit and it has been reached + pub fn limit_reached(&self) -> bool { + if let Some(fetch) = self.fetch { + self.total_rows >= fetch + } else { + false + } + } } /// Heuristically compact `StringViewArray`s to reduce memory usage, if needed @@ -400,164 +515,206 @@ fn gc_string_view_batch(batch: &RecordBatch) -> RecordBatch { #[cfg(test)] mod tests { + use super::*; use arrow::datatypes::{DataType, Field, Schema}; use arrow_array::builder::ArrayBuilder; use arrow_array::{StringViewArray, UInt32Array}; + use std::ops::Range; - use crate::{memory::MemoryExec, repartition::RepartitionExec, Partitioning}; - - use super::*; - - #[tokio::test(flavor = "multi_thread")] - async fn test_concat_batches() -> Result<()> { - let schema = test_schema(); - let partition = create_vec_batches(&schema, 10); - let partitions = vec![partition]; - - let output_partitions = coalesce_batches(&schema, partitions, 21, None).await?; - assert_eq!(1, output_partitions.len()); - - // input is 10 batches x 8 rows (80 rows) - // expected output is batches of at least 20 rows (except for the final batch) - let batches = &output_partitions[0]; - assert_eq!(4, batches.len()); - assert_eq!(24, batches[0].num_rows()); - assert_eq!(24, batches[1].num_rows()); - assert_eq!(24, batches[2].num_rows()); - assert_eq!(8, batches[3].num_rows()); - - Ok(()) + #[test] + fn test_coalesce() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // expected output is batches of at least 20 rows (except for the final batch) + .with_target_batch_size(21) + .with_expected_output_sizes(vec![24, 24, 24, 8]) + .run() } - #[tokio::test] - async fn test_concat_batches_with_fetch_larger_than_input_size() -> Result<()> { - let schema = test_schema(); - let partition = create_vec_batches(&schema, 10); - let partitions = vec![partition]; - - let output_partitions = - coalesce_batches(&schema, partitions, 21, Some(100)).await?; - assert_eq!(1, output_partitions.len()); + #[test] + fn test_coalesce_with_fetch_larger_than_input_size() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // input is 10 batches x 8 rows (80 rows) with fetch limit of 100 + // expected to behave the same as `test_concat_batches` + .with_target_batch_size(21) + .with_fetch(Some(100)) + .with_expected_output_sizes(vec![24, 24, 24, 8]) + .run(); + } - // input is 10 batches x 8 rows (80 rows) with fetch limit of 100 - // expected to behave the same as `test_concat_batches` - let batches = &output_partitions[0]; - assert_eq!(4, batches.len()); - assert_eq!(24, batches[0].num_rows()); - assert_eq!(24, batches[1].num_rows()); - assert_eq!(24, batches[2].num_rows()); - assert_eq!(8, batches[3].num_rows()); + #[test] + fn test_coalesce_with_fetch_less_than_input_size() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // input is 10 batches x 8 rows (80 rows) with fetch limit of 50 + .with_target_batch_size(21) + .with_fetch(Some(50)) + .with_expected_output_sizes(vec![24, 24, 2]) + .run(); + } - Ok(()) + #[test] + fn test_coalesce_with_fetch_less_than_target_and_no_remaining_rows() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // input is 10 batches x 8 rows (80 rows) with fetch limit of 48 + .with_target_batch_size(21) + .with_fetch(Some(48)) + .with_expected_output_sizes(vec![24, 24]) + .run(); } - #[tokio::test] - async fn test_concat_batches_with_fetch_less_than_input_size() -> Result<()> { - let schema = test_schema(); - let partition = create_vec_batches(&schema, 10); - let partitions = vec![partition]; + #[test] + fn test_coalesce_with_fetch_less_target_batch_size() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // input is 10 batches x 8 rows (80 rows) with fetch limit of 10 + .with_target_batch_size(21) + .with_fetch(Some(10)) + .with_expected_output_sizes(vec![10]) + .run(); + } - let output_partitions = - coalesce_batches(&schema, partitions, 21, Some(50)).await?; - assert_eq!(1, output_partitions.len()); + #[test] + fn test_coalesce_single_large_batch_over_fetch() { + let large_batch = uint32_batch(0..100); + Test::new() + .with_batch(large_batch) + .with_target_batch_size(20) + .with_fetch(Some(7)) + .with_expected_output_sizes(vec![7]) + .run() + } + + /// Test for [`BatchCoalescer`] + /// + /// Pushes the input batches to the coalescer and verifies that the resulting + /// batches have the expected number of rows and contents. + #[derive(Debug, Clone, Default)] + struct Test { + /// Batches to feed to the coalescer. Tests must have at least one + /// schema + input_batches: Vec, + /// Expected output sizes of the resulting batches + expected_output_sizes: Vec, + /// target batch size + target_batch_size: usize, + /// Fetch (limit) + fetch: Option, + } - // input is 10 batches x 8 rows (80 rows) with fetch limit of 50 - let batches = &output_partitions[0]; - assert_eq!(3, batches.len()); - assert_eq!(24, batches[0].num_rows()); - assert_eq!(24, batches[1].num_rows()); - assert_eq!(2, batches[2].num_rows()); + impl Test { + fn new() -> Self { + Self::default() + } - Ok(()) - } + /// Set the target batch size + fn with_target_batch_size(mut self, target_batch_size: usize) -> Self { + self.target_batch_size = target_batch_size; + self + } - #[tokio::test] - async fn test_concat_batches_with_fetch_less_than_target_and_no_remaining_rows( - ) -> Result<()> { - let schema = test_schema(); - let partition = create_vec_batches(&schema, 10); - let partitions = vec![partition]; + /// Set the fetch (limit) + fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } - let output_partitions = - coalesce_batches(&schema, partitions, 21, Some(48)).await?; - assert_eq!(1, output_partitions.len()); + /// Extend the input batches with `batch` + fn with_batch(mut self, batch: RecordBatch) -> Self { + self.input_batches.push(batch); + self + } - // input is 10 batches x 8 rows (80 rows) with fetch limit of 48 - let batches = &output_partitions[0]; - assert_eq!(2, batches.len()); - assert_eq!(24, batches[0].num_rows()); - assert_eq!(24, batches[1].num_rows()); + /// Extends the input batches with `batches` + fn with_batches( + mut self, + batches: impl IntoIterator, + ) -> Self { + self.input_batches.extend(batches); + self + } - Ok(()) - } + /// Extends `sizes` to expected output sizes + fn with_expected_output_sizes( + mut self, + sizes: impl IntoIterator, + ) -> Self { + self.expected_output_sizes.extend(sizes); + self + } - #[tokio::test] - async fn test_concat_batches_with_fetch_less_target_batch_size() -> Result<()> { - let schema = test_schema(); - let partition = create_vec_batches(&schema, 10); - let partitions = vec![partition]; + /// Runs the test -- see documentation on [`Test`] for details + fn run(self) { + let Self { + input_batches, + target_batch_size, + fetch, + expected_output_sizes, + } = self; - let output_partitions = - coalesce_batches(&schema, partitions, 21, Some(10)).await?; - assert_eq!(1, output_partitions.len()); + let schema = input_batches[0].schema(); - // input is 10 batches x 8 rows (80 rows) with fetch limit of 10 - let batches = &output_partitions[0]; - assert_eq!(1, batches.len()); - assert_eq!(10, batches[0].num_rows()); + // create a single large input batch for output comparison + let single_input_batch = concat_batches(&schema, &input_batches).unwrap(); - Ok(()) - } + let mut coalescer = BatchCoalescer::new(schema, target_batch_size, fetch); - fn test_schema() -> Arc { - Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])) - } + let mut output_batches = vec![]; + for batch in input_batches { + if let Some(batch) = coalescer.push_batch(batch).unwrap() { + output_batches.push(batch); + } + } + if let Some(batch) = coalescer.finish().unwrap() { + output_batches.push(batch); + } - async fn coalesce_batches( - schema: &SchemaRef, - input_partitions: Vec>, - target_batch_size: usize, - fetch: Option, - ) -> Result>> { - // create physical plan - let exec = MemoryExec::try_new(&input_partitions, Arc::clone(schema), None)?; - let exec = - RepartitionExec::try_new(Arc::new(exec), Partitioning::RoundRobinBatch(1))?; - let exec: Arc = Arc::new( - CoalesceBatchesExec::new(Arc::new(exec), target_batch_size).with_fetch(fetch), - ); - - // execute and collect results - let output_partition_count = exec.output_partitioning().partition_count(); - let mut output_partitions = Vec::with_capacity(output_partition_count); - for i in 0..output_partition_count { - // execute this *output* partition and collect all batches - let task_ctx = Arc::new(TaskContext::default()); - let mut stream = exec.execute(i, Arc::clone(&task_ctx))?; - let mut batches = vec![]; - while let Some(result) = stream.next().await { - batches.push(result?); + // make sure we got the expected number of output batches and content + let mut starting_idx = 0; + assert_eq!(expected_output_sizes.len(), output_batches.len()); + for (i, (expected_size, batch)) in + expected_output_sizes.iter().zip(output_batches).enumerate() + { + assert_eq!( + *expected_size, + batch.num_rows(), + "Unexpected number of rows in Batch {i}" + ); + + // compare the contents of the batch (using `==` compares the + // underlying memory layout too) + let expected_batch = + single_input_batch.slice(starting_idx, *expected_size); + let batch_strings = batch_to_pretty_strings(&batch); + let expected_batch_strings = batch_to_pretty_strings(&expected_batch); + let batch_strings = batch_strings.lines().collect::>(); + let expected_batch_strings = + expected_batch_strings.lines().collect::>(); + assert_eq!( + expected_batch_strings, batch_strings, + "Unexpected content in Batch {i}:\ + \n\nExpected:\n{expected_batch_strings:#?}\n\nActual:\n{batch_strings:#?}" + ); + starting_idx += *expected_size; } - output_partitions.push(batches); } - Ok(output_partitions) } - /// Create vector batches - fn create_vec_batches(schema: &Schema, n: usize) -> Vec { - let batch = create_batch(schema); - let mut vec = Vec::with_capacity(n); - for _ in 0..n { - vec.push(batch.clone()); - } - vec - } + /// Return a batch of UInt32 with the specified range + fn uint32_batch(range: Range) -> RecordBatch { + let schema = + Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])); - /// Create batch - fn create_batch(schema: &Schema) -> RecordBatch { RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))], + Arc::clone(&schema), + vec![Arc::new(UInt32Array::from_iter_values(range))], ) .unwrap() } @@ -656,4 +813,9 @@ mod tests { } } } + fn batch_to_pretty_strings(batch: &RecordBatch) -> String { + arrow::util::pretty::pretty_format_batches(&[batch.clone()]) + .unwrap() + .to_string() + } } diff --git a/datafusion/physical-plan/src/execution_plan.rs b/datafusion/physical-plan/src/execution_plan.rs new file mode 100644 index 000000000000..5a3fc086c1f8 --- /dev/null +++ b/datafusion/physical-plan/src/execution_plan.rs @@ -0,0 +1,1018 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use futures::stream::{StreamExt, TryStreamExt}; +use tokio::task::JoinSet; + +use datafusion_common::config::ConfigOptions; +pub use datafusion_common::hash_utils; +pub use datafusion_common::utils::project_schema; +use datafusion_common::{exec_err, Result}; +pub use datafusion_common::{internal_err, ColumnStatistics, Statistics}; +use datafusion_execution::TaskContext; +pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; +pub use datafusion_expr::{Accumulator, ColumnarValue}; +pub use datafusion_physical_expr::window::WindowExpr; +pub use datafusion_physical_expr::{ + expressions, functions, udf, AggregateExpr, Distribution, Partitioning, PhysicalExpr, +}; +use datafusion_physical_expr::{ + EquivalenceProperties, LexOrdering, PhysicalSortExpr, PhysicalSortRequirement, +}; + +use crate::coalesce_partitions::CoalescePartitionsExec; +use crate::display::DisplayableExecutionPlan; +pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; +pub use crate::metrics::Metric; +use crate::metrics::MetricsSet; +pub use crate::ordering::InputOrderMode; +use crate::repartition::RepartitionExec; +use crate::sorts::sort_preserving_merge::SortPreservingMergeExec; +pub use crate::stream::EmptyRecordBatchStream; +use crate::stream::RecordBatchStreamAdapter; + +/// Represent nodes in the DataFusion Physical Plan. +/// +/// Calling [`execute`] produces an `async` [`SendableRecordBatchStream`] of +/// [`RecordBatch`] that incrementally computes a partition of the +/// `ExecutionPlan`'s output from its input. See [`Partitioning`] for more +/// details on partitioning. +/// +/// Methods such as [`Self::schema`] and [`Self::properties`] communicate +/// properties of the output to the DataFusion optimizer, and methods such as +/// [`required_input_distribution`] and [`required_input_ordering`] express +/// requirements of the `ExecutionPlan` from its input. +/// +/// [`ExecutionPlan`] can be displayed in a simplified form using the +/// return value from [`displayable`] in addition to the (normally +/// quite verbose) `Debug` output. +/// +/// [`execute`]: ExecutionPlan::execute +/// [`required_input_distribution`]: ExecutionPlan::required_input_distribution +/// [`required_input_ordering`]: ExecutionPlan::required_input_ordering +pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { + /// Short name for the ExecutionPlan, such as 'ParquetExec'. + /// + /// Implementation note: this method can just proxy to + /// [`static_name`](ExecutionPlan::static_name) if no special action is + /// needed. It doesn't provide a default implementation like that because + /// this method doesn't require the `Sized` constrain to allow a wilder + /// range of use cases. + fn name(&self) -> &str; + + /// Short name for the ExecutionPlan, such as 'ParquetExec'. + /// Like [`name`](ExecutionPlan::name) but can be called without an instance. + fn static_name() -> &'static str + where + Self: Sized, + { + let full_name = std::any::type_name::(); + let maybe_start_idx = full_name.rfind(':'); + match maybe_start_idx { + Some(start_idx) => &full_name[start_idx + 1..], + None => "UNKNOWN", + } + } + + /// Returns the execution plan as [`Any`] so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + /// Get the schema for this execution plan + fn schema(&self) -> SchemaRef { + Arc::clone(self.properties().schema()) + } + + /// Return properties of the output of the `ExecutionPlan`, such as output + /// ordering(s), partitioning information etc. + /// + /// This information is available via methods on [`ExecutionPlanProperties`] + /// trait, which is implemented for all `ExecutionPlan`s. + fn properties(&self) -> &PlanProperties; + + /// Specifies the data distribution requirements for all the + /// children for this `ExecutionPlan`, By default it's [[Distribution::UnspecifiedDistribution]] for each child, + fn required_input_distribution(&self) -> Vec { + vec![Distribution::UnspecifiedDistribution; self.children().len()] + } + + /// Specifies the ordering required for all of the children of this + /// `ExecutionPlan`. + /// + /// For each child, it's the local ordering requirement within + /// each partition rather than the global ordering + /// + /// NOTE that checking `!is_empty()` does **not** check for a + /// required input ordering. Instead, the correct check is that at + /// least one entry must be `Some` + fn required_input_ordering(&self) -> Vec>> { + vec![None; self.children().len()] + } + + /// Returns `false` if this `ExecutionPlan`'s implementation may reorder + /// rows within or between partitions. + /// + /// For example, Projection, Filter, and Limit maintain the order + /// of inputs -- they may transform values (Projection) or not + /// produce the same number of rows that went in (Filter and + /// Limit), but the rows that are produced go in the same way. + /// + /// DataFusion uses this metadata to apply certain optimizations + /// such as automatically repartitioning correctly. + /// + /// The default implementation returns `false` + /// + /// WARNING: if you override this default, you *MUST* ensure that + /// the `ExecutionPlan`'s maintains the ordering invariant or else + /// DataFusion may produce incorrect results. + fn maintains_input_order(&self) -> Vec { + vec![false; self.children().len()] + } + + /// Specifies whether the `ExecutionPlan` benefits from increased + /// parallelization at its input for each child. + /// + /// If returns `true`, the `ExecutionPlan` would benefit from partitioning + /// its corresponding child (and thus from more parallelism). For + /// `ExecutionPlan` that do very little work the overhead of extra + /// parallelism may outweigh any benefits + /// + /// The default implementation returns `true` unless this `ExecutionPlan` + /// has signalled it requires a single child input partition. + fn benefits_from_input_partitioning(&self) -> Vec { + // By default try to maximize parallelism with more CPUs if + // possible + self.required_input_distribution() + .into_iter() + .map(|dist| !matches!(dist, Distribution::SinglePartition)) + .collect() + } + + /// Get a list of children `ExecutionPlan`s that act as inputs to this plan. + /// The returned list will be empty for leaf nodes such as scans, will contain + /// a single value for unary nodes, or two values for binary nodes (such as + /// joins). + fn children(&self) -> Vec<&Arc>; + + /// Returns a new `ExecutionPlan` where all existing children were replaced + /// by the `children`, in order + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result>; + + /// If supported, attempt to increase the partitioning of this `ExecutionPlan` to + /// produce `target_partitions` partitions. + /// + /// If the `ExecutionPlan` does not support changing its partitioning, + /// returns `Ok(None)` (the default). + /// + /// It is the `ExecutionPlan` can increase its partitioning, but not to the + /// `target_partitions`, it may return an ExecutionPlan with fewer + /// partitions. This might happen, for example, if each new partition would + /// be too small to be efficiently processed individually. + /// + /// The DataFusion optimizer attempts to use as many threads as possible by + /// repartitioning its inputs to match the target number of threads + /// available (`target_partitions`). Some data sources, such as the built in + /// CSV and Parquet readers, implement this method as they are able to read + /// from their input files in parallel, regardless of how the source data is + /// split amongst files. + fn repartitioned( + &self, + _target_partitions: usize, + _config: &ConfigOptions, + ) -> Result>> { + Ok(None) + } + + /// Begin execution of `partition`, returning a [`Stream`] of + /// [`RecordBatch`]es. + /// + /// # Notes + /// + /// The `execute` method itself is not `async` but it returns an `async` + /// [`futures::stream::Stream`]. This `Stream` should incrementally compute + /// the output, `RecordBatch` by `RecordBatch` (in a streaming fashion). + /// Most `ExecutionPlan`s should not do any work before the first + /// `RecordBatch` is requested from the stream. + /// + /// [`RecordBatchStreamAdapter`] can be used to convert an `async` + /// [`Stream`] into a [`SendableRecordBatchStream`]. + /// + /// Using `async` `Streams` allows for network I/O during execution and + /// takes advantage of Rust's built in support for `async` continuations and + /// crate ecosystem. + /// + /// [`Stream`]: futures::stream::Stream + /// [`StreamExt`]: futures::stream::StreamExt + /// [`TryStreamExt`]: futures::stream::TryStreamExt + /// [`RecordBatchStreamAdapter`]: crate::stream::RecordBatchStreamAdapter + /// + /// # Cancellation / Aborting Execution + /// + /// The [`Stream`] that is returned must ensure that any allocated resources + /// are freed when the stream itself is dropped. This is particularly + /// important for [`spawn`]ed tasks or threads. Unless care is taken to + /// "abort" such tasks, they may continue to consume resources even after + /// the plan is dropped, generating intermediate results that are never + /// used. + /// Thus, [`spawn`] is disallowed, and instead use [`SpawnedTask`]. + /// + /// For more details see [`SpawnedTask`], [`JoinSet`] and [`RecordBatchReceiverStreamBuilder`] + /// for structures to help ensure all background tasks are cancelled. + /// + /// [`spawn`]: tokio::task::spawn + /// [`JoinSet`]: tokio::task::JoinSet + /// [`SpawnedTask`]: datafusion_common_runtime::SpawnedTask + /// [`RecordBatchReceiverStreamBuilder`]: crate::stream::RecordBatchReceiverStreamBuilder + /// + /// # Implementation Examples + /// + /// While `async` `Stream`s have a non trivial learning curve, the + /// [`futures`] crate provides [`StreamExt`] and [`TryStreamExt`] + /// which help simplify many common operations. + /// + /// Here are some common patterns: + /// + /// ## Return Precomputed `RecordBatch` + /// + /// We can return a precomputed `RecordBatch` as a `Stream`: + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::RecordBatch; + /// # use arrow_schema::SchemaRef; + /// # use datafusion_common::Result; + /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + /// # use datafusion_physical_plan::memory::MemoryStream; + /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; + /// struct MyPlan { + /// batch: RecordBatch, + /// } + /// + /// impl MyPlan { + /// fn execute( + /// &self, + /// partition: usize, + /// context: Arc + /// ) -> Result { + /// // use functions from futures crate convert the batch into a stream + /// let fut = futures::future::ready(Ok(self.batch.clone())); + /// let stream = futures::stream::once(fut); + /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.batch.schema(), stream))) + /// } + /// } + /// ``` + /// + /// ## Lazily (async) Compute `RecordBatch` + /// + /// We can also lazily compute a `RecordBatch` when the returned `Stream` is polled + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::RecordBatch; + /// # use arrow_schema::SchemaRef; + /// # use datafusion_common::Result; + /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + /// # use datafusion_physical_plan::memory::MemoryStream; + /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; + /// struct MyPlan { + /// schema: SchemaRef, + /// } + /// + /// /// Returns a single batch when the returned stream is polled + /// async fn get_batch() -> Result { + /// todo!() + /// } + /// + /// impl MyPlan { + /// fn execute( + /// &self, + /// partition: usize, + /// context: Arc + /// ) -> Result { + /// let fut = get_batch(); + /// let stream = futures::stream::once(fut); + /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))) + /// } + /// } + /// ``` + /// + /// ## Lazily (async) create a Stream + /// + /// If you need to create the return `Stream` using an `async` function, + /// you can do so by flattening the result: + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::RecordBatch; + /// # use arrow_schema::SchemaRef; + /// # use futures::TryStreamExt; + /// # use datafusion_common::Result; + /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + /// # use datafusion_physical_plan::memory::MemoryStream; + /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; + /// struct MyPlan { + /// schema: SchemaRef, + /// } + /// + /// /// async function that returns a stream + /// async fn get_batch_stream() -> Result { + /// todo!() + /// } + /// + /// impl MyPlan { + /// fn execute( + /// &self, + /// partition: usize, + /// context: Arc + /// ) -> Result { + /// // A future that yields a stream + /// let fut = get_batch_stream(); + /// // Use TryStreamExt::try_flatten to flatten the stream of streams + /// let stream = futures::stream::once(fut).try_flatten(); + /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))) + /// } + /// } + /// ``` + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result; + + /// Return a snapshot of the set of [`Metric`]s for this + /// [`ExecutionPlan`]. If no `Metric`s are available, return None. + /// + /// While the values of the metrics in the returned + /// [`MetricsSet`]s may change as execution progresses, the + /// specific metrics will not. + /// + /// Once `self.execute()` has returned (technically the future is + /// resolved) for all available partitions, the set of metrics + /// should be complete. If this function is called prior to + /// `execute()` new metrics may appear in subsequent calls. + fn metrics(&self) -> Option { + None + } + + /// Returns statistics for this `ExecutionPlan` node. If statistics are not + /// available, should return [`Statistics::new_unknown`] (the default), not + /// an error. + fn statistics(&self) -> Result { + Ok(Statistics::new_unknown(&self.schema())) + } + + /// Returns `true` if a limit can be safely pushed down through this + /// `ExecutionPlan` node. + /// + /// If this method returns `true`, and the query plan contains a limit at + /// the output of this node, DataFusion will push the limit to the input + /// of this node. + fn supports_limit_pushdown(&self) -> bool { + false + } + + /// Returns a fetching variant of this `ExecutionPlan` node, if it supports + /// fetch limits. Returns `None` otherwise. + fn with_fetch(&self, _limit: Option) -> Option> { + None + } +} + +/// Extension trait provides an easy API to fetch various properties of +/// [`ExecutionPlan`] objects based on [`ExecutionPlan::properties`]. +pub trait ExecutionPlanProperties { + /// Specifies how the output of this `ExecutionPlan` is split into + /// partitions. + fn output_partitioning(&self) -> &Partitioning; + + /// Specifies whether this plan generates an infinite stream of records. + /// If the plan does not support pipelining, but its input(s) are + /// infinite, returns [`ExecutionMode::PipelineBreaking`] to indicate this. + fn execution_mode(&self) -> ExecutionMode; + + /// If the output of this `ExecutionPlan` within each partition is sorted, + /// returns `Some(keys)` describing the ordering. A `None` return value + /// indicates no assumptions should be made on the output ordering. + /// + /// For example, `SortExec` (obviously) produces sorted output as does + /// `SortPreservingMergeStream`. Less obviously, `Projection` produces sorted + /// output if its input is sorted as it does not reorder the input rows. + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]>; + + /// Get the [`EquivalenceProperties`] within the plan. + /// + /// Equivalence properties tell DataFusion what columns are known to be + /// equal, during various optimization passes. By default, this returns "no + /// known equivalences" which is always correct, but may cause DataFusion to + /// unnecessarily resort data. + /// + /// If this ExecutionPlan makes no changes to the schema of the rows flowing + /// through it or how columns within each row relate to each other, it + /// should return the equivalence properties of its input. For + /// example, since `FilterExec` may remove rows from its input, but does not + /// otherwise modify them, it preserves its input equivalence properties. + /// However, since `ProjectionExec` may calculate derived expressions, it + /// needs special handling. + /// + /// See also [`ExecutionPlan::maintains_input_order`] and [`Self::output_ordering`] + /// for related concepts. + fn equivalence_properties(&self) -> &EquivalenceProperties; +} + +impl ExecutionPlanProperties for Arc { + fn output_partitioning(&self) -> &Partitioning { + self.properties().output_partitioning() + } + + fn execution_mode(&self) -> ExecutionMode { + self.properties().execution_mode() + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.properties().output_ordering() + } + + fn equivalence_properties(&self) -> &EquivalenceProperties { + self.properties().equivalence_properties() + } +} + +impl ExecutionPlanProperties for &dyn ExecutionPlan { + fn output_partitioning(&self) -> &Partitioning { + self.properties().output_partitioning() + } + + fn execution_mode(&self) -> ExecutionMode { + self.properties().execution_mode() + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.properties().output_ordering() + } + + fn equivalence_properties(&self) -> &EquivalenceProperties { + self.properties().equivalence_properties() + } +} + +/// Describes the execution mode of an operator's resulting stream with respect +/// to its size and behavior. There are three possible execution modes: `Bounded`, +/// `Unbounded` and `PipelineBreaking`. +#[derive(Clone, Copy, PartialEq, Debug)] +pub enum ExecutionMode { + /// Represents the mode where generated stream is bounded, e.g. finite. + Bounded, + /// Represents the mode where generated stream is unbounded, e.g. infinite. + /// Even though the operator generates an unbounded stream of results, it + /// works with bounded memory and execution can still continue successfully. + /// + /// The stream that results from calling `execute` on an `ExecutionPlan` that is `Unbounded` + /// will never be done (return `None`), except in case of error. + Unbounded, + /// Represents the mode where some of the operator's input stream(s) are + /// unbounded; however, the operator cannot generate streaming results from + /// these streaming inputs. In this case, the execution mode will be pipeline + /// breaking, e.g. the operator requires unbounded memory to generate results. + PipelineBreaking, +} + +impl ExecutionMode { + /// Check whether the execution mode is unbounded or not. + pub fn is_unbounded(&self) -> bool { + matches!(self, ExecutionMode::Unbounded) + } + + /// Check whether the execution is pipeline friendly. If so, operator can + /// execute safely. + pub fn pipeline_friendly(&self) -> bool { + matches!(self, ExecutionMode::Bounded | ExecutionMode::Unbounded) + } +} + +/// Conservatively "combines" execution modes of a given collection of operators. +pub(crate) fn execution_mode_from_children<'a>( + children: impl IntoIterator>, +) -> ExecutionMode { + let mut result = ExecutionMode::Bounded; + for mode in children.into_iter().map(|child| child.execution_mode()) { + match (mode, result) { + (ExecutionMode::PipelineBreaking, _) + | (_, ExecutionMode::PipelineBreaking) => { + // If any of the modes is `PipelineBreaking`, so is the result: + return ExecutionMode::PipelineBreaking; + } + (ExecutionMode::Unbounded, _) | (_, ExecutionMode::Unbounded) => { + // Unbounded mode eats up bounded mode: + result = ExecutionMode::Unbounded; + } + (ExecutionMode::Bounded, ExecutionMode::Bounded) => { + // When both modes are bounded, so is the result: + result = ExecutionMode::Bounded; + } + } + } + result +} + +/// Stores certain, often expensive to compute, plan properties used in query +/// optimization. +/// +/// These properties are stored a single structure to permit this information to +/// be computed once and then those cached results used multiple times without +/// recomputation (aka a cache) +#[derive(Debug, Clone)] +pub struct PlanProperties { + /// See [ExecutionPlanProperties::equivalence_properties] + pub eq_properties: EquivalenceProperties, + /// See [ExecutionPlanProperties::output_partitioning] + pub partitioning: Partitioning, + /// See [ExecutionPlanProperties::execution_mode] + pub execution_mode: ExecutionMode, + /// See [ExecutionPlanProperties::output_ordering] + output_ordering: Option, +} + +impl PlanProperties { + /// Construct a new `PlanPropertiesCache` from the + pub fn new( + eq_properties: EquivalenceProperties, + partitioning: Partitioning, + execution_mode: ExecutionMode, + ) -> Self { + // Output ordering can be derived from `eq_properties`. + let output_ordering = eq_properties.output_ordering(); + Self { + eq_properties, + partitioning, + execution_mode, + output_ordering, + } + } + + /// Overwrite output partitioning with its new value. + pub fn with_partitioning(mut self, partitioning: Partitioning) -> Self { + self.partitioning = partitioning; + self + } + + /// Overwrite the execution Mode with its new value. + pub fn with_execution_mode(mut self, execution_mode: ExecutionMode) -> Self { + self.execution_mode = execution_mode; + self + } + + /// Overwrite equivalence properties with its new value. + pub fn with_eq_properties(mut self, eq_properties: EquivalenceProperties) -> Self { + // Changing equivalence properties also changes output ordering, so + // make sure to overwrite it: + self.output_ordering = eq_properties.output_ordering(); + self.eq_properties = eq_properties; + self + } + + pub fn equivalence_properties(&self) -> &EquivalenceProperties { + &self.eq_properties + } + + pub fn output_partitioning(&self) -> &Partitioning { + &self.partitioning + } + + pub fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.output_ordering.as_deref() + } + + pub fn execution_mode(&self) -> ExecutionMode { + self.execution_mode + } + + /// Get schema of the node. + fn schema(&self) -> &SchemaRef { + self.eq_properties.schema() + } +} + +/// Indicate whether a data exchange is needed for the input of `plan`, which will be very helpful +/// especially for the distributed engine to judge whether need to deal with shuffling. +/// Currently there are 3 kinds of execution plan which needs data exchange +/// 1. RepartitionExec for changing the partition number between two `ExecutionPlan`s +/// 2. CoalescePartitionsExec for collapsing all of the partitions into one without ordering guarantee +/// 3. SortPreservingMergeExec for collapsing all of the sorted partitions into one with ordering guarantee +pub fn need_data_exchange(plan: Arc) -> bool { + if let Some(repartition) = plan.as_any().downcast_ref::() { + !matches!( + repartition.properties().output_partitioning(), + Partitioning::RoundRobinBatch(_) + ) + } else if let Some(coalesce) = plan.as_any().downcast_ref::() + { + coalesce.input().output_partitioning().partition_count() > 1 + } else if let Some(sort_preserving_merge) = + plan.as_any().downcast_ref::() + { + sort_preserving_merge + .input() + .output_partitioning() + .partition_count() + > 1 + } else { + false + } +} + +/// Returns a copy of this plan if we change any child according to the pointer comparison. +/// The size of `children` must be equal to the size of `ExecutionPlan::children()`. +pub fn with_new_children_if_necessary( + plan: Arc, + children: Vec>, +) -> Result> { + let old_children = plan.children(); + if children.len() != old_children.len() { + internal_err!("Wrong number of children") + } else if children.is_empty() + || children + .iter() + .zip(old_children.iter()) + .any(|(c1, c2)| !Arc::ptr_eq(c1, c2)) + { + plan.with_new_children(children) + } else { + Ok(plan) + } +} + +/// Return a [wrapper](DisplayableExecutionPlan) around an +/// [`ExecutionPlan`] which can be displayed in various easier to +/// understand ways. +pub fn displayable(plan: &dyn ExecutionPlan) -> DisplayableExecutionPlan<'_> { + DisplayableExecutionPlan::new(plan) +} + +/// Execute the [ExecutionPlan] and collect the results in memory +pub async fn collect( + plan: Arc, + context: Arc, +) -> Result> { + let stream = execute_stream(plan, context)?; + crate::common::collect(stream).await +} + +/// Execute the [ExecutionPlan] and return a single stream of `RecordBatch`es. +/// +/// See [collect] to buffer the `RecordBatch`es in memory. +/// +/// # Aborting Execution +/// +/// Dropping the stream will abort the execution of the query, and free up +/// any allocated resources +pub fn execute_stream( + plan: Arc, + context: Arc, +) -> Result { + match plan.output_partitioning().partition_count() { + 0 => Ok(Box::pin(EmptyRecordBatchStream::new(plan.schema()))), + 1 => plan.execute(0, context), + _ => { + // merge into a single partition + let plan = CoalescePartitionsExec::new(Arc::clone(&plan)); + // CoalescePartitionsExec must produce a single partition + assert_eq!(1, plan.properties().output_partitioning().partition_count()); + plan.execute(0, context) + } + } +} + +/// Execute the [ExecutionPlan] and collect the results in memory +pub async fn collect_partitioned( + plan: Arc, + context: Arc, +) -> Result>> { + let streams = execute_stream_partitioned(plan, context)?; + + let mut join_set = JoinSet::new(); + // Execute the plan and collect the results into batches. + streams.into_iter().enumerate().for_each(|(idx, stream)| { + join_set.spawn(async move { + let result: Result> = stream.try_collect().await; + (idx, result) + }); + }); + + let mut batches = vec![]; + // Note that currently this doesn't identify the thread that panicked + // + // TODO: Replace with [join_next_with_id](https://docs.rs/tokio/latest/tokio/task/struct.JoinSet.html#method.join_next_with_id + // once it is stable + while let Some(result) = join_set.join_next().await { + match result { + Ok((idx, res)) => batches.push((idx, res?)), + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + + batches.sort_by_key(|(idx, _)| *idx); + let batches = batches.into_iter().map(|(_, batch)| batch).collect(); + + Ok(batches) +} + +/// Execute the [ExecutionPlan] and return a vec with one stream per output +/// partition +/// +/// # Aborting Execution +/// +/// Dropping the stream will abort the execution of the query, and free up +/// any allocated resources +pub fn execute_stream_partitioned( + plan: Arc, + context: Arc, +) -> Result> { + let num_partitions = plan.output_partitioning().partition_count(); + let mut streams = Vec::with_capacity(num_partitions); + for i in 0..num_partitions { + streams.push(plan.execute(i, Arc::clone(&context))?); + } + Ok(streams) +} + +/// Executes an input stream and ensures that the resulting stream adheres to +/// the `not null` constraints specified in the `sink_schema`. +/// +/// # Arguments +/// +/// * `input` - An execution plan +/// * `sink_schema` - The schema to be applied to the output stream +/// * `partition` - The partition index to be executed +/// * `context` - The task context +/// +/// # Returns +/// +/// * `Result` - A stream of `RecordBatch`es if successful +/// +/// This function first executes the given input plan for the specified partition +/// and context. It then checks if there are any columns in the input that might +/// violate the `not null` constraints specified in the `sink_schema`. If there are +/// such columns, it wraps the resulting stream to enforce the `not null` constraints +/// by invoking the `check_not_null_contraits` function on each batch of the stream. +pub fn execute_input_stream( + input: Arc, + sink_schema: SchemaRef, + partition: usize, + context: Arc, +) -> Result { + let input_stream = input.execute(partition, context)?; + + debug_assert_eq!(sink_schema.fields().len(), input.schema().fields().len()); + + // Find input columns that may violate the not null constraint. + let risky_columns: Vec<_> = sink_schema + .fields() + .iter() + .zip(input.schema().fields().iter()) + .enumerate() + .filter_map(|(idx, (sink_field, input_field))| { + (!sink_field.is_nullable() && input_field.is_nullable()).then_some(idx) + }) + .collect(); + + if risky_columns.is_empty() { + Ok(input_stream) + } else { + // Check not null constraint on the input stream + Ok(Box::pin(RecordBatchStreamAdapter::new( + sink_schema, + input_stream + .map(move |batch| check_not_null_contraits(batch?, &risky_columns)), + ))) + } +} + +/// Checks a `RecordBatch` for `not null` constraints on specified columns. +/// +/// # Arguments +/// +/// * `batch` - The `RecordBatch` to be checked +/// * `column_indices` - A vector of column indices that should be checked for +/// `not null` constraints. +/// +/// # Returns +/// +/// * `Result` - The original `RecordBatch` if all constraints are met +/// +/// This function iterates over the specified column indices and ensures that none +/// of the columns contain null values. If any column contains null values, an error +/// is returned. +pub fn check_not_null_contraits( + batch: RecordBatch, + column_indices: &Vec, +) -> Result { + for &index in column_indices { + if batch.num_columns() <= index { + return exec_err!( + "Invalid batch column count {} expected > {}", + batch.num_columns(), + index + ); + } + + if batch.column(index).null_count() > 0 { + return exec_err!( + "Invalid batch column at '{}' has null but schema specifies non-nullable", + index + ); + } + } + + Ok(batch) +} + +/// Utility function yielding a string representation of the given [`ExecutionPlan`]. +pub fn get_plan_string(plan: &Arc) -> Vec { + let formatted = displayable(plan.as_ref()).indent(true).to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + actual.iter().map(|elem| elem.to_string()).collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use std::any::Any; + use std::sync::Arc; + + use arrow_schema::{Schema, SchemaRef}; + + use datafusion_common::{Result, Statistics}; + use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + + use crate::{DisplayAs, DisplayFormatType, ExecutionPlan}; + + #[derive(Debug)] + pub struct EmptyExec; + + impl EmptyExec { + pub fn new(_schema: SchemaRef) -> Self { + Self + } + } + + impl DisplayAs for EmptyExec { + fn fmt_as( + &self, + _t: DisplayFormatType, + _f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + unimplemented!() + } + } + + impl ExecutionPlan for EmptyExec { + fn name(&self) -> &'static str { + Self::static_name() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + unimplemented!() + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + unimplemented!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!() + } + + fn statistics(&self) -> Result { + unimplemented!() + } + } + + #[derive(Debug)] + pub struct RenamedEmptyExec; + + impl RenamedEmptyExec { + pub fn new(_schema: SchemaRef) -> Self { + Self + } + } + + impl DisplayAs for RenamedEmptyExec { + fn fmt_as( + &self, + _t: DisplayFormatType, + _f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + unimplemented!() + } + } + + impl ExecutionPlan for RenamedEmptyExec { + fn name(&self) -> &'static str { + Self::static_name() + } + + fn static_name() -> &'static str + where + Self: Sized, + { + "MyRenamedEmptyExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + unimplemented!() + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + unimplemented!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!() + } + + fn statistics(&self) -> Result { + unimplemented!() + } + } + + #[test] + fn test_execution_plan_name() { + let schema1 = Arc::new(Schema::empty()); + let default_name_exec = EmptyExec::new(schema1); + assert_eq!(default_name_exec.name(), "EmptyExec"); + + let schema2 = Arc::new(Schema::empty()); + let renamed_exec = RenamedEmptyExec::new(schema2); + assert_eq!(renamed_exec.name(), "MyRenamedEmptyExec"); + assert_eq!(RenamedEmptyExec::static_name(), "MyRenamedEmptyExec"); + } + + /// A compilation test to ensure that the `ExecutionPlan::name()` method can + /// be called from a trait object. + /// Related ticket: https://github.com/apache/datafusion/pull/11047 + #[allow(dead_code)] + fn use_execution_plan_as_trait_object(plan: &dyn ExecutionPlan) { + let _ = plan.name(); + } +} + +// pub mod test; diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 67de0989649e..69bcfefcd476 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -380,11 +380,11 @@ impl Stream for FilterExecStream { Some(Ok(batch)) => { let timer = self.baseline_metrics.elapsed_compute().timer(); let filtered_batch = batch_filter(&batch, &self.predicate)?; + timer.done(); // skip entirely filtered batches if filtered_batch.num_rows() == 0 { continue; } - timer.done(); poll = Poll::Ready(Some(Ok(filtered_batch))); break; } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 51744730a5a1..b8a58e4d0d30 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -66,7 +66,7 @@ use parking_lot::Mutex; /// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 /// As the key is a hash value, we need to check possible hash collisions in the probe stage /// During this stage it might be the case that a row is contained the same hashmap value, -/// but the values don't match. Those are checked in the [`equal_rows_arr`](crate::joins::hash_join::equal_rows_arr) method. +/// but the values don't match. Those are checked in the `equal_rows_arr` method. /// /// The indices (values) are stored in a separate chained list stored in the `Vec`. /// diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 19554d07f7a0..eeecc017c2af 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -14,46 +14,36 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 + #![deny(clippy::clone_on_ref_ptr)] //! Traits for physical query plan, supporting parallel execution for partitioned relations. +//! +//! Entrypoint of this crate is trait [ExecutionPlan]. -use std::any::Any; -use std::fmt::Debug; -use std::sync::Arc; - -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; -use futures::stream::{StreamExt, TryStreamExt}; -use tokio::task::JoinSet; - -use datafusion_common::config::ConfigOptions; pub use datafusion_common::hash_utils; pub use datafusion_common::utils::project_schema; -use datafusion_common::{exec_err, Result}; pub use datafusion_common::{internal_err, ColumnStatistics, Statistics}; -use datafusion_execution::TaskContext; pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; pub use datafusion_expr::{Accumulator, ColumnarValue}; pub use datafusion_physical_expr::window::WindowExpr; +use datafusion_physical_expr::PhysicalSortExpr; pub use datafusion_physical_expr::{ expressions, functions, udf, AggregateExpr, Distribution, Partitioning, PhysicalExpr, }; -use datafusion_physical_expr::{ - EquivalenceProperties, LexOrdering, PhysicalSortExpr, PhysicalSortRequirement, -}; -use crate::coalesce_partitions::CoalescePartitionsExec; -use crate::display::DisplayableExecutionPlan; pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; +pub(crate) use crate::execution_plan::execution_mode_from_children; +pub use crate::execution_plan::{ + collect, collect_partitioned, displayable, execute_input_stream, execute_stream, + execute_stream_partitioned, get_plan_string, with_new_children_if_necessary, + ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, +}; pub use crate::metrics::Metric; -use crate::metrics::MetricsSet; pub use crate::ordering::InputOrderMode; -use crate::repartition::RepartitionExec; -use crate::sorts::sort_preserving_merge::SortPreservingMergeExec; pub use crate::stream::EmptyRecordBatchStream; -use crate::stream::RecordBatchStreamAdapter; pub use crate::topk::TopK; pub use crate::visitor::{accept, visit_execution_plan, ExecutionPlanVisitor}; @@ -68,6 +58,7 @@ pub mod coalesce_partitions; pub mod common; pub mod display; pub mod empty; +pub mod execution_plan; pub mod explain; pub mod filter; pub mod insert; @@ -96,967 +87,5 @@ pub mod udaf { }; } -/// Represent nodes in the DataFusion Physical Plan. -/// -/// Calling [`execute`] produces an `async` [`SendableRecordBatchStream`] of -/// [`RecordBatch`] that incrementally computes a partition of the -/// `ExecutionPlan`'s output from its input. See [`Partitioning`] for more -/// details on partitioning. -/// -/// Methods such as [`Self::schema`] and [`Self::properties`] communicate -/// properties of the output to the DataFusion optimizer, and methods such as -/// [`required_input_distribution`] and [`required_input_ordering`] express -/// requirements of the `ExecutionPlan` from its input. -/// -/// [`ExecutionPlan`] can be displayed in a simplified form using the -/// return value from [`displayable`] in addition to the (normally -/// quite verbose) `Debug` output. -/// -/// [`execute`]: ExecutionPlan::execute -/// [`required_input_distribution`]: ExecutionPlan::required_input_distribution -/// [`required_input_ordering`]: ExecutionPlan::required_input_ordering -pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { - /// Short name for the ExecutionPlan, such as 'ParquetExec'. - /// - /// Implementation note: this method can just proxy to - /// [`static_name`](ExecutionPlan::static_name) if no special action is - /// needed. It doesn't provide a default implementation like that because - /// this method doesn't require the `Sized` constrain to allow a wilder - /// range of use cases. - fn name(&self) -> &str; - - /// Short name for the ExecutionPlan, such as 'ParquetExec'. - /// Like [`name`](ExecutionPlan::name) but can be called without an instance. - fn static_name() -> &'static str - where - Self: Sized, - { - let full_name = std::any::type_name::(); - let maybe_start_idx = full_name.rfind(':'); - match maybe_start_idx { - Some(start_idx) => &full_name[start_idx + 1..], - None => "UNKNOWN", - } - } - - /// Returns the execution plan as [`Any`] so that it can be - /// downcast to a specific implementation. - fn as_any(&self) -> &dyn Any; - - /// Get the schema for this execution plan - fn schema(&self) -> SchemaRef { - Arc::clone(self.properties().schema()) - } - - /// Return properties of the output of the `ExecutionPlan`, such as output - /// ordering(s), partitioning information etc. - /// - /// This information is available via methods on [`ExecutionPlanProperties`] - /// trait, which is implemented for all `ExecutionPlan`s. - fn properties(&self) -> &PlanProperties; - - /// Specifies the data distribution requirements for all the - /// children for this `ExecutionPlan`, By default it's [[Distribution::UnspecifiedDistribution]] for each child, - fn required_input_distribution(&self) -> Vec { - vec![Distribution::UnspecifiedDistribution; self.children().len()] - } - - /// Specifies the ordering required for all of the children of this - /// `ExecutionPlan`. - /// - /// For each child, it's the local ordering requirement within - /// each partition rather than the global ordering - /// - /// NOTE that checking `!is_empty()` does **not** check for a - /// required input ordering. Instead, the correct check is that at - /// least one entry must be `Some` - fn required_input_ordering(&self) -> Vec>> { - vec![None; self.children().len()] - } - - /// Returns `false` if this `ExecutionPlan`'s implementation may reorder - /// rows within or between partitions. - /// - /// For example, Projection, Filter, and Limit maintain the order - /// of inputs -- they may transform values (Projection) or not - /// produce the same number of rows that went in (Filter and - /// Limit), but the rows that are produced go in the same way. - /// - /// DataFusion uses this metadata to apply certain optimizations - /// such as automatically repartitioning correctly. - /// - /// The default implementation returns `false` - /// - /// WARNING: if you override this default, you *MUST* ensure that - /// the `ExecutionPlan`'s maintains the ordering invariant or else - /// DataFusion may produce incorrect results. - fn maintains_input_order(&self) -> Vec { - vec![false; self.children().len()] - } - - /// Specifies whether the `ExecutionPlan` benefits from increased - /// parallelization at its input for each child. - /// - /// If returns `true`, the `ExecutionPlan` would benefit from partitioning - /// its corresponding child (and thus from more parallelism). For - /// `ExecutionPlan` that do very little work the overhead of extra - /// parallelism may outweigh any benefits - /// - /// The default implementation returns `true` unless this `ExecutionPlan` - /// has signalled it requires a single child input partition. - fn benefits_from_input_partitioning(&self) -> Vec { - // By default try to maximize parallelism with more CPUs if - // possible - self.required_input_distribution() - .into_iter() - .map(|dist| !matches!(dist, Distribution::SinglePartition)) - .collect() - } - - /// Get a list of children `ExecutionPlan`s that act as inputs to this plan. - /// The returned list will be empty for leaf nodes such as scans, will contain - /// a single value for unary nodes, or two values for binary nodes (such as - /// joins). - fn children(&self) -> Vec<&Arc>; - - /// Returns a new `ExecutionPlan` where all existing children were replaced - /// by the `children`, in order - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result>; - - /// If supported, attempt to increase the partitioning of this `ExecutionPlan` to - /// produce `target_partitions` partitions. - /// - /// If the `ExecutionPlan` does not support changing its partitioning, - /// returns `Ok(None)` (the default). - /// - /// It is the `ExecutionPlan` can increase its partitioning, but not to the - /// `target_partitions`, it may return an ExecutionPlan with fewer - /// partitions. This might happen, for example, if each new partition would - /// be too small to be efficiently processed individually. - /// - /// The DataFusion optimizer attempts to use as many threads as possible by - /// repartitioning its inputs to match the target number of threads - /// available (`target_partitions`). Some data sources, such as the built in - /// CSV and Parquet readers, implement this method as they are able to read - /// from their input files in parallel, regardless of how the source data is - /// split amongst files. - fn repartitioned( - &self, - _target_partitions: usize, - _config: &ConfigOptions, - ) -> Result>> { - Ok(None) - } - - /// Begin execution of `partition`, returning a [`Stream`] of - /// [`RecordBatch`]es. - /// - /// # Notes - /// - /// The `execute` method itself is not `async` but it returns an `async` - /// [`futures::stream::Stream`]. This `Stream` should incrementally compute - /// the output, `RecordBatch` by `RecordBatch` (in a streaming fashion). - /// Most `ExecutionPlan`s should not do any work before the first - /// `RecordBatch` is requested from the stream. - /// - /// [`RecordBatchStreamAdapter`] can be used to convert an `async` - /// [`Stream`] into a [`SendableRecordBatchStream`]. - /// - /// Using `async` `Streams` allows for network I/O during execution and - /// takes advantage of Rust's built in support for `async` continuations and - /// crate ecosystem. - /// - /// [`Stream`]: futures::stream::Stream - /// [`StreamExt`]: futures::stream::StreamExt - /// [`TryStreamExt`]: futures::stream::TryStreamExt - /// [`RecordBatchStreamAdapter`]: crate::stream::RecordBatchStreamAdapter - /// - /// # Cancellation / Aborting Execution - /// - /// The [`Stream`] that is returned must ensure that any allocated resources - /// are freed when the stream itself is dropped. This is particularly - /// important for [`spawn`]ed tasks or threads. Unless care is taken to - /// "abort" such tasks, they may continue to consume resources even after - /// the plan is dropped, generating intermediate results that are never - /// used. - /// Thus, [`spawn`] is disallowed, and instead use [`SpawnedTask`]. - /// - /// For more details see [`SpawnedTask`], [`JoinSet`] and [`RecordBatchReceiverStreamBuilder`] - /// for structures to help ensure all background tasks are cancelled. - /// - /// [`spawn`]: tokio::task::spawn - /// [`JoinSet`]: tokio::task::JoinSet - /// [`SpawnedTask`]: datafusion_common_runtime::SpawnedTask - /// [`RecordBatchReceiverStreamBuilder`]: crate::stream::RecordBatchReceiverStreamBuilder - /// - /// # Implementation Examples - /// - /// While `async` `Stream`s have a non trivial learning curve, the - /// [`futures`] crate provides [`StreamExt`] and [`TryStreamExt`] - /// which help simplify many common operations. - /// - /// Here are some common patterns: - /// - /// ## Return Precomputed `RecordBatch` - /// - /// We can return a precomputed `RecordBatch` as a `Stream`: - /// - /// ``` - /// # use std::sync::Arc; - /// # use arrow_array::RecordBatch; - /// # use arrow_schema::SchemaRef; - /// # use datafusion_common::Result; - /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; - /// # use datafusion_physical_plan::memory::MemoryStream; - /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; - /// struct MyPlan { - /// batch: RecordBatch, - /// } - /// - /// impl MyPlan { - /// fn execute( - /// &self, - /// partition: usize, - /// context: Arc - /// ) -> Result { - /// // use functions from futures crate convert the batch into a stream - /// let fut = futures::future::ready(Ok(self.batch.clone())); - /// let stream = futures::stream::once(fut); - /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.batch.schema(), stream))) - /// } - /// } - /// ``` - /// - /// ## Lazily (async) Compute `RecordBatch` - /// - /// We can also lazily compute a `RecordBatch` when the returned `Stream` is polled - /// - /// ``` - /// # use std::sync::Arc; - /// # use arrow_array::RecordBatch; - /// # use arrow_schema::SchemaRef; - /// # use datafusion_common::Result; - /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; - /// # use datafusion_physical_plan::memory::MemoryStream; - /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; - /// struct MyPlan { - /// schema: SchemaRef, - /// } - /// - /// /// Returns a single batch when the returned stream is polled - /// async fn get_batch() -> Result { - /// todo!() - /// } - /// - /// impl MyPlan { - /// fn execute( - /// &self, - /// partition: usize, - /// context: Arc - /// ) -> Result { - /// let fut = get_batch(); - /// let stream = futures::stream::once(fut); - /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))) - /// } - /// } - /// ``` - /// - /// ## Lazily (async) create a Stream - /// - /// If you need to create the return `Stream` using an `async` function, - /// you can do so by flattening the result: - /// - /// ``` - /// # use std::sync::Arc; - /// # use arrow_array::RecordBatch; - /// # use arrow_schema::SchemaRef; - /// # use futures::TryStreamExt; - /// # use datafusion_common::Result; - /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; - /// # use datafusion_physical_plan::memory::MemoryStream; - /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; - /// struct MyPlan { - /// schema: SchemaRef, - /// } - /// - /// /// async function that returns a stream - /// async fn get_batch_stream() -> Result { - /// todo!() - /// } - /// - /// impl MyPlan { - /// fn execute( - /// &self, - /// partition: usize, - /// context: Arc - /// ) -> Result { - /// // A future that yields a stream - /// let fut = get_batch_stream(); - /// // Use TryStreamExt::try_flatten to flatten the stream of streams - /// let stream = futures::stream::once(fut).try_flatten(); - /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))) - /// } - /// } - /// ``` - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result; - - /// Return a snapshot of the set of [`Metric`]s for this - /// [`ExecutionPlan`]. If no `Metric`s are available, return None. - /// - /// While the values of the metrics in the returned - /// [`MetricsSet`]s may change as execution progresses, the - /// specific metrics will not. - /// - /// Once `self.execute()` has returned (technically the future is - /// resolved) for all available partitions, the set of metrics - /// should be complete. If this function is called prior to - /// `execute()` new metrics may appear in subsequent calls. - fn metrics(&self) -> Option { - None - } - - /// Returns statistics for this `ExecutionPlan` node. If statistics are not - /// available, should return [`Statistics::new_unknown`] (the default), not - /// an error. - fn statistics(&self) -> Result { - Ok(Statistics::new_unknown(&self.schema())) - } - - /// Returns `true` if a limit can be safely pushed down through this - /// `ExecutionPlan` node. - /// - /// If this method returns `true`, and the query plan contains a limit at - /// the output of this node, DataFusion will push the limit to the input - /// of this node. - fn supports_limit_pushdown(&self) -> bool { - false - } - - /// Returns a fetching variant of this `ExecutionPlan` node, if it supports - /// fetch limits. Returns `None` otherwise. - fn with_fetch(&self, _limit: Option) -> Option> { - None - } -} - -/// Extension trait provides an easy API to fetch various properties of -/// [`ExecutionPlan`] objects based on [`ExecutionPlan::properties`]. -pub trait ExecutionPlanProperties { - /// Specifies how the output of this `ExecutionPlan` is split into - /// partitions. - fn output_partitioning(&self) -> &Partitioning; - - /// Specifies whether this plan generates an infinite stream of records. - /// If the plan does not support pipelining, but its input(s) are - /// infinite, returns [`ExecutionMode::PipelineBreaking`] to indicate this. - fn execution_mode(&self) -> ExecutionMode; - - /// If the output of this `ExecutionPlan` within each partition is sorted, - /// returns `Some(keys)` describing the ordering. A `None` return value - /// indicates no assumptions should be made on the output ordering. - /// - /// For example, `SortExec` (obviously) produces sorted output as does - /// `SortPreservingMergeStream`. Less obviously, `Projection` produces sorted - /// output if its input is sorted as it does not reorder the input rows. - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]>; - - /// Get the [`EquivalenceProperties`] within the plan. - /// - /// Equivalence properties tell DataFusion what columns are known to be - /// equal, during various optimization passes. By default, this returns "no - /// known equivalences" which is always correct, but may cause DataFusion to - /// unnecessarily resort data. - /// - /// If this ExecutionPlan makes no changes to the schema of the rows flowing - /// through it or how columns within each row relate to each other, it - /// should return the equivalence properties of its input. For - /// example, since `FilterExec` may remove rows from its input, but does not - /// otherwise modify them, it preserves its input equivalence properties. - /// However, since `ProjectionExec` may calculate derived expressions, it - /// needs special handling. - /// - /// See also [`ExecutionPlan::maintains_input_order`] and [`Self::output_ordering`] - /// for related concepts. - fn equivalence_properties(&self) -> &EquivalenceProperties; -} - -impl ExecutionPlanProperties for Arc { - fn output_partitioning(&self) -> &Partitioning { - self.properties().output_partitioning() - } - - fn execution_mode(&self) -> ExecutionMode { - self.properties().execution_mode() - } - - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.properties().output_ordering() - } - - fn equivalence_properties(&self) -> &EquivalenceProperties { - self.properties().equivalence_properties() - } -} - -impl ExecutionPlanProperties for &dyn ExecutionPlan { - fn output_partitioning(&self) -> &Partitioning { - self.properties().output_partitioning() - } - - fn execution_mode(&self) -> ExecutionMode { - self.properties().execution_mode() - } - - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.properties().output_ordering() - } - - fn equivalence_properties(&self) -> &EquivalenceProperties { - self.properties().equivalence_properties() - } -} - -/// Describes the execution mode of an operator's resulting stream with respect -/// to its size and behavior. There are three possible execution modes: `Bounded`, -/// `Unbounded` and `PipelineBreaking`. -#[derive(Clone, Copy, PartialEq, Debug)] -pub enum ExecutionMode { - /// Represents the mode where generated stream is bounded, e.g. finite. - Bounded, - /// Represents the mode where generated stream is unbounded, e.g. infinite. - /// Even though the operator generates an unbounded stream of results, it - /// works with bounded memory and execution can still continue successfully. - /// - /// The stream that results from calling `execute` on an `ExecutionPlan` that is `Unbounded` - /// will never be done (return `None`), except in case of error. - Unbounded, - /// Represents the mode where some of the operator's input stream(s) are - /// unbounded; however, the operator cannot generate streaming results from - /// these streaming inputs. In this case, the execution mode will be pipeline - /// breaking, e.g. the operator requires unbounded memory to generate results. - PipelineBreaking, -} - -impl ExecutionMode { - /// Check whether the execution mode is unbounded or not. - pub fn is_unbounded(&self) -> bool { - matches!(self, ExecutionMode::Unbounded) - } - - /// Check whether the execution is pipeline friendly. If so, operator can - /// execute safely. - pub fn pipeline_friendly(&self) -> bool { - matches!(self, ExecutionMode::Bounded | ExecutionMode::Unbounded) - } -} - -/// Conservatively "combines" execution modes of a given collection of operators. -fn execution_mode_from_children<'a>( - children: impl IntoIterator>, -) -> ExecutionMode { - let mut result = ExecutionMode::Bounded; - for mode in children.into_iter().map(|child| child.execution_mode()) { - match (mode, result) { - (ExecutionMode::PipelineBreaking, _) - | (_, ExecutionMode::PipelineBreaking) => { - // If any of the modes is `PipelineBreaking`, so is the result: - return ExecutionMode::PipelineBreaking; - } - (ExecutionMode::Unbounded, _) | (_, ExecutionMode::Unbounded) => { - // Unbounded mode eats up bounded mode: - result = ExecutionMode::Unbounded; - } - (ExecutionMode::Bounded, ExecutionMode::Bounded) => { - // When both modes are bounded, so is the result: - result = ExecutionMode::Bounded; - } - } - } - result -} - -/// Stores certain, often expensive to compute, plan properties used in query -/// optimization. -/// -/// These properties are stored a single structure to permit this information to -/// be computed once and then those cached results used multiple times without -/// recomputation (aka a cache) -#[derive(Debug, Clone)] -pub struct PlanProperties { - /// See [ExecutionPlanProperties::equivalence_properties] - pub eq_properties: EquivalenceProperties, - /// See [ExecutionPlanProperties::output_partitioning] - pub partitioning: Partitioning, - /// See [ExecutionPlanProperties::execution_mode] - pub execution_mode: ExecutionMode, - /// See [ExecutionPlanProperties::output_ordering] - output_ordering: Option, -} - -impl PlanProperties { - /// Construct a new `PlanPropertiesCache` from the - pub fn new( - eq_properties: EquivalenceProperties, - partitioning: Partitioning, - execution_mode: ExecutionMode, - ) -> Self { - // Output ordering can be derived from `eq_properties`. - let output_ordering = eq_properties.output_ordering(); - Self { - eq_properties, - partitioning, - execution_mode, - output_ordering, - } - } - - /// Overwrite output partitioning with its new value. - pub fn with_partitioning(mut self, partitioning: Partitioning) -> Self { - self.partitioning = partitioning; - self - } - - /// Overwrite the execution Mode with its new value. - pub fn with_execution_mode(mut self, execution_mode: ExecutionMode) -> Self { - self.execution_mode = execution_mode; - self - } - - /// Overwrite equivalence properties with its new value. - pub fn with_eq_properties(mut self, eq_properties: EquivalenceProperties) -> Self { - // Changing equivalence properties also changes output ordering, so - // make sure to overwrite it: - self.output_ordering = eq_properties.output_ordering(); - self.eq_properties = eq_properties; - self - } - - pub fn equivalence_properties(&self) -> &EquivalenceProperties { - &self.eq_properties - } - - pub fn output_partitioning(&self) -> &Partitioning { - &self.partitioning - } - - pub fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.output_ordering.as_deref() - } - - pub fn execution_mode(&self) -> ExecutionMode { - self.execution_mode - } - - /// Get schema of the node. - fn schema(&self) -> &SchemaRef { - self.eq_properties.schema() - } -} - -/// Indicate whether a data exchange is needed for the input of `plan`, which will be very helpful -/// especially for the distributed engine to judge whether need to deal with shuffling. -/// Currently there are 3 kinds of execution plan which needs data exchange -/// 1. RepartitionExec for changing the partition number between two `ExecutionPlan`s -/// 2. CoalescePartitionsExec for collapsing all of the partitions into one without ordering guarantee -/// 3. SortPreservingMergeExec for collapsing all of the sorted partitions into one with ordering guarantee -pub fn need_data_exchange(plan: Arc) -> bool { - if let Some(repartition) = plan.as_any().downcast_ref::() { - !matches!( - repartition.properties().output_partitioning(), - Partitioning::RoundRobinBatch(_) - ) - } else if let Some(coalesce) = plan.as_any().downcast_ref::() - { - coalesce.input().output_partitioning().partition_count() > 1 - } else if let Some(sort_preserving_merge) = - plan.as_any().downcast_ref::() - { - sort_preserving_merge - .input() - .output_partitioning() - .partition_count() - > 1 - } else { - false - } -} - -/// Returns a copy of this plan if we change any child according to the pointer comparison. -/// The size of `children` must be equal to the size of `ExecutionPlan::children()`. -pub fn with_new_children_if_necessary( - plan: Arc, - children: Vec>, -) -> Result> { - let old_children = plan.children(); - if children.len() != old_children.len() { - internal_err!("Wrong number of children") - } else if children.is_empty() - || children - .iter() - .zip(old_children.iter()) - .any(|(c1, c2)| !Arc::ptr_eq(c1, c2)) - { - plan.with_new_children(children) - } else { - Ok(plan) - } -} - -/// Return a [wrapper](DisplayableExecutionPlan) around an -/// [`ExecutionPlan`] which can be displayed in various easier to -/// understand ways. -pub fn displayable(plan: &dyn ExecutionPlan) -> DisplayableExecutionPlan<'_> { - DisplayableExecutionPlan::new(plan) -} - -/// Execute the [ExecutionPlan] and collect the results in memory -pub async fn collect( - plan: Arc, - context: Arc, -) -> Result> { - let stream = execute_stream(plan, context)?; - common::collect(stream).await -} - -/// Execute the [ExecutionPlan] and return a single stream of `RecordBatch`es. -/// -/// See [collect] to buffer the `RecordBatch`es in memory. -/// -/// # Aborting Execution -/// -/// Dropping the stream will abort the execution of the query, and free up -/// any allocated resources -pub fn execute_stream( - plan: Arc, - context: Arc, -) -> Result { - match plan.output_partitioning().partition_count() { - 0 => Ok(Box::pin(EmptyRecordBatchStream::new(plan.schema()))), - 1 => plan.execute(0, context), - _ => { - // merge into a single partition - let plan = CoalescePartitionsExec::new(Arc::clone(&plan)); - // CoalescePartitionsExec must produce a single partition - assert_eq!(1, plan.properties().output_partitioning().partition_count()); - plan.execute(0, context) - } - } -} - -/// Execute the [ExecutionPlan] and collect the results in memory -pub async fn collect_partitioned( - plan: Arc, - context: Arc, -) -> Result>> { - let streams = execute_stream_partitioned(plan, context)?; - - let mut join_set = JoinSet::new(); - // Execute the plan and collect the results into batches. - streams.into_iter().enumerate().for_each(|(idx, stream)| { - join_set.spawn(async move { - let result: Result> = stream.try_collect().await; - (idx, result) - }); - }); - - let mut batches = vec![]; - // Note that currently this doesn't identify the thread that panicked - // - // TODO: Replace with [join_next_with_id](https://docs.rs/tokio/latest/tokio/task/struct.JoinSet.html#method.join_next_with_id - // once it is stable - while let Some(result) = join_set.join_next().await { - match result { - Ok((idx, res)) => batches.push((idx, res?)), - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); - } - } - } - } - - batches.sort_by_key(|(idx, _)| *idx); - let batches = batches.into_iter().map(|(_, batch)| batch).collect(); - - Ok(batches) -} - -/// Execute the [ExecutionPlan] and return a vec with one stream per output -/// partition -/// -/// # Aborting Execution -/// -/// Dropping the stream will abort the execution of the query, and free up -/// any allocated resources -pub fn execute_stream_partitioned( - plan: Arc, - context: Arc, -) -> Result> { - let num_partitions = plan.output_partitioning().partition_count(); - let mut streams = Vec::with_capacity(num_partitions); - for i in 0..num_partitions { - streams.push(plan.execute(i, Arc::clone(&context))?); - } - Ok(streams) -} - -/// Executes an input stream and ensures that the resulting stream adheres to -/// the `not null` constraints specified in the `sink_schema`. -/// -/// # Arguments -/// -/// * `input` - An execution plan -/// * `sink_schema` - The schema to be applied to the output stream -/// * `partition` - The partition index to be executed -/// * `context` - The task context -/// -/// # Returns -/// -/// * `Result` - A stream of `RecordBatch`es if successful -/// -/// This function first executes the given input plan for the specified partition -/// and context. It then checks if there are any columns in the input that might -/// violate the `not null` constraints specified in the `sink_schema`. If there are -/// such columns, it wraps the resulting stream to enforce the `not null` constraints -/// by invoking the `check_not_null_contraits` function on each batch of the stream. -pub fn execute_input_stream( - input: Arc, - sink_schema: SchemaRef, - partition: usize, - context: Arc, -) -> Result { - let input_stream = input.execute(partition, context)?; - - debug_assert_eq!(sink_schema.fields().len(), input.schema().fields().len()); - - // Find input columns that may violate the not null constraint. - let risky_columns: Vec<_> = sink_schema - .fields() - .iter() - .zip(input.schema().fields().iter()) - .enumerate() - .filter_map(|(idx, (sink_field, input_field))| { - (!sink_field.is_nullable() && input_field.is_nullable()).then_some(idx) - }) - .collect(); - - if risky_columns.is_empty() { - Ok(input_stream) - } else { - // Check not null constraint on the input stream - Ok(Box::pin(RecordBatchStreamAdapter::new( - sink_schema, - input_stream - .map(move |batch| check_not_null_contraits(batch?, &risky_columns)), - ))) - } -} - -/// Checks a `RecordBatch` for `not null` constraints on specified columns. -/// -/// # Arguments -/// -/// * `batch` - The `RecordBatch` to be checked -/// * `column_indices` - A vector of column indices that should be checked for -/// `not null` constraints. -/// -/// # Returns -/// -/// * `Result` - The original `RecordBatch` if all constraints are met -/// -/// This function iterates over the specified column indices and ensures that none -/// of the columns contain null values. If any column contains null values, an error -/// is returned. -pub fn check_not_null_contraits( - batch: RecordBatch, - column_indices: &Vec, -) -> Result { - for &index in column_indices { - if batch.num_columns() <= index { - return exec_err!( - "Invalid batch column count {} expected > {}", - batch.num_columns(), - index - ); - } - - if batch.column(index).null_count() > 0 { - return exec_err!( - "Invalid batch column at '{}' has null but schema specifies non-nullable", - index - ); - } - } - - Ok(batch) -} - -/// Utility function yielding a string representation of the given [`ExecutionPlan`]. -pub fn get_plan_string(plan: &Arc) -> Vec { - let formatted = displayable(plan.as_ref()).indent(true).to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - actual.iter().map(|elem| elem.to_string()).collect() -} - #[cfg(test)] -mod tests { - use std::any::Any; - use std::sync::Arc; - - use arrow_schema::{Schema, SchemaRef}; - - use datafusion_common::{Result, Statistics}; - use datafusion_execution::{SendableRecordBatchStream, TaskContext}; - - use crate::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; - - #[derive(Debug)] - pub struct EmptyExec; - - impl EmptyExec { - pub fn new(_schema: SchemaRef) -> Self { - Self - } - } - - impl DisplayAs for EmptyExec { - fn fmt_as( - &self, - _t: DisplayFormatType, - _f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - unimplemented!() - } - } - - impl ExecutionPlan for EmptyExec { - fn name(&self) -> &'static str { - Self::static_name() - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - unimplemented!() - } - - fn children(&self) -> Vec<&Arc> { - vec![] - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - unimplemented!() - } - - fn execute( - &self, - _partition: usize, - _context: Arc, - ) -> Result { - unimplemented!() - } - - fn statistics(&self) -> Result { - unimplemented!() - } - } - - #[derive(Debug)] - pub struct RenamedEmptyExec; - - impl RenamedEmptyExec { - pub fn new(_schema: SchemaRef) -> Self { - Self - } - } - - impl DisplayAs for RenamedEmptyExec { - fn fmt_as( - &self, - _t: DisplayFormatType, - _f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - unimplemented!() - } - } - - impl ExecutionPlan for RenamedEmptyExec { - fn name(&self) -> &'static str { - Self::static_name() - } - - fn static_name() -> &'static str - where - Self: Sized, - { - "MyRenamedEmptyExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - unimplemented!() - } - - fn children(&self) -> Vec<&Arc> { - vec![] - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - unimplemented!() - } - - fn execute( - &self, - _partition: usize, - _context: Arc, - ) -> Result { - unimplemented!() - } - - fn statistics(&self) -> Result { - unimplemented!() - } - } - - #[test] - fn test_execution_plan_name() { - let schema1 = Arc::new(Schema::empty()); - let default_name_exec = EmptyExec::new(schema1); - assert_eq!(default_name_exec.name(), "EmptyExec"); - - let schema2 = Arc::new(Schema::empty()); - let renamed_exec = RenamedEmptyExec::new(schema2); - assert_eq!(renamed_exec.name(), "MyRenamedEmptyExec"); - assert_eq!(RenamedEmptyExec::static_name(), "MyRenamedEmptyExec"); - } - - /// A compilation test to ensure that the `ExecutionPlan::name()` method can - /// be called from a trait object. - /// Related ticket: https://github.com/apache/datafusion/pull/11047 - #[allow(dead_code)] - fn use_execution_plan_as_trait_object(plan: &dyn ExecutionPlan) { - let _ = plan.name(); - } -} - pub mod test; diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index f09324c4019c..656d82215bbe 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -414,8 +414,8 @@ pub struct RepartitionExec { struct RepartitionMetrics { /// Time in nanos to execute child operator and fetch batches fetch_time: metrics::Time, - /// Time in nanos to perform repartitioning - repart_time: metrics::Time, + /// Repartitioning elapsed time in nanos + repartition_time: metrics::Time, /// Time in nanos for sending resulting batches to channels. /// /// One metric per output partition. @@ -433,8 +433,8 @@ impl RepartitionMetrics { MetricBuilder::new(metrics).subset_time("fetch_time", input_partition); // Time in nanos to perform repartitioning - let repart_time = - MetricBuilder::new(metrics).subset_time("repart_time", input_partition); + let repartition_time = + MetricBuilder::new(metrics).subset_time("repartition_time", input_partition); // Time in nanos for sending resulting batches to channels let send_time = (0..num_output_partitions) @@ -449,7 +449,7 @@ impl RepartitionMetrics { Self { fetch_time, - repart_time, + repartition_time, send_time, } } @@ -775,7 +775,7 @@ impl RepartitionExec { context: Arc, ) -> Result<()> { let mut partitioner = - BatchPartitioner::try_new(partitioning, metrics.repart_time.clone())?; + BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?; // execute the child operator let timer = metrics.fetch_time.timer(); diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index dbfea253959e..2bd33a476c2f 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -21,7 +21,6 @@ use std::borrow::Borrow; use std::sync::Arc; use crate::{ - aggregates, expressions::{ cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, NthValue, Ntile, PhysicalSortExpr, RowNumber, @@ -104,23 +103,6 @@ pub fn create_window_expr( ignore_nulls: bool, ) -> Result> { Ok(match fun { - WindowFunctionDefinition::AggregateFunction(fun) => { - let aggregate = aggregates::create_aggregate_expr( - fun, - false, - args, - &[], - input_schema, - name, - ignore_nulls, - )?; - window_expr_from_aggregate_expr( - partition_by, - order_by, - window_frame, - aggregate, - ) - } WindowFunctionDefinition::BuiltInWindowFunction(fun) => { Arc::new(BuiltInWindowExpr::new( create_built_in_window_expr(fun, args, input_schema, name, ignore_nulls)?, diff --git a/datafusion/proto/gen/src/main.rs b/datafusion/proto/gen/src/main.rs index d38a41a01ac2..d3b3c92f6065 100644 --- a/datafusion/proto/gen/src/main.rs +++ b/datafusion/proto/gen/src/main.rs @@ -33,6 +33,7 @@ fn main() -> Result<(), String> { .file_descriptor_set_path(&descriptor_path) .out_dir(out_dir) .compile_well_known_types() + .protoc_arg("--experimental_allow_proto3_optional") .extern_path(".google.protobuf", "::pbjson_types") .compile_protos(&[proto_path], &["proto"]) .map_err(|e| format!("protobuf compilation failed: {e}"))?; @@ -52,7 +53,11 @@ fn main() -> Result<(), String> { let prost = proto_dir.join("src/datafusion.rs"); let pbjson = proto_dir.join("src/datafusion.serde.rs"); let common_path = proto_dir.join("src/datafusion_common.rs"); - + println!( + "Copying {} to {}", + prost.clone().display(), + proto_dir.join("src/generated/prost.rs").display() + ); std::fs::copy(prost, proto_dir.join("src/generated/prost.rs")).unwrap(); std::fs::copy(pbjson, proto_dir.join("src/generated/pbjson.rs")).unwrap(); std::fs::copy( diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 4c90297263c4..819130b08e86 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -311,8 +311,6 @@ message LogicalExprNode { // binary expressions BinaryExprNode binary_expr = 4; - // aggregate expressions - AggregateExprNode aggregate_expr = 5; // null checks IsNull is_null_expr = 6; @@ -466,51 +464,6 @@ message InListNode { bool negated = 3; } -enum AggregateFunction { - MIN = 0; - MAX = 1; - // SUM = 2; - // AVG = 3; - // COUNT = 4; - // APPROX_DISTINCT = 5; - // ARRAY_AGG = 6; - // VARIANCE = 7; - // VARIANCE_POP = 8; - // COVARIANCE = 9; - // COVARIANCE_POP = 10; - // STDDEV = 11; - // STDDEV_POP = 12; - // CORRELATION = 13; - // APPROX_PERCENTILE_CONT = 14; - // APPROX_MEDIAN = 15; - // APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; - // GROUPING = 17; - // MEDIAN = 18; - // BIT_AND = 19; - // BIT_OR = 20; - // BIT_XOR = 21; - // BOOL_AND = 22; - // BOOL_OR = 23; - // REGR_SLOPE = 26; - // REGR_INTERCEPT = 27; - // REGR_COUNT = 28; - // REGR_R2 = 29; - // REGR_AVGX = 30; - // REGR_AVGY = 31; - // REGR_SXX = 32; - // REGR_SYY = 33; - // REGR_SXY = 34; - // STRING_AGG = 35; - // NTH_VALUE_AGG = 36; -} - -message AggregateExprNode { - AggregateFunction aggr_function = 1; - repeated LogicalExprNode expr = 2; - bool distinct = 3; - LogicalExprNode filter = 4; - repeated LogicalExprNode order_by = 5; -} message AggregateUDFExprNode { string fun_name = 1; @@ -543,7 +496,6 @@ enum BuiltInWindowFunction { message WindowExprNode { oneof window_function { - AggregateFunction aggr_function = 1; BuiltInWindowFunction built_in_function = 2; string udaf = 3; string udwf = 9; @@ -853,7 +805,6 @@ message PhysicalScalarUdfNode { message PhysicalAggregateExprNode { oneof AggregateFunction { - AggregateFunction aggr_function = 1; string user_defined_aggr_function = 4; } repeated PhysicalExprNode expr = 2; @@ -865,7 +816,6 @@ message PhysicalAggregateExprNode { message PhysicalWindowExprNode { oneof window_function { - AggregateFunction aggr_function = 1; BuiltInWindowFunction built_in_function = 2; string user_defined_aggr_function = 3; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 163a4c044aeb..521a0d90c1ed 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -362,240 +362,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { deserializer.deserialize_struct("datafusion.AggregateExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for AggregateExprNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.aggr_function != 0 { - len += 1; - } - if !self.expr.is_empty() { - len += 1; - } - if self.distinct { - len += 1; - } - if self.filter.is_some() { - len += 1; - } - if !self.order_by.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.AggregateExprNode", len)?; - if self.aggr_function != 0 { - let v = AggregateFunction::try_from(self.aggr_function) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.aggr_function)))?; - struct_ser.serialize_field("aggrFunction", &v)?; - } - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; - } - if self.distinct { - struct_ser.serialize_field("distinct", &self.distinct)?; - } - if let Some(v) = self.filter.as_ref() { - struct_ser.serialize_field("filter", v)?; - } - if !self.order_by.is_empty() { - struct_ser.serialize_field("orderBy", &self.order_by)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for AggregateExprNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "aggr_function", - "aggrFunction", - "expr", - "distinct", - "filter", - "order_by", - "orderBy", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - AggrFunction, - Expr, - Distinct, - Filter, - OrderBy, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), - "expr" => Ok(GeneratedField::Expr), - "distinct" => Ok(GeneratedField::Distinct), - "filter" => Ok(GeneratedField::Filter), - "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = AggregateExprNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.AggregateExprNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut aggr_function__ = None; - let mut expr__ = None; - let mut distinct__ = None; - let mut filter__ = None; - let mut order_by__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::AggrFunction => { - if aggr_function__.is_some() { - return Err(serde::de::Error::duplicate_field("aggrFunction")); - } - aggr_function__ = Some(map_.next_value::()? as i32); - } - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = Some(map_.next_value()?); - } - GeneratedField::Distinct => { - if distinct__.is_some() { - return Err(serde::de::Error::duplicate_field("distinct")); - } - distinct__ = Some(map_.next_value()?); - } - GeneratedField::Filter => { - if filter__.is_some() { - return Err(serde::de::Error::duplicate_field("filter")); - } - filter__ = map_.next_value()?; - } - GeneratedField::OrderBy => { - if order_by__.is_some() { - return Err(serde::de::Error::duplicate_field("orderBy")); - } - order_by__ = Some(map_.next_value()?); - } - } - } - Ok(AggregateExprNode { - aggr_function: aggr_function__.unwrap_or_default(), - expr: expr__.unwrap_or_default(), - distinct: distinct__.unwrap_or_default(), - filter: filter__, - order_by: order_by__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.AggregateExprNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for AggregateFunction { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - let variant = match self { - Self::Min => "MIN", - Self::Max => "MAX", - }; - serializer.serialize_str(variant) - } -} -impl<'de> serde::Deserialize<'de> for AggregateFunction { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "MIN", - "MAX", - ]; - - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = AggregateFunction; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) - } - - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "MIN" => Ok(AggregateFunction::Min), - "MAX" => Ok(AggregateFunction::Max), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), - } - } - } - deserializer.deserialize_any(GeneratedVisitor) - } -} impl serde::Serialize for AggregateMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -9488,9 +9254,6 @@ impl serde::Serialize for LogicalExprNode { logical_expr_node::ExprType::BinaryExpr(v) => { struct_ser.serialize_field("binaryExpr", v)?; } - logical_expr_node::ExprType::AggregateExpr(v) => { - struct_ser.serialize_field("aggregateExpr", v)?; - } logical_expr_node::ExprType::IsNullExpr(v) => { struct_ser.serialize_field("isNullExpr", v)?; } @@ -9592,8 +9355,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "literal", "binary_expr", "binaryExpr", - "aggregate_expr", - "aggregateExpr", "is_null_expr", "isNullExpr", "is_not_null_expr", @@ -9647,7 +9408,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { Alias, Literal, BinaryExpr, - AggregateExpr, IsNullExpr, IsNotNullExpr, NotExpr, @@ -9701,7 +9461,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "alias" => Ok(GeneratedField::Alias), "literal" => Ok(GeneratedField::Literal), "binaryExpr" | "binary_expr" => Ok(GeneratedField::BinaryExpr), - "aggregateExpr" | "aggregate_expr" => Ok(GeneratedField::AggregateExpr), "isNullExpr" | "is_null_expr" => Ok(GeneratedField::IsNullExpr), "isNotNullExpr" | "is_not_null_expr" => Ok(GeneratedField::IsNotNullExpr), "notExpr" | "not_expr" => Ok(GeneratedField::NotExpr), @@ -9778,13 +9537,6 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { return Err(serde::de::Error::duplicate_field("binaryExpr")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::BinaryExpr) -; - } - GeneratedField::AggregateExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("aggregateExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::AggregateExpr) ; } GeneratedField::IsNullExpr => { @@ -12708,11 +12460,6 @@ impl serde::Serialize for PhysicalAggregateExprNode { } if let Some(v) = self.aggregate_function.as_ref() { match v { - physical_aggregate_expr_node::AggregateFunction::AggrFunction(v) => { - let v = AggregateFunction::try_from(*v) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; - struct_ser.serialize_field("aggrFunction", &v)?; - } physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(v) => { struct_ser.serialize_field("userDefinedAggrFunction", v)?; } @@ -12736,8 +12483,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { "ignoreNulls", "fun_definition", "funDefinition", - "aggr_function", - "aggrFunction", "user_defined_aggr_function", "userDefinedAggrFunction", ]; @@ -12749,7 +12494,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { Distinct, IgnoreNulls, FunDefinition, - AggrFunction, UserDefinedAggrFunction, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -12777,7 +12521,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { "distinct" => Ok(GeneratedField::Distinct), "ignoreNulls" | "ignore_nulls" => Ok(GeneratedField::IgnoreNulls), "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), - "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -12838,12 +12581,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) ; } - GeneratedField::AggrFunction => { - if aggregate_function__.is_some() { - return Err(serde::de::Error::duplicate_field("aggrFunction")); - } - aggregate_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_aggregate_expr_node::AggregateFunction::AggrFunction(x as i32)); - } GeneratedField::UserDefinedAggrFunction => { if aggregate_function__.is_some() { return Err(serde::de::Error::duplicate_field("userDefinedAggrFunction")); @@ -15948,11 +15685,6 @@ impl serde::Serialize for PhysicalWindowExprNode { } if let Some(v) = self.window_function.as_ref() { match v { - physical_window_expr_node::WindowFunction::AggrFunction(v) => { - let v = AggregateFunction::try_from(*v) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; - struct_ser.serialize_field("aggrFunction", &v)?; - } physical_window_expr_node::WindowFunction::BuiltInFunction(v) => { let v = BuiltInWindowFunction::try_from(*v) .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; @@ -15983,8 +15715,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { "name", "fun_definition", "funDefinition", - "aggr_function", - "aggrFunction", "built_in_function", "builtInFunction", "user_defined_aggr_function", @@ -15999,7 +15729,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { WindowFrame, Name, FunDefinition, - AggrFunction, BuiltInFunction, UserDefinedAggrFunction, } @@ -16029,7 +15758,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), "name" => Ok(GeneratedField::Name), "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), - "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), "builtInFunction" | "built_in_function" => Ok(GeneratedField::BuiltInFunction), "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -16098,12 +15826,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) ; } - GeneratedField::AggrFunction => { - if window_function__.is_some() { - return Err(serde::de::Error::duplicate_field("aggrFunction")); - } - window_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_window_expr_node::WindowFunction::AggrFunction(x as i32)); - } GeneratedField::BuiltInFunction => { if window_function__.is_some() { return Err(serde::de::Error::duplicate_field("builtInFunction")); @@ -20483,11 +20205,6 @@ impl serde::Serialize for WindowExprNode { } if let Some(v) = self.window_function.as_ref() { match v { - window_expr_node::WindowFunction::AggrFunction(v) => { - let v = AggregateFunction::try_from(*v) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; - struct_ser.serialize_field("aggrFunction", &v)?; - } window_expr_node::WindowFunction::BuiltInFunction(v) => { let v = BuiltInWindowFunction::try_from(*v) .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; @@ -20520,8 +20237,6 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { "windowFrame", "fun_definition", "funDefinition", - "aggr_function", - "aggrFunction", "built_in_function", "builtInFunction", "udaf", @@ -20535,7 +20250,6 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { OrderBy, WindowFrame, FunDefinition, - AggrFunction, BuiltInFunction, Udaf, Udwf, @@ -20565,7 +20279,6 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), - "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), "builtInFunction" | "built_in_function" => Ok(GeneratedField::BuiltInFunction), "udaf" => Ok(GeneratedField::Udaf), "udwf" => Ok(GeneratedField::Udwf), @@ -20628,12 +20341,6 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) ; } - GeneratedField::AggrFunction => { - if window_function__.is_some() { - return Err(serde::de::Error::duplicate_field("aggrFunction")); - } - window_function__ = map_.next_value::<::std::option::Option>()?.map(|x| window_expr_node::WindowFunction::AggrFunction(x as i32)); - } GeneratedField::BuiltInFunction => { if window_function__.is_some() { return Err(serde::de::Error::duplicate_field("builtInFunction")); diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 606fe3c1699f..070c9b31d3d4 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -488,7 +488,7 @@ pub struct SubqueryAliasNode { pub struct LogicalExprNode { #[prost( oneof = "logical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" )] pub expr_type: ::core::option::Option, } @@ -508,9 +508,6 @@ pub mod logical_expr_node { /// binary expressions #[prost(message, tag = "4")] BinaryExpr(super::BinaryExprNode), - /// aggregate expressions - #[prost(message, tag = "5")] - AggregateExpr(::prost::alloc::boxed::Box), /// null checks #[prost(message, tag = "6")] IsNullExpr(::prost::alloc::boxed::Box), @@ -733,20 +730,6 @@ pub struct InListNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct AggregateExprNode { - #[prost(enumeration = "AggregateFunction", tag = "1")] - pub aggr_function: i32, - #[prost(message, repeated, tag = "2")] - pub expr: ::prost::alloc::vec::Vec, - #[prost(bool, tag = "3")] - pub distinct: bool, - #[prost(message, optional, boxed, tag = "4")] - pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, repeated, tag = "5")] - pub order_by: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] pub struct AggregateUdfExprNode { #[prost(string, tag = "1")] pub fun_name: ::prost::alloc::string::String, @@ -785,7 +768,7 @@ pub struct WindowExprNode { pub window_frame: ::core::option::Option, #[prost(bytes = "vec", optional, tag = "10")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, - #[prost(oneof = "window_expr_node::WindowFunction", tags = "1, 2, 3, 9")] + #[prost(oneof = "window_expr_node::WindowFunction", tags = "2, 3, 9")] pub window_function: ::core::option::Option, } /// Nested message and enum types in `WindowExprNode`. @@ -793,8 +776,6 @@ pub mod window_expr_node { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum WindowFunction { - #[prost(enumeration = "super::AggregateFunction", tag = "1")] - AggrFunction(i32), #[prost(enumeration = "super::BuiltInWindowFunction", tag = "2")] BuiltInFunction(i32), #[prost(string, tag = "3")] @@ -1301,7 +1282,7 @@ pub struct PhysicalAggregateExprNode { pub ignore_nulls: bool, #[prost(bytes = "vec", optional, tag = "7")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, - #[prost(oneof = "physical_aggregate_expr_node::AggregateFunction", tags = "1, 4")] + #[prost(oneof = "physical_aggregate_expr_node::AggregateFunction", tags = "4")] pub aggregate_function: ::core::option::Option< physical_aggregate_expr_node::AggregateFunction, >, @@ -1311,8 +1292,6 @@ pub mod physical_aggregate_expr_node { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum AggregateFunction { - #[prost(enumeration = "super::AggregateFunction", tag = "1")] - AggrFunction(i32), #[prost(string, tag = "4")] UserDefinedAggrFunction(::prost::alloc::string::String), } @@ -1332,7 +1311,7 @@ pub struct PhysicalWindowExprNode { pub name: ::prost::alloc::string::String, #[prost(bytes = "vec", optional, tag = "9")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, - #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "1, 2, 3")] + #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "2, 3")] pub window_function: ::core::option::Option< physical_window_expr_node::WindowFunction, >, @@ -1342,8 +1321,6 @@ pub mod physical_window_expr_node { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum WindowFunction { - #[prost(enumeration = "super::AggregateFunction", tag = "1")] - AggrFunction(i32), #[prost(enumeration = "super::BuiltInWindowFunction", tag = "2")] BuiltInFunction(i32), #[prost(string, tag = "3")] @@ -1941,65 +1918,6 @@ pub struct PartitionStats { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] -pub enum AggregateFunction { - Min = 0, - /// SUM = 2; - /// AVG = 3; - /// COUNT = 4; - /// APPROX_DISTINCT = 5; - /// ARRAY_AGG = 6; - /// VARIANCE = 7; - /// VARIANCE_POP = 8; - /// COVARIANCE = 9; - /// COVARIANCE_POP = 10; - /// STDDEV = 11; - /// STDDEV_POP = 12; - /// CORRELATION = 13; - /// APPROX_PERCENTILE_CONT = 14; - /// APPROX_MEDIAN = 15; - /// APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; - /// GROUPING = 17; - /// MEDIAN = 18; - /// BIT_AND = 19; - /// BIT_OR = 20; - /// BIT_XOR = 21; - /// BOOL_AND = 22; - /// BOOL_OR = 23; - /// REGR_SLOPE = 26; - /// REGR_INTERCEPT = 27; - /// REGR_COUNT = 28; - /// REGR_R2 = 29; - /// REGR_AVGX = 30; - /// REGR_AVGY = 31; - /// REGR_SXX = 32; - /// REGR_SYY = 33; - /// REGR_SXY = 34; - /// STRING_AGG = 35; - /// NTH_VALUE_AGG = 36; - Max = 1, -} -impl AggregateFunction { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - AggregateFunction::Min => "MIN", - AggregateFunction::Max => "MAX", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "MIN" => Some(Self::Min), - "MAX" => Some(Self::Max), - _ => None, - } - } -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] pub enum BuiltInWindowFunction { RowNumber = 0, Rank = 1, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 5e9b9af49ae9..6c4c07428bd3 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -22,11 +22,13 @@ use datafusion_common::{ exec_datafusion_err, internal_err, plan_datafusion_err, Result, ScalarValue, TableReference, UnnestOptions, }; +use datafusion_expr::expr::Unnest; +use datafusion_expr::expr::{Alias, Placeholder}; +use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ - expr::{self, Alias, InList, Placeholder, Sort, Unnest, WindowFunction}, + expr::{self, InList, Sort, WindowFunction}, logical_plan::{PlanType, StringifiedPlan}, - AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, Case, Cast, Expr, - ExprFunctionExt, GroupingSet, + Between, BinaryExpr, BuiltInWindowFunction, Case, Cast, Expr, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -136,15 +138,6 @@ impl From<&protobuf::StringifiedPlan> for StringifiedPlan { } } -impl From for AggregateFunction { - fn from(agg_fun: protobuf::AggregateFunction) -> Self { - match agg_fun { - protobuf::AggregateFunction::Min => Self::Min, - protobuf::AggregateFunction::Max => Self::Max, - } - } -} - impl From for BuiltInWindowFunction { fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self { match built_in_function { @@ -231,12 +224,6 @@ impl From for JoinConstraint { } } -pub fn parse_i32_to_aggregate_function(value: &i32) -> Result { - protobuf::AggregateFunction::try_from(*value) - .map(|a| a.into()) - .map_err(|_| Error::unknown("AggregateFunction", *value)) -} - pub fn parse_expr( proto: &protobuf::LogicalExprNode, registry: &dyn FunctionRegistry, @@ -297,24 +284,6 @@ pub fn parse_expr( // TODO: support proto for null treatment match window_function { - window_expr_node::WindowFunction::AggrFunction(i) => { - let aggr_function = parse_i32_to_aggregate_function(i)?; - - Expr::WindowFunction(WindowFunction::new( - expr::WindowFunctionDefinition::AggregateFunction(aggr_function), - vec![parse_required_expr( - expr.expr.as_deref(), - registry, - "expr", - codec, - )?], - )) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .build() - .map_err(Error::DataFusionError) - } window_expr_node::WindowFunction::BuiltInFunction(i) => { let built_in_function = protobuf::BuiltInWindowFunction::try_from(*i) .map_err(|_| Error::unknown("BuiltInWindowFunction", *i))? @@ -379,19 +348,6 @@ pub fn parse_expr( } } } - ExprType::AggregateExpr(expr) => { - let fun = parse_i32_to_aggregate_function(&expr.aggr_function)?; - - Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, - parse_exprs(&expr.expr, registry, codec)?, - expr.distinct, - parse_optional_expr(expr.filter.as_deref(), registry, codec)? - .map(Box::new), - parse_vec_expr(&expr.order_by, registry, codec)?, - None, - ))) - } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( parse_required_expr(alias.expr.as_deref(), registry, "expr", codec)?, alias diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index c2441892e8a8..74d9d61b3a7f 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -25,9 +25,9 @@ use datafusion_expr::expr::{ InList, Like, Placeholder, ScalarFunction, Sort, Unnest, }; use datafusion_expr::{ - logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, - BuiltInWindowFunction, Expr, JoinConstraint, JoinType, TryCast, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + logical_plan::PlanType, logical_plan::StringifiedPlan, BuiltInWindowFunction, Expr, + JoinConstraint, JoinType, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, }; use crate::protobuf::{ @@ -111,15 +111,6 @@ impl From<&StringifiedPlan> for protobuf::StringifiedPlan { } } -impl From<&AggregateFunction> for protobuf::AggregateFunction { - fn from(value: &AggregateFunction) -> Self { - match value { - AggregateFunction::Min => Self::Min, - AggregateFunction::Max => Self::Max, - } - } -} - impl From<&BuiltInWindowFunction> for protobuf::BuiltInWindowFunction { fn from(value: &BuiltInWindowFunction) -> Self { match value { @@ -319,12 +310,6 @@ pub fn serialize_expr( null_treatment: _, }) => { let (window_function, fun_definition) = match fun { - WindowFunctionDefinition::AggregateFunction(fun) => ( - protobuf::window_expr_node::WindowFunction::AggrFunction( - protobuf::AggregateFunction::from(fun).into(), - ), - None, - ), WindowFunctionDefinition::BuiltInWindowFunction(fun) => ( protobuf::window_expr_node::WindowFunction::BuiltInFunction( protobuf::BuiltInWindowFunction::from(fun).into(), @@ -383,29 +368,6 @@ pub fn serialize_expr( ref order_by, null_treatment: _, }) => match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - let aggr_function = match fun { - AggregateFunction::Min => protobuf::AggregateFunction::Min, - AggregateFunction::Max => protobuf::AggregateFunction::Max, - }; - - let aggregate_expr = protobuf::AggregateExprNode { - aggr_function: aggr_function.into(), - expr: serialize_exprs(args, codec)?, - distinct: *distinct, - filter: match filter { - Some(e) => Some(Box::new(serialize_expr(e, codec)?)), - None => None, - }, - order_by: match order_by { - Some(e) => serialize_exprs(e, codec)?, - None => vec![], - }, - }; - protobuf::LogicalExprNode { - expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))), - } - } AggregateFunctionDefinition::UDF(fun) => { let mut buf = Vec::new(); let _ = codec.try_encode_udaf(fun, &mut buf); diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 5ecca5147805..bc0a19336bae 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -145,15 +145,6 @@ pub fn parse_physical_window_expr( let fun = if let Some(window_func) = proto.window_function.as_ref() { match window_func { - protobuf::physical_window_expr_node::WindowFunction::AggrFunction(n) => { - let f = protobuf::AggregateFunction::try_from(*n).map_err(|_| { - proto_error(format!( - "Received an unknown window aggregate function: {n}" - )) - })?; - - WindowFunctionDefinition::AggregateFunction(f.into()) - } protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => { let f = protobuf::BuiltInWindowFunction::try_from(*n).map_err(|_| { proto_error(format!( diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 3c0d6664da17..a79eafe43846 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -35,7 +35,7 @@ use datafusion::datasource::physical_plan::{AvroExec, CsvExec}; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::FunctionRegistry; use datafusion::physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; -use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateMode}; +use datafusion::physical_plan::aggregates::AggregateMode; use datafusion::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; @@ -479,30 +479,10 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { ExprType::AggregateExpr(agg_node) => { let input_phy_expr: Vec> = agg_node.expr.iter() .map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec)).collect::>>()?; - let ordering_req: Vec = agg_node.ordering_req.iter() + let _ordering_req: Vec = agg_node.ordering_req.iter() .map(|e| parse_physical_sort_expr(e, registry, &physical_schema, extension_codec)).collect::>>()?; agg_node.aggregate_function.as_ref().map(|func| { match func { - AggregateFunction::AggrFunction(i) => { - let aggr_function = protobuf::AggregateFunction::try_from(*i) - .map_err( - |_| { - proto_error(format!( - "Received an unknown aggregate function: {i}" - )) - }, - )?; - - create_aggregate_expr( - &aggr_function.into(), - agg_node.distinct, - input_phy_expr.as_slice(), - &ordering_req, - &physical_schema, - name.to_string(), - agg_node.ignore_nulls, - ) - } AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = match &agg_node.fun_definition { Some(buf) => extension_codec.try_decode_udaf(udaf_name, buf)?, diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 140482b9903c..57cd22a99ae1 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -24,8 +24,8 @@ use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, InListExpr, IsNotNullExpr, - IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, Ntile, Rank, - RankType, RowNumber, TryCastExpr, WindowShift, + IsNullExpr, Literal, NegativeExpr, NotExpr, NthValue, Ntile, Rank, RankType, + RowNumber, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -60,7 +60,7 @@ pub fn serialize_physical_aggr_expr( let name = a.fun().name().to_string(); let mut buf = Vec::new(); codec.try_encode_udaf(a.fun(), &mut buf)?; - return Ok(protobuf::PhysicalExprNode { + Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), @@ -71,35 +71,15 @@ pub fn serialize_physical_aggr_expr( fun_definition: (!buf.is_empty()).then_some(buf) }, )), - }); + }) + } else { + unreachable!("No other types exists besides AggergationFunctionExpr"); } - - let AggrFn { - inner: aggr_function, - distinct, - } = aggr_expr_to_aggr_fn(aggr_expr.as_ref())?; - - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( - protobuf::PhysicalAggregateExprNode { - aggregate_function: Some( - physical_aggregate_expr_node::AggregateFunction::AggrFunction( - aggr_function as i32, - ), - ), - expr: expressions, - ordering_req, - distinct, - ignore_nulls: false, - fun_definition: None, - }, - )), - }) } fn serialize_physical_window_aggr_expr( aggr_expr: &dyn AggregateExpr, - window_frame: &WindowFrame, + _window_frame: &WindowFrame, codec: &dyn PhysicalExtensionCodec, ) -> Result<(physical_window_expr_node::WindowFunction, Option>)> { if let Some(a) = aggr_expr.as_any().downcast_ref::() { @@ -119,23 +99,7 @@ fn serialize_physical_window_aggr_expr( (!buf.is_empty()).then_some(buf), )) } else { - let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(aggr_expr)?; - if distinct { - return not_impl_err!( - "Distinct aggregate functions not supported in window expressions" - ); - } - - if !window_frame.start_bound.is_unbounded() { - return Err(DataFusionError::Internal(format!( - "Unbounded start bound in WindowFrame = {window_frame}" - ))); - } - - Ok(( - physical_window_expr_node::WindowFunction::AggrFunction(inner as i32), - None, - )) + unreachable!("No other types exists besides AggergationFunctionExpr"); } } @@ -252,29 +216,6 @@ pub fn serialize_physical_window_expr( }) } -struct AggrFn { - inner: protobuf::AggregateFunction, - distinct: bool, -} - -fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { - let aggr_expr = expr.as_any(); - - // TODO: remove Min and Max - let inner = if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Min - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Max - } else { - return not_impl_err!("Aggregate function not supported: {expr:?}"); - }; - - Ok(AggrFn { - inner, - distinct: false, - }) -} - pub fn serialize_physical_sort_exprs( sort_exprs: I, codec: &dyn PhysicalExtensionCodec, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index f7ad2b9b6158..d150c474e88f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -42,9 +42,10 @@ use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::expr_fn::{ approx_median, approx_percentile_cont, approx_percentile_cont_with_weight, count, - count_distinct, covar_pop, covar_samp, first_value, grouping, median, stddev, - stddev_pop, sum, var_pop, var_sample, + count_distinct, covar_pop, covar_samp, first_value, grouping, max, median, min, + stddev, stddev_pop, sum, var_pop, var_sample, }; +use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_nested::map::map; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; @@ -61,10 +62,10 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateFunction, AggregateUDF, ColumnarValue, ExprFunctionExt, - ExprSchemable, Literal, LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, - Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunctionDefinition, WindowUDF, WindowUDFImpl, + Accumulator, AggregateUDF, ColumnarValue, ExprFunctionExt, ExprSchemable, Literal, + LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, Signature, TryCast, Volatility, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, + WindowUDFImpl, }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::expr_fn::{ @@ -875,7 +876,9 @@ async fn roundtrip_expr_api() -> Result<()> { covar_pop(lit(1.5), lit(2.2)), corr(lit(1.5), lit(2.2)), sum(lit(1)), + max(lit(1)), median(lit(2)), + min(lit(2)), var_sample(lit(2.2)), var_pop(lit(2.2)), stddev(lit(2.2)), @@ -2284,7 +2287,7 @@ fn roundtrip_window() { ); let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("col1")], )) .partition_by(vec![col("col1")]) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index caab8f0a77f7..9c180e219b5b 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -25,8 +25,10 @@ use std::vec; use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; use datafusion::physical_expr_common::aggregate::AggregateExprBuilder; +use datafusion_functions_aggregate::min_max::max_udaf; use prost::Message; +use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, MyRegexUdfNode}; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; @@ -43,7 +45,7 @@ use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::sum::sum_udaf; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; use datafusion::physical_expr::aggregate::utils::down_cast_any_ref; -use datafusion::physical_expr::expressions::{Literal, Max}; +use datafusion::physical_expr::expressions::Literal; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ @@ -94,8 +96,6 @@ use datafusion_proto::physical_plan::{ }; use datafusion_proto::protobuf; -use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, MyRegexUdfNode}; - /// Perform a serde roundtrip and assert that the string representation of the before and after plans /// are identical. Note that this often isn't sufficient to guarantee that no information is /// lost during serde because the string representation of a plan often only shows a subset of state. @@ -911,11 +911,18 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { )), input, )?); + let aggr_expr = AggregateExprBuilder::new( + max_udaf(), + vec![udf_expr.clone() as Arc], + ) + .schema(schema.clone()) + .name("max") + .build()?; let window = Arc::new(WindowAggExec::try_new( vec![Arc::new(PlainAggregateWindowExpr::new( - Arc::new(Max::new(udf_expr.clone(), "max", DataType::Int64)), - &[col("author", &schema)?], + aggr_expr.clone(), + &[col("author", &schema.clone())?], &[], Arc::new(WindowFrame::new(None)), ))], @@ -926,7 +933,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { let aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::new(vec![], vec![], vec![]), - vec![Arc::new(Max::new(udf_expr, "max", DataType::Int64))], + vec![aggr_expr.clone()], vec![None], window, schema.clone(), diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 2506ef740fde..d16d08b041ae 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::str::FromStr; - use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::DataType; @@ -26,8 +24,7 @@ use datafusion_common::{ }; use datafusion_expr::planner::PlannerResult; use datafusion_expr::{ - expr, AggregateFunction, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, - WindowFunctionDefinition, + expr, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFunctionDefinition, }; use datafusion_expr::{ expr::{ScalarFunction, Unnest}, @@ -38,7 +35,6 @@ use sqlparser::ast::{ FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, NullTreatment, ObjectName, OrderByExpr, WindowType, }; - use strum::IntoEnumIterator; /// Suggest a valid function based on an invalid input function name @@ -51,7 +47,6 @@ pub fn suggest_valid_function( // All aggregate functions and builtin window functions let mut funcs = Vec::new(); - funcs.extend(AggregateFunction::iter().map(|func| func.to_string())); funcs.extend(ctx.udaf_names()); funcs.extend(BuiltInWindowFunction::iter().map(|func| func.to_string())); funcs.extend(ctx.udwf_names()); @@ -62,7 +57,6 @@ pub fn suggest_valid_function( let mut funcs = Vec::new(); funcs.extend(ctx.udf_names()); - funcs.extend(AggregateFunction::iter().map(|func| func.to_string())); funcs.extend(ctx.udaf_names()); funcs @@ -324,31 +318,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; if let Ok(fun) = self.find_window_func(&name) { - return match fun { - WindowFunctionDefinition::AggregateFunction(aggregate_fun) => { - let args = - self.function_args_to_expr(args, schema, planner_context)?; - - Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(aggregate_fun), - args, - )) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build() - } - _ => Expr::WindowFunction(expr::WindowFunction::new( - fun, - self.function_args_to_expr(args, schema, planner_context)?, - )) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build(), - }; + return Expr::WindowFunction(expr::WindowFunction::new( + fun, + self.function_args_to_expr(args, schema, planner_context)?, + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build(); } } else { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function @@ -375,32 +353,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { null_treatment, ))); } - - // next, aggregate built-ins - if let Ok(fun) = AggregateFunction::from_str(&name) { - let order_by = self.order_by_to_sort_expr( - order_by, - schema, - planner_context, - true, - None, - )?; - let order_by = (!order_by.is_empty()).then_some(order_by); - let args = self.function_args_to_expr(args, schema, planner_context)?; - let filter: Option> = filter - .map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context)) - .transpose()? - .map(Box::new); - - return Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, - args, - distinct, - filter, - order_by, - null_treatment, - ))); - }; } // Could not find the relevant function, so return an error diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 71ff7c03bea2..b80ffb6aed3f 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -21,8 +21,8 @@ use datafusion_expr::planner::PlannerResult; use datafusion_expr::planner::RawDictionaryExpr; use datafusion_expr::planner::RawFieldAccessExpr; use sqlparser::ast::{ - CastKind, DictionaryField, Expr as SQLExpr, StructField, Subscript, TrimWhereField, - Value, + CastKind, DictionaryField, Expr as SQLExpr, MapEntry, StructField, Subscript, + TrimWhereField, Value, }; use datafusion_common::{ @@ -628,6 +628,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Dictionary(fields) => { self.try_plan_dictionary_literal(fields, schema, planner_context) } + SQLExpr::Map(map) => { + self.try_plan_map_literal(map.entries, schema, planner_context) + } _ => not_impl_err!("Unsupported ast node in sqltorel: {sql:?}"), } } @@ -711,7 +714,29 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { PlannerResult::Original(expr) => raw_expr = expr, } } - not_impl_err!("Unsupported dictionary literal: {raw_expr:?}") + not_impl_err!("Dictionary not supported by ExprPlanner: {raw_expr:?}") + } + + fn try_plan_map_literal( + &self, + entries: Vec, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let mut exprs: Vec<_> = entries + .into_iter() + .flat_map(|entry| vec![entry.key, entry.value].into_iter()) + .map(|expr| self.sql_expr_to_logical_expr(*expr, schema, planner_context)) + .collect::>>()?; + for planner in self.context_provider.get_expr_planners() { + match planner.plan_make_map(exprs)? { + PlannerResult::Planned(expr) => { + return Ok(expr); + } + PlannerResult::Original(expr) => exprs = expr, + } + } + not_impl_err!("MAP not supported by ExprPlanner: {exprs:?}") } // Handles a call to struct(...) where the arguments are named. For example diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index e144dfd649d2..9b44848a91a8 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -522,8 +522,9 @@ impl Unparser<'_> { } } - /// This function can convert more [`Expr`] types than `expr_to_sql`, returning an [`Unparsed`] - /// like `Sort` expressions to `OrderByExpr` expressions. + /// This function can convert more [`Expr`] types than `expr_to_sql`, + /// returning an [`Unparsed`] like `Sort` expressions to `OrderByExpr` + /// expressions. pub fn expr_to_unparsed(&self, expr: &Expr) -> Result { match expr { Expr::Sort(Sort { diff --git a/datafusion/sql/src/unparser/mod.rs b/datafusion/sql/src/unparser/mod.rs index 83ae64ba238b..b2fd32566aa8 100644 --- a/datafusion/sql/src/unparser/mod.rs +++ b/datafusion/sql/src/unparser/mod.rs @@ -29,6 +29,8 @@ pub use plan::plan_to_sql; use self::dialect::{DefaultDialect, Dialect}; pub mod dialect; +pub use expr::Unparsed; + /// Convert a DataFusion [`Expr`] to [`sqlparser::ast::Expr`] /// /// See [`expr_to_sql`] for background. `Unparser` allows greater control of diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index b30e109881c2..e08f25d3c27c 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -142,6 +142,14 @@ impl Unparser<'_> { return Ok(*body); } + // If no projection is set, add a wildcard projection to the select + // which will be translated to `SELECT *` in the SQL statement + if !select_builder.already_projected() { + select_builder.projection(vec![ast::SelectItem::Wildcard( + ast::WildcardAdditionalOptions::default(), + )]); + } + let mut twj = select_builder.pop_from().unwrap(); twj.relation(relation_builder); select_builder.push_from(twj); diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index a52333e54fac..bae3ec2e2779 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -20,7 +20,7 @@ use std::vec; use arrow_schema::*; use datafusion_common::{DFSchema, Result, TableReference}; -use datafusion_expr::test::function_stub::{count_udaf, sum_udaf}; +use datafusion_expr::test::function_stub::{count_udaf, max_udaf, min_udaf, sum_udaf}; use datafusion_expr::{col, table_scan}; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ @@ -381,7 +381,9 @@ fn roundtrip_statement_with_dialect() -> Result<()> { .parse_statement()?; let context = MockContextProvider::default() - .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())) + .with_udaf(max_udaf()) + .with_udaf(min_udaf()); let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel .sql_statement_to_plan(statement) @@ -449,6 +451,30 @@ fn test_table_references_in_plan_to_sql() { ); } +#[test] +fn test_table_scan_with_no_projection_in_plan_to_sql() { + fn test(table_name: &str, expected_sql: &str) { + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("value", DataType::Utf8, false), + ]); + + let plan = table_scan(Some(table_name), &schema, None) + .unwrap() + .build() + .unwrap(); + let sql = plan_to_sql(&plan).unwrap(); + assert_eq!(format!("{}", sql), expected_sql) + } + + test( + "catalog.schema.table", + "SELECT * FROM catalog.\"schema\".\"table\"", + ); + test("schema.table", "SELECT * FROM \"schema\".\"table\""); + test("table", "SELECT * FROM \"table\""); +} + #[test] fn test_pretty_roundtrip() -> Result<()> { let schema = Schema::new(vec![ diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 40a58827b388..c1b2246e4980 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -42,7 +42,8 @@ use datafusion_sql::{ use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_functions_aggregate::{ - approx_median::approx_median_udaf, count::count_udaf, + approx_median::approx_median_udaf, count::count_udaf, min_max::max_udaf, + min_max::min_udaf, }; use datafusion_functions_aggregate::{average::avg_udaf, grouping::grouping_udaf}; use rstest::rstest; @@ -2764,6 +2765,8 @@ fn logical_plan_with_dialect_and_options( .with_udaf(approx_median_udaf()) .with_udaf(count_udaf()) .with_udaf(avg_udaf()) + .with_udaf(min_udaf()) + .with_udaf(max_udaf()) .with_udaf(grouping_udaf()) .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 6ec1e0c52690..ee72289d66eb 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1243,6 +1243,12 @@ SELECT APPROX_PERCENTILE_CONT(v, 0.5) FROM (VALUES (1), (2), (3), (NULL), (NULL) ---- 2 +# percentile_cont_with_nulls_only +query I +SELECT APPROX_PERCENTILE_CONT(v, 0.5) FROM (VALUES (CAST(NULL as INT))) as t (v); +---- +NULL + # csv_query_cube_avg query TIR SELECT c1, c2, AVG(c3) FROM aggregate_test_100 GROUP BY CUBE (c1, c2) ORDER BY c1, c2 diff --git a/datafusion/sqllogictest/test_files/clickbench.slt b/datafusion/sqllogictest/test_files/clickbench.slt index c2dba435263d..733c0a3cd972 100644 --- a/datafusion/sqllogictest/test_files/clickbench.slt +++ b/datafusion/sqllogictest/test_files/clickbench.slt @@ -274,5 +274,23 @@ query PI SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-14' AND "EventDate"::INT::DATE <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000; ---- +# Clickbench "Extended" queries that test count distinct + +query III +SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; +---- +1 1 1 + +query III +SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; +---- +1 1 1 + +query TIIII +SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; +---- +� 1 1 1 1 + + statement ok drop table hits; diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 5a1733460120..eae4f428b4b4 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -176,7 +176,6 @@ EXPLAIN VERBOSE SELECT a, b, c FROM simple_explain_test initial_logical_plan 01)Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c 02)--TableScan: simple_explain_test -logical_plan after apply_function_rewrites SAME TEXT AS ABOVE logical_plan after inline_table_scan SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE logical_plan after count_wildcard_rule SAME TEXT AS ABOVE diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 29f1c4384daf..21fea4ad1025 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -1130,4 +1130,51 @@ SELECT * FROM (SELECT * FROM t1 CROSS JOIN t2) WHERE t1.a + t2.a IS NULL; ---- -NULL NULL \ No newline at end of file +NULL NULL + +statement ok +CREATE TABLE t5(v0 BIGINT, v1 STRING, v2 BIGINT, v3 STRING, v4 BOOLEAN); + +statement ok +CREATE TABLE t1(v0 BIGINT, v1 STRING); + +statement ok +CREATE TABLE t0(v0 BIGINT, v1 DOUBLE); + +query TT +explain SELECT * +FROM t1 +NATURAL JOIN t5 +INNER JOIN t0 ON (t0.v1 + t5.v0) > 0 +WHERE t0.v1 = t1.v0; +---- +logical_plan +01)Projection: t1.v0, t1.v1, t5.v2, t5.v3, t5.v4, t0.v0, t0.v1 +02)--Inner Join: CAST(t1.v0 AS Float64) = t0.v1 Filter: t0.v1 + CAST(t5.v0 AS Float64) > Float64(0) +03)----Projection: t1.v0, t1.v1, t5.v0, t5.v2, t5.v3, t5.v4 +04)------Inner Join: Using t1.v0 = t5.v0, t1.v1 = t5.v1 +05)--------TableScan: t1 projection=[v0, v1] +06)--------TableScan: t5 projection=[v0, v1, v2, v3, v4] +07)----TableScan: t0 projection=[v0, v1] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(CAST(t1.v0 AS Float64)@6, v1@1)], filter=v1@1 + CAST(v0@0 AS Float64) > 0, projection=[v0@0, v1@1, v2@3, v3@4, v4@5, v0@7, v1@8] +03)----CoalescePartitionsExec +04)------ProjectionExec: expr=[v0@0 as v0, v1@1 as v1, v0@2 as v0, v2@3 as v2, v3@4 as v3, v4@5 as v4, CAST(v0@0 AS Float64) as CAST(t1.v0 AS Float64)] +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(v0@0, v0@0), (v1@1, v1@1)], projection=[v0@0, v1@1, v0@2, v2@4, v3@5, v4@6] +08)--------------MemoryExec: partitions=1, partition_sizes=[0] +09)--------------MemoryExec: partitions=1, partition_sizes=[0] +10)----MemoryExec: partitions=1, partition_sizes=[0] + + + +statement ok +drop table t5; + +statement ok +drop table t1; + +statement ok +drop table t0; diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index e530e14df66e..11998eea9044 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -310,3 +310,152 @@ VALUES (MAP(['a'], [1])), (MAP(['b'], [2])), (MAP(['c', 'a'], [3, 1])) {a: 1} {b: 2} {c: 3, a: 1} + +query ? +SELECT MAP {'a':1, 'b':2, 'c':3}; +---- +{a: 1, b: 2, c: 3} + +query ? +SELECT MAP {'a':1, 'b':2, 'c':3 } FROM t; +---- +{a: 1, b: 2, c: 3} +{a: 1, b: 2, c: 3} +{a: 1, b: 2, c: 3} + +query I +SELECT MAP {'a':1, 'b':2, 'c':3}['a']; +---- +1 + +query I +SELECT MAP {'a':1, 'b':2, 'c':3 }['a'] FROM t; +---- +1 +1 +1 + +# TODO(https://github.com/sqlparser-rs/sqlparser-rs/pull/1361): support parsing an empty map. Enable this after upgrading sqlparser-rs. +# query ? +# SELECT MAP {}; +# ---- +# {} + +# values contain null +query ? +SELECT MAP {'a': 1, 'b': null}; +---- +{a: 1, b: } + +# keys contain null +query error DataFusion error: Execution error: map key cannot be null +SELECT MAP {'a': 1, null: 2} + +# array as key +query ? +SELECT MAP {[1,2,3]:1, [2,4]:2}; +---- + {[1, 2, 3]: 1, [2, 4]: 2} + +# array with different type as key +# expect to fail due to type coercion error +query error +SELECT MAP {[1,2,3]:1, ['a', 'b']:2}; + +# array as value +query ? +SELECT MAP {'a':[1,2,3], 'b':[2,4]}; +---- +{a: [1, 2, 3], b: [2, 4]} + +# array with different type as value +# expect to fail due to type coercion error +query error +SELECT MAP {'a':[1,2,3], 'b':['a', 'b']}; + +# struct as key +query ? +SELECT MAP {{'a':1, 'b':2}:1, {'a':3, 'b':4}:2}; +---- +{{a: 1, b: 2}: 1, {a: 3, b: 4}: 2} + +# struct with different fields as key +# expect to fail due to type coercion error +query error +SELECT MAP {{'a':1, 'b':2}:1, {'c':3, 'd':4}:2}; + +# struct as value +query ? +SELECT MAP {'a':{'b':1, 'c':2}, 'b':{'b':3, 'c':4}}; +---- +{a: {b: 1, c: 2}, b: {b: 3, c: 4}} + +# struct with different fields as value +# expect to fail due to type coercion error +query error +SELECT MAP {'a':{'b':1, 'c':2}, 'b':{'c':3, 'd':4}}; + +# map as key +query ? +SELECT MAP { MAP {1:'a', 2:'b'}:1, MAP {1:'c', 2:'d'}:2 }; +---- +{{1: a, 2: b}: 1, {1: c, 2: d}: 2} + +# map with different keys as key +query ? +SELECT MAP { MAP {1:'a', 2:'b', 3:'c'}:1, MAP {2:'c', 4:'d'}:2 }; +---- + {{1: a, 2: b, 3: c}: 1, {2: c, 4: d}: 2} + +# map as value +query ? +SELECT MAP {1: MAP {1:'a', 2:'b'}, 2: MAP {1:'c', 2:'d'} }; +---- +{1: {1: a, 2: b}, 2: {1: c, 2: d}} + +# map with different keys as value +query ? +SELECT MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} }; +---- +{a: {1: a, 2: b, 3: c}, b: {2: c, 4: d}} + +# complex map for each row +query ? +SELECT MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} } from t; +---- +{a: {1: a, 2: b, 3: c}, b: {2: c, 4: d}} +{a: {1: a, 2: b, 3: c}, b: {2: c, 4: d}} +{a: {1: a, 2: b, 3: c}, b: {2: c, 4: d}} + +# access map with non-existent key +query ? +SELECT MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} }['c']; +---- +NULL + +# access map with null key +query error +SELECT MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} }[NULL]; + +query ? +SELECT MAP { 'a': 1, 2: 3 }; +---- +{a: 1, 2: 3} + +# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key +# query ? +# SELECT MAP { 1: 'a', 2: 'b', 3: 'c' }[1]; +# ---- +# a + +# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key +# query ? +# SELECT MAP { MAP {1:'a', 2:'b'}:1, MAP {1:'c', 2:'d'}:2 }[MAP {1:'a', 2:'b'}]; +# ---- +# 1 + +# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key +# query ? +# SELECT MAKE_MAP(1, null, 2, 33, 3, null)[2]; +# ---- +# 33 diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index 3f9a4793f655..763b4e99c614 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -321,6 +321,40 @@ logical_plan 02)--Filter: CAST(test.column2_utf8 AS Utf8View) = test.column1_utf8view 03)----TableScan: test projection=[column1_utf8, column2_utf8, column1_utf8view] +## Test distinct aggregates +query III +SELECT + COUNT(DISTINCT column1_utf8), + COUNT(DISTINCT column1_utf8view), + COUNT(DISTINCT column1_dict) +FROM test; +---- +3 3 3 + +query III +SELECT + COUNT(DISTINCT column1_utf8), + COUNT(DISTINCT column1_utf8view), + COUNT(DISTINCT column1_dict) +FROM test +GROUP BY column2_utf8view; +---- +1 1 1 +1 1 1 +1 1 1 + + +query TT +EXPLAIN SELECT + COUNT(DISTINCT column1_utf8), + COUNT(DISTINCT column1_utf8view), + COUNT(DISTINCT column1_dict) +FROM test; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[count(DISTINCT test.column1_utf8), count(DISTINCT test.column1_utf8view), count(DISTINCT test.column1_dict)]] +02)--TableScan: test projection=[column1_utf8, column1_utf8view, column1_dict] + statement ok drop table test; diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index eebadb239d56..89f2efec66aa 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -30,8 +30,8 @@ use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; use datafusion::logical_expr::{ - aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case, - EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, Values, + expr::find_df_window_func, Aggregate, BinaryExpr, Case, EmptyRelation, Expr, + ExprSchemable, LogicalPlan, Operator, Projection, Values, }; use substrait::proto::expression::subquery::set_predicate::PredicateOp; use url::Url; @@ -67,7 +67,6 @@ use datafusion::{ scalar::ScalarValue, }; use std::collections::{HashMap, HashSet}; -use std::str::FromStr; use std::sync::Arc; use substrait::proto::exchange_rel::ExchangeKind; use substrait::proto::expression::literal::user_defined::Val; @@ -1005,11 +1004,6 @@ pub async fn from_substrait_agg_func( Ok(Arc::new(Expr::AggregateFunction( expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None), ))) - } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) - { - Ok(Arc::new(Expr::AggregateFunction( - expr::AggregateFunction::new(fun, args, distinct, filter, order_by, None), - ))) } else { not_impl_err!( "Aggregate function {} is not supported: function anchor = {:?}", diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 8263209ffccc..bd6e0e00491a 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -48,7 +48,6 @@ use datafusion::common::{ }; use datafusion::common::{substrait_err, DFSchemaRef}; #[allow(unused_imports)] -use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ AggregateFunctionDefinition, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, Sort, WindowFunction, @@ -767,37 +766,6 @@ pub fn to_substrait_agg_measure( match expr { Expr::AggregateFunction(expr::AggregateFunction { func_def, args, distinct, filter, order_by, null_treatment: _, }) => { match func_def { - AggregateFunctionDefinition::BuiltIn (fun) => { - let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? - } else { - vec![] - }; - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extensions)?)) }); - } - let function_anchor = extensions.register_function(fun.to_string()); - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: function_anchor, - arguments, - sorts, - output_type: None, - invocation: match distinct { - true => AggregationInvocation::Distinct as i32, - false => AggregationInvocation::All as i32, - }, - phase: AggregationPhase::Unspecified as i32, - args: vec![], - options: vec![], - }), - filter: match filter { - Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extensions)?), - None => None - } - }) - } AggregateFunctionDefinition::UDF(fun) => { let sorts = if let Some(order_by) = order_by { order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? diff --git a/docs/source/user-guide/dataframe.md b/docs/source/user-guide/dataframe.md index c3d0b6c2d688..96be1bb9e256 100644 --- a/docs/source/user-guide/dataframe.md +++ b/docs/source/user-guide/dataframe.md @@ -53,6 +53,7 @@ Here is a minimal example showing the execution of a query using the DataFrame A ```rust use datafusion::prelude::*; use datafusion::error::Result; +use datafusion::functions_aggregate::expr_fn::min; #[tokio::main] async fn main() -> Result<()> { diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index 813dbb1bc02a..6108315f398a 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -60,6 +60,7 @@ library guide for more information on the SQL API. ```rust use datafusion::prelude::*; +use datafusion::functions_aggregate::expr_fn::min; #[tokio::main] async fn main() -> datafusion::error::Result<()> { @@ -148,6 +149,7 @@ async fn main() -> datafusion::error::Result<()> { ```rust use datafusion::prelude::*; +use datafusion::functions_aggregate::expr_fn::min; #[tokio::main] async fn main() -> datafusion::error::Result<()> { diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 3a39419236d8..8f8983061eb6 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -114,6 +114,7 @@ Here are some active projects using DataFusion: - [qv](https://github.com/timvw/qv) Quickly view your data - [Restate](https://github.com/restatedev) Easily build resilient applications using distributed durable async/await - [ROAPI](https://github.com/roapi/roapi) +- [Sail](https://github.com/lakehq/sail) Unifying stream, batch, and AI workloads with Apache Spark compatibility - [Seafowl](https://github.com/splitgraph/seafowl) CDN-friendly analytical database - [Spice.ai](https://github.com/spiceai/spiceai) Unified SQL query interface & materialization engine - [Synnada](https://synnada.ai/) Streaming-first framework for data products