diff --git a/velox/common/base/SimdUtil.cpp b/velox/common/base/SimdUtil.cpp index 0aaf5398591e..080081804b7b 100644 --- a/velox/common/base/SimdUtil.cpp +++ b/velox/common/base/SimdUtil.cpp @@ -23,6 +23,7 @@ void gatherBits( const uint64_t* bits, folly::Range indexRange, uint64_t* result) { + constexpr int32_t kStep = xsimd::batch::size; auto size = indexRange.size(); auto indices = indexRange.data(); uint8_t* resultPtr = reinterpret_cast(result); @@ -35,38 +36,17 @@ void gatherBits( *resultPtr = smallResult; return; } - constexpr int32_t kStep = xsimd::batch::size; int32_t i = 0; for (; i + kStep < size; i += kStep) { - if constexpr (kStep == 8) { - *(resultPtr++) = - simd::gather8Bits(bits, xsimd::load_unaligned(indices + i), 8); - } else { - VELOX_DCHECK_EQ(kStep, 4); uint16_t flags = simd::gather8Bits(bits, xsimd::load_unaligned(indices + i), kStep); - if (i % 8 == 0) { - resultPtr[i / 8] = flags; - } else { - resultPtr[i / 8] |= flags << 4; - } - } + bits::storeBitsToByte(flags, resultPtr, i); } auto bitsLeft = size - i; if (bitsLeft > 0) { - if constexpr (kStep == 8) { - *resultPtr = - simd::gather8Bits(bits, xsimd::load_unaligned(indices + i), bitsLeft); - } else { - VELOX_DCHECK_EQ(kStep, 4); uint16_t flags = simd::gather8Bits(bits, xsimd::load_unaligned(indices + i), bitsLeft); - if (i % 8 == 0) { - resultPtr[i / 8] = flags; - } else { - resultPtr[i / 8] |= flags << 4; - } - } + bits::storeBitsToByte(flags, resultPtr, i); } }