Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Faiss Query With Filters: Reduce iteration and memory for id filter #1402

Merged
merged 22 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9015089
Optimize Faiss Query With Filters. Reduce iteration copy for docid se…
luyuncheng Jan 19, 2024
103307c
Optimize Faiss Query With Filters. Reduce iteration copy for docid se…
luyuncheng Jan 22, 2024
c2cd334
Using int64_t instead of long type for GetLongArrayElements
luyuncheng Jan 24, 2024
bfbfb55
Add IDSelectorJlongBitmap
luyuncheng Jan 29, 2024
a82e59f
1. Add IDSelectorJlongBitmap and UT for it
luyuncheng Jan 31, 2024
965621f
1. Add IDSelectorJlongBitmap and UT for it
luyuncheng Jan 31, 2024
da85df8
Rebase remote-tracking branch 'origin/main' into Filter
luyuncheng Feb 6, 2024
f5a7f95
Merge remote-tracking branch 'origin/main' into Filter
luyuncheng Feb 6, 2024
263a575
Rebase remote-tracking branch 'origin/main' into Filter
luyuncheng Feb 6, 2024
dcef6c2
tidy
luyuncheng Feb 6, 2024
9ca9c98
Add Changelog
luyuncheng Feb 6, 2024
568972f
fix javadoc tasks
luyuncheng Feb 7, 2024
a2b27ee
fix bwc javadoc
luyuncheng Feb 7, 2024
a48a928
UpdatedFilterIdsSelector
luyuncheng Feb 7, 2024
5d303e7
Merge branch 'main' into Filter
luyuncheng Feb 8, 2024
1aed422
UpdatedFilterIdsSelector
luyuncheng Feb 16, 2024
b8e961c
Merge remote-tracking branch 'origin/main' into Filter
luyuncheng Feb 16, 2024
3970f98
Rebase faiss_wrapper.cpp
luyuncheng Feb 16, 2024
f5ebc1a
UpdatedFilterIdsSelector For description Select different FilterIdsSe…
luyuncheng Feb 26, 2024
3e1aaee
UpdatedFilterIdsSelector For description Select different FilterIdsSe…
luyuncheng Feb 26, 2024
f747ec7
UpdatedFilterIdsSelector as Byte.SIZE
luyuncheng Mar 4, 2024
b042b62
UpdatedFilterIdsSelector For comments
luyuncheng Mar 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ namespace knn_jni {
//
// Return an array of KNNQueryResults
jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ);
jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ,
jint filterIdsTypeJ);
luyuncheng marked this conversation as resolved.
Show resolved Hide resolved

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);
Expand Down
10 changes: 10 additions & 0 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ namespace knn_jni {

virtual int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) = 0;

virtual int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ) = 0;

virtual int GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ) = 0;

virtual int GetJavaFloatArrayLength(JNIEnv *env, jfloatArray arrayJ) = 0;
Expand All @@ -94,6 +96,8 @@ namespace knn_jni {

virtual jint * GetIntArrayElements(JNIEnv *env, jintArray array, jboolean * isCopy) = 0;

virtual jlong * GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy) = 0;

virtual jobject GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index) = 0;

virtual jobject NewObject(JNIEnv *env, jclass clazz, jmethodID methodId, int id, float distance) = 0;
Expand All @@ -108,6 +112,8 @@ namespace knn_jni {

virtual void ReleaseIntArrayElements(JNIEnv *env, jintArray array, jint *elems, jint mode) = 0;

virtual void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode) = 0;

virtual void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val) = 0;

virtual void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf) = 0;
Expand Down Expand Up @@ -139,20 +145,24 @@ namespace knn_jni {
int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ);
int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ);
int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ);
int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ);
int GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ);
int GetJavaFloatArrayLength(JNIEnv *env, jfloatArray arrayJ);

