diff --git a/acceptance/src/data.rs b/acceptance/src/data.rs index d971e7fde..a6a50ef77 100644 --- a/acceptance/src/data.rs +++ b/acceptance/src/data.rs @@ -130,6 +130,7 @@ pub async fn assert_scan_data(engine: Arc, test_case: &TestCaseInfo) .execute(engine)? .map(|scan_result| -> DeltaResult<_> { let scan_result = scan_result?; + let mask = scan_result.full_mask(); let data = scan_result.raw_data?; let record_batch: RecordBatch = data .into_any() @@ -139,7 +140,7 @@ pub async fn assert_scan_data(engine: Arc, test_case: &TestCaseInfo) if schema.is_none() { schema = Some(record_batch.schema()); } - if let Some(mask) = scan_result.mask { + if let Some(mask) = mask { Ok(filter_record_batch(&record_batch, &mask.into())?) } else { Ok(record_batch) diff --git a/kernel/examples/read-table-single-threaded/src/main.rs b/kernel/examples/read-table-single-threaded/src/main.rs index 08b6920d7..5c9d23d41 100644 --- a/kernel/examples/read-table-single-threaded/src/main.rs +++ b/kernel/examples/read-table-single-threaded/src/main.rs @@ -117,18 +117,14 @@ fn try_main() -> DeltaResult<()> { .execute(engine.as_ref())? .map(|scan_result| -> DeltaResult<_> { let scan_result = scan_result?; + let mask = scan_result.full_mask(); let data = scan_result.raw_data?; let record_batch: RecordBatch = data .into_any() .downcast::() .map_err(|_| delta_kernel::Error::EngineDataType("ArrowEngineData".to_string()))? .into(); - if let Some(mut mask) = scan_result.mask { - let extra_rows = record_batch.num_rows() - mask.len(); - if extra_rows > 0 { - // we need to extend the mask here in case it's too short - mask.extend(std::iter::repeat(true).take(extra_rows)); - } + if let Some(mask) = mask { Ok(filter_record_batch(&record_batch, &mask.into())?) } else { Ok(record_batch) diff --git a/kernel/src/scan/mod.rs b/kernel/src/scan/mod.rs index d3219c900..9159313a3 100644 --- a/kernel/src/scan/mod.rs +++ b/kernel/src/scan/mod.rs @@ -124,14 +124,35 @@ impl ScanBuilder { pub struct ScanResult { /// Raw engine data as read from the disk for a particular file included in the query pub raw_data: DeltaResult>, - /// If an item at mask\[i\] is true, the row at that row index is valid, otherwise if it is - /// false, the row at that row index is invalid and should be ignored. If the mask is *shorter* - /// than the number of rows returned, missing elements are considered `true`, i.e. included in - /// the query. If this is None, all rows are valid. NB: If you are using the default engine and - /// plan to call arrow's `filter_record_batch`, you _need_ to extend this vector to the full - /// length of the batch or arrow will drop the extra rows + /// Raw row mask. // TODO(nick) this should be allocated by the engine - pub mask: Option>, + raw_mask: Option>, +} + +impl ScanResult { + /// Returns the raw row mask. If an item at `raw_mask()[i]` is true, row `i` is + /// valid. Otherwise, row `i` is invalid and should be ignored. + /// + /// The raw mask is dangerous to use because it may be shorter than expected. In particular, if + /// you are using the default engine and plan to call arrow's `filter_record_batch`, you _need_ + /// to extend the mask to the full length of the batch or arrow will drop the extra + /// rows. Calling [`full_mask`] instead avoids this risk entirely, at the cost of a copy. + pub fn raw_mask(&self) -> Option<&Vec> { + self.raw_mask.as_ref() + } + + /// Extends the underlying (raw) mask to match the row count of the accompanying data. + /// + /// If the raw mask is *shorter* than the number of rows returned, missing elements are + /// considered `true`, i.e. included in the query. If the mask is `None`, all rows are valid. + /// + /// NB: If you are using the default engine and plan to call arrow's `filter_record_batch`, you + /// _need_ to extend the mask to the full length of the batch or arrow will drop the extra rows. + pub fn full_mask(&self) -> Option> { + let mut mask = self.raw_mask.clone()?; + mask.resize(self.raw_data.as_ref().ok()?.length(), true); + Some(mask) + } } /// Scan uses this to set up what kinds of columns it is scanning. For `Selected` we just store the @@ -314,7 +335,7 @@ impl Scan { let rest = split_vector(sv.as_mut(), len, None); let result = ScanResult { raw_data: logical, - mask: sv, + raw_mask: sv, }; selection_vector = rest; Ok(result) diff --git a/kernel/tests/dv.rs b/kernel/tests/dv.rs index 6ffa9f430..593361e2f 100644 --- a/kernel/tests/dv.rs +++ b/kernel/tests/dv.rs @@ -16,12 +16,10 @@ fn count_total_scan_rows( scan_result_iter .map(|scan_result| { let scan_result = scan_result?; - let data = scan_result.raw_data?; // NOTE: The mask only suppresses rows for which it is both present and false. - let deleted_rows = scan_result - .mask - .as_ref() - .map_or(0, |mask| mask.iter().filter(|&&m| !m).count()); + let mask = scan_result.raw_mask(); + let deleted_rows = mask.into_iter().flatten().filter(|&&m| !m).count(); + let data = scan_result.raw_data?; Ok(data.length() - deleted_rows) }) .fold_ok(0, Add::add) diff --git a/kernel/tests/golden_tables.rs b/kernel/tests/golden_tables.rs index c3c54ef3d..806171373 100644 --- a/kernel/tests/golden_tables.rs +++ b/kernel/tests/golden_tables.rs @@ -161,14 +161,10 @@ async fn latest_snapshot_test( let batches: Vec = scan_res .map(|scan_result| -> DeltaResult<_> { let scan_result = scan_result?; + let mask = scan_result.full_mask(); let data = scan_result.raw_data?; let record_batch = to_arrow(data)?; - if let Some(mut mask) = scan_result.mask { - let extra_rows = record_batch.num_rows() - mask.len(); - if extra_rows > 0 { - // we need to extend the mask here in case it's too short - mask.extend(std::iter::repeat(true).take(extra_rows)); - } + if let Some(mask) = mask { Ok(filter_record_batch(&record_batch, &mask.into())?) } else { Ok(record_batch) diff --git a/kernel/tests/read.rs b/kernel/tests/read.rs index 815d6104a..d478b6a43 100644 --- a/kernel/tests/read.rs +++ b/kernel/tests/read.rs @@ -401,14 +401,10 @@ fn read_with_execute( let batches: Vec = scan_results .map(|scan_result| -> DeltaResult<_> { let scan_result = scan_result?; + let mask = scan_result.full_mask(); let data = scan_result.raw_data?; let record_batch = to_arrow(data)?; - if let Some(mut mask) = scan_result.mask { - let extra_rows = record_batch.num_rows() - mask.len(); - if extra_rows > 0 { - // we need to extend the mask here in case it's too short - mask.extend(std::iter::repeat(true).take(extra_rows)); - } + if let Some(mask) = mask { Ok(filter_record_batch(&record_batch, &mask.into())?) } else { Ok(record_batch)