Skip to content

Commit f198fc8

Browse files
ctskalamb
andauthored
Fix quadratic runtime in min_max_bytes (#18044)
## Which issue does this PR close? - Closes #17897 ## What changes are included in this PR? This PR replaces the `locations` vector used to reduce the number of allocations / resizes in the accumulator with. a HashMap instead. ## Are these changes tested? Not in particular. Additional unit-tests and broader regression testing would be useful. A microbenchmark verifies that the runtime is no longer quadratic. ## Are there any user-facing changes? No. --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent f199b00 commit f198fc8

File tree

2 files changed

+26
-30
lines changed

2 files changed

+26
-30
lines changed

datafusion/common/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,12 @@ pub use error::{
108108
// The HashMap and HashSet implementations that should be used as the uniform defaults
109109
pub type HashMap<K, V, S = DefaultHashBuilder> = hashbrown::HashMap<K, V, S>;
110110
pub type HashSet<T, S = DefaultHashBuilder> = hashbrown::HashSet<T, S>;
111+
pub mod hash_map {
112+
pub use hashbrown::hash_map::Entry;
113+
}
114+
pub mod hash_set {
115+
pub use hashbrown::hash_set::Entry;
116+
}
111117

112118
/// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is
113119
/// not possible. In normal usage of DataFusion the downcast should always succeed.

datafusion/functions-aggregate/src/min_max/min_max_bytes.rs

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ use arrow::array::{
2020
LargeBinaryBuilder, LargeStringBuilder, StringBuilder, StringViewBuilder,
2121
};
2222
use arrow::datatypes::DataType;
23-
use datafusion_common::{internal_err, Result};
23+
use datafusion_common::hash_map::Entry;
24+
use datafusion_common::{internal_err, HashMap, Result};
2425
use datafusion_expr::{EmitTo, GroupsAccumulator};
2526
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls;
2627
use std::mem::size_of;
@@ -391,14 +392,6 @@ struct MinMaxBytesState {
391392
total_data_bytes: usize,
392393
}
393394

394-
#[derive(Debug, Clone, Copy)]
395-
enum MinMaxLocation<'a> {
396-
/// the min/max value is stored in the existing `min_max` array
397-
ExistingMinMax,
398-
/// the min/max value is stored in the input array at the given index
399-
Input(&'a [u8]),
400-
}
401-
402395
/// Implement the MinMaxBytesAccumulator with a comparison function
403396
/// for comparing strings
404397
impl MinMaxBytesState {
@@ -450,7 +443,7 @@ impl MinMaxBytesState {
450443
// Minimize value copies by calculating the new min/maxes for each group
451444
// in this batch (either the existing min/max or the new input value)
452445
// and updating the owned values in `self.min_maxes` at most once
453-
let mut locations = vec![MinMaxLocation::ExistingMinMax; total_num_groups];
446+
let mut locations = HashMap::<usize, &[u8]>::with_capacity(group_indices.len());
454447

455448
// Figure out the new min value for each group
456449
for (new_val, group_index) in iter.into_iter().zip(group_indices.iter()) {
@@ -459,32 +452,29 @@ impl MinMaxBytesState {
459452
continue; // skip nulls
460453
};
461454

462-
let existing_val = match locations[group_index] {
463-
// previous input value was the min/max, so compare it
464-
MinMaxLocation::Input(existing_val) => existing_val,
465-
MinMaxLocation::ExistingMinMax => {
466-
let Some(existing_val) = self.min_max[group_index].as_ref() else {
467-
// no existing min/max, so this is the new min/max
468-
locations[group_index] = MinMaxLocation::Input(new_val);
469-
continue;
470-
};
471-
existing_val.as_ref()
455+
match locations.entry(group_index) {
456+
Entry::Occupied(mut occupied_entry) => {
457+
if cmp(new_val, occupied_entry.get()) {
458+
occupied_entry.insert(new_val);
459+
}
460+
}
461+
Entry::Vacant(vacant_entry) => {
462+
if let Some(old_val) = self.min_max[group_index].as_ref() {
463+
if cmp(new_val, old_val) {
464+
vacant_entry.insert(new_val);
465+
}
466+
} else {
467+
vacant_entry.insert(new_val);
468+
}
472469
}
473470
};
474-
475-
// Compare the new value to the existing value, replacing if necessary
476-
if cmp(new_val, existing_val) {
477-
locations[group_index] = MinMaxLocation::Input(new_val);
478-
}
479471
}
480472

481473
// Update self.min_max with any new min/max values we found in the input
482-
for (group_index, location) in locations.iter().enumerate() {
483-
match location {
484-
MinMaxLocation::ExistingMinMax => {}
485-
MinMaxLocation::Input(new_val) => self.set_value(group_index, new_val),
486-
}
474+
for (group_index, location) in locations.iter() {
475+
self.set_value(*group_index, location);
487476
}
477+
488478
Ok(())
489479
}
490480

0 commit comments

Comments
 (0)