-
Notifications
You must be signed in to change notification settings - Fork 841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restructure sum
for better auto-vectorization for floats
#4560
Changes from all commits
0e41362
4912ae3
51be103
857fdd4
4bc686a
f472f3f
be00492
94f7b18
181c6da
0a03c83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -285,44 +285,178 @@ where | |
return None; | ||
} | ||
|
||
let data: &[T::Native] = array.values(); | ||
fn sum_impl_integer<T>(array: &PrimitiveArray<T>) -> Option<T::Native> | ||
where | ||
T: ArrowNumericType, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
let data: &[T::Native] = array.values(); | ||
|
||
match array.nulls() { | ||
None => { | ||
let sum = data.iter().fold(T::default_value(), |accumulator, value| { | ||
accumulator.add_wrapping(*value) | ||
}); | ||
match array.nulls() { | ||
None => { | ||
let sum = data.iter().fold(T::default_value(), |accumulator, value| { | ||
accumulator.add_wrapping(*value) | ||
}); | ||
|
||
Some(sum) | ||
Some(sum) | ||
} | ||
Some(nulls) => { | ||
let mut sum = T::default_value(); | ||
let data_chunks = data.chunks_exact(64); | ||
let remainder = data_chunks.remainder(); | ||
|
||
let bit_chunks = nulls.inner().bit_chunks(); | ||
data_chunks | ||
.zip(bit_chunks.iter()) | ||
.for_each(|(chunk, mask)| { | ||
// index_mask has value 1 << i in the loop | ||
let mut index_mask = 1; | ||
chunk.iter().for_each(|value| { | ||
if (mask & index_mask) != 0 { | ||
sum = sum.add_wrapping(*value); | ||
} | ||
index_mask <<= 1; | ||
}); | ||
}); | ||
|
||
let remainder_bits = bit_chunks.remainder_bits(); | ||
|
||
remainder.iter().enumerate().for_each(|(i, value)| { | ||
if remainder_bits & (1 << i) != 0 { | ||
sum = sum.add_wrapping(*value); | ||
} | ||
}); | ||
|
||
Some(sum) | ||
} | ||
} | ||
Some(nulls) => { | ||
let mut sum = T::default_value(); | ||
let data_chunks = data.chunks_exact(64); | ||
let remainder = data_chunks.remainder(); | ||
|
||
let bit_chunks = nulls.inner().bit_chunks(); | ||
data_chunks | ||
.zip(bit_chunks.iter()) | ||
.for_each(|(chunk, mask)| { | ||
// index_mask has value 1 << i in the loop | ||
let mut index_mask = 1; | ||
chunk.iter().for_each(|value| { | ||
if (mask & index_mask) != 0 { | ||
sum = sum.add_wrapping(*value); | ||
} | ||
|
||
fn sum_impl_floating<T, const LANES: usize>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as above |
||
array: &PrimitiveArray<T>, | ||
) -> Option<T::Native> | ||
where | ||
T: ArrowNumericType, | ||
T::Native: ArrowNativeTypeOp, | ||
{ | ||
let data: &[T::Native] = array.values(); | ||
let mut chunk_acc = [T::default_value(); LANES]; | ||
let mut rem_acc = T::default_value(); | ||
|
||
match array.nulls() { | ||
None => { | ||
let data_chunks = data.chunks_exact(LANES); | ||
let remainder = data_chunks.remainder(); | ||
|
||
data_chunks.for_each(|chunk| { | ||
let chunk: [T::Native; LANES] = chunk.try_into().unwrap(); | ||
|
||
for i in 0..LANES { | ||
chunk_acc[i] = chunk_acc[i].add_wrapping(chunk[i]); | ||
} | ||
}); | ||
|
||
remainder.iter().copied().for_each(|value| { | ||
rem_acc = rem_acc.add_wrapping(value); | ||
}); | ||
|
||
let mut reduced = T::default_value(); | ||
for v in chunk_acc { | ||
reduced = reduced.add_wrapping(v); | ||
} | ||
let sum = reduced.add_wrapping(rem_acc); | ||
|
||
Some(sum) | ||
} | ||
Some(nulls) => { | ||
// process data in chunks of 64 elements since we also get 64 bits of validity information at a time | ||
let data_chunks = data.chunks_exact(64); | ||
let remainder = data_chunks.remainder(); | ||
|
||
let bit_chunks = nulls.inner().bit_chunks(); | ||
let remainder_bits = bit_chunks.remainder_bits(); | ||
|
||
data_chunks.zip(bit_chunks).for_each(|(chunk, mut mask)| { | ||
// split chunks further into slices corresponding to the vector length | ||
// the compiler is able to unroll this inner loop and remove bounds checks | ||
// since the outer chunk size (64) is always a multiple of the number of lanes | ||
chunk.chunks_exact(LANES).for_each(|chunk| { | ||
let mut chunk: [T::Native; LANES] = chunk.try_into().unwrap(); | ||
|
||
for i in 0..LANES { | ||
if mask & (1 << i) == 0 { | ||
chunk[i] = T::default_value(); | ||
} | ||
chunk_acc[i] = chunk_acc[i].add_wrapping(chunk[i]); | ||
} | ||
index_mask <<= 1; | ||
}); | ||
|
||
mask >>= LANES; | ||
}) | ||
}); | ||
|
||
let remainder_bits = bit_chunks.remainder_bits(); | ||
remainder.iter().enumerate().for_each(|(i, value)| { | ||
if remainder_bits & (1 << i) != 0 { | ||
rem_acc = rem_acc.add_wrapping(*value); | ||
} | ||
}); | ||
|
||
remainder.iter().enumerate().for_each(|(i, value)| { | ||
if remainder_bits & (1 << i) != 0 { | ||
sum = sum.add_wrapping(*value); | ||
let mut reduced = T::default_value(); | ||
for v in chunk_acc { | ||
reduced = reduced.add_wrapping(v); | ||
} | ||
}); | ||
let sum = reduced.add_wrapping(rem_acc); | ||
|
||
Some(sum) | ||
Some(sum) | ||
} | ||
} | ||
} | ||
|
||
match T::DATA_TYPE { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This match block is kind of grim, but I don't have a better solution off the top of my head... Perhaps some sort of trait 🤔 |
||
DataType::Timestamp(_, _) | ||
| DataType::Time32(_) | ||
| DataType::Time64(_) | ||
| DataType::Date32 | ||
| DataType::Date64 | ||
| DataType::Duration(_) | ||
| DataType::Interval(_) | ||
| DataType::Int8 | ||
| DataType::Int16 | ||
| DataType::Int32 | ||
| DataType::Int64 | ||
| DataType::UInt8 | ||
| DataType::UInt16 | ||
| DataType::UInt32 | ||
| DataType::UInt64 => sum_impl_integer(array), | ||
DataType::Float16 | ||
| DataType::Float32 | ||
| DataType::Float64 | ||
| DataType::Decimal128(_, _) | ||
| DataType::Decimal256(_, _) => match T::lanes() { | ||
Comment on lines
+433
to
+434
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is decimal here? |
||
1 => sum_impl_floating::<T, 1>(array), | ||
2 => sum_impl_floating::<T, 2>(array), | ||
4 => sum_impl_floating::<T, 4>(array), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It occurs to me that we have 3 floating point types, we could just dispatch to sum_impl_floating with the appropriate constant specified, without needing ArrowNumericType? |
||
8 => sum_impl_floating::<T, 8>(array), | ||
16 => sum_impl_floating::<T, 16>(array), | ||
32 => sum_impl_floating::<T, 32>(array), | ||
64 => sum_impl_floating::<T, 64>(array), | ||
unhandled => unreachable!("Unhandled number of lanes: {unhandled}"), | ||
}, | ||
DataType::Null | ||
| DataType::Boolean | ||
| DataType::Binary | ||
| DataType::FixedSizeBinary(_) | ||
| DataType::LargeBinary | ||
| DataType::Utf8 | ||
| DataType::LargeUtf8 | ||
| DataType::List(_) | ||
| DataType::FixedSizeList(_, _) | ||
| DataType::LargeList(_) | ||
| DataType::Struct(_) | ||
| DataType::Union(_, _) | ||
| DataType::Dictionary(_, _) | ||
| DataType::Map(_, _) | ||
| DataType::RunEndEncoded(_, _) => { | ||
unreachable!("Unsupported data type: {:?}", T::DATA_TYPE) | ||
} | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -113,10 +113,13 @@ where | |
|
||
/// A subtype of primitive type that represents numeric values. | ||
#[cfg(not(feature = "simd"))] | ||
pub trait ArrowNumericType: ArrowPrimitiveType {} | ||
pub trait ArrowNumericType: ArrowPrimitiveType { | ||
/// The number of SIMD lanes available | ||
fn lanes() -> usize; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It feels a little off to define this for all the types, but then only use it for a special case of floats 🤔 |
||
} | ||
|
||
macro_rules! make_numeric_type { | ||
($impl_ty:ty, $native_ty:ty, $simd_ty:ident, $simd_mask_ty:ident) => { | ||
($impl_ty:ty, $native_ty:ty, $simd_ty:ident, $simd_mask_ty:ident, $lanes:expr) => { | ||
#[cfg(feature = "simd")] | ||
impl ArrowNumericType for $impl_ty { | ||
type Simd = $simd_ty; | ||
|
@@ -336,42 +339,52 @@ macro_rules! make_numeric_type { | |
} | ||
|
||
#[cfg(not(feature = "simd"))] | ||
impl ArrowNumericType for $impl_ty {} | ||
impl ArrowNumericType for $impl_ty { | ||
#[inline] | ||
fn lanes() -> usize { | ||
$lanes | ||
} | ||
} | ||
}; | ||
} | ||
|
||
make_numeric_type!(Int8Type, i8, i8x64, m8x64); | ||
make_numeric_type!(Int16Type, i16, i16x32, m16x32); | ||
make_numeric_type!(Int32Type, i32, i32x16, m32x16); | ||
make_numeric_type!(Int64Type, i64, i64x8, m64x8); | ||
make_numeric_type!(UInt8Type, u8, u8x64, m8x64); | ||
make_numeric_type!(UInt16Type, u16, u16x32, m16x32); | ||
make_numeric_type!(UInt32Type, u32, u32x16, m32x16); | ||
make_numeric_type!(UInt64Type, u64, u64x8, m64x8); | ||
make_numeric_type!(Float32Type, f32, f32x16, m32x16); | ||
make_numeric_type!(Float64Type, f64, f64x8, m64x8); | ||
|
||
make_numeric_type!(TimestampSecondType, i64, i64x8, m64x8); | ||
make_numeric_type!(TimestampMillisecondType, i64, i64x8, m64x8); | ||
make_numeric_type!(TimestampMicrosecondType, i64, i64x8, m64x8); | ||
make_numeric_type!(TimestampNanosecondType, i64, i64x8, m64x8); | ||
make_numeric_type!(Date32Type, i32, i32x16, m32x16); | ||
make_numeric_type!(Date64Type, i64, i64x8, m64x8); | ||
make_numeric_type!(Time32SecondType, i32, i32x16, m32x16); | ||
make_numeric_type!(Time32MillisecondType, i32, i32x16, m32x16); | ||
make_numeric_type!(Time64MicrosecondType, i64, i64x8, m64x8); | ||
make_numeric_type!(Time64NanosecondType, i64, i64x8, m64x8); | ||
make_numeric_type!(IntervalYearMonthType, i32, i32x16, m32x16); | ||
make_numeric_type!(IntervalDayTimeType, i64, i64x8, m64x8); | ||
make_numeric_type!(IntervalMonthDayNanoType, i128, i128x4, m128x4); | ||
make_numeric_type!(DurationSecondType, i64, i64x8, m64x8); | ||
make_numeric_type!(DurationMillisecondType, i64, i64x8, m64x8); | ||
make_numeric_type!(DurationMicrosecondType, i64, i64x8, m64x8); | ||
make_numeric_type!(DurationNanosecondType, i64, i64x8, m64x8); | ||
make_numeric_type!(Decimal128Type, i128, i128x4, m128x4); | ||
make_numeric_type!(Int8Type, i8, i8x64, m8x64, 64); | ||
make_numeric_type!(Int16Type, i16, i16x32, m16x32, 32); | ||
make_numeric_type!(Int32Type, i32, i32x16, m32x16, 16); | ||
make_numeric_type!(Int64Type, i64, i64x8, m64x8, 8); | ||
make_numeric_type!(UInt8Type, u8, u8x64, m8x64, 64); | ||
make_numeric_type!(UInt16Type, u16, u16x32, m16x32, 32); | ||
make_numeric_type!(UInt32Type, u32, u32x16, m32x16, 16); | ||
make_numeric_type!(UInt64Type, u64, u64x8, m64x8, 8); | ||
make_numeric_type!(Float32Type, f32, f32x16, m32x16, 16); | ||
make_numeric_type!(Float64Type, f64, f64x8, m64x8, 8); | ||
|
||
make_numeric_type!(TimestampSecondType, i64, i64x8, m64x8, 8); | ||
make_numeric_type!(TimestampMillisecondType, i64, i64x8, m64x8, 8); | ||
make_numeric_type!(TimestampMicrosecondType, i64, i64x8, m64x8, 8); | ||
make_numeric_type!(TimestampNanosecondType, i64, i64x8, m64x8, 8); | ||
make_numeric_type!(Date32Type, i32, i32x16, m32x16, 16); | ||
make_numeric_type!(Date64Type, i64, i64x8, m64x8, 8); | ||
make_numeric_type!(Time32SecondType, i32, i32x16, m32x16, 16); | ||
make_numeric_type!(Time32MillisecondType, i32, i32x16, m32x16, 16); | ||
make_numeric_type!(Time64MicrosecondType, i64, i64x8, m64x8, 8); | ||
make_numeric_type!(Time64NanosecondType, i64, i64x8, m64x8, 8); | ||
make_numeric_type!(IntervalYearMonthType, i32, i32x16, m32x16, 16); | ||
make_numeric_type!(IntervalDayTimeType, i64, i64x8, m64x8, 8); | ||
make_numeric_type!(IntervalMonthDayNanoType, i128, i128x4, m128x4, 4); | ||
make_numeric_type!(DurationSecondType, i64, i64x8, m64x8, 8); | ||
make_numeric_type!(DurationMillisecondType, i64, i64x8, m64x8, 8); | ||
make_numeric_type!(DurationMicrosecondType, i64, i64x8, m64x8, 8); | ||
make_numeric_type!(DurationNanosecondType, i64, i64x8, m64x8, 8); | ||
make_numeric_type!(Decimal128Type, i128, i128x4, m128x4, 4); | ||
|
||
#[cfg(not(feature = "simd"))] | ||
impl ArrowNumericType for Float16Type {} | ||
impl ArrowNumericType for Float16Type { | ||
#[inline] | ||
fn lanes() -> usize { | ||
Float32Type::lanes() | ||
} | ||
} | ||
|
||
#[cfg(feature = "simd")] | ||
impl ArrowNumericType for Float16Type { | ||
|
@@ -467,7 +480,12 @@ impl ArrowNumericType for Float16Type { | |
} | ||
|
||
#[cfg(not(feature = "simd"))] | ||
impl ArrowNumericType for Decimal256Type {} | ||
impl ArrowNumericType for Decimal256Type { | ||
#[inline] | ||
fn lanes() -> usize { | ||
1 | ||
} | ||
} | ||
|
||
#[cfg(feature = "simd")] | ||
impl ArrowNumericType for Decimal256Type { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW if you changed the signature to
It would potentially save on codegen, as it would be instantiated per native type not per primitive type