diff --git a/velox/common/base/SimdUtil.cpp b/velox/common/base/SimdUtil.cpp index f8b4ef432b88..0aaf5398591e 100644 --- a/velox/common/base/SimdUtil.cpp +++ b/velox/common/base/SimdUtil.cpp @@ -35,15 +35,38 @@ void gatherBits( *resultPtr = smallResult; return; } + constexpr int32_t kStep = xsimd::batch::size; int32_t i = 0; - for (; i + 8 < size; i += 8) { - *(resultPtr++) = - simd::gather8Bits(bits, xsimd::load_unaligned(indices + i), 8); + for (; i + kStep < size; i += kStep) { + if constexpr (kStep == 8) { + *(resultPtr++) = + simd::gather8Bits(bits, xsimd::load_unaligned(indices + i), 8); + } else { + VELOX_DCHECK_EQ(kStep, 4); + uint16_t flags = + simd::gather8Bits(bits, xsimd::load_unaligned(indices + i), kStep); + if (i % 8 == 0) { + resultPtr[i / 8] = flags; + } else { + resultPtr[i / 8] |= flags << 4; + } + } } auto bitsLeft = size - i; if (bitsLeft > 0) { - *resultPtr = - simd::gather8Bits(bits, xsimd::load_unaligned(indices + i), bitsLeft); + if constexpr (kStep == 8) { + *resultPtr = + simd::gather8Bits(bits, xsimd::load_unaligned(indices + i), bitsLeft); + } else { + VELOX_DCHECK_EQ(kStep, 4); + uint16_t flags = + simd::gather8Bits(bits, xsimd::load_unaligned(indices + i), bitsLeft); + if (i % 8 == 0) { + resultPtr[i / 8] = flags; + } else { + resultPtr[i / 8] |= flags << 4; + } + } } }