From a31868a5ad5885bb303cd323262f1ef749e0e04b Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Thu, 14 Nov 2024 15:27:42 +0100 Subject: [PATCH 01/21] wip --- .../src/nodes/joins/equi_join.rs | 133 ++++++++++++++++++ crates/polars-stream/src/nodes/joins/mod.rs | 1 + 2 files changed, 134 insertions(+) create mode 100644 crates/polars-stream/src/nodes/joins/equi_join.rs diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs new file mode 100644 index 000000000000..3cba624ec9d6 --- /dev/null +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -0,0 +1,133 @@ +use std::sync::Arc; + +use polars_core::schema::Schema; +use polars_ops::frame::JoinArgs; + +use crate::nodes::compute_node_prelude::*; +use crate::nodes::in_memory_sink::InMemorySinkNode; +use crate::nodes::in_memory_source::InMemorySourceNode; + +struct BuildPartition { + hash_keys: Vec, + frames: Vec, +} + +struct BuildState { + partitions: Vec, +} + +struct ProbeState { + +} + +enum EquiJoinState { + Build(BuildState), + Probe(ProbeState), + Done, +} + +pub struct EquiJoinNode { + state: EquiJoinState, + num_pipelines: usize, + left_is_build: bool, + coalesce: bool, + emit_unmatched_build: bool, + emit_unmatched_probe: bool, + join_nulls: bool, +} + +impl EquiJoinNode { + pub fn new( + left_input_schema: Arc, + right_input_schema: Arc, + args: JoinArgs, + ) -> Self { + Self { + state: EquiJoinState::Sink { + left: InMemorySinkNode::new(left_input_schema), + right: InMemorySinkNode::new(right_input_schema), + }, + num_pipelines: 0, + } + } +} + +impl ComputeNode for EquiJoinNode { + fn name(&self) -> &str { + "in_memory_join" + } + + fn initialize(&mut self, num_pipelines: usize) { + self.num_pipelines = num_pipelines; + } + + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { + assert!(recv.len() == 2 && send.len() == 1); + + // If the output doesn't want any more data, transition to being done. + if send[0] == PortState::Done && !matches!(self.state, EquiJoinState::Done) { + self.state = EquiJoinState::Done; + } + + // If the input is done, transition to being a source. + if let EquiJoinState::Sink { left, right } = &mut self.state { + if recv[0] == PortState::Done && recv[1] == PortState::Done { + let left_df = left.get_output()?.unwrap(); + let right_df = right.get_output()?.unwrap(); + let mut source_node = + InMemorySourceNode::new(Arc::new((self.joiner)(left_df, right_df)?)); + source_node.initialize(self.num_pipelines); + self.state = EquiJoinState::Source(source_node); + } + } + + match &mut self.state { + EquiJoinState::Sink { left, right, .. } => { + left.update_state(&mut recv[0..1], &mut [])?; + right.update_state(&mut recv[1..2], &mut [])?; + send[0] = PortState::Blocked; + }, + EquiJoinState::Source(source_node) => { + recv[0] = PortState::Done; + recv[1] = PortState::Done; + source_node.update_state(&mut [], send)?; + }, + EquiJoinState::Done => { + recv[0] = PortState::Done; + recv[1] = PortState::Done; + send[0] = PortState::Done; + }, + } + Ok(()) + } + + fn is_memory_intensive_pipeline_blocker(&self) -> bool { + matches!(self.state, EquiJoinState::Sink { .. }) + } + + fn spawn<'env, 's>( + &'env mut self, + scope: &'s TaskScope<'s, 'env>, + recv_ports: &mut [Option>], + send_ports: &mut [Option>], + state: &'s ExecutionState, + join_handles: &mut Vec>>, + ) { + assert!(recv_ports.len() == 2); + assert!(send_ports.len() == 1); + match &mut self.state { + EquiJoinState::Sink { left, right, .. } => { + if recv_ports[0].is_some() { + left.spawn(scope, &mut recv_ports[0..1], &mut [], state, join_handles); + } + if recv_ports[1].is_some() { + right.spawn(scope, &mut recv_ports[1..2], &mut [], state, join_handles); + } + }, + EquiJoinState::Source(source) => { + source.spawn(scope, &mut [], send_ports, state, join_handles) + }, + EquiJoinState::Done => unreachable!(), + } + } +} diff --git a/crates/polars-stream/src/nodes/joins/mod.rs b/crates/polars-stream/src/nodes/joins/mod.rs index fa2e12699f5e..26f3282b4a76 100644 --- a/crates/polars-stream/src/nodes/joins/mod.rs +++ b/crates/polars-stream/src/nodes/joins/mod.rs @@ -1 +1,2 @@ pub mod in_memory; +pub mod equi_join; \ No newline at end of file From b8c045d7efdd6ca4494f9be29ef0596b163edc2e Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Thu, 14 Nov 2024 15:55:43 +0100 Subject: [PATCH 02/21] wip --- crates/polars-expr/src/hash_keys.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/crates/polars-expr/src/hash_keys.rs b/crates/polars-expr/src/hash_keys.rs index d5f85ef49db7..bde928a1e3df 100644 --- a/crates/polars-expr/src/hash_keys.rs +++ b/crates/polars-expr/src/hash_keys.rs @@ -1,5 +1,6 @@ use arrow::array::BinaryArray; use arrow::compute::take::binary::take_unchecked; +use arrow::compute::utils::combine_validities_and_many; use polars_core::frame::DataFrame; use polars_core::prelude::row_encode::_get_rows_encoded_unordered; use polars_core::prelude::PlRandomState; @@ -18,15 +19,20 @@ pub enum HashKeys { } impl HashKeys { - pub fn from_df(df: &DataFrame, random_state: PlRandomState, force_row_encoding: bool) -> Self { + pub fn from_df(df: &DataFrame, random_state: PlRandomState, null_is_valid: bool, force_row_encoding: bool) -> Self { if df.width() > 1 || force_row_encoding { let keys = df .get_columns() .iter() .map(|c| c.as_materialized_series().clone()) .collect_vec(); - let keys_encoded = _get_rows_encoded_unordered(&keys[..]).unwrap().into_array(); - assert!(keys_encoded.len() == df.height()); + let mut keys_encoded = _get_rows_encoded_unordered(&keys[..]).unwrap().into_array(); + + if !null_is_valid { + let validities = keys.iter().map(|c| c.rechunk_validity()).collect_vec(); + let combined = combine_validities_and_many(&validities); + keys_encoded.set_validity(combined); + } // TODO: use vechash? Not supported yet for lists. // let mut hashes = Vec::with_capacity(df.height()); From bafd230a565b71b5ba02866530fb681f3b83883d Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 15 Nov 2024 14:51:49 +0100 Subject: [PATCH 03/21] wip --- .../src/nodes/joins/equi_join.rs | 92 ++++++++++++++++--- 1 file changed, 80 insertions(+), 12 deletions(-) diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index 3cba624ec9d6..eca6dd5694df 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -1,11 +1,64 @@ use std::sync::Arc; +use polars_core::prelude::PlHashSet; use polars_core::schema::Schema; -use polars_ops::frame::JoinArgs; +use polars_expr::hash_keys::HashKeys; +use polars_ops::frame::{JoinArgs, JoinType}; +use polars_utils::format_pl_smallstr; +use polars_utils::pl_str::PlSmallStr; use crate::nodes::compute_node_prelude::*; -use crate::nodes::in_memory_sink::InMemorySinkNode; -use crate::nodes::in_memory_source::InMemorySourceNode; + +/// A payload selector contains for each column whether that column should be +/// included in the payload, and if yes with what name. +fn compute_payload_selector( + this: &Schema, + other: &Schema, + is_left: bool, + args: &JoinArgs, +) -> Vec> { + let should_coalesce = args.should_coalesce(); + let other_col_names: PlHashSet = other.iter_names_cloned().collect(); + + this.iter_names() + .map(|c| { + if !other_col_names.contains(c) { + return Some(c.clone()); + } + + if is_left { + if should_coalesce && args.how == JoinType::Right { + None + } else { + Some(c.clone()) + } + } else { + if should_coalesce { + if args.how == JoinType::Right { + Some(c.clone()) + } else { + None + } + } else { + Some(format_pl_smallstr!("{}{}", c, args.suffix())) + } + } + }) + .collect() +} + +fn select_payload(df: DataFrame, selector: &[Option]) -> DataFrame { + // Maintain height of zero-width dataframes. + if df.width() == 0 { + return df; + } + + df.take_columns() + .into_iter() + .zip(selector) + .filter_map(|(c, name)| Some(c.with_name(name.clone()?))) + .collect() +} struct BuildPartition { hash_keys: Vec, @@ -16,9 +69,7 @@ struct BuildState { partitions: Vec, } -struct ProbeState { - -} +struct ProbeState {} enum EquiJoinState { Build(BuildState), @@ -30,10 +81,11 @@ pub struct EquiJoinNode { state: EquiJoinState, num_pipelines: usize, left_is_build: bool, - coalesce: bool, emit_unmatched_build: bool, emit_unmatched_probe: bool, - join_nulls: bool, + left_payload_select: Vec>, + right_payload_select: Vec>, + args: JoinArgs, } impl EquiJoinNode { @@ -42,16 +94,31 @@ impl EquiJoinNode { right_input_schema: Arc, args: JoinArgs, ) -> Self { + // TODO: use cardinality estimation to determine this. + let left_is_build = args.how != JoinType::Left; + + let emit_unmatched_left = args.how == JoinType::Left || args.how == JoinType::Full; + let emit_unmatched_right = args.how == JoinType::Right || args.how == JoinType::Full; + let emit_unmatched_build = if left_is_build { emit_unmatched_left } else { emit_unmatched_right }; + let emit_unmatched_probe = if left_is_build { emit_unmatched_right } else { emit_unmatched_left }; + let left_payload_select = compute_payload_selector(&left_input_schema, &right_input_schema, true, &args); + let right_payload_select = compute_payload_selector(&right_input_schema, &left_input_schema, false, &args); Self { - state: EquiJoinState::Sink { - left: InMemorySinkNode::new(left_input_schema), - right: InMemorySinkNode::new(right_input_schema), - }, + state: EquiJoinState::Build(BuildState { + partitions: Vec::new() + }), num_pipelines: 0, + left_is_build, + emit_unmatched_build, + emit_unmatched_probe, + left_payload_select, + right_payload_select, + args } } } +/* impl ComputeNode for EquiJoinNode { fn name(&self) -> &str { "in_memory_join" @@ -131,3 +198,4 @@ impl ComputeNode for EquiJoinNode { } } } +*/ From bac2dde7d3b25e2e4fc47e724f2f73ed32a2543e Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 15 Nov 2024 18:52:04 +0100 Subject: [PATCH 04/21] wip --- crates/polars-core/src/frame/mod.rs | 4 +- .../polars-expr/src/chunked_idx_table/mod.rs | 30 ++ .../src/chunked_idx_table/row_encoded.rs | 383 ++++++++++++++++++ crates/polars-expr/src/groups/mod.rs | 2 +- crates/polars-expr/src/hash_keys.rs | 56 ++- crates/polars-expr/src/lib.rs | 1 + crates/polars-stream/src/nodes/group_by.rs | 2 +- .../src/nodes/joins/equi_join.rs | 182 ++++++--- 8 files changed, 594 insertions(+), 66 deletions(-) create mode 100644 crates/polars-expr/src/chunked_idx_table/mod.rs create mode 100644 crates/polars-expr/src/chunked_idx_table/row_encoded.rs diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index da29d8da070b..65c29e306792 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -1894,11 +1894,11 @@ impl DataFrame { unsafe { DataFrame::new_no_checks(idx.len(), cols) } } - pub(crate) unsafe fn take_slice_unchecked(&self, idx: &[IdxSize]) -> Self { + pub unsafe fn take_slice_unchecked(&self, idx: &[IdxSize]) -> Self { self.take_slice_unchecked_impl(idx, true) } - unsafe fn take_slice_unchecked_impl(&self, idx: &[IdxSize], allow_threads: bool) -> Self { + pub unsafe fn take_slice_unchecked_impl(&self, idx: &[IdxSize], allow_threads: bool) -> Self { let cols = if allow_threads { POOL.install(|| self._apply_columns_par(&|s| s.take_slice_unchecked(idx))) } else { diff --git a/crates/polars-expr/src/chunked_idx_table/mod.rs b/crates/polars-expr/src/chunked_idx_table/mod.rs new file mode 100644 index 000000000000..50a85a785420 --- /dev/null +++ b/crates/polars-expr/src/chunked_idx_table/mod.rs @@ -0,0 +1,30 @@ +use std::any::Any; + +use polars_core::prelude::*; +use polars_utils::index::ChunkId; +use polars_utils::IdxSize; + +use crate::hash_keys::HashKeys; + +mod row_encoded; + + +pub trait ChunkedIdxTable: Any + Send + Sync { + /// Creates a new empty ChunkedIdxTable similar to this one. + fn new_empty(&self) -> Box; + + /// Reserves space for the given number additional keys. + fn reserve(&mut self, additional: usize); + + /// Returns the number of unique keys in this ChunkedIdxTable. + fn num_keys(&self) -> IdxSize; + + /// Inserts the given key chunk into this ChunkedIdxTable. + fn insert_key_chunk(&mut self, keys: HashKeys); + + /// Probe the table, returning a ChunkId per key. + fn probe(&self, keys: &HashKeys, out: &mut Vec); + + /// Get the ChunkIds for each key which was never probed. + fn unprobed_keys(&self, out: &mut Vec); +} \ No newline at end of file diff --git a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs new file mode 100644 index 000000000000..00bf386780e6 --- /dev/null +++ b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs @@ -0,0 +1,383 @@ +use hashbrown::hash_table::{Entry as TEntry, HashTable, OccupiedEntry as TOccupiedEntry, VacantEntry as TVacantEntry}; +use polars_utils::IdxSize; + +const BASE_KEY_DATA_CAPACITY: usize = 1024; + +struct Key { + key_hash: u64, + key_buffer: u32, + key_offset: usize, + key_length: u32, +} + +impl Key { + unsafe fn get<'k>(&self, key_data: &'k [Vec]) -> &'k [u8] { + let buf = key_data.get_unchecked(self.key_buffer as usize); + buf.get_unchecked(self.key_offset..self.key_offset + self.key_length as usize) + } +} + + +/// An IndexMap where the keys are always [u8] slices which are pre-hashed. +pub struct BytesIndexMap { + table: HashTable, + tuples: Vec<(Key, V)>, + key_data: Vec>, + + // Internal random seed used to keep hash iteration order decorrelated. + // We simply store a random odd number and multiply the canonical hash by it. + seed: u64, +} + +impl Default for BytesIndexMap { + fn default() -> Self { + Self { + table: HashTable::new(), + tuples: Vec::new(), + key_data: vec![Vec::with_capacity(BASE_KEY_DATA_CAPACITY)], + seed: rand::random::() | 1, + } + } +} + +impl BytesIndexMap { + pub fn new() -> Self { + Self::default() + } + + pub fn reserve(&mut self, additional: usize) { + self.table.reserve(additional, |i| unsafe { + let tuple = self.tuples.get_unchecked(*i as usize); + tuple.0.key_hash.wrapping_mul(self.seed) + }); + self.tuples.reserve(additional); + } + + pub fn entry<'k>(&mut self, key: &'k [u8], hash: u64) -> Entry<'_, 'k, V> { + let entry = self.table.entry( + hash.wrapping_mul(self.seed), + |i| unsafe { + let t = self.tuples.get_unchecked(*i as usize); + hash == t.0.key_hash && key == t.0.get(&self.key_data) + }, + |i| unsafe { + let t = self.tuples.get_unchecked(*i as usize); + t.0.key_hash.wrapping_mul(self.seed) + }, + ); + + match entry { + TEntry::Occupied(o) => Entry::Occupied(OccupiedEntry { + entry: o, + tuples: &mut self.tuples, + key_data: &mut self.key_data, + }), + TEntry::Vacant(v) => Entry::Vacant(VacantEntry { + key, + hash, + entry: v, + tuples: &mut self.tuples, + key_data: &mut self.key_data, + }), + } + } +} + +pub enum Entry<'a, 'k, V> { + Occupied(OccupiedEntry<'a, V>), + Vacant(VacantEntry<'a, 'k, V>), +} + +pub struct OccupiedEntry<'a, V> { + entry: TOccupiedEntry<'a, IdxSize>, + tuples: &'a mut Vec<(Key, V)>, + key_data: &'a mut Vec>, +} + +impl<'a, V> OccupiedEntry<'a, V> { + pub fn index(&self) -> IdxSize { + *self.entry.get() + } +} + +pub struct VacantEntry<'a, 'k, V> { + key: &'k [u8], + hash: u64, + entry: TVacantEntry<'a, IdxSize>, + tuples: &'a mut Vec<(Key, V)>, + key_data: &'a mut Vec>, +} + +impl<'a, 'k, V> VacantEntry<'a, 'k, V> { + pub fn index(&self) -> IdxSize { + self.tuples.len() as IdxSize + } + + pub fn insert(self, value: V) -> &'a mut V { + unsafe { + let tuple_idx: IdxSize = self.tuples.len().try_into().unwrap(); + + let mut num_buffers = self.key_data.len() as u32; + let mut active_buf = self.key_data.last_mut().unwrap_unchecked(); + let key_len = self.key.len(); + if active_buf.len() + key_len > active_buf.capacity() { + let ideal_next_cap = BASE_KEY_DATA_CAPACITY.checked_shl(num_buffers).unwrap(); + let next_capacity = std::cmp::max(ideal_next_cap, key_len); + self.key_data.push(Vec::with_capacity(next_capacity)); + active_buf = self.key_data.last_mut().unwrap_unchecked(); + num_buffers += 1; + } + + let tuple_key = Key { + key_hash: self.hash, + key_buffer: num_buffers - 1, + key_offset: active_buf.len(), + key_length: self.key.len().try_into().unwrap(), + }; + self.tuples.push((tuple_key, value)); + active_buf.extend_from_slice(self.key); + self.entry.insert(tuple_idx); + &mut self.tuples.last_mut().unwrap_unchecked().1 + } + } +} + + +/* +use hashbrown::hash_table::{Entry, HashTable}; +use polars_row::EncodingField; +use polars_utils::cardinality_sketch::CardinalitySketch; +use polars_utils::vec::PushUnchecked; + +use super::*; +use crate::hash_keys::HashKeys; + +const BASE_KEY_DATA_CAPACITY: usize = 1024; + +struct Key { + key_hash: u64, + key_buffer: u32, + key_offset: usize, + key_length: u32, +} + +impl Key { + unsafe fn get<'k>(&self, key_data: &'k [Vec]) -> &'k [u8] { + let buf = key_data.get_unchecked(self.key_buffer as usize); + buf.get_unchecked(self.key_offset..self.key_offset + self.key_length as usize) + } +} + +#[derive(Default)] +pub struct RowEncodedHashtupleer { + key_schema: Arc, + table: HashTable, + group_keys: Vec, + key_data: Vec>, + + // Internal random seed used to keep hash iteration order decorrelated. + // We simply store a random odd number and multiply the canonical hash by it. + seed: u64, +} + +impl RowEncodedHashGrouper { + pub fn new(key_schema: Arc) -> Self { + Self { + key_schema, + seed: rand::random::() | 1, + key_data: vec![Vec::with_capacity(BASE_KEY_DATA_CAPACITY)], + ..Default::default() + } + } + + fn insert_key(&mut self, hash: u64, key: &[u8]) -> IdxSize { + let entry = self.table.entry( + hash.wrapping_mul(self.seed), + |g| unsafe { + let gk = self.group_keys.get_unchecked(*g as usize); + hash == gk.key_hash && key == gk.get(&self.key_data) + }, + |g| unsafe { + let gk = self.group_keys.get_unchecked(*g as usize); + gk.key_hash.wrapping_mul(self.seed) + }, + ); + + match entry { + Entry::Occupied(e) => *e.get(), + Entry::Vacant(e) => unsafe { + let mut num_buffers = self.key_data.len() as u32; + let mut active_buf = self.key_data.last_mut().unwrap_unchecked(); + let key_len = key.len(); + if active_buf.len() + key_len > active_buf.capacity() { + let ideal_next_cap = BASE_KEY_DATA_CAPACITY.checked_shl(num_buffers).unwrap(); + let next_capacity = std::cmp::max(ideal_next_cap, key_len); + self.key_data.push(Vec::with_capacity(next_capacity)); + active_buf = self.key_data.last_mut().unwrap_unchecked(); + num_buffers += 1; + } + + let group_idx: IdxSize = self.group_keys.len().try_into().unwrap(); + let group_key = Key { + key_hash: hash, + key_buffer: num_buffers - 1, + key_offset: active_buf.len(), + key_length: key.len().try_into().unwrap(), + }; + self.group_keys.push(group_key); + active_buf.extend_from_slice(key); + e.insert(group_idx); + group_idx + }, + } + } + + fn finalize_keys(&self, mut key_rows: Vec<&[u8]>) -> DataFrame { + let key_dtypes = self + .key_schema + .iter() + .map(|(_name, dt)| dt.to_physical().to_arrow(CompatLevel::newest())) + .collect::>(); + let fields = vec![EncodingField::new_unsorted(); key_dtypes.len()]; + let key_columns = + unsafe { polars_row::decode::decode_rows(&mut key_rows, &fields, &key_dtypes) }; + + let cols = self + .key_schema + .iter() + .zip(key_columns) + .map(|((name, dt), col)| { + let s = Series::try_from((name.clone(), col)).unwrap(); + unsafe { s.to_logical_repr_unchecked(dt) } + .unwrap() + .into_column() + }) + .collect(); + unsafe { DataFrame::new_no_checks_height_from_first(cols) } + } +} + +impl Grouper for RowEncodedHashGrouper { + fn new_empty(&self) -> Box { + Box::new(Self::new(self.key_schema.clone())) + } + + fn reserve(&mut self, additional: usize) { + self.table.reserve(additional, |g| unsafe { + let gk = self.group_keys.get_unchecked(*g as usize); + gk.key_hash.wrapping_mul(self.seed) + }); + self.group_keys.reserve(additional); + } + + fn num_groups(&self) -> IdxSize { + self.table.len() as IdxSize + } + + fn insert_keys(&mut self, keys: HashKeys, group_idxs: &mut Vec) { + let HashKeys::RowEncoded(keys) = keys else { + unreachable!() + }; + group_idxs.clear(); + group_idxs.reserve(keys.hashes.len()); + for (hash, key) in keys.hashes.iter().zip(keys.keys.values_iter()) { + unsafe { + group_idxs.push_unchecked(self.insert_key(*hash, key)); + } + } + } + + fn combine(&mut self, other: &dyn Grouper, group_idxs: &mut Vec) { + let other = other.as_any().downcast_ref::().unwrap(); + + // TODO: cardinality estimation. + self.table.reserve(other.group_keys.len(), |g| unsafe { + let gk = self.group_keys.get_unchecked(*g as usize); + gk.key_hash.wrapping_mul(self.seed) + }); + + unsafe { + group_idxs.clear(); + group_idxs.reserve(other.table.len()); + for group_key in &other.group_keys { + let new_idx = self.insert_key(group_key.key_hash, group_key.get(&other.key_data)); + group_idxs.push_unchecked(new_idx); + } + } + } + + unsafe fn gather_combine( + &mut self, + other: &dyn Grouper, + subset: &[IdxSize], + group_idxs: &mut Vec, + ) { + let other = other.as_any().downcast_ref::().unwrap(); + + // TODO: cardinality estimation. + self.table.reserve(subset.len(), |g| unsafe { + let gk = self.group_keys.get_unchecked(*g as usize); + gk.key_hash.wrapping_mul(self.seed) + }); + self.group_keys.reserve(subset.len()); + + unsafe { + group_idxs.clear(); + group_idxs.reserve(subset.len()); + for i in subset { + let group_key = other.group_keys.get_unchecked(*i as usize); + let new_idx = self.insert_key(group_key.key_hash, group_key.get(&other.key_data)); + group_idxs.push_unchecked(new_idx); + } + } + } + + fn get_keys_in_group_order(&self) -> DataFrame { + let mut key_rows: Vec<&[u8]> = Vec::with_capacity(self.table.len()); + unsafe { + for group_key in &self.group_keys { + key_rows.push_unchecked(group_key.get(&self.key_data)); + } + } + self.finalize_keys(key_rows) + } + + fn gen_partition_idxs( + &self, + partitioner: &HashPartitioner, + partition_idxs: &mut [Vec], + sketches: &mut [CardinalitySketch], + ) { + let num_partitions = partitioner.num_partitions(); + assert!(partition_idxs.len() == num_partitions); + assert!(sketches.len() == num_partitions); + + // Two-pass algorithm to prevent reallocations. + let mut partition_sizes = vec![0; num_partitions]; + unsafe { + for group_key in &self.group_keys { + let p_idx = partitioner.hash_to_partition(group_key.key_hash); + *partition_sizes.get_unchecked_mut(p_idx) += 1; + sketches.get_unchecked_mut(p_idx).insert(group_key.key_hash); + } + } + + for (partition, sz) in partition_idxs.iter_mut().zip(partition_sizes) { + partition.clear(); + partition.reserve(sz); + } + + unsafe { + for (i, group_key) in self.group_keys.iter().enumerate() { + let p_idx = partitioner.hash_to_partition(group_key.key_hash); + let p = partition_idxs.get_unchecked_mut(p_idx); + p.push_unchecked(i as IdxSize); + } + } + } + + fn as_any(&self) -> &dyn Any { + self + } +} +*/ \ No newline at end of file diff --git a/crates/polars-expr/src/groups/mod.rs b/crates/polars-expr/src/groups/mod.rs index 2938536a729e..42d259de7fd8 100644 --- a/crates/polars-expr/src/groups/mod.rs +++ b/crates/polars-expr/src/groups/mod.rs @@ -15,7 +15,7 @@ pub trait Grouper: Any + Send + Sync { /// Creates a new empty Grouper similar to this one. fn new_empty(&self) -> Box; - /// Reserves space for the given number additional of groups. + /// Reserves space for the given number additional groups. fn reserve(&mut self, additional: usize); /// Returns the number of groups in this Grouper. diff --git a/crates/polars-expr/src/hash_keys.rs b/crates/polars-expr/src/hash_keys.rs index bde928a1e3df..90a35d0cfae8 100644 --- a/crates/polars-expr/src/hash_keys.rs +++ b/crates/polars-expr/src/hash_keys.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use arrow::array::BinaryArray; use arrow::compute::take::binary::take_unchecked; use arrow::compute::utils::combine_validities_and_many; @@ -5,6 +7,7 @@ use polars_core::frame::DataFrame; use polars_core::prelude::row_encode::_get_rows_encoded_unordered; use polars_core::prelude::PlRandomState; use polars_core::series::Series; +use polars_utils::cardinality_sketch::CardinalitySketch; use polars_utils::hashing::HashPartitioner; use polars_utils::itertools::Itertools; use polars_utils::vec::PushUnchecked; @@ -43,7 +46,7 @@ impl HashKeys { .map(|k| random_state.hash_one(k)) .collect(); Self::RowEncoded(RowEncodedKeys { - hashes, + hashes: Arc::new(hashes), keys: keys_encoded, }) } else { @@ -56,14 +59,20 @@ impl HashKeys { } } + /// After this call partition_idxs[p] will contain the indices of hashes + /// that belong to partition p, and the cardinality sketches are updated + /// accordingly. + /// + /// If null_is_valid is false rows with nulls do not get assigned a partition. pub fn gen_partition_idxs( &self, partitioner: &HashPartitioner, partition_idxs: &mut [Vec], + sketches: &mut [CardinalitySketch], ) { match self { - Self::RowEncoded(s) => s.gen_partition_idxs(partitioner, partition_idxs), - Self::Single(s) => s.gen_partition_idxs(partitioner, partition_idxs), + Self::RowEncoded(s) => s.gen_partition_idxs(partitioner, partition_idxs, sketches), + Self::Single(s) => s.gen_partition_idxs(partitioner, partition_idxs, sketches), } } @@ -78,7 +87,7 @@ impl HashKeys { } pub struct RowEncodedKeys { - pub hashes: Vec, + pub hashes: Arc>, pub keys: BinaryArray, } @@ -87,13 +96,33 @@ impl RowEncodedKeys { &self, partitioner: &HashPartitioner, partition_idxs: &mut [Vec], + sketches: &mut [CardinalitySketch], ) { - assert!(partitioner.num_partitions() == partition_idxs.len()); - for (i, h) in self.hashes.iter().enumerate() { - unsafe { - // SAFETY: we assured the number of partitions matches. - let p = partitioner.hash_to_partition(*h); - partition_idxs.get_unchecked_mut(p).push(i as IdxSize); + assert!(partition_idxs.len() == partitioner.num_partitions()); + assert!(sketches.len() == partitioner.num_partitions()); + for p in partition_idxs.iter_mut() { + p.clear(); + } + + if let Some(validity) = self.keys.validity() { + for (i, (h, is_v)) in self.hashes.iter().zip(validity).enumerate() { + if is_v { + unsafe { + // SAFETY: we assured the number of partitions matches. + let p = partitioner.hash_to_partition(*h); + partition_idxs.get_unchecked_mut(p).push(i as IdxSize); + sketches.get_unchecked_mut(p).insert(*h); + } + } + } + } else { + for (i, h) in self.hashes.iter().enumerate() { + unsafe { + // SAFETY: we assured the number of partitions matches. + let p = partitioner.hash_to_partition(*h); + partition_idxs.get_unchecked_mut(p).push(i as IdxSize); + sketches.get_unchecked_mut(p).insert(*h); + } } } } @@ -107,7 +136,7 @@ impl RowEncodedKeys { } let idx_arr = arrow::ffi::mmap::slice(idxs); let keys = take_unchecked(&self.keys, &idx_arr); - Self { hashes, keys } + Self { hashes: Arc::new(hashes), keys } } } @@ -124,8 +153,13 @@ impl SingleKeys { &self, partitioner: &HashPartitioner, partition_idxs: &mut [Vec], + sketches: &mut [CardinalitySketch], ) { assert!(partitioner.num_partitions() == partition_idxs.len()); + for p in partition_idxs.iter_mut() { + p.clear(); + } + todo!() } diff --git a/crates/polars-expr/src/lib.rs b/crates/polars-expr/src/lib.rs index 0a7e7b20bfe2..2da894f9e297 100644 --- a/crates/polars-expr/src/lib.rs +++ b/crates/polars-expr/src/lib.rs @@ -5,5 +5,6 @@ pub mod planner; pub mod prelude; pub mod reduce; pub mod state; +pub mod chunked_idx_table; pub use crate::planner::{create_physical_expr, ExpressionConversionState}; diff --git a/crates/polars-stream/src/nodes/group_by.rs b/crates/polars-stream/src/nodes/group_by.rs index 5784c060384f..0151970ee766 100644 --- a/crates/polars-stream/src/nodes/group_by.rs +++ b/crates/polars-stream/src/nodes/group_by.rs @@ -77,7 +77,7 @@ impl GroupBySinkState { key_columns.push(s.into_column()); } let keys = DataFrame::new_with_broadcast_len(key_columns, df.height())?; - let hash_keys = HashKeys::from_df(&keys, random_state.clone(), true); + let hash_keys = HashKeys::from_df(&keys, random_state.clone(), true, true); local.grouper.insert_keys(hash_keys, &mut group_idxs); // Update reductions. diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index eca6dd5694df..652a2d966180 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -1,12 +1,15 @@ use std::sync::Arc; -use polars_core::prelude::PlHashSet; +use polars_core::prelude::{PlHashSet, PlRandomState}; use polars_core::schema::Schema; use polars_expr::hash_keys::HashKeys; use polars_ops::frame::{JoinArgs, JoinType}; +use polars_utils::cardinality_sketch::CardinalitySketch; use polars_utils::format_pl_smallstr; +use polars_utils::hashing::HashPartitioner; use polars_utils::pl_str::PlSmallStr; +use crate::async_primitives::connector::Receiver; use crate::nodes::compute_node_prelude::*; /// A payload selector contains for each column whether that column should be @@ -60,13 +63,57 @@ fn select_payload(df: DataFrame, selector: &[Option]) -> DataFrame { .collect() } +#[derive(Default)] struct BuildPartition { hash_keys: Vec, frames: Vec, + sketch: Option, } struct BuildState { - partitions: Vec, + partitions_per_worker: Vec>, +} + +impl BuildState { + async fn partition_and_sink( + mut recv: Receiver, + partitions: &mut Vec, + partitioner: HashPartitioner, + params: &EquiJoinParams, + ) -> PolarsResult<()> { + let mut partition_idxs = vec![Vec::new(); partitioner.num_partitions()]; + partitions.resize_with(partitioner.num_partitions(), BuildPartition::default); + + let mut sketches = vec![CardinalitySketch::default(); partitioner.num_partitions()]; + + while let Ok(morsel) = recv.recv().await { + let df = morsel.into_df(); + let hash_keys = HashKeys::from_df(&df, params.random_state.clone(), params.args.join_nulls, true); + let selector = if params.left_is_build { + ¶ms.left_payload_select + } else { + ¶ms.right_payload_select + }; + + // We must rechunk the payload for later chunked gathers. + let mut payload = select_payload(df, selector); + payload.rechunk_mut(); + + unsafe { + hash_keys.gen_partition_idxs(&partitioner, &mut partition_idxs, &mut sketches); + for (p, idxs_in_p) in partitions.iter_mut().zip(&partition_idxs) { + p.hash_keys.push(hash_keys.gather(idxs_in_p)); + p.frames.push(payload.take_slice_unchecked_impl(idxs_in_p, false)); + } + } + } + + for (p, sketch) in sketches.into_iter().enumerate() { + partitions[p].sketch = Some(sketch); + } + + Ok(()) + } } struct ProbeState {} @@ -77,15 +124,38 @@ enum EquiJoinState { Done, } -pub struct EquiJoinNode { - state: EquiJoinState, - num_pipelines: usize, +struct EquiJoinParams { left_is_build: bool, - emit_unmatched_build: bool, - emit_unmatched_probe: bool, left_payload_select: Vec>, right_payload_select: Vec>, args: JoinArgs, + random_state: PlRandomState, +} + +impl EquiJoinParams { + /// Should we emit unmatched rows from the build side? + fn emit_unmatched_build(&self) -> bool { + if self.left_is_build { + self.args.how == JoinType::Left || self.args.how == JoinType::Full + } else { + self.args.how == JoinType::Right || self.args.how == JoinType::Full + } + } + + /// Should we emit unmatched rows from the probe side? + fn emit_unmatched_probe(&self) -> bool { + if self.left_is_build { + self.args.how == JoinType::Right || self.args.how == JoinType::Full + } else { + self.args.how == JoinType::Left || self.args.how == JoinType::Full + } + } +} + +pub struct EquiJoinNode { + state: EquiJoinState, + params: EquiJoinParams, + num_pipelines: usize, } impl EquiJoinNode { @@ -97,31 +167,29 @@ impl EquiJoinNode { // TODO: use cardinality estimation to determine this. let left_is_build = args.how != JoinType::Left; - let emit_unmatched_left = args.how == JoinType::Left || args.how == JoinType::Full; - let emit_unmatched_right = args.how == JoinType::Right || args.how == JoinType::Full; - let emit_unmatched_build = if left_is_build { emit_unmatched_left } else { emit_unmatched_right }; - let emit_unmatched_probe = if left_is_build { emit_unmatched_right } else { emit_unmatched_left }; - let left_payload_select = compute_payload_selector(&left_input_schema, &right_input_schema, true, &args); - let right_payload_select = compute_payload_selector(&right_input_schema, &left_input_schema, false, &args); + let left_payload_select = + compute_payload_selector(&left_input_schema, &right_input_schema, true, &args); + let right_payload_select = + compute_payload_selector(&right_input_schema, &left_input_schema, false, &args); Self { state: EquiJoinState::Build(BuildState { - partitions: Vec::new() + partitions_per_worker: Vec::new(), }), num_pipelines: 0, - left_is_build, - emit_unmatched_build, - emit_unmatched_probe, - left_payload_select, - right_payload_select, - args + params: EquiJoinParams { + left_is_build, + left_payload_select, + right_payload_select, + args, + random_state: PlRandomState::new(), + } } } } -/* impl ComputeNode for EquiJoinNode { fn name(&self) -> &str { - "in_memory_join" + "equi_join" } fn initialize(&mut self, num_pipelines: usize) { @@ -131,33 +199,32 @@ impl ComputeNode for EquiJoinNode { fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { assert!(recv.len() == 2 && send.len() == 1); - // If the output doesn't want any more data, transition to being done. - if send[0] == PortState::Done && !matches!(self.state, EquiJoinState::Done) { + let build_idx = if self.params.left_is_build { 0 } else { 1 }; + let probe_idx = 1 - build_idx; + + // If the output doesn't want any more data, or the probe side is done, + // transition to being done. + if send[0] == PortState::Done || recv[probe_idx] == PortState::Done { self.state = EquiJoinState::Done; } - // If the input is done, transition to being a source. - if let EquiJoinState::Sink { left, right } = &mut self.state { - if recv[0] == PortState::Done && recv[1] == PortState::Done { - let left_df = left.get_output()?.unwrap(); - let right_df = right.get_output()?.unwrap(); - let mut source_node = - InMemorySourceNode::new(Arc::new((self.joiner)(left_df, right_df)?)); - source_node.initialize(self.num_pipelines); - self.state = EquiJoinState::Source(source_node); + // If we are building and the build input is done, transition to probing. + if let EquiJoinState::Build(build_state) = &mut self.state { + if recv[build_idx] == PortState::Done { + todo!() } } match &mut self.state { - EquiJoinState::Sink { left, right, .. } => { - left.update_state(&mut recv[0..1], &mut [])?; - right.update_state(&mut recv[1..2], &mut [])?; + EquiJoinState::Build(_) => { + recv[build_idx] = PortState::Ready; + recv[probe_idx] = PortState::Blocked; send[0] = PortState::Blocked; }, - EquiJoinState::Source(source_node) => { - recv[0] = PortState::Done; - recv[1] = PortState::Done; - source_node.update_state(&mut [], send)?; + EquiJoinState::Probe(_) => { + recv[build_idx] = PortState::Done; + recv[probe_idx] = PortState::Ready; + send[0] = PortState::Ready; }, EquiJoinState::Done => { recv[0] = PortState::Done; @@ -169,7 +236,7 @@ impl ComputeNode for EquiJoinNode { } fn is_memory_intensive_pipeline_blocker(&self) -> bool { - matches!(self.state, EquiJoinState::Sink { .. }) + matches!(self.state, EquiJoinState::Build { .. }) } fn spawn<'env, 's>( @@ -177,25 +244,38 @@ impl ComputeNode for EquiJoinNode { scope: &'s TaskScope<'s, 'env>, recv_ports: &mut [Option>], send_ports: &mut [Option>], - state: &'s ExecutionState, + _state: &'s ExecutionState, join_handles: &mut Vec>>, ) { assert!(recv_ports.len() == 2); assert!(send_ports.len() == 1); + + let build_idx = if self.params.left_is_build { 0 } else { 1 }; + let probe_idx = 1 - build_idx; + match &mut self.state { - EquiJoinState::Sink { left, right, .. } => { - if recv_ports[0].is_some() { - left.spawn(scope, &mut recv_ports[0..1], &mut [], state, join_handles); - } - if recv_ports[1].is_some() { - right.spawn(scope, &mut recv_ports[1..2], &mut [], state, join_handles); + EquiJoinState::Build(build_state) => { + assert!(send_ports[0].is_none()); + assert!(recv_ports[probe_idx].is_none()); + let receivers = recv_ports[build_idx].take().unwrap().parallel(); + + build_state + .partitions_per_worker + .resize_with(self.num_pipelines, || Vec::new()); + let partitioner = HashPartitioner::new(self.num_pipelines, 0); + for (worker_ps, recv) in build_state.partitions_per_worker.iter_mut().zip(receivers) + { + join_handles.push(scope.spawn_task( + TaskPriority::High, + BuildState::partition_and_sink(recv, worker_ps, partitioner.clone(), &self.params), + )); } }, - EquiJoinState::Source(source) => { - source.spawn(scope, &mut [], send_ports, state, join_handles) + EquiJoinState::Probe(probe_state) => { + assert!(recv_ports[build_idx].is_none()); + todo!() }, EquiJoinState::Done => unreachable!(), } } } -*/ From 5868491bd21db6b5db300e630fef9ecc7c64dc2a Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 18 Nov 2024 16:04:45 +0100 Subject: [PATCH 05/21] wip --- .../polars-expr/src/chunked_idx_table/mod.rs | 15 +- .../src/chunked_idx_table/row_encoded.rs | 338 ++++-------------- crates/polars-expr/src/groups/row_encoded.rs | 6 +- crates/polars-expr/src/hash_keys.rs | 17 +- .../src/nodes/joins/equi_join.rs | 108 +++++- crates/polars-utils/src/index.rs | 8 + 6 files changed, 210 insertions(+), 282 deletions(-) diff --git a/crates/polars-expr/src/chunked_idx_table/mod.rs b/crates/polars-expr/src/chunked_idx_table/mod.rs index 50a85a785420..2386a7d45f47 100644 --- a/crates/polars-expr/src/chunked_idx_table/mod.rs +++ b/crates/polars-expr/src/chunked_idx_table/mod.rs @@ -22,9 +22,16 @@ pub trait ChunkedIdxTable: Any + Send + Sync { /// Inserts the given key chunk into this ChunkedIdxTable. fn insert_key_chunk(&mut self, keys: HashKeys); - /// Probe the table, returning a ChunkId per key. - fn probe(&self, keys: &HashKeys, out: &mut Vec); + /// Probe the table, updating table_match and probe_match with + /// (ChunkId, IdxSize) pairs for each match. Will stop processing new keys + /// once limit matches have been generated, returning the number of keys processed. + fn probe(&self, hash_keys: &HashKeys, table_match: &mut Vec, probe_match: &mut Vec, mark_matches: bool, limit: usize) -> usize; - /// Get the ChunkIds for each key which was never probed. - fn unprobed_keys(&self, out: &mut Vec); + /// Get the ChunkIds for each key which was never marked during probing. + fn unmarked_keys(&self, out: &mut Vec); +} + +pub fn new_chunked_idx_table(key_schema: Arc) -> Box { + // Box::new(row_encoded::BytesIndexMap::new(key_schema)) + todo!() } \ No newline at end of file diff --git a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs index 00bf386780e6..4f39322e18a2 100644 --- a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs +++ b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs @@ -1,277 +1,94 @@ -use hashbrown::hash_table::{Entry as TEntry, HashTable, OccupiedEntry as TOccupiedEntry, VacantEntry as TVacantEntry}; -use polars_utils::IdxSize; - -const BASE_KEY_DATA_CAPACITY: usize = 1024; - -struct Key { - key_hash: u64, - key_buffer: u32, - key_offset: usize, - key_length: u32, -} - -impl Key { - unsafe fn get<'k>(&self, key_data: &'k [Vec]) -> &'k [u8] { - let buf = key_data.get_unchecked(self.key_buffer as usize); - buf.get_unchecked(self.key_offset..self.key_offset + self.key_length as usize) - } -} +use std::sync::atomic::AtomicU64; +use arrow::bitmap::MutableBitmap; +use polars_row::EncodingField; +use polars_utils::cardinality_sketch::CardinalitySketch; +use polars_utils::idx_map::bytes_idx_map::{BytesIndexMap, Entry}; +use polars_utils::idx_vec::UnitVec; +use polars_utils::itertools::Itertools; +use polars_utils::unitvec; +use polars_utils::vec::PushUnchecked; -/// An IndexMap where the keys are always [u8] slices which are pre-hashed. -pub struct BytesIndexMap { - table: HashTable, - tuples: Vec<(Key, V)>, - key_data: Vec>, +use super::*; +use crate::hash_keys::HashKeys; - // Internal random seed used to keep hash iteration order decorrelated. - // We simply store a random odd number and multiply the canonical hash by it. - seed: u64, +#[derive(Default)] +pub struct RowEncodedChunkedIdxTable { + // These AtomicU64s actually are ChunkIds, but we use the top bit of the + // first chunk in each to mark keys during probing. + idx_map: BytesIndexMap>, + chunk_ctr: IdxSize, } -impl Default for BytesIndexMap { - fn default() -> Self { +impl RowEncodedChunkedIdxTable { + pub fn new() -> Self { Self { - table: HashTable::new(), - tuples: Vec::new(), - key_data: vec![Vec::with_capacity(BASE_KEY_DATA_CAPACITY)], - seed: rand::random::() | 1, + idx_map: BytesIndexMap::new(), + chunk_ctr: 0, } } } -impl BytesIndexMap { - pub fn new() -> Self { - Self::default() - } - - pub fn reserve(&mut self, additional: usize) { - self.table.reserve(additional, |i| unsafe { - let tuple = self.tuples.get_unchecked(*i as usize); - tuple.0.key_hash.wrapping_mul(self.seed) - }); - self.tuples.reserve(additional); +impl ChunkedIdxTable for RowEncodedChunkedIdxTable { + fn new_empty(&self) -> Box { + Box::new(Self::new()) } - pub fn entry<'k>(&mut self, key: &'k [u8], hash: u64) -> Entry<'_, 'k, V> { - let entry = self.table.entry( - hash.wrapping_mul(self.seed), - |i| unsafe { - let t = self.tuples.get_unchecked(*i as usize); - hash == t.0.key_hash && key == t.0.get(&self.key_data) - }, - |i| unsafe { - let t = self.tuples.get_unchecked(*i as usize); - t.0.key_hash.wrapping_mul(self.seed) - }, - ); - - match entry { - TEntry::Occupied(o) => Entry::Occupied(OccupiedEntry { - entry: o, - tuples: &mut self.tuples, - key_data: &mut self.key_data, - }), - TEntry::Vacant(v) => Entry::Vacant(VacantEntry { - key, - hash, - entry: v, - tuples: &mut self.tuples, - key_data: &mut self.key_data, - }), - } + fn reserve(&mut self, additional: usize) { + self.idx_map.reserve(additional); } -} - -pub enum Entry<'a, 'k, V> { - Occupied(OccupiedEntry<'a, V>), - Vacant(VacantEntry<'a, 'k, V>), -} -pub struct OccupiedEntry<'a, V> { - entry: TOccupiedEntry<'a, IdxSize>, - tuples: &'a mut Vec<(Key, V)>, - key_data: &'a mut Vec>, -} - -impl<'a, V> OccupiedEntry<'a, V> { - pub fn index(&self) -> IdxSize { - *self.entry.get() + fn num_keys(&self) -> IdxSize { + self.idx_map.len() } -} -pub struct VacantEntry<'a, 'k, V> { - key: &'k [u8], - hash: u64, - entry: TVacantEntry<'a, IdxSize>, - tuples: &'a mut Vec<(Key, V)>, - key_data: &'a mut Vec>, -} + fn insert_key_chunk(&mut self, hash_keys: HashKeys) { + let HashKeys::RowEncoded(keys) = hash_keys else { + unreachable!() + }; + if keys.keys.len() >= 1 << 31 { + panic!("overly large chunk in RowEncodedChunkedIdxTable"); + } -impl<'a, 'k, V> VacantEntry<'a, 'k, V> { - pub fn index(&self) -> IdxSize { - self.tuples.len() as IdxSize - } - - pub fn insert(self, value: V) -> &'a mut V { - unsafe { - let tuple_idx: IdxSize = self.tuples.len().try_into().unwrap(); + // for in keys.hashes + // group_idxs.clear(); + // group_idxs.reserve(keys.hashes.len()); + for (i, (hash, key)) in keys.hashes.values_iter().zip(keys.keys.iter()).enumerate_idx() { + if let Some(key) = key { + let chunk_id = AtomicU64::new(ChunkId::<_>::store(self.chunk_ctr, i).into_inner()); + match self.idx_map.entry(*hash, key) { + Entry::Occupied(o) => { o.into_mut().push(chunk_id); }, + Entry::Vacant(v) => { v.insert(unitvec![chunk_id]); }, + } - let mut num_buffers = self.key_data.len() as u32; - let mut active_buf = self.key_data.last_mut().unwrap_unchecked(); - let key_len = self.key.len(); - if active_buf.len() + key_len > active_buf.capacity() { - let ideal_next_cap = BASE_KEY_DATA_CAPACITY.checked_shl(num_buffers).unwrap(); - let next_capacity = std::cmp::max(ideal_next_cap, key_len); - self.key_data.push(Vec::with_capacity(next_capacity)); - active_buf = self.key_data.last_mut().unwrap_unchecked(); - num_buffers += 1; } - - let tuple_key = Key { - key_hash: self.hash, - key_buffer: num_buffers - 1, - key_offset: active_buf.len(), - key_length: self.key.len().try_into().unwrap(), - }; - self.tuples.push((tuple_key, value)); - active_buf.extend_from_slice(self.key); - self.entry.insert(tuple_idx); - &mut self.tuples.last_mut().unwrap_unchecked().1 } + + self.chunk_ctr = self.chunk_ctr.checked_add(1).unwrap(); } -} - - -/* -use hashbrown::hash_table::{Entry, HashTable}; -use polars_row::EncodingField; -use polars_utils::cardinality_sketch::CardinalitySketch; -use polars_utils::vec::PushUnchecked; - -use super::*; -use crate::hash_keys::HashKeys; - -const BASE_KEY_DATA_CAPACITY: usize = 1024; - -struct Key { - key_hash: u64, - key_buffer: u32, - key_offset: usize, - key_length: u32, -} - -impl Key { - unsafe fn get<'k>(&self, key_data: &'k [Vec]) -> &'k [u8] { - let buf = key_data.get_unchecked(self.key_buffer as usize); - buf.get_unchecked(self.key_offset..self.key_offset + self.key_length as usize) - } -} - -#[derive(Default)] -pub struct RowEncodedHashtupleer { - key_schema: Arc, - table: HashTable, - group_keys: Vec, - key_data: Vec>, - - // Internal random seed used to keep hash iteration order decorrelated. - // We simply store a random odd number and multiply the canonical hash by it. - seed: u64, -} - -impl RowEncodedHashGrouper { - pub fn new(key_schema: Arc) -> Self { - Self { - key_schema, - seed: rand::random::() | 1, - key_data: vec![Vec::with_capacity(BASE_KEY_DATA_CAPACITY)], - ..Default::default() - } + + fn probe(&self, keys: &HashKeys, table_match: &mut Vec, probe_match: &mut Vec, mark_matches: bool, limit: usize) -> usize { + todo!() } - - fn insert_key(&mut self, hash: u64, key: &[u8]) -> IdxSize { - let entry = self.table.entry( - hash.wrapping_mul(self.seed), - |g| unsafe { - let gk = self.group_keys.get_unchecked(*g as usize); - hash == gk.key_hash && key == gk.get(&self.key_data) - }, - |g| unsafe { - let gk = self.group_keys.get_unchecked(*g as usize); - gk.key_hash.wrapping_mul(self.seed) - }, - ); - - match entry { - Entry::Occupied(e) => *e.get(), - Entry::Vacant(e) => unsafe { - let mut num_buffers = self.key_data.len() as u32; - let mut active_buf = self.key_data.last_mut().unwrap_unchecked(); - let key_len = key.len(); - if active_buf.len() + key_len > active_buf.capacity() { - let ideal_next_cap = BASE_KEY_DATA_CAPACITY.checked_shl(num_buffers).unwrap(); - let next_capacity = std::cmp::max(ideal_next_cap, key_len); - self.key_data.push(Vec::with_capacity(next_capacity)); - active_buf = self.key_data.last_mut().unwrap_unchecked(); - num_buffers += 1; - } - - let group_idx: IdxSize = self.group_keys.len().try_into().unwrap(); - let group_key = Key { - key_hash: hash, - key_buffer: num_buffers - 1, - key_offset: active_buf.len(), - key_length: key.len().try_into().unwrap(), - }; - self.group_keys.push(group_key); - active_buf.extend_from_slice(key); - e.insert(group_idx); - group_idx - }, - } + + fn unmarked_keys(&self, out: &mut Vec) { + todo!() } - fn finalize_keys(&self, mut key_rows: Vec<&[u8]>) -> DataFrame { - let key_dtypes = self - .key_schema - .iter() - .map(|(_name, dt)| dt.to_physical().to_arrow(CompatLevel::newest())) - .collect::>(); - let fields = vec![EncodingField::new_unsorted(); key_dtypes.len()]; - let key_columns = - unsafe { polars_row::decode::decode_rows(&mut key_rows, &fields, &key_dtypes) }; - - let cols = self - .key_schema - .iter() - .zip(key_columns) - .map(|((name, dt), col)| { - let s = Series::try_from((name.clone(), col)).unwrap(); - unsafe { s.to_logical_repr_unchecked(dt) } - .unwrap() - .into_column() - }) - .collect(); - unsafe { DataFrame::new_no_checks_height_from_first(cols) } - } } -impl Grouper for RowEncodedHashGrouper { +/* +impl Grouper for RowEncodedChunkedIdxTable { fn new_empty(&self) -> Box { Box::new(Self::new(self.key_schema.clone())) } fn reserve(&mut self, additional: usize) { - self.table.reserve(additional, |g| unsafe { - let gk = self.group_keys.get_unchecked(*g as usize); - gk.key_hash.wrapping_mul(self.seed) - }); - self.group_keys.reserve(additional); + self.idx_map.reserve(additional); } fn num_groups(&self) -> IdxSize { - self.table.len() as IdxSize + self.idx_map.len() } fn insert_keys(&mut self, keys: HashKeys, group_idxs: &mut Vec) { @@ -291,17 +108,13 @@ impl Grouper for RowEncodedHashGrouper { let other = other.as_any().downcast_ref::().unwrap(); // TODO: cardinality estimation. - self.table.reserve(other.group_keys.len(), |g| unsafe { - let gk = self.group_keys.get_unchecked(*g as usize); - gk.key_hash.wrapping_mul(self.seed) - }); + self.idx_map.reserve(other.idx_map.len() as usize); unsafe { group_idxs.clear(); - group_idxs.reserve(other.table.len()); - for group_key in &other.group_keys { - let new_idx = self.insert_key(group_key.key_hash, group_key.get(&other.key_data)); - group_idxs.push_unchecked(new_idx); + group_idxs.reserve(other.idx_map.len() as usize); + for (hash, key) in other.idx_map.iter_hash_keys() { + group_idxs.push_unchecked(self.insert_key(hash, key)); } } } @@ -315,31 +128,26 @@ impl Grouper for RowEncodedHashGrouper { let other = other.as_any().downcast_ref::().unwrap(); // TODO: cardinality estimation. - self.table.reserve(subset.len(), |g| unsafe { - let gk = self.group_keys.get_unchecked(*g as usize); - gk.key_hash.wrapping_mul(self.seed) - }); - self.group_keys.reserve(subset.len()); + self.idx_map.reserve(subset.len()); unsafe { group_idxs.clear(); group_idxs.reserve(subset.len()); for i in subset { - let group_key = other.group_keys.get_unchecked(*i as usize); - let new_idx = self.insert_key(group_key.key_hash, group_key.get(&other.key_data)); - group_idxs.push_unchecked(new_idx); + let (hash, key, ()) = other.idx_map.get_index_unchecked(*i); + group_idxs.push_unchecked(self.insert_key(hash, key)); } } } fn get_keys_in_group_order(&self) -> DataFrame { - let mut key_rows: Vec<&[u8]> = Vec::with_capacity(self.table.len()); unsafe { - for group_key in &self.group_keys { - key_rows.push_unchecked(group_key.get(&self.key_data)); + let mut key_rows: Vec<&[u8]> = Vec::with_capacity(self.idx_map.len() as usize); + for (_, key) in self.idx_map.iter_hash_keys() { + key_rows.push_unchecked(key); } + self.finalize_keys(key_rows) } - self.finalize_keys(key_rows) } fn gen_partition_idxs( @@ -355,10 +163,10 @@ impl Grouper for RowEncodedHashGrouper { // Two-pass algorithm to prevent reallocations. let mut partition_sizes = vec![0; num_partitions]; unsafe { - for group_key in &self.group_keys { - let p_idx = partitioner.hash_to_partition(group_key.key_hash); + for (hash, _key) in self.idx_map.iter_hash_keys() { + let p_idx = partitioner.hash_to_partition(hash); *partition_sizes.get_unchecked_mut(p_idx) += 1; - sketches.get_unchecked_mut(p_idx).insert(group_key.key_hash); + sketches.get_unchecked_mut(p_idx).insert(hash); } } @@ -368,8 +176,8 @@ impl Grouper for RowEncodedHashGrouper { } unsafe { - for (i, group_key) in self.group_keys.iter().enumerate() { - let p_idx = partitioner.hash_to_partition(group_key.key_hash); + for (i, (hash, _key)) in self.idx_map.iter_hash_keys().enumerate() { + let p_idx = partitioner.hash_to_partition(hash); let p = partition_idxs.get_unchecked_mut(p_idx); p.push_unchecked(i as IdxSize); } diff --git a/crates/polars-expr/src/groups/row_encoded.rs b/crates/polars-expr/src/groups/row_encoded.rs index 8c25e1ff08f1..05b28ed45548 100644 --- a/crates/polars-expr/src/groups/row_encoded.rs +++ b/crates/polars-expr/src/groups/row_encoded.rs @@ -1,3 +1,4 @@ +use arrow::array::Array; use polars_row::EncodingField; use polars_utils::cardinality_sketch::CardinalitySketch; use polars_utils::idx_map::bytes_idx_map::{BytesIndexMap, Entry}; @@ -73,9 +74,12 @@ impl Grouper for RowEncodedHashGrouper { let HashKeys::RowEncoded(keys) = keys else { unreachable!() }; + assert!(!keys.hashes.has_nulls()); + assert!(!keys.keys.has_nulls()); + group_idxs.clear(); group_idxs.reserve(keys.hashes.len()); - for (hash, key) in keys.hashes.iter().zip(keys.keys.values_iter()) { + for (hash, key) in keys.hashes.values_iter().zip(keys.keys.values_iter()) { unsafe { group_idxs.push_unchecked(self.insert_key(*hash, key)); } diff --git a/crates/polars-expr/src/hash_keys.rs b/crates/polars-expr/src/hash_keys.rs index 90a35d0cfae8..a2753804e9b8 100644 --- a/crates/polars-expr/src/hash_keys.rs +++ b/crates/polars-expr/src/hash_keys.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use arrow::array::BinaryArray; +use arrow::array::{BinaryArray, PrimitiveArray, UInt64Array}; use arrow::compute::take::binary::take_unchecked; use arrow::compute::utils::combine_validities_and_many; use polars_core::frame::DataFrame; @@ -16,6 +16,7 @@ use polars_utils::IdxSize; /// Represents a DataFrame plus a hash per row, intended for keys in grouping /// or joining. The hashes may or may not actually be physically pre-computed, /// this depends per type. +#[derive(Clone)] pub enum HashKeys { RowEncoded(RowEncodedKeys), Single(SingleKeys), @@ -46,7 +47,7 @@ impl HashKeys { .map(|k| random_state.hash_one(k)) .collect(); Self::RowEncoded(RowEncodedKeys { - hashes: Arc::new(hashes), + hashes: PrimitiveArray::from_vec(hashes), keys: keys_encoded, }) } else { @@ -86,8 +87,9 @@ impl HashKeys { } } +#[derive(Clone)] pub struct RowEncodedKeys { - pub hashes: Arc>, + pub hashes: UInt64Array, pub keys: BinaryArray, } @@ -105,7 +107,7 @@ impl RowEncodedKeys { } if let Some(validity) = self.keys.validity() { - for (i, (h, is_v)) in self.hashes.iter().zip(validity).enumerate() { + for (i, (h, is_v)) in self.hashes.values_iter().zip(validity).enumerate() { if is_v { unsafe { // SAFETY: we assured the number of partitions matches. @@ -116,7 +118,7 @@ impl RowEncodedKeys { } } } else { - for (i, h) in self.hashes.iter().enumerate() { + for (i, h) in self.hashes.values_iter().enumerate() { unsafe { // SAFETY: we assured the number of partitions matches. let p = partitioner.hash_to_partition(*h); @@ -132,16 +134,17 @@ impl RowEncodedKeys { pub unsafe fn gather(&self, idxs: &[IdxSize]) -> Self { let mut hashes = Vec::with_capacity(idxs.len()); for idx in idxs { - hashes.push_unchecked(*self.hashes.get_unchecked(*idx as usize)); + hashes.push_unchecked(*self.hashes.values().get_unchecked(*idx as usize)); } let idx_arr = arrow::ffi::mmap::slice(idxs); let keys = take_unchecked(&self.keys, &idx_arr); - Self { hashes: Arc::new(hashes), keys } + Self { hashes: PrimitiveArray::from_vec(hashes), keys } } } /// Single keys. Does not pre-hash for boolean & integer types, only for strings /// and nested types. +#[derive(Clone)] pub struct SingleKeys { pub random_state: PlRandomState, pub hashes: Option>, diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index 652a2d966180..90c53445dba0 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -2,14 +2,17 @@ use std::sync::Arc; use polars_core::prelude::{PlHashSet, PlRandomState}; use polars_core::schema::Schema; +use polars_core::utils::accumulate_dataframes_vertical_unchecked; +use polars_expr::chunked_idx_table::{new_chunked_idx_table, ChunkedIdxTable}; use polars_expr::hash_keys::HashKeys; use polars_ops::frame::{JoinArgs, JoinType}; use polars_utils::cardinality_sketch::CardinalitySketch; use polars_utils::format_pl_smallstr; use polars_utils::hashing::HashPartitioner; use polars_utils::pl_str::PlSmallStr; +use rayon::prelude::*; -use crate::async_primitives::connector::Receiver; +use crate::async_primitives::connector::{Receiver, Sender}; use crate::nodes::compute_node_prelude::*; /// A payload selector contains for each column whether that column should be @@ -114,9 +117,88 @@ impl BuildState { Ok(()) } + + fn finalize(&mut self, table: &dyn ChunkedIdxTable) -> ProbeState { + let num_partitions = self.partitions_per_worker.len(); + let table_per_partition: Vec<_> = (0..num_partitions) + .into_par_iter() + .with_max_len(1) + .map(|p| { + // Estimate sizes and cardinality. + let mut sketch = CardinalitySketch::new(); + let mut num_frames = 0; + for worker in &self.partitions_per_worker { + sketch.combine(worker[p].sketch.as_ref().unwrap()); + num_frames += worker[p].frames.len(); + } + + // Build table for this partition. + let mut combined_frames = Vec::with_capacity(num_frames); + let mut table = table.new_empty(); + table.reserve(sketch.estimate() * 5 / 4); + for worker in &self.partitions_per_worker { + for (hash_keys, frame) in worker[p].hash_keys.iter().zip(&worker[p].frames) { + table.insert_key_chunk(hash_keys.clone()); + combined_frames.push(frame.clone()); + } + } + + let df = accumulate_dataframes_vertical_unchecked(combined_frames); + ProbeTable { table, df } + }) + .collect(); + + ProbeState { + table_per_partition + } + } } -struct ProbeState {} +struct ProbeTable { + // Important that df is not rechunked, the chunks it was inserted with + // into the table must be preserved for chunked gathers. + table: Box, + df: DataFrame, +} + +struct ProbeState { + table_per_partition: Vec, +} + +impl ProbeState { + // TODO: shuffle after partitioning and keep probe tables thread-local. + async fn partition_and_probe( + mut recv: Receiver, + mut send: Sender, + partitions: &[ProbeTable], + partitioner: HashPartitioner, + params: &EquiJoinParams, + ) -> PolarsResult<()> { + while let Ok(morsel) = recv.recv().await { + // let df = morsel.into_df(); + // let hash_keys = HashKeys::from_df(&df, params.random_state.clone(), params.args.join_nulls, true); + // let selector = if params.left_is_build { + // ¶ms.left_payload_select + // } else { + // ¶ms.right_payload_select + // }; + + // // We must rechunk the payload for later chunked gathers. + // let mut payload = select_payload(df, selector); + // payload.rechunk_mut(); + + // unsafe { + // hash_keys.gen_partition_idxs(&partitioner, &mut partition_idxs, &mut sketches); + // for (p, idxs_in_p) in partitions.iter_mut().zip(&partition_idxs) { + // p.hash_keys.push(hash_keys.gather(idxs_in_p)); + // p.frames.push(payload.take_slice_unchecked_impl(idxs_in_p, false)); + // } + // } + } + + Ok(()) + } +} enum EquiJoinState { Build(BuildState), @@ -156,6 +238,7 @@ pub struct EquiJoinNode { state: EquiJoinState, params: EquiJoinParams, num_pipelines: usize, + table: Box, } impl EquiJoinNode { @@ -166,6 +249,11 @@ impl EquiJoinNode { ) -> Self { // TODO: use cardinality estimation to determine this. let left_is_build = args.how != JoinType::Left; + let table = if left_is_build { + new_chunked_idx_table(left_input_schema.clone()) + } else { + new_chunked_idx_table(right_input_schema.clone()) + }; let left_payload_select = compute_payload_selector(&left_input_schema, &right_input_schema, true, &args); @@ -182,7 +270,8 @@ impl EquiJoinNode { right_payload_select, args, random_state: PlRandomState::new(), - } + }, + table } } } @@ -211,7 +300,7 @@ impl ComputeNode for EquiJoinNode { // If we are building and the build input is done, transition to probing. if let EquiJoinState::Build(build_state) = &mut self.state { if recv[build_idx] == PortState::Done { - todo!() + self.state = EquiJoinState::Probe(build_state.finalize(&*self.table)); } } @@ -273,7 +362,16 @@ impl ComputeNode for EquiJoinNode { }, EquiJoinState::Probe(probe_state) => { assert!(recv_ports[build_idx].is_none()); - todo!() + let receivers = recv_ports[probe_idx].take().unwrap().parallel(); + let senders = send_ports[0].take().unwrap().parallel(); + + let partitioner = HashPartitioner::new(self.num_pipelines, 0); + for (recv, send) in receivers.into_iter().zip(senders.into_iter()) { + join_handles.push(scope.spawn_task( + TaskPriority::High, + ProbeState::partition_and_probe(recv, send, &probe_state.table_per_partition, partitioner.clone(), &self.params), + )); + } }, EquiJoinState::Done => unreachable!(), } diff --git a/crates/polars-utils/src/index.rs b/crates/polars-utils/src/index.rs index fb43a1958cd6..8ef3c05cee4d 100644 --- a/crates/polars-utils/src/index.rs +++ b/crates/polars-utils/src/index.rs @@ -236,6 +236,14 @@ impl ChunkId { pub fn inner_mut(&mut self) -> &mut u64 { &mut self.swizzled } + + pub fn from_inner(inner: u64) -> Self { + Self { swizzled: inner } + } + + pub fn into_inner(self) -> u64 { + self.swizzled + } } #[cfg(test)] From 6262174ccf43742a4582daaa0fbb151b1da8250b Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Tue, 19 Nov 2024 17:19:40 +0100 Subject: [PATCH 06/21] wip --- .../polars-expr/src/chunked_idx_table/mod.rs | 39 ++- .../src/chunked_idx_table/row_encoded.rs | 325 +++++++++++------- crates/polars-expr/src/hash_keys.rs | 27 +- .../src/nodes/joins/equi_join.rs | 164 +++++++-- .../polars-utils/src/idx_map/bytes_idx_map.rs | 18 + 5 files changed, 402 insertions(+), 171 deletions(-) diff --git a/crates/polars-expr/src/chunked_idx_table/mod.rs b/crates/polars-expr/src/chunked_idx_table/mod.rs index 2386a7d45f47..14c7e036bb31 100644 --- a/crates/polars-expr/src/chunked_idx_table/mod.rs +++ b/crates/polars-expr/src/chunked_idx_table/mod.rs @@ -8,7 +8,6 @@ use crate::hash_keys::HashKeys; mod row_encoded; - pub trait ChunkedIdxTable: Any + Send + Sync { /// Creates a new empty ChunkedIdxTable similar to this one. fn new_empty(&self) -> Box; @@ -21,17 +20,45 @@ pub trait ChunkedIdxTable: Any + Send + Sync { /// Inserts the given key chunk into this ChunkedIdxTable. fn insert_key_chunk(&mut self, keys: HashKeys); - + /// Probe the table, updating table_match and probe_match with /// (ChunkId, IdxSize) pairs for each match. Will stop processing new keys - /// once limit matches have been generated, returning the number of keys processed. - fn probe(&self, hash_keys: &HashKeys, table_match: &mut Vec, probe_match: &mut Vec, mark_matches: bool, limit: usize) -> usize; + /// once limit matches have been generated, returning the number of keys + /// processed. + /// + /// If mark_matches is true, matches are marked in the table as such. + /// + /// If emit_unmatched is true, for keys that do not have a match we emit a + /// match with ChunkId::null() on the table match. + fn probe( + &self, + hash_keys: &HashKeys, + table_match: &mut Vec>, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + limit: IdxSize, + ) -> IdxSize; + /// The same as probe, except it will only apply to the specified subset of keys. + /// # Safety + /// The provided subset indices must be in-bounds. + unsafe fn probe_subset( + &self, + hash_keys: &HashKeys, + subset: &[IdxSize], + table_match: &mut Vec>, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + limit: IdxSize, + ) -> IdxSize; + /// Get the ChunkIds for each key which was never marked during probing. - fn unmarked_keys(&self, out: &mut Vec); + fn unmarked_keys(&self, out: &mut Vec>); } pub fn new_chunked_idx_table(key_schema: Arc) -> Box { // Box::new(row_encoded::BytesIndexMap::new(key_schema)) todo!() -} \ No newline at end of file +} diff --git a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs index 4f39322e18a2..950e6a03844a 100644 --- a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs +++ b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs @@ -1,13 +1,10 @@ -use std::sync::atomic::AtomicU64; +use std::sync::atomic::{AtomicU64, Ordering}; -use arrow::bitmap::MutableBitmap; -use polars_row::EncodingField; -use polars_utils::cardinality_sketch::CardinalitySketch; +use arrow::array::Array; use polars_utils::idx_map::bytes_idx_map::{BytesIndexMap, Entry}; use polars_utils::idx_vec::UnitVec; use polars_utils::itertools::Itertools; use polars_utils::unitvec; -use polars_utils::vec::PushUnchecked; use super::*; use crate::hash_keys::HashKeys; @@ -17,7 +14,7 @@ pub struct RowEncodedChunkedIdxTable { // These AtomicU64s actually are ChunkIds, but we use the top bit of the // first chunk in each to mark keys during probing. idx_map: BytesIndexMap>, - chunk_ctr: IdxSize, + chunk_ctr: u32, } impl RowEncodedChunkedIdxTable { @@ -29,6 +26,97 @@ impl RowEncodedChunkedIdxTable { } } +impl RowEncodedChunkedIdxTable { + #[inline(always)] + fn probe_one( + &self, + hash: u64, + key: &[u8], + key_idx: IdxSize, + table_match: &mut Vec>, + probe_match: &mut Vec, + ) { + if let Some(chunk_ids) = self.idx_map.get(hash, key) { + for chunk_id in &chunk_ids[..] { + // Create matches, making sure to clear top bit. + let raw_chunk_id = chunk_id.load(Ordering::Relaxed); + let chunk_id = ChunkId::from_inner(raw_chunk_id & !(1 << 63)); + table_match.push(chunk_id); + probe_match.push(key_idx); + } + + // Mark if necessary. This action is idempotent so doesn't + // need any synchronization on the load, nor does it need a + // fetch_or to do it atomically. + if MARK_MATCHES { + let first_chunk_id = unsafe { chunk_ids.get_unchecked(0) }; + let first_chunk_val = first_chunk_id.load(Ordering::Relaxed); + if first_chunk_val >> 63 == 0 { + first_chunk_id.store(first_chunk_val | (1 << 63), Ordering::Release); + } + } + } else if EMIT_UNMATCHED { + table_match.push(ChunkId::null()); + probe_match.push(key_idx); + } + } + + fn probe_impl<'a, const MARK_MATCHES: bool, const EMIT_UNMATCHED: bool>( + &self, + hash_keys: impl Iterator)>, + table_match: &mut Vec>, + probe_match: &mut Vec, + limit: IdxSize, + ) -> IdxSize { + table_match.clear(); + probe_match.clear(); + + let mut keys_processed = 0; + for (hash, key) in hash_keys { + if let Some(key) = key { + self.probe_one::( + hash, + key, + keys_processed, + table_match, + probe_match, + ); + } + + keys_processed += 1; + if table_match.len() >= limit as usize { + break; + } + } + keys_processed + } + + fn probe_dispatch<'a>( + &self, + hash_keys: impl Iterator)>, + table_match: &mut Vec>, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + limit: IdxSize, + ) -> IdxSize { + match (mark_matches, emit_unmatched) { + (false, false) => { + self.probe_impl::(hash_keys, table_match, probe_match, limit) + }, + (false, true) => { + self.probe_impl::(hash_keys, table_match, probe_match, limit) + }, + (true, false) => { + self.probe_impl::(hash_keys, table_match, probe_match, limit) + }, + (true, true) => { + self.probe_impl::(hash_keys, table_match, probe_match, limit) + }, + } + } +} + impl ChunkedIdxTable for RowEncodedChunkedIdxTable { fn new_empty(&self) -> Box { Box::new(Self::new()) @@ -43,149 +131,138 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { } fn insert_key_chunk(&mut self, hash_keys: HashKeys) { - let HashKeys::RowEncoded(keys) = hash_keys else { + let HashKeys::RowEncoded(hash_keys) = hash_keys else { unreachable!() }; - if keys.keys.len() >= 1 << 31 { + if hash_keys.keys.len() >= 1 << 31 { panic!("overly large chunk in RowEncodedChunkedIdxTable"); } - // for in keys.hashes - // group_idxs.clear(); - // group_idxs.reserve(keys.hashes.len()); - for (i, (hash, key)) in keys.hashes.values_iter().zip(keys.keys.iter()).enumerate_idx() { + for (i, (hash, key)) in hash_keys + .hashes + .values_iter() + .zip(hash_keys.keys.iter()) + .enumerate_idx() + { if let Some(key) = key { - let chunk_id = AtomicU64::new(ChunkId::<_>::store(self.chunk_ctr, i).into_inner()); + let chunk_id = + AtomicU64::new(ChunkId::<32>::store(self.chunk_ctr as IdxSize, i).into_inner()); match self.idx_map.entry(*hash, key) { - Entry::Occupied(o) => { o.into_mut().push(chunk_id); }, - Entry::Vacant(v) => { v.insert(unitvec![chunk_id]); }, + Entry::Occupied(o) => { + o.into_mut().push(chunk_id); + }, + Entry::Vacant(v) => { + v.insert(unitvec![chunk_id]); + }, } - } } - - self.chunk_ctr = self.chunk_ctr.checked_add(1).unwrap(); - } - - fn probe(&self, keys: &HashKeys, table_match: &mut Vec, probe_match: &mut Vec, mark_matches: bool, limit: usize) -> usize { - todo!() - } - - fn unmarked_keys(&self, out: &mut Vec) { - todo!() - } - -} - -/* -impl Grouper for RowEncodedChunkedIdxTable { - fn new_empty(&self) -> Box { - Box::new(Self::new(self.key_schema.clone())) - } - fn reserve(&mut self, additional: usize) { - self.idx_map.reserve(additional); - } - - fn num_groups(&self) -> IdxSize { - self.idx_map.len() + self.chunk_ctr = self.chunk_ctr.checked_add(1).unwrap(); } - fn insert_keys(&mut self, keys: HashKeys, group_idxs: &mut Vec) { - let HashKeys::RowEncoded(keys) = keys else { + fn probe( + &self, + hash_keys: &HashKeys, + table_match: &mut Vec>, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + limit: IdxSize, + ) -> IdxSize { + let HashKeys::RowEncoded(hash_keys) = hash_keys else { unreachable!() }; - group_idxs.clear(); - group_idxs.reserve(keys.hashes.len()); - for (hash, key) in keys.hashes.iter().zip(keys.keys.values_iter()) { - unsafe { - group_idxs.push_unchecked(self.insert_key(*hash, key)); - } - } - } - - fn combine(&mut self, other: &dyn Grouper, group_idxs: &mut Vec) { - let other = other.as_any().downcast_ref::().unwrap(); - // TODO: cardinality estimation. - self.idx_map.reserve(other.idx_map.len() as usize); - - unsafe { - group_idxs.clear(); - group_idxs.reserve(other.idx_map.len() as usize); - for (hash, key) in other.idx_map.iter_hash_keys() { - group_idxs.push_unchecked(self.insert_key(hash, key)); - } + if hash_keys.keys.has_nulls() { + let iter = hash_keys + .hashes + .values_iter() + .copied() + .zip(hash_keys.keys.iter()); + self.probe_dispatch( + iter, + table_match, + probe_match, + mark_matches, + emit_unmatched, + limit, + ) + } else { + let iter = hash_keys + .hashes + .values_iter() + .copied() + .zip(hash_keys.keys.values_iter().map(Some)); + self.probe_dispatch( + iter, + table_match, + probe_match, + mark_matches, + emit_unmatched, + limit, + ) } } - unsafe fn gather_combine( - &mut self, - other: &dyn Grouper, + unsafe fn probe_subset( + &self, + hash_keys: &HashKeys, subset: &[IdxSize], - group_idxs: &mut Vec, - ) { - let other = other.as_any().downcast_ref::().unwrap(); - - // TODO: cardinality estimation. - self.idx_map.reserve(subset.len()); - - unsafe { - group_idxs.clear(); - group_idxs.reserve(subset.len()); - for i in subset { - let (hash, key, ()) = other.idx_map.get_index_unchecked(*i); - group_idxs.push_unchecked(self.insert_key(hash, key)); - } - } - } + table_match: &mut Vec>, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + limit: IdxSize, + ) -> IdxSize { + let HashKeys::RowEncoded(hash_keys) = hash_keys else { + unreachable!() + }; - fn get_keys_in_group_order(&self) -> DataFrame { - unsafe { - let mut key_rows: Vec<&[u8]> = Vec::with_capacity(self.idx_map.len() as usize); - for (_, key) in self.idx_map.iter_hash_keys() { - key_rows.push_unchecked(key); - } - self.finalize_keys(key_rows) + if hash_keys.keys.has_nulls() { + let iter = subset.iter().map(|i| { + ( + hash_keys.hashes.value_unchecked(*i as usize), + hash_keys.keys.get_unchecked(*i as usize), + ) + }); + self.probe_dispatch( + iter, + table_match, + probe_match, + mark_matches, + emit_unmatched, + limit, + ) + } else { + let iter = subset.iter().map(|i| { + ( + hash_keys.hashes.value_unchecked(*i as usize), + Some(hash_keys.keys.value_unchecked(*i as usize)), + ) + }); + self.probe_dispatch( + iter, + table_match, + probe_match, + mark_matches, + emit_unmatched, + limit, + ) } } - fn gen_partition_idxs( - &self, - partitioner: &HashPartitioner, - partition_idxs: &mut [Vec], - sketches: &mut [CardinalitySketch], - ) { - let num_partitions = partitioner.num_partitions(); - assert!(partition_idxs.len() == num_partitions); - assert!(sketches.len() == num_partitions); - - // Two-pass algorithm to prevent reallocations. - let mut partition_sizes = vec![0; num_partitions]; - unsafe { - for (hash, _key) in self.idx_map.iter_hash_keys() { - let p_idx = partitioner.hash_to_partition(hash); - *partition_sizes.get_unchecked_mut(p_idx) += 1; - sketches.get_unchecked_mut(p_idx).insert(hash); - } - } - - for (partition, sz) in partition_idxs.iter_mut().zip(partition_sizes) { - partition.clear(); - partition.reserve(sz); - } - - unsafe { - for (i, (hash, _key)) in self.idx_map.iter_hash_keys().enumerate() { - let p_idx = partitioner.hash_to_partition(hash); - let p = partition_idxs.get_unchecked_mut(p_idx); - p.push_unchecked(i as IdxSize); + fn unmarked_keys(&self, out: &mut Vec>) { + for chunk_ids in self.idx_map.iter_values() { + let first_chunk_id = unsafe { chunk_ids.get_unchecked(0) }; + let first_chunk_val = first_chunk_id.load(Ordering::Acquire); + if first_chunk_val >> 63 == 0 { + for chunk_id in &chunk_ids[..] { + let raw_chunk_id = chunk_id.load(Ordering::Relaxed); + let chunk_id = ChunkId::from_inner(raw_chunk_id & !(1 << 63)); + out.push(chunk_id); + } } } } - - fn as_any(&self) -> &dyn Any { - self - } } -*/ \ No newline at end of file diff --git a/crates/polars-expr/src/hash_keys.rs b/crates/polars-expr/src/hash_keys.rs index a2753804e9b8..7c66130f2dab 100644 --- a/crates/polars-expr/src/hash_keys.rs +++ b/crates/polars-expr/src/hash_keys.rs @@ -71,9 +71,16 @@ impl HashKeys { partition_idxs: &mut [Vec], sketches: &mut [CardinalitySketch], ) { - match self { - Self::RowEncoded(s) => s.gen_partition_idxs(partitioner, partition_idxs, sketches), - Self::Single(s) => s.gen_partition_idxs(partitioner, partition_idxs, sketches), + if sketches.is_empty() { + match self { + Self::RowEncoded(s) => s.gen_partition_idxs::(partitioner, partition_idxs, sketches), + Self::Single(s) => s.gen_partition_idxs::(partitioner, partition_idxs, sketches), + } + } else { + match self { + Self::RowEncoded(s) => s.gen_partition_idxs::(partitioner, partition_idxs, sketches), + Self::Single(s) => s.gen_partition_idxs::(partitioner, partition_idxs, sketches), + } } } @@ -94,14 +101,14 @@ pub struct RowEncodedKeys { } impl RowEncodedKeys { - pub fn gen_partition_idxs( + pub fn gen_partition_idxs( &self, partitioner: &HashPartitioner, partition_idxs: &mut [Vec], sketches: &mut [CardinalitySketch], ) { assert!(partition_idxs.len() == partitioner.num_partitions()); - assert!(sketches.len() == partitioner.num_partitions()); + assert!(BUILD_SKETCHES && sketches.len() == partitioner.num_partitions()); for p in partition_idxs.iter_mut() { p.clear(); } @@ -113,7 +120,9 @@ impl RowEncodedKeys { // SAFETY: we assured the number of partitions matches. let p = partitioner.hash_to_partition(*h); partition_idxs.get_unchecked_mut(p).push(i as IdxSize); - sketches.get_unchecked_mut(p).insert(*h); + if BUILD_SKETCHES { + sketches.get_unchecked_mut(p).insert(*h); + } } } } @@ -123,7 +132,9 @@ impl RowEncodedKeys { // SAFETY: we assured the number of partitions matches. let p = partitioner.hash_to_partition(*h); partition_idxs.get_unchecked_mut(p).push(i as IdxSize); - sketches.get_unchecked_mut(p).insert(*h); + if BUILD_SKETCHES { + sketches.get_unchecked_mut(p).insert(*h); + } } } } @@ -152,7 +163,7 @@ pub struct SingleKeys { } impl SingleKeys { - pub fn gen_partition_idxs( + pub fn gen_partition_idxs( &self, partitioner: &HashPartitioner, partition_idxs: &mut [Vec], diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index 90c53445dba0..73af8e73cd8b 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -2,17 +2,20 @@ use std::sync::Arc; use polars_core::prelude::{PlHashSet, PlRandomState}; use polars_core::schema::Schema; +use polars_core::series::IsSorted; use polars_core::utils::accumulate_dataframes_vertical_unchecked; use polars_expr::chunked_idx_table::{new_chunked_idx_table, ChunkedIdxTable}; use polars_expr::hash_keys::HashKeys; use polars_ops::frame::{JoinArgs, JoinType}; +use polars_ops::prelude::TakeChunked; use polars_utils::cardinality_sketch::CardinalitySketch; -use polars_utils::format_pl_smallstr; use polars_utils::hashing::HashPartitioner; use polars_utils::pl_str::PlSmallStr; +use polars_utils::{format_pl_smallstr, IdxSize}; use rayon::prelude::*; use crate::async_primitives::connector::{Receiver, Sender}; +use crate::morsel::get_ideal_morsel_size; use crate::nodes::compute_node_prelude::*; /// A payload selector contains for each column whether that column should be @@ -86,12 +89,17 @@ impl BuildState { ) -> PolarsResult<()> { let mut partition_idxs = vec![Vec::new(); partitioner.num_partitions()]; partitions.resize_with(partitioner.num_partitions(), BuildPartition::default); - + let mut sketches = vec![CardinalitySketch::default(); partitioner.num_partitions()]; while let Ok(morsel) = recv.recv().await { let df = morsel.into_df(); - let hash_keys = HashKeys::from_df(&df, params.random_state.clone(), params.args.join_nulls, true); + let hash_keys = HashKeys::from_df( + &df, + params.random_state.clone(), + params.args.join_nulls, + true, + ); let selector = if params.left_is_build { ¶ms.left_payload_select } else { @@ -101,23 +109,24 @@ impl BuildState { // We must rechunk the payload for later chunked gathers. let mut payload = select_payload(df, selector); payload.rechunk_mut(); - + unsafe { hash_keys.gen_partition_idxs(&partitioner, &mut partition_idxs, &mut sketches); for (p, idxs_in_p) in partitions.iter_mut().zip(&partition_idxs) { p.hash_keys.push(hash_keys.gather(idxs_in_p)); - p.frames.push(payload.take_slice_unchecked_impl(idxs_in_p, false)); + p.frames + .push(payload.take_slice_unchecked_impl(idxs_in_p, false)); } } } - + for (p, sketch) in sketches.into_iter().enumerate() { partitions[p].sketch = Some(sketch); } - + Ok(()) } - + fn finalize(&mut self, table: &dyn ChunkedIdxTable) -> ProbeState { let num_partitions = self.partitions_per_worker.len(); let table_per_partition: Vec<_> = (0..num_partitions) @@ -147,9 +156,9 @@ impl BuildState { ProbeTable { table, df } }) .collect(); - + ProbeState { - table_per_partition + table_per_partition, } } } @@ -174,35 +183,108 @@ impl ProbeState { partitioner: HashPartitioner, params: &EquiJoinParams, ) -> PolarsResult<()> { + let mut partition_idxs = Vec::new(); + let mut table_match = Vec::new(); + let mut probe_match = Vec::new(); + + let probe_limit = get_ideal_morsel_size() as IdxSize; + let mark_matches = params.emit_unmatched_build(); + let emit_unmatched = params.emit_unmatched_probe(); + while let Ok(morsel) = recv.recv().await { - // let df = morsel.into_df(); - // let hash_keys = HashKeys::from_df(&df, params.random_state.clone(), params.args.join_nulls, true); - // let selector = if params.left_is_build { - // ¶ms.left_payload_select - // } else { - // ¶ms.right_payload_select - // }; - - // // We must rechunk the payload for later chunked gathers. - // let mut payload = select_payload(df, selector); - // payload.rechunk_mut(); - - // unsafe { - // hash_keys.gen_partition_idxs(&partitioner, &mut partition_idxs, &mut sketches); - // for (p, idxs_in_p) in partitions.iter_mut().zip(&partition_idxs) { - // p.hash_keys.push(hash_keys.gather(idxs_in_p)); - // p.frames.push(payload.take_slice_unchecked_impl(idxs_in_p, false)); - // } - // } + let (df, seq, src_token, wait_token) = morsel.into_inner(); + let hash_keys = HashKeys::from_df( + &df, + params.random_state.clone(), + params.args.join_nulls, + true, + ); + let selector = if params.left_is_build { + ¶ms.right_payload_select + } else { + ¶ms.left_payload_select + }; + let payload = select_payload(df, selector); + + unsafe { + hash_keys.gen_partition_idxs(&partitioner, &mut partition_idxs, &mut []); + for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { + let mut offset = 0; + while let Some(idxs_in_p_slice) = idxs_in_p.get(offset as usize..) { + offset += p.table.probe_subset( + &hash_keys, + idxs_in_p_slice, + &mut table_match, + &mut probe_match, + mark_matches, + emit_unmatched, + probe_limit, + ); + let mut build_df = if emit_unmatched { + p.df.take_opt_chunked_unchecked(&table_match) + } else { + p.df.take_chunked_unchecked(&table_match, IsSorted::Not) + }; + let mut probe_df = payload.take_slice_unchecked(&probe_match); + + let out_df = if params.left_is_build { + build_df.hstack_mut_unchecked(probe_df.get_columns()); + build_df + } else { + probe_df.hstack_mut_unchecked(build_df.get_columns()); + probe_df + }; + + let out_morsel = Morsel::new(out_df, seq, src_token.clone()); + if send.send(out_morsel).await.is_err() { + break; + } + } + } + } + + drop(wait_token); } - + Ok(()) } + + async fn emit_unmatched( + mut send: Sender, + partitions: &[ProbeTable], + params: &EquiJoinParams, + ) -> PolarsResult<()> { + let source_token = SourceToken::new(); + let mut unmarked_idxs = Vec::new(); + unsafe { + for p in partitions { + p.table.unmarked_keys(&mut unmarked_idxs); + let build_df = p.df.take_chunked_unchecked(&table_match, IsSorted::Not); + + let out_df = if params.left_is_build { + build_df.hstack_mut_unchecked(probe_df.get_columns()); + build_df + } else { + probe_df.hstack_mut_unchecked(build_df.get_columns()); + probe_df + }; + + + + let ideal_morsel_count = (len / get_ideal_morsel_size()).max(1); + let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines); + self.morsel_size = len.div_ceil(morsel_count).max(1); + + + } + } + } } enum EquiJoinState { Build(BuildState), Probe(ProbeState), + EmitUnmatchedBuild(ProbeState), Done, } @@ -271,7 +353,7 @@ impl EquiJoinNode { args, random_state: PlRandomState::new(), }, - table + table, } } } @@ -315,6 +397,11 @@ impl ComputeNode for EquiJoinNode { recv[probe_idx] = PortState::Ready; send[0] = PortState::Ready; }, + EquiJoinState::EmitUnmatchedBuild(_) => { + recv[build_idx] = PortState::Done; + recv[probe_idx] = PortState::Done; + send[0] = PortState::Ready; + }, EquiJoinState::Done => { recv[0] = PortState::Done; recv[1] = PortState::Done; @@ -356,7 +443,12 @@ impl ComputeNode for EquiJoinNode { { join_handles.push(scope.spawn_task( TaskPriority::High, - BuildState::partition_and_sink(recv, worker_ps, partitioner.clone(), &self.params), + BuildState::partition_and_sink( + recv, + worker_ps, + partitioner.clone(), + &self.params, + ), )); } }, @@ -369,7 +461,13 @@ impl ComputeNode for EquiJoinNode { for (recv, send) in receivers.into_iter().zip(senders.into_iter()) { join_handles.push(scope.spawn_task( TaskPriority::High, - ProbeState::partition_and_probe(recv, send, &probe_state.table_per_partition, partitioner.clone(), &self.params), + ProbeState::partition_and_probe( + recv, + send, + &probe_state.table_per_partition, + partitioner.clone(), + &self.params, + ), )); } }, diff --git a/crates/polars-utils/src/idx_map/bytes_idx_map.rs b/crates/polars-utils/src/idx_map/bytes_idx_map.rs index 0df1bbcd9c0c..61848af1d8df 100644 --- a/crates/polars-utils/src/idx_map/bytes_idx_map.rs +++ b/crates/polars-utils/src/idx_map/bytes_idx_map.rs @@ -62,6 +62,17 @@ impl BytesIndexMap { pub fn is_empty(&self) -> bool { self.table.is_empty() } + + pub fn get(&self, hash: u64, key: &[u8]) -> Option<&V> { + let idx = self.table.find( + hash.wrapping_mul(self.seed), + |i| unsafe { + let t = self.tuples.get_unchecked(*i as usize); + hash == t.0.key_hash && key == t.0.get(&self.key_data) + }, + )?; + unsafe { Some(&self.tuples.get_unchecked(*idx as usize).1) } + } pub fn entry<'k>(&mut self, hash: u64, key: &'k [u8]) -> Entry<'_, 'k, V> { let entry = self.table.entry( @@ -106,6 +117,13 @@ impl BytesIndexMap { .iter() .map(|t| unsafe { (t.0.key_hash, t.0.get(&self.key_data)) }) } + + /// Iterates over the values in insertion order. + pub fn iter_values(&self) -> impl Iterator { + self.tuples + .iter() + .map(|t| &t.1) + } } pub enum Entry<'a, 'k, V> { From a58fbcb76a67168195a82c5b9d2efe5d6528829d Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Tue, 19 Nov 2024 18:27:05 +0100 Subject: [PATCH 07/21] almost there... just need to select key columns and hook it up --- crates/polars-core/src/datatypes/field.rs | 6 + crates/polars-core/src/frame/mod.rs | 9 ++ .../polars-expr/src/chunked_idx_table/mod.rs | 2 +- .../src/chunked_idx_table/row_encoded.rs | 10 +- .../src/nodes/joins/equi_join.rs | 126 ++++++++++++++---- crates/polars-stream/src/physical_plan/fmt.rs | 8 +- crates/polars-stream/src/physical_plan/mod.rs | 10 +- .../src/physical_plan/to_graph.rs | 25 ++++ .../polars-utils/src/idx_map/bytes_idx_map.rs | 8 ++ 9 files changed, 175 insertions(+), 29 deletions(-) diff --git a/crates/polars-core/src/datatypes/field.rs b/crates/polars-core/src/datatypes/field.rs index b85caeec0a2e..7ff81d7277ea 100644 --- a/crates/polars-core/src/datatypes/field.rs +++ b/crates/polars-core/src/datatypes/field.rs @@ -96,6 +96,12 @@ impl Field { pub fn set_name(&mut self, name: PlSmallStr) { self.name = name; } + + /// Returns this `Field`, renamed. + pub fn with_name(mut self, name: PlSmallStr) -> Self { + self.name = name; + self + } /// Converts the `Field` to an `arrow::datatypes::Field`. /// diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 65c29e306792..1b601b2fa36f 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -380,6 +380,15 @@ impl DataFrame { unsafe { DataFrame::new_no_checks(0, cols) } } + /// Create a new `DataFrame` with the given schema, only containing nulls. + pub fn full_null(schema: &Schema, height: usize) -> Self { + let columns = schema + .iter_fields() + .map(|f| Column::full_null(f.name.clone(), height, f.dtype())) + .collect(); + DataFrame { height, columns } + } + /// Removes the last `Series` from the `DataFrame` and returns it, or [`None`] if it is empty. /// /// # Example diff --git a/crates/polars-expr/src/chunked_idx_table/mod.rs b/crates/polars-expr/src/chunked_idx_table/mod.rs index 14c7e036bb31..141072083d8a 100644 --- a/crates/polars-expr/src/chunked_idx_table/mod.rs +++ b/crates/polars-expr/src/chunked_idx_table/mod.rs @@ -55,7 +55,7 @@ pub trait ChunkedIdxTable: Any + Send + Sync { ) -> IdxSize; /// Get the ChunkIds for each key which was never marked during probing. - fn unmarked_keys(&self, out: &mut Vec>); + fn unmarked_keys(&self, out: &mut Vec>, offset: IdxSize, limit: IdxSize); } pub fn new_chunked_idx_table(key_schema: Arc) -> Box { diff --git a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs index 950e6a03844a..d82bb7d4425c 100644 --- a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs +++ b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs @@ -252,8 +252,10 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { } } - fn unmarked_keys(&self, out: &mut Vec>) { - for chunk_ids in self.idx_map.iter_values() { + fn unmarked_keys(&self, out: &mut Vec>, offset: IdxSize, limit: IdxSize) { + out.clear(); + + while let Some((_, _, chunk_ids)) = self.idx_map.get_index(offset) { let first_chunk_id = unsafe { chunk_ids.get_unchecked(0) }; let first_chunk_val = first_chunk_id.load(Ordering::Acquire); if first_chunk_val >> 63 == 0 { @@ -263,6 +265,10 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { out.push(chunk_id); } } + + if out.len() >= limit as usize { + break; + } } } } diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index 73af8e73cd8b..4087b4796856 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use polars_core::prelude::{PlHashSet, PlRandomState}; -use polars_core::schema::Schema; +use polars_core::schema::{Schema, SchemaExt}; use polars_core::series::IsSorted; use polars_core::utils::accumulate_dataframes_vertical_unchecked; use polars_expr::chunked_idx_table::{new_chunked_idx_table, ChunkedIdxTable}; @@ -15,7 +15,8 @@ use polars_utils::{format_pl_smallstr, IdxSize}; use rayon::prelude::*; use crate::async_primitives::connector::{Receiver, Sender}; -use crate::morsel::get_ideal_morsel_size; +use crate::async_primitives::wait_group::WaitGroup; +use crate::morsel::{get_ideal_morsel_size, SourceToken}; use crate::nodes::compute_node_prelude::*; /// A payload selector contains for each column whether that column should be @@ -56,6 +57,13 @@ fn compute_payload_selector( .collect() } +fn select_schema(schema: &Schema, selector: &[Option]) -> Schema { + schema.iter_fields() + .zip(selector) + .filter_map(|(f, name)| Some(f.with_name(name.clone()?))) + .collect() +} + fn select_payload(df: DataFrame, selector: &[Option]) -> DataFrame { // Maintain height of zero-width dataframes. if df.width() == 0 { @@ -248,43 +256,77 @@ impl ProbeState { Ok(()) } +} + +struct EmitUnmatchedState { + partitions: Vec, + active_partition_idx: usize, + offset_in_active_p: usize, +} +impl EmitUnmatchedState { async fn emit_unmatched( + &mut self, mut send: Sender, - partitions: &[ProbeTable], params: &EquiJoinParams, + num_pipelines: usize, ) -> PolarsResult<()> { + let total_len: usize = self.partitions.iter().map(|p| p.table.num_keys() as usize).sum(); + let ideal_morsel_count = (total_len / get_ideal_morsel_size()).max(1); + let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines); + let morsel_size = total_len.div_ceil(morsel_count).max(1); + + let mut morsel_seq = MorselSeq::default(); + let wait_group = WaitGroup::default(); let source_token = SourceToken::new(); let mut unmarked_idxs = Vec::new(); - unsafe { - for p in partitions { - p.table.unmarked_keys(&mut unmarked_idxs); - let build_df = p.df.take_chunked_unchecked(&table_match, IsSorted::Not); + while let Some(p) = self.partitions.get(self.active_partition_idx) { + loop { + p.table.unmarked_keys(&mut unmarked_idxs, self.offset_in_active_p as IdxSize, morsel_size as IdxSize); + self.offset_in_active_p += unmarked_idxs.len(); + if unmarked_idxs.is_empty() { + break; + } - let out_df = if params.left_is_build { - build_df.hstack_mut_unchecked(probe_df.get_columns()); - build_df - } else { - probe_df.hstack_mut_unchecked(build_df.get_columns()); - probe_df + let out_df = unsafe { + let mut build_df = p.df.take_chunked_unchecked(&unmarked_idxs, IsSorted::Not); + let len = build_df.height(); + if params.left_is_build { + let probe_df = DataFrame::full_null(¶ms.right_payload_schema, len); + build_df.hstack_mut_unchecked(probe_df.get_columns()); + build_df + } else { + let mut probe_df = DataFrame::full_null(¶ms.left_payload_schema, len); + probe_df.hstack_mut_unchecked(build_df.get_columns()); + probe_df + } }; + let mut morsel = Morsel::new(out_df, morsel_seq, source_token.clone()); + morsel_seq = morsel_seq.successor(); + morsel.set_consume_token(wait_group.token()); + if send.send(morsel).await.is_err() { + return Ok(()); + } - - let ideal_morsel_count = (len / get_ideal_morsel_size()).max(1); - let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines); - self.morsel_size = len.div_ceil(morsel_count).max(1); - - + wait_group.wait().await; + if source_token.stop_requested() { + return Ok(()); + } } + + self.active_partition_idx += 1; + self.offset_in_active_p = 0; } + + Ok(()) } } enum EquiJoinState { Build(BuildState), Probe(ProbeState), - EmitUnmatchedBuild(ProbeState), + EmitUnmatchedBuild(EmitUnmatchedState), Done, } @@ -292,6 +334,8 @@ struct EquiJoinParams { left_is_build: bool, left_payload_select: Vec>, right_payload_select: Vec>, + left_payload_schema: Schema, + right_payload_schema: Schema, args: JoinArgs, random_state: PlRandomState, } @@ -341,6 +385,9 @@ impl EquiJoinNode { compute_payload_selector(&left_input_schema, &right_input_schema, true, &args); let right_payload_select = compute_payload_selector(&right_input_schema, &left_input_schema, false, &args); + + let left_payload_schema = select_schema(&left_input_schema, &left_payload_select); + let right_payload_schema = select_schema(&right_input_schema, &right_payload_select); Self { state: EquiJoinState::Build(BuildState { partitions_per_worker: Vec::new(), @@ -350,6 +397,8 @@ impl EquiJoinNode { left_is_build, left_payload_select, right_payload_select, + left_payload_schema, + right_payload_schema, args, random_state: PlRandomState::new(), }, @@ -373,9 +422,8 @@ impl ComputeNode for EquiJoinNode { let build_idx = if self.params.left_is_build { 0 } else { 1 }; let probe_idx = 1 - build_idx; - // If the output doesn't want any more data, or the probe side is done, - // transition to being done. - if send[0] == PortState::Done || recv[probe_idx] == PortState::Done { + // If the output doesn't want any more data, transition to being done. + if send[0] == PortState::Done { self.state = EquiJoinState::Done; } @@ -385,6 +433,29 @@ impl ComputeNode for EquiJoinNode { self.state = EquiJoinState::Probe(build_state.finalize(&*self.table)); } } + + // If we are probing and the probe input is done, emit unmatched if + // necessary, otherwise we're done. + if let EquiJoinState::Probe(probe_state) = &mut self.state { + if recv[probe_idx] == PortState::Done { + if self.params.emit_unmatched_build() { + self.state = EquiJoinState::EmitUnmatchedBuild(EmitUnmatchedState { + partitions: core::mem::take(&mut probe_state.table_per_partition), + active_partition_idx: 0, + offset_in_active_p: 0, + }); + } else { + self.state = EquiJoinState::Done; + } + } + } + + // Finally, check if we are done emitting unmatched keys. + if let EquiJoinState::EmitUnmatchedBuild(emit_state) = &mut self.state { + if emit_state.active_partition_idx >= emit_state.partitions.len() { + self.state = EquiJoinState::Done; + } + } match &mut self.state { EquiJoinState::Build(_) => { @@ -471,6 +542,15 @@ impl ComputeNode for EquiJoinNode { )); } }, + EquiJoinState::EmitUnmatchedBuild(emit_state) => { + assert!(recv_ports[build_idx].is_none()); + assert!(recv_ports[probe_idx].is_none()); + let send = send_ports[0].take().unwrap().serial(); + join_handles.push(scope.spawn_task( + TaskPriority::Low, + emit_state.emit_unmatched(send, &self.params, self.num_pipelines) + )); + }, EquiJoinState::Done => unreachable!(), } } diff --git a/crates/polars-stream/src/physical_plan/fmt.rs b/crates/polars-stream/src/physical_plan/fmt.rs index 7ef74d5b0ad9..57ae8119db11 100644 --- a/crates/polars-stream/src/physical_plan/fmt.rs +++ b/crates/polars-stream/src/physical_plan/fmt.rs @@ -214,8 +214,12 @@ fn visualize_plan_rec( left_on, right_on, args, - } => { - let mut label = "in-memory-join".to_string(); + } | PhysNodeKind::EquiJoin { input_left, input_right, left_on, right_on, args } => { + let mut label = if matches!(phys_sm[node_key].kind, PhysNodeKind::EquiJoin { .. }) { + "equi-join".to_string() + } else { + "in-memory-join".to_string() + }; write!(label, r"\nleft_on:\n{}", fmt_exprs(left_on, expr_arena)).unwrap(); write!(label, r"\nright_on:\n{}", fmt_exprs(right_on, expr_arena)).unwrap(); write!( diff --git a/crates/polars-stream/src/physical_plan/mod.rs b/crates/polars-stream/src/physical_plan/mod.rs index 707c2a53dec2..aa821e6b0a38 100644 --- a/crates/polars-stream/src/physical_plan/mod.rs +++ b/crates/polars-stream/src/physical_plan/mod.rs @@ -153,6 +153,14 @@ pub enum PhysNodeKind { key: Vec, aggs: Vec, }, + + EquiJoin { + input_left: PhysNodeKey, + input_right: PhysNodeKey, + left_on: Vec, + right_on: Vec, + args: JoinArgs, + }, /// Generic fallback for (as-of-yet) unsupported streaming joins. /// Fully sinks all data to in-memory data frames and uses the in-memory @@ -213,7 +221,7 @@ fn insert_multiplexers( insert_multiplexers(*input, phys_sm, referenced); }, - PhysNodeKind::InMemoryJoin { + PhysNodeKind::InMemoryJoin { input_left, input_right, .. } | PhysNodeKind::EquiJoin { input_left, input_right, .. diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index 66bb1f4180a8..f5dd8d02b94c 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -23,6 +23,7 @@ use super::{PhysNode, PhysNodeKey, PhysNodeKind}; use crate::expression::StreamExpr; use crate::graph::{Graph, GraphNodeKey}; use crate::nodes; +use crate::nodes::joins::equi_join::EquiJoinNode; use crate::physical_plan::lower_expr::compute_output_schema; use crate::utils::late_materialized_df::LateMaterializedDataFrame; @@ -503,6 +504,30 @@ fn to_graph_rec<'a>( [left_input_key, right_input_key], ) }, + + EquiJoin { + input_left, + input_right, + left_on, + right_on, + args, + } => { + let args = args.clone(); + let left_input_key = to_graph_rec(*input_left, ctx)?; + let right_input_key = to_graph_rec(*input_right, ctx)?; + let left_input_schema = ctx.phys_sm[*input_left].output_schema.clone(); + let right_input_schema = ctx.phys_sm[*input_right].output_schema.clone(); + + todo!() + // ctx.graph.add_node( + // nodes::joins::equi_join::EquiJoinNode::new( + // left_input_schema, + // right_input_schema, + // args, + // ), + // [left_input_key, right_input_key], + // ) + }, }; ctx.phys_to_graph.insert(phys_node_key, graph_key); diff --git a/crates/polars-utils/src/idx_map/bytes_idx_map.rs b/crates/polars-utils/src/idx_map/bytes_idx_map.rs index 61848af1d8df..c362361e2620 100644 --- a/crates/polars-utils/src/idx_map/bytes_idx_map.rs +++ b/crates/polars-utils/src/idx_map/bytes_idx_map.rs @@ -102,10 +102,18 @@ impl BytesIndexMap { } } + /// Gets the hash, key and value at the given index by insertion order. + #[inline(always)] + pub fn get_index(&self, idx: IdxSize) -> Option<(u64, &[u8], &V)> { + let t = self.tuples.get(idx as usize)?; + Some((t.0.key_hash, unsafe { t.0.get(&self.key_data) }, &t.1)) + } + /// Gets the hash, key and value at the given index by insertion order. /// /// # Safety /// The index must be less than len(). + #[inline(always)] pub unsafe fn get_index_unchecked(&self, idx: IdxSize) -> (u64, &[u8], &V) { let t = self.tuples.get_unchecked(idx as usize); (t.0.key_hash, t.0.get(&self.key_data), &t.1) From 9c69575d07c145268fc74e372aaa5deb91175d34 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Wed, 20 Nov 2024 19:22:10 +0100 Subject: [PATCH 08/21] wip, just some bugs left --- .../polars-expr/src/chunked_idx_table/mod.rs | 7 +- .../src/chunked_idx_table/row_encoded.rs | 44 +++- crates/polars-expr/src/hash_keys.rs | 24 +- crates/polars-ops/src/frame/join/args.rs | 4 + .../src/nodes/joins/equi_join.rs | 212 +++++++++++------- .../src/physical_plan/lower_ir.rs | 33 ++- .../src/physical_plan/to_graph.rs | 36 ++- crates/polars-utils/src/index.rs | 5 +- 8 files changed, 243 insertions(+), 122 deletions(-) diff --git a/crates/polars-expr/src/chunked_idx_table/mod.rs b/crates/polars-expr/src/chunked_idx_table/mod.rs index 141072083d8a..10a51731ecd5 100644 --- a/crates/polars-expr/src/chunked_idx_table/mod.rs +++ b/crates/polars-expr/src/chunked_idx_table/mod.rs @@ -19,7 +19,7 @@ pub trait ChunkedIdxTable: Any + Send + Sync { fn num_keys(&self) -> IdxSize; /// Inserts the given key chunk into this ChunkedIdxTable. - fn insert_key_chunk(&mut self, keys: HashKeys); + fn insert_key_chunk(&mut self, keys: HashKeys, track_unmatchable: bool); /// Probe the table, updating table_match and probe_match with /// (ChunkId, IdxSize) pairs for each match. Will stop processing new keys @@ -58,7 +58,6 @@ pub trait ChunkedIdxTable: Any + Send + Sync { fn unmarked_keys(&self, out: &mut Vec>, offset: IdxSize, limit: IdxSize); } -pub fn new_chunked_idx_table(key_schema: Arc) -> Box { - // Box::new(row_encoded::BytesIndexMap::new(key_schema)) - todo!() +pub fn new_chunked_idx_table(_key_schema: Arc) -> Box { + Box::new(row_encoded::RowEncodedChunkedIdxTable::new()) } diff --git a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs index d82bb7d4425c..364662e029f6 100644 --- a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs +++ b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs @@ -15,6 +15,7 @@ pub struct RowEncodedChunkedIdxTable { // first chunk in each to mark keys during probing. idx_map: BytesIndexMap>, chunk_ctr: u32, + null_keys: Vec>, } impl RowEncodedChunkedIdxTable { @@ -22,20 +23,21 @@ impl RowEncodedChunkedIdxTable { Self { idx_map: BytesIndexMap::new(), chunk_ctr: 0, + null_keys: Vec::new(), } } } impl RowEncodedChunkedIdxTable { #[inline(always)] - fn probe_one( + fn probe_one( &self, hash: u64, key: &[u8], key_idx: IdxSize, table_match: &mut Vec>, probe_match: &mut Vec, - ) { + ) -> bool { if let Some(chunk_ids) = self.idx_map.get(hash, key) { for chunk_id in &chunk_ids[..] { // Create matches, making sure to clear top bit. @@ -55,9 +57,9 @@ impl RowEncodedChunkedIdxTable { first_chunk_id.store(first_chunk_val | (1 << 63), Ordering::Release); } } - } else if EMIT_UNMATCHED { - table_match.push(ChunkId::null()); - probe_match.push(key_idx); + true + } else { + false } } @@ -73,14 +75,21 @@ impl RowEncodedChunkedIdxTable { let mut keys_processed = 0; for (hash, key) in hash_keys { - if let Some(key) = key { - self.probe_one::( + let found_match = if let Some(key) = key { + self.probe_one::( hash, key, keys_processed, table_match, probe_match, - ); + ) + } else { + false + }; + + if EMIT_UNMATCHED && !found_match { + table_match.push(ChunkId::null()); + probe_match.push(keys_processed); } keys_processed += 1; @@ -130,7 +139,7 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { self.idx_map.len() } - fn insert_key_chunk(&mut self, hash_keys: HashKeys) { + fn insert_key_chunk(&mut self, hash_keys: HashKeys, track_unmatchable: bool) { let HashKeys::RowEncoded(hash_keys) = hash_keys else { unreachable!() }; @@ -144,9 +153,9 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { .zip(hash_keys.keys.iter()) .enumerate_idx() { + let chunk_id = ChunkId::<32>::store(self.chunk_ctr as IdxSize, i); if let Some(key) = key { - let chunk_id = - AtomicU64::new(ChunkId::<32>::store(self.chunk_ctr as IdxSize, i).into_inner()); + let chunk_id = AtomicU64::new(chunk_id.into_inner()); match self.idx_map.entry(*hash, key) { Entry::Occupied(o) => { o.into_mut().push(chunk_id); @@ -155,6 +164,8 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { v.insert(unitvec![chunk_id]); }, } + } else if track_unmatchable { + self.null_keys.push(chunk_id); } } @@ -252,9 +263,16 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { } } - fn unmarked_keys(&self, out: &mut Vec>, offset: IdxSize, limit: IdxSize) { + fn unmarked_keys(&self, out: &mut Vec>, mut offset: IdxSize, limit: IdxSize) { out.clear(); + if (offset as usize) < self.null_keys.len() { + out.extend(self.null_keys[offset as usize..].iter().copied().take(limit as usize)); + return; + } + + offset -= self.null_keys.len() as IdxSize; + while let Some((_, _, chunk_ids)) = self.idx_map.get_index(offset) { let first_chunk_id = unsafe { chunk_ids.get_unchecked(0) }; let first_chunk_val = first_chunk_id.load(Ordering::Acquire); @@ -269,6 +287,8 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { if out.len() >= limit as usize { break; } + + offset += 1; } } } diff --git a/crates/polars-expr/src/hash_keys.rs b/crates/polars-expr/src/hash_keys.rs index 7c66130f2dab..4a22c32d991f 100644 --- a/crates/polars-expr/src/hash_keys.rs +++ b/crates/polars-expr/src/hash_keys.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use arrow::array::{BinaryArray, PrimitiveArray, UInt64Array}; use arrow::compute::take::binary::take_unchecked; use arrow::compute::utils::combine_validities_and_many; @@ -63,23 +61,22 @@ impl HashKeys { /// After this call partition_idxs[p] will contain the indices of hashes /// that belong to partition p, and the cardinality sketches are updated /// accordingly. - /// - /// If null_is_valid is false rows with nulls do not get assigned a partition. pub fn gen_partition_idxs( &self, partitioner: &HashPartitioner, partition_idxs: &mut [Vec], sketches: &mut [CardinalitySketch], + partition_nulls: bool, ) { if sketches.is_empty() { match self { - Self::RowEncoded(s) => s.gen_partition_idxs::(partitioner, partition_idxs, sketches), - Self::Single(s) => s.gen_partition_idxs::(partitioner, partition_idxs, sketches), + Self::RowEncoded(s) => s.gen_partition_idxs::(partitioner, partition_idxs, sketches, partition_nulls), + Self::Single(s) => s.gen_partition_idxs::(partitioner, partition_idxs, sketches, partition_nulls), } } else { match self { - Self::RowEncoded(s) => s.gen_partition_idxs::(partitioner, partition_idxs, sketches), - Self::Single(s) => s.gen_partition_idxs::(partitioner, partition_idxs, sketches), + Self::RowEncoded(s) => s.gen_partition_idxs::(partitioner, partition_idxs, sketches, partition_nulls), + Self::Single(s) => s.gen_partition_idxs::(partitioner, partition_idxs, sketches, partition_nulls), } } } @@ -106,9 +103,10 @@ impl RowEncodedKeys { partitioner: &HashPartitioner, partition_idxs: &mut [Vec], sketches: &mut [CardinalitySketch], + partition_nulls: bool, ) { assert!(partition_idxs.len() == partitioner.num_partitions()); - assert!(BUILD_SKETCHES && sketches.len() == partitioner.num_partitions()); + assert!(!BUILD_SKETCHES || sketches.len() == partitioner.num_partitions()); for p in partition_idxs.iter_mut() { p.clear(); } @@ -124,6 +122,11 @@ impl RowEncodedKeys { sketches.get_unchecked_mut(p).insert(*h); } } + } else if partition_nulls { + // Arbitrarily put nulls in partition 0. + unsafe { + partition_idxs.get_unchecked_mut(0).push(i as IdxSize); + } } } } else { @@ -167,7 +170,8 @@ impl SingleKeys { &self, partitioner: &HashPartitioner, partition_idxs: &mut [Vec], - sketches: &mut [CardinalitySketch], + _sketches: &mut [CardinalitySketch], + _partition_nulls: bool, ) { assert!(partitioner.num_partitions() == partition_idxs.len()); for p in partition_idxs.iter_mut() { diff --git a/crates/polars-ops/src/frame/join/args.rs b/crates/polars-ops/src/frame/join/args.rs index def36b76a677..b005fa896a63 100644 --- a/crates/polars-ops/src/frame/join/args.rs +++ b/crates/polars-ops/src/frame/join/args.rs @@ -163,6 +163,10 @@ impl Debug for JoinType { } impl JoinType { + pub fn is_equi(&self) -> bool { + matches!(self, JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full) + } + pub fn is_asof(&self) -> bool { #[cfg(feature = "asof_join")] { diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index 4087b4796856..c7d9da14b7a7 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use polars_core::prelude::{PlHashSet, PlRandomState}; +use polars_core::prelude::{IntoColumn, PlHashSet, PlRandomState}; use polars_core::schema::{Schema, SchemaExt}; use polars_core::series::IsSorted; use polars_core::utils::accumulate_dataframes_vertical_unchecked; @@ -16,6 +16,7 @@ use rayon::prelude::*; use crate::async_primitives::connector::{Receiver, Sender}; use crate::async_primitives::wait_group::WaitGroup; +use crate::expression::StreamExpr; use crate::morsel::{get_ideal_morsel_size, SourceToken}; use crate::nodes::compute_node_prelude::*; @@ -24,46 +25,63 @@ use crate::nodes::compute_node_prelude::*; fn compute_payload_selector( this: &Schema, other: &Schema, + this_key_schema: &Schema, is_left: bool, args: &JoinArgs, ) -> Vec> { let should_coalesce = args.should_coalesce(); - let other_col_names: PlHashSet = other.iter_names_cloned().collect(); this.iter_names() .map(|c| { - if !other_col_names.contains(c) { - return Some(c.clone()); + if should_coalesce && this_key_schema.contains(c) { + if is_left != (args.how == JoinType::Right) { + return Some(c.clone()); + } else { + return None; + } } + if !other.contains(c) { + return Some(c.clone()); + } + if is_left { - if should_coalesce && args.how == JoinType::Right { - None - } else { - Some(c.clone()) - } + Some(c.clone()) } else { - if should_coalesce { - if args.how == JoinType::Right { - Some(c.clone()) - } else { - None - } - } else { - Some(format_pl_smallstr!("{}{}", c, args.suffix())) - } + Some(format_pl_smallstr!("{}{}", c, args.suffix())) } }) .collect() } fn select_schema(schema: &Schema, selector: &[Option]) -> Schema { - schema.iter_fields() + schema + .iter_fields() .zip(selector) .filter_map(|(f, name)| Some(f.with_name(name.clone()?))) .collect() } +async fn select_keys( + df: &DataFrame, + key_selectors: &[StreamExpr], + params: &EquiJoinParams, + state: &ExecutionState, +) -> PolarsResult { + let mut key_columns = Vec::new(); + for selector in key_selectors { + let s = selector.evaluate(&df, state).await?; + key_columns.push(s.into_column()); + } + let keys = DataFrame::new_with_broadcast_len(key_columns, df.height())?; + Ok(HashKeys::from_df( + &keys, + params.random_state.clone(), + params.args.join_nulls, + true, + )) +} + fn select_payload(df: DataFrame, selector: &[Option]) -> DataFrame { // Maintain height of zero-width dataframes. if df.width() == 0 { @@ -94,32 +112,32 @@ impl BuildState { partitions: &mut Vec, partitioner: HashPartitioner, params: &EquiJoinParams, + state: &ExecutionState, ) -> PolarsResult<()> { + let track_unmatchable = params.emit_unmatched_build(); let mut partition_idxs = vec![Vec::new(); partitioner.num_partitions()]; partitions.resize_with(partitioner.num_partitions(), BuildPartition::default); - let mut sketches = vec![CardinalitySketch::default(); partitioner.num_partitions()]; + let (key_selectors, payload_selector); + if params.left_is_build { + payload_selector = ¶ms.left_payload_select; + key_selectors = ¶ms.left_key_selectors; + } else { + payload_selector = ¶ms.right_payload_select; + key_selectors = ¶ms.right_key_selectors; + }; + while let Ok(morsel) = recv.recv().await { + // Compute hashed keys and payload. We must rechunk the payload for + // later chunked gathers. let df = morsel.into_df(); - let hash_keys = HashKeys::from_df( - &df, - params.random_state.clone(), - params.args.join_nulls, - true, - ); - let selector = if params.left_is_build { - ¶ms.left_payload_select - } else { - ¶ms.right_payload_select - }; - - // We must rechunk the payload for later chunked gathers. - let mut payload = select_payload(df, selector); + let hash_keys = select_keys(&df, &key_selectors, params, state).await?; + let mut payload = select_payload(df, payload_selector); payload.rechunk_mut(); unsafe { - hash_keys.gen_partition_idxs(&partitioner, &mut partition_idxs, &mut sketches); + hash_keys.gen_partition_idxs(&partitioner, &mut partition_idxs, &mut sketches, track_unmatchable); for (p, idxs_in_p) in partitions.iter_mut().zip(&partition_idxs) { p.hash_keys.push(hash_keys.gather(idxs_in_p)); p.frames @@ -135,8 +153,9 @@ impl BuildState { Ok(()) } - fn finalize(&mut self, table: &dyn ChunkedIdxTable) -> ProbeState { + fn finalize(&mut self, params: &EquiJoinParams, table: &dyn ChunkedIdxTable) -> ProbeState { let num_partitions = self.partitions_per_worker.len(); + let track_unmatchable = params.emit_unmatched_build(); let table_per_partition: Vec<_> = (0..num_partitions) .into_par_iter() .with_max_len(1) @@ -155,12 +174,26 @@ impl BuildState { table.reserve(sketch.estimate() * 5 / 4); for worker in &self.partitions_per_worker { for (hash_keys, frame) in worker[p].hash_keys.iter().zip(&worker[p].frames) { - table.insert_key_chunk(hash_keys.clone()); + // Zero-sized chunks can get deleted, so skip entirely to avoid messing + // up the chunk counter. + if frame.height() == 0 { + continue; + } + + table.insert_key_chunk(hash_keys.clone(), track_unmatchable); combined_frames.push(frame.clone()); } } - let df = accumulate_dataframes_vertical_unchecked(combined_frames); + let df = if combined_frames.is_empty() { + if params.left_is_build { + DataFrame::empty_with_schema(¶ms.left_payload_schema) + } else { + DataFrame::empty_with_schema(¶ms.right_payload_schema) + } + } else { + accumulate_dataframes_vertical_unchecked(combined_frames) + }; ProbeTable { table, df } }) .collect(); @@ -190,44 +223,48 @@ impl ProbeState { partitions: &[ProbeTable], partitioner: HashPartitioner, params: &EquiJoinParams, + state: &ExecutionState, ) -> PolarsResult<()> { - let mut partition_idxs = Vec::new(); + let mut partition_idxs = vec![Vec::new(); partitioner.num_partitions()]; let mut table_match = Vec::new(); let mut probe_match = Vec::new(); - + let probe_limit = get_ideal_morsel_size() as IdxSize; let mark_matches = params.emit_unmatched_build(); let emit_unmatched = params.emit_unmatched_probe(); + let (key_selectors, payload_selector); + if params.left_is_build { + payload_selector = ¶ms.right_payload_select; + key_selectors = ¶ms.right_key_selectors; + } else { + payload_selector = ¶ms.left_payload_select; + key_selectors = ¶ms.left_key_selectors; + }; + while let Ok(morsel) = recv.recv().await { + // Compute hashed keys and payload. let (df, seq, src_token, wait_token) = morsel.into_inner(); - let hash_keys = HashKeys::from_df( - &df, - params.random_state.clone(), - params.args.join_nulls, - true, - ); - let selector = if params.left_is_build { - ¶ms.right_payload_select - } else { - ¶ms.left_payload_select - }; - let payload = select_payload(df, selector); + let hash_keys = select_keys(&df, &key_selectors, params, state).await?; + let payload = select_payload(df, payload_selector); unsafe { - hash_keys.gen_partition_idxs(&partitioner, &mut partition_idxs, &mut []); + // Partition and probe the tables. + hash_keys.gen_partition_idxs(&partitioner, &mut partition_idxs, &mut [], emit_unmatched); for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { let mut offset = 0; - while let Some(idxs_in_p_slice) = idxs_in_p.get(offset as usize..) { + while offset < idxs_in_p.len() { offset += p.table.probe_subset( &hash_keys, - idxs_in_p_slice, + &idxs_in_p[offset..], &mut table_match, &mut probe_match, mark_matches, emit_unmatched, probe_limit, - ); + ) as usize; + + // Gather output and send. let mut build_df = if emit_unmatched { p.df.take_opt_chunked_unchecked(&table_match) } else { @@ -271,7 +308,11 @@ impl EmitUnmatchedState { params: &EquiJoinParams, num_pipelines: usize, ) -> PolarsResult<()> { - let total_len: usize = self.partitions.iter().map(|p| p.table.num_keys() as usize).sum(); + let total_len: usize = self + .partitions + .iter() + .map(|p| p.table.num_keys() as usize) + .sum(); let ideal_morsel_count = (total_len / get_ideal_morsel_size()).max(1); let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines); let morsel_size = total_len.div_ceil(morsel_count).max(1); @@ -282,12 +323,18 @@ impl EmitUnmatchedState { let mut unmarked_idxs = Vec::new(); while let Some(p) = self.partitions.get(self.active_partition_idx) { loop { - p.table.unmarked_keys(&mut unmarked_idxs, self.offset_in_active_p as IdxSize, morsel_size as IdxSize); + // Generate a chunk of unmarked key indices. + p.table.unmarked_keys( + &mut unmarked_idxs, + self.offset_in_active_p as IdxSize, + morsel_size as IdxSize, + ); self.offset_in_active_p += unmarked_idxs.len(); if unmarked_idxs.is_empty() { break; } - + + // Gather and create full-null counterpart. let out_df = unsafe { let mut build_df = p.df.take_chunked_unchecked(&unmarked_idxs, IsSorted::Not); let len = build_df.height(); @@ -302,6 +349,7 @@ impl EmitUnmatchedState { } }; + // Send and wait until consume token is consumed. let mut morsel = Morsel::new(out_df, morsel_seq, source_token.clone()); morsel_seq = morsel_seq.successor(); morsel.set_consume_token(wait_group.token()); @@ -314,7 +362,7 @@ impl EmitUnmatchedState { return Ok(()); } } - + self.active_partition_idx += 1; self.offset_in_active_p = 0; } @@ -332,6 +380,8 @@ enum EquiJoinState { struct EquiJoinParams { left_is_build: bool, + left_key_selectors: Vec, + right_key_selectors: Vec, left_payload_select: Vec>, right_payload_select: Vec>, left_payload_schema: Schema, @@ -371,21 +421,26 @@ impl EquiJoinNode { pub fn new( left_input_schema: Arc, right_input_schema: Arc, + left_key_schema: Arc, + right_key_schema: Arc, + left_key_selectors: Vec, + right_key_selectors: Vec, args: JoinArgs, ) -> Self { // TODO: use cardinality estimation to determine this. let left_is_build = args.how != JoinType::Left; + + let left_payload_select = + compute_payload_selector(&left_input_schema, &right_input_schema, &left_key_schema, true, &args); + let right_payload_select = + compute_payload_selector(&right_input_schema, &left_input_schema, &right_key_schema, false, &args); + let table = if left_is_build { - new_chunked_idx_table(left_input_schema.clone()) + new_chunked_idx_table(left_key_schema) } else { - new_chunked_idx_table(right_input_schema.clone()) + new_chunked_idx_table(right_key_schema) }; - let left_payload_select = - compute_payload_selector(&left_input_schema, &right_input_schema, true, &args); - let right_payload_select = - compute_payload_selector(&right_input_schema, &left_input_schema, false, &args); - let left_payload_schema = select_schema(&left_input_schema, &left_payload_select); let right_payload_schema = select_schema(&right_input_schema, &right_payload_select); Self { @@ -395,6 +450,8 @@ impl EquiJoinNode { num_pipelines: 0, params: EquiJoinParams { left_is_build, + left_key_selectors, + right_key_selectors, left_payload_select, right_payload_select, left_payload_schema, @@ -430,10 +487,10 @@ impl ComputeNode for EquiJoinNode { // If we are building and the build input is done, transition to probing. if let EquiJoinState::Build(build_state) = &mut self.state { if recv[build_idx] == PortState::Done { - self.state = EquiJoinState::Probe(build_state.finalize(&*self.table)); + self.state = EquiJoinState::Probe(build_state.finalize(&self.params, &*self.table)); } } - + // If we are probing and the probe input is done, emit unmatched if // necessary, otherwise we're done. if let EquiJoinState::Probe(probe_state) = &mut self.state { @@ -449,7 +506,7 @@ impl ComputeNode for EquiJoinNode { } } } - + // Finally, check if we are done emitting unmatched keys. if let EquiJoinState::EmitUnmatchedBuild(emit_state) = &mut self.state { if emit_state.active_partition_idx >= emit_state.partitions.len() { @@ -459,24 +516,23 @@ impl ComputeNode for EquiJoinNode { match &mut self.state { EquiJoinState::Build(_) => { + send[0] = PortState::Blocked; recv[build_idx] = PortState::Ready; recv[probe_idx] = PortState::Blocked; - send[0] = PortState::Blocked; }, EquiJoinState::Probe(_) => { + core::mem::swap(&mut send[0], &mut recv[probe_idx]); recv[build_idx] = PortState::Done; - recv[probe_idx] = PortState::Ready; - send[0] = PortState::Ready; }, EquiJoinState::EmitUnmatchedBuild(_) => { + send[0] = PortState::Ready; recv[build_idx] = PortState::Done; recv[probe_idx] = PortState::Done; - send[0] = PortState::Ready; }, EquiJoinState::Done => { + send[0] = PortState::Done; recv[0] = PortState::Done; recv[1] = PortState::Done; - send[0] = PortState::Done; }, } Ok(()) @@ -491,7 +547,7 @@ impl ComputeNode for EquiJoinNode { scope: &'s TaskScope<'s, 'env>, recv_ports: &mut [Option>], send_ports: &mut [Option>], - _state: &'s ExecutionState, + state: &'s ExecutionState, join_handles: &mut Vec>>, ) { assert!(recv_ports.len() == 2); @@ -519,6 +575,7 @@ impl ComputeNode for EquiJoinNode { worker_ps, partitioner.clone(), &self.params, + state, ), )); } @@ -538,6 +595,7 @@ impl ComputeNode for EquiJoinNode { &probe_state.table_per_partition, partitioner.clone(), &self.params, + state, ), )); } @@ -548,7 +606,7 @@ impl ComputeNode for EquiJoinNode { let send = send_ports[0].take().unwrap().serial(); join_handles.push(scope.spawn_task( TaskPriority::Low, - emit_state.emit_unmatched(send, &self.params, self.num_pipelines) + emit_state.emit_unmatched(send, &self.params, self.num_pipelines), )); }, EquiJoinState::Done => unreachable!(), diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index db999481bf97..02640adc3777 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -519,12 +519,33 @@ pub fn lower_ir( let args = options.args.clone(); let phys_left = lower_ir!(input_left)?; let phys_right = lower_ir!(input_right)?; - PhysNodeKind::InMemoryJoin { - input_left: phys_left, - input_right: phys_right, - left_on, - right_on, - args, + if args.how.is_equi() && !args.validation.needs_checks() { + let (trans_input_left, trans_left_on) = + lower_exprs(phys_left, &left_on, expr_arena, phys_sm, expr_cache)?; + let (trans_input_right, trans_right_on) = + lower_exprs(phys_right, &right_on, expr_arena, phys_sm, expr_cache)?; + let mut node = phys_sm.insert(PhysNode::new( + output_schema, + PhysNodeKind::EquiJoin { + input_left: trans_input_left, + input_right: trans_input_right, + left_on: trans_left_on, + right_on: trans_right_on, + args: args.clone() + } + )); + if let Some((offset, len)) = args.slice { + node = build_slice_node(node, offset, len, phys_sm); + } + return Ok(node); + } else { + PhysNodeKind::InMemoryJoin { + input_left: phys_left, + input_right: phys_right, + left_on, + right_on, + args, + } } }, IR::Distinct { .. } => todo!(), diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index f5dd8d02b94c..6b51ac1ecf90 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -23,7 +23,6 @@ use super::{PhysNode, PhysNodeKey, PhysNodeKind}; use crate::expression::StreamExpr; use crate::graph::{Graph, GraphNodeKey}; use crate::nodes; -use crate::nodes::joins::equi_join::EquiJoinNode; use crate::physical_plan::lower_expr::compute_output_schema; use crate::utils::late_materialized_df::LateMaterializedDataFrame; @@ -518,15 +517,32 @@ fn to_graph_rec<'a>( let left_input_schema = ctx.phys_sm[*input_left].output_schema.clone(); let right_input_schema = ctx.phys_sm[*input_right].output_schema.clone(); - todo!() - // ctx.graph.add_node( - // nodes::joins::equi_join::EquiJoinNode::new( - // left_input_schema, - // right_input_schema, - // args, - // ), - // [left_input_key, right_input_key], - // ) + let left_key_schema = compute_output_schema(&left_input_schema, left_on, ctx.expr_arena)? + .materialize_unknown_dtypes()?; + let right_key_schema = compute_output_schema(&right_input_schema, right_on, ctx.expr_arena)? + .materialize_unknown_dtypes()?; + + let left_key_selectors = left_on + .iter() + .map(|e| create_stream_expr(e, ctx, &left_input_schema)) + .try_collect_vec()?; + let right_key_selectors = right_on + .iter() + .map(|e| create_stream_expr(e, ctx, &right_input_schema)) + .try_collect_vec()?; + + ctx.graph.add_node( + nodes::joins::equi_join::EquiJoinNode::new( + left_input_schema, + right_input_schema, + Arc::new(left_key_schema), + Arc::new(right_key_schema), + left_key_selectors, + right_key_selectors, + args + ), + [left_input_key, right_input_key], + ) }, }; diff --git a/crates/polars-utils/src/index.rs b/crates/polars-utils/src/index.rs index 8ef3c05cee4d..9ef037954c39 100644 --- a/crates/polars-utils/src/index.rs +++ b/crates/polars-utils/src/index.rs @@ -226,9 +226,8 @@ impl ChunkId { #[allow(clippy::unnecessary_cast)] pub fn extract(self) -> (IdxSize, IdxSize) { let row = (self.swizzled >> CHUNK_BITS) as IdxSize; - - let mask: IdxSize = IdxSize::MAX << CHUNK_BITS; - let chunk = (self.swizzled as IdxSize) & !mask; + let mask = (1u64 << CHUNK_BITS) - 1; + let chunk = (self.swizzled & mask) as IdxSize; (chunk, row) } From 7ca20acf574873df7915a37c38a682ac54a7871d Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 22 Nov 2024 15:48:22 +0100 Subject: [PATCH 09/21] wip --- crates/polars-core/src/frame/mod.rs | 2 +- .../polars-expr/src/chunked_idx_table/mod.rs | 2 +- .../src/chunked_idx_table/row_encoded.rs | 34 +-- crates/polars-expr/src/hash_keys.rs | 51 +++++ .../src/nodes/joins/equi_join.rs | 216 +++++++++++++++--- 5 files changed, 257 insertions(+), 48 deletions(-) diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 1b601b2fa36f..f4dad06dea7a 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -722,7 +722,7 @@ impl DataFrame { /// - The length of each appended column matches the height of the [`DataFrame`]. For /// `DataFrame`]s with no columns (ZCDFs), it is important that the height is set afterwards /// with [`DataFrame::set_height`]. - pub unsafe fn column_extend_unchecked(&mut self, iter: impl Iterator) { + pub unsafe fn column_extend_unchecked(&mut self, iter: impl IntoIterator) { unsafe { self.get_columns_mut() }.extend(iter) } diff --git a/crates/polars-expr/src/chunked_idx_table/mod.rs b/crates/polars-expr/src/chunked_idx_table/mod.rs index 10a51731ecd5..a9d853c866db 100644 --- a/crates/polars-expr/src/chunked_idx_table/mod.rs +++ b/crates/polars-expr/src/chunked_idx_table/mod.rs @@ -55,7 +55,7 @@ pub trait ChunkedIdxTable: Any + Send + Sync { ) -> IdxSize; /// Get the ChunkIds for each key which was never marked during probing. - fn unmarked_keys(&self, out: &mut Vec>, offset: IdxSize, limit: IdxSize); + fn unmarked_keys(&self, out: &mut Vec>, offset: IdxSize, limit: IdxSize) -> IdxSize; } pub fn new_chunked_idx_table(_key_schema: Arc) -> Box { diff --git a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs index 364662e029f6..f7952ad91dea 100644 --- a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs +++ b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs @@ -32,9 +32,9 @@ impl RowEncodedChunkedIdxTable { #[inline(always)] fn probe_one( &self, + key_idx: IdxSize, hash: u64, key: &[u8], - key_idx: IdxSize, table_match: &mut Vec>, probe_match: &mut Vec, ) -> bool { @@ -65,7 +65,7 @@ impl RowEncodedChunkedIdxTable { fn probe_impl<'a, const MARK_MATCHES: bool, const EMIT_UNMATCHED: bool>( &self, - hash_keys: impl Iterator)>, + hash_keys: impl Iterator)>, table_match: &mut Vec>, probe_match: &mut Vec, limit: IdxSize, @@ -74,12 +74,12 @@ impl RowEncodedChunkedIdxTable { probe_match.clear(); let mut keys_processed = 0; - for (hash, key) in hash_keys { + for (key_idx, hash, key) in hash_keys { let found_match = if let Some(key) = key { self.probe_one::( + key_idx, hash, key, - keys_processed, table_match, probe_match, ) @@ -89,7 +89,7 @@ impl RowEncodedChunkedIdxTable { if EMIT_UNMATCHED && !found_match { table_match.push(ChunkId::null()); - probe_match.push(keys_processed); + probe_match.push(key_idx); } keys_processed += 1; @@ -102,7 +102,7 @@ impl RowEncodedChunkedIdxTable { fn probe_dispatch<'a>( &self, - hash_keys: impl Iterator)>, + hash_keys: impl Iterator)>, table_match: &mut Vec>, probe_match: &mut Vec, mark_matches: bool, @@ -190,7 +190,9 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { .hashes .values_iter() .copied() - .zip(hash_keys.keys.iter()); + .zip(hash_keys.keys.iter()) + .enumerate_idx() + .map(|(i, (h, k))| (i, h, k)); self.probe_dispatch( iter, table_match, @@ -204,7 +206,9 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { .hashes .values_iter() .copied() - .zip(hash_keys.keys.values_iter().map(Some)); + .zip(hash_keys.keys.values_iter().map(Some)) + .enumerate_idx() + .map(|(i, (h, k))| (i, h, k)); self.probe_dispatch( iter, table_match, @@ -233,6 +237,7 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { if hash_keys.keys.has_nulls() { let iter = subset.iter().map(|i| { ( + *i, hash_keys.hashes.value_unchecked(*i as usize), hash_keys.keys.get_unchecked(*i as usize), ) @@ -248,6 +253,7 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { } else { let iter = subset.iter().map(|i| { ( + *i, hash_keys.hashes.value_unchecked(*i as usize), Some(hash_keys.keys.value_unchecked(*i as usize)), ) @@ -263,17 +269,18 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { } } - fn unmarked_keys(&self, out: &mut Vec>, mut offset: IdxSize, limit: IdxSize) { + fn unmarked_keys(&self, out: &mut Vec>, mut offset: IdxSize, limit: IdxSize) -> IdxSize { out.clear(); if (offset as usize) < self.null_keys.len() { out.extend(self.null_keys[offset as usize..].iter().copied().take(limit as usize)); - return; + return out.len() as IdxSize; } offset -= self.null_keys.len() as IdxSize; - while let Some((_, _, chunk_ids)) = self.idx_map.get_index(offset) { + let mut keys_processed = 0; + while let Some((_, _, chunk_ids)) = self.idx_map.get_index(offset + keys_processed) { let first_chunk_id = unsafe { chunk_ids.get_unchecked(0) }; let first_chunk_val = first_chunk_id.load(Ordering::Acquire); if first_chunk_val >> 63 == 0 { @@ -284,11 +291,12 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { } } + keys_processed += 1; if out.len() >= limit as usize { break; } - - offset += 1; } + + keys_processed } } diff --git a/crates/polars-expr/src/hash_keys.rs b/crates/polars-expr/src/hash_keys.rs index 4a22c32d991f..59aa4add3443 100644 --- a/crates/polars-expr/src/hash_keys.rs +++ b/crates/polars-expr/src/hash_keys.rs @@ -7,6 +7,7 @@ use polars_core::prelude::PlRandomState; use polars_core::series::Series; use polars_utils::cardinality_sketch::CardinalitySketch; use polars_utils::hashing::HashPartitioner; +use polars_utils::index::ChunkId; use polars_utils::itertools::Itertools; use polars_utils::vec::PushUnchecked; use polars_utils::IdxSize; @@ -58,6 +59,13 @@ impl HashKeys { } } + pub fn len(&self) -> usize { + match self { + HashKeys::RowEncoded(s) => s.keys.len(), + HashKeys::Single(s) => s.keys.len(), + } + } + /// After this call partition_idxs[p] will contain the indices of hashes /// that belong to partition p, and the cardinality sketches are updated /// accordingly. @@ -80,6 +88,20 @@ impl HashKeys { } } } + + /// Generates indices for a chunked gather such that the ith key gathers + /// the next gathers_per_key[i] elements from the partition[i]th chunk. + pub fn gen_partitioned_gather_idxs( + &self, + partitioner: &HashPartitioner, + gathers_per_key: &[IdxSize], + gather_idxs: &mut Vec>, + ) { + match self { + Self::RowEncoded(s) => s.gen_partitioned_gather_idxs(partitioner, gathers_per_key, gather_idxs), + Self::Single(s) => s.gen_partitioned_gather_idxs(partitioner, gathers_per_key, gather_idxs), + } + } /// # Safety /// The indices must be in-bounds. @@ -143,6 +165,26 @@ impl RowEncodedKeys { } } + pub fn gen_partitioned_gather_idxs( + &self, + partitioner: &HashPartitioner, + gathers_per_key: &[IdxSize], + gather_idxs: &mut Vec>, + ) { + assert!(gathers_per_key.len() == self.keys.len()); + unsafe { + let mut offsets = vec![0; partitioner.num_partitions()]; + for (hash, &n) in self.hashes.values_iter().zip(gathers_per_key) { + let p = partitioner.hash_to_partition(*hash); + let offset = *offsets.get_unchecked(p); + for i in offset..offset+n { + gather_idxs.push(ChunkId::store(p as IdxSize, i)); + } + *offsets.get_unchecked_mut(p) += n; + } + } + } + /// # Safety /// The indices must be in-bounds. pub unsafe fn gather(&self, idxs: &[IdxSize]) -> Self { @@ -181,6 +223,15 @@ impl SingleKeys { todo!() } + pub fn gen_partitioned_gather_idxs( + &self, + _partitioner: &HashPartitioner, + _gathers_per_key: &[IdxSize], + _gather_idxs: &mut Vec>, + ) { + todo!() + } + /// # Safety /// The indices must be in-bounds. pub unsafe fn gather(&self, idxs: &[IdxSize]) -> Self { diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index c7d9da14b7a7..a7e55c193fb7 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use polars_core::prelude::{IntoColumn, PlHashSet, PlRandomState}; +use polars_core::prelude::*; use polars_core::schema::{Schema, SchemaExt}; use polars_core::series::IsSorted; use polars_core::utils::accumulate_dataframes_vertical_unchecked; @@ -19,6 +19,7 @@ use crate::async_primitives::wait_group::WaitGroup; use crate::expression::StreamExpr; use crate::morsel::{get_ideal_morsel_size, SourceToken}; use crate::nodes::compute_node_prelude::*; +use crate::nodes::in_memory_source::InMemorySourceNode; /// A payload selector contains for each column whether that column should be /// included in the payload, and if yes with what name. @@ -69,9 +70,12 @@ async fn select_keys( state: &ExecutionState, ) -> PolarsResult { let mut key_columns = Vec::new(); - for selector in key_selectors { + for (i, selector) in key_selectors.iter().enumerate() { + // We use key columns entirely by position, and allow duplicate names, + // so just assign arbitrary unique names. + let unique_name = format_pl_smallstr!("__POLARS_KEYCOL_{i}"); let s = selector.evaluate(&df, state).await?; - key_columns.push(s.into_column()); + key_columns.push(s.into_column().with_name(unique_name)); } let keys = DataFrame::new_with_broadcast_len(key_columns, df.height())?; Ok(HashKeys::from_df( @@ -98,7 +102,7 @@ fn select_payload(df: DataFrame, selector: &[Option]) -> DataFrame { #[derive(Default)] struct BuildPartition { hash_keys: Vec, - frames: Vec, + frames: Vec<(MorselSeq, DataFrame)>, sketch: Option, } @@ -131,9 +135,8 @@ impl BuildState { while let Ok(morsel) = recv.recv().await { // Compute hashed keys and payload. We must rechunk the payload for // later chunked gathers. - let df = morsel.into_df(); - let hash_keys = select_keys(&df, &key_selectors, params, state).await?; - let mut payload = select_payload(df, payload_selector); + let hash_keys = select_keys(morsel.df(), &key_selectors, params, state).await?; + let mut payload = select_payload(morsel.df().clone(), payload_selector); payload.rechunk_mut(); unsafe { @@ -141,7 +144,7 @@ impl BuildState { for (p, idxs_in_p) in partitions.iter_mut().zip(&partition_idxs) { p.hash_keys.push(hash_keys.gather(idxs_in_p)); p.frames - .push(payload.take_slice_unchecked_impl(idxs_in_p, false)); + .push((morsel.seq(), payload.take_slice_unchecked_impl(idxs_in_p, false))); } } } @@ -167,13 +170,21 @@ impl BuildState { sketch.combine(worker[p].sketch.as_ref().unwrap()); num_frames += worker[p].frames.len(); } - + // Build table for this partition. let mut combined_frames = Vec::with_capacity(num_frames); let mut table = table.new_empty(); table.reserve(sketch.estimate() * 5 / 4); - for worker in &self.partitions_per_worker { - for (hash_keys, frame) in worker[p].hash_keys.iter().zip(&worker[p].frames) { + if params.preserve_order_build { + let mut combined = Vec::with_capacity(num_frames); + for worker in &self.partitions_per_worker { + for (hash_keys, (seq, frame)) in worker[p].hash_keys.iter().zip(&worker[p].frames) { + combined.push((seq, hash_keys, frame)); + } + } + + combined.sort_unstable_by_key(|c| c.0); + for (_seq, hash_keys, frame) in combined { // Zero-sized chunks can get deleted, so skip entirely to avoid messing // up the chunk counter. if frame.height() == 0 { @@ -183,6 +194,19 @@ impl BuildState { table.insert_key_chunk(hash_keys.clone(), track_unmatchable); combined_frames.push(frame.clone()); } + } else { + for worker in &self.partitions_per_worker { + for (hash_keys, (_, frame)) in worker[p].hash_keys.iter().zip(&worker[p].frames) { + // Zero-sized chunks can get deleted, so skip entirely to avoid messing + // up the chunk counter. + if frame.height() == 0 { + continue; + } + + table.insert_key_chunk(hash_keys.clone(), track_unmatchable); + combined_frames.push(frame.clone()); + } + } } let df = if combined_frames.is_empty() { @@ -209,6 +233,7 @@ struct ProbeTable { // into the table must be preserved for chunked gathers. table: Box, df: DataFrame, + chunk_seq_ids: Vec, } struct ProbeState { @@ -251,38 +276,92 @@ impl ProbeState { unsafe { // Partition and probe the tables. hash_keys.gen_partition_idxs(&partitioner, &mut partition_idxs, &mut [], emit_unmatched); - for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { - let mut offset = 0; - while offset < idxs_in_p.len() { - offset += p.table.probe_subset( + if params.preserve_order_probe { + // TODO: non-sort based implementation, can directly scatter + // after finding matches for each partition. + let mut out_per_partition = Vec::with_capacity(partitioner.num_partitions()); + let name = PlSmallStr::from_static("__POLARS_PROBE_PRESERVE_ORDER_IDX"); + for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { + p.table.probe_subset( &hash_keys, - &idxs_in_p[offset..], + &idxs_in_p, &mut table_match, &mut probe_match, mark_matches, emit_unmatched, - probe_limit, - ) as usize; + IdxSize::MAX, + ); - // Gather output and send. let mut build_df = if emit_unmatched { p.df.take_opt_chunked_unchecked(&table_match) } else { p.df.take_chunked_unchecked(&table_match, IsSorted::Not) }; let mut probe_df = payload.take_slice_unchecked(&probe_match); - - let out_df = if params.left_is_build { + + let mut out_df = if params.left_is_build { build_df.hstack_mut_unchecked(probe_df.get_columns()); build_df } else { probe_df.hstack_mut_unchecked(build_df.get_columns()); probe_df }; - - let out_morsel = Morsel::new(out_df, seq, src_token.clone()); - if send.send(out_morsel).await.is_err() { - break; + + let idxs_ca = IdxCa::from_vec(name.clone(), core::mem::take(&mut probe_match)); + out_df.with_column_unchecked(idxs_ca.into_column()); + out_per_partition.push(out_df); + } + + let sort_options = SortMultipleOptions { + descending: vec![false], + nulls_last: vec![false], + multithreaded: false, + maintain_order: true, + limit: None, + }; + let mut out_df = accumulate_dataframes_vertical_unchecked(out_per_partition); + out_df.sort_in_place([name.clone()], sort_options).unwrap(); + out_df.drop_in_place(&name).unwrap(); + + // TODO: break in smaller morsels. + let out_morsel = Morsel::new(out_df, seq, src_token.clone()); + if send.send(out_morsel).await.is_err() { + break; + } + } else { + for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { + let mut offset = 0; + while offset < idxs_in_p.len() { + offset += p.table.probe_subset( + &hash_keys, + &idxs_in_p[offset..], + &mut table_match, + &mut probe_match, + mark_matches, + emit_unmatched, + probe_limit, + ) as usize; + + // Gather output and send. + let mut build_df = if emit_unmatched { + p.df.take_opt_chunked_unchecked(&table_match) + } else { + p.df.take_chunked_unchecked(&table_match, IsSorted::Not) + }; + let mut probe_df = payload.take_slice_unchecked(&probe_match); + + let out_df = if params.left_is_build { + build_df.hstack_mut_unchecked(probe_df.get_columns()); + build_df + } else { + probe_df.hstack_mut_unchecked(build_df.get_columns()); + probe_df + }; + + let out_morsel = Morsel::new(out_df, seq, src_token.clone()); + if send.send(out_morsel).await.is_err() { + break; + } } } } @@ -293,6 +372,55 @@ impl ProbeState { Ok(()) } + + fn ordered_unmatched(&mut self) -> DataFrame { + let mut out_per_partition = Vec::with_capacity(partitioner.num_partitions()); + let name = PlSmallStr::from_static("__POLARS_PROBE_PRESERVE_ORDER_IDX"); + let mut unmarked_idxs = Vec::new(); + for p in self.table_per_partition.iter() { + p.table.unmarked_keys( + &mut unmarked_idxs, + 0, + IdxSize::MAX, + ); + + let mut build_df = if emit_unmatched { + p.df.take_opt_chunked_unchecked(&table_match) + } else { + p.df.take_chunked_unchecked(&table_match, IsSorted::Not) + }; + let mut probe_df = payload.take_slice_unchecked(&probe_match); + + let mut out_df = if params.left_is_build { + build_df.hstack_mut_unchecked(probe_df.get_columns()); + build_df + } else { + probe_df.hstack_mut_unchecked(build_df.get_columns()); + probe_df + }; + + let idxs_ca = IdxCa::from_vec(name.clone(), core::mem::take(&mut probe_match)); + out_df.with_column_unchecked(idxs_ca.into_column()); + out_per_partition.push(out_df); + } + + let sort_options = SortMultipleOptions { + descending: vec![false], + nulls_last: vec![false], + multithreaded: false, + maintain_order: true, + limit: None, + }; + let mut out_df = accumulate_dataframes_vertical_unchecked(out_per_partition); + out_df.sort_in_place([name.clone()], sort_options).unwrap(); + out_df.drop_in_place(&name).unwrap(); + + // TODO: break in smaller morsels. + let out_morsel = Morsel::new(out_df, seq, src_token.clone()); + if send.send(out_morsel).await.is_err() { + break; + } + } } struct EmitUnmatchedState { @@ -324,12 +452,11 @@ impl EmitUnmatchedState { while let Some(p) = self.partitions.get(self.active_partition_idx) { loop { // Generate a chunk of unmarked key indices. - p.table.unmarked_keys( + self.offset_in_active_p += p.table.unmarked_keys( &mut unmarked_idxs, self.offset_in_active_p as IdxSize, morsel_size as IdxSize, - ); - self.offset_in_active_p += unmarked_idxs.len(); + ) as usize; if unmarked_idxs.is_empty() { break; } @@ -375,11 +502,14 @@ enum EquiJoinState { Build(BuildState), Probe(ProbeState), EmitUnmatchedBuild(EmitUnmatchedState), + EmitUnmatchedBuildInOrder(InMemorySourceNode), Done, } struct EquiJoinParams { left_is_build: bool, + preserve_order_build: bool, + preserve_order_probe: bool, left_key_selectors: Vec, right_key_selectors: Vec, left_payload_select: Vec>, @@ -429,6 +559,9 @@ impl EquiJoinNode { ) -> Self { // TODO: use cardinality estimation to determine this. let left_is_build = args.how != JoinType::Left; + + // TODO: expose as a parameter, and let you choose the primary order to preserve. + let preserve_order = std::env::var("POLARS_JOIN_IGNORE_ORDER").as_deref() != Ok("1"); let left_payload_select = compute_payload_selector(&left_input_schema, &right_input_schema, &left_key_schema, true, &args); @@ -450,6 +583,8 @@ impl EquiJoinNode { num_pipelines: 0, params: EquiJoinParams { left_is_build, + preserve_order_build: preserve_order, + preserve_order_probe: preserve_order, left_key_selectors, right_key_selectors, left_payload_select, @@ -496,11 +631,15 @@ impl ComputeNode for EquiJoinNode { if let EquiJoinState::Probe(probe_state) = &mut self.state { if recv[probe_idx] == PortState::Done { if self.params.emit_unmatched_build() { - self.state = EquiJoinState::EmitUnmatchedBuild(EmitUnmatchedState { - partitions: core::mem::take(&mut probe_state.table_per_partition), - active_partition_idx: 0, - offset_in_active_p: 0, - }); + if self.params.preserve_order_build { + self.state = EquiJoinState::EmitUnmatchedBuild(EmitUnmatchedState { + partitions: core::mem::take(&mut probe_state.table_per_partition), + active_partition_idx: 0, + offset_in_active_p: 0, + }); + } else { + + } } else { self.state = EquiJoinState::Done; } @@ -529,6 +668,14 @@ impl ComputeNode for EquiJoinNode { recv[build_idx] = PortState::Done; recv[probe_idx] = PortState::Done; }, + EquiJoinState::EmitUnmatchedBuildInOrder(src_node) => { + recv[build_idx] = PortState::Done; + recv[probe_idx] = PortState::Done; + src_node.update_state(&mut [], &mut send[0..1])?; + if send[0] == PortState::Done { + self.state = EquiJoinState::Done; + } + } EquiJoinState::Done => { send[0] = PortState::Done; recv[0] = PortState::Done; @@ -609,6 +756,9 @@ impl ComputeNode for EquiJoinNode { emit_state.emit_unmatched(send, &self.params, self.num_pipelines), )); }, + EquiJoinState::EmitUnmatchedBuildInOrder(src_node) => { + src_node.spawn(scope, recv_ports, send_ports, state, join_handles); + } EquiJoinState::Done => unreachable!(), } } From ccb2d8049a33231f9d8f490803525be820ff02f8 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 25 Nov 2024 11:17:05 +0100 Subject: [PATCH 10/21] fix morsel sequence monotonicity --- crates/polars-stream/src/nodes/group_by.rs | 2 +- .../polars-stream/src/nodes/in_memory_map.rs | 2 +- .../src/nodes/in_memory_source.rs | 7 +- .../src/nodes/joins/equi_join.rs | 250 +++++++++++------- .../src/nodes/joins/in_memory.rs | 2 +- .../src/physical_plan/to_graph.rs | 3 +- 6 files changed, 169 insertions(+), 97 deletions(-) diff --git a/crates/polars-stream/src/nodes/group_by.rs b/crates/polars-stream/src/nodes/group_by.rs index 0151970ee766..c08284d9f009 100644 --- a/crates/polars-stream/src/nodes/group_by.rs +++ b/crates/polars-stream/src/nodes/group_by.rs @@ -221,7 +221,7 @@ impl GroupBySinkState { Self::combine_locals_parallel(num_partitions, output_schema, self.local) }; - let mut source_node = InMemorySourceNode::new(Arc::new(df?)); + let mut source_node = InMemorySourceNode::new(Arc::new(df?), MorselSeq::default()); source_node.initialize(num_pipelines); Ok(source_node) } diff --git a/crates/polars-stream/src/nodes/in_memory_map.rs b/crates/polars-stream/src/nodes/in_memory_map.rs index 27af6be9aa87..118827f76529 100644 --- a/crates/polars-stream/src/nodes/in_memory_map.rs +++ b/crates/polars-stream/src/nodes/in_memory_map.rs @@ -56,7 +56,7 @@ impl ComputeNode for InMemoryMapNode { { if recv[0] == PortState::Done { let df = sink_node.get_output()?; - let mut source_node = InMemorySourceNode::new(Arc::new(map.call_udf(df.unwrap())?)); + let mut source_node = InMemorySourceNode::new(Arc::new(map.call_udf(df.unwrap())?), MorselSeq::default()); source_node.initialize(*num_pipelines); *self = Self::Source(source_node); } diff --git a/crates/polars-stream/src/nodes/in_memory_source.rs b/crates/polars-stream/src/nodes/in_memory_source.rs index ab3231b1c759..b8d07c756a34 100644 --- a/crates/polars-stream/src/nodes/in_memory_source.rs +++ b/crates/polars-stream/src/nodes/in_memory_source.rs @@ -9,14 +9,16 @@ pub struct InMemorySourceNode { source: Option>, morsel_size: usize, seq: AtomicU64, + seq_offset: MorselSeq, } impl InMemorySourceNode { - pub fn new(source: Arc) -> Self { + pub fn new(source: Arc, seq_offset: MorselSeq) -> Self { InMemorySourceNode { source: Some(source), morsel_size: 0, seq: AtomicU64::new(0), + seq_offset } } } @@ -87,7 +89,8 @@ impl ComputeNode for InMemorySourceNode { break; } - let mut morsel = Morsel::new(df, MorselSeq::new(seq), source_token.clone()); + let morsel_seq = MorselSeq::new(seq).offset_by(slf.seq_offset); + let mut morsel = Morsel::new(df, morsel_seq, source_token.clone()); morsel.set_consume_token(wait_group.token()); if send.send(morsel).await.is_err() { break; diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index a7e55c193fb7..69f78fe1ca86 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -10,6 +10,7 @@ use polars_ops::frame::{JoinArgs, JoinType}; use polars_ops::prelude::TakeChunked; use polars_utils::cardinality_sketch::CardinalitySketch; use polars_utils::hashing::HashPartitioner; +use polars_utils::itertools::Itertools; use polars_utils::pl_str::PlSmallStr; use polars_utils::{format_pl_smallstr, IdxSize}; use rayon::prelude::*; @@ -45,7 +46,7 @@ fn compute_payload_selector( if !other.contains(c) { return Some(c.clone()); } - + if is_left { Some(c.clone()) } else { @@ -140,11 +141,18 @@ impl BuildState { payload.rechunk_mut(); unsafe { - hash_keys.gen_partition_idxs(&partitioner, &mut partition_idxs, &mut sketches, track_unmatchable); + hash_keys.gen_partition_idxs( + &partitioner, + &mut partition_idxs, + &mut sketches, + track_unmatchable, + ); for (p, idxs_in_p) in partitions.iter_mut().zip(&partition_idxs) { p.hash_keys.push(hash_keys.gather(idxs_in_p)); - p.frames - .push((morsel.seq(), payload.take_slice_unchecked_impl(idxs_in_p, false))); + p.frames.push(( + morsel.seq(), + payload.take_slice_unchecked_impl(idxs_in_p, false), + )); } } } @@ -170,21 +178,24 @@ impl BuildState { sketch.combine(worker[p].sketch.as_ref().unwrap()); num_frames += worker[p].frames.len(); } - + // Build table for this partition. let mut combined_frames = Vec::with_capacity(num_frames); + let mut chunk_seq_ids = Vec::with_capacity(num_frames); let mut table = table.new_empty(); table.reserve(sketch.estimate() * 5 / 4); if params.preserve_order_build { let mut combined = Vec::with_capacity(num_frames); for worker in &self.partitions_per_worker { - for (hash_keys, (seq, frame)) in worker[p].hash_keys.iter().zip(&worker[p].frames) { + for (hash_keys, (seq, frame)) in + worker[p].hash_keys.iter().zip(&worker[p].frames) + { combined.push((seq, hash_keys, frame)); } } - + combined.sort_unstable_by_key(|c| c.0); - for (_seq, hash_keys, frame) in combined { + for (seq, hash_keys, frame) in combined { // Zero-sized chunks can get deleted, so skip entirely to avoid messing // up the chunk counter. if frame.height() == 0 { @@ -193,10 +204,13 @@ impl BuildState { table.insert_key_chunk(hash_keys.clone(), track_unmatchable); combined_frames.push(frame.clone()); + chunk_seq_ids.push(*seq); } } else { for worker in &self.partitions_per_worker { - for (hash_keys, (_, frame)) in worker[p].hash_keys.iter().zip(&worker[p].frames) { + for (hash_keys, (_, frame)) in + worker[p].hash_keys.iter().zip(&worker[p].frames) + { // Zero-sized chunks can get deleted, so skip entirely to avoid messing // up the chunk counter. if frame.height() == 0 { @@ -218,12 +232,17 @@ impl BuildState { } else { accumulate_dataframes_vertical_unchecked(combined_frames) }; - ProbeTable { table, df } + ProbeTable { + table, + df, + chunk_seq_ids, + } }) .collect(); ProbeState { table_per_partition, + max_seq_sent: MorselSeq::default(), } } } @@ -233,15 +252,16 @@ struct ProbeTable { // into the table must be preserved for chunked gathers. table: Box, df: DataFrame, - chunk_seq_ids: Vec, + chunk_seq_ids: Vec, } struct ProbeState { table_per_partition: Vec, + max_seq_sent: MorselSeq, } impl ProbeState { - // TODO: shuffle after partitioning and keep probe tables thread-local. + /// Returns the max morsel sequence sent. async fn partition_and_probe( mut recv: Receiver, mut send: Sender, @@ -249,11 +269,13 @@ impl ProbeState { partitioner: HashPartitioner, params: &EquiJoinParams, state: &ExecutionState, - ) -> PolarsResult<()> { + ) -> PolarsResult { + // TODO: shuffle after partitioning and keep probe tables thread-local. let mut partition_idxs = vec![Vec::new(); partitioner.num_partitions()]; let mut table_match = Vec::new(); let mut probe_match = Vec::new(); - + let mut max_seq = MorselSeq::default(); + let probe_limit = get_ideal_morsel_size() as IdxSize; let mark_matches = params.emit_unmatched_build(); let emit_unmatched = params.emit_unmatched_probe(); @@ -272,10 +294,16 @@ impl ProbeState { let (df, seq, src_token, wait_token) = morsel.into_inner(); let hash_keys = select_keys(&df, &key_selectors, params, state).await?; let payload = select_payload(df, payload_selector); + max_seq = seq; unsafe { // Partition and probe the tables. - hash_keys.gen_partition_idxs(&partitioner, &mut partition_idxs, &mut [], emit_unmatched); + hash_keys.gen_partition_idxs( + &partitioner, + &mut partition_idxs, + &mut [], + emit_unmatched, + ); if params.preserve_order_probe { // TODO: non-sort based implementation, can directly scatter // after finding matches for each partition. @@ -298,7 +326,7 @@ impl ProbeState { p.df.take_chunked_unchecked(&table_match, IsSorted::Not) }; let mut probe_df = payload.take_slice_unchecked(&probe_match); - + let mut out_df = if params.left_is_build { build_df.hstack_mut_unchecked(probe_df.get_columns()); build_df @@ -306,12 +334,13 @@ impl ProbeState { probe_df.hstack_mut_unchecked(build_df.get_columns()); probe_df }; - - let idxs_ca = IdxCa::from_vec(name.clone(), core::mem::take(&mut probe_match)); + + let idxs_ca = + IdxCa::from_vec(name.clone(), core::mem::take(&mut probe_match)); out_df.with_column_unchecked(idxs_ca.into_column()); out_per_partition.push(out_df); } - + let sort_options = SortMultipleOptions { descending: vec![false], nulls_last: vec![false], @@ -322,7 +351,7 @@ impl ProbeState { let mut out_df = accumulate_dataframes_vertical_unchecked(out_per_partition); out_df.sort_in_place([name.clone()], sort_options).unwrap(); out_df.drop_in_place(&name).unwrap(); - + // TODO: break in smaller morsels. let out_morsel = Morsel::new(out_df, seq, src_token.clone()); if send.send(out_morsel).await.is_err() { @@ -349,7 +378,7 @@ impl ProbeState { p.df.take_chunked_unchecked(&table_match, IsSorted::Not) }; let mut probe_df = payload.take_slice_unchecked(&probe_match); - + let out_df = if params.left_is_build { build_df.hstack_mut_unchecked(probe_df.get_columns()); build_df @@ -357,7 +386,7 @@ impl ProbeState { probe_df.hstack_mut_unchecked(build_df.get_columns()); probe_df }; - + let out_morsel = Morsel::new(out_df, seq, src_token.clone()); if send.send(out_morsel).await.is_err() { break; @@ -370,55 +399,67 @@ impl ProbeState { drop(wait_token); } - Ok(()) + Ok(max_seq) } - - fn ordered_unmatched(&mut self) -> DataFrame { + + fn ordered_unmatched( + &mut self, + partitioner: &HashPartitioner, + params: &EquiJoinParams, + ) -> DataFrame { let mut out_per_partition = Vec::with_capacity(partitioner.num_partitions()); - let name = PlSmallStr::from_static("__POLARS_PROBE_PRESERVE_ORDER_IDX"); + let seq_name = PlSmallStr::from_static("__POLARS_PROBE_PRESERVE_ORDER_SEQ"); + let idx_name = PlSmallStr::from_static("__POLARS_PROBE_PRESERVE_ORDER_IDX"); let mut unmarked_idxs = Vec::new(); - for p in self.table_per_partition.iter() { - p.table.unmarked_keys( - &mut unmarked_idxs, - 0, - IdxSize::MAX, - ); - - let mut build_df = if emit_unmatched { - p.df.take_opt_chunked_unchecked(&table_match) - } else { - p.df.take_chunked_unchecked(&table_match, IsSorted::Not) - }; - let mut probe_df = payload.take_slice_unchecked(&probe_match); - - let mut out_df = if params.left_is_build { - build_df.hstack_mut_unchecked(probe_df.get_columns()); - build_df - } else { - probe_df.hstack_mut_unchecked(build_df.get_columns()); - probe_df + unsafe { + for p in self.table_per_partition.iter() { + p.table.unmarked_keys(&mut unmarked_idxs, 0, IdxSize::MAX); + + // Gather and create full-null counterpart. + let mut build_df = p.df.take_chunked_unchecked(&unmarked_idxs, IsSorted::Not); + let len = build_df.height(); + let mut out_df = if params.left_is_build { + let probe_df = DataFrame::full_null(¶ms.right_payload_schema, len); + build_df.hstack_mut_unchecked(probe_df.get_columns()); + build_df + } else { + let mut probe_df = DataFrame::full_null(¶ms.left_payload_schema, len); + probe_df.hstack_mut_unchecked(build_df.get_columns()); + probe_df + }; + + // The indices are not ordered globally, but within each chunk they are, so sorting + // by chunk sequence id, breaking ties by inner chunk idx works. + let (chunk_seqs, idx_in_chunk) = unmarked_idxs + .iter() + .map(|chunk_id| { + let (chunk, idx_in_chunk) = chunk_id.extract(); + (p.chunk_seq_ids[chunk as usize].to_u64(), idx_in_chunk) + }) + .unzip(); + + let chunk_seqs_ca = UInt64Chunked::from_vec(seq_name.clone(), chunk_seqs); + let idxs_ca = IdxCa::from_vec(idx_name.clone(), idx_in_chunk); + out_df.with_column_unchecked(chunk_seqs_ca.into_column()); + out_df.with_column_unchecked(idxs_ca.into_column()); + out_per_partition.push(out_df); + } + + // Sort by chunk sequence id, then by inner chunk idx. + let sort_options = SortMultipleOptions { + descending: vec![false], + nulls_last: vec![false], + multithreaded: true, + maintain_order: false, + limit: None, }; - - let idxs_ca = IdxCa::from_vec(name.clone(), core::mem::take(&mut probe_match)); - out_df.with_column_unchecked(idxs_ca.into_column()); - out_per_partition.push(out_df); - } - - let sort_options = SortMultipleOptions { - descending: vec![false], - nulls_last: vec![false], - multithreaded: false, - maintain_order: true, - limit: None, - }; - let mut out_df = accumulate_dataframes_vertical_unchecked(out_per_partition); - out_df.sort_in_place([name.clone()], sort_options).unwrap(); - out_df.drop_in_place(&name).unwrap(); - - // TODO: break in smaller morsels. - let out_morsel = Morsel::new(out_df, seq, src_token.clone()); - if send.send(out_morsel).await.is_err() { - break; + let mut out_df = accumulate_dataframes_vertical_unchecked(out_per_partition); + out_df + .sort_in_place([seq_name.clone(), idx_name.clone()], sort_options) + .unwrap(); + out_df.drop_in_place(&seq_name).unwrap(); + out_df.drop_in_place(&idx_name).unwrap(); + out_df } } } @@ -427,6 +468,7 @@ struct EmitUnmatchedState { partitions: Vec, active_partition_idx: usize, offset_in_active_p: usize, + morsel_seq: MorselSeq, } impl EmitUnmatchedState { @@ -445,7 +487,6 @@ impl EmitUnmatchedState { let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines); let morsel_size = total_len.div_ceil(morsel_count).max(1); - let mut morsel_seq = MorselSeq::default(); let wait_group = WaitGroup::default(); let source_token = SourceToken::new(); let mut unmarked_idxs = Vec::new(); @@ -477,8 +518,8 @@ impl EmitUnmatchedState { }; // Send and wait until consume token is consumed. - let mut morsel = Morsel::new(out_df, morsel_seq, source_token.clone()); - morsel_seq = morsel_seq.successor(); + let mut morsel = Morsel::new(out_df, self.morsel_seq, source_token.clone()); + self.morsel_seq = self.morsel_seq.successor(); morsel.set_consume_token(wait_group.token()); if send.send(morsel).await.is_err() { return Ok(()); @@ -559,14 +600,24 @@ impl EquiJoinNode { ) -> Self { // TODO: use cardinality estimation to determine this. let left_is_build = args.how != JoinType::Left; - + // TODO: expose as a parameter, and let you choose the primary order to preserve. let preserve_order = std::env::var("POLARS_JOIN_IGNORE_ORDER").as_deref() != Ok("1"); - let left_payload_select = - compute_payload_selector(&left_input_schema, &right_input_schema, &left_key_schema, true, &args); - let right_payload_select = - compute_payload_selector(&right_input_schema, &left_input_schema, &right_key_schema, false, &args); + let left_payload_select = compute_payload_selector( + &left_input_schema, + &right_input_schema, + &left_key_schema, + true, + &args, + ); + let right_payload_select = compute_payload_selector( + &right_input_schema, + &left_input_schema, + &right_key_schema, + false, + &args, + ); let table = if left_is_build { new_chunked_idx_table(left_key_schema) @@ -636,9 +687,14 @@ impl ComputeNode for EquiJoinNode { partitions: core::mem::take(&mut probe_state.table_per_partition), active_partition_idx: 0, offset_in_active_p: 0, + morsel_seq: probe_state.max_seq_sent.successor(), }); } else { - + let partitioner = HashPartitioner::new(self.num_pipelines, 0); + let unmatched = probe_state.ordered_unmatched(&partitioner, &self.params); + let mut src = InMemorySourceNode::new(Arc::new(unmatched), probe_state.max_seq_sent.successor()); + src.initialize(self.num_pipelines); + self.state = EquiJoinState::EmitUnmatchedBuildInOrder(src); } } else { self.state = EquiJoinState::Done; @@ -675,7 +731,7 @@ impl ComputeNode for EquiJoinNode { if send[0] == PortState::Done { self.state = EquiJoinState::Done; } - } + }, EquiJoinState::Done => { send[0] = PortState::Done; recv[0] = PortState::Done; @@ -733,19 +789,31 @@ impl ComputeNode for EquiJoinNode { let senders = send_ports[0].take().unwrap().parallel(); let partitioner = HashPartitioner::new(self.num_pipelines, 0); - for (recv, send) in receivers.into_iter().zip(senders.into_iter()) { - join_handles.push(scope.spawn_task( - TaskPriority::High, - ProbeState::partition_and_probe( - recv, - send, - &probe_state.table_per_partition, - partitioner.clone(), - &self.params, - state, - ), - )); - } + let probe_tasks = receivers + .into_iter() + .zip(senders.into_iter()) + .map(|(recv, send)| { + scope.spawn_task( + TaskPriority::High, + ProbeState::partition_and_probe( + recv, + send, + &probe_state.table_per_partition, + partitioner.clone(), + &self.params, + state, + ), + ) + }) + .collect_vec(); + + let max_seq_sent = &mut probe_state.max_seq_sent; + join_handles.push(scope.spawn_task(TaskPriority::High, async move { + for probe_task in probe_tasks { + *max_seq_sent = (*max_seq_sent).max(probe_task.await?); + } + Ok(()) + })); }, EquiJoinState::EmitUnmatchedBuild(emit_state) => { assert!(recv_ports[build_idx].is_none()); @@ -758,7 +826,7 @@ impl ComputeNode for EquiJoinNode { }, EquiJoinState::EmitUnmatchedBuildInOrder(src_node) => { src_node.spawn(scope, recv_ports, send_ports, state, join_handles); - } + }, EquiJoinState::Done => unreachable!(), } } diff --git a/crates/polars-stream/src/nodes/joins/in_memory.rs b/crates/polars-stream/src/nodes/joins/in_memory.rs index a98c23a435b0..79b45d074e5c 100644 --- a/crates/polars-stream/src/nodes/joins/in_memory.rs +++ b/crates/polars-stream/src/nodes/joins/in_memory.rs @@ -61,7 +61,7 @@ impl ComputeNode for InMemoryJoinNode { let left_df = left.get_output()?.unwrap(); let right_df = right.get_output()?.unwrap(); let mut source_node = - InMemorySourceNode::new(Arc::new((self.joiner)(left_df, right_df)?)); + InMemorySourceNode::new(Arc::new((self.joiner)(left_df, right_df)?), MorselSeq::default()); source_node.initialize(self.num_pipelines); self.state = InMemoryJoinState::Source(source_node); } diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index 6b51ac1ecf90..d7f776b9d0ad 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -22,6 +22,7 @@ use slotmap::{SecondaryMap, SlotMap}; use super::{PhysNode, PhysNodeKey, PhysNodeKind}; use crate::expression::StreamExpr; use crate::graph::{Graph, GraphNodeKey}; +use crate::morsel::MorselSeq; use crate::nodes; use crate::physical_plan::lower_expr::compute_output_schema; use crate::utils::late_materialized_df::LateMaterializedDataFrame; @@ -92,7 +93,7 @@ fn to_graph_rec<'a>( let node = &ctx.phys_sm[phys_node_key]; let graph_key = match &node.kind { InMemorySource { df } => ctx.graph.add_node( - nodes::in_memory_source::InMemorySourceNode::new(df.clone()), + nodes::in_memory_source::InMemorySourceNode::new(df.clone(), MorselSeq::default()), [], ), From 57df7c09420941b9ec43769ea3fae980ce36cd66 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 25 Nov 2024 14:37:52 +0100 Subject: [PATCH 11/21] physically coalesce for full joins and check duplicate column --- .../src/nodes/joins/equi_join.rs | 90 ++++++++++++++----- .../src/physical_plan/to_graph.rs | 2 +- 2 files changed, 70 insertions(+), 22 deletions(-) diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index 69f78fe1ca86..929597e9f5cf 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -8,6 +8,7 @@ use polars_expr::chunked_idx_table::{new_chunked_idx_table, ChunkedIdxTable}; use polars_expr::hash_keys::HashKeys; use polars_ops::frame::{JoinArgs, JoinType}; use polars_ops::prelude::TakeChunked; +use polars_ops::series::coalesce_columns; use polars_utils::cardinality_sketch::CardinalitySketch; use polars_utils::hashing::HashPartitioner; use polars_utils::itertools::Itertools; @@ -30,32 +31,70 @@ fn compute_payload_selector( this_key_schema: &Schema, is_left: bool, args: &JoinArgs, -) -> Vec> { +) -> PolarsResult>> { let should_coalesce = args.should_coalesce(); this.iter_names() - .map(|c| { - if should_coalesce && this_key_schema.contains(c) { + .enumerate() + .map(|(i, c)| { + let selector = if should_coalesce && this_key_schema.contains(c) { if is_left != (args.how == JoinType::Right) { - return Some(c.clone()); + Some(c.clone()) } else { - return None; + if args.how == JoinType::Full { + // We must keep the right-hand side keycols around for + // coalescing. + Some(format_pl_smallstr!("__POLARS_COALESCE_KEYCOL{i}")) + } else { + None + } } - } - - if !other.contains(c) { - return Some(c.clone()); - } - - if is_left { + } else if !other.contains(c) || is_left { Some(c.clone()) } else { - Some(format_pl_smallstr!("{}{}", c, args.suffix())) - } + let suffixed = format_pl_smallstr!("{}{}", c, args.suffix()); + if other.contains(&suffixed) { + polars_bail!(Duplicate: "column with name '{suffixed}' already exists\n\n\ + You may want to try:\n\ + - renaming the column prior to joining\n\ + - using the `suffix` parameter to specify a suffix different to the default one ('_right')") + } + Some(suffixed) + }; + Ok(selector) }) .collect() } +fn postprocess_join(df: DataFrame, params: &EquiJoinParams) -> DataFrame { + if params.args.how == JoinType::Full && params.args.should_coalesce() { + // TODO: don't do string-based column lookups for each dataframe, pre-compute coalesce indices. + let mut key_idx = 0; + df.get_columns() + .iter() + .filter_map(|c| { + if let Some((key_name, _)) = params.left_key_schema.get_at_index(key_idx) { + if c.name() == key_name { + let other = df + .column(&format_pl_smallstr!("__POLARS_COALESCE_KEYCOL{key_idx}")) + .unwrap(); + key_idx += 1; + return Some(coalesce_columns(&[c.clone(), other.clone()]).unwrap()); + } + } + + if c.name().starts_with("__POLARS_COALESCE_KEYCOL") { + return None; + } + + Some(c.clone()) + }) + .collect() + } else { + df + } +} + fn select_schema(schema: &Schema, selector: &[Option]) -> Schema { schema .iter_fields() @@ -351,6 +390,7 @@ impl ProbeState { let mut out_df = accumulate_dataframes_vertical_unchecked(out_per_partition); out_df.sort_in_place([name.clone()], sort_options).unwrap(); out_df.drop_in_place(&name).unwrap(); + out_df = postprocess_join(out_df, params); // TODO: break in smaller morsels. let out_morsel = Morsel::new(out_df, seq, src_token.clone()); @@ -386,6 +426,7 @@ impl ProbeState { probe_df.hstack_mut_unchecked(build_df.get_columns()); probe_df }; + let out_df = postprocess_join(out_df, params); let out_morsel = Morsel::new(out_df, seq, src_token.clone()); if send.send(out_morsel).await.is_err() { @@ -459,6 +500,7 @@ impl ProbeState { .unwrap(); out_df.drop_in_place(&seq_name).unwrap(); out_df.drop_in_place(&idx_name).unwrap(); + out_df = postprocess_join(out_df, params); out_df } } @@ -516,6 +558,7 @@ impl EmitUnmatchedState { probe_df } }; + let out_df = postprocess_join(out_df, params); // Send and wait until consume token is consumed. let mut morsel = Morsel::new(out_df, self.morsel_seq, source_token.clone()); @@ -551,6 +594,7 @@ struct EquiJoinParams { left_is_build: bool, preserve_order_build: bool, preserve_order_probe: bool, + left_key_schema: Arc, left_key_selectors: Vec, right_key_selectors: Vec, left_payload_select: Vec>, @@ -597,7 +641,7 @@ impl EquiJoinNode { left_key_selectors: Vec, right_key_selectors: Vec, args: JoinArgs, - ) -> Self { + ) -> PolarsResult { // TODO: use cardinality estimation to determine this. let left_is_build = args.how != JoinType::Left; @@ -610,24 +654,24 @@ impl EquiJoinNode { &left_key_schema, true, &args, - ); + )?; let right_payload_select = compute_payload_selector( &right_input_schema, &left_input_schema, &right_key_schema, false, &args, - ); + )?; let table = if left_is_build { - new_chunked_idx_table(left_key_schema) + new_chunked_idx_table(left_key_schema.clone()) } else { new_chunked_idx_table(right_key_schema) }; let left_payload_schema = select_schema(&left_input_schema, &left_payload_select); let right_payload_schema = select_schema(&right_input_schema, &right_payload_select); - Self { + Ok(Self { state: EquiJoinState::Build(BuildState { partitions_per_worker: Vec::new(), }), @@ -636,6 +680,7 @@ impl EquiJoinNode { left_is_build, preserve_order_build: preserve_order, preserve_order_probe: preserve_order, + left_key_schema, left_key_selectors, right_key_selectors, left_payload_select, @@ -646,7 +691,7 @@ impl EquiJoinNode { random_state: PlRandomState::new(), }, table, - } + }) } } @@ -692,7 +737,10 @@ impl ComputeNode for EquiJoinNode { } else { let partitioner = HashPartitioner::new(self.num_pipelines, 0); let unmatched = probe_state.ordered_unmatched(&partitioner, &self.params); - let mut src = InMemorySourceNode::new(Arc::new(unmatched), probe_state.max_seq_sent.successor()); + let mut src = InMemorySourceNode::new( + Arc::new(unmatched), + probe_state.max_seq_sent.successor(), + ); src.initialize(self.num_pipelines); self.state = EquiJoinState::EmitUnmatchedBuildInOrder(src); } diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index d7f776b9d0ad..4a36eb3b5a2e 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -541,7 +541,7 @@ fn to_graph_rec<'a>( left_key_selectors, right_key_selectors, args - ), + )?, [left_input_key, right_input_key], ) }, From d1d5ccd97b829c848f625e5687e0b436bdf228aa Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 25 Nov 2024 15:54:11 +0100 Subject: [PATCH 12/21] ignore row order for full/inner join in update --- py-polars/tests/unit/operations/test_join.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index 27cba18e18d5..84433dcac3f2 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -537,10 +537,10 @@ def test_update() -> None: assert result.collect().to_series().to_list() == [1, 2, 3] result = a.update(b, how="inner", left_on="a", right_on="c") - assert result.collect().to_series().to_list() == [1, 3] + assert sorted(result.collect().to_series().to_list()) == [1, 3] result = a.update(b.rename({"b": "a"}), how="full", on="a") - assert result.collect().to_series().sort().to_list() == [1, 2, 3, 4, 5] + assert sorted(result.collect().to_series().sort().to_list()) == [1, 2, 3, 4, 5] # check behavior of include_nulls=True df = pl.DataFrame( @@ -562,7 +562,7 @@ def test_update() -> None: "B": [-99, 500, None, 700, -66], } ) - assert_frame_equal(out, expected) + assert_frame_equal(out, expected, check_row_order=False) # edge-case #11684 x = pl.DataFrame({"a": [0, 1]}) @@ -604,6 +604,7 @@ def test_join_concat_projection_pd_case_7071() -> None: assert_frame_equal(result, expected) +@pytest.mark.may_fail_auto_streaming # legacy full join is not order-preserving whereas new-streaming is def test_join_sorted_fast_paths_null() -> None: df1 = pl.DataFrame({"x": [0, 1, 0]}).sort("x") df2 = pl.DataFrame({"x": [0, None], "y": [0, 1]}) From 17ef4b90ebf2db8982e23b1c15edf1ca60a3c99b Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 25 Nov 2024 15:54:29 +0100 Subject: [PATCH 13/21] skip tests not made for new-streaming engine --- py-polars/tests/unit/operations/test_is_sorted.py | 1 + py-polars/tests/unit/test_string_cache.py | 1 + 2 files changed, 2 insertions(+) diff --git a/py-polars/tests/unit/operations/test_is_sorted.py b/py-polars/tests/unit/operations/test_is_sorted.py index 093dae47bfbf..b14062a988ab 100644 --- a/py-polars/tests/unit/operations/test_is_sorted.py +++ b/py-polars/tests/unit/operations/test_is_sorted.py @@ -331,6 +331,7 @@ def test_sorted_flag() -> None: pl.Series([{"a": 1}], dtype=pl.Object).set_sorted(descending=True) +@pytest.mark.may_fail_auto_streaming def test_sorted_flag_after_joins() -> None: np.random.seed(1) dfa = pl.DataFrame( diff --git a/py-polars/tests/unit/test_string_cache.py b/py-polars/tests/unit/test_string_cache.py index def1f15db07a..084aebb065f1 100644 --- a/py-polars/tests/unit/test_string_cache.py +++ b/py-polars/tests/unit/test_string_cache.py @@ -107,6 +107,7 @@ def my_function() -> None: sc(True) +@pytest.mark.may_fail_auto_streaming def test_string_cache_join() -> None: df1 = pl.DataFrame({"a": ["foo", "bar", "ham"], "b": [1, 2, 3]}) df2 = pl.DataFrame({"a": ["eggs", "spam", "foo"], "c": [2, 2, 3]}) From 58553be340439f3f9ca2cb4334a8f1297df5cc58 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 25 Nov 2024 15:54:38 +0100 Subject: [PATCH 14/21] fix feature flags --- crates/polars-lazy/Cargo.toml | 2 +- crates/polars-stream/Cargo.toml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 3f8c64dd1970..be992164a4d9 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -161,7 +161,7 @@ dtype-time = [ dtype-u16 = ["polars-plan/dtype-u16", "polars-pipe?/dtype-u16", "polars-expr/dtype-u16", "polars-mem-engine/dtype-u16"] dtype-u8 = ["polars-plan/dtype-u8", "polars-pipe?/dtype-u8", "polars-expr/dtype-u8", "polars-mem-engine/dtype-u8"] -object = ["polars-plan/object", "polars-mem-engine/object"] +object = ["polars-plan/object", "polars-mem-engine/object", "polars-stream?/object"] month_start = ["polars-plan/month_start"] month_end = ["polars-plan/month_end"] offset_by = ["polars-plan/offset_by"] diff --git a/crates/polars-stream/Cargo.toml b/crates/polars-stream/Cargo.toml index f0b3b1c30e35..d90047b70074 100644 --- a/crates/polars-stream/Cargo.toml +++ b/crates/polars-stream/Cargo.toml @@ -46,6 +46,7 @@ parquet = ["polars-mem-engine/parquet", "polars-plan/parquet"] csv = ["polars-mem-engine/csv", "polars-plan/csv"] json = ["polars-mem-engine/json", "polars-plan/json"] cloud = ["polars-mem-engine/cloud", "polars-plan/cloud", "polars-io/cloud"] +object = ["polars-ops/object"] # We need to specify default features here to match workspace defaults. # Otherwise we get warnings with cargo check/clippy. From 7269a4e3f834b3245cdeb1dc8d473a018d83dcdc Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 25 Nov 2024 15:55:32 +0100 Subject: [PATCH 15/21] fix expression lowering --- .../src/nodes/joins/equi_join.rs | 11 +++++-- .../src/physical_plan/lower_ir.rs | 29 +++++++++++++++---- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index 929597e9f5cf..d44ccbdf7708 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -642,11 +642,16 @@ impl EquiJoinNode { right_key_selectors: Vec, args: JoinArgs, ) -> PolarsResult { - // TODO: use cardinality estimation to determine this. - let left_is_build = args.how != JoinType::Left; - // TODO: expose as a parameter, and let you choose the primary order to preserve. let preserve_order = std::env::var("POLARS_JOIN_IGNORE_ORDER").as_deref() != Ok("1"); + + let left_is_build = if preserve_order { + // Legacy, preserve right -> left unless join type is left, then preserve left -> right. + args.how != JoinType::Left + } else { + // TODO: use cardinality estimation to determine this. + true + }; let left_payload_select = compute_payload_selector( &left_input_schema, diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index 02640adc3777..e7aa33b520fc 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -520,10 +520,27 @@ pub fn lower_ir( let phys_left = lower_ir!(input_left)?; let phys_right = lower_ir!(input_right)?; if args.how.is_equi() && !args.validation.needs_checks() { - let (trans_input_left, trans_left_on) = - lower_exprs(phys_left, &left_on, expr_arena, phys_sm, expr_cache)?; - let (trans_input_right, trans_right_on) = - lower_exprs(phys_right, &right_on, expr_arena, phys_sm, expr_cache)?; + // When lowering the expressions for the keys we need to ensure we keep around the + // payload columns, otherwise the input nodes can get replaced by input-independent + // nodes since the lowering code does not see we access any non-literal expressions. + // So we add dummy expressions before lowering and remove them afterwards. + let mut aug_left_on = left_on.clone(); + for name in phys_sm[phys_left].output_schema.iter_names() { + let col_expr = expr_arena.add(AExpr::Column(name.clone())); + aug_left_on.push(ExprIR::new(col_expr, OutputName::ColumnLhs(name.clone()))); + } + let mut aug_right_on = right_on.clone(); + for name in phys_sm[phys_right].output_schema.iter_names() { + let col_expr = expr_arena.add(AExpr::Column(name.clone())); + aug_right_on.push(ExprIR::new(col_expr, OutputName::ColumnLhs(name.clone()))); + } + let (trans_input_left, mut trans_left_on) = + lower_exprs(phys_left, &aug_left_on, expr_arena, phys_sm, expr_cache)?; + let (trans_input_right, mut trans_right_on) = + lower_exprs(phys_right, &aug_right_on, expr_arena, phys_sm, expr_cache)?; + trans_left_on.drain(left_on.len()..); + trans_right_on.drain(right_on.len()..); + let mut node = phys_sm.insert(PhysNode::new( output_schema, PhysNodeKind::EquiJoin { @@ -531,8 +548,8 @@ pub fn lower_ir( input_right: trans_input_right, left_on: trans_left_on, right_on: trans_right_on, - args: args.clone() - } + args: args.clone(), + }, )); if let Some((offset, len)) = args.slice { node = build_slice_node(node, offset, len, phys_sm); From 77ea298da038f6b55ba7675d741e4d0c47da093f Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 25 Nov 2024 15:56:02 +0100 Subject: [PATCH 16/21] fmt --- crates/polars-core/src/datatypes/field.rs | 2 +- .../polars-expr/src/chunked_idx_table/mod.rs | 9 ++-- .../src/chunked_idx_table/row_encoded.rs | 36 +++++++------ crates/polars-expr/src/hash_keys.rs | 54 +++++++++++++++---- crates/polars-expr/src/lib.rs | 2 +- crates/polars-ops/src/frame/join/args.rs | 5 +- .../polars-stream/src/nodes/in_memory_map.rs | 5 +- .../src/nodes/in_memory_source.rs | 2 +- .../src/nodes/joins/equi_join.rs | 2 +- .../src/nodes/joins/in_memory.rs | 6 ++- crates/polars-stream/src/nodes/joins/mod.rs | 2 +- crates/polars-stream/src/physical_plan/fmt.rs | 9 +++- crates/polars-stream/src/physical_plan/mod.rs | 9 +++- .../src/physical_plan/to_graph.rs | 14 ++--- .../polars-utils/src/idx_map/bytes_idx_map.rs | 17 +++--- crates/polars-utils/src/index.rs | 2 +- 16 files changed, 115 insertions(+), 61 deletions(-) diff --git a/crates/polars-core/src/datatypes/field.rs b/crates/polars-core/src/datatypes/field.rs index 7ff81d7277ea..0c30aacdd275 100644 --- a/crates/polars-core/src/datatypes/field.rs +++ b/crates/polars-core/src/datatypes/field.rs @@ -96,7 +96,7 @@ impl Field { pub fn set_name(&mut self, name: PlSmallStr) { self.name = name; } - + /// Returns this `Field`, renamed. pub fn with_name(mut self, name: PlSmallStr) -> Self { self.name = name; diff --git a/crates/polars-expr/src/chunked_idx_table/mod.rs b/crates/polars-expr/src/chunked_idx_table/mod.rs index a9d853c866db..bd6cc0a05fe9 100644 --- a/crates/polars-expr/src/chunked_idx_table/mod.rs +++ b/crates/polars-expr/src/chunked_idx_table/mod.rs @@ -25,9 +25,9 @@ pub trait ChunkedIdxTable: Any + Send + Sync { /// (ChunkId, IdxSize) pairs for each match. Will stop processing new keys /// once limit matches have been generated, returning the number of keys /// processed. - /// + /// /// If mark_matches is true, matches are marked in the table as such. - /// + /// /// If emit_unmatched is true, for keys that do not have a match we emit a /// match with ChunkId::null() on the table match. fn probe( @@ -39,7 +39,7 @@ pub trait ChunkedIdxTable: Any + Send + Sync { emit_unmatched: bool, limit: IdxSize, ) -> IdxSize; - + /// The same as probe, except it will only apply to the specified subset of keys. /// # Safety /// The provided subset indices must be in-bounds. @@ -55,7 +55,8 @@ pub trait ChunkedIdxTable: Any + Send + Sync { ) -> IdxSize; /// Get the ChunkIds for each key which was never marked during probing. - fn unmarked_keys(&self, out: &mut Vec>, offset: IdxSize, limit: IdxSize) -> IdxSize; + fn unmarked_keys(&self, out: &mut Vec>, offset: IdxSize, limit: IdxSize) + -> IdxSize; } pub fn new_chunked_idx_table(_key_schema: Arc) -> Box { diff --git a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs index f7952ad91dea..fc67aca159a9 100644 --- a/crates/polars-expr/src/chunked_idx_table/row_encoded.rs +++ b/crates/polars-expr/src/chunked_idx_table/row_encoded.rs @@ -59,7 +59,7 @@ impl RowEncodedChunkedIdxTable { } true } else { - false + false } } @@ -76,17 +76,11 @@ impl RowEncodedChunkedIdxTable { let mut keys_processed = 0; for (key_idx, hash, key) in hash_keys { let found_match = if let Some(key) = key { - self.probe_one::( - key_idx, - hash, - key, - table_match, - probe_match, - ) + self.probe_one::(key_idx, hash, key, table_match, probe_match) } else { false }; - + if EMIT_UNMATCHED && !found_match { table_match.push(ChunkId::null()); probe_match.push(key_idx); @@ -269,16 +263,26 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { } } - fn unmarked_keys(&self, out: &mut Vec>, mut offset: IdxSize, limit: IdxSize) -> IdxSize { + fn unmarked_keys( + &self, + out: &mut Vec>, + mut offset: IdxSize, + limit: IdxSize, + ) -> IdxSize { out.clear(); - + if (offset as usize) < self.null_keys.len() { - out.extend(self.null_keys[offset as usize..].iter().copied().take(limit as usize)); + out.extend( + self.null_keys[offset as usize..] + .iter() + .copied() + .take(limit as usize), + ); return out.len() as IdxSize; } - + offset -= self.null_keys.len() as IdxSize; - + let mut keys_processed = 0; while let Some((_, _, chunk_ids)) = self.idx_map.get_index(offset + keys_processed) { let first_chunk_id = unsafe { chunk_ids.get_unchecked(0) }; @@ -290,13 +294,13 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable { out.push(chunk_id); } } - + keys_processed += 1; if out.len() >= limit as usize { break; } } - + keys_processed } } diff --git a/crates/polars-expr/src/hash_keys.rs b/crates/polars-expr/src/hash_keys.rs index 59aa4add3443..4690da1e47b1 100644 --- a/crates/polars-expr/src/hash_keys.rs +++ b/crates/polars-expr/src/hash_keys.rs @@ -22,7 +22,12 @@ pub enum HashKeys { } impl HashKeys { - pub fn from_df(df: &DataFrame, random_state: PlRandomState, null_is_valid: bool, force_row_encoding: bool) -> Self { + pub fn from_df( + df: &DataFrame, + random_state: PlRandomState, + null_is_valid: bool, + force_row_encoding: bool, + ) -> Self { if df.width() > 1 || force_row_encoding { let keys = df .get_columns() @@ -30,7 +35,7 @@ impl HashKeys { .map(|c| c.as_materialized_series().clone()) .collect_vec(); let mut keys_encoded = _get_rows_encoded_unordered(&keys[..]).unwrap().into_array(); - + if !null_is_valid { let validities = keys.iter().map(|c| c.rechunk_validity()).collect_vec(); let combined = combine_validities_and_many(&validities); @@ -78,17 +83,37 @@ impl HashKeys { ) { if sketches.is_empty() { match self { - Self::RowEncoded(s) => s.gen_partition_idxs::(partitioner, partition_idxs, sketches, partition_nulls), - Self::Single(s) => s.gen_partition_idxs::(partitioner, partition_idxs, sketches, partition_nulls), + Self::RowEncoded(s) => s.gen_partition_idxs::( + partitioner, + partition_idxs, + sketches, + partition_nulls, + ), + Self::Single(s) => s.gen_partition_idxs::( + partitioner, + partition_idxs, + sketches, + partition_nulls, + ), } } else { match self { - Self::RowEncoded(s) => s.gen_partition_idxs::(partitioner, partition_idxs, sketches, partition_nulls), - Self::Single(s) => s.gen_partition_idxs::(partitioner, partition_idxs, sketches, partition_nulls), + Self::RowEncoded(s) => s.gen_partition_idxs::( + partitioner, + partition_idxs, + sketches, + partition_nulls, + ), + Self::Single(s) => s.gen_partition_idxs::( + partitioner, + partition_idxs, + sketches, + partition_nulls, + ), } } } - + /// Generates indices for a chunked gather such that the ith key gathers /// the next gathers_per_key[i] elements from the partition[i]th chunk. pub fn gen_partitioned_gather_idxs( @@ -98,8 +123,12 @@ impl HashKeys { gather_idxs: &mut Vec>, ) { match self { - Self::RowEncoded(s) => s.gen_partitioned_gather_idxs(partitioner, gathers_per_key, gather_idxs), - Self::Single(s) => s.gen_partitioned_gather_idxs(partitioner, gathers_per_key, gather_idxs), + Self::RowEncoded(s) => { + s.gen_partitioned_gather_idxs(partitioner, gathers_per_key, gather_idxs) + }, + Self::Single(s) => { + s.gen_partitioned_gather_idxs(partitioner, gathers_per_key, gather_idxs) + }, } } @@ -177,7 +206,7 @@ impl RowEncodedKeys { for (hash, &n) in self.hashes.values_iter().zip(gathers_per_key) { let p = partitioner.hash_to_partition(*hash); let offset = *offsets.get_unchecked(p); - for i in offset..offset+n { + for i in offset..offset + n { gather_idxs.push(ChunkId::store(p as IdxSize, i)); } *offsets.get_unchecked_mut(p) += n; @@ -194,7 +223,10 @@ impl RowEncodedKeys { } let idx_arr = arrow::ffi::mmap::slice(idxs); let keys = take_unchecked(&self.keys, &idx_arr); - Self { hashes: PrimitiveArray::from_vec(hashes), keys } + Self { + hashes: PrimitiveArray::from_vec(hashes), + keys, + } } } diff --git a/crates/polars-expr/src/lib.rs b/crates/polars-expr/src/lib.rs index 2da894f9e297..138068e3c268 100644 --- a/crates/polars-expr/src/lib.rs +++ b/crates/polars-expr/src/lib.rs @@ -1,3 +1,4 @@ +pub mod chunked_idx_table; mod expressions; pub mod groups; pub mod hash_keys; @@ -5,6 +6,5 @@ pub mod planner; pub mod prelude; pub mod reduce; pub mod state; -pub mod chunked_idx_table; pub use crate::planner::{create_physical_expr, ExpressionConversionState}; diff --git a/crates/polars-ops/src/frame/join/args.rs b/crates/polars-ops/src/frame/join/args.rs index b005fa896a63..5d55367b1f4a 100644 --- a/crates/polars-ops/src/frame/join/args.rs +++ b/crates/polars-ops/src/frame/join/args.rs @@ -164,7 +164,10 @@ impl Debug for JoinType { impl JoinType { pub fn is_equi(&self) -> bool { - matches!(self, JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full) + matches!( + self, + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full + ) } pub fn is_asof(&self) -> bool { diff --git a/crates/polars-stream/src/nodes/in_memory_map.rs b/crates/polars-stream/src/nodes/in_memory_map.rs index 118827f76529..bce07359ec10 100644 --- a/crates/polars-stream/src/nodes/in_memory_map.rs +++ b/crates/polars-stream/src/nodes/in_memory_map.rs @@ -56,7 +56,10 @@ impl ComputeNode for InMemoryMapNode { { if recv[0] == PortState::Done { let df = sink_node.get_output()?; - let mut source_node = InMemorySourceNode::new(Arc::new(map.call_udf(df.unwrap())?), MorselSeq::default()); + let mut source_node = InMemorySourceNode::new( + Arc::new(map.call_udf(df.unwrap())?), + MorselSeq::default(), + ); source_node.initialize(*num_pipelines); *self = Self::Source(source_node); } diff --git a/crates/polars-stream/src/nodes/in_memory_source.rs b/crates/polars-stream/src/nodes/in_memory_source.rs index b8d07c756a34..5c9f1e63fbce 100644 --- a/crates/polars-stream/src/nodes/in_memory_source.rs +++ b/crates/polars-stream/src/nodes/in_memory_source.rs @@ -18,7 +18,7 @@ impl InMemorySourceNode { source: Some(source), morsel_size: 0, seq: AtomicU64::new(0), - seq_offset + seq_offset, } } } diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index d44ccbdf7708..fc6cd8619a50 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -644,7 +644,7 @@ impl EquiJoinNode { ) -> PolarsResult { // TODO: expose as a parameter, and let you choose the primary order to preserve. let preserve_order = std::env::var("POLARS_JOIN_IGNORE_ORDER").as_deref() != Ok("1"); - + let left_is_build = if preserve_order { // Legacy, preserve right -> left unless join type is left, then preserve left -> right. args.how != JoinType::Left diff --git a/crates/polars-stream/src/nodes/joins/in_memory.rs b/crates/polars-stream/src/nodes/joins/in_memory.rs index 79b45d074e5c..3fb981c25d20 100644 --- a/crates/polars-stream/src/nodes/joins/in_memory.rs +++ b/crates/polars-stream/src/nodes/joins/in_memory.rs @@ -60,8 +60,10 @@ impl ComputeNode for InMemoryJoinNode { if recv[0] == PortState::Done && recv[1] == PortState::Done { let left_df = left.get_output()?.unwrap(); let right_df = right.get_output()?.unwrap(); - let mut source_node = - InMemorySourceNode::new(Arc::new((self.joiner)(left_df, right_df)?), MorselSeq::default()); + let mut source_node = InMemorySourceNode::new( + Arc::new((self.joiner)(left_df, right_df)?), + MorselSeq::default(), + ); source_node.initialize(self.num_pipelines); self.state = InMemoryJoinState::Source(source_node); } diff --git a/crates/polars-stream/src/nodes/joins/mod.rs b/crates/polars-stream/src/nodes/joins/mod.rs index 26f3282b4a76..f5304162d56a 100644 --- a/crates/polars-stream/src/nodes/joins/mod.rs +++ b/crates/polars-stream/src/nodes/joins/mod.rs @@ -1,2 +1,2 @@ +pub mod equi_join; pub mod in_memory; -pub mod equi_join; \ No newline at end of file diff --git a/crates/polars-stream/src/physical_plan/fmt.rs b/crates/polars-stream/src/physical_plan/fmt.rs index 57ae8119db11..6da72a3bcb4c 100644 --- a/crates/polars-stream/src/physical_plan/fmt.rs +++ b/crates/polars-stream/src/physical_plan/fmt.rs @@ -214,7 +214,14 @@ fn visualize_plan_rec( left_on, right_on, args, - } | PhysNodeKind::EquiJoin { input_left, input_right, left_on, right_on, args } => { + } + | PhysNodeKind::EquiJoin { + input_left, + input_right, + left_on, + right_on, + args, + } => { let mut label = if matches!(phys_sm[node_key].kind, PhysNodeKind::EquiJoin { .. }) { "equi-join".to_string() } else { diff --git a/crates/polars-stream/src/physical_plan/mod.rs b/crates/polars-stream/src/physical_plan/mod.rs index aa821e6b0a38..c5f679c7fe56 100644 --- a/crates/polars-stream/src/physical_plan/mod.rs +++ b/crates/polars-stream/src/physical_plan/mod.rs @@ -153,7 +153,7 @@ pub enum PhysNodeKind { key: Vec, aggs: Vec, }, - + EquiJoin { input_left: PhysNodeKey, input_right: PhysNodeKey, @@ -221,7 +221,12 @@ fn insert_multiplexers( insert_multiplexers(*input, phys_sm, referenced); }, - PhysNodeKind::InMemoryJoin { input_left, input_right, .. } | PhysNodeKind::EquiJoin { + PhysNodeKind::InMemoryJoin { + input_left, + input_right, + .. + } + | PhysNodeKind::EquiJoin { input_left, input_right, .. diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index 4a36eb3b5a2e..4bf706c00aa2 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -504,7 +504,7 @@ fn to_graph_rec<'a>( [left_input_key, right_input_key], ) }, - + EquiJoin { input_left, input_right, @@ -518,10 +518,12 @@ fn to_graph_rec<'a>( let left_input_schema = ctx.phys_sm[*input_left].output_schema.clone(); let right_input_schema = ctx.phys_sm[*input_right].output_schema.clone(); - let left_key_schema = compute_output_schema(&left_input_schema, left_on, ctx.expr_arena)? - .materialize_unknown_dtypes()?; - let right_key_schema = compute_output_schema(&right_input_schema, right_on, ctx.expr_arena)? - .materialize_unknown_dtypes()?; + let left_key_schema = + compute_output_schema(&left_input_schema, left_on, ctx.expr_arena)? + .materialize_unknown_dtypes()?; + let right_key_schema = + compute_output_schema(&right_input_schema, right_on, ctx.expr_arena)? + .materialize_unknown_dtypes()?; let left_key_selectors = left_on .iter() @@ -540,7 +542,7 @@ fn to_graph_rec<'a>( Arc::new(right_key_schema), left_key_selectors, right_key_selectors, - args + args, )?, [left_input_key, right_input_key], ) diff --git a/crates/polars-utils/src/idx_map/bytes_idx_map.rs b/crates/polars-utils/src/idx_map/bytes_idx_map.rs index c362361e2620..90e7a5c33ee4 100644 --- a/crates/polars-utils/src/idx_map/bytes_idx_map.rs +++ b/crates/polars-utils/src/idx_map/bytes_idx_map.rs @@ -62,15 +62,12 @@ impl BytesIndexMap { pub fn is_empty(&self) -> bool { self.table.is_empty() } - + pub fn get(&self, hash: u64, key: &[u8]) -> Option<&V> { - let idx = self.table.find( - hash.wrapping_mul(self.seed), - |i| unsafe { - let t = self.tuples.get_unchecked(*i as usize); - hash == t.0.key_hash && key == t.0.get(&self.key_data) - }, - )?; + let idx = self.table.find(hash.wrapping_mul(self.seed), |i| unsafe { + let t = self.tuples.get_unchecked(*i as usize); + hash == t.0.key_hash && key == t.0.get(&self.key_data) + })?; unsafe { Some(&self.tuples.get_unchecked(*idx as usize).1) } } @@ -128,9 +125,7 @@ impl BytesIndexMap { /// Iterates over the values in insertion order. pub fn iter_values(&self) -> impl Iterator { - self.tuples - .iter() - .map(|t| &t.1) + self.tuples.iter().map(|t| &t.1) } } diff --git a/crates/polars-utils/src/index.rs b/crates/polars-utils/src/index.rs index 9ef037954c39..2ecad3cbf92b 100644 --- a/crates/polars-utils/src/index.rs +++ b/crates/polars-utils/src/index.rs @@ -235,7 +235,7 @@ impl ChunkId { pub fn inner_mut(&mut self) -> &mut u64 { &mut self.swizzled } - + pub fn from_inner(inner: u64) -> Self { Self { swizzled: inner } } From c9e6b97675bc90727da83e443dfff7a4e07c90dd Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 25 Nov 2024 16:10:13 +0100 Subject: [PATCH 17/21] clippy --- crates/polars-core/src/frame/mod.rs | 4 ++++ .../polars-expr/src/chunked_idx_table/mod.rs | 1 + crates/polars-expr/src/hash_keys.rs | 5 ++++ .../src/nodes/joins/equi_join.rs | 24 +++++++++---------- 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index f4dad06dea7a..c86c6d0ca2d0 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -1903,10 +1903,14 @@ impl DataFrame { unsafe { DataFrame::new_no_checks(idx.len(), cols) } } + /// # Safety + /// The indices must be in-bounds. pub unsafe fn take_slice_unchecked(&self, idx: &[IdxSize]) -> Self { self.take_slice_unchecked_impl(idx, true) } + /// # Safety + /// The indices must be in-bounds. pub unsafe fn take_slice_unchecked_impl(&self, idx: &[IdxSize], allow_threads: bool) -> Self { let cols = if allow_threads { POOL.install(|| self._apply_columns_par(&|s| s.take_slice_unchecked(idx))) diff --git a/crates/polars-expr/src/chunked_idx_table/mod.rs b/crates/polars-expr/src/chunked_idx_table/mod.rs index bd6cc0a05fe9..948e34effad0 100644 --- a/crates/polars-expr/src/chunked_idx_table/mod.rs +++ b/crates/polars-expr/src/chunked_idx_table/mod.rs @@ -43,6 +43,7 @@ pub trait ChunkedIdxTable: Any + Send + Sync { /// The same as probe, except it will only apply to the specified subset of keys. /// # Safety /// The provided subset indices must be in-bounds. + #[allow(clippy::too_many_arguments)] unsafe fn probe_subset( &self, hash_keys: &HashKeys, diff --git a/crates/polars-expr/src/hash_keys.rs b/crates/polars-expr/src/hash_keys.rs index 4690da1e47b1..602ca7c52216 100644 --- a/crates/polars-expr/src/hash_keys.rs +++ b/crates/polars-expr/src/hash_keys.rs @@ -71,6 +71,10 @@ impl HashKeys { } } + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + /// After this call partition_idxs[p] will contain the indices of hashes /// that belong to partition p, and the cardinality sketches are updated /// accordingly. @@ -255,6 +259,7 @@ impl SingleKeys { todo!() } + #[allow(clippy::ptr_arg)] // Remove when implemented. pub fn gen_partitioned_gather_idxs( &self, _partitioner: &HashPartitioner, diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index fc6cd8619a50..d960433206f6 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -40,14 +40,12 @@ fn compute_payload_selector( let selector = if should_coalesce && this_key_schema.contains(c) { if is_left != (args.how == JoinType::Right) { Some(c.clone()) + } else if args.how == JoinType::Full { + // We must keep the right-hand side keycols around for + // coalescing. + Some(format_pl_smallstr!("__POLARS_COALESCE_KEYCOL{i}")) } else { - if args.how == JoinType::Full { - // We must keep the right-hand side keycols around for - // coalescing. - Some(format_pl_smallstr!("__POLARS_COALESCE_KEYCOL{i}")) - } else { - None - } + None } } else if !other.contains(c) || is_left { Some(c.clone()) @@ -114,7 +112,7 @@ async fn select_keys( // We use key columns entirely by position, and allow duplicate names, // so just assign arbitrary unique names. let unique_name = format_pl_smallstr!("__POLARS_KEYCOL_{i}"); - let s = selector.evaluate(&df, state).await?; + let s = selector.evaluate(df, state).await?; key_columns.push(s.into_column().with_name(unique_name)); } let keys = DataFrame::new_with_broadcast_len(key_columns, df.height())?; @@ -175,7 +173,7 @@ impl BuildState { while let Ok(morsel) = recv.recv().await { // Compute hashed keys and payload. We must rechunk the payload for // later chunked gathers. - let hash_keys = select_keys(morsel.df(), &key_selectors, params, state).await?; + let hash_keys = select_keys(morsel.df(), key_selectors, params, state).await?; let mut payload = select_payload(morsel.df().clone(), payload_selector); payload.rechunk_mut(); @@ -331,7 +329,7 @@ impl ProbeState { while let Ok(morsel) = recv.recv().await { // Compute hashed keys and payload. let (df, seq, src_token, wait_token) = morsel.into_inner(); - let hash_keys = select_keys(&df, &key_selectors, params, state).await?; + let hash_keys = select_keys(&df, key_selectors, params, state).await?; let payload = select_payload(df, payload_selector); max_seq = seq; @@ -351,7 +349,7 @@ impl ProbeState { for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) { p.table.probe_subset( &hash_keys, - &idxs_in_p, + idxs_in_p, &mut table_match, &mut probe_match, mark_matches, @@ -820,7 +818,7 @@ impl ComputeNode for EquiJoinNode { build_state .partitions_per_worker - .resize_with(self.num_pipelines, || Vec::new()); + .resize_with(self.num_pipelines, Vec::new); let partitioner = HashPartitioner::new(self.num_pipelines, 0); for (worker_ps, recv) in build_state.partitions_per_worker.iter_mut().zip(receivers) { @@ -844,7 +842,7 @@ impl ComputeNode for EquiJoinNode { let partitioner = HashPartitioner::new(self.num_pipelines, 0); let probe_tasks = receivers .into_iter() - .zip(senders.into_iter()) + .zip(senders) .map(|(recv, send)| { scope.spawn_task( TaskPriority::High, From 917bc86675a4e738b3acd61886be328a173dda79 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 25 Nov 2024 16:41:19 +0100 Subject: [PATCH 18/21] fix rayon being used in gather --- crates/polars-stream/src/nodes/joins/equi_join.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index d960433206f6..1341a02cfe7b 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -362,7 +362,7 @@ impl ProbeState { } else { p.df.take_chunked_unchecked(&table_match, IsSorted::Not) }; - let mut probe_df = payload.take_slice_unchecked(&probe_match); + let mut probe_df = payload.take_slice_unchecked_impl(&probe_match, false); let mut out_df = if params.left_is_build { build_df.hstack_mut_unchecked(probe_df.get_columns()); @@ -415,7 +415,7 @@ impl ProbeState { } else { p.df.take_chunked_unchecked(&table_match, IsSorted::Not) }; - let mut probe_df = payload.take_slice_unchecked(&probe_match); + let mut probe_df = payload.take_slice_unchecked_impl(&probe_match, false); let out_df = if params.left_is_build { build_df.hstack_mut_unchecked(probe_df.get_columns()); From fe756659c77aa13454b05cfe3db256212eb9756e Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Mon, 25 Nov 2024 17:26:10 +0100 Subject: [PATCH 19/21] fmt --- crates/polars-stream/src/nodes/joins/equi_join.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index 1341a02cfe7b..fc2d69ffb742 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -415,7 +415,8 @@ impl ProbeState { } else { p.df.take_chunked_unchecked(&table_match, IsSorted::Not) }; - let mut probe_df = payload.take_slice_unchecked_impl(&probe_match, false); + let mut probe_df = + payload.take_slice_unchecked_impl(&probe_match, false); let out_df = if params.left_is_build { build_df.hstack_mut_unchecked(probe_df.get_columns()); @@ -643,13 +644,8 @@ impl EquiJoinNode { // TODO: expose as a parameter, and let you choose the primary order to preserve. let preserve_order = std::env::var("POLARS_JOIN_IGNORE_ORDER").as_deref() != Ok("1"); - let left_is_build = if preserve_order { - // Legacy, preserve right -> left unless join type is left, then preserve left -> right. - args.how != JoinType::Left - } else { - // TODO: use cardinality estimation to determine this. - true - }; + // TODO: use cardinality estimation to determine this when not order-preserving. + let left_is_build = args.how != JoinType::Left; let left_payload_select = compute_payload_selector( &left_input_schema, From 587bb048b4007dab6fe20d7579b295a65d22651f Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Tue, 26 Nov 2024 17:39:47 +0100 Subject: [PATCH 20/21] fix incorrect binaryarray gather kernel --- .../src/compute/take/generic_binary.rs | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/crates/polars-arrow/src/compute/take/generic_binary.rs b/crates/polars-arrow/src/compute/take/generic_binary.rs index dc281bda129a..32e890ec42a1 100644 --- a/crates/polars-arrow/src/compute/take/generic_binary.rs +++ b/crates/polars-arrow/src/compute/take/generic_binary.rs @@ -73,7 +73,7 @@ pub(super) unsafe fn take_values_validity( values: &[u8], indices: &PrimitiveArray, ) -> (OffsetsBuffer, Buffer, Option) { - let mut length = O::default(); + let mut total_length = O::default(); let offsets = offsets.buffer(); let mut starts = Vec::::with_capacity(indices.len()); let lengths = indices.values().iter().map(|index| { let index = index.to_usize(); + let length; match offsets.get(index + 1) { Some(&next) => { let start = *offsets.get_unchecked(index); - length += next - start; + length = next - start; + total_length += length; starts.push_unchecked(start); }, - None => starts.push_unchecked(O::default()), + None => { + length = O::zero(); + starts.push_unchecked(O::default()); + } }; length.to_usize() }); let offsets = create_offsets(lengths, indices.len()); - let buffer = take_values(length, &starts, &offsets, values); + let buffer = take_values(total_length, &starts, &offsets, values); (offsets, buffer, indices.validity().cloned()) } @@ -127,7 +133,7 @@ pub(super) unsafe fn take_values_indices_validity, ) -> (OffsetsBuffer, Buffer, Option) { - let mut length = O::default(); + let mut total_length = O::default(); let mut validity = MutableBitmap::with_capacity(indices.len()); let values_validity = values.validity().unwrap(); @@ -136,28 +142,32 @@ pub(super) unsafe fn take_values_indices_validity::with_capacity(indices.len()); let lengths = indices.iter().map(|index| { + let length; match index { Some(index) => { let index = index.to_usize(); if values_validity.get_bit(index) { validity.push(true); - length += *offsets.get_unchecked(index + 1) - *offsets.get_unchecked(index); + length = *offsets.get_unchecked(index + 1) - *offsets.get_unchecked(index); starts.push_unchecked(*offsets.get_unchecked(index)); } else { validity.push(false); + length = O::zero(); starts.push_unchecked(O::default()); } }, None => { validity.push(false); + length = O::zero(); starts.push_unchecked(O::default()); }, }; + total_length += length; length.to_usize() }); let offsets = create_offsets(lengths, indices.len()); - let buffer = take_values(length, &starts, &offsets, values_values); + let buffer = take_values(total_length, &starts, &offsets, values_values); (offsets, buffer, validity.into()) } From a601af2e65590d8aead34404df6c0a9ded2b0ae0 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Tue, 26 Nov 2024 17:40:10 +0100 Subject: [PATCH 21/21] fmt --- crates/polars-arrow/src/compute/take/generic_binary.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/polars-arrow/src/compute/take/generic_binary.rs b/crates/polars-arrow/src/compute/take/generic_binary.rs index 32e890ec42a1..a1c220ddbd29 100644 --- a/crates/polars-arrow/src/compute/take/generic_binary.rs +++ b/crates/polars-arrow/src/compute/take/generic_binary.rs @@ -117,7 +117,7 @@ pub(super) unsafe fn take_indices_validity( None => { length = O::zero(); starts.push_unchecked(O::default()); - } + }, }; length.to_usize() });