diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs index 44269e38758..172e61dca7e 100644 --- a/arrow-select/src/take.rs +++ b/arrow-select/src/take.rs @@ -769,6 +769,57 @@ to_indices_reinterpret!(Int32Type, UInt32Type); to_indices_identity!(UInt64Type); to_indices_reinterpret!(Int64Type, UInt64Type); +/// Take rows by index from [`RecordBatch`] and returns a new [`RecordBatch`] from those indexes. +/// +/// This function will call [`take`] on each array of the [`RecordBatch`] and assemble a new [`RecordBatch`]. +/// +/// # Example +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{StringArray, Int32Array, UInt32Array, RecordBatch}; +/// # use arrow_schema::{DataType, Field, Schema}; +/// # use arrow_select::take::take_record_batch; +/// +/// let schema = Arc::new(Schema::new(vec![ +/// Field::new("a", DataType::Int32, true), +/// Field::new("b", DataType::Utf8, true), +/// ])); +/// let batch = RecordBatch::try_new( +/// schema.clone(), +/// vec![ +/// Arc::new(Int32Array::from_iter_values(0..20)), +/// Arc::new(StringArray::from_iter_values( +/// (0..20).map(|i| format!("str-{}", i)), +/// )), +/// ], +/// ) +/// .unwrap(); +/// +/// let indices = UInt32Array::from(vec![1, 5, 10]); +/// let taken = take_record_batch(&batch, &indices).unwrap(); +/// +/// let expected = RecordBatch::try_new( +/// schema, +/// vec![ +/// Arc::new(Int32Array::from(vec![1, 5, 10])), +/// Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])), +/// ], +/// ) +/// .unwrap(); +/// assert_eq!(taken, expected); +/// ``` +pub fn take_record_batch( + record_batch: &RecordBatch, + indices: &dyn Array, +) -> Result { + let columns = record_batch + .columns() + .iter() + .map(|c| take(c, indices, None)) + .collect::, _>>()?; + RecordBatch::try_new(record_batch.schema(), columns) +} + #[cfg(test)] mod tests { use super::*;