Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Add a VarInt encoding for the row encoding #19929

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions crates/polars-arrow/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,14 @@ impl ArrowDataType {
)
}

pub fn is_integer(&self) -> bool {
use ArrowDataType as D;
matches!(
self,
D::Int8 | D::Int16 | D::Int32 | D::Int64 | D::UInt8 | D::UInt16 | D::UInt32 | D::UInt64
)
}

pub fn to_fixed_size_list(self, size: usize, is_nullable: bool) -> ArrowDataType {
ArrowDataType::FixedSizeList(
Box::new(Field::new(
Expand Down
1 change: 1 addition & 0 deletions crates/polars-core/src/chunked_array/ops/row_encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ pub fn _get_rows_encoded(
descending: *desc,
nulls_last: *null_last,
no_order: false,
enable_varint: false,
};
cols.push(arr);
fields.push(sort_field);
Expand Down
5 changes: 3 additions & 2 deletions crates/polars-python/src/dataframe/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ impl PyDataFrame {
fn _row_encode<'py>(
&'py self,
py: Python<'py>,
fields: Vec<(bool, bool, bool)>,
fields: Vec<(bool, bool, bool, bool)>,
) -> PyResult<PySeries> {
py.allow_threads(|| {
let mut df = self.df.clone();
Expand All @@ -732,10 +732,11 @@ impl PyDataFrame {
let fields = fields
.into_iter()
.map(
|(descending, nulls_last, no_order)| polars_row::EncodingField {
|(descending, nulls_last, no_order, enable_varint)| polars_row::EncodingField {
descending,
nulls_last,
no_order,
enable_varint,
},
)
.collect::<Vec<_>>();
Expand Down
5 changes: 3 additions & 2 deletions crates/polars-python/src/series/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,18 +541,19 @@ impl PySeries {
&'py self,
py: Python<'py>,
dtypes: Vec<(String, Wrap<DataType>)>,
fields: Vec<(bool, bool, bool)>,
fields: Vec<(bool, bool, bool, bool)>,
) -> PyResult<PyDataFrame> {
py.allow_threads(|| {
assert_eq!(dtypes.len(), fields.len());

let fields = fields
.into_iter()
.map(
|(descending, nulls_last, no_order)| polars_row::EncodingField {
|(descending, nulls_last, no_order, enable_varint)| polars_row::EncodingField {
descending,
nulls_last,
no_order,
enable_varint,
},
)
.collect::<Vec<_>>();
Expand Down
27 changes: 21 additions & 6 deletions crates/polars-row/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ unsafe fn decode_validity(rows: &mut [&[u8]], field: &EncodingField) -> Option<B

// We inline this in an attempt to avoid the dispatch cost.
#[inline(always)]
fn dtype_and_data_to_encoded_item_len(
unsafe fn dtype_and_data_to_encoded_item_len(
dtype: &ArrowDataType,
data: &[u8],
field: &EncodingField,
) -> usize {
// Fast path: if the size is fixed, we can just divide.
if let Some(size) = fixed_size(dtype) {
if let Some(size) = fixed_size(dtype, field) {
return size;
}

Expand Down Expand Up @@ -129,6 +129,16 @@ fn dtype_and_data_to_encoded_item_len(
item_len
},

D::Int8 => <i8 as crate::variable::varint::VarIntEncoding>::len_from_buffer(data, field),
D::Int16 => <i16 as crate::variable::varint::VarIntEncoding>::len_from_buffer(data, field),
D::Int32 => <i32 as crate::variable::varint::VarIntEncoding>::len_from_buffer(data, field),
D::Int64 => <i64 as crate::variable::varint::VarIntEncoding>::len_from_buffer(data, field),

D::UInt8 => <u8 as crate::variable::varint::VarIntEncoding>::len_from_buffer(data, field),
D::UInt16 => <u16 as crate::variable::varint::VarIntEncoding>::len_from_buffer(data, field),
D::UInt32 => <u32 as crate::variable::varint::VarIntEncoding>::len_from_buffer(data, field),
D::UInt64 => <u64 as crate::variable::varint::VarIntEncoding>::len_from_buffer(data, field),

D::Union(_, _, _) => todo!(),
D::Map(_, _) => todo!(),
D::Dictionary(_, _, _) => todo!(),
Expand All @@ -152,7 +162,7 @@ fn rows_for_fixed_size_list<'a>(
nested_rows.reserve(rows.len() * width);

// Fast path: if the size is fixed, we can just divide.
if let Some(size) = fixed_size(dtype) {
if let Some(size) = fixed_size(dtype, field) {
for row in rows.iter_mut() {
for i in 0..width {
nested_rows.push(&row[(i * size)..][..size]);
Expand Down Expand Up @@ -204,7 +214,7 @@ fn rows_for_fixed_size_list<'a>(
// @TODO: This is quite slow since we need to dispatch for possibly every nested type
for row in rows.iter_mut() {
for _ in 0..width {
let length = dtype_and_data_to_encoded_item_len(dtype, row, field);
let length = unsafe { dtype_and_data_to_encoded_item_len(dtype, row, field) };
let v;
(v, *row) = row.split_at(length);
nested_rows.push(v);
Expand All @@ -223,7 +233,7 @@ fn offsets_from_dtype_and_data(
offsets.clear();

// Fast path: if the size is fixed, we can just divide.
if let Some(size) = fixed_size(dtype) {
if let Some(size) = fixed_size(dtype, field) {
assert!(size == 0 || data.len() % size == 0);
offsets.extend((0..data.len() / size).map(|i| i * size));
return;
Expand Down Expand Up @@ -267,7 +277,7 @@ fn offsets_from_dtype_and_data(
let mut data = data;
let mut offset = 0;
while !data.is_empty() {
let length = dtype_and_data_to_encoded_item_len(dtype, data, field);
let length = unsafe { dtype_and_data_to_encoded_item_len(dtype, data, field) };
offsets.push(offset);
data = &data[length..];
offset += length;
Expand Down Expand Up @@ -362,6 +372,11 @@ unsafe fn decode(rows: &mut [&[u8]], field: &EncodingField, dtype: &ArrowDataTyp
)
.to_boxed()
},
dt if field.enable_varint && dt.is_integer() => {
with_match_arrow_integer_type!(dt, |$T| {
crate::variable::varint::decode::<$T>(rows, field).to_boxed()
})
},
dt => {
with_match_arrow_primitive_type!(dt, |$T| {
decode_primitive::<$T>(rows, field).to_boxed()
Expand Down
130 changes: 102 additions & 28 deletions crates/polars-row/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ use arrow::array::{
};
use arrow::bitmap::Bitmap;
use arrow::datatypes::ArrowDataType;
use arrow::trusted_len::TrustMyLength;
use arrow::types::{NativeType, Offset};

use crate::fixed::{get_null_sentinel, FixedLengthEncoding};
use crate::row::{EncodingField, RowsEncoded};
use crate::{with_match_arrow_primitive_type, ArrayRef};
use crate::variable::varint::VarIntEncoding;
use crate::{with_match_arrow_integer_type, with_match_arrow_primitive_type, ArrayRef};

pub fn convert_columns(
num_rows: usize,
Expand Down Expand Up @@ -323,13 +325,33 @@ fn biniter_num_column_bytes(
}
}

fn varint_get_encoder<T: NativeType + VarIntEncoding>(
array: &dyn Array,
row_widths: &mut RowWidths,
) -> Encoder {
let dc_array = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let validity = dc_array.validity();

let widths = if validity.is_none() {
row_widths.append_iter(dc_array.values_iter().map(|v| T::len_from_item(Some(*v))))
} else {
row_widths.append_iter(dc_array.iter().map(|v| T::len_from_item(v.copied())))
};

Encoder {
widths,
array: array.to_boxed(),
state: EncoderState::Stateless,
}
}

/// Get the encoder for a specific array.
fn get_encoder(array: &dyn Array, field: &EncodingField, row_widths: &mut RowWidths) -> Encoder {
use ArrowDataType as D;
let dtype = array.dtype();

// Fast path: column has a fixed size encoding
if let Some(size) = fixed_size(dtype) {
if let Some(size) = fixed_size(dtype, field) {
row_widths.push_constant(size);
let state = match dtype {
D::FixedSizeList(_, width) => {
Expand Down Expand Up @@ -412,6 +434,15 @@ fn get_encoder(array: &dyn Array, field: &EncodingField, row_widths: &mut RowWid
}
},

D::Int8 => varint_get_encoder::<i8>(array, row_widths),
D::Int16 => varint_get_encoder::<i16>(array, row_widths),
D::Int32 => varint_get_encoder::<i32>(array, row_widths),
D::Int64 => varint_get_encoder::<i64>(array, row_widths),
D::UInt8 => varint_get_encoder::<u8>(array, row_widths),
D::UInt16 => varint_get_encoder::<u16>(array, row_widths),
D::UInt32 => varint_get_encoder::<u32>(array, row_widths),
D::UInt64 => varint_get_encoder::<u64>(array, row_widths),

D::List(_) => list_num_column_bytes::<i32>(array, field, row_widths),
D::LargeList(_) => list_num_column_bytes::<i64>(array, field, row_widths),

Expand Down Expand Up @@ -493,12 +524,34 @@ fn get_encoder(array: &dyn Array, field: &EncodingField, row_widths: &mut RowWid
.as_any()
.downcast_ref::<DictionaryArray<u32>>()
.unwrap();
let iter = dc_array
.iter_typed::<Utf8ViewArray>()
.unwrap()
.map(|opt_s| opt_s.map_or(0, |s| s.len()));
// @TODO: Do a better job here. This is just plainly incorrect.
biniter_num_column_bytes(array, iter, dc_array.validity(), field, row_widths)

if field.no_order {
let widths =
if dc_array.values().len() < 64 {
let mut widths = RowWidths::new(row_widths.num_rows());
row_widths.push_constant(1);
widths.push_constant(1);
widths
} else {
row_widths.append_iter(unsafe {
TrustMyLength::new(dc_array.keys_iter().map(|k| {
<usize as crate::variable::varint::VarIntEncoding>::len_from_item(k)
}), dc_array.len())
})
};

Encoder {
widths,
array: array.to_boxed(),
state: EncoderState::Stateless,
}
} else {
let iter = dc_array
.iter_typed::<Utf8ViewArray>()
.unwrap()
.map(|opt_s| opt_s.map_or(0, |s| s.len()));
biniter_num_column_bytes(array, iter, dc_array.validity(), field, row_widths)
}
},
D::Union(_, _, _) => todo!(),
D::Map(_, _) => todo!(),
Expand Down Expand Up @@ -555,10 +608,18 @@ unsafe fn encode_flat_array(
let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
crate::fixed::encode_iter(buffer, array.iter(), field, offsets);
},
dt if dt.is_numeric() => with_match_arrow_primitive_type!(dt, |$T| {
let array = array.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
encode_primitive(buffer, array, field, offsets);
}),
dt if dt.is_integer() && field.enable_varint => {
with_match_arrow_integer_type!(dt, |$T| {
let array = array.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
crate::variable::varint::encode_iter(buffer, array.iter().map(|v| v.copied()), field, offsets);
})
},
dt if dt.is_numeric() => {
with_match_arrow_primitive_type!(dt, |$T| {
let array = array.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
encode_primitive(buffer, array, field, offsets);
})
},

D::Binary => {
let array = array.as_any().downcast_ref::<BinaryArray<i32>>().unwrap();
Expand Down Expand Up @@ -604,11 +665,17 @@ unsafe fn encode_flat_array(
.as_any()
.downcast_ref::<DictionaryArray<u32>>()
.unwrap();
let iter = dc_array
.iter_typed::<Utf8ViewArray>()
.unwrap()
.map(|opt_s| opt_s.map(|s| s.as_bytes()));
crate::variable::encode_iter(buffer, iter, field, offsets);

if field.no_order {
let iter = unsafe { TrustMyLength::new(dc_array.keys_iter(), dc_array.len()) };
crate::variable::varint::encode_iter(buffer, iter, field, offsets);
} else {
let iter = dc_array
.iter_typed::<Utf8ViewArray>()
.unwrap()
.map(|opt_s| opt_s.map(|s| s.as_bytes()));
crate::variable::encode_iter(buffer, iter, field, offsets);
}
},

D::FixedSizeBinary(_) => todo!(),
Expand Down Expand Up @@ -801,26 +868,33 @@ unsafe fn encode_primitive<T: NativeType + FixedLengthEncoding>(
}
}

pub fn fixed_size(dtype: &ArrowDataType) -> Option<usize> {
pub fn fixed_size(dtype: &ArrowDataType, field: &EncodingField) -> Option<usize> {
use ArrowDataType::*;

if !field.enable_varint {
match dtype {
UInt8 => return Some(u8::ENCODED_LEN),
UInt16 => return Some(u16::ENCODED_LEN),
UInt32 => return Some(u32::ENCODED_LEN),
UInt64 => return Some(u64::ENCODED_LEN),
Int8 => return Some(i8::ENCODED_LEN),
Int16 => return Some(i16::ENCODED_LEN),
Int32 => return Some(i32::ENCODED_LEN),
Int64 => return Some(i64::ENCODED_LEN),
_ => {},
}
}

Some(match dtype {
UInt8 => u8::ENCODED_LEN,
UInt16 => u16::ENCODED_LEN,
UInt32 => u32::ENCODED_LEN,
UInt64 => u64::ENCODED_LEN,
Int8 => i8::ENCODED_LEN,
Int16 => i16::ENCODED_LEN,
Int32 => i32::ENCODED_LEN,
Int64 => i64::ENCODED_LEN,
Decimal(_, _) => i128::ENCODED_LEN,
Float32 => f32::ENCODED_LEN,
Float64 => f64::ENCODED_LEN,
Boolean => bool::ENCODED_LEN,
FixedSizeList(f, width) => 1 + width * fixed_size(f.dtype())?,
FixedSizeList(f, width) => 1 + width * fixed_size(f.dtype(), field)?,
Struct(fs) => {
let mut sum = 0;
for f in fs {
sum += fixed_size(f.dtype())?;
sum += fixed_size(f.dtype(), field)?;
}
1 + sum
},
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-row/src/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ pub struct EncodingField {
/// Ignore all order-related flags and don't encode order-preserving.
/// This is faster for variable encoding as we can just memcopy all the bytes.
pub no_order: bool,

/// Enable the variable integer encoding. This compresses large integers into smaller integers.
pub enable_varint: bool,
}

impl EncodingField {
Expand All @@ -21,6 +24,7 @@ impl EncodingField {
descending,
nulls_last,
no_order: false,
enable_varint: false,
}
}

Expand Down
19 changes: 19 additions & 0 deletions crates/polars-row/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,22 @@ macro_rules! with_match_arrow_primitive_type {(
_ => unreachable!(),
}
})}

#[macro_export]
macro_rules! with_match_arrow_integer_type {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use arrow::datatypes::ArrowDataType::*;
match $key_type {
Int8 => __with_ty__! { i8 },
Int16 => __with_ty__! { i16 },
Int32 => __with_ty__! { i32 },
Int64 => __with_ty__! { i64 },
UInt8 => __with_ty__! { u8 },
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
UInt64 => __with_ty__! { u64 },
_ => unreachable!(),
}
})}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
//! - `0xFF_u8` if this is not the last block for this string
//! - otherwise the length of the block as a `u8`

pub(crate) mod varint;

use std::mem::MaybeUninit;

use arrow::array::{BinaryArray, BinaryViewArray, MutableBinaryViewArray};
Expand Down
Loading