Skip to content

Commit

Permalink
Add clang-format
Browse files Browse the repository at this point in the history
Change-Id: I6288ad6b96c52862592f951363162894ecd936ef
  • Loading branch information
stu1130 committed Feb 27, 2020
1 parent 8b3122e commit 9690a4a
Show file tree
Hide file tree
Showing 20 changed files with 331 additions and 372 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
.gradle
.DS_Store
.idea
.clang
*.iml
.ipynb_checkpoints
build
Expand Down
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ configure(javaProjects()) {
targetCompatibility = 1.8
compileJava.options.encoding = "UTF-8"

apply from: file("${rootProject.projectDir}/tools/gradle/formatter.gradle")
apply from: file("${rootProject.projectDir}/tools/gradle/java-formatter.gradle")
apply from: file("${rootProject.projectDir}/tools/gradle/check.gradle")

test {
Expand Down
8 changes: 1 addition & 7 deletions pytorch/pytorch-native/build.gradle
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@

plugins {
id 'maven-publish'
id 'signing'
}

group = "ai.djl.pytorch"

def VERSION = "1.4.0"
apply from: file("${rootProject.projectDir}/tools/gradle/cpp-formatter.gradle")

task buildJNI(type:Exec) {
if (System.properties['os.name'].toLowerCase(Locale.ROOT).contains("windows")) {
Expand Down
1 change: 1 addition & 0 deletions pytorch/pytorch-native/gradlew
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,31 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
#include <torch/script.h>
#include <torch/torch.h>

#include "../build/include/ai_djl_pytorch_jni_PyTorchLibrary.h"
#include "djl_pytorch_jni_utils.h"
#include <torch/torch.h>
#include <torch/script.h>

// The file is the implementation for PyTorch inference operations

JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleLoad
(JNIEnv* env, jobject jthis, jstring jpath, jintArray jarray) {
JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleLoad(
JNIEnv* env, jobject jthis, jstring jpath, jintArray jarray) {
const std::string path_string((env)->GetStringUTFChars(jpath, JNI_FALSE));
const c10::Device device = utils::GetDeviceFromJDevice(env, jarray);
const torch::jit::script::Module module = torch::jit::load(path_string, device);
const auto* module_ptr = new torch::jit::script::Module(module);
return utils::CreatePointer<torch::jit::script::Module>(env, module_ptr);
}

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleEval
(JNIEnv* env, jobject jthis, jobject module_handle) {
JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleEval(
JNIEnv* env, jobject jthis, jobject module_handle) {
auto* module_ptr = utils::GetPointerFromJHandle<torch::jit::script::Module>(env, module_handle);
module_ptr->eval();
}

JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleForward
(JNIEnv* env, jobject jthis, jobject module_handle, jobjectArray jivalue_ptr_array) {
JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleForward(
JNIEnv* env, jobject jthis, jobject module_handle, jobjectArray jivalue_ptr_array) {
auto ivalue_vec = std::vector<c10::IValue>();
for (auto i = 0; i < env->GetArrayLength(jivalue_ptr_array); ++i) {
auto ivalue = utils::GetPointerFromJHandle<c10::IValue>(env, env->GetObjectArrayElement(jivalue_ptr_array, i));
Expand All @@ -44,28 +45,26 @@ JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleForward
return utils::CreatePointer<c10::IValue>(env, result_ptr);
}

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDeleteModule
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDeleteModule(
JNIEnv* env, jobject jthis, jobject jhandle) {
const auto* module_ptr = utils::GetPointerFromJHandle<const torch::jit::script::Module>(env, jhandle);
delete module_ptr;
}

JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueCreateFromTensor
(JNIEnv* env, jobject jthis, jobject jhandle) {
const auto* ivalue_ptr = new c10::IValue(
*utils::GetPointerFromJHandle<torch::Tensor>(env, jhandle));
JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueCreateFromTensor(
JNIEnv* env, jobject jthis, jobject jhandle) {
const auto* ivalue_ptr = new c10::IValue(*utils::GetPointerFromJHandle<torch::Tensor>(env, jhandle));
return utils::CreatePointer<c10::IValue>(env, ivalue_ptr);
}

JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToTensor
(JNIEnv* env, jobject jthis, jobject jhandle) {
auto* tensor_ptr = new torch::Tensor(
utils::GetPointerFromJHandle<c10::IValue>(env, jhandle)->toTensor());
JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToTensor(
JNIEnv* env, jobject jthis, jobject jhandle) {
auto* tensor_ptr = new torch::Tensor(utils::GetPointerFromJHandle<c10::IValue>(env, jhandle)->toTensor());
return utils::CreatePointer<torch::Tensor>(env, tensor_ptr);
}

JNIEXPORT jobjectArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToList
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT jobjectArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToList(
JNIEnv* env, jobject jthis, jobject jhandle) {
auto* ivalue_ptr = utils::GetPointerFromJHandle<c10::IValue>(env, jhandle);
auto ivalue_list = ivalue_ptr->toGenericList();
jobjectArray jarray = env->NewObjectArray(ivalue_list.size(), env->FindClass(utils::POINTER_CLASS), nullptr);
Expand All @@ -77,8 +76,8 @@ JNIEXPORT jobjectArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToLi
return jarray;
}

JNIEXPORT jobjectArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToListFromTuple
(JNIEnv *env, jobject jthis, jobject jhandle) {
JNIEXPORT jobjectArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToListFromTuple(
JNIEnv* env, jobject jthis, jobject jhandle) {
auto* ivalue_ptr = utils::GetPointerFromJHandle<c10::IValue>(env, jhandle);
auto ivalue_list = ivalue_ptr->toTuple()->elements();
jobjectArray jarray = env->NewObjectArray(ivalue_list.size(), env->FindClass(utils::POINTER_CLASS), nullptr);
Expand All @@ -90,8 +89,8 @@ JNIEXPORT jobjectArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToLi
return jarray;
}

JNIEXPORT jobjectArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToTensorList
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT jobjectArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToTensorList(
JNIEnv* env, jobject jthis, jobject jhandle) {
auto* ivalue_ptr = utils::GetPointerFromJHandle<c10::IValue>(env, jhandle);
auto ivalue_list = ivalue_ptr->toTensorList();
jobjectArray jarray = env->NewObjectArray(ivalue_list.size(), env->FindClass(utils::POINTER_CLASS), nullptr);
Expand All @@ -103,8 +102,8 @@ JNIEXPORT jobjectArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToTe
return jarray;
}

JNIEXPORT jobjectArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToMap
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT jobjectArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToMap(
JNIEnv* env, jobject jthis, jobject jhandle) {
auto* ivalue_ptr = utils::GetPointerFromJHandle<c10::IValue>(env, jhandle);
auto dict = ivalue_ptr->toGenericDict();
jobjectArray jarray = env->NewObjectArray(dict.size() * 2, env->FindClass(utils::POINTER_CLASS), nullptr);
Expand All @@ -120,39 +119,38 @@ JNIEXPORT jobjectArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToMa
return jarray;
}

JNIEXPORT jstring JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToString
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT jstring JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToString(
JNIEnv* env, jobject jthis, jobject jhandle) {
auto* ivalue_ptr = utils::GetPointerFromJHandle<c10::IValue>(env, jhandle);
return env->NewStringUTF(ivalue_ptr->toString()->string().c_str());
}

JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsString
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsString(
JNIEnv* env, jobject jthis, jobject jhandle) {
return utils::GetPointerFromJHandle<c10::IValue>(env, jhandle)->isString();
}

JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsTensor
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsTensor(
JNIEnv* env, jobject jthis, jobject jhandle) {
return utils::GetPointerFromJHandle<c10::IValue>(env, jhandle)->isTensor();
}

JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsTensorList
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsTensorList(
JNIEnv* env, jobject jthis, jobject jhandle) {
return utils::GetPointerFromJHandle<c10::IValue>(env, jhandle)->isTensorList();
}

JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsList
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsList(
JNIEnv* env, jobject jthis, jobject jhandle) {
return utils::GetPointerFromJHandle<c10::IValue>(env, jhandle)->isGenericList();
}

JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsMap
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsMap(
JNIEnv* env, jobject jthis, jobject jhandle) {
return utils::GetPointerFromJHandle<c10::IValue>(env, jhandle)->isGenericDict();
}

JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsTuple
(JNIEnv *env, jobject jthis, jobject jhandle) {
JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsTuple(
JNIEnv* env, jobject jthis, jobject jhandle) {
return utils::GetPointerFromJHandle<c10::IValue>(env, jhandle)->isTuple();
}

Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@

// The file is the implementation for PyTorch neural network functional ops

JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchSoftmax
(JNIEnv* env, jobject jthis, jobject jhandle, jlong jdim, jint jdtype) {
JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchSoftmax(
JNIEnv* env, jobject jthis, jobject jhandle, jlong jdim, jint jdtype) {
const auto* tensor_ptr = utils::GetPointerFromJHandle<const torch::Tensor>(env, jhandle);
const auto* result_ptr = new torch::Tensor(tensor_ptr->softmax(jdim, utils::GetScalarTypeFromDType(jdtype)));
return utils::CreatePointer<torch::Tensor>(env, result_ptr);
}

JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchUpsampleBilinear2d
(JNIEnv* env, jobject jthis, jobject jhandle, jlongArray jsize, jboolean jalign_corners) {
JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchUpsampleBilinear2d(
JNIEnv* env, jobject jthis, jobject jhandle, jlongArray jsize, jboolean jalign_corners) {
const auto* tensor_ptr = utils::GetPointerFromJHandle<const torch::Tensor>(env, jhandle);
const auto size_vec = utils::GetVecFromJLongArray(env, jsize);
const auto* result_ptr = new torch::Tensor(torch::upsample_bilinear2d(*tensor_ptr, size_vec, jalign_corners == JNI_TRUE));
const auto* result_ptr =
new torch::Tensor(torch::upsample_bilinear2d(*tensor_ptr, size_vec, jalign_corners == JNI_TRUE));
return utils::CreatePointer<torch::Tensor>(env, result_ptr);
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,17 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
#include <torch/torch.h>

#include "../build/include/ai_djl_pytorch_jni_PyTorchLibrary.h"
#include "djl_pytorch_jni_utils.h"
#include <torch/torch.h>

// The file is the implementation for PyTorch system-wide operations

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchManualSeed
(JNIEnv* env, jobject jthis, jlong jseed) {
JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchManualSeed(JNIEnv* env, jobject jthis, jlong jseed) {
torch::manual_seed(jseed);
}

JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchCudaAvailable
(JNIEnv* env, jobject jthis) {
JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchCudaAvailable(JNIEnv* env, jobject jthis) {
return torch::cuda::is_available();
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,21 @@

// The file is the implementation for PyTorch tensor core functionality operation

JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchSizes
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchSizes(
JNIEnv* env, jobject jthis, jobject jhandle) {
const auto* tensor_ptr = utils::GetPointerFromJHandle<torch::Tensor>(env, jhandle);
jlongArray size = env->NewLongArray(tensor_ptr->dim());
env->SetLongArrayRegion(size, 0, tensor_ptr->dim(),
reinterpret_cast<const jlong*>(tensor_ptr->sizes().data()));
env->SetLongArrayRegion(size, 0, tensor_ptr->dim(), reinterpret_cast<const jlong*>(tensor_ptr->sizes().data()));
return size;
}

JNIEXPORT jint JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDType
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT jint JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDType(JNIEnv* env, jobject jthis, jobject jhandle) {
const auto* tensor_ptr = utils::GetPointerFromJHandle<torch::Tensor>(env, jhandle);
return utils::GetDTypeFromScalarType(tensor_ptr->scalar_type());
}

JNIEXPORT jintArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDevice
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT jintArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDevice(
JNIEnv* env, jobject jthis, jobject jhandle) {
Log log(env);
const auto* tensor_ptr = utils::GetPointerFromJHandle<torch::Tensor>(env, jhandle);
jintArray result = env->NewIntArray(2);
Expand All @@ -44,8 +42,7 @@ JNIEXPORT jintArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDevice
return result;
}

JNIEXPORT jint JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchLayout
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT jint JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchLayout(JNIEnv* env, jobject jthis, jobject jhandle) {
const auto* tensor_ptr = utils::GetPointerFromJHandle<torch::Tensor>(env, jhandle);
auto layout = tensor_ptr->layout();
switch (layout) {
Expand All @@ -60,62 +57,61 @@ JNIEXPORT jint JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchLayout
}
}

JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchTo
(JNIEnv* env, jobject jthis, jobject jhandle, jint jdtype, jintArray jdevice, jboolean jcopy) {
JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchTo(
JNIEnv* env, jobject jthis, jobject jhandle, jint jdtype, jintArray jdevice, jboolean jcopy) {
const auto* tensor_ptr = utils::GetPointerFromJHandle<torch::Tensor>(env, jhandle);
const auto device = utils::GetDeviceFromJDevice(env, jdevice);
const auto* result_ptr = new torch::Tensor(
tensor_ptr->to(device, utils::GetScalarTypeFromDType(jdtype), false, jcopy == JNI_TRUE));
const auto* result_ptr =
new torch::Tensor(tensor_ptr->to(device, utils::GetScalarTypeFromDType(jdtype), false, jcopy == JNI_TRUE));
return utils::CreatePointer<torch::Tensor>(env, result_ptr);
}

JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_tensorClone
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_tensorClone(
JNIEnv* env, jobject jthis, jobject jhandle) {
const auto* tensor_ptr = utils::GetPointerFromJHandle<torch::Tensor>(env, jhandle);
const auto* result_ptr = new torch::Tensor(tensor_ptr->clone());
return utils::CreatePointer<torch::Tensor>(env, result_ptr);
}

JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchSlice
(JNIEnv* env, jobject jthis, jobject jhandle, jlong jdim, jlong jstart, jlong jend, jlong jstep) {
JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchSlice(
JNIEnv* env, jobject jthis, jobject jhandle, jlong jdim, jlong jstart, jlong jend, jlong jstep) {
const auto* tensor_ptr = utils::GetPointerFromJHandle<torch::Tensor>(env, jhandle);
const auto* result_ptr = new torch::Tensor(tensor_ptr->slice(jdim, jstart, jend, jstep));
return utils::CreatePointer<torch::Tensor>(env, result_ptr);
}

JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchMaskedSelect
(JNIEnv *env, jobject jthis, jobject jhandle, jobject jmasked_handle) {
JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchMaskedSelect(
JNIEnv* env, jobject jthis, jobject jhandle, jobject jmasked_handle) {
const auto* tensor_ptr = utils::GetPointerFromJHandle<torch::Tensor>(env, jhandle);
const auto* index_ptr = utils::GetPointerFromJHandle<torch::Tensor>(env, jmasked_handle);
const auto* result_ptr = new torch::Tensor(tensor_ptr->masked_select(*index_ptr));
return utils::CreatePointer<torch::Tensor>(env, result_ptr);
}

JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDataPtr
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDataPtr(
JNIEnv* env, jobject jthis, jobject jhandle) {
const auto* tensor_ptr = utils::GetPointerFromJHandle<torch::Tensor>(env, jhandle);
jobject buf = env->NewDirectByteBuffer(tensor_ptr->data_ptr(), tensor_ptr->nbytes());
return buf;
}

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDeleteTensor
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDeleteTensor(
JNIEnv* env, jobject jthis, jobject jhandle) {
const auto* tensor_ptr = utils::GetPointerFromJHandle<const torch::Tensor>(env, jhandle);
delete tensor_ptr;
}

JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchLogicalXor
(JNIEnv* env, jobject jthis, jobject jself, jobject jother) {
JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchLogicalXor(
JNIEnv* env, jobject jthis, jobject jself, jobject jother) {
const auto* self_ptr = utils::GetPointerFromJHandle<const torch::Tensor>(env, jself);
const auto* other_ptr = utils::GetPointerFromJHandle<const torch::Tensor>(env, jother);
const auto* result_ptr = new torch::Tensor(torch::logical_xor(*self_ptr, *other_ptr));
return utils::CreatePointer<torch::Tensor>(env, result_ptr);
}

JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchLogicalNot
(JNIEnv* env, jobject jthis, jobject jhandle) {
JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchLogicalNot(
JNIEnv* env, jobject jthis, jobject jhandle) {
const auto* tensor_ptr = utils::GetPointerFromJHandle<const torch::Tensor>(env, jhandle);
const auto* result_ptr = new torch::Tensor(tensor_ptr->logical_not());
return utils::CreatePointer<torch::Tensor>(env, result_ptr);

}
Loading

0 comments on commit 9690a4a

Please sign in to comment.