void DeleteLocalRef(JNIEnv *env, jobject obj);
jbyte * GetByteArrayElements(JNIEnv *env, jbyteArray array, jboolean * isCopy);
jfloat * GetFloatArrayElements(JNIEnv *env, jfloatArray array, jboolean * isCopy);
jint * GetIntArrayElements(JNIEnv *env, jintArray array, jboolean * isCopy);
jlong * GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy);

jobject GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index);
jobject NewObject(JNIEnv *env, jclass clazz, jmethodID methodId, int id, float distance);
jobjectArray NewObjectArray(JNIEnv *env, jsize len, jclass clazz, jobject init);
jbyteArray NewByteArray(JNIEnv *env, jsize len);
void ReleaseByteArrayElements(JNIEnv *env, jbyteArray array, jbyte *elems, int mode);
void ReleaseFloatArrayElements(JNIEnv *env, jfloatArray array, jfloat *elems, int mode);
void ReleaseIntArrayElements(JNIEnv *env, jintArray array, jint *elems, jint mode);
void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode);
void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val);
void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf);

Expand Down
2 changes: 1 addition & 1 deletion jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd
* Signature: (J[FI[J)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray);
(JNIEnv *, jclass, jlong, jfloatArray, jint, jlongArray, jint);

/*
* Class: org_opensearch_knn_jni_FaissService
Expand Down
56 changes: 12 additions & 44 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

// Defines type of IDSelector
enum FilterIdsSelectorType{
BITMAP, BATCH
BITMAP = 0, BATCH = 1,
};

// Translate space type to faiss metric
Expand Down Expand Up @@ -196,11 +196,11 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI

jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ) {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr);
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr, 0);
}

jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ) {
jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ) {

if (queryVectorJ == nullptr) {
throw std::runtime_error("Query Vector cannot be null");
Expand All @@ -223,28 +223,16 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
omp_set_num_threads(1);
// create the filterSearch params if the filterIdsJ is not a null pointer
if(filterIdsJ != nullptr) {
int *filteredIdsArray = jniUtil->GetIntArrayElements(env, filterIdsJ, nullptr);
int64_t *filteredIdsArray = jniUtil->GetLongArrayElements(env, filterIdsJ, nullptr);
int filterIdsLength = env->GetArrayLength(filterIdsJ);
std::unique_ptr<faiss::IDSelector> idSelector;
FilterIdsSelectorType idSelectorType = getIdSelectorType(filteredIdsArray, filterIdsLength);
// start with empty vectors for 2 different types of empty Selectors. We need define them here to avoid copying of data
// during the returns. We could have used pass by reference, but we choose pointers. Returning reference to local
// vector is also an option which can be efficient than copying during returns but it requires upto date C++ compilers.
// To avoid all those confusions, its better to work with pointers here. Ref: https://cplusplus.com/forum/general/56177/
std::vector<faiss::idx_t> convertedIds;
std::vector<uint8_t> bitmap;
// Choose a selector which suits best
if(idSelectorType == BATCH) {
convertedIds.resize(filterIdsLength);
convertFilterIdsToFaissIdType(filteredIdsArray, filterIdsLength, convertedIds.data());
idSelector.reset(new faiss::IDSelectorBatch(convertedIds.size(), convertedIds.data()));
if(filterIdsTypeJ == BITMAP) {
const int bitsetArraySize = filterIdsLength * 8;
uint8_t *bitmap = reinterpret_cast<uint8_t*>(filteredIdsArray);
heemin32 marked this conversation as resolved.
Show resolved Hide resolved
luyuncheng marked this conversation as resolved.
Show resolved Hide resolved
idSelector.reset(new faiss::IDSelectorBitmap(bitsetArraySize, bitmap));
} else {
int maxIdValue = filteredIdsArray[filterIdsLength - 1];
// >> 3 is equivalent to value / 8
const int bitsetArraySize = (maxIdValue >> 3) + 1;
bitmap.resize(bitsetArraySize, 0);
buildFilterIdsBitMap(filteredIdsArray, filterIdsLength, bitmap.data());
idSelector.reset(new faiss::IDSelectorBitmap(filterIdsLength, bitmap.data()));
faiss::idx_t* batchIndices = reinterpret_cast<faiss::idx_t*>(filteredIdsArray);
heemin32 marked this conversation as resolved.
Show resolved Hide resolved
idSelector.reset(new faiss::IDSelectorBatch(filterIdsLength, batchIndices));
}
faiss::SearchParameters *searchParameters;
faiss::SearchParametersHNSW hnswParams;
Expand All @@ -268,10 +256,10 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters);
} catch (...) {
jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT);
jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
throw;
}
jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
} else {
try {
indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data());
Expand Down Expand Up @@ -469,23 +457,3 @@ FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLengt
}
return BITMAP;
}

void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds) {
for (int i = 0; i < filterIdsLength; i++) {
convertedFilterIds[i] = filterIds[i];
}
}

void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector) {
/**
* Coming from Faiss IDSelectorBitmap::is_member function bitmap id will be selected
* iff id / 8 < n and bit number (i%8) of bitmap[floor(i / 8)] is 1.
*/
for(int i = 0 ; i < filterIdsLength ; i ++) {
int value = filterIds[i];
// / , % are expensive operation. Hence, using BitShift operation as they are fast.
int bitsetArrayIndex = value >> 3 ; // is equivalent to value / 8
// (value & 7) equivalent to value % 8
bitsetVector[bitsetArrayIndex] = bitsetVector[bitsetArrayIndex] | (1 << (value & 7));
}
}
26 changes: 26 additions & 0 deletions jni/src/jni_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,17 @@ int knn_jni::JNIUtil::GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) {
return length;
}

int knn_jni::JNIUtil::GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ) {

if (arrayJ == nullptr) {
throw std::runtime_error("Array cannot be null");
}

int length = env->GetArrayLength(arrayJ);
this->HasExceptionInStack(env, "Unable to get array length");
return length;
}

int knn_jni::JNIUtil::GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ) {

if (arrayJ == nullptr) {
Expand Down Expand Up @@ -376,6 +387,17 @@ jint * knn_jni::JNIUtil::GetIntArrayElements(JNIEnv *env, jintArray array, jbool
return intArray;
}

jlong * knn_jni::JNIUtil::GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy) {
// Lets check for error here
jlong * longArray = env->GetLongArrayElements(array, isCopy);
if (longArray == nullptr) {
this->HasExceptionInStack(env, "Unable to get long array");
throw std::runtime_error("Unable to get long array");
}

return longArray;
}

jobject knn_jni::JNIUtil::GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index) {
jobject object = env->GetObjectArrayElement(array, index);
this->HasExceptionInStack(env, "Unable to get object");
Expand Down Expand Up @@ -424,6 +446,10 @@ void knn_jni::JNIUtil::ReleaseIntArrayElements(JNIEnv *env, jintArray array, jin
env->ReleaseIntArrayElements(array, elems, mode);
}

void knn_jni::JNIUtil::ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode) {
env->ReleaseLongArrayElements(array, elems, mode);
}

void knn_jni::JNIUtil::SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val) {
env->SetObjectArrayElement(array, index, val);
this->HasExceptionInStack(env, "Unable to set object array element");
Expand Down
4 changes: 2 additions & 2 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd
}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jintArray filteredIdsJ) {
(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jlongArray filteredIdsJ, jint filterIdsTypeJ) {

try {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ);
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ, filterIdsTypeJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,6 @@ public class KNNConstants {
// Please refer this github issue for more details for choosing this value:
// https://github.com/opensearch-project/k-NN/issues/1049#issuecomment-1694741092
public static int MAX_DISTANCE_COMPUTATIONS = 2048000;
public static int MAX_ID_SELECT_ARRAY = 2048;
luyuncheng marked this conversation as resolved.
Show resolved Hide resolved

}
Loading