diff --git a/crates/polars-stream/src/nodes/joins/equi_join.rs b/crates/polars-stream/src/nodes/joins/equi_join.rs index 8afc7c88b167..c0c072d51d8b 100644 --- a/crates/polars-stream/src/nodes/joins/equi_join.rs +++ b/crates/polars-stream/src/nodes/joins/equi_join.rs @@ -202,18 +202,29 @@ impl BuildState { } fn finalize(&mut self, params: &EquiJoinParams, table: &dyn ChunkedIdxTable) -> ProbeState { - let num_partitions = self.partitions_per_worker.len(); + // Transpose. + let num_workers = self.partitions_per_worker.len(); + let num_partitions = self.partitions_per_worker[0].len(); + let mut results_per_partition = (0..num_partitions) + .map(|_| Vec::with_capacity(num_workers)) + .collect_vec(); + for worker in self.partitions_per_worker.drain(..) { + for (p, result) in worker.into_iter().enumerate() { + results_per_partition[p].push(result); + } + } + let track_unmatchable = params.emit_unmatched_build(); - let table_per_partition: Vec<_> = (0..num_partitions) + let table_per_partition: Vec<_> = results_per_partition .into_par_iter() .with_max_len(1) - .map(|p| { + .map(|results| { // Estimate sizes and cardinality. let mut sketch = CardinalitySketch::new(); let mut num_frames = 0; - for worker in &self.partitions_per_worker { - sketch.combine(worker[p].sketch.as_ref().unwrap()); - num_frames += worker[p].frames.len(); + for result in &results { + sketch.combine(result.sketch.as_ref().unwrap()); + num_frames += result.frames.len(); } // Build table for this partition. @@ -223,9 +234,9 @@ impl BuildState { table.reserve(sketch.estimate() * 5 / 4); if params.preserve_order_build { let mut combined = Vec::with_capacity(num_frames); - for worker in &self.partitions_per_worker { + for result in results { for (hash_keys, (seq, frame)) in - worker[p].hash_keys.iter().zip(&worker[p].frames) + result.hash_keys.into_iter().zip(result.frames) { combined.push((seq, hash_keys, frame)); } @@ -239,14 +250,14 @@ impl BuildState { continue; } - table.insert_key_chunk(hash_keys.clone(), track_unmatchable); - combined_frames.push(frame.clone()); - chunk_seq_ids.push(*seq); + table.insert_key_chunk(hash_keys, track_unmatchable); + combined_frames.push(frame); + chunk_seq_ids.push(seq); } } else { - for worker in &self.partitions_per_worker { + for result in results { for (hash_keys, (_, frame)) in - worker[p].hash_keys.iter().zip(&worker[p].frames) + result.hash_keys.into_iter().zip(result.frames) { // Zero-sized chunks can get deleted, so skip entirely to avoid messing // up the chunk counter. @@ -254,8 +265,8 @@ impl BuildState { continue; } - table.insert_key_chunk(hash_keys.clone(), track_unmatchable); - combined_frames.push(frame.clone()); + table.insert_key_chunk(hash_keys, track_unmatchable); + combined_frames.push(frame); } } }