Skip to content

Commit

Permalink
Introduce a loading layer in NMSLIB. (#2185)
Browse files Browse the repository at this point in the history
* Introduce a loading layer in NMSLIB.

Signed-off-by: Dooyong Kim <[email protected]>

* Added NMSLIB istream implementation.

Signed-off-by: Dooyong Kim <[email protected]>

* Fix integer overflow issue when passing read size for loading NMSLIB vector index.

Signed-off-by: Dooyong Kim <[email protected]>

* Added unit test for NMSLIB loading layer.

Signed-off-by: Dooyong Kim <[email protected]>

* Made a patch in NMSLIB to avoid frequently calling JNI for better loading index performance.

Signed-off-by: Dooyong Kim <[email protected]>

* Compliance constexpr function in C++11 having nullstatement.

Signed-off-by: Dooyong Kim <[email protected]>

---------

Signed-off-by: Dooyong Kim <[email protected]>
Co-authored-by: Dooyong Kim <[email protected]>
  • Loading branch information
0ctopus13prime and Dooyong Kim authored Oct 15, 2024
1 parent f810493 commit 7cf45c8
Show file tree
Hide file tree
Showing 24 changed files with 1,103 additions and 562 deletions.
3 changes: 2 additions & 1 deletion jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ if ("${WIN32}" STREQUAL "")
tests/commons_test.cpp
tests/faiss_stream_support_test.cpp
tests/faiss_index_service_test.cpp
)
tests/nmslib_stream_support_test.cpp
)

target_link_libraries(
jni_test
Expand Down
3 changes: 1 addition & 2 deletions jni/cmake/init-nmslib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@ if (NOT EXISTS ${NMS_REPO_DIR})
execute_process(COMMAND git submodule update --init -- external/nmslib WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif ()


# Apply patches
if(NOT DEFINED APPLY_LIB_PATCHES OR "${APPLY_LIB_PATCHES}" STREQUAL true)
# Define list of patch files
set(PATCH_FILE_LIST)
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0001-Initialize-maxlevel-during-add-from-enterpoint-level.patch")
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0002-Adds-ability-to-pass-ef-parameter-in-the-query-for-h.patch")
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0003-Adding-two-apis-using-stream-to-load-save-in-Hnsw.patch")
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0003-Added-streaming-apis-for-vector-index-loading-in-Hnsw.patch")

# Get patch id of the last commit
execute_process(COMMAND sh -c "git --no-pager show HEAD | git patch-id --stable" OUTPUT_VARIABLE PATCH_ID_OUTPUT_FROM_COMMIT WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/nmslib)
Expand Down
81 changes: 4 additions & 77 deletions jni/include/faiss_stream_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
* GitHub history for details.
*/

#ifndef OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H
#define OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H
#ifndef OPENSEARCH_KNN_JNI_FAISS_STREAM_SUPPORT_H
#define OPENSEARCH_KNN_JNI_FAISS_STREAM_SUPPORT_H

#include "faiss/impl/io.h"
#include "jni_util.h"
#include "native_engines_stream_support.h"

#include <jni.h>
#include <stdexcept>
Expand All @@ -23,80 +24,6 @@
namespace knn_jni {
namespace stream {

/**
* This class contains Java IndexInputWithBuffer reference and calls its API to copy required bytes into a read buffer.
*/

class NativeEngineIndexInputMediator {
public:
// Expect IndexInputWithBuffer is given as `_indexInput`.
NativeEngineIndexInputMediator(JNIUtilInterface *_jni_interface,
JNIEnv *_env,
jobject _indexInput)
: jni_interface(_jni_interface),
env(_env),
indexInput(_indexInput),
bufferArray((jbyteArray) (_jni_interface->GetObjectField(_env,
_indexInput,
getBufferFieldId(_jni_interface, _env)))),
copyBytesMethod(getCopyBytesMethod(_jni_interface, _env)) {
}

void copyBytes(int64_t nbytes, uint8_t *destination) {
while (nbytes > 0) {
// Call `copyBytes` to read bytes as many as possible.
const auto readBytes =
jni_interface->CallIntMethodLong(env, indexInput, copyBytesMethod, nbytes);

// === Critical Section Start ===

// Get primitive array pointer, no copy is happening in OpenJDK.
auto primitiveArray =
(jbyte *) jni_interface->GetPrimitiveArrayCritical(env, bufferArray, nullptr);

// Copy Java bytes to C++ destination address.
std::memcpy(destination, primitiveArray, readBytes);

// Release the acquired primitive array pointer.
// JNI_ABORT tells JVM to directly free memory without copying back to Java byte[].
// Since we're merely copying data, we don't need to copying back.
jni_interface->ReleasePrimitiveArrayCritical(env, bufferArray, primitiveArray, JNI_ABORT);

// === Critical Section End ===

destination += readBytes;
nbytes -= readBytes;
} // End while
}

private:
static jclass getIndexInputWithBufferClass(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jclass INDEX_INPUT_WITH_BUFFER_CLASS =
jni_interface->FindClassFromJNIEnv(env, "org/opensearch/knn/index/store/IndexInputWithBuffer");
return INDEX_INPUT_WITH_BUFFER_CLASS;
}

static jmethodID getCopyBytesMethod(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jmethodID COPY_METHOD_ID =
jni_interface->GetMethodID(env, getIndexInputWithBufferClass(jni_interface, env), "copyBytes", "(J)I");
return COPY_METHOD_ID;
}

static jfieldID getBufferFieldId(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jfieldID BUFFER_FIELD_ID =
jni_interface->GetFieldID(env, getIndexInputWithBufferClass(jni_interface, env), "buffer", "[B");
return BUFFER_FIELD_ID;
}

JNIUtilInterface *jni_interface;
JNIEnv *env;

// `IndexInputWithBuffer` instance having `IndexInput` instance obtained from `Directory` for reading.
jobject indexInput;
jbyteArray bufferArray;
jmethodID copyBytesMethod;
}; // class NativeEngineIndexInputMediator



/**
Expand Down Expand Up @@ -133,4 +60,4 @@ class FaissOpenSearchIOReader final : public faiss::IOReader {
}
}

#endif //OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H
#endif //OPENSEARCH_KNN_JNI_FAISS_STREAM_SUPPORT_H
9 changes: 7 additions & 2 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ namespace knn_jni {

virtual void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) = 0;

virtual jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) = 0;
virtual jint CallNonvirtualIntMethodA(JNIEnv *env, jobject obj, jclass clazz,
jmethodID methodID, jvalue *args) = 0;

virtual jlong CallNonvirtualLongMethodA(JNIEnv * env, jobject obj, jclass clazz,
jmethodID methodID, jvalue* args) = 0;

// --------------------------------------------------------------------------
};
Expand Down Expand Up @@ -194,7 +198,8 @@ namespace knn_jni {
jclass FindClassFromJNIEnv(JNIEnv * env, const char *name) final;
jmethodID GetMethodID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final;
jfieldID GetFieldID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final;
jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) final;
jint CallNonvirtualIntMethodA(JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, jvalue *args) final;
jlong CallNonvirtualLongMethodA(JNIEnv * env, jobject obj, jclass clazz, jmethodID methodID, jvalue* args) final;
void * GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) final;
void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) final;

Expand Down
125 changes: 125 additions & 0 deletions jni/include/native_engines_stream_support.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

#ifndef OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H
#define OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H

#include "jni_util.h"

#include <jni.h>
#include <stdexcept>
#include <iostream>
#include <cstring>

namespace knn_jni {
namespace stream {



/**
* This class contains Java IndexInputWithBuffer reference and calls its API to copy required bytes into a read buffer.
*/
class NativeEngineIndexInputMediator {
public:
// Expect IndexInputWithBuffer is given as `_indexInput`.
NativeEngineIndexInputMediator(JNIUtilInterface *_jni_interface,
JNIEnv *_env,
jobject _indexInput)
: jni_interface(_jni_interface),
env(_env),
indexInput(_indexInput),
bufferArray((jbyteArray) (_jni_interface->GetObjectField(_env,
_indexInput,
getBufferFieldId(_jni_interface, _env)))),
copyBytesMethod(getCopyBytesMethod(_jni_interface, _env)),
remainingBytesMethod(getRemainingBytesMethod(_jni_interface, _env)) {
}

void copyBytes(int64_t nbytes, uint8_t *destination) {
auto jclazz = getIndexInputWithBufferClass(jni_interface, env);

while (nbytes > 0) {
// Call `copyBytes` to read bytes as many as possible.
jvalue args;
args.j = nbytes;
const auto readBytes =
jni_interface->CallNonvirtualIntMethodA(env, indexInput, jclazz, copyBytesMethod, &args);

// === Critical Section Start ===

// Get primitive array pointer, no copy is happening in OpenJDK.
auto primitiveArray =
(jbyte *) jni_interface->GetPrimitiveArrayCritical(env, bufferArray, nullptr);

// Copy Java bytes to C++ destination address.
std::memcpy(destination, primitiveArray, readBytes);

// Release the acquired primitive array pointer.
// JNI_ABORT tells JVM to directly free memory without copying back to Java byte[].
// Since we're merely copying data, we don't need to copying back.
jni_interface->ReleasePrimitiveArrayCritical(env, bufferArray, primitiveArray, JNI_ABORT);

// === Critical Section End ===

destination += readBytes;
nbytes -= readBytes;
} // End while
}

int64_t remainingBytes() {
return jni_interface->CallNonvirtualLongMethodA(env,
indexInput,
getIndexInputWithBufferClass(jni_interface, env),
remainingBytesMethod,
nullptr);
}

private:
static jclass getIndexInputWithBufferClass(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jclass INDEX_INPUT_WITH_BUFFER_CLASS =
jni_interface->FindClassFromJNIEnv(env, "org/opensearch/knn/index/store/IndexInputWithBuffer");
return INDEX_INPUT_WITH_BUFFER_CLASS;
}

static jmethodID getCopyBytesMethod(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jmethodID COPY_METHOD_ID =
jni_interface->GetMethodID(env, getIndexInputWithBufferClass(jni_interface, env), "copyBytes", "(J)I");
return COPY_METHOD_ID;
}

static jmethodID getRemainingBytesMethod(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jmethodID COPY_METHOD_ID =
jni_interface->GetMethodID(env, getIndexInputWithBufferClass(jni_interface, env), "remainingBytes", "()J");
return COPY_METHOD_ID;
}

static jfieldID getBufferFieldId(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jfieldID BUFFER_FIELD_ID =
jni_interface->GetFieldID(env, getIndexInputWithBufferClass(jni_interface, env), "buffer", "[B");
return BUFFER_FIELD_ID;
}

JNIUtilInterface *jni_interface;
JNIEnv *env;

// `IndexInputWithBuffer` instance having `IndexInput` instance obtained from `Directory` for reading.
jobject indexInput;
jbyteArray bufferArray;
jmethodID copyBytesMethod;
jmethodID remainingBytesMethod;
}; // class NativeEngineIndexInputMediator



}
}

#endif //OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H
51 changes: 51 additions & 0 deletions jni/include/nmslib_stream_support.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

#ifndef OPENSEARCH_KNN_JNI_NMSLIB_STREAM_SUPPORT_H
#define OPENSEARCH_KNN_JNI_NMSLIB_STREAM_SUPPORT_H

#include "native_engines_stream_support.h"

namespace knn_jni {
namespace stream {



/**
* NmslibIOReader implementation delegating NativeEngineIndexInputMediator to read bytes.
*/
class NmslibOpenSearchIOReader final : public similarity::NmslibIOReader {
public:
explicit NmslibOpenSearchIOReader(NativeEngineIndexInputMediator *_mediator)
: mediator(_mediator) {
}

void read(char *bytes, size_t len) final {
if (len > 0) {
// Mediator calls IndexInput, then copy read bytes to `ptr`.
mediator->copyBytes(len, (uint8_t *) bytes);
}
}

size_t remainingBytes() final {
return mediator->remainingBytes();
}

private:
NativeEngineIndexInputMediator *mediator;
}; // class NmslibOpenSearchIOReader



}
}

#endif //OPENSEARCH_KNN_JNI_NMSLIB_STREAM_SUPPORT_H
8 changes: 8 additions & 0 deletions jni/include/nmslib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ namespace knn_jni {
// Return a pointer to the loaded index
jlong LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ, jobject parametersJ);

// Load an index via an input stream into memory. Use parametersJ to set any query time parameters
//
// Return a pointer to the loaded index
jlong LoadIndexWithStream(knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
jobject readStream,
jobject parametersJ);

// Execute a query against the index located in memory at indexPointerJ.
//
// Return an array of KNNQueryResults
Expand Down
8 changes: 8 additions & 0 deletions jni/include/org_opensearch_knn_jni_NmslibService.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_NmslibService_loadIndex
(JNIEnv *, jclass, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_NmslibService
* Method: loadIndexWithStream
* Signature: (Lorg/opensearch/knn/index/store/IndexInputWithBuffer;Ljava/util/Map;)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_NmslibService_loadIndexWithStream
(JNIEnv *, jclass, jobject, jobject);

/*
* Class: org_opensearch_knn_jni_NmslibService
* Method: queryIndex
Expand Down
Loading

0 comments on commit 7cf45c8

Please sign in to comment.