Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix duplicate results #2

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
cur_element_count--;
}

std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, const size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const {
return searchKnn(query_data, k, isIdAllowed);
}

std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
Expand Down
21 changes: 16 additions & 5 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
_mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
#endif

if (!isMarkedDeleted(candidate_id))
if (!isMarkedDeleted(candidate_id))
top_candidates.emplace(dist1, candidate_id);

if (top_candidates.size() > ef_construction_)
top_candidates.pop();

Expand Down Expand Up @@ -765,6 +765,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
lock_table.unlock();

markDeletedInternal(internalId);
lock_table.lock();
label_lookup_.erase(label);
}


Expand Down Expand Up @@ -884,6 +886,10 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
setExternalLabel(internal_id_replaced, label);

std::unique_lock <std::mutex> lock_table(label_lookup_lock);
// check if the label is already in the index
if (label_lookup_.find(label) != label_lookup_.end() && !isMarkedDeleted(label_lookup_[label])) {
markDeletedInternal(label_lookup_[label]);
}
label_lookup_.erase(label_replaced);
label_lookup_[label] = internal_id_replaced;
lock_table.unlock();
Expand Down Expand Up @@ -1173,9 +1179,14 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
return cur_c;
}


std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
return searchKnn(query_data, k, this->ef_, isIdAllowed);
}


std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, const size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const {
std::priority_queue<std::pair<dist_t, labeltype >> result;
if (cur_element_count == 0) return result;

Expand Down Expand Up @@ -1210,10 +1221,10 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
if (num_deleted_) {
top_candidates = searchBaseLayerST<true, true>(
currObj, query_data, std::max(ef_, k), isIdAllowed);
currObj, query_data, std::max(ef, k), isIdAllowed);
} else {
top_candidates = searchBaseLayerST<false, true>(
currObj, query_data, std::max(ef_, k), isIdAllowed);
currObj, query_data, std::max(ef, k), isIdAllowed);
}

while (top_candidates.size() > k) {
Expand Down
26 changes: 26 additions & 0 deletions hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,18 +160,44 @@ class AlgorithmInterface {
public:
virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0;

virtual std::priority_queue<std::pair<dist_t, labeltype>>
searchKnn(const void*, size_t, const size_t ef_, BaseFilterFunctor* isIdAllowed = nullptr) const = 0;

virtual std::priority_queue<std::pair<dist_t, labeltype>>
searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0;

// Return k nearest neighbor in the order of closer fist
virtual std::vector<std::pair<dist_t, labeltype>>
searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const;

virtual std::vector<std::pair<dist_t, labeltype>>
searchKnnCloserFirst(const void* query_data, size_t k, const size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const;

virtual void saveIndex(const std::string &location) = 0;
virtual ~AlgorithmInterface(){
}
};

template<typename dist_t>
std::vector<std::pair<dist_t, labeltype>>
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k,
const size_t ef, BaseFilterFunctor* isIdAllowed) const {
std::vector<std::pair<dist_t, labeltype>> result;

// here searchKnn returns the result in the order of further first
auto ret = searchKnn(query_data, k, ef, isIdAllowed);
{
size_t sz = ret.size();
result.resize(sz);
while (!ret.empty()) {
result[--sz] = ret.top();
ret.pop();
}
}

return result;
}

template<typename dist_t>
std::vector<std::pair<dist_t, labeltype>>
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k,
Expand Down
Loading