diff --git a/src/io/memory_block_io.h b/src/io/memory_block_io.h index 12dfee90..2b8000d7 100644 --- a/src/io/memory_block_io.h +++ b/src/io/memory_block_io.h @@ -36,15 +36,31 @@ namespace vsag { class MemoryBlockIO : public BasicIO { public: explicit MemoryBlockIO(Allocator* allocator, uint64_t block_size = DEFAULT_BLOCK_SIZE) - : block_size_(block_size), allocator_(allocator), blocks_(0, allocator) { + : allocator_(allocator), blocks_(0, allocator) { + block_bit_ = 0; + while (block_size_ > 0) { + block_size_ >>= 1; + block_bit_ += 1; + } + block_bit_ -= 1; + in_block_mask_ = (1ULL << block_bit_) - 1; + block_size_ = in_block_mask_ + 1; } MemoryBlockIO(const JsonType& io_param, const IndexCommonParam& common_param) - : allocator_(common_param.allocator_), blocks_(0, common_param.allocator_) { + : MemoryBlockIO(common_param.allocator_) { if (io_param.contains(BLOCK_IO_BLOCK_SIZE_KEY)) { this->block_size_ = io_param[BLOCK_IO_BLOCK_SIZE_KEY]; // TODO(LHT): trans str to uint64_t } + block_bit_ = 0; + while (block_size_ > 0) { + block_size_ >>= 1; + block_bit_ += 1; + } + block_bit_ -= 1; + in_block_mask_ = (1ULL << block_bit_) - 1; + block_size_ = in_block_mask_ + 1; } ~MemoryBlockIO() override { @@ -83,7 +99,7 @@ class MemoryBlockIO : public BasicIO { private: [[nodiscard]] inline bool check_valid_offset(uint64_t size) const { - return size <= blocks_.size() * block_size_; + return size <= (blocks_.size() << block_bit_); } inline void @@ -91,14 +107,14 @@ class MemoryBlockIO : public BasicIO { [[nodiscard]] inline const uint8_t* get_data_ptr(uint64_t offset) const { - auto block_no = offset / block_size_; - auto block_off = offset % block_size_; + auto block_no = offset >> block_bit_; + auto block_off = offset & in_block_mask_; return blocks_[block_no] + block_off; } [[nodiscard]] inline bool check_in_one_block(uint64_t off1, uint64_t off2) const { - return (off1 / block_size_) == (off2 / block_size_); + return (off1 ^ off2) < block_size_; } private: @@ -109,14 +125,18 @@ class MemoryBlockIO : public BasicIO { Allocator* const allocator_{nullptr}; static const uint64_t DEFAULT_BLOCK_SIZE = 128 * 1024 * 1024; // 128MB + + uint64_t block_bit_ = 27; + + uint64_t in_block_mask_ = (1 << 27) - 1; }; void MemoryBlockIO::WriteImpl(const uint8_t* data, uint64_t size, uint64_t offset) { check_and_realloc(size + offset); uint64_t cur_size = 0; - auto start_no = offset / block_size_; - auto start_off = offset % block_size_; + auto start_no = offset >> block_bit_; + auto start_off = offset & in_block_mask_; auto max_size = block_size_ - start_off; while (cur_size < size) { uint8_t* cur_write = blocks_[start_no] + start_off; @@ -134,8 +154,8 @@ MemoryBlockIO::ReadImpl(uint64_t size, uint64_t offset, uint8_t* data) const { bool ret = check_valid_offset(size + offset); if (ret) { uint64_t cur_size = 0; - auto start_no = offset / block_size_; - auto start_off = offset % block_size_; + auto start_no = offset >> block_bit_; + auto start_off = offset & in_block_mask_; auto max_size = block_size_ - start_off; while (cur_size < size) { const uint8_t* cur_read = blocks_[start_no] + start_off; @@ -189,7 +209,7 @@ MemoryBlockIO::check_and_realloc(uint64_t size) { if (check_valid_offset(size)) { return; } - const uint64_t new_block_count = (size + this->block_size_ - 1) / block_size_; + const uint64_t new_block_count = (size + this->block_size_ - 1) >> block_bit_; auto cur_block_size = this->blocks_.size(); while (cur_block_size < new_block_count) { this->blocks_.emplace_back((uint8_t*)(allocator_->Allocate(block_size_))); diff --git a/src/simd/avx512.cpp b/src/simd/avx512.cpp index e8c4e2b1..8f6ff24c 100644 --- a/src/simd/avx512.cpp +++ b/src/simd/avx512.cpp @@ -16,6 +16,7 @@ #include #include +#include #include "fp32_simd.h" #include "normalize.h" @@ -299,9 +300,10 @@ SQ8ComputeL2Sqr(const float* query, for (; i + 15 < dim; i += 16) { // Load data into registers __m128i code_values = _mm_loadu_si128(reinterpret_cast(codes + i)); + __m512 diff_values = _mm512_loadu_ps(diff + i); + __m512i codes_512 = _mm512_cvtepu8_epi32(code_values); __m512 code_floats = _mm512_div_ps(_mm512_cvtepi32_ps(codes_512), _mm512_set1_ps(255.0f)); - __m512 diff_values = _mm512_loadu_ps(diff + i); __m512 lowerBound_values = _mm512_loadu_ps(lowerBound + i); __m512 query_values = _mm512_loadu_ps(query + i); @@ -373,19 +375,14 @@ SQ8ComputeCodesL2Sqr(const uint8_t* codes1, for (; i + 15 < dim; i += 16) { __m128i code1_values = _mm_loadu_si128(reinterpret_cast(codes1 + i)); __m128i code2_values = _mm_loadu_si128(reinterpret_cast(codes2 + i)); - __m512i codes1_512 = _mm512_cvtepu8_epi32(code1_values); - __m512i codes2_512 = _mm512_cvtepu8_epi32(code2_values); - __m512 code1_floats = _mm512_div_ps(_mm512_cvtepi32_ps(codes1_512), _mm512_set1_ps(255.0f)); - __m512 code2_floats = _mm512_div_ps(_mm512_cvtepi32_ps(codes2_512), _mm512_set1_ps(255.0f)); __m512 diff_values = _mm512_loadu_ps(diff + i); - __m512 lowerBound_values = _mm512_loadu_ps(lowerBound + i); - // Perform calculations - __m512 scaled_codes1 = _mm512_fmadd_ps(code1_floats, diff_values, lowerBound_values); - __m512 scaled_codes2 = _mm512_fmadd_ps(code2_floats, diff_values, lowerBound_values); - __m512 val = _mm512_sub_ps(scaled_codes1, scaled_codes2); - val = _mm512_mul_ps(val, val); - sum = _mm512_add_ps(sum, val); + __m512i codes1_512 = _mm512_cvtepu8_epi32(code1_values); + __m512i codes2_512 = _mm512_cvtepu8_epi32(code2_values); + __m512 sub = _mm512_cvtepi32_ps(_mm512_sub_epi32(codes1_512, codes2_512)); + __m512 scaled = _mm512_mul_ps(sub, _mm512_set1_ps(1.0 / 255.0f)); + __m512 val = _mm512_mul_ps(scaled, diff_values); + sum = _mm512_fmadd_ps(val, val, sum); } // Horizontal addition