diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index bfcc65487..775c7ca71 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -376,6 +376,16 @@ void PQFlashIndex::cache_bfs_levels(uint64_t num_nodes_to_cache, std: //auto this_thread_data = manager.scratch_space(); //IOContext &ctx = this_thread_data->ctx; + // initialize the cache buffer + // Allocate space for neighborhood cache + _nhood_cache_buf = new uint32_t[num_nodes_to_cache * (_max_degree + 1)]; + memset(_nhood_cache_buf, 0, num_nodes_to_cache * (_max_degree + 1)); + + // Allocate space for coordinate cache + size_t coord_cache_buf_len = num_nodes_to_cache * _aligned_dim; + diskann::alloc_aligned((void**)&_coord_cache_buf, coord_cache_buf_len * sizeof(T), 8 * sizeof(T)); + memset(_coord_cache_buf, 0, coord_cache_buf_len * sizeof(T)); + std::unique_ptr> cur_level, prev_level; cur_level = std::make_unique>(); prev_level = std::make_unique>(); @@ -402,7 +412,8 @@ void PQFlashIndex::cache_bfs_levels(uint64_t num_nodes_to_cache, std: uint64_t lvl = 1; uint64_t prev_node_set_size = 0; - while ((node_set.size() + cur_level->size() < num_nodes_to_cache) && cur_level->size() != 0) + uint64_t current_count = 0; + while (cur_level->size() != 0) { // swap prev_level and cur_level std::swap(prev_level, cur_level); @@ -413,10 +424,10 @@ void PQFlashIndex::cache_bfs_levels(uint64_t num_nodes_to_cache, std: for (const uint32_t &id : *prev_level) { - if (node_set.find(id) != node_set.end()) - { - continue; - } + //if (node_set.find(id) != node_set.end()) + //{ + // continue; + //} node_set.insert(id); nodes_to_expand.push_back(id); } @@ -431,20 +442,22 @@ void PQFlashIndex::cache_bfs_levels(uint64_t num_nodes_to_cache, std: uint64_t BLOCK_SIZE = 1024; uint64_t nblocks = DIV_ROUND_UP(nodes_to_expand.size(), BLOCK_SIZE); - for (size_t block = 0; block < nblocks && !finish_flag; block++) + for (size_t block = 0; block < nblocks; block++) { diskann::cout << "." << std::flush; size_t start = block * BLOCK_SIZE; size_t end = (std::min)((block + 1) * BLOCK_SIZE, nodes_to_expand.size()); std::vector nodes_to_read; - std::vector coord_buffers(end - start, nullptr); + std::vector coord_buffers; std::vector> nbr_buffers; for (size_t cur_pt = start; cur_pt < end; cur_pt++) { nodes_to_read.push_back(nodes_to_expand[cur_pt]); - nbr_buffers.emplace_back(0, new uint32_t[_max_degree + 1]); + coord_buffers.push_back(_coord_cache_buf + current_count * _aligned_dim); + nbr_buffers.emplace_back(0, _nhood_cache_buf + current_count * (_max_degree + 1)); + current_count++; } // issue read requests @@ -459,6 +472,9 @@ void PQFlashIndex::cache_bfs_levels(uint64_t num_nodes_to_cache, std: } else { + _coord_cache.insert(std::make_pair(nodes_to_read[i], coord_buffers[i])); + _nhood_cache.insert(std::make_pair(nodes_to_read[i], nbr_buffers[i])); + uint32_t nnbrs = nbr_buffers[i].first; uint32_t *nbrs = nbr_buffers[i].second; @@ -485,18 +501,13 @@ void PQFlashIndex::cache_bfs_levels(uint64_t num_nodes_to_cache, std: lvl++; } - assert(node_set.size() + cur_level->size() == num_nodes_to_cache || cur_level->size() == 0); + assert(node_set.size() == num_nodes_to_cache && cur_level->size() == 0); node_list.clear(); - node_list.reserve(node_set.size() + cur_level->size()); + node_list.reserve(node_set.size()); for (auto node : node_set) node_list.push_back(node); - for (auto node : *cur_level) - node_list.push_back(node); - diskann::cout << "Level: " << lvl << std::flush; - diskann::cout << ". #nodes: " << node_list.size() - prev_node_set_size << ", #nodes thus far: " << node_list.size() - << std::endl; diskann::cout << "done" << std::endl; }