diff --git a/src/include/detail/ivf/qv.h b/src/include/detail/ivf/qv.h index f045813b2..3ea90b4b9 100644 --- a/src/include/detail/ivf/qv.h +++ b/src/include/detail/ivf/qv.h @@ -1040,9 +1040,21 @@ auto query_finite_ram( log_timer _i{tdb_func__ + " in RAM"}; + std::vector> futs; + futs.reserve(nthreads); + size_t part_offset = 0; - while (partitioned_vectors.load()) { - _i.start(); + while (partitioned_vectors.load_tmp()) { + while (!futs.empty()) { + auto min_n = futs.back().get(); + futs.pop_back(); + for (size_t j = 0; j < num_queries; ++j) { + for (auto&& [e, f] : min_n[j]) { + min_scores[j].insert(e, f); + } + } + } + partitioned_vectors.swap(); auto indices = partitioned_vectors.indices(); auto partitioned_ids = partitioned_vectors.ids(); @@ -1050,15 +1062,11 @@ auto query_finite_ram( auto current_part_size = ::num_partitions(partitioned_vectors); size_t parts_per_thread = (current_part_size + nthreads - 1) / nthreads; - std::vector> futs; - futs.reserve(nthreads); - for (size_t n = 0; n < nthreads; ++n) { auto first_part = std::min(n * parts_per_thread, current_part_size); auto last_part = std::min((n + 1) * parts_per_thread, current_part_size); - if (first_part != last_part) { futs.emplace_back(std::async( std::launch::async, @@ -1084,18 +1092,16 @@ auto query_finite_ram( } } - for (size_t n = 0; n < size(futs); ++n) { - auto min_n = futs[n].get(); - - for (size_t j = 0; j < num_queries; ++j) { - for (auto&& [e, f] : min_n[j]) { - min_scores[j].insert(e, f); - } + part_offset += current_part_size; + } + while (!futs.empty()) { + auto min_n = futs.back().get(); + futs.pop_back(); + for (size_t j = 0; j < num_queries; ++j) { + for (auto&& [e, f] : min_n[j]) { + min_scores[j].insert(e, f); } } - - part_offset += current_part_size; - _i.stop(); } return get_top_k_with_scores(min_scores, k_nn); diff --git a/src/include/detail/linalg/partitioned_matrix.h b/src/include/detail/linalg/partitioned_matrix.h index a62ce3ab2..be0971474 100644 --- a/src/include/detail/linalg/partitioned_matrix.h +++ b/src/include/detail/linalg/partitioned_matrix.h @@ -76,6 +76,7 @@ class PartitionedMatrix : public Matrix { using Base::num_rows; public: + using base_type = Base; using value_type = typename Base::value_type; // should be same as T using typename Base::index_type; using typename Base::reference; @@ -217,9 +218,16 @@ class PartitionedMatrix : public Matrix { return part_index_; } + virtual bool load_tmp() { + return false; + } + virtual bool load() { return false; } + + virtual void swap() { + } }; /** diff --git a/src/include/detail/linalg/tdb_partitioned_matrix.h b/src/include/detail/linalg/tdb_partitioned_matrix.h index 5ad644ba0..ed4b26262 100644 --- a/src/include/detail/linalg/tdb_partitioned_matrix.h +++ b/src/include/detail/linalg/tdb_partitioned_matrix.h @@ -143,6 +143,12 @@ class tdbPartitionedMatrix std::unique_ptr partitioned_ids_array_; tiledb::ArraySchema ids_schema_; + std::unique_ptr> temp_ids_; + std::unique_ptr> temp_part_index_; + std::unique_ptr temp_data_; + size_t temp_num_vectors_{0}; + size_t temp_num_parts_{0}; + /***************************************************************************** * Partitioning information ****************************************************************************/ @@ -418,6 +424,11 @@ class tdbPartitionedMatrix std::move(Base{dimensions_, column_capacity_, max_resident_parts_})); this->num_vectors_ = 0; this->num_parts_ = 0; + this->temp_ids_ = std::make_unique>(column_capacity_); + this->temp_part_index_ = + std::make_unique>(max_resident_parts_ + 1); + this->temp_data_ = std::make_unique( + dimensions_, column_capacity_); if (this->part_index_.size() != max_resident_parts_ + 1) { throw std::runtime_error( @@ -433,14 +444,14 @@ class tdbPartitionedMatrix * @todo -- col oriented only for now, should generalize. * */ - bool load() override { + bool load_tmp() override { scoped_timer _{tdb_func__ + " " + partitioned_vectors_uri_}; - if (this->part_index_.size() != max_resident_parts_ + 1) { + if (this->temp_part_index_->size() != max_resident_parts_ + 1) { throw std::runtime_error( "[tdb_partioned_matrix@load] Invalid partitioning, part_index_ " "size " + - std::to_string(this->part_index_.size()) + + std::to_string(this->temp_part_index_->size()) + " != " + std::to_string(max_resident_parts_ + 1)); } @@ -502,11 +513,11 @@ class tdbPartitionedMatrix std::to_string(num_resident_parts) + " resident parts"); } - if (this->part_index_.size() != max_resident_parts_ + 1) { + if (this->temp_part_index_->size() != max_resident_parts_ + 1) { throw std::runtime_error( "[tdb_partioned_matrix@load] Invalid partitioning, part_index_ " "size (" + - std::to_string(this->part_index_.size()) + + std::to_string(this->temp_part_index_->size()) + ") != max_resident_parts_ + 1 (" + std::to_string(max_resident_parts_ + 1) + ")"); } @@ -558,7 +569,7 @@ class tdbPartitionedMatrix // c. Execute the vectors query. tiledb::Query query(ctx_, *(this->partitioned_vectors_array_)); - auto ptr = this->data(); + auto ptr = this->temp_data_->data(); query.set_subarray(subarray) .set_layout(partitioned_vectors_schema_.cell_order()) .set_data_buffer(attr_name, ptr, col_count * dimensions_); @@ -575,7 +586,7 @@ class tdbPartitionedMatrix // d. Execute the IDs query. tiledb::Query ids_query(ctx_, *partitioned_ids_array_); - auto ids_ptr = this->ids_.data(); + auto ids_ptr = this->temp_ids_->data(); ids_query.set_subarray(ids_subarray) .set_data_buffer(ids_attr_name, ids_ptr, col_count); tiledb_helpers::submit_query(tdb_func__, partitioned_ids_uri_, ids_query); @@ -594,21 +605,37 @@ class tdbPartitionedMatrix // Also [first_resident_part, last_resident_part_) auto sub = squashed_indices_[first_resident_part]; for (size_t i = 0; i < num_resident_parts + 1; ++i) { - this->part_index_[i] = squashed_indices_[i + first_resident_part] - sub; + (*this->temp_part_index_)[i] = + squashed_indices_[i + first_resident_part] - sub; } - this->num_vectors_ = num_resident_cols_; - this->num_parts_ = num_resident_parts; + this->temp_num_vectors_ = num_resident_cols_; + this->temp_num_parts_ = num_resident_parts; if (last_resident_part_ == total_num_parts_ && last_resident_col_ == total_max_cols_) { // We have loaded all the data we can, let's close our Array's. close(); } - return true; } + bool load() override { + if (load_tmp()) { + swap(); + return true; + } + return false; + } + + void swap() override { + this->num_vectors_ = this->temp_num_vectors_; + this->num_parts_ = this->temp_num_parts_; + std::swap(this->ids_, *this->temp_ids_); + std::swap(this->part_index_, *this->temp_part_index_); + std::swap(static_cast(*this), *this->temp_data_); + } + /** * Destructor. Closes arrays if they are open. */