diff --git a/Cargo.lock b/Cargo.lock index 0910d2f6d..970ac9498 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4692,6 +4692,8 @@ dependencies = [ "ahash 0.8.3", "anyhow", "arrow", + "arrow-array", + "arrow-select", "async-once-cell", "async-stream", "async-trait", diff --git a/crates/sparrow-runtime/Cargo.toml b/crates/sparrow-runtime/Cargo.toml index 5a0bf2bca..773f0894f 100644 --- a/crates/sparrow-runtime/Cargo.toml +++ b/crates/sparrow-runtime/Cargo.toml @@ -18,6 +18,8 @@ pulsar = ["dep:pulsar", "avro", "lz4"] ahash.workspace = true anyhow.workspace = true arrow.workspace = true +arrow-array.workspace = true +arrow-select.workspace = true async-once-cell.workspace = true async-stream.workspace = true async-trait.workspace = true diff --git a/crates/sparrow-runtime/src/execute.rs b/crates/sparrow-runtime/src/execute.rs index b516fe6ab..8f9200a31 100644 --- a/crates/sparrow-runtime/src/execute.rs +++ b/crates/sparrow-runtime/src/execute.rs @@ -14,10 +14,11 @@ use sparrow_arrow::scalar_value::ScalarValue; use sparrow_compiler::{hash_compute_plan_proto, DataContext}; use sparrow_qfr::kaskada::sparrow::v1alpha::FlightRecordHeader; +use crate::execute::compute_store_guard::ComputeStoreGuard; use crate::execute::error::Error; -use crate::execute::key_hash_inverse::{KeyHashInverse, ThreadSafeKeyHashInverse}; use crate::execute::operation::OperationContext; use crate::execute::output::Destination; +use crate::key_hash_inverse::{KeyHashInverse, ThreadSafeKeyHashInverse}; use crate::stores::ObjectStoreRegistry; use crate::RuntimeOptions; @@ -25,7 +26,6 @@ mod checkpoints; mod compute_executor; mod compute_store_guard; pub mod error; -pub(crate) mod key_hash_inverse; pub(crate) mod operation; pub mod output; mod progress_reporter; @@ -68,7 +68,7 @@ pub async fn execute( // let output_at_time = request.final_result_time; - execute_new(plan, destination, data_context, options).await + execute_new(plan, destination, data_context, options, None).await } #[derive(Default, Debug)] @@ -173,24 +173,12 @@ impl ExecutionOptions { } } -pub async fn execute_new( - plan: ComputePlan, - destination: Destination, - mut data_context: DataContext, - options: ExecutionOptions, -) -> error_stack::Result>, Error> { - let object_stores = Arc::new(ObjectStoreRegistry::default()); - - let plan_hash = hash_compute_plan_proto(&plan); - - let compute_store = options - .compute_store( - object_stores.as_ref(), - plan.per_entity_behavior(), - &plan_hash, - ) - .await?; - +async fn load_key_hash_inverse( + plan: &ComputePlan, + data_context: &mut DataContext, + compute_store: &Option, + object_stores: &ObjectStoreRegistry, +) -> error_stack::Result, Error> { let primary_grouping_key_type = plan .primary_grouping_key_type .to_owned() @@ -200,23 +188,58 @@ pub async fn execute_new( .into_report() .change_context(Error::internal_msg("decode primary_grouping_key_type"))?; - let mut key_hash_inverse = KeyHashInverse::from_data_type(primary_grouping_key_type.clone()); + let primary_group_id = data_context + .get_or_create_group_id(&plan.primary_grouping, &primary_grouping_key_type) + .into_report() + .change_context(Error::internal_msg("get primary grouping ID"))?; + + let mut key_hash_inverse = KeyHashInverse::from_data_type(&primary_grouping_key_type.clone()); if let Some(compute_store) = &compute_store { if let Ok(restored) = KeyHashInverse::restore_from(compute_store.store_ref()) { key_hash_inverse = restored } } - let primary_group_id = data_context - .get_or_create_group_id(&plan.primary_grouping, &primary_grouping_key_type) - .into_report() - .change_context(Error::internal_msg("get primary grouping ID"))?; - key_hash_inverse - .add_from_data_context(&data_context, primary_group_id, &object_stores) + .add_from_data_context(data_context, primary_group_id, object_stores) .await .change_context(Error::internal_msg("initialize key hash inverse"))?; let key_hash_inverse = Arc::new(ThreadSafeKeyHashInverse::new(key_hash_inverse)); + Ok(key_hash_inverse) +} + +/// Execute a given query based on the options. +/// +/// Parameters +/// ---------- +/// - key_hash_inverse: If set, specifies the key hash inverses to use. If None, the +/// key hashes will be created. +pub async fn execute_new( + plan: ComputePlan, + destination: Destination, + mut data_context: DataContext, + options: ExecutionOptions, + key_hash_inverses: Option>, +) -> error_stack::Result>, Error> { + let object_stores = Arc::new(ObjectStoreRegistry::default()); + + let plan_hash = hash_compute_plan_proto(&plan); + + let compute_store = options + .compute_store( + object_stores.as_ref(), + plan.per_entity_behavior(), + &plan_hash, + ) + .await?; + + let key_hash_inverse = if let Some(key_hash_inverse) = key_hash_inverses { + key_hash_inverse + } else { + load_key_hash_inverse(&plan, &mut data_context, &compute_store, &object_stores) + .await + .change_context(Error::internal_msg("load key hash inverse"))? + }; // Channel for the output stats. let (progress_updates_tx, progress_updates_rx) = @@ -303,5 +326,5 @@ pub async fn materialize( // TODO: the `execute_with_progress` method contains a lot of additional logic that is theoretically not needed, // as the materialization does not exit, and should not need to handle cleanup tasks that regular // queries do. We should likely refactor this to use a separate `materialize_with_progress` method. - execute_new(plan, destination, data_context, options).await + execute_new(plan, destination, data_context, options, None).await } diff --git a/crates/sparrow-runtime/src/execute/operation.rs b/crates/sparrow-runtime/src/execute/operation.rs index b93dd0690..600cb8abb 100644 --- a/crates/sparrow-runtime/src/execute/operation.rs +++ b/crates/sparrow-runtime/src/execute/operation.rs @@ -60,10 +60,10 @@ use self::scan::ScanOperation; use self::select::SelectOperation; use self::tick::TickOperation; use self::with_key::WithKeyOperation; -use crate::execute::key_hash_inverse::ThreadSafeKeyHashInverse; use crate::execute::operation::expression_executor::{ExpressionExecutor, InputColumn}; use crate::execute::operation::shift_until::ShiftUntilOperation; use crate::execute::Error; +use crate::key_hash_inverse::ThreadSafeKeyHashInverse; use crate::stores::ObjectStoreRegistry; use crate::Batch; diff --git a/crates/sparrow-runtime/src/execute/operation/scan.rs b/crates/sparrow-runtime/src/execute/operation/scan.rs index 430bcc8a5..e99cf8328 100644 --- a/crates/sparrow-runtime/src/execute/operation/scan.rs +++ b/crates/sparrow-runtime/src/execute/operation/scan.rs @@ -376,9 +376,9 @@ mod tests { use sparrow_compiler::DataContext; use uuid::Uuid; - use crate::execute::key_hash_inverse::{KeyHashInverse, ThreadSafeKeyHashInverse}; use crate::execute::operation::testing::batches_to_csv; use crate::execute::operation::{OperationContext, OperationExecutor}; + use crate::key_hash_inverse::ThreadSafeKeyHashInverse; use crate::read::testing::write_parquet_file; use crate::stores::ObjectStoreRegistry; @@ -486,8 +486,7 @@ mod tests { })), }; - let key_hash_inverse = KeyHashInverse::from_data_type(DataType::Utf8); - let key_hash_inverse = Arc::new(ThreadSafeKeyHashInverse::new(key_hash_inverse)); + let key_hash_inverse = Arc::new(ThreadSafeKeyHashInverse::from_data_type(&DataType::Utf8)); let (max_event_tx, mut max_event_rx) = tokio::sync::mpsc::unbounded_channel(); let (sender, receiver) = tokio::sync::mpsc::channel(10); diff --git a/crates/sparrow-runtime/src/execute/operation/testing.rs b/crates/sparrow-runtime/src/execute/operation/testing.rs index f1a131623..be209ecd5 100644 --- a/crates/sparrow-runtime/src/execute/operation/testing.rs +++ b/crates/sparrow-runtime/src/execute/operation/testing.rs @@ -7,8 +7,8 @@ use itertools::Itertools; use sparrow_api::kaskada::v1alpha::{ComputePlan, OperationPlan}; use sparrow_compiler::DataContext; -use crate::execute::key_hash_inverse::{KeyHashInverse, ThreadSafeKeyHashInverse}; use crate::execute::operation::{OperationContext, OperationExecutor}; +use crate::key_hash_inverse::ThreadSafeKeyHashInverse; use crate::stores::ObjectStoreRegistry; use crate::Batch; @@ -173,8 +173,7 @@ pub(super) async fn run_operation( // Channel for the output stats. let (progress_updates_tx, _) = tokio::sync::mpsc::channel(29); - let key_hash_inverse = KeyHashInverse::from_data_type(DataType::Utf8); - let key_hash_inverse = Arc::new(ThreadSafeKeyHashInverse::new(key_hash_inverse)); + let key_hash_inverse = Arc::new(ThreadSafeKeyHashInverse::from_data_type(&DataType::Utf8)); let mut context = OperationContext { plan: ComputePlan { @@ -223,8 +222,7 @@ pub(super) async fn run_operation_json( inputs.push(receiver); } - let key_hash_inverse = KeyHashInverse::from_data_type(DataType::Utf8); - let key_hash_inverse = Arc::new(ThreadSafeKeyHashInverse::new(key_hash_inverse)); + let key_hash_inverse = Arc::new(ThreadSafeKeyHashInverse::from_data_type(&DataType::Utf8)); let (max_event_tx, mut max_event_rx) = tokio::sync::mpsc::unbounded_channel(); diff --git a/crates/sparrow-runtime/src/execute/operation/with_key.rs b/crates/sparrow-runtime/src/execute/operation/with_key.rs index a518b4fc9..5dcb2093f 100644 --- a/crates/sparrow-runtime/src/execute/operation/with_key.rs +++ b/crates/sparrow-runtime/src/execute/operation/with_key.rs @@ -2,10 +2,10 @@ use std::sync::Arc; use super::BoxedOperation; use crate::execute::error::{invalid_operation, Error}; -use crate::execute::key_hash_inverse::ThreadSafeKeyHashInverse; use crate::execute::operation::expression_executor::InputColumn; use crate::execute::operation::single_consumer_helper::SingleConsumerHelper; use crate::execute::operation::{InputBatch, Operation, OperationContext}; +use crate::key_hash_inverse::ThreadSafeKeyHashInverse; use crate::Batch; use anyhow::Context; use async_trait::async_trait; @@ -115,8 +115,9 @@ impl WithKeyOperation { // primary grouping to produce the key hash inverse for output. if self.is_primary_grouping { self.key_hash_inverse - .add(new_keys.to_owned(), &new_key_hashes) - .await?; + .add(new_keys.as_ref(), &new_key_hashes) + .await + .map_err(|e| e.into_error())?; } // Get the take indices, which will allow us to get the requested columns from diff --git a/crates/sparrow-runtime/src/execute/output.rs b/crates/sparrow-runtime/src/execute/output.rs index e5e96c892..aac939595 100644 --- a/crates/sparrow-runtime/src/execute/output.rs +++ b/crates/sparrow-runtime/src/execute/output.rs @@ -12,9 +12,9 @@ use sparrow_api::kaskada::v1alpha::execute_request::Limits; use sparrow_api::kaskada::v1alpha::{data_type, ObjectStoreDestination, PulsarDestination}; use sparrow_arrow::downcast::{downcast_primitive_array, downcast_struct_array}; -use crate::execute::key_hash_inverse::ThreadSafeKeyHashInverse; use crate::execute::operation::OperationContext; use crate::execute::progress_reporter::ProgressUpdate; +use crate::key_hash_inverse::ThreadSafeKeyHashInverse; use crate::Batch; mod object_store; diff --git a/crates/sparrow-runtime/src/execute/key_hash_inverse.rs b/crates/sparrow-runtime/src/key_hash_inverse.rs similarity index 70% rename from crates/sparrow-runtime/src/execute/key_hash_inverse.rs rename to crates/sparrow-runtime/src/key_hash_inverse.rs index f79c29dd3..779fbd98c 100644 --- a/crates/sparrow-runtime/src/execute/key_hash_inverse.rs +++ b/crates/sparrow-runtime/src/key_hash_inverse.rs @@ -4,8 +4,9 @@ use anyhow::Context; use arrow::array::{Array, ArrayRef, AsArray, PrimitiveArray, UInt64Array}; use arrow::datatypes::{DataType, UInt64Type}; -use error_stack::{IntoReportCompat, ResultExt}; +use error_stack::{IntoReport, IntoReportCompat, ResultExt}; use futures::TryStreamExt; +use hashbrown::hash_map::Entry; use hashbrown::HashMap; use sparrow_arrow::downcast::downcast_primitive_array; use sparrow_compiler::DataContext; @@ -21,7 +22,7 @@ use crate::stores::{ObjectStoreRegistry, ObjectStoreUrl}; /// If the entity key type is null, then all inverse keys are null. #[derive(serde::Serialize, serde::Deserialize)] pub struct KeyHashInverse { - key_hash_to_indices: HashMap, + key_hash_to_indices: HashMap, #[serde(with = "sparrow_arrow::serde::array_ref")] key: ArrayRef, } @@ -45,6 +46,19 @@ pub enum Error { OpeningMetadata, #[display(fmt = "failed to read metadata")] ReadingMetadata, + #[display(fmt = "key hashes contained nulls")] + KeyHashContainedNull, + #[display(Fmt = "error in Arrow kernel")] + Arrow, + #[display(fmt = "key hash not registered")] + MissingKeyHash, + #[display(fmt = "key hashes and keys are of different lengths ({keys} != {key_hashes})")] + MismatchedLengths { keys: usize, key_hashes: usize }, + #[display(fmt = "incompatible key types (expected: {expected:?}, actual: {actual:?})")] + IncompatibleKeyTypes { + expected: DataType, + actual: DataType, + }, } impl error_stack::Context for Error {} @@ -68,10 +82,10 @@ impl KeyHashInverse { } /// Creates a new key hash inverse from a primary grouping data type. - pub fn from_data_type(primary_grouping_type: DataType) -> Self { + pub fn from_data_type(primary_grouping_type: &DataType) -> Self { Self { key_hash_to_indices: HashMap::new(), - key: arrow::array::new_empty_array(&primary_grouping_type), + key: arrow::array::new_empty_array(primary_grouping_type), } } @@ -109,8 +123,7 @@ impl KeyHashInverse { .into_report() .change_context(Error::ReadingMetadata)?; let entity_key_col = batch.column(1); - self.add(entity_key_col.to_owned(), hash_col) - .into_report() + self.add(entity_key_col.as_ref(), hash_col) .change_context(Error::ReadingMetadata)?; } @@ -128,8 +141,7 @@ impl KeyHashInverse { }) }); for (keys, key_hashes) in in_memory { - self.add(keys, key_hashes.as_primitive()) - .into_report() + self.add(keys.as_ref(), key_hashes.as_primitive()) .change_context(Error::ReadingMetadata) .unwrap(); } @@ -143,30 +155,53 @@ impl KeyHashInverse { /// values are aligned to map from a key to a hash per index. The /// current implementation eagerly adds the keys and hashes to the /// inverse but can be optimized to perform the addition lazily. - fn add(&mut self, keys: ArrayRef, key_hashes: &UInt64Array) -> anyhow::Result<()> { + fn add( + &mut self, + keys: &dyn Array, + key_hashes: &UInt64Array, + ) -> error_stack::Result<(), Error> { // Since the keys map to the key hashes directly, both arrays need to be the // same length - anyhow::ensure!(keys.len() == key_hashes.len()); + error_stack::ensure!(key_hashes.null_count() == 0, Error::KeyHashContainedNull); + error_stack::ensure!( + keys.data_type() == self.key.data_type(), + Error::IncompatibleKeyTypes { + expected: self.key.data_type().clone(), + actual: keys.data_type().clone(), + } + ); + let mut len = self.key_hash_to_indices.len() as u64; + // Determine the indices that we need to add. let indices_from_batch: Vec = key_hashes + .values() .iter() .enumerate() .flat_map(|(index, key_hash)| { - if let Some(key_hash) = key_hash { - if !self.key_hash_to_indices.contains_key(&key_hash) { - self.key_hash_to_indices - .insert(key_hash, self.key_hash_to_indices.len()); - return Some(index as u64); + match self.key_hash_to_indices.entry(*key_hash) { + Entry::Occupied(_) => { + // Key hash is already registered. + None + } + Entry::Vacant(vacancy) => { + vacancy.insert(len); + len += 1; + Some(index as u64) } } - None }) .collect(); + debug_assert_eq!(self.key_hash_to_indices.len(), len as usize); + if !indices_from_batch.is_empty() { let indices_from_batch: PrimitiveArray = PrimitiveArray::from_iter_values(indices_from_batch); - let keys = arrow::compute::take(&keys, &indices_from_batch, None)?; + let keys = arrow_select::take::take(keys, &indices_from_batch, None) + .into_report() + .change_context(Error::Arrow)?; let concatenated_keys: Vec<_> = vec![self.key.as_ref(), keys.as_ref()]; - let concatenated_keys = arrow::compute::concat(&concatenated_keys)?; + let concatenated_keys = arrow_select::concat::concat(&concatenated_keys) + .into_report() + .change_context(Error::Arrow)?; self.key = concatenated_keys; } Ok(()) @@ -184,19 +219,20 @@ impl KeyHashInverse { /// /// If the entity key type is null, then a null array is returned of same /// length. - pub fn inverse(&self, key_hashes: &UInt64Array) -> anyhow::Result { + pub fn inverse(&self, key_hashes: &UInt64Array) -> error_stack::Result { let mut key_hash_indices: Vec = Vec::new(); - for key_hash in key_hashes { - let key_hash = key_hash.with_context(|| "unable to get key_hash")?; + for key_hash in key_hashes.values() { let key_hash_index = self .key_hash_to_indices - .get(&key_hash) - .with_context(|| "unable to find key")?; - key_hash_indices.push(*key_hash_index as u64); + .get(key_hash) + .ok_or(Error::MissingKeyHash)?; + key_hash_indices.push(*key_hash_index); } let key_hash_indices: PrimitiveArray = PrimitiveArray::from_iter_values(key_hash_indices); - let result = arrow::compute::take(&self.key, &key_hash_indices, None)?; + let result = arrow_select::take::take(&self.key, &key_hash_indices, None) + .into_report() + .change_context(Error::Arrow)?; Ok(result) } } @@ -238,10 +274,15 @@ impl ThreadSafeKeyHashInverse { } } + /// Creates a new key hash inverse from a primary grouping data type. + pub fn from_data_type(primary_grouping_type: &DataType) -> Self { + Self::new(KeyHashInverse::from_data_type(primary_grouping_type)) + } + /// Lookup keys from a key hash array. /// /// This method is thread-safe and acquires the read-lock. - pub async fn inverse(&self, key_hashes: &UInt64Array) -> anyhow::Result { + pub async fn inverse(&self, key_hashes: &UInt64Array) -> error_stack::Result { let read = self.key_map.read().await; read.inverse(key_hashes) } @@ -255,8 +296,18 @@ impl ThreadSafeKeyHashInverse { /// This method is thread safe. It acquires the read lock to check if /// any of the keys need to be added to the inverse map, and only acquires /// the write lock if needed. - pub async fn add(&self, keys: ArrayRef, key_hashes: &UInt64Array) -> anyhow::Result<()> { - anyhow::ensure!(keys.len() == key_hashes.len()); + pub async fn add( + &self, + keys: &dyn Array, + key_hashes: &UInt64Array, + ) -> error_stack::Result<(), Error> { + error_stack::ensure!( + keys.len() == key_hashes.len(), + Error::MismatchedLengths { + keys: keys.len(), + key_hashes: key_hashes.len() + } + ); let has_new_keys = { let read = self.key_map.read().await; read.has_new_keys(key_hashes) @@ -270,6 +321,27 @@ impl ThreadSafeKeyHashInverse { } } + pub fn blocking_add( + &self, + keys: &dyn Array, + key_hashes: &UInt64Array, + ) -> error_stack::Result<(), Error> { + error_stack::ensure!( + keys.len() == key_hashes.len(), + Error::MismatchedLengths { + keys: keys.len(), + key_hashes: key_hashes.len() + } + ); + let has_new_keys = self.key_map.blocking_read().has_new_keys(key_hashes); + + if has_new_keys { + self.key_map.blocking_write().add(keys, key_hashes) + } else { + Ok(()) + } + } + /// Stores the KeyHashInverse to the compute store. /// /// This method is thread-safe and acquires the read-lock. @@ -287,15 +359,15 @@ mod tests { use arrow::datatypes::DataType; use sparrow_instructions::ComputeStore; - use crate::execute::key_hash_inverse::{KeyHashInverse, ThreadSafeKeyHashInverse}; + use crate::key_hash_inverse::{KeyHashInverse, ThreadSafeKeyHashInverse}; #[test] fn test_inverse_with_int32() { let keys = Arc::new(Int32Array::from(vec![100, 200])); let key_hashes = UInt64Array::from(vec![1, 2]); - let mut key_hash = KeyHashInverse::from_data_type(DataType::Int32); - key_hash.add(keys, &key_hashes).unwrap(); + let mut key_hash = KeyHashInverse::from_data_type(&DataType::Int32); + key_hash.add(keys.as_ref(), &key_hashes).unwrap(); let test_hashes = UInt64Array::from_iter_values([1, 2, 1]); let result = key_hash.inverse(&test_hashes).unwrap(); @@ -304,11 +376,11 @@ mod tests { #[test] fn test_inverse_with_string() { - let keys = Arc::new(StringArray::from(vec!["awkward", "tacos"])); + let keys = StringArray::from(vec!["awkward", "tacos"]); let key_hashes = UInt64Array::from(vec![1, 2]); - let mut key_hash = KeyHashInverse::from_data_type(DataType::Utf8); - key_hash.add(keys, &key_hashes).unwrap(); + let mut key_hash = KeyHashInverse::from_data_type(&DataType::Utf8); + key_hash.add(&keys, &key_hashes).unwrap(); let test_hashes = UInt64Array::from_iter_values([1, 2, 1]); let result = key_hash.inverse(&test_hashes).unwrap(); @@ -320,10 +392,10 @@ mod tests { #[test] fn test_has_new_keys_no_new_keys() { - let keys = Arc::new(Int32Array::from(vec![100, 200])); + let keys = Int32Array::from(vec![100, 200]); let key_hashes = UInt64Array::from(vec![1, 2]); - let mut key_hash = KeyHashInverse::from_data_type(DataType::Int32); - key_hash.add(keys, &key_hashes).unwrap(); + let mut key_hash = KeyHashInverse::from_data_type(&DataType::Int32); + key_hash.add(&keys, &key_hashes).unwrap(); let verify_key_hashes = UInt64Array::from(vec![1, 2]); assert!(!key_hash.has_new_keys(&verify_key_hashes)); @@ -331,10 +403,10 @@ mod tests { #[test] fn test_has_new_keys_some_new_keys() { - let keys = Arc::new(Int32Array::from(vec![100, 200])); + let keys = Int32Array::from(vec![100, 200]); let key_hashes = UInt64Array::from(vec![1, 2]); - let mut key_hash = KeyHashInverse::from_data_type(DataType::Int32); - key_hash.add(keys, &key_hashes).unwrap(); + let mut key_hash = KeyHashInverse::from_data_type(&DataType::Int32); + key_hash.add(&keys, &key_hashes).unwrap(); let verify_key_hashes = UInt64Array::from(vec![1, 2, 3]); assert!(key_hash.has_new_keys(&verify_key_hashes)); @@ -342,10 +414,10 @@ mod tests { #[test] fn test_has_new_keys_all_new_keys() { - let keys = Arc::new(Int32Array::from(vec![100, 200])); + let keys = Int32Array::from(vec![100, 200]); let key_hashes = UInt64Array::from(vec![1, 2]); - let mut key_hash = KeyHashInverse::from_data_type(DataType::Int32); - key_hash.add(keys, &key_hashes).unwrap(); + let mut key_hash = KeyHashInverse::from_data_type(&DataType::Int32); + key_hash.add(&keys, &key_hashes).unwrap(); let verify_key_hashes = UInt64Array::from(vec![3, 4, 5]); assert!(key_hash.has_new_keys(&verify_key_hashes)); @@ -353,12 +425,12 @@ mod tests { #[tokio::test] async fn test_thread_safe_inverse_with_int32() { - let keys = Arc::new(Int32Array::from(vec![100, 200])); + let keys = Int32Array::from(vec![100, 200]); let key_hashes = UInt64Array::from(vec![1, 2]); - let key_hash = KeyHashInverse::from_data_type(DataType::Int32); + let key_hash = KeyHashInverse::from_data_type(&DataType::Int32); let key_hash = ThreadSafeKeyHashInverse::new(key_hash); - key_hash.add(keys, &key_hashes).await.unwrap(); + key_hash.add(&keys, &key_hashes).await.unwrap(); let test_hashes = UInt64Array::from_iter_values([1, 2, 1]); let result = key_hash.inverse(&test_hashes).await.unwrap(); @@ -367,12 +439,12 @@ mod tests { #[tokio::test] async fn test_thread_safe_inverse_with_string() { - let keys = Arc::new(StringArray::from(vec!["awkward", "tacos"])); + let keys = StringArray::from(vec!["awkward", "tacos"]); let key_hashes = UInt64Array::from(vec![1, 2]); - let key_hash = KeyHashInverse::from_data_type(DataType::Utf8); + let key_hash = KeyHashInverse::from_data_type(&DataType::Utf8); let key_hash = ThreadSafeKeyHashInverse::new(key_hash); - key_hash.add(keys, &key_hashes).await.unwrap(); + key_hash.add(&keys, &key_hashes).await.unwrap(); let test_hashes = UInt64Array::from_iter_values([1, 2, 1]); let result = key_hash.inverse(&test_hashes).await.unwrap(); @@ -408,9 +480,9 @@ mod tests { key_hash.store_to(&compute_store).unwrap(); let mut key_hash = KeyHashInverse::restore_from(&compute_store).unwrap(); - let keys = Arc::new(StringArray::from(vec!["party", "pizza"])); + let keys = StringArray::from(vec!["party", "pizza"]); let key_hashes = UInt64Array::from(vec![3, 4]); - key_hash.add(keys, &key_hashes).unwrap(); + key_hash.add(&keys, &key_hashes).unwrap(); let test_hashes = UInt64Array::from_iter_values([1, 2, 3, 4]); let result = key_hash.inverse(&test_hashes).unwrap(); assert_eq!( @@ -420,10 +492,10 @@ mod tests { } async fn test_key_hash_inverse() -> KeyHashInverse { - let keys = Arc::new(StringArray::from(vec!["awkward", "tacos"])); + let keys = StringArray::from(vec!["awkward", "tacos"]); let key_hashes = UInt64Array::from(vec![1, 2]); - let mut key_hash = KeyHashInverse::from_data_type(DataType::Utf8); - key_hash.add(keys, &key_hashes).unwrap(); + let mut key_hash = KeyHashInverse::from_data_type(&DataType::Utf8); + key_hash.add(&keys, &key_hashes).unwrap(); key_hash } diff --git a/crates/sparrow-runtime/src/lib.rs b/crates/sparrow-runtime/src/lib.rs index 9207840a7..e9e9e6621 100644 --- a/crates/sparrow-runtime/src/lib.rs +++ b/crates/sparrow-runtime/src/lib.rs @@ -25,6 +25,7 @@ mod batch; pub mod execute; mod key_hash_index; +pub mod key_hash_inverse; mod metadata; mod min_heap; pub mod prepare; diff --git a/crates/sparrow-runtime/src/prepare/execute_input_stream.rs b/crates/sparrow-runtime/src/prepare/execute_input_stream.rs index 87da2c725..8932980fa 100644 --- a/crates/sparrow-runtime/src/prepare/execute_input_stream.rs +++ b/crates/sparrow-runtime/src/prepare/execute_input_stream.rs @@ -14,7 +14,7 @@ use sparrow_api::kaskada::v1alpha::{slice_plan, TableConfig}; use sparrow_arrow::downcast::downcast_primitive_array; use sparrow_core::TableSchema; -use crate::execute::key_hash_inverse::ThreadSafeKeyHashInverse; +use crate::key_hash_inverse::ThreadSafeKeyHashInverse; use crate::prepare::slice_preparer::SlicePreparer; use crate::prepare::Error; @@ -308,9 +308,8 @@ async fn update_key_inverse( .into_report() .change_context(Error::PreparingColumn)?; key_hash_inverse - .add(keys.clone(), key_hashes) + .add(keys.as_ref(), key_hashes) .await - .into_report() .change_context(Error::PreparingColumn)?; Ok(()) } @@ -328,7 +327,7 @@ mod tests { use static_init::dynamic; use uuid::Uuid; - use crate::execute::key_hash_inverse::{KeyHashInverse, ThreadSafeKeyHashInverse}; + use crate::key_hash_inverse::ThreadSafeKeyHashInverse; use crate::prepare::execute_input_stream; use crate::RawMetadata; @@ -378,9 +377,8 @@ mod tests { let batch2 = make_time_batch(&[6, 12, 10, 17, 11, 12]); let batch3 = make_time_batch(&[20]); let reader = futures::stream::iter(vec![Ok(batch1), Ok(batch2), Ok(batch3)]).boxed(); - let key_hash_inverse = Arc::new(ThreadSafeKeyHashInverse::new( - KeyHashInverse::from_data_type(DataType::UInt64), - )); + let key_hash_inverse = + Arc::new(ThreadSafeKeyHashInverse::from_data_type(&DataType::UInt64)); let raw_metadata = RawMetadata::from_raw_schema(RAW_SCHEMA.clone()).unwrap(); let mut stream = execute_input_stream::prepare_input( @@ -428,9 +426,8 @@ mod tests { let batch3 = make_time_batch(&[20]); let reader = futures::stream::iter(vec![Ok(batch1), Ok(batch2), Ok(batch3)]).boxed(); - let key_hash_inverse = Arc::new(ThreadSafeKeyHashInverse::new( - KeyHashInverse::from_data_type(DataType::UInt64), - )); + let key_hash_inverse = + Arc::new(ThreadSafeKeyHashInverse::from_data_type(&DataType::UInt64)); let raw_metadata = RawMetadata::from_raw_schema(RAW_SCHEMA.clone()).unwrap(); let mut stream = execute_input_stream::prepare_input( @@ -479,9 +476,8 @@ mod tests { let batch3 = make_time_batch(&[7, 17]); let reader = futures::stream::iter(vec![Ok(batch1), Ok(batch2), Ok(batch3)]).boxed(); - let key_hash_inverse = Arc::new(ThreadSafeKeyHashInverse::new( - KeyHashInverse::from_data_type(DataType::UInt64), - )); + let key_hash_inverse = + Arc::new(ThreadSafeKeyHashInverse::from_data_type(&DataType::UInt64)); let raw_metadata = RawMetadata::from_raw_schema(RAW_SCHEMA.clone()).unwrap(); let mut stream = execute_input_stream::prepare_input( diff --git a/crates/sparrow-session/src/session.rs b/crates/sparrow-session/src/session.rs index 3fb909a83..db3791027 100644 --- a/crates/sparrow-session/src/session.rs +++ b/crates/sparrow-session/src/session.rs @@ -1,4 +1,6 @@ use std::borrow::Cow; +use std::collections::HashMap; +use std::sync::Arc; use arrow_schema::SchemaRef; use error_stack::{IntoReport, IntoReportCompat, ResultExt}; @@ -9,7 +11,9 @@ use sparrow_api::kaskada::v1alpha::{ ComputeTable, FeatureSet, PerEntityBehavior, TableConfig, TableMetadata, }; use sparrow_compiler::{AstDfgRef, DataContext, Dfg, DiagnosticCollector}; +use sparrow_plan::GroupId; use sparrow_runtime::execute::output::Destination; +use sparrow_runtime::key_hash_inverse::ThreadSafeKeyHashInverse; use sparrow_syntax::{ExprOp, FenlType, LiteralValue, Located, Location, Resolved}; use uuid::Uuid; @@ -20,6 +24,7 @@ use crate::{Error, Expr, Literal, Table}; pub struct Session { data_context: DataContext, dfg: Dfg, + key_hash_inverse: HashMap>, } #[derive(Default)] @@ -81,6 +86,10 @@ impl Session { file_sets: vec![], }; + let (key_column, key_field) = schema + .column_with_name(key_column_name) + .expect("expected key column"); + let table_info = self .data_context .add_table(table) @@ -98,7 +107,17 @@ impl Session { let expr = Expr(dfg_node); - Ok(Table::new(table_info, expr)) + let key_hash_inverse = self + .key_hash_inverse + .entry(table_info.group_id()) + .or_insert_with(|| { + Arc::new(ThreadSafeKeyHashInverse::from_data_type( + key_field.data_type(), + )) + }) + .clone(); + + Ok(Table::new(table_info, key_hash_inverse, key_column, expr)) } pub fn add_cast( @@ -262,13 +281,13 @@ impl Session { options: ExecutionOptions, ) -> error_stack::Result { // TODO: Decorations? + let group_id = expr + .0 + .grouping() + .expect("query to be grouped (non-literal)"); let primary_group_info = self .data_context - .group_info( - expr.0 - .grouping() - .expect("query to be grouped (non-literal)"), - ) + .group_info(group_id) .expect("missing group info"); let primary_grouping = primary_group_info.name().to_owned(); let primary_grouping_key_type = primary_group_info.key_type(); @@ -280,7 +299,7 @@ impl Session { .into_report() .change_context(Error::Compile)?; - // TODO: Run the egraph simplifier. + // TODO: Run the egraph simplifications. // TODO: Incremental? // TODO: Slicing? let plan = sparrow_compiler::plan::extract_plan_proto( @@ -310,6 +329,16 @@ impl Session { let mut options = options.to_sparrow_options(); options.stop_signal_rx = Some(stop_signal_rx); + let key_hash_inverse = self + .key_hash_inverse + .get(&group_id) + .cloned() + .unwrap_or_else(|| { + Arc::new(ThreadSafeKeyHashInverse::from_data_type( + primary_grouping_key_type, + )) + }); + // Hacky. Use the existing execution logic. This weird things with downloading checkpoints, etc. let progress = rt .block_on(sparrow_runtime::execute::execute_new( @@ -317,6 +346,7 @@ impl Session { destination, data_context, options, + Some(key_hash_inverse), )) .change_context(Error::Execute)? .map_err(|e| e.change_context(Error::Execute)) diff --git a/crates/sparrow-session/src/table.rs b/crates/sparrow-session/src/table.rs index 6405ed387..2243821c5 100644 --- a/crates/sparrow-session/src/table.rs +++ b/crates/sparrow-session/src/table.rs @@ -1,11 +1,13 @@ use std::sync::Arc; +use arrow_array::cast::AsArray; use arrow_array::types::ArrowPrimitiveType; use arrow_array::RecordBatch; use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; use error_stack::ResultExt; use sparrow_compiler::TableInfo; use sparrow_merge::InMemoryBatches; +use sparrow_runtime::key_hash_inverse::ThreadSafeKeyHashInverse; use sparrow_runtime::preparer::Preparer; use crate::{Error, Expr}; @@ -14,10 +16,17 @@ pub struct Table { pub expr: Expr, preparer: Preparer, in_memory_batches: Arc, + key_column: usize, + key_hash_inverse: Arc, } impl Table { - pub(crate) fn new(table_info: &mut TableInfo, expr: Expr) -> Self { + pub(crate) fn new( + table_info: &mut TableInfo, + key_hash_inverse: Arc, + key_column: usize, + expr: Expr, + ) -> Self { let prepared_fields: Fields = KEY_FIELDS .iter() .chain(table_info.schema().fields.iter()) @@ -42,6 +51,8 @@ impl Table { expr, preparer, in_memory_batches, + key_hash_inverse, + key_column: key_column + KEY_FIELDS.len(), } } @@ -54,6 +65,13 @@ impl Table { .preparer .prepare_batch(batch) .change_context(Error::Prepare)?; + + let key_hashes = prepared.column(2).as_primitive(); + let keys = prepared.column(self.key_column); + self.key_hash_inverse + .blocking_add(keys.as_ref(), key_hashes) + .change_context(Error::Prepare)?; + self.in_memory_batches .add_batch(prepared) .change_context(Error::Prepare)?; diff --git a/sparrow-py/Cargo.lock b/sparrow-py/Cargo.lock index a9e98f1d2..48290ab55 100644 --- a/sparrow-py/Cargo.lock +++ b/sparrow-py/Cargo.lock @@ -3884,6 +3884,8 @@ dependencies = [ "ahash 0.8.3", "anyhow", "arrow", + "arrow-array", + "arrow-select", "async-once-cell", "async-stream", "async-trait", diff --git a/sparrow-py/noxfile.py b/sparrow-py/noxfile.py index 80798ff91..5b922924c 100644 --- a/sparrow-py/noxfile.py +++ b/sparrow-py/noxfile.py @@ -95,7 +95,7 @@ def safety(session: Session) -> None: @session(python=python_versions) def mypy(session: Session) -> None: """Type-check using mypy.""" - args = session.posargs or ["pysrc", "pytests", "docs/conf.py"] + args = session.posargs or ["pysrc", "pytests", "docs/source/conf.py"] session.install("mypy", "pytest", "pandas-stubs") install_self(session) # Using `--install-types` should make this less picky about missing stubs. diff --git a/sparrow-py/pysrc/sparrow_py/windows.py b/sparrow-py/pysrc/sparrow_py/windows.py index 37371a8be..8d621f4e8 100644 --- a/sparrow-py/pysrc/sparrow_py/windows.py +++ b/sparrow-py/pysrc/sparrow_py/windows.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from datetime import timedelta @@ -18,12 +20,12 @@ class Since(Window): Parameters ---------- - predicate : Timestream + predicate : Timestream | bool The boolean Timestream to use as predicate for the window. Each time the predicate evaluates to true the window will be cleared. """ - predicate: Timestream + predicate: Timestream | bool @dataclass(frozen=True) @@ -36,13 +38,13 @@ class Sliding(Window): duration : int The number of sliding intervals to use in the window. - predicate : Timestream + predicate : Timestream | bool The boolean Timestream to use as predicate for the window Each time the predicate evaluates to true the window starts a new interval. """ duration: int - predicate: Timestream + predicate: Timestream | bool def __post_init__(self): if self.duration <= 0: diff --git a/sparrow-py/pytests/aggregation_test.py b/sparrow-py/pytests/aggregation_test.py index 20602faab..e607cc156 100644 --- a/sparrow-py/pytests/aggregation_test.py +++ b/sparrow-py/pytests/aggregation_test.py @@ -42,6 +42,6 @@ def test_sum_since_true(source, golden) -> None: # `since(True)` should be the same as unwindowed, so equals the original vaule. m_sum_since_true = kt.record({ "m": source.col("m"), - "m_sum": source.col("m").sum(window=kt.SinceWindow(True)), + "m_sum": source.col("m").sum(window=kt.windows.Since(True)), }) golden.jsonl(m_sum_since_true) \ No newline at end of file diff --git a/sparrow-py/pytests/golden/result_test/test_iter_pandas_async_materialize_1.jsonl b/sparrow-py/pytests/golden/result_test/test_iter_pandas_async_materialize_1.jsonl index 79b23c2c3..0dcf67073 100644 --- a/sparrow-py/pytests/golden/result_test/test_iter_pandas_async_materialize_1.jsonl +++ b/sparrow-py/pytests/golden/result_test/test_iter_pandas_async_materialize_1.jsonl @@ -1,6 +1,6 @@ {"_time":851128797000000000,"_subsort":6,"_key_hash":12960666915911099378,"_key":"A","time":"1996-12-20T16:39:57-08:00","key":"A","m":5.0,"n":10.0} {"_time":851128798000000000,"_subsort":7,"_key_hash":2867199309159137213,"_key":"B","time":"1996-12-20T16:39:58-08:00","key":"B","m":24.0,"n":3.0} {"_time":851128799000000000,"_subsort":8,"_key_hash":12960666915911099378,"_key":"A","time":"1996-12-20T16:39:59-08:00","key":"A","m":17.0,"n":6.0} -{"_time":851128800000000000,"_subsort":9,"_key_hash":12960666915911099378,"_key":"A","time":"1996-12-20T16:40:00-08:00","key":"A","m":null,"n":9.0} +{"_time":851128800000000000,"_subsort":9,"_key_hash":2521269998124177631,"_key":"C","time":"1996-12-20T16:40:00-08:00","key":"C","m":null,"n":9.0} {"_time":851128801000000000,"_subsort":10,"_key_hash":12960666915911099378,"_key":"A","time":"1996-12-20T16:40:01-08:00","key":"A","m":12.0,"n":null} {"_time":851128802000000000,"_subsort":11,"_key_hash":12960666915911099378,"_key":"A","time":"1996-12-20T16:40:02-08:00","key":"A","m":null,"n":null} diff --git a/sparrow-py/pytests/result_test.py b/sparrow-py/pytests/result_test.py index b0d750372..8628a6aba 100644 --- a/sparrow-py/pytests/result_test.py +++ b/sparrow-py/pytests/result_test.py @@ -59,7 +59,7 @@ async def test_iter_pandas_async_materialize(golden, source_int64) -> None: "1996-12-20T16:39:57-08:00,A,5,10", "1996-12-20T16:39:58-08:00,B,24,3", "1996-12-20T16:39:59-08:00,A,17,6", - "1996-12-20T16:40:00-08:00,A,,9", + "1996-12-20T16:40:00-08:00,C,,9", "1996-12-20T16:40:01-08:00,A,12,", "1996-12-20T16:40:02-08:00,A,,", ]