Skip to content

Commit

Permalink
feat: Allow decoding of non-Polars arrow dictionaries in Arrow and Pa…
Browse files Browse the repository at this point in the history
…rquet (#20248)
  • Loading branch information
coastalwhite authored Dec 15, 2024
1 parent 028297a commit 6d0f5df
Show file tree
Hide file tree
Showing 52 changed files with 1,104 additions and 1,150 deletions.
10 changes: 9 additions & 1 deletion crates/polars-arrow/src/array/binview/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,13 +481,21 @@ impl<T: ViewType + ?Sized> BinaryViewArrayGeneric<T> {
let views = self.views.make_mut();
let completed_buffers = self.buffers.to_vec();
let validity = self.validity.map(|bitmap| bitmap.make_mut());

// We need to know the total_bytes_len if we are going to mutate it.
let mut total_bytes_len = self.total_bytes_len.load(Ordering::Relaxed);
if total_bytes_len == UNKNOWN_LEN {
total_bytes_len = views.iter().map(|view| view.length as u64).sum();
}
let total_bytes_len = total_bytes_len as usize;

MutableBinaryViewArray {
views,
completed_buffers,
in_progress_buffer: vec![],
validity,
phantom: Default::default(),
total_bytes_len: self.total_bytes_len.load(Ordering::Relaxed) as usize,
total_bytes_len,
total_buffer_len: self.total_buffer_len,
stolen_buffers: PlHashMap::new(),
}
Expand Down
9 changes: 9 additions & 0 deletions crates/polars-arrow/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize};
use super::{ArrowDataType, Metadata};

pub static DTYPE_ENUM_VALUES: &str = "_PL_ENUM_VALUES";
pub static DTYPE_CATEGORICAL: &str = "_PL_CATEGORICAL";

/// Represents Arrow's metadata of a "column".
///
Expand Down Expand Up @@ -74,4 +75,12 @@ impl Field {
false
}
}

pub fn is_categorical(&self) -> bool {
if let Some(md) = &self.metadata {
md.get(DTYPE_CATEGORICAL).is_some()
} else {
false
}
}
}
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod schema;
use std::collections::BTreeMap;
use std::sync::Arc;

pub use field::{Field, DTYPE_ENUM_VALUES};
pub use field::{Field, DTYPE_CATEGORICAL, DTYPE_ENUM_VALUES};
pub use physical_type::*;
use polars_utils::pl_str::PlSmallStr;
pub use schema::{ArrowSchema, ArrowSchemaRef};
Expand Down
10 changes: 10 additions & 0 deletions crates/polars-arrow/src/io/avro/read/deserialize.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use avro_schema::file::Block;
use avro_schema::schema::{Enum, Field as AvroField, Record, Schema as AvroSchema};
use polars_error::{polars_bail, polars_err, PolarsResult};
Expand Down Expand Up @@ -506,8 +508,16 @@ pub fn deserialize(
}
}

let projected_schema = fields
.iter_values()
.zip(projection)
.filter_map(|(f, p)| (*p).then_some(f))
.cloned()
.collect();

RecordBatchT::try_new(
rows,
Arc::new(projected_schema),
arrays
.iter_mut()
.zip(projection.iter())
Expand Down
23 changes: 18 additions & 5 deletions crates/polars-arrow/src/io/ipc/read/common.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::VecDeque;
use std::io::{Read, Seek};
use std::sync::Arc;

use polars_error::{polars_bail, polars_err, PolarsResult};
use polars_utils::aliases::PlHashMap;
Expand Down Expand Up @@ -197,7 +198,11 @@ pub fn read_record_batch<R: Read + Seek>(
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
let length = limit.map(|limit| limit.min(length)).unwrap_or(length);

RecordBatchT::try_new(length, columns)
let mut schema: ArrowSchema = fields.iter_values().cloned().collect();
if let Some(projection) = projection {
schema = schema.try_project_indices(projection).unwrap();
}
RecordBatchT::try_new(length, Arc::new(schema), columns)
}

fn find_first_dict_field_d<'a>(
Expand Down Expand Up @@ -373,13 +378,21 @@ pub fn apply_projection(
let length = chunk.len();

// re-order according to projection
let arrays = chunk.into_arrays();
let (schema, arrays) = chunk.into_schema_and_arrays();
let mut new_schema = schema.as_ref().clone();
let mut new_arrays = arrays.clone();

map.iter()
.for_each(|(old, new)| new_arrays[*new] = arrays[*old].clone());
map.iter().for_each(|(old, new)| {
let (old_name, old_field) = schema.get_at_index(*old).unwrap();
let (new_name, new_field) = new_schema.get_at_index_mut(*new).unwrap();

*new_name = old_name.clone();
*new_field = old_field.clone();

new_arrays[*new] = arrays[*old].clone();
});

RecordBatchT::new(length, new_arrays)
RecordBatchT::new(length, Arc::new(new_schema), new_arrays)
}

#[cfg(test)]
Expand Down
8 changes: 7 additions & 1 deletion crates/polars-arrow/src/mmap/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,13 @@ pub(crate) unsafe fn mmap_record<T: AsRef<[u8]>>(
)
})
.collect::<PolarsResult<_>>()
.and_then(|arr| RecordBatchT::try_new(length, arr))
.and_then(|arr| {
RecordBatchT::try_new(
length,
Arc::new(fields.iter_values().cloned().collect()),
arr,
)
})
}

/// Memory maps an record batch from an IPC file into a [`RecordBatchT`].
Expand Down
29 changes: 25 additions & 4 deletions crates/polars-arrow/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
use polars_error::{polars_ensure, PolarsResult};

use crate::array::{Array, ArrayRef};
use crate::datatypes::{ArrowSchema, ArrowSchemaRef};

/// A vector of trait objects of [`Array`] where every item has
/// the same length, [`RecordBatchT::len`].
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RecordBatchT<A: AsRef<dyn Array>> {
height: usize,
schema: ArrowSchemaRef,
arrays: Vec<A>,
}

Expand All @@ -21,29 +23,42 @@ impl<A: AsRef<dyn Array>> RecordBatchT<A> {
/// # Panics
///
/// I.f.f. the length does not match the length of any of the arrays
pub fn new(length: usize, arrays: Vec<A>) -> Self {
Self::try_new(length, arrays).unwrap()
pub fn new(length: usize, schema: ArrowSchemaRef, arrays: Vec<A>) -> Self {
Self::try_new(length, schema, arrays).unwrap()
}

/// Creates a new [`RecordBatchT`].
///
/// # Error
///
/// I.f.f. the height does not match the length of any of the arrays
pub fn try_new(height: usize, arrays: Vec<A>) -> PolarsResult<Self> {
pub fn try_new(height: usize, schema: ArrowSchemaRef, arrays: Vec<A>) -> PolarsResult<Self> {
polars_ensure!(
schema.len() == arrays.len(),
ComputeError: "RecordBatch requires an equal number of fields and arrays",
);
polars_ensure!(
arrays.iter().all(|arr| arr.as_ref().len() == height),
ComputeError: "RecordBatch requires all its arrays to have an equal number of rows",
);

Ok(Self { height, arrays })
Ok(Self {
height,
schema,
arrays,
})
}

/// returns the [`Array`]s in [`RecordBatchT`]
pub fn arrays(&self) -> &[A] {
&self.arrays
}

/// returns the [`ArrowSchema`]s in [`RecordBatchT`]
pub fn schema(&self) -> &ArrowSchema {
&self.schema
}

/// returns the [`Array`]s in [`RecordBatchT`]
pub fn columns(&self) -> &[A] {
&self.arrays
Expand Down Expand Up @@ -74,6 +89,12 @@ impl<A: AsRef<dyn Array>> RecordBatchT<A> {
pub fn into_arrays(self) -> Vec<A> {
self.arrays
}

/// Consumes [`RecordBatchT`] into its underlying schema and arrays.
/// The arrays are guaranteed to have the same length
pub fn into_schema_and_arrays(self) -> (ArrowSchemaRef, Vec<A>) {
(self.schema, self.arrays)
}
}

impl<A: AsRef<dyn Array>> From<RecordBatchT<A>> for Vec<A> {
Expand Down
1 change: 1 addition & 0 deletions crates/polars-compute/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub mod horizontal_flatten;
pub mod hyperloglogplus;
pub mod if_then_else;
pub mod min_max;
pub mod propagate_dictionary;
pub mod size;
pub mod unique;
pub mod var_cov;
Expand Down
87 changes: 87 additions & 0 deletions crates/polars-compute/src/propagate_dictionary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use arrow::array::{Array, BinaryViewArray, PrimitiveArray, Utf8ViewArray};
use arrow::bitmap::Bitmap;
use arrow::datatypes::ArrowDataType::UInt32;

/// Propagate the nulls from the dictionary values into the keys and remove those nulls from the
/// values.
pub fn propagate_dictionary_value_nulls(
keys: &PrimitiveArray<u32>,
values: &Utf8ViewArray,
) -> (PrimitiveArray<u32>, Utf8ViewArray) {
let Some(values_validity) = values.validity() else {
return (keys.clone(), values.clone().with_validity(None));
};
if values_validity.unset_bits() == 0 {
return (keys.clone(), values.clone().with_validity(None));
}

let num_values = values.len();

// Create a map from the old indices to indices with nulls filtered out
let mut offset = 0;
let new_idx_map: Vec<u32> = (0..num_values)
.map(|i| {
let is_valid = unsafe { values_validity.get_bit_unchecked(i) };
offset += usize::from(!is_valid);
if is_valid {
(i - offset) as u32
} else {
0
}
})
.collect();

let keys = match keys.validity() {
None => {
let values = keys
.values()
.iter()
.map(|&k| unsafe {
// SAFETY: Arrow invariant that all keys are in range of values
*new_idx_map.get_unchecked(k as usize)
})
.collect();
let validity = Bitmap::from_iter(keys.values().iter().map(|&k| unsafe {
// SAFETY: Arrow invariant that all keys are in range of values
values_validity.get_bit_unchecked(k as usize)
}));

PrimitiveArray::new(UInt32, values, Some(validity))
},
Some(keys_validity) => {
let values = keys
.values()
.iter()
.map(|&k| {
// deal with nulls in keys
let idx = (k as usize).min(num_values);
// SAFETY: Arrow invariant that all keys are in range of values
*unsafe { new_idx_map.get_unchecked(idx) }
})
.collect();
let propagated_validity = Bitmap::from_iter(keys.values().iter().map(|&k| {
// deal with nulls in keys
let idx = (k as usize).min(num_values);
// SAFETY: Arrow invariant that all keys are in range of values
unsafe { values_validity.get_bit_unchecked(idx) }
}));

let validity = &propagated_validity & keys_validity;
PrimitiveArray::new(UInt32, values, Some(validity))
},
};

// Filter only handles binary
let values = values.to_binview();

// Filter out the null values
let values = crate::filter::filter_with_bitmap(&values, values_validity);
let values = values.as_any().downcast_ref::<BinaryViewArray>().unwrap();
let values = unsafe { values.to_utf8view_unchecked() }.clone();

// Explicitly set the values validity to none.
assert_eq!(values.null_count(), 0);
let values = values.with_validity(None);

(keys, values)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,10 @@ use std::hash::{BuildHasher, Hash, Hasher};

use arrow::array::*;
use polars_utils::aliases::PlRandomState;
#[cfg(any(feature = "serde-lazy", feature = "serde"))]
use serde::{Deserialize, Serialize};

use crate::datatypes::PlHashMap;
use crate::using_string_cache;

#[derive(Debug, Copy, Clone, PartialEq, Default)]
#[cfg_attr(
any(feature = "serde-lazy", feature = "serde"),
derive(Serialize, Deserialize)
)]
pub enum CategoricalOrdering {
#[default]
Physical,
Lexical,
}

#[derive(Clone)]
pub enum RevMapping {
/// Hashmap: maps the indexes from the global cache/categorical array to indexes in the local Utf8Array
Expand Down Expand Up @@ -85,6 +72,7 @@ impl RevMapping {
}

pub fn build_local(categories: Utf8ViewArray) -> Self {
debug_assert_eq!(categories.null_count(), 0);
let hash = Self::build_hash(&categories);
Self::Local(categories, hash)
}
Expand Down
20 changes: 0 additions & 20 deletions crates/polars-core/src/datatypes/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1791,26 +1791,6 @@ mod test {
))),
DataType::List(DataType::Float64.into()),
),
(
ArrowDataType::Dictionary(IntegerType::UInt32, ArrowDataType::Utf8.into(), false),
DataType::Categorical(None, Default::default()),
),
(
ArrowDataType::Dictionary(
IntegerType::UInt32,
ArrowDataType::LargeUtf8.into(),
false,
),
DataType::Categorical(None, Default::default()),
),
(
ArrowDataType::Dictionary(
IntegerType::UInt64,
ArrowDataType::LargeUtf8.into(),
false,
),
DataType::Categorical(None, Default::default()),
),
];

for (dt_a, dt_p) in dtypes {
Expand Down
Loading

0 comments on commit 6d0f5df

Please sign in to comment.