Skip to content

Commit

Permalink
feat: Async back-pressure on adding to a source (#746)
Browse files Browse the repository at this point in the history
This closes #743

---------

Co-authored-by: Jordan Frazier <[email protected]>
  • Loading branch information
bjchambers and jordanrfrazier authored Sep 12, 2023
1 parent 83d4b59 commit bf294dd
Show file tree
Hide file tree
Showing 72 changed files with 728 additions and 390 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ arrow-ord = { version = "43.0.0" }
arrow-schema = { version = "43.0.0", features = ["serde"] }
arrow-select = { version = "43.0.0" }
arrow-string = { version = "43.0.0" }
async-broadcast = "0.5.1"
async-once-cell = "0.5.3"
async-stream = "0.3.4"
async-trait = "0.1.68"
Expand Down
1 change: 1 addition & 0 deletions crates/sparrow-merge/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ arrow-array.workspace = true
arrow-csv = { workspace = true, optional = true }
arrow-schema.workspace = true
arrow-select.workspace = true
async-broadcast.workspace = true
async-stream.workspace = true
bit-set.workspace = true
derive_more.workspace = true
Expand Down
44 changes: 26 additions & 18 deletions crates/sparrow-merge/src/in_memory_batches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::RwLock;

use arrow_array::RecordBatch;
use arrow_schema::SchemaRef;
use error_stack::{IntoReport, IntoReportCompat, ResultExt};
use error_stack::{IntoReportCompat, ResultExt};
use futures::Stream;

use crate::old::homogeneous_merge;
Expand All @@ -20,12 +20,15 @@ impl error_stack::Context for Error {}
/// Struct for managing in-memory batches.
#[derive(Debug)]
pub struct InMemoryBatches {
retained: bool,
/// Whether rows added will be available for interactive queries.
/// If False, rows will be discarded after being sent to any active
/// materializations.
queryable: bool,
current: RwLock<Current>,
updates: tokio::sync::broadcast::Sender<(usize, RecordBatch)>,
sender: async_broadcast::Sender<(usize, RecordBatch)>,
/// A subscriber that is never used -- it exists only to keep the sender
/// alive.
_subscriber: tokio::sync::broadcast::Receiver<(usize, RecordBatch)>,
_receiver: async_broadcast::InactiveReceiver<(usize, RecordBatch)>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -61,38 +64,43 @@ impl Current {
}

impl InMemoryBatches {
pub fn new(retained: bool, schema: SchemaRef) -> Self {
let (updates, _subscriber) = tokio::sync::broadcast::channel(10);
pub fn new(queryable: bool, schema: SchemaRef) -> Self {
let (mut sender, receiver) = async_broadcast::broadcast(10);

// Don't wait for a receiver. If no-one receives, `send` will fail.
sender.set_await_active(false);

let current = RwLock::new(Current::new(schema.clone()));
Self {
retained,
queryable,
current,
updates,
_subscriber,
sender,
_receiver: receiver.deactivate(),
}
}

/// Add a batch, merging it into the in-memory version.
///
/// Publishes the new batch to the subscribers.
pub fn add_batch(&self, batch: RecordBatch) -> error_stack::Result<(), Error> {
pub async fn add_batch(&self, batch: RecordBatch) -> error_stack::Result<(), Error> {
if batch.num_rows() == 0 {
return Ok(());
}

let new_version = {
let mut write = self.current.write().map_err(|_| Error::Add)?;
if self.retained {
if self.queryable {
write.add_batch(&batch)?;
}
write.version += 1;
write.version
};

self.updates
.send((new_version, batch))
.into_report()
.change_context(Error::Add)?;
let send_result = self.sender.broadcast((new_version, batch)).await;
if send_result.is_err() {
assert!(!self.sender.is_closed());
tracing::info!("No-one subscribed for new batch");
}
Ok(())
}

Expand All @@ -107,7 +115,7 @@ impl InMemoryBatches {
let read = self.current.read().unwrap();
(read.version, read.batch.clone())
};
let mut recv = self.updates.subscribe();
let mut recv = self.sender.new_receiver();

async_stream::try_stream! {
tracing::info!("Starting subscriber with version {version}");
Expand All @@ -126,11 +134,11 @@ impl InMemoryBatches {
tracing::warn!("Ignoring old version {recv_version}");
}
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
Err(async_broadcast::RecvError::Closed) => {
tracing::info!("Sender closed.");
break;
},
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
Err(async_broadcast::RecvError::Overflowed(_)) => {
Err(Error::ReceiverLagged)?;
}
}
Expand Down
21 changes: 0 additions & 21 deletions crates/sparrow-runtime/src/key_hash_inverse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,27 +324,6 @@ impl ThreadSafeKeyHashInverse {
}
}

pub fn blocking_add(
&self,
keys: &dyn Array,
key_hashes: &UInt64Array,
) -> error_stack::Result<(), Error> {
error_stack::ensure!(
keys.len() == key_hashes.len(),
Error::MismatchedLengths {
keys: keys.len(),
key_hashes: key_hashes.len()
}
);
let has_new_keys = self.key_map.blocking_read().has_new_keys(key_hashes);

if has_new_keys {
self.key_map.blocking_write().add(keys, key_hashes)
} else {
Ok(())
}
}

/// Stores the KeyHashInverse to the compute store.
///
/// This method is thread-safe and acquires the read-lock.
Expand Down
16 changes: 8 additions & 8 deletions crates/sparrow-runtime/src/prepare/preparer.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;

use arrow::array::{ArrayRef, UInt64Array};
Expand Down Expand Up @@ -31,7 +32,7 @@ pub struct Preparer {
prepared_schema: SchemaRef,
time_column_name: String,
subsort_column_name: Option<String>,
next_subsort: u64,
next_subsort: AtomicU64,
key_column_name: String,
time_multiplier: Option<i64>,
}
Expand All @@ -51,7 +52,7 @@ impl Preparer {
prepared_schema,
time_column_name,
subsort_column_name,
next_subsort: prepare_hash,
next_subsort: prepare_hash.into(),
key_column_name,
time_multiplier,
})
Expand All @@ -66,10 +67,7 @@ impl Preparer {
/// - This computes and adds the key columns.
/// - This sorts the batch by time, subsort and key hash.
/// - This adds or casts columns as needed.
///
/// Self is mutated as necessary to ensure the `subsort` column is increasing, if
/// it is added.
pub fn prepare_batch(&mut self, batch: RecordBatch) -> error_stack::Result<RecordBatch, Error> {
pub fn prepare_batch(&self, batch: RecordBatch) -> error_stack::Result<RecordBatch, Error> {
let time = get_required_column(&batch, &self.time_column_name)?;
let time = cast_to_timestamp(time, self.time_multiplier)?;

Expand All @@ -80,8 +78,10 @@ impl Preparer {
.into_report()
.change_context_lazy(|| Error::ConvertSubsort(subsort.data_type().clone()))?
} else {
let subsort: UInt64Array = (self.next_subsort..).take(num_rows).collect();
self.next_subsort += num_rows as u64;
let subsort_start = self
.next_subsort
.fetch_add(num_rows as u64, Ordering::SeqCst);
let subsort: UInt64Array = (subsort_start..).take(num_rows).collect();
Arc::new(subsort)
};

Expand Down
4 changes: 2 additions & 2 deletions crates/sparrow-session/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl Session {
name: &str,
schema: SchemaRef,
time_column_name: &str,
retained: bool,
queryable: bool,
subsort_column_name: Option<&str>,
key_column_name: &str,
grouping_name: Option<&str>,
Expand Down Expand Up @@ -150,7 +150,7 @@ impl Session {
key_hash_inverse,
key_column,
expr,
retained,
queryable,
time_unit,
)
}
Expand Down
10 changes: 6 additions & 4 deletions crates/sparrow-session/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl Table {
key_hash_inverse: Arc<ThreadSafeKeyHashInverse>,
key_column: usize,
expr: Expr,
retained: bool,
queryable: bool,
time_unit: Option<&str>,
) -> error_stack::Result<Self, Error> {
let prepared_fields: Fields = KEY_FIELDS
Expand All @@ -38,7 +38,7 @@ impl Table {
let prepare_hash = 0;

assert!(table_info.in_memory.is_none());
let in_memory_batches = Arc::new(InMemoryBatches::new(retained, prepared_schema.clone()));
let in_memory_batches = Arc::new(InMemoryBatches::new(queryable, prepared_schema.clone()));
table_info.in_memory = Some(in_memory_batches.clone());

let preparer = Preparer::new(
Expand Down Expand Up @@ -66,7 +66,7 @@ impl Table {
self.preparer.schema()
}

pub fn add_data(&mut self, batch: RecordBatch) -> error_stack::Result<(), Error> {
pub async fn add_data(&self, batch: RecordBatch) -> error_stack::Result<(), Error> {
let prepared = self
.preparer
.prepare_batch(batch)
Expand All @@ -75,11 +75,13 @@ impl Table {
let key_hashes = prepared.column(2).as_primitive();
let keys = prepared.column(self.key_column);
self.key_hash_inverse
.blocking_add(keys.as_ref(), key_hashes)
.add(keys.as_ref(), key_hashes)
.await
.change_context(Error::Prepare)?;

self.in_memory_batches
.add_batch(prepared)
.await
.change_context(Error::Prepare)?;
Ok(())
}
Expand Down
9 changes: 4 additions & 5 deletions examples/event-api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@ async def main():

# Initialize event source with schema from historical data.
events = kd.sources.PyDict(
rows = [],
schema = pa.schema([
schema=pa.schema([
pa.field("ts", pa.float64()),
pa.field("user", pa.string()),
pa.field("request_id", pa.string()),
]),
time_column = "ts",
key_column = "user",
time_unit = "s",
time_column="ts",
key_column="user",
time_unit="s",
retained=False,
)

Expand Down
3 changes: 2 additions & 1 deletion examples/slackbot/Notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,11 @@
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"import pandas\n",
"import sparrow_pi.sources as sources\n",
"\n",
"messages = kt.sources.Parquet(\"./messages.parquet\", time = \"ts\", entity = \"channel\")\n",
"messages = await kt.sources.Parquet.create(\"./messages.parquet\", time = \"ts\", entity = \"channel\")\n",
"messages = messages.with_key(kt.record({ # !!!\n",
" \"channel\": messages.col(\"channel\"),\n",
" \"thread\": messages.col(\"thread_ts\"),\n",
Expand Down
11 changes: 11 additions & 0 deletions python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion python/docs/source/examples/time_centric.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"\n",
"# For demo simplicity, instead of a CSV file, we read and then parse data from a\n",
"# CSV string. Kaskadaa\n",
"event_data_string = \"\"\"\n",
Expand All @@ -151,7 +153,7 @@
" ev_00020,2022-01-01 22:20:00,user_002,view_item,0\n",
"\"\"\"\n",
"\n",
"events = kd.sources.CsvString(\n",
"events = await kd.sources.CsvString.create(\n",
" event_data_string, time_column=\"event_at\", key_column=\"entity_id\"\n",
")\n",
"\n",
Expand Down
3 changes: 2 additions & 1 deletion python/docs/source/guide/entities.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ This is helpful since the _feature vector_ for an entity will depend only on eve
---
tags: [hide-input]
---
import asyncio
import kaskada as kd
kd.init_session()
data = "\n".join(
Expand All @@ -79,7 +80,7 @@ data = "\n".join(
"1996-12-23T16:40:01,A,12",
]
)
multi_entity = kd.sources.CsvString(data, time_column="time", key_column="key")
multi_entity = await kd.sources.CsvString.create(data, time_column="time", key_column="key")
kd.plot.render(
kd.plot.Plot(multi_entity.col("m"), name="m"),
Expand Down
3 changes: 2 additions & 1 deletion python/docs/source/guide/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ The following Python code imports the Kaskada library, creates a session, and lo
It then runs a query to produce a Pandas DataFrame.

```{code-cell}
import asyncio
import kaskada as kd
kd.init_session()
content = "\n".join(
Expand All @@ -41,6 +42,6 @@ content = "\n".join(
"1996-12-19T16:40:02,A,,",
]
)
source = kd.sources.CsvString(content, time_column="time", key_column="key")
source = await kd.sources.CsvString.create(content, time_column="time", key_column="key")
source.select("m", "n").extend({"sum_m": source.col("m").sum() }).to_pandas()
```
Loading

0 comments on commit bf294dd

Please sign in to comment.