diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index 88d62f1427a8..1fa594747e23 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -42,7 +42,7 @@ use crate::{ RecordBatchStream, SendableRecordBatchStream, }; -use arrow::array::{ArrayRef, UInt32Array, UInt64Array}; +use arrow::array::{Array, ArrayRef, UInt32Array, UInt64Array}; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::{ @@ -296,6 +296,35 @@ pub(super) fn lookup_join_hashmap( Ok((build_indices, probe_indices, next_offset)) } +/// Counts the number of distinct elements in the input array. +/// +/// The input array must be sorted (e.g., `[0, 1, 1, 2, 2, ...]`) and contain no null values. +#[inline] +fn count_distinct_sorted_indices(indices: &UInt32Array) -> usize { + if indices.is_empty() { + return 0; + } + + debug_assert!(indices.null_count() == 0); + + let values_buf = indices.values(); + let values = values_buf.as_ref(); + let mut iter = values.iter(); + let Some(&first) = iter.next() else { + return 0; + }; + + let mut count = 1usize; + let mut last = first; + for &value in iter { + if value != last { + last = value; + count += 1; + } + } + count +} + impl HashJoinStream { #[allow(clippy::too_many_arguments)] pub(super) fn new( @@ -517,21 +546,7 @@ impl HashJoinStream { state.offset, )?; - let mut last_seen: Option = None; - let distinct_right_indices_count = right_indices - .iter() - .filter(|ele| match ele { - Some(ele_val) => { - if last_seen.is_none() || last_seen.unwrap() != *ele_val { - last_seen = Some(*ele_val); - true - } else { - false - } - } - None => false, - }) - .count(); + let distinct_right_indices_count = count_distinct_sorted_indices(&right_indices); self.join_metrics .probe_hit_rate