Skip to content

Commit

Permalink
Optimize Faiss Query With Filters: Reduce iteration and memory for id…
Browse files Browse the repository at this point in the history
… filter (#1402)

* Optimize Faiss Query With Filters. Reduce iteration copy for docid set iterator

Signed-off-by: luyuncheng <[email protected]>

* Optimize Faiss Query With Filters. Reduce iteration copy for docid set iterator.
Use Bitmap And Batch to do id filter. and you sparse or fixed bitset do exact ANN search

Signed-off-by: luyuncheng <[email protected]>

* Using int64_t instead of long type for GetLongArrayElements

Signed-off-by: luyuncheng <[email protected]>

* Add IDSelectorJlongBitmap

Signed-off-by: luyuncheng <[email protected]>

* 1. Add IDSelectorJlongBitmap and UT for it
2. Move FilterIdsSelectorType to a util class

Signed-off-by: luyuncheng <[email protected]>

* 1. Add IDSelectorJlongBitmap and UT for it
2. Move FilterIdsSelectorType to a util class
3. Spotless apply

Signed-off-by: luyuncheng <[email protected]>

* Rebase remote-tracking branch 'origin/main' into Filter

Signed-off-by: luyuncheng <[email protected]>

* tidy

Signed-off-by: luyuncheng <[email protected]>

* Add Changelog

Signed-off-by: luyuncheng <[email protected]>

* fix javadoc tasks

Signed-off-by: luyuncheng <[email protected]>

* fix bwc javadoc

Signed-off-by: luyuncheng <[email protected]>

* UpdatedFilterIdsSelector

Signed-off-by: luyuncheng <[email protected]>

* UpdatedFilterIdsSelector

Signed-off-by: luyuncheng <[email protected]>

* Rebase faiss_wrapper.cpp

Signed-off-by: luyuncheng <[email protected]>

* UpdatedFilterIdsSelector For description Select different FilterIdsSelectorType

Signed-off-by: luyuncheng <[email protected]>

* UpdatedFilterIdsSelector For description Select different FilterIdsSelectorType

Signed-off-by: luyuncheng <[email protected]>

* UpdatedFilterIdsSelector as Byte.SIZE

Signed-off-by: luyuncheng <[email protected]>

* UpdatedFilterIdsSelector For comments

Signed-off-by: luyuncheng <[email protected]>

---------

Signed-off-by: luyuncheng <[email protected]>
(cherry picked from commit 3eeb855)
  • Loading branch information
luyuncheng authored and github-actions[bot] committed Mar 4, 2024
1 parent 513e496 commit 4ebb39a
Show file tree
Hide file tree
Showing 23 changed files with 401 additions and 159 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.12...2.x)
### Features
### Enhancements
* Optize Faiss Query With Filters: Reduce iteration and memory for id filter [#1402](https://github.com/opensearch-project/k-NN/pull/1402)
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
3 changes: 2 additions & 1 deletion jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ namespace knn_jni {
//
// Return an array of KNNQueryResults
jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ, jintArray parentIdsJ);
jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ,
jint filterIdsTypeJ, jintArray parentIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);
Expand Down
9 changes: 9 additions & 0 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ namespace knn_jni {

virtual int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) = 0;

virtual int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ) = 0;

virtual int GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ) = 0;

virtual int GetJavaFloatArrayLength(JNIEnv *env, jfloatArray arrayJ) = 0;
Expand All @@ -94,6 +96,8 @@ namespace knn_jni {

virtual jint * GetIntArrayElements(JNIEnv *env, jintArray array, jboolean * isCopy) = 0;

virtual jlong * GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy) = 0;

virtual jobject GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index) = 0;

virtual jobject NewObject(JNIEnv *env, jclass clazz, jmethodID methodId, int id, float distance) = 0;
Expand All @@ -108,6 +112,8 @@ namespace knn_jni {

virtual void ReleaseIntArrayElements(JNIEnv *env, jintArray array, jint *elems, jint mode) = 0;

virtual void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode) = 0;

virtual void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val) = 0;

virtual void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf) = 0;
Expand Down Expand Up @@ -139,20 +145,23 @@ namespace knn_jni {
int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ);
int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ);
int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ);
int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ);
int GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ);
int GetJavaFloatArrayLength(JNIEnv *env, jfloatArray arrayJ);

void DeleteLocalRef(JNIEnv *env, jobject obj);
jbyte * GetByteArrayElements(JNIEnv *env, jbyteArray array, jboolean * isCopy);
jfloat * GetFloatArrayElements(JNIEnv *env, jfloatArray array, jboolean * isCopy);
jint * GetIntArrayElements(JNIEnv *env, jintArray array, jboolean * isCopy);
jlong * GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy);
jobject GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index);
jobject NewObject(JNIEnv *env, jclass clazz, jmethodID methodId, int id, float distance);
jobjectArray NewObjectArray(JNIEnv *env, jsize len, jclass clazz, jobject init);
jbyteArray NewByteArray(JNIEnv *env, jsize len);
void ReleaseByteArrayElements(JNIEnv *env, jbyteArray array, jbyte *elems, int mode);
void ReleaseFloatArrayElements(JNIEnv *env, jfloatArray array, jfloat *elems, int mode);
void ReleaseIntArrayElements(JNIEnv *env, jintArray array, jint *elems, jint mode);
void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode);
void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val);
void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf);

Expand Down
2 changes: 1 addition & 1 deletion jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd
* Signature: (J[FI[J)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray, jintArray);
(JNIEnv *, jclass, jlong, jfloatArray, jint, jlongArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
Expand Down
120 changes: 35 additions & 85 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,32 @@

// Defines type of IDSelector
enum FilterIdsSelectorType{
BITMAP, BATCH
BITMAP = 0, BATCH = 1,
};
namespace faiss {

// Using jlong to do Bitmap selector, jlong[] equals to lucene FixedBitSet#bits
struct IDSelectorJlongBitmap : IDSelector {
size_t n;
const jlong* bitmap;

/** Construct with a binary mask like Lucene FixedBitSet
*
* @param n size of the bitmap array
* @param bitmap id like Lucene FixedBitSet bits
*/
IDSelectorJlongBitmap(size_t n, const jlong* bitmap) : n(n), bitmap(bitmap) {};
bool is_member(idx_t id) const final {
uint64_t index = id;
uint64_t i = index >> 6; // div 64
if (i >= n ) {
return false;
}
return (bitmap[i] >> ( index & 63)) & 1L;
}
~IDSelectorJlongBitmap() override {}
};
}
// Translate space type to faiss metric
faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType);

Expand All @@ -42,9 +65,6 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env,
// Train an index with data provided
void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x);

// Helps to choose the right FilterIdsSelectorType for Faiss
FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLength);

// Converts the int FilterIds to Faiss ids type array.
void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds);

Expand Down Expand Up @@ -199,11 +219,12 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI

jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ) {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr, parentIdsJ);
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr, 0, parentIdsJ);
}

jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ, jintArray parentIdsJ) {
jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {

if (queryVectorJ == nullptr) {
throw std::runtime_error("Query Vector cannot be null");
}
Expand All @@ -225,28 +246,14 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
omp_set_num_threads(1);
// create the filterSearch params if the filterIdsJ is not a null pointer
if(filterIdsJ != nullptr) {
int *filteredIdsArray = jniUtil->GetIntArrayElements(env, filterIdsJ, nullptr);
int filterIdsLength = jniUtil->GetJavaIntArrayLength(env, filterIdsJ);
jlong *filteredIdsArray = jniUtil->GetLongArrayElements(env, filterIdsJ, nullptr);
int filterIdsLength = jniUtil->GetJavaLongArrayLength(env, filterIdsJ);
std::unique_ptr<faiss::IDSelector> idSelector;
FilterIdsSelectorType idSelectorType = getIdSelectorType(filteredIdsArray, filterIdsLength);
// start with empty vectors for 2 different types of empty Selectors. We need define them here to avoid copying of data
// during the returns. We could have used pass by reference, but we choose pointers. Returning reference to local
// vector is also an option which can be efficient than copying during returns but it requires upto date C++ compilers.
// To avoid all those confusions, its better to work with pointers here. Ref: https://cplusplus.com/forum/general/56177/
std::vector<faiss::idx_t> convertedIds;
std::vector<uint8_t> bitmap;
// Choose a selector which suits best
if(idSelectorType == BATCH) {
convertedIds.resize(filterIdsLength);
convertFilterIdsToFaissIdType(filteredIdsArray, filterIdsLength, convertedIds.data());
idSelector.reset(new faiss::IDSelectorBatch(convertedIds.size(), convertedIds.data()));
if(filterIdsTypeJ == BITMAP) {
idSelector.reset(new faiss::IDSelectorJlongBitmap(filterIdsLength, filteredIdsArray));
} else {
int maxIdValue = filteredIdsArray[filterIdsLength - 1];
// >> 3 is equivalent to value / 8
const int bitsetArraySize = (maxIdValue >> 3) + 1;
bitmap.resize(bitsetArraySize, 0);
buildFilterIdsBitMap(filteredIdsArray, filterIdsLength, bitmap.data());
idSelector.reset(new faiss::IDSelectorBitmap(bitsetArraySize, bitmap.data()));
faiss::idx_t* batchIndices = reinterpret_cast<faiss::idx_t*>(filteredIdsArray);
idSelector.reset(new faiss::IDSelectorBatch(filterIdsLength, batchIndices));
}
faiss::SearchParameters *searchParameters;
faiss::SearchParametersHNSW hnswParams;
Expand Down Expand Up @@ -276,10 +283,10 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters);
} catch (...) {
jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT);
jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
throw;
}
jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
} else {
faiss::SearchParameters *searchParameters = nullptr;
faiss::SearchParametersHNSW hnswParams;
Expand Down Expand Up @@ -454,63 +461,6 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) {
}
}

/**
* This function takes a call on what ID Selector to use:
* https://github.com/facebookresearch/faiss/wiki/Setting-search-parameters-for-one-query#idselectorarray-idselectorbatch-and-idselectorbitmap
*
* class storage lookup construction(Opensearch + Faiss)
* IDSelectorArray O(k) O(k) O(2k)
* IDSelectorBatch O(k) O(1) O(2k)
* IDSelectorBitmap O(n/8) O(1) O(k) -> n is the max value of id in the index
*
* TODO: We need to ideally decide when we can take another hit of K iterations in latency. Some facts:
* an OpenSearch Index can have max segment size as 5GB which, which on a vector with dimension of 128 boils down to
* 7.5M vectors.
* Ref: https://opensearch.org/docs/latest/search-plugins/knn/knn-index/#hnsw-memory-estimation
* M = 16
* Dimension = 128
* (1.1 * ( 4 * 128 + 8 * 16) * 7500000)/(1024*1024*1024) ~ 4.9GB
* Ids are sequential in a Segment which means for IDSelectorBitmap total size if the max ID has value of 7.5M will be
* 7500000/(8*1024) = 915KBs in worst case. But with larger dimensions this worst case value will decrease.
*
* With 915KB how many ids can be represented as an array of 64-bit longs : 117,120 ids
* So iterating on 117k ids for 1 single pass is also time consuming. So, we are currently concluding to consider only size
* as factor. We need to improve on this.
*
* TODO: Best way is to implement a SparseBitSet in C++. This can be done by extending the IDSelector Interface of Faiss.
*
* @param filterIds
* @param filterIdsLength
* @return std::string
*/
FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLength) {
int maxIdValue = filterIds[filterIdsLength - 1];
if(filterIdsLength * sizeof(faiss::idx_t) * 8 <= maxIdValue ) {
return BATCH;
}
return BITMAP;
}

void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds) {
for (int i = 0; i < filterIdsLength; i++) {
convertedFilterIds[i] = filterIds[i];
}
}

void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector) {
/**
* Coming from Faiss IDSelectorBitmap::is_member function bitmap id will be selected
* iff id / 8 < n and bit number (i%8) of bitmap[floor(i / 8)] is 1.
*/
for(int i = 0 ; i < filterIdsLength ; i ++) {
int value = filterIds[i];
// / , % are expensive operation. Hence, using BitShift operation as they are fast.
int bitsetArrayIndex = value >> 3 ; // is equivalent to value / 8
// (value & 7) equivalent to value % 8
bitsetVector[bitsetArrayIndex] = bitsetVector[bitsetArrayIndex] | (1 << (value & 7));
}
}

std::unique_ptr<faiss::IDGrouperBitmap> buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector<uint64_t>* bitmap) {
int *parentIdsArray = jniUtil->GetIntArrayElements(env, parentIdsJ, nullptr);
int parentIdsLength = jniUtil->GetJavaIntArrayLength(env, parentIdsJ);
Expand Down
26 changes: 26 additions & 0 deletions jni/src/jni_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,17 @@ int knn_jni::JNIUtil::GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) {
return length;
}

int knn_jni::JNIUtil::GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ) {

if (arrayJ == nullptr) {
throw std::runtime_error("Array cannot be null");
}

int length = env->GetArrayLength(arrayJ);
this->HasExceptionInStack(env, "Unable to get array length");
return length;
}

int knn_jni::JNIUtil::GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ) {

if (arrayJ == nullptr) {
Expand Down Expand Up @@ -376,6 +387,17 @@ jint * knn_jni::JNIUtil::GetIntArrayElements(JNIEnv *env, jintArray array, jbool
return intArray;
}

jlong * knn_jni::JNIUtil::GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy) {
// Lets check for error here
jlong * longArray = env->GetLongArrayElements(array, isCopy);
if (longArray == nullptr) {
this->HasExceptionInStack(env, "Unable to get long array");
throw std::runtime_error("Unable to get long array");
}

return longArray;
}

jobject knn_jni::JNIUtil::GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index) {
jobject object = env->GetObjectArrayElement(array, index);
this->HasExceptionInStack(env, "Unable to get object");
Expand Down Expand Up @@ -424,6 +446,10 @@ void knn_jni::JNIUtil::ReleaseIntArrayElements(JNIEnv *env, jintArray array, jin
env->ReleaseIntArrayElements(array, elems, mode);
}

void knn_jni::JNIUtil::ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode) {
env->ReleaseLongArrayElements(array, elems, mode);
}

void knn_jni::JNIUtil::SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val) {
env->SetObjectArrayElement(array, index, val);
this->HasExceptionInStack(env, "Unable to set object array element");
Expand Down
4 changes: 2 additions & 2 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd
}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jintArray filteredIdsJ, jintArray parentIdsJ) {
(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {

try {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ, parentIdsJ);
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ, filterIdsTypeJ, parentIdsJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down
14 changes: 9 additions & 5 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,13 @@ TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) {
queries.push_back(query);
}

std::vector<int> filterIds;
int num_bits = test_util::bits2words(164);
std::vector<jlong> bitmap(num_bits,0);
std::vector<int64_t> filterIds;

for (int64_t i = 154; i < 163; i++) {
filterIds.push_back(i);
test_util::setBitSet(i, bitmap.data(), bitmap.size());
}
std::unordered_set<int> filterIdSet(filterIds.begin(), filterIds.end());

Expand All @@ -270,9 +274,9 @@ TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) {
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;
EXPECT_CALL(mockJNIUtil,
GetJavaIntArrayLength(
jniEnv, reinterpret_cast<jintArray>(&filterIds)))
.WillRepeatedly(Return(filterIds.size()));
GetJavaLongArrayLength(
jniEnv, reinterpret_cast<jlongArray>(&bitmap)))
.WillRepeatedly(Return(bitmap.size()));

int k = 20;
for (auto query : queries) {
Expand All @@ -282,7 +286,7 @@ TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) {
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), k,
reinterpret_cast<jintArray>(&filterIds), nullptr)));
reinterpret_cast<jlongArray>(&bitmap), 0, nullptr)));

ASSERT_TRUE(results->size() <= filterIds.size());
ASSERT_TRUE(results->size() > 0);
Expand Down
Loading

0 comments on commit 4ebb39a

Please sign in to comment.