Skip to content

Commit

Permalink
Merge SortMergeJoin filtered batches into larger batches (#14160)
Browse files Browse the repository at this point in the history
* Merge SortMergeJoin filtered batches into bigger batches
  • Loading branch information
comphead authored Jan 22, 2025
1 parent 274e535 commit 0ba6e70
Showing 1 changed file with 95 additions and 40 deletions.
135 changes: 95 additions & 40 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,10 @@ struct SortMergeJoinStream {
/// optional join filter
pub filter: Option<JoinFilter>,
/// Staging output array builders
pub output_record_batches: JoinedRecordBatches,
pub staging_output_record_batches: JoinedRecordBatches,
/// Output buffer. Currently used by filtering as it requires double buffering
/// to avoid small/empty batches. Non-filtered join outputs directly from `staging_output_record_batches.batches`
pub output: RecordBatch,
/// Staging output size, including output batches and staging joined results.
/// Increased when we put rows into buffer and decreased after we actually output batches.
/// Used to trigger output when sufficient rows are ready
Expand Down Expand Up @@ -1053,13 +1056,35 @@ impl Stream for SortMergeJoinStream {
{
self.freeze_all()?;

if !self.output_record_batches.batches.is_empty()
// If join is filtered and there is joined tuples waiting
// to be filtered
if !self
.staging_output_record_batches
.batches
.is_empty()
{
// Apply filter on joined tuples and get filtered batch
let out_filtered_batch =
self.filter_joined_batch()?;
return Poll::Ready(Some(Ok(
out_filtered_batch,
)));

// Append filtered batch to the output buffer
self.output = concat_batches(
&self.schema(),
vec![&self.output, &out_filtered_batch],
)?;

// Send to output if the output buffer surpassed the `batch_size`
if self.output.num_rows() >= self.batch_size {
let record_batch = std::mem::replace(
&mut self.output,
RecordBatch::new_empty(
out_filtered_batch.schema(),
),
);
return Poll::Ready(Some(Ok(
record_batch,
)));
}
}
}

Expand Down Expand Up @@ -1116,7 +1141,7 @@ impl Stream for SortMergeJoinStream {
}
} else {
self.freeze_all()?;
if !self.output_record_batches.batches.is_empty() {
if !self.staging_output_record_batches.batches.is_empty() {
let record_batch = self.output_record_batch_and_reset()?;
// For non-filtered join output whenever the target output batch size
// is hit. For filtered join its needed to output on later phase
Expand Down Expand Up @@ -1146,7 +1171,8 @@ impl Stream for SortMergeJoinStream {
SortMergeJoinState::Exhausted => {
self.freeze_all()?;

if !self.output_record_batches.batches.is_empty() {
// if there is still something not processed
if !self.staging_output_record_batches.batches.is_empty() {
if self.filter.is_some()
&& matches!(
self.join_type,
Expand All @@ -1159,12 +1185,20 @@ impl Stream for SortMergeJoinStream {
| JoinType::LeftMark
)
{
let out = self.filter_joined_batch()?;
return Poll::Ready(Some(Ok(out)));
let record_batch = self.filter_joined_batch()?;
return Poll::Ready(Some(Ok(record_batch)));
} else {
let record_batch = self.output_record_batch_and_reset()?;
return Poll::Ready(Some(Ok(record_batch)));
}
} else if self.output.num_rows() > 0 {
// if processed but still not outputted because it didn't hit batch size before
let schema = self.output.schema();
let record_batch = std::mem::replace(
&mut self.output,
RecordBatch::new_empty(schema),
);
return Poll::Ready(Some(Ok(record_batch)));
} else {
return Poll::Ready(None);
}
Expand Down Expand Up @@ -1197,7 +1231,7 @@ impl SortMergeJoinStream {
state: SortMergeJoinState::Init,
sort_options,
null_equals_null,
schema,
schema: Arc::clone(&schema),
streamed_schema: Arc::clone(&streamed_schema),
buffered_schema,
streamed,
Expand All @@ -1212,12 +1246,13 @@ impl SortMergeJoinStream {
on_streamed,
on_buffered,
filter,
output_record_batches: JoinedRecordBatches {
staging_output_record_batches: JoinedRecordBatches {
batches: vec![],
filter_mask: BooleanBuilder::new(),
row_indices: UInt64Builder::new(),
batch_ids: vec![],
},
output: RecordBatch::new_empty(schema),
output_size: 0,
batch_size,
join_type,
Expand Down Expand Up @@ -1607,17 +1642,20 @@ impl SortMergeJoinStream {
buffered_batch,
)? {
let num_rows = record_batch.num_rows();
self.output_record_batches
self.staging_output_record_batches
.filter_mask
.append_nulls(num_rows);
self.output_record_batches
self.staging_output_record_batches
.row_indices
.append_nulls(num_rows);
self.output_record_batches
.batch_ids
.resize(self.output_record_batches.batch_ids.len() + num_rows, 0);
self.staging_output_record_batches.batch_ids.resize(
self.staging_output_record_batches.batch_ids.len() + num_rows,
0,
);

self.output_record_batches.batches.push(record_batch);
self.staging_output_record_batches
.batches
.push(record_batch);
}
buffered_batch.null_joined.clear();
}
Expand Down Expand Up @@ -1651,16 +1689,19 @@ impl SortMergeJoinStream {
)? {
let num_rows = record_batch.num_rows();

self.output_record_batches
self.staging_output_record_batches
.filter_mask
.append_nulls(num_rows);
self.output_record_batches
self.staging_output_record_batches
.row_indices
.append_nulls(num_rows);
self.output_record_batches
.batch_ids
.resize(self.output_record_batches.batch_ids.len() + num_rows, 0);
self.output_record_batches.batches.push(record_batch);
self.staging_output_record_batches.batch_ids.resize(
self.staging_output_record_batches.batch_ids.len() + num_rows,
0,
);
self.staging_output_record_batches
.batches
.push(record_batch);
}
buffered_batch.join_filter_not_matched_map.clear();

Expand Down Expand Up @@ -1792,20 +1833,29 @@ impl SortMergeJoinStream {
| JoinType::LeftMark
| JoinType::Full
) {
self.output_record_batches.batches.push(output_batch);
self.staging_output_record_batches
.batches
.push(output_batch);
} else {
let filtered_batch = filter_record_batch(&output_batch, &mask)?;
self.output_record_batches.batches.push(filtered_batch);
self.staging_output_record_batches
.batches
.push(filtered_batch);
}

if !matches!(self.join_type, JoinType::Full) {
self.output_record_batches.filter_mask.extend(&mask);
self.staging_output_record_batches.filter_mask.extend(&mask);
} else {
self.output_record_batches.filter_mask.extend(pre_mask);
self.staging_output_record_batches
.filter_mask
.extend(pre_mask);
}
self.output_record_batches.row_indices.extend(&left_indices);
self.output_record_batches.batch_ids.resize(
self.output_record_batches.batch_ids.len() + left_indices.len(),
self.staging_output_record_batches
.row_indices
.extend(&left_indices);
self.staging_output_record_batches.batch_ids.resize(
self.staging_output_record_batches.batch_ids.len()
+ left_indices.len(),
self.streamed_batch_counter.load(Relaxed),
);

Expand Down Expand Up @@ -1837,10 +1887,14 @@ impl SortMergeJoinStream {
}
}
} else {
self.output_record_batches.batches.push(output_batch);
self.staging_output_record_batches
.batches
.push(output_batch);
}
} else {
self.output_record_batches.batches.push(output_batch);
self.staging_output_record_batches
.batches
.push(output_batch);
}
}

Expand All @@ -1851,7 +1905,7 @@ impl SortMergeJoinStream {

fn output_record_batch_and_reset(&mut self) -> Result<RecordBatch> {
let record_batch =
concat_batches(&self.schema, &self.output_record_batches.batches)?;
concat_batches(&self.schema, &self.staging_output_record_batches.batches)?;
self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(record_batch.num_rows());
// If join filter exists, `self.output_size` is not accurate as we don't know the exact
Expand All @@ -1877,16 +1931,17 @@ impl SortMergeJoinStream {
| JoinType::Full
))
{
self.output_record_batches.batches.clear();
self.staging_output_record_batches.batches.clear();
}
Ok(record_batch)
}

fn filter_joined_batch(&mut self) -> Result<RecordBatch> {
let record_batch = self.output_record_batch_and_reset()?;
let mut out_indices = self.output_record_batches.row_indices.finish();
let mut out_mask = self.output_record_batches.filter_mask.finish();
let mut batch_ids = &self.output_record_batches.batch_ids;
let record_batch =
concat_batches(&self.schema, &self.staging_output_record_batches.batches)?;
let mut out_indices = self.staging_output_record_batches.row_indices.finish();
let mut out_mask = self.staging_output_record_batches.filter_mask.finish();
let mut batch_ids = &self.staging_output_record_batches.batch_ids;
let default_batch_ids = vec![0; record_batch.num_rows()];

// If only nulls come in and indices sizes doesn't match with expected record batch count
Expand All @@ -1901,7 +1956,7 @@ impl SortMergeJoinStream {
}

if out_mask.is_empty() {
self.output_record_batches.batches.clear();
self.staging_output_record_batches.batches.clear();
return Ok(record_batch);
}

Expand Down Expand Up @@ -2044,7 +2099,7 @@ impl SortMergeJoinStream {
)?;
}

self.output_record_batches.clear();
self.staging_output_record_batches.clear();

Ok(filtered_record_batch)
}
Expand Down

0 comments on commit 0ba6e70

Please sign in to comment.