Skip to content

Commit

Permalink
Use ef as a function parameter (#1)
Browse files Browse the repository at this point in the history
* Use `ef` as a function parameter

* Update searchKnnCloserFirst method signature

* Remove set_ef call from example

* Remove ef parameter check from Index

* Revert all the changes.

* Add overloaded searchKnn method with explicit ef parameter
  • Loading branch information
ozanarmagan authored Jan 30, 2024
1 parent 5100d3f commit 2fec56c
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
4 changes: 4 additions & 0 deletions hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
cur_element_count--;
}

std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, const size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const {
return searchKnn(query_data, k, isIdAllowed);
}

std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
Expand Down
11 changes: 8 additions & 3 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -1173,9 +1173,14 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
return cur_c;
}


std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
return searchKnn(query_data, k, this->ef_, isIdAllowed);
}


std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, const size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const {
std::priority_queue<std::pair<dist_t, labeltype >> result;
if (cur_element_count == 0) return result;

Expand Down Expand Up @@ -1210,10 +1215,10 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
if (num_deleted_) {
top_candidates = searchBaseLayerST<true, true>(
currObj, query_data, std::max(ef_, k), isIdAllowed);
currObj, query_data, std::max(ef, k), isIdAllowed);
} else {
top_candidates = searchBaseLayerST<false, true>(
currObj, query_data, std::max(ef_, k), isIdAllowed);
currObj, query_data, std::max(ef, k), isIdAllowed);
}

while (top_candidates.size() > k) {
Expand Down
26 changes: 26 additions & 0 deletions hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,18 +160,44 @@ class AlgorithmInterface {
public:
virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0;

virtual std::priority_queue<std::pair<dist_t, labeltype>>
searchKnn(const void*, size_t, const size_t ef_, BaseFilterFunctor* isIdAllowed = nullptr) const = 0;

virtual std::priority_queue<std::pair<dist_t, labeltype>>
searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0;

// Return k nearest neighbor in the order of closer fist
virtual std::vector<std::pair<dist_t, labeltype>>
searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const;

virtual std::vector<std::pair<dist_t, labeltype>>
searchKnnCloserFirst(const void* query_data, size_t k, const size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const;

virtual void saveIndex(const std::string &location) = 0;
virtual ~AlgorithmInterface(){
}
};

template<typename dist_t>
std::vector<std::pair<dist_t, labeltype>>
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k,
const size_t ef, BaseFilterFunctor* isIdAllowed) const {
std::vector<std::pair<dist_t, labeltype>> result;

// here searchKnn returns the result in the order of further first
auto ret = searchKnn(query_data, k, ef, isIdAllowed);
{
size_t sz = ret.size();
result.resize(sz);
while (!ret.empty()) {
result[--sz] = ret.top();
ret.pop();
}
}

return result;
}

template<typename dist_t>
std::vector<std::pair<dist_t, labeltype>>
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k,
Expand Down

0 comments on commit 2fec56c

Please sign in to comment.