diff --git a/crates/sparrow-batch/src/batch.rs b/crates/sparrow-batch/src/batch.rs index 8f7a26acd..86a48b30d 100644 --- a/crates/sparrow-batch/src/batch.rs +++ b/crates/sparrow-batch/src/batch.rs @@ -1,15 +1,17 @@ +use std::sync::Arc; + +use arrow::ipc::Timestamp; use arrow_array::cast::AsArray; use arrow_array::types::{TimestampNanosecondType, UInt64Type}; use arrow_array::{ - Array, ArrayRef, ArrowPrimitiveType, RecordBatch, TimestampNanosecondArray, UInt64Array, + Array, ArrayRef, ArrowPrimitiveType, RecordBatch, StructArray, TimestampNanosecondArray, + UInt64Array, }; -use arrow_schema::Schema; +use arrow_schema::{Field, Fields, Schema}; use error_stack::{IntoReport, IntoReportCompat, ResultExt}; use itertools::Itertools; -use sparrow_core::KeyTriple; -use sparrow_core::KeyTriples; -use crate::RowTime; +use crate::{Error, RowTime}; /// A batch to be processed by the system. #[derive(Clone, PartialEq, Debug)] @@ -17,64 +19,62 @@ pub struct Batch { /// The data associated with the batch. pub(crate) data: Option, - /// The bounds of this batch. - /// - /// Any rows in this batch must have a key greater than or equal to `lower_bound` - /// Any rows in this batch must have a key less than or equal to `upper_bound` - /// Any rows in the next batch must have a key greater than or equal to the `upper_bound`. - /// - /// Note that these bounds may not necessarily represent the actual rows in the batch, - /// as the batch may have been filtered. The bounds are unchanged from the original - /// so as to allow downstream consumers to proceed in time. - pub lower_bound: KeyTriple, - pub upper_bound: KeyTriple, + /// An indication that the batch stream has completed up to the given time. + /// Any rows in future batches on this stream must have a time strictly + /// greater than this. + pub up_to_time: RowTime, } impl Batch { - pub fn new_empty() -> Self { + pub fn new_empty(up_to_time: RowTime) -> Self { Self { data: None, - lower_bound: KeyTriple::MIN, - upper_bound: KeyTriple::MIN, + up_to_time, } } - /// Construct a new batch, inferring the lower and upper bound from the - /// data. + /// Construct a new batch, inferring the upper bound from the data. /// /// It is expected that that data is sorted. - /// - /// This should be used only when the initial [Batch] is created while - /// reading a table. Otherwise, bounds should be preserved using - /// [Batch::with_data] or [Batch::try_new_bounds]. pub fn try_new_from_batch(data: RecordBatch) -> error_stack::Result { error_stack::ensure!( data.num_rows() > 0, - Error::internal_msg("Unable to create batch from empty data") + Error::internal_msg("Unable to create batch from empty data".to_owned()) ); - let key_triples = KeyTriples::try_from(&data) - .into_report() - .change_context(Error::internal())?; - Self::try_new_with_bounds( - data, - key_triples.first().ok_or(Error::internal())?, - key_triples.last().ok_or(Error::internal())?, - ) - } + let time_column: &TimestampNanosecondArray = data.column(0).as_primitive(); + let up_to_time = RowTime::from_timestamp_ns(time_column.value(time_column.len() - 1)); - pub fn try_new_with_bounds( - data: RecordBatch, - lower_bound: KeyTriple, - upper_bound: KeyTriple, - ) -> error_stack::Result { #[cfg(debug_assertions)] - validate(&data, &lower_bound, &upper_bound)?; + validate(&data, up_to_time)?; + // TODO: Extract the columns I want + + let time: &TimestampNanosecondArray = data.column(0).as_primitive(); + let min_present_time = time.value(0); + let max_present_time = time.value(time.len() - 1); + + let time = data.column(0).clone(); + let subsort = data.column(1).clone(); + let key_hash = data.column(2).clone(); + + let schema = data.schema(); + let fields: Vec> = schema.fields()[3..].to_vec(); + let fields = Fields::from(fields); + let columns: Vec = data.columns()[3..].to_vec(); + + // TODO: ...I can't remember if I need to get the null buffers out for this + let data = Arc::new(StructArray::new(fields, columns, None)); Ok(Self { - data, - lower_bound, - upper_bound, + data: Some(BatchInfo { + data, + time, + subsort, + key_hash, + min_present_time: min_present_time.into(), + max_present_time: max_present_time.into(), + }), + up_to_time, }) } @@ -335,9 +335,8 @@ impl Batch { key_hash: impl Into, up_to_time: i64, ) -> Self { - use std::sync::Arc; - use arrow_array::StructArray; + use std::sync::Arc; let time: TimestampNanosecondArray = time.into(); let subsort: UInt64Array = (0..(time.len() as u64)).collect_vec().into(); @@ -364,41 +363,31 @@ impl Batch { } #[cfg(debug_assertions)] -fn validate( - data: &RecordBatch, - lower_bound: &KeyTriple, - upper_bound: &KeyTriple, -) -> error_stack::Result<(), Error> { +fn validate(data: &RecordBatch, up_to_time: RowTime) -> error_stack::Result<(), Error> { validate_batch_schema(data.schema().as_ref())?; for key_column in 0..3 { error_stack::ensure!( data.column(key_column).null_count() == 0, - Error::internal_msg(&format!( + Error::internal_msg(format!( "Key column '{}' should not contain null", data.schema().field(key_column).name() )) ); } - validate_bounds( - data.column(0), - data.column(1), - data.column(2), - lower_bound, - upper_bound, - ) + validate_bounds(data.column(0), data.column(1), data.column(2), up_to_time) } #[cfg(debug_assertions)] -/// Validate that the result is totally sorted and that the lower and upper -/// bound are correct. +/// Validate that the result is sorted. +/// +/// Note: This only validates the up_to_time bound. pub(crate) fn validate_bounds( time: &ArrayRef, subsort: &ArrayRef, key_hash: &ArrayRef, - lower_bound: &KeyTriple, - upper_bound: &KeyTriple, + up_to_time: RowTime, ) -> error_stack::Result<(), Error> { if time.len() == 0 { // No more validation necessary for empty batches. @@ -409,30 +398,9 @@ pub(crate) fn validate_bounds( let subsort: &UInt64Array = subsort.as_primitive(); let key_hash: &UInt64Array = key_hash.as_primitive(); - let mut prev_time = lower_bound.time; - let mut prev_subsort = lower_bound.subsort; - let mut prev_key_hash = lower_bound.key_hash; - - let curr_time = time.value(0); - let curr_subsort = subsort.value(0); - let curr_key_hash = key_hash.value(0); - - let order = prev_time - .cmp(&curr_time) - .then(prev_subsort.cmp(&curr_subsort)) - .then(prev_key_hash.cmp(&curr_key_hash)); - - error_stack::ensure!( - order.is_le(), - Error::internal_msg(&format!( - "Expected lower_bound <= data[0], but ({}, {}, {}) > ({}, {}, {})", - prev_time, prev_subsort, prev_key_hash, curr_time, curr_subsort, curr_key_hash - )) - ); - - prev_time = curr_time; - prev_subsort = curr_subsort; - prev_key_hash = curr_key_hash; + let mut prev_time = time.value(0); + let mut prev_subsort = subsort.value(0); + let mut prev_key_hash = key_hash.value(0); for i in 1..time.len() { let curr_time = time.value(i); @@ -446,7 +414,7 @@ pub(crate) fn validate_bounds( error_stack::ensure!( order.is_lt(), - Error::internal_msg(&format!( + Error::internal_msg(format!( "Expected data[i - 1] < data[i], but ({}, {}, {}) >= ({}, {}, {})", prev_time, prev_subsort, prev_key_hash, curr_time, curr_subsort, curr_key_hash )) @@ -457,20 +425,13 @@ pub(crate) fn validate_bounds( prev_key_hash = curr_key_hash; } - let curr_time = upper_bound.time; - let curr_subsort = upper_bound.subsort; - let curr_key_hash = upper_bound.key_hash; - - let order = prev_time - .cmp(&curr_time) - .then(prev_subsort.cmp(&curr_subsort)) - .then(prev_key_hash.cmp(&curr_key_hash)); - + let curr_time: i64 = up_to_time.into(); + let order = prev_time.cmp(&curr_time); error_stack::ensure!( order.is_le(), - Error::internal_msg(&format!( - "Expected last data <= upper bound, but ({}, {}, {}) > ({}, {}, {})", - prev_time, prev_subsort, prev_key_hash, curr_time, curr_subsort, curr_key_hash + Error::internal_msg(format!( + "Expected last data <= upper bound, but ({}) > ({})", + prev_time, curr_time )) ); @@ -501,7 +462,7 @@ fn validate_key_column( ) -> error_stack::Result<(), Error> { error_stack::ensure!( schema.field(index).name() == expected_name, - Error::internal_msg(&format!( + Error::internal_msg(format!( "Expected column {} to be '{}' but was '{}'", index, expected_name, @@ -510,7 +471,7 @@ fn validate_key_column( ); error_stack::ensure!( schema.field(index).data_type() == expected_type, - Error::internal_msg(&format!( + Error::internal_msg(format!( "Key column '{}' should be '{:?}' but was '{:?}'", expected_name, schema.field(index).data_type(), @@ -617,24 +578,6 @@ impl BatchInfo { } } -#[derive(Debug, derive_more::Display)] -pub enum Error { - #[display(fmt = "internal error: {_0:?}")] - Internal(&'static str), -} - -impl Error { - pub fn internal() -> Self { - Self::Internal("no additional context") - } - - pub fn internal_msg(msg: &'static str) -> Self { - Self::Internal(msg) - } -} - -impl error_stack::Context for Error {} - #[cfg(any(test, feature = "testing"))] #[static_init::dynamic] static MINIMAL_SCHEMA: arrow_schema::SchemaRef = { diff --git a/crates/sparrow-batch/src/row_time.rs b/crates/sparrow-batch/src/row_time.rs index 687e80a86..bcba18ae3 100644 --- a/crates/sparrow-batch/src/row_time.rs +++ b/crates/sparrow-batch/src/row_time.rs @@ -33,3 +33,9 @@ impl From for i64 { val.0 } } + +impl Into for i64 { + fn into(self) -> RowTime { + RowTime(self) + } +}