Skip to content

Commit

Permalink
almost there... just need to select key columns and hook it up
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp committed Nov 19, 2024
1 parent 207c076 commit c7def77
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 29 deletions.
6 changes: 6 additions & 0 deletions crates/polars-core/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ impl Field {
pub fn set_name(&mut self, name: PlSmallStr) {
self.name = name;
}

/// Returns this `Field`, renamed.
pub fn with_name(mut self, name: PlSmallStr) -> Self {
self.name = name;
self
}

/// Converts the `Field` to an `arrow::datatypes::Field`.
///
Expand Down
9 changes: 9 additions & 0 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,15 @@ impl DataFrame {
unsafe { DataFrame::new_no_checks(0, cols) }
}

/// Create a new `DataFrame` with the given schema, only containing nulls.
pub fn full_null(schema: &Schema, height: usize) -> Self {
let columns = schema
.iter_fields()
.map(|f| Column::full_null(f.name.clone(), height, f.dtype()))
.collect();
DataFrame { height, columns }
}

/// Removes the last `Series` from the `DataFrame` and returns it, or [`None`] if it is empty.
///
/// # Example
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-expr/src/chunked_idx_table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pub trait ChunkedIdxTable: Any + Send + Sync {
) -> IdxSize;

/// Get the ChunkIds for each key which was never marked during probing.
fn unmarked_keys(&self, out: &mut Vec<ChunkId<32>>);
fn unmarked_keys(&self, out: &mut Vec<ChunkId<32>>, offset: IdxSize, limit: IdxSize);
}

pub fn new_chunked_idx_table(key_schema: Arc<Schema>) -> Box<dyn ChunkedIdxTable> {
Expand Down
10 changes: 8 additions & 2 deletions crates/polars-expr/src/chunked_idx_table/row_encoded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,10 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable {
}
}

fn unmarked_keys(&self, out: &mut Vec<ChunkId<32>>) {
for chunk_ids in self.idx_map.iter_values() {
fn unmarked_keys(&self, out: &mut Vec<ChunkId<32>>, offset: IdxSize, limit: IdxSize) {
out.clear();

while let Some((_, _, chunk_ids)) = self.idx_map.get_index(offset) {
let first_chunk_id = unsafe { chunk_ids.get_unchecked(0) };
let first_chunk_val = first_chunk_id.load(Ordering::Acquire);
if first_chunk_val >> 63 == 0 {
Expand All @@ -263,6 +265,10 @@ impl ChunkedIdxTable for RowEncodedChunkedIdxTable {
out.push(chunk_id);
}
}

if out.len() >= limit as usize {
break;
}
}
}
}
126 changes: 103 additions & 23 deletions crates/polars-stream/src/nodes/joins/equi_join.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use polars_core::prelude::{PlHashSet, PlRandomState};
use polars_core::schema::Schema;
use polars_core::schema::{Schema, SchemaExt};
use polars_core::series::IsSorted;
use polars_core::utils::accumulate_dataframes_vertical_unchecked;
use polars_expr::chunked_idx_table::{new_chunked_idx_table, ChunkedIdxTable};
Expand All @@ -15,7 +15,8 @@ use polars_utils::{format_pl_smallstr, IdxSize};
use rayon::prelude::*;

use crate::async_primitives::connector::{Receiver, Sender};
use crate::morsel::get_ideal_morsel_size;
use crate::async_primitives::wait_group::WaitGroup;
use crate::morsel::{get_ideal_morsel_size, SourceToken};
use crate::nodes::compute_node_prelude::*;

/// A payload selector contains for each column whether that column should be
Expand Down Expand Up @@ -56,6 +57,13 @@ fn compute_payload_selector(
.collect()
}

fn select_schema(schema: &Schema, selector: &[Option<PlSmallStr>]) -> Schema {
schema.iter_fields()
.zip(selector)
.filter_map(|(f, name)| Some(f.with_name(name.clone()?)))
.collect()
}

fn select_payload(df: DataFrame, selector: &[Option<PlSmallStr>]) -> DataFrame {
// Maintain height of zero-width dataframes.
if df.width() == 0 {
Expand Down Expand Up @@ -248,50 +256,86 @@ impl ProbeState {

Ok(())
}
}

struct EmitUnmatchedState {
partitions: Vec<ProbeTable>,
active_partition_idx: usize,
offset_in_active_p: usize,
}

impl EmitUnmatchedState {
async fn emit_unmatched(
&mut self,
mut send: Sender<Morsel>,
partitions: &[ProbeTable],
params: &EquiJoinParams,
num_pipelines: usize,
) -> PolarsResult<()> {
let total_len: usize = self.partitions.iter().map(|p| p.table.num_keys() as usize).sum();
let ideal_morsel_count = (total_len / get_ideal_morsel_size()).max(1);
let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines);
let morsel_size = total_len.div_ceil(morsel_count).max(1);

let mut morsel_seq = MorselSeq::default();
let wait_group = WaitGroup::default();
let source_token = SourceToken::new();
let mut unmarked_idxs = Vec::new();
unsafe {
for p in partitions {
p.table.unmarked_keys(&mut unmarked_idxs);
let build_df = p.df.take_chunked_unchecked(&table_match, IsSorted::Not);
while let Some(p) = self.partitions.get(self.active_partition_idx) {
loop {
p.table.unmarked_keys(&mut unmarked_idxs, self.offset_in_active_p as IdxSize, morsel_size as IdxSize);
self.offset_in_active_p += unmarked_idxs.len();
if unmarked_idxs.is_empty() {
break;
}

let out_df = if params.left_is_build {
build_df.hstack_mut_unchecked(probe_df.get_columns());
build_df
} else {
probe_df.hstack_mut_unchecked(build_df.get_columns());
probe_df
let out_df = unsafe {
let mut build_df = p.df.take_chunked_unchecked(&unmarked_idxs, IsSorted::Not);
let len = build_df.height();
if params.left_is_build {
let probe_df = DataFrame::full_null(&params.right_payload_schema, len);
build_df.hstack_mut_unchecked(probe_df.get_columns());
build_df
} else {
let mut probe_df = DataFrame::full_null(&params.left_payload_schema, len);
probe_df.hstack_mut_unchecked(build_df.get_columns());
probe_df
}
};

let mut morsel = Morsel::new(out_df, morsel_seq, source_token.clone());
morsel_seq = morsel_seq.successor();
morsel.set_consume_token(wait_group.token());
if send.send(morsel).await.is_err() {
return Ok(());
}


let ideal_morsel_count = (len / get_ideal_morsel_size()).max(1);
let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines);
self.morsel_size = len.div_ceil(morsel_count).max(1);


wait_group.wait().await;
if source_token.stop_requested() {
return Ok(());
}
}

self.active_partition_idx += 1;
self.offset_in_active_p = 0;
}

Ok(())
}
}

enum EquiJoinState {
Build(BuildState),
Probe(ProbeState),
EmitUnmatchedBuild(ProbeState),
EmitUnmatchedBuild(EmitUnmatchedState),
Done,
}

struct EquiJoinParams {
left_is_build: bool,
left_payload_select: Vec<Option<PlSmallStr>>,
right_payload_select: Vec<Option<PlSmallStr>>,
left_payload_schema: Schema,
right_payload_schema: Schema,
args: JoinArgs,
random_state: PlRandomState,
}
Expand Down Expand Up @@ -341,6 +385,9 @@ impl EquiJoinNode {
compute_payload_selector(&left_input_schema, &right_input_schema, true, &args);
let right_payload_select =
compute_payload_selector(&right_input_schema, &left_input_schema, false, &args);

let left_payload_schema = select_schema(&left_input_schema, &left_payload_select);
let right_payload_schema = select_schema(&right_input_schema, &right_payload_select);
Self {
state: EquiJoinState::Build(BuildState {
partitions_per_worker: Vec::new(),
Expand All @@ -350,6 +397,8 @@ impl EquiJoinNode {
left_is_build,
left_payload_select,
right_payload_select,
left_payload_schema,
right_payload_schema,
args,
random_state: PlRandomState::new(),
},
Expand All @@ -373,9 +422,8 @@ impl ComputeNode for EquiJoinNode {
let build_idx = if self.params.left_is_build { 0 } else { 1 };
let probe_idx = 1 - build_idx;

// If the output doesn't want any more data, or the probe side is done,
// transition to being done.
if send[0] == PortState::Done || recv[probe_idx] == PortState::Done {
// If the output doesn't want any more data, transition to being done.
if send[0] == PortState::Done {
self.state = EquiJoinState::Done;
}

Expand All @@ -385,6 +433,29 @@ impl ComputeNode for EquiJoinNode {
self.state = EquiJoinState::Probe(build_state.finalize(&*self.table));
}
}

// If we are probing and the probe input is done, emit unmatched if
// necessary, otherwise we're done.
if let EquiJoinState::Probe(probe_state) = &mut self.state {
if recv[probe_idx] == PortState::Done {
if self.params.emit_unmatched_build() {
self.state = EquiJoinState::EmitUnmatchedBuild(EmitUnmatchedState {
partitions: core::mem::take(&mut probe_state.table_per_partition),
active_partition_idx: 0,
offset_in_active_p: 0,
});
} else {
self.state = EquiJoinState::Done;
}
}
}

// Finally, check if we are done emitting unmatched keys.
if let EquiJoinState::EmitUnmatchedBuild(emit_state) = &mut self.state {
if emit_state.active_partition_idx >= emit_state.partitions.len() {
self.state = EquiJoinState::Done;
}
}

match &mut self.state {
EquiJoinState::Build(_) => {
Expand Down Expand Up @@ -471,6 +542,15 @@ impl ComputeNode for EquiJoinNode {
));
}
},
EquiJoinState::EmitUnmatchedBuild(emit_state) => {
assert!(recv_ports[build_idx].is_none());
assert!(recv_ports[probe_idx].is_none());
let send = send_ports[0].take().unwrap().serial();
join_handles.push(scope.spawn_task(
TaskPriority::Low,
emit_state.emit_unmatched(send, &self.params, self.num_pipelines)
));
},
EquiJoinState::Done => unreachable!(),
}
}
Expand Down
8 changes: 6 additions & 2 deletions crates/polars-stream/src/physical_plan/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,12 @@ fn visualize_plan_rec(
left_on,
right_on,
args,
} => {
let mut label = "in-memory-join".to_string();
} | PhysNodeKind::EquiJoin { input_left, input_right, left_on, right_on, args } => {
let mut label = if matches!(phys_sm[node_key].kind, PhysNodeKind::EquiJoin { .. }) {
"equi-join".to_string()
} else {
"in-memory-join".to_string()
};
write!(label, r"\nleft_on:\n{}", fmt_exprs(left_on, expr_arena)).unwrap();
write!(label, r"\nright_on:\n{}", fmt_exprs(right_on, expr_arena)).unwrap();
write!(
Expand Down
10 changes: 9 additions & 1 deletion crates/polars-stream/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ pub enum PhysNodeKind {
key: Vec<ExprIR>,
aggs: Vec<ExprIR>,
},

EquiJoin {
input_left: PhysNodeKey,
input_right: PhysNodeKey,
left_on: Vec<ExprIR>,
right_on: Vec<ExprIR>,
args: JoinArgs,
},

/// Generic fallback for (as-of-yet) unsupported streaming joins.
/// Fully sinks all data to in-memory data frames and uses the in-memory
Expand Down Expand Up @@ -213,7 +221,7 @@ fn insert_multiplexers(
insert_multiplexers(*input, phys_sm, referenced);
},

PhysNodeKind::InMemoryJoin {
PhysNodeKind::InMemoryJoin { input_left, input_right, .. } | PhysNodeKind::EquiJoin {
input_left,
input_right,
..
Expand Down
25 changes: 25 additions & 0 deletions crates/polars-stream/src/physical_plan/to_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use super::{PhysNode, PhysNodeKey, PhysNodeKind};
use crate::expression::StreamExpr;
use crate::graph::{Graph, GraphNodeKey};
use crate::nodes;
use crate::nodes::joins::equi_join::EquiJoinNode;
use crate::physical_plan::lower_expr::compute_output_schema;
use crate::utils::late_materialized_df::LateMaterializedDataFrame;

Expand Down Expand Up @@ -503,6 +504,30 @@ fn to_graph_rec<'a>(
[left_input_key, right_input_key],
)
},

EquiJoin {
input_left,
input_right,
left_on,
right_on,
args,
} => {
let args = args.clone();
let left_input_key = to_graph_rec(*input_left, ctx)?;
let right_input_key = to_graph_rec(*input_right, ctx)?;
let left_input_schema = ctx.phys_sm[*input_left].output_schema.clone();
let right_input_schema = ctx.phys_sm[*input_right].output_schema.clone();

todo!()
// ctx.graph.add_node(
// nodes::joins::equi_join::EquiJoinNode::new(
// left_input_schema,
// right_input_schema,
// args,
// ),
// [left_input_key, right_input_key],
// )
},
};

ctx.phys_to_graph.insert(phys_node_key, graph_key);
Expand Down
8 changes: 8 additions & 0 deletions crates/polars-utils/src/idx_map/bytes_idx_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,18 @@ impl<V> BytesIndexMap<V> {
}
}

/// Gets the hash, key and value at the given index by insertion order.
#[inline(always)]
pub fn get_index(&self, idx: IdxSize) -> Option<(u64, &[u8], &V)> {
let t = self.tuples.get(idx as usize)?;
Some((t.0.key_hash, unsafe { t.0.get(&self.key_data) }, &t.1))
}

/// Gets the hash, key and value at the given index by insertion order.
///
/// # Safety
/// The index must be less than len().
#[inline(always)]
pub unsafe fn get_index_unchecked(&self, idx: IdxSize) -> (u64, &[u8], &V) {
let t = self.tuples.get_unchecked(idx as usize);
(t.0.key_hash, t.0.get(&self.key_data), &t.1)
Expand Down

0 comments on commit c7def77

Please sign in to comment.