Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure sum for better auto-vectorization for floats #4560

Closed
wants to merge 10 commits into from
67 changes: 51 additions & 16 deletions arrow-arith/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,42 +286,77 @@ where
}

let data: &[T::Native] = array.values();
// TODO choose lanes based on T::Native. Extract from simd module
const LANES: usize = 16;
let mut chunk_acc = [T::default_value(); LANES];
let mut rem_acc = T::default_value();

match array.nulls() {
None => {
let sum = data.iter().fold(T::default_value(), |accumulator, value| {
accumulator.add_wrapping(*value)
let data_chunks = data.chunks_exact(64);
let remainder = data_chunks.remainder();

data_chunks.for_each(|chunk| {
chunk.chunks_exact(LANES).for_each(|chunk| {
let chunk: [T::Native; LANES] = chunk.try_into().unwrap();

for i in 0..LANES {
chunk_acc[i] = chunk_acc[i].add_wrapping(chunk[i]);
}
})
});

remainder.iter().copied().for_each(|value| {
rem_acc = rem_acc.add_wrapping(value);
});

let mut reduced = T::default_value();
for v in chunk_acc {
reduced = reduced.add_wrapping(v);
}
let sum = reduced.add_wrapping(rem_acc);

Some(sum)
}
Some(nulls) => {
let mut sum = T::default_value();
// process data in chunks of 64 elements since we also get 64 bits of validity information at a time
let data_chunks = data.chunks_exact(64);
let remainder = data_chunks.remainder();

let bit_chunks = nulls.inner().bit_chunks();
data_chunks
.zip(bit_chunks.iter())
.for_each(|(chunk, mask)| {
// index_mask has value 1 << i in the loop
let mut index_mask = 1;
chunk.iter().for_each(|value| {
if (mask & index_mask) != 0 {
sum = sum.add_wrapping(*value);
let remainder_bits = bit_chunks.remainder_bits();

data_chunks.zip(bit_chunks).for_each(|(chunk, mut mask)| {
// split chunks further into slices corresponding to the vector length
// the compiler is able to unroll this inner loop and remove bounds checks
// since the outer chunk size (64) is always a multiple of the number of lanes
chunk.chunks_exact(LANES).for_each(|chunk| {
let mut chunk: [T::Native; LANES] = chunk.try_into().unwrap();

for i in 0..LANES {
if mask & (1 << i) == 0 {
chunk[i] = T::default_value();
}
index_mask <<= 1;
});
});
chunk_acc[i] = chunk_acc[i].add_wrapping(chunk[i]);
}

let remainder_bits = bit_chunks.remainder_bits();
// skip the shift and avoid overflow for u8 type, which uses 64 lanes.
mask >>= LANES % 64;
})
});

remainder.iter().enumerate().for_each(|(i, value)| {
if remainder_bits & (1 << i) != 0 {
sum = sum.add_wrapping(*value);
rem_acc = rem_acc.add_wrapping(*value);
}
});

let mut reduced = T::default_value();
for v in chunk_acc {
reduced = reduced.add_wrapping(v);
}
let sum = reduced.add_wrapping(rem_acc);

Some(sum)
}
}
Expand Down