Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Add a test for RowFilter with nested type #5600

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions parquet/src/arrow/async_reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1857,4 +1857,92 @@ mod tests {
assert_eq!(total_rows, expected);
}
}

#[tokio::test]
async fn test_row_filter_nested() {
let a = StringArray::from_iter_values(["a", "b", "b", "b", "c", "c"]);
let b = StructArray::from(vec![
(
Arc::new(Field::new("aa", DataType::Utf8, true)),
Arc::new(StringArray::from(vec!["a", "b", "b", "b", "c", "c"])) as ArrayRef,
),
(
Arc::new(Field::new("bb", DataType::Utf8, true)),
Arc::new(StringArray::from(vec!["1", "2", "3", "4", "5", "6"])) as ArrayRef,
),
]);
let c = Int32Array::from_iter(0..6);
let data = RecordBatch::try_from_iter([
("a", Arc::new(a) as ArrayRef),
("b", Arc::new(b) as ArrayRef),
("c", Arc::new(c) as ArrayRef),
])
.unwrap();

let mut buf = Vec::with_capacity(1024);
let mut writer = ArrowWriter::try_new(&mut buf, data.schema(), None).unwrap();
writer.write(&data).unwrap();
writer.close().unwrap();

let data: Bytes = buf.into();
let metadata = parse_metadata(&data).unwrap();
let parquet_schema = metadata.file_metadata().schema_descr_ptr();

let test = TestReader {
data,
metadata: Arc::new(metadata),
requests: Default::default(),
};
let requests = test.requests.clone();

let a_scalar = StringArray::from_iter_values(["b"]);
let a_filter = ArrowPredicateFn::new(
ProjectionMask::leaves(&parquet_schema, vec![0]),
move |batch| eq(batch.column(0), &Scalar::new(&a_scalar)),
);

let b_scalar = StringArray::from_iter_values(["4"]);
let b_filter = ArrowPredicateFn::new(
ProjectionMask::leaves(&parquet_schema, vec![2]),
move |batch| {
// Filter on the second element of the struct.
let struct_array = batch
.column(0)
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
eq(struct_array.column(0), &Scalar::new(&b_scalar))
Comment on lines +1908 to +1914
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, the row filter needs to know what the schema is so it can get correct (nested) column to do filtering. For the general filter implementation like apache/iceberg-rust#295 proposes to be, is any utility we can use to "flatten" nested columns from the batch?

In other words, is any existing way to flatten projected (nested) columns in the batch? So if we know a leaf column's index, we can know its position in projection mask and the flatten batch. Then we can simply get the column by flatten_batch.column.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the challenge here is when the nested arrays are either repeated or nullable, in such a case trying to interpret the leaves in isolation isn't necessarily meaningful

},
);

let filter = RowFilter::new(vec![Box::new(a_filter), Box::new(b_filter)]);

let mask = ProjectionMask::leaves(&parquet_schema, vec![0, 3]);
let stream = ParquetRecordBatchStreamBuilder::new(test)
.await
.unwrap()
.with_projection(mask.clone())
.with_batch_size(1024)
.with_row_filter(filter)
.build()
.unwrap();

let batches: Vec<_> = stream.try_collect().await.unwrap();
assert_eq!(batches.len(), 1);

let batch = &batches[0];
assert_eq!(batch.num_rows(), 1);
assert_eq!(batch.num_columns(), 2);

let col = batch.column(0);
let val = col.as_any().downcast_ref::<StringArray>().unwrap().value(0);
assert_eq!(val, "b");

let col = batch.column(1);
let val = col.as_any().downcast_ref::<Int32Array>().unwrap().value(0);
assert_eq!(val, 3);

// Should only have made 3 requests
assert_eq!(requests.lock().unwrap().len(), 3);
}
}
Loading