From 6d7235ba5ab995e42a0e251874e65e9d7eaa2997 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 15 Sep 2024 21:55:38 -0400 Subject: [PATCH] [Java] Exposing SessionOptions.SetDeterministicCompute (#18998) ### Description Exposes `SetDeterministicCompute` in Java, added to the C API by #18944. ### Motivation and Context Parity between C and Java APIs. --- .../main/java/ai/onnxruntime/OrtSession.java | 17 +++++++++++++++++ .../ai_onnxruntime_OrtSession_SessionOptions.c | 13 +++++++++++++ .../test/java/ai/onnxruntime/InferenceTest.java | 1 + 3 files changed, 31 insertions(+) diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index f87cbc76ef141..6d146d5857d3c 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -942,6 +942,20 @@ public void setSymbolicDimensionValue(String dimensionName, long dimensionValue) OnnxRuntime.ortApiHandle, nativeHandle, dimensionName, dimensionValue); } + /** + * Set whether to use deterministic compute. + * + *

Default is false. If set to true, this will enable deterministic compute for GPU kernels + * where possible. Note that this most likely will have a performance cost. + * + * @param value Should the compute be deterministic? + * @throws OrtException If there was an error in native code. + */ + public void setDeterministicCompute(boolean value) throws OrtException { + checkClosed(); + setDeterministicCompute(OnnxRuntime.ortApiHandle, nativeHandle, value); + } + /** * Disables the per session thread pools. Must be used in conjunction with an environment * containing global thread pools. @@ -1327,6 +1341,9 @@ private native void registerCustomOpsUsingFunction( private native void closeOptions(long apiHandle, long nativeHandle); + private native void setDeterministicCompute( + long apiHandle, long nativeHandle, boolean isDeterministic) throws OrtException; + private native void addFreeDimensionOverrideByName( long apiHandle, long nativeHandle, String dimensionName, long dimensionValue) throws OrtException; diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index ff9348c299e90..ff6b7fa703e6e 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -259,6 +259,19 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setSes checkOrtStatus(jniEnv,api,api->SetSessionLogVerbosityLevel(options,logLevel)); } +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: setDeterministicCompute + * Signature: (JJZ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setDeterministicCompute + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jboolean isDeterministic) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle; + checkOrtStatus(jniEnv,api,api->SetDeterministicCompute(options, isDeterministic)); +} + /* * Class: ai_onnxruntime_OrtSession_SessionOptions * Method: registerCustomOpLibrary diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index f76e1b3b20e19..11141a3a65a3e 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1263,6 +1263,7 @@ public void testExtraSessionOptions() throws OrtException, IOException { options.setLoggerId("monkeys"); options.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL); options.setSessionLogVerbosityLevel(5); + options.setDeterministicCompute(true); Map configEntries = options.getConfigEntries(); assertTrue(configEntries.isEmpty()); options.addConfigEntry("key", "value");