Skip to content

Commit

Permalink
Refactor all cpp code (deepjavalibrary#477)
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 authored Dec 30, 2020
1 parent ecda5e9 commit 46053d5
Show file tree
Hide file tree
Showing 28 changed files with 349 additions and 415 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,37 @@
#include <vector>
#include <string>

namespace djl {
namespace utils {
namespace jni {

static constexpr const jint RELEASE_MODE = JNI_ABORT;
static constexpr const jlong NULL_PTR = 0;

inline std::string GetStringFromJString(JNIEnv* env, jstring jstr) {
if (jstr == nullptr) {
return std::string();
}
const char* c_str = env->GetStringUTFChars(jstr, JNI_FALSE);
std::string str = std::string(c_str);
env->ReleaseStringUTFChars(jstr, c_str);
return str;
}

template<typename T>
inline std::vector <T> GetObjectVecFromJHandles(JNIEnv *env, jlongArray jhandles) {
inline std::vector<T> GetObjectVecFromJHandles(JNIEnv* env, jlongArray jhandles) {
jsize length = env->GetArrayLength(jhandles);
jlong *jptrs = env->GetLongArrayElements(jhandles, JNI_FALSE);
std::vector <T> vec;
jlong* jptrs = env->GetLongArrayElements(jhandles, JNI_FALSE);
std::vector<T> vec;
vec.reserve(length);
for (size_t i = 0; i < length; ++i) {
vec.emplace_back(*(reinterpret_cast<T *>(jptrs[i])));
vec.emplace_back(*(reinterpret_cast<T*>(jptrs[i])));
}
env->ReleaseLongArrayElements(jhandles, jptrs, RELEASE_MODE);
return std::move(vec);
}

template <typename T1, typename T2>
template<typename T1, typename T2>
inline jlongArray GetPtrArrayFromContainer(JNIEnv* env, T1 list) {
size_t len = list.size();
jlongArray jarray = env->NewLongArray(len);
Expand All @@ -51,7 +62,7 @@ inline jlongArray GetPtrArrayFromContainer(JNIEnv* env, T1 list) {
return jarray;
}

inline std::vector<int64_t> GetVecFromJLongArray(JNIEnv* env, jlongArray jarray) {
inline std::vector <int64_t> GetVecFromJLongArray(JNIEnv* env, jlongArray jarray) {
jlong* jarr = env->GetLongArrayElements(jarray, JNI_FALSE);
jsize length = env->GetArrayLength(jarray);
std::vector<int64_t> vec(jarr, jarr + length);
Expand All @@ -62,7 +73,7 @@ inline std::vector<int64_t> GetVecFromJLongArray(JNIEnv* env, jlongArray jarray)
inline std::vector<int32_t> GetVecFromJIntArray(JNIEnv* env, jintArray jarray) {
jint* jarr = env->GetIntArrayElements(jarray, JNI_FALSE);
jsize length = env->GetArrayLength(jarray);
std::vector<int32_t> vec(jarr, jarr + length);
std::vector <int32_t> vec(jarr, jarr + length);
env->ReleaseIntArrayElements(jarray, jarr, RELEASE_MODE);
return std::move(vec);
}
Expand All @@ -75,17 +86,51 @@ inline std::vector<float> GetVecFromJFloatArray(JNIEnv* env, jfloatArray jarray)
return std::move(vec);
}

inline std::string GetStringFromJString(JNIEnv* env, jstring jstr) {
if (jstr == nullptr) {
return std::string();
inline std::vector <std::string> GetVecFromJStringArray(JNIEnv* env, jobjectArray array) {
std::vector <std::string> vec;
jsize len = env->GetArrayLength(array);
vec.reserve(len);
for (int i = 0; i < len; ++i) {
vec.emplace_back(djl::utils::jni::GetStringFromJString(
env, (jstring) env->GetObjectArrayElement(array, i)));
}
const char* c_str = env->GetStringUTFChars(jstr, JNI_FALSE);
std::string str = std::string(c_str);
env->ReleaseStringUTFChars(jstr, c_str);
return str;
return std::move(vec);
}

// String[]
inline jobjectArray GetStringArrayFromVec(JNIEnv* env, const std::vector <std::string> &vec) {
jobjectArray array = env->NewObjectArray(vec.size(), env->FindClass("Ljava/lang/String;"), nullptr);
for (int i = 0; i < vec.size(); ++i) {
env->SetObjectArrayElement(array, i, env->NewStringUTF(vec[i].c_str()));
}
return array;
}

inline jintArray GetIntArrayFromVec(JNIEnv* env, const std::vector<int> &vec) {
jintArray array = env->NewIntArray(vec.size());
env->SetIntArrayRegion(array, 0, vec.size(), reinterpret_cast<const jint*>(vec.data()));
return array;
}

inline jobjectArray Get2DIntArrayFrom2DVec(JNIEnv* env, const std::vector<std::vector<int>> &vec) {
jobjectArray array = env->NewObjectArray(vec.size(), env->FindClass("[I"), nullptr);
for (size_t i = 0; i < vec.size(); ++i) {
env->SetObjectArrayElement(array, i, djl::utils::jni::GetIntArrayFromVec(env, vec[i]));
}
return array;
}

// String[][]
inline jobjectArray Get2DStringArrayFrom2DVec(JNIEnv* env, const std::vector<std::vector<std::string>> &vec) {
jobjectArray array = env->NewObjectArray(vec.size(), env->FindClass("[Ljava/lang/String;"), nullptr);
for (int i = 0; i < vec.size(); ++i) {
env->SetObjectArrayElement(array, i, GetStringArrayFromVec(env, vec[i]));
}
return array;
}

} // namespace jni
} // namespace utils
} // namespace djl

#endif //DJL_UTILS_H
3 changes: 2 additions & 1 deletion extensions/fasttext/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

find_package(JNI REQUIRED)

find_path(UTILS_INCLUDE_DIR NAMES djl/utils.h PATHS ${PROJECT_SOURCE_DIR}/../../api/src/main/native REQUIRED)
add_subdirectory(fastText)

add_library(jni_fasttext SHARED src/main/native/ai_djl_fasttext_jni_FastTextLibrary.cc)
target_include_directories(jni_fasttext PUBLIC
${JNI_INCLUDE_DIRS}
${UTILS_INCLUDE_DIR}
fastText/src
build/include)
target_link_libraries(jni_fasttext fasttext-static_pic)
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include <numeric>

#include <djl/utils.h>

#include "args.h"
#include "dictionary.h"
#include "fasttext.cc"
Expand All @@ -35,67 +37,6 @@ struct FastTextPrivateMembers {
std::shared_ptr<fasttext::Model> model_;
};

inline std::string jstringToString(JNIEnv* env, jstring array) {
jsize len = env->GetStringUTFLength(array);

const char* str = env->GetStringUTFChars(array, nullptr);
std::string s(str, len);
env->ReleaseStringUTFChars(array, str);

return s;
}

// String[]
inline jobjectArray GetStringArrayFromVector(JNIEnv* env, const std::vector<std::string>& vec) {
jobjectArray array = env->NewObjectArray(vec.size(), env->FindClass("Ljava/lang/String;"), nullptr);
for (int i = 0; i < vec.size(); ++i) {
env->SetObjectArrayElement(array, i, env->NewStringUTF(vec[i].c_str()));
}
return array;
}

// String[][]
inline jobjectArray Get2DStringArrayFrom2DVector(JNIEnv* env, const std::vector<std::vector<std::string>>& vec) {
jobjectArray array = env->NewObjectArray(vec.size(), env->FindClass("[Ljava/lang/String;"), nullptr);
for (int i = 0; i < vec.size(); ++i) {
env->SetObjectArrayElement(array, i, GetStringArrayFromVector(env, vec[i]));
}
return array;
}

inline void GetVectorFromStringArray(JNIEnv* env, jobjectArray array, std::vector<std::string>* vec) {
jsize len = env->GetArrayLength(array);
vec->resize(len);
for (int i = 0; i < len; ++i) {
std::string stdStr = jstringToString(env, (jstring)env->GetObjectArrayElement(array, i));
(*vec)[i] = stdStr;
}
}

inline jintArray GetIntArrayFromVector(JNIEnv* env, const std::vector<int>& vec) {
jintArray array = env->NewIntArray(vec.size());
env->SetIntArrayRegion(array, 0, vec.size(), vec.data());
return array;
}

inline jobjectArray Get2DIntArrayFrom2DVector(JNIEnv* env, const std::vector<std::vector<int>>& vec) {
jobjectArray array = env->NewObjectArray(vec.size(), env->FindClass("[I"), nullptr);
for (int i = 0; i < vec.size(); ++i) {
env->SetObjectArrayElement(array, i, GetIntArrayFromVector(env, vec[i]));
}
return array;
}

inline std::vector<int> GetVectorFromIntArray(JNIEnv* env, jintArray array) {
jsize len = env->GetArrayLength(array);

void* data = env->GetPrimitiveArrayCritical(array, JNI_FALSE);
std::vector<int> vec((int*)data, ((int*)data) + len);
env->ReleasePrimitiveArrayCritical(array, data, JNI_ABORT);

return vec;
}

JNIEXPORT jlong JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_createFastText(JNIEnv* env, jobject jthis) {
auto* fasttext_ptr = new fasttext::FastText();
return reinterpret_cast<uintptr_t>(fasttext_ptr);
Expand All @@ -110,7 +51,7 @@ JNIEXPORT void JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_freeFastText(
JNIEXPORT void JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_loadModel(
JNIEnv* env, jobject jthis, jlong jhandle, jstring jpath) {
auto* fasttext_ptr = reinterpret_cast<fasttext::FastText*>(jhandle);
const std::string path_string = jstringToString(env, jpath);
const std::string path_string = djl::utils::jni::GetStringFromJString(env, jpath);
try {
fasttext_ptr->loadModel(path_string);
} catch (const std::invalid_argument& e) {
Expand All @@ -121,7 +62,7 @@ JNIEXPORT void JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_loadModel(

JNIEXPORT jboolean JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_checkModel(
JNIEnv* env, jobject jthis, jstring jpath) {
const std::string filename = jstringToString(env, jpath);
const std::string filename = djl::utils::jni::GetStringFromJString(env, jpath);
std::ifstream in(filename, std::ifstream::binary);
int32_t magic;
int32_t version;
Expand Down Expand Up @@ -169,7 +110,7 @@ JNIEXPORT jstring JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_getModelType(
JNIEXPORT jint JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_predictProba(
JNIEnv* env, jobject jthis, jlong jhandle, jstring jtext, jint top_k, jobjectArray jclasses, jfloatArray jprob) {
auto* fasttext_ptr = reinterpret_cast<fasttext::FastText*>(jhandle);
std::string text = jstringToString(env, jtext);
std::string text = djl::utils::jni::GetStringFromJString(env, jtext);
std::istringstream in(text);
std::vector<std::pair<real, std::string>> predictions;
fasttext_ptr->predictLine(in, predictions, top_k, 0.0);
Expand All @@ -188,7 +129,7 @@ JNIEXPORT jint JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_predictProba(

JNIEXPORT jfloatArray JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_getWordVector(
JNIEnv* env, jobject jthis, jlong jhandle, jstring word) {
std::string word_str = jstringToString(env, word);
std::string word_str = djl::utils::jni::GetStringFromJString(env, word);
auto* fasttext_ptr = reinterpret_cast<fasttext::FastText*>(jhandle);
auto* privateMembers = (FastTextPrivateMembers*)fasttext_ptr;

Expand All @@ -201,8 +142,7 @@ JNIEXPORT jfloatArray JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_getWordVe
}

JNIEXPORT int JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_runCmd(JNIEnv* env, jobject jthis, jobjectArray args) {
std::vector<std::string> vec;
GetVectorFromStringArray(env, args, &vec);
std::vector<std::string> vec = djl::utils::jni::GetVecFromJStringArray(env, args);
if (vec.size() < 2) {
printUsage();
return -1;
Expand Down
3 changes: 3 additions & 0 deletions extensions/sentencepiece/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(SPM_ENABLE_TCMALLOC OFF CACHE BOOL "Build sentencepiece static library")
# disable shared library as we only link static library
set(SPM_ENABLE_SHARED OFF CACHE BOOL "Build sentencepiece static library")

find_package(JNI REQUIRED)
find_path(UTILS_INCLUDE_DIR NAMES djl/utils.h PATHS ${PROJECT_SOURCE_DIR}/../../api/src/main/native REQUIRED)

add_subdirectory(sentencepiece)

add_library(sentencepiece_native SHARED src/main/native/ai_djl_sentencepiece_jni_SentencePieceLibrary.cc)
target_include_directories(sentencepiece_native PUBLIC
${JNI_INCLUDE_DIRS}
${UTILS_INCLUDE_DIR}
sentencepiece/src
build/include)
target_link_libraries(sentencepiece_native sentencepiece-static)
Loading

0 comments on commit 46053d5

Please sign in to comment.