Skip to content

Commit

Permalink
feat: add take_record_batch. (#5333)
Browse files Browse the repository at this point in the history
* feat: add `take_record_batch`.

* Improve the function argument and enrich the docs.

* Follow the comments.
  • Loading branch information
RinChanNOWWW authored Jan 27, 2024
1 parent 8fff5e4 commit 5117b38
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions arrow-select/src/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,57 @@ to_indices_reinterpret!(Int32Type, UInt32Type);
to_indices_identity!(UInt64Type);
to_indices_reinterpret!(Int64Type, UInt64Type);

/// Take rows by index from [`RecordBatch`] and returns a new [`RecordBatch`] from those indexes.
///
/// This function will call [`take`] on each array of the [`RecordBatch`] and assemble a new [`RecordBatch`].
///
/// # Example
/// ```
/// # use std::sync::Arc;
/// # use arrow_array::{StringArray, Int32Array, UInt32Array, RecordBatch};
/// # use arrow_schema::{DataType, Field, Schema};
/// # use arrow_select::take::take_record_batch;
///
/// let schema = Arc::new(Schema::new(vec![
/// Field::new("a", DataType::Int32, true),
/// Field::new("b", DataType::Utf8, true),
/// ]));
/// let batch = RecordBatch::try_new(
/// schema.clone(),
/// vec![
/// Arc::new(Int32Array::from_iter_values(0..20)),
/// Arc::new(StringArray::from_iter_values(
/// (0..20).map(|i| format!("str-{}", i)),
/// )),
/// ],
/// )
/// .unwrap();
///
/// let indices = UInt32Array::from(vec![1, 5, 10]);
/// let taken = take_record_batch(&batch, &indices).unwrap();
///
/// let expected = RecordBatch::try_new(
/// schema,
/// vec![
/// Arc::new(Int32Array::from(vec![1, 5, 10])),
/// Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])),
/// ],
/// )
/// .unwrap();
/// assert_eq!(taken, expected);
/// ```
pub fn take_record_batch(
record_batch: &RecordBatch,
indices: &dyn Array,
) -> Result<RecordBatch, ArrowError> {
let columns = record_batch
.columns()
.iter()
.map(|c| take(c, indices, None))
.collect::<Result<Vec<_>, _>>()?;
RecordBatch::try_new(record_batch.schema(), columns)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit 5117b38

Please sign in to comment.