From 44a8f22b2b2a418a0c449489ecce32b77c8ab6ec Mon Sep 17 00:00:00 2001 From: Chester Liu <4710575+skyline75489@users.noreply.github.com> Date: Tue, 3 Dec 2024 11:12:47 +0800 Subject: [PATCH] Added Java Bindings for Adapters API (#1030) * Added Java Bindings for Adapters API and test cases. * Enable Java CI tests on win-x64, win-arm64, linux-x64 and osx-arm64. * Fix existing Java test cases. * Fix some style issues reported by spotless. Internal work item: https://task.ms/aii/52363 --- .github/workflows/linux-cpu-x64-build.yml | 17 ++ .github/workflows/mac-cpu-arm64-build.yml | 19 +++ .github/workflows/win-cpu-arm64-build.yml | 29 +++- .github/workflows/win-cpu-x64-build.yml | 16 ++ .gitignore | 2 + build.py | 14 +- src/java/CMakeLists.txt | 73 +++++--- src/java/build-android.gradle | 1 + src/java/build.gradle | 1 + .../java/ai/onnxruntime/genai/Adapters.java | 81 +++++++++ .../java/ai/onnxruntime/genai/Config.java | 23 ++- .../main/java/ai/onnxruntime/genai/GenAI.java | 2 + .../java/ai/onnxruntime/genai/Generator.java | 33 +++- .../ai/onnxruntime/genai/GeneratorParams.java | 10 +- .../java/ai/onnxruntime/genai/Images.java | 44 ++--- .../main/java/ai/onnxruntime/genai/Model.java | 1 + .../genai/MultiModalProcessor.java | 11 +- .../ai/onnxruntime/genai/NamedTensors.java | 42 ++--- .../ai/onnxruntime/genai/SimpleGenAI.java | 14 +- .../java/ai/onnxruntime/genai/Tensor.java | 37 +++- .../native/ai_onnxruntime_genai_Adapters.cpp | 40 +++++ .../native/ai_onnxruntime_genai_GenAI.cpp | 15 ++ .../native/ai_onnxruntime_genai_Generator.cpp | 24 ++- .../native/ai_onnxruntime_genai_Sequences.cpp | 4 +- .../native/ai_onnxruntime_genai_Tensor.cpp | 32 +++- .../genai/GenAITestExecutionListener.java | 10 ++ .../ai/onnxruntime/genai/GenerationTest.java | 159 ++++++++++++------ .../genai/GeneratorParamsTest.java | 16 +- .../genai/MultiModalProcessorTest.java | 39 ++--- .../java/ai/onnxruntime/genai/TensorTest.java | 32 ++-- .../java/ai/onnxruntime/genai/TestUtils.java | 42 ++++- .../ai/onnxruntime/genai/TokenizerTest.java | 13 +- ...it.platform.launcher.TestExecutionListener | 1 + .../test_models/images}/landscape.jpg | Bin 34 files changed, 675 insertions(+), 222 deletions(-) create mode 100644 src/java/src/main/java/ai/onnxruntime/genai/Adapters.java create mode 100644 src/java/src/main/native/ai_onnxruntime_genai_Adapters.cpp create mode 100644 src/java/src/main/native/ai_onnxruntime_genai_GenAI.cpp create mode 100644 src/java/src/test/java/ai/onnxruntime/genai/GenAITestExecutionListener.java create mode 100644 src/java/src/test/resources/META-INF/services/org.junit.platform.launcher.TestExecutionListener rename {src/java/src/test/java/ai/onnxruntime/genai => test/test_models/images}/landscape.jpg (100%) diff --git a/.github/workflows/linux-cpu-x64-build.yml b/.github/workflows/linux-cpu-x64-build.yml index cddfaea8b..5fc97369d 100644 --- a/.github/workflows/linux-cpu-x64-build.yml +++ b/.github/workflows/linux-cpu-x64-build.yml @@ -27,6 +27,18 @@ jobs: with: dotnet-version: '8.0.x' + - name: Setup Java 17 + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: 'gradle' + + - name: Setup Gradle + uses: gradle/actions/setup-gradle@v3 + with: + gradle-version: '8.6' + - name: Get the Latest OnnxRuntime Nightly Version shell: pwsh run: | @@ -96,6 +108,11 @@ jobs: cd test/csharp dotnet test /p:Configuration=Release /p:NativeBuildOutputDir="../../build/cpu/" /p:OrtLibDir="../../ort/lib/" --verbosity normal + - name: Build the Java API and Run the Java Tests + run: | + set -e -x + python3 build.py --config=Release --build_dir build/cpu --build_java --parallel --cmake_generator "Ninja" + - name: Run tests run: | set -e -x diff --git a/.github/workflows/mac-cpu-arm64-build.yml b/.github/workflows/mac-cpu-arm64-build.yml index 2a019a2ca..b675450ab 100644 --- a/.github/workflows/mac-cpu-arm64-build.yml +++ b/.github/workflows/mac-cpu-arm64-build.yml @@ -25,6 +25,18 @@ jobs: with: python-version: '3.12.x' + - name: Setup Java 17 + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: 'gradle' + + - name: Setup Gradle + uses: gradle/actions/setup-gradle@v3 + with: + gradle-version: '8.6' + - name: Get the Latest OnnxRuntime Nightly Version run: | ORT_NIGHTLY_VERSION=$(curl -s "${{ env.ORT_NIGHTLY_REST_API }}" | jq -r '.value[0].versions[0].normalizedVersion') @@ -76,6 +88,7 @@ jobs: source genai-macos-venv/bin/activate export HF_TOKEN="12345" export ORTGENAI_LOG_ORT_LIB=1 + python3 -m pip install requests python3 test/python/test_onnxruntime_genai.py --cwd test/python --test_models test/test_models - name: Build the C# API and Run the C# Tests @@ -84,6 +97,12 @@ jobs: cd test/csharp dotnet test /p:Configuration=Release /p:NativeBuildOutputDir="../../build/cpu/osx-arm64" --verbosity normal + - name: Build the Java API and Run the Java Tests + run: | + set -e -x + source genai-macos-venv/bin/activate + python3 build.py --config=Release --build_dir build/cpu/osx-arm64 --build_java --parallel --cmake_generator "Unix Makefiles" + - name: Run tests run: | set -e -x diff --git a/.github/workflows/win-cpu-arm64-build.yml b/.github/workflows/win-cpu-arm64-build.yml index 9450e10df..b6b20cfc6 100644 --- a/.github/workflows/win-cpu-arm64-build.yml +++ b/.github/workflows/win-cpu-arm64-build.yml @@ -34,6 +34,18 @@ jobs: with: nuget-version: '5.x' + - name: Setup Java 21 + uses: actions/setup-java@v4 + with: + java-version: '21' + distribution: 'temurin' + cache: 'gradle' + + - name: Setup Gradle + uses: gradle/actions/setup-gradle@v3 + with: + gradle-version: '8.6' + - name: Download OnnxRuntime Nightly shell: powershell run: | @@ -54,18 +66,14 @@ jobs: - name: Configure CMake run: | - python -m pip install wheel + python -m pip install wheel requests + cmake --preset windows_arm64_cpu_release - name: Build with CMake run: | cmake --build --preset windows_arm64_cpu_release --parallel - - name: Build the C# API and Run the C# Tests - run: | - cd test\csharp - dotnet test /p:NativeBuildOutputDir="$env:GITHUB_WORKSPACE\$env:binaryDir\Release" /p:OrtLibDir="$env:GITHUB_WORKSPACE\ort\lib" - - name: Install the Python Wheel and Test Dependencies run: | python -m pip install "numpy<2" coloredlogs flatbuffers packaging protobuf sympy pytest @@ -76,6 +84,15 @@ jobs: run: | python test/python/test_onnxruntime_genai.py --cwd "test\python" --test_models "test\test_models" + - name: Build the C# API and Run the C# Tests + run: | + cd test\csharp + dotnet test /p:NativeBuildOutputDir="$env:GITHUB_WORKSPACE\$env:binaryDir\Release" /p:OrtLibDir="$env:GITHUB_WORKSPACE\ort\lib" + + - name: Build the Java API and Run the Java Tests + run: | + python build.py --config=Release --build_dir $env:binaryDir --build_java --parallel + - name: Verify Build Artifacts if: always() continue-on-error: true diff --git a/.github/workflows/win-cpu-x64-build.yml b/.github/workflows/win-cpu-x64-build.yml index d6aeaed85..3374a3b6d 100644 --- a/.github/workflows/win-cpu-x64-build.yml +++ b/.github/workflows/win-cpu-x64-build.yml @@ -41,6 +41,18 @@ jobs: with: dotnet-version: '8.0.x' + - name: Setup Java 17 + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: 'gradle' + + - name: Setup Gradle + uses: gradle/actions/setup-gradle@v3 + with: + gradle-version: '8.6' + - name: Download OnnxRuntime Nightly shell: pwsh run: | @@ -92,6 +104,10 @@ jobs: cd test\csharp dotnet test /p:NativeBuildOutputDir="$env:GITHUB_WORKSPACE\$env:binaryDir\Release" /p:OrtLibDir="$env:GITHUB_WORKSPACE\ort\lib" --verbosity normal + - name: Build the Java API and Run the Java Tests + run: | + python3 build.py --config=Release --build_dir $env:binaryDir --build_java --parallel + - name: Verify Build Artifacts if: always() continue-on-error: true diff --git a/.gitignore b/.gitignore index 5ee2887de..3b68e21f3 100644 --- a/.gitignore +++ b/.gitignore @@ -35,3 +35,5 @@ examples/csharp/HelloPhi/models /src/java/.gradle /src/java/local.properties /src/java/build +/src/java/CMakeFiles +/src/java/CMakeCache.txt \ No newline at end of file diff --git a/build.py b/build.py index d8fb42867..8a750d776 100644 --- a/build.py +++ b/build.py @@ -229,11 +229,13 @@ def _validate_build_dir(args: argparse.Namespace): # also tweak build directory name for mac builds target_sys = "macOS" - args.build_dir = Path("build") / target_sys + # set to a config specific build dir if no build_dir specified from command arguments + args.build_dir = Path("build") / target_sys / args.config - # set to a config specific build dir. it should exist unless we're creating the cmake setup is_strict = not args.update - args.build_dir = args.build_dir.resolve(strict=is_strict) / args.config + # Use user-specified build_dir and ignore args.config + # This is to better accommodate the existing cmake presets which can uses arbitrary paths. + args.build_dir = args.build_dir.resolve(strict=is_strict) def _validate_cuda_args(args: argparse.Namespace): @@ -453,7 +455,7 @@ def update(args: argparse.Namespace, env: dict[str, str]): is_x64_host = platform.machine() == "AMD64" if is_x64_host: - toolset_options += ["host=x64"] + pass if args.use_cuda: toolset_options += ["cuda=" + str(args.cuda_home)] @@ -624,6 +626,10 @@ def test(args: argparse.Namespace, env: dict[str, str]): csharp_test_command += _get_csharp_properties(args, ort_lib_dir=lib_dir) util.run(csharp_test_command, env=env, cwd=str(REPO_ROOT / "test" / "csharp")) + if args.build_java: + ctest_cmd = [str(args.ctest_path), "--build-config", args.config, "--verbose", "--timeout", "10800"] + util.run(ctest_cmd, cwd=str(args.build_dir / "src" / "java")) + if args.android: _run_android_tests(args) diff --git a/src/java/CMakeLists.txt b/src/java/CMakeLists.txt index f2f114a98..cf6187449 100644 --- a/src/java/CMakeLists.txt +++ b/src/java/CMakeLists.txt @@ -13,9 +13,9 @@ endif() set(JAVA_SRC_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) # /src/java (path used with add_subdirectory in root CMakeLists.txt) -set(JAVA_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) +set(JAVA_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) -# Should we use onnxruntime-genai or onnxruntime-genai-static? Using onnxruntime-genai for now. +# Should we use onnxruntime-genai or onnxruntime-genai-static? Using onnxruntime-genai for now. # Add dependency on native target set(JAVA_DEPENDS onnxruntime-genai) @@ -34,20 +34,20 @@ elseif (ANDROID) endif() # this jar is solely used to signaling mechanism for dependency management in CMake -# if any of the Java sources change, the jar (and generated headers) will be regenerated +# if any of the Java sources change, the jar (and generated headers) will be regenerated # and the onnxruntime-genai-jni target will be rebuilt set(JAVA_OUTPUT_JAR ${JAVA_OUTPUT_DIR}/build/libs/onnxruntime-genai.jar) set(GRADLE_ARGS clean jar -x test) # this jar is solely used to signaling mechanism for dependency management in CMake -# if any of the Java sources change, the jar (and generated headers) will be regenerated +# if any of the Java sources change, the jar (and generated headers) will be regenerated # and the onnxruntime-genai-jni target will be rebuilt set(JAVA_OUTPUT_JAR ${JAVA_SRC_ROOT}/build/libs/onnxruntime-genai.jar) set(GRADLE_ARGS clean jar -x test) -add_custom_command(OUTPUT ${JAVA_OUTPUT_JAR} - COMMAND ${GRADLE_EXECUTABLE} ${GRADLE_OPTIONS} ${GRADLE_ARGS} - WORKING_DIRECTORY ${JAVA_SRC_ROOT} +add_custom_command(OUTPUT ${JAVA_OUTPUT_JAR} + COMMAND ${GRADLE_EXECUTABLE} ${GRADLE_OPTIONS} ${GRADLE_ARGS} + WORKING_DIRECTORY ${JAVA_SRC_ROOT} DEPENDS ${genai4j_gradle_files} ${genai4j_srcs}) add_custom_target(onnxruntime-genai4j DEPENDS ${JAVA_OUTPUT_JAR}) @@ -60,13 +60,13 @@ file(GLOB genai4j_native_src "${JAVA_SRC_ROOT}/src/main/native/*.h" "${SRC_ROOT}/ort_genai_c.h" ) - + add_library(onnxruntime-genai-jni SHARED ${genai4j_native_src}) set_property(TARGET onnxruntime-genai-jni PROPERTY CXX_STANDARD 17) add_dependencies(onnxruntime-genai-jni onnxruntime-genai4j) # the JNI headers are generated in the genai4j target target_include_directories(onnxruntime-genai-jni PRIVATE ${SRC_ROOT} - ${JAVA_SRC_ROOT}/build/headers + ${JAVA_SRC_ROOT}/build/headers ${JNI_INCLUDE_DIRS}) target_link_libraries(onnxruntime-genai-jni PUBLIC onnxruntime-genai) @@ -104,13 +104,13 @@ file(MAKE_DIRECTORY ${JAVA_PACKAGE_LIB_DIR}) # Add the native genai library to the native-lib dir add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_if_different - $ + COMMAND ${CMAKE_COMMAND} -E copy_if_different + $ ${JAVA_PACKAGE_LIB_DIR}/$) # Add the JNI bindings to the native-jni dir add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_if_different + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) @@ -140,21 +140,21 @@ if (ANDROID) file(MAKE_DIRECTORY ${ANDROID_PACKAGE_HEADERS_DIR}) file(MAKE_DIRECTORY ${ANDROID_PACKAGE_ABI_DIR}) - # copy C/C++ API headers to be packed into Android AAR package - add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD + # copy C/C++ API headers to be packed into Android AAR package + add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${SRC_ROOT}/ort_genai.h ${ANDROID_PACKAGE_HEADERS_DIR}) - add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD + add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${SRC_ROOT}/ort_genai_c.h ${ANDROID_PACKAGE_HEADERS_DIR}) # Copy onnxruntime-genai.so and onnxruntime-genai-jni.so for building Android AAR package add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_if_different + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${ANDROID_PACKAGE_ABI_DIR}/$) add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_if_different - $ + COMMAND ${CMAKE_COMMAND} -E copy_if_different + $ ${ANDROID_PACKAGE_ABI_DIR}/$) # Generate the Android AAR package @@ -176,20 +176,20 @@ if (ANDROID) file(GLOB android_test_files "${ANDROID_TEST_SRC_ROOT}/*") file(COPY ${android_test_files} DESTINATION ${ANDROID_TEST_PACKAGE_DIR}) - set(ANDROID_TEST_PACKAGE_LIB_DIR ${ANDROID_TEST_PACKAGE_DIR}/app/libs) + set(ANDROID_TEST_PACKAGE_LIB_DIR ${ANDROID_TEST_PACKAGE_DIR}/app/libs) set(ANDROID_TEST_PACKAGE_APP_ASSETS_DIR ${ANDROID_TEST_PACKAGE_DIR}/app/src/main/assets) file(MAKE_DIRECTORY ${ANDROID_TEST_PACKAGE_LIB_DIR}) file(MAKE_DIRECTORY ${ANDROID_TEST_PACKAGE_APP_ASSETS_DIR}) # Copy the test model to the assets folder in the test app add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_directory_if_different + COMMAND ${CMAKE_COMMAND} -E copy_directory_if_different ${REPO_ROOT}/test/test_models/hf-internal-testing/tiny-random-gpt2-fp32 ${ANDROID_TEST_PACKAGE_APP_ASSETS_DIR}/model) # Copy the Android AAR package we built to the libs folder of our test app add_custom_command(TARGET onnxruntime-genai-jni POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_if_different + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${ANDROID_PACKAGE_OUTPUT_DIR}/outputs/aar/onnxruntime-genai-debug.aar ${ANDROID_TEST_PACKAGE_LIB_DIR}/onnxruntime-genai.aar) @@ -207,19 +207,38 @@ if (ENABLE_TESTS) # On windows ctest requires a test to be an .exe(.com) file # With gradle wrapper we get gradlew.bat. We delegate execution to a separate .cmake file # That can handle both .exe and .bat - add_test(NAME onnxruntime-genai4j_test - COMMAND ${CMAKE_COMMAND} - -DGRADLE_EXECUTABLE=${GRADLE_EXECUTABLE} - -DBIN_DIR=${JAVA_OUTPUT_DIR} + add_test(NAME onnxruntime-genai4j_test + COMMAND ${CMAKE_COMMAND} + -DGRADLE_EXECUTABLE=${GRADLE_EXECUTABLE} + -DBIN_DIR=${JAVA_OUTPUT_DIR} -DJAVA_SRC_ROOT=${JAVA_SRC_ROOT} -DJAVA_PACKAGE_LIB_DIR=${JAVA_PACKAGE_LIB_DIR} -P ${JAVA_SRC_ROOT}/windows-unittests.cmake) else() - add_test(NAME onnxruntime-genai4j_test - COMMAND ${GRADLE_EXECUTABLE} cmakeCheck + add_test(NAME onnxruntime-genai4j_test + COMMAND ${GRADLE_EXECUTABLE} cmakeCheck -DcmakeBuildDir=${JAVA_OUTPUT_DIR} -DnativeLibDir=${JAVA_PACKAGE_LIB_DIR} WORKING_DIRECTORY ${JAVA_SRC_ROOT}) endif() + if(WIN32) + set(ONNXRUNTIME_GENAI_DEPENDENCY "*.dll") + elseif(APPLE) + set(ONNXRUNTIME_GENAI_DEPENDENCY "*.dylib") + else() + set(ONNXRUNTIME_GENAI_DEPENDENCY "*.so") + endif() + + file(GLOB ort_native_libs "${ORT_LIB_DIR}/${ONNXRUNTIME_GENAI_DEPENDENCY}") + + # Copy ORT native libs for Java tests + foreach(LIB_FILE ${ort_native_libs}) + add_custom_command( + TARGET onnxruntime-genai-jni POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${LIB_FILE} + ${JAVA_PACKAGE_LIB_DIR}/) + endforeach() + set_property(TEST onnxruntime-genai4j_test APPEND PROPERTY DEPENDS onnxruntime-genai-jni) endif() diff --git a/src/java/build-android.gradle b/src/java/build-android.gradle index 31a38b99e..3920e1307 100644 --- a/src/java/build-android.gradle +++ b/src/java/build-android.gradle @@ -119,6 +119,7 @@ artifacts { dependencies { testImplementation 'org.junit.jupiter:junit-jupiter-api:5.7.0' + testImplementation 'org.junit.platform:junit-platform-launcher:1.10.1' testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.7.0' } diff --git a/src/java/build.gradle b/src/java/build.gradle index a601885e6..bb5c7719d 100644 --- a/src/java/build.gradle +++ b/src/java/build.gradle @@ -153,6 +153,7 @@ if (cmakeNativeLibDir != null) { dependencies { testImplementation 'org.junit.jupiter:junit-jupiter-api:5.9.2' + testImplementation 'org.junit.platform:junit-platform-launcher:1.10.1' testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.9.2' } diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Adapters.java b/src/java/src/main/java/ai/onnxruntime/genai/Adapters.java new file mode 100644 index 000000000..4f5f93b7a --- /dev/null +++ b/src/java/src/main/java/ai/onnxruntime/genai/Adapters.java @@ -0,0 +1,81 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime.genai; + +public final class Adapters implements AutoCloseable { + private long nativeHandle = 0; + + /** + * Constructs an Adapters object with the given model. + * + * @param model The model. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public Adapters(Model model) throws GenAIException { + if (model.nativeHandle() == 0) { + throw new IllegalArgumentException("model has been freed and is invalid"); + } + + nativeHandle = createAdapters(model.nativeHandle()); + } + + /** + * Load an adapter from the specified path. + * + * @param adapterFilePath The path of the adapter. + * @param adapterName A unique user supplied adapter identifier. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public void loadAdapters(String adapterFilePath, String adapterName) throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + loadAdapter(nativeHandle, adapterFilePath, adapterName); + } + + /** + * Unload an adapter. + * + * @param adapterName A unique user supplied adapter identifier. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public void unloadAdapters(String adapterName) throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + unloadAdapter(nativeHandle, adapterName); + } + + @Override + public void close() { + if (nativeHandle != 0) { + destroyAdapters(nativeHandle); + nativeHandle = 0; + } + } + + long nativeHandle() { + return nativeHandle; + } + + static { + try { + GenAI.init(); + } catch (Exception e) { + throw new RuntimeException("Failed to load onnxruntime-genai native libraries", e); + } + } + + private native long createAdapters(long modelHandle) throws GenAIException; + + private native void destroyAdapters(long nativeHandle); + + private native void loadAdapter(long nativeHandle, String adapterFilePath, String adapterName) + throws GenAIException; + + private native void unloadAdapter(long nativeHandle, String adapterName) throws GenAIException; +} diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Config.java b/src/java/src/main/java/ai/onnxruntime/genai/Config.java index 44c823b63..857eebaa1 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/Config.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/Config.java @@ -11,23 +11,23 @@ public Config(String modelPath) throws GenAIException { } public void clearProviders() { - if (nativeHandle == 0) { - throw new IllegalStateException("Instance has been freed and is invalid"); - } - clearProviders(nativeHandle); + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + clearProviders(nativeHandle); } public void appendProvider(String provider_name) { if (nativeHandle == 0) { - throw new IllegalStateException("Instance has been freed and is invalid"); + throw new IllegalStateException("Instance has been freed and is invalid"); } - appendProvider(nativeHandle, provider_name); + appendProvider(nativeHandle, provider_name); } public void setProviderOption(String provider_name, String option_name, String option_value) { if (nativeHandle == 0) { - throw new IllegalStateException("Instance has been freed and is invalid"); - } + throw new IllegalStateException("Instance has been freed and is invalid"); + } setProviderOption(nativeHandle, provider_name, option_name, option_value); } @@ -52,8 +52,13 @@ long nativeHandle() { } private native long createConfig(String modelPath) throws GenAIException; + private native void destroyConfig(long configHandle); + private native void clearProviders(long configHandle); + private native void appendProvider(long configHandle, String provider_name); - private native void setProviderOption(long configHandle, String provider_name, String option_name, String option_value); + + private native void setProviderOption( + long configHandle, String provider_name, String option_name, String option_value); } diff --git a/src/java/src/main/java/ai/onnxruntime/genai/GenAI.java b/src/java/src/main/java/ai/onnxruntime/genai/GenAI.java index 914935338..2d31fe196 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/GenAI.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/GenAI.java @@ -67,6 +67,8 @@ static synchronized void init() throws IOException { } } + static native void shutdown(); + /* Computes and initializes OS_ARCH_STR (such as linux-x64) */ private static String initOsArch() { String detectedOS = null; diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Generator.java b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java index a38c59c14..79c3d8053 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/Generator.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java @@ -150,6 +150,32 @@ public int getLastTokenInSequence(long sequenceIndex) throws GenAIException { return getSequenceLastToken(nativeHandle, sequenceIndex); } + /** + * Fetches and returns the output tensor with the given name. + * + * @param name The name of the output needed. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public Tensor getOutput(String name) throws GenAIException { + long tensorHandle = getOutputNative(nativeHandle, name); + return new Tensor(tensorHandle); + } + + /** + * Activates one of the loaded adapters. + * + * @param adapters The Adapters container. + * @param adapterName The adapter name that was previously loaded. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public void setActiveAdapter(Adapters adapters, String adapterName) throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + setActiveAdapter(nativeHandle, adapters.nativeHandle(), adapterName); + } + /** Closes the Generator and releases any associated resources. */ @Override public void close() { @@ -191,7 +217,7 @@ private native long createGenerator(long modelHandle, long generatorParamsHandle private native void destroyGenerator(long nativeHandle); private native boolean isDone(long nativeHandle); - + private native void appendTokens(long nativeHandle, int[] tokens) throws GenAIException; private native void appendTokenSequences(long nativeHandle, long sequencesHandle) @@ -206,4 +232,9 @@ private native int[] getSequenceNative(long nativeHandle, long sequenceIndex) private native int getSequenceLastToken(long nativeHandle, long sequenceIndex) throws GenAIException; + + private native void setActiveAdapter( + long nativeHandle, long adaptersNativeHandle, String adapterName) throws GenAIException; + + private native long getOutputNative(long nativeHandle, String outputName) throws GenAIException; } diff --git a/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java b/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java index dbcabeff8..5f9cca786 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java @@ -5,7 +5,6 @@ package ai.onnxruntime.genai; import java.nio.ByteBuffer; -import java.nio.ByteOrder; /** * The `GeneratorParams` class represents the parameters used for generating sequences with a model. @@ -44,7 +43,7 @@ public void setSearchOption(String optionName, boolean value) throws GenAIExcept * * @param name Name of the model input the tensor will provide. * @param tensor Tensor to add. - * @throws GenAIException + * @throws GenAIException If the call to the GenAI native API fails. */ public void setInput(String name, Tensor tensor) throws GenAIException { if (nativeHandle == 0) { @@ -62,7 +61,7 @@ public void setInput(String name, Tensor tensor) throws GenAIException { * Add a NamedTensors as a model input. * * @param namedTensors NamedTensors to add. - * @throws GenAIException + * @throws GenAIException If the call to the GenAI native API fails. */ public void setInputs(NamedTensors namedTensors) throws GenAIException { if (nativeHandle == 0) { @@ -108,7 +107,6 @@ private native void setSearchOptionBool(long nativeHandle, String optionName, bo private native void setModelInput(long nativeHandle, String inputName, long tensorHandle) throws GenAIException; - - private native void setInputs(long nativeHandle, long namedTensorsHandle) - throws GenAIException; + + private native void setInputs(long nativeHandle, long namedTensorsHandle) throws GenAIException; } diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Images.java b/src/java/src/main/java/ai/onnxruntime/genai/Images.java index 51eaa61ac..fe60c6ff5 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/Images.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/Images.java @@ -3,34 +3,34 @@ */ package ai.onnxruntime.genai; -public class Images implements AutoCloseable{ - private long nativeHandle; +public class Images implements AutoCloseable { + private long nativeHandle; - public Images(String imagePath) throws GenAIException { - nativeHandle = loadImages(imagePath); - } + public Images(String imagePath) throws GenAIException { + nativeHandle = loadImages(imagePath); + } - @Override - public void close() { - if (nativeHandle != 0) { - destroyImages(nativeHandle); - nativeHandle = 0; - } + @Override + public void close() { + if (nativeHandle != 0) { + destroyImages(nativeHandle); + nativeHandle = 0; } + } - long nativeHandle() { - return nativeHandle; - } + long nativeHandle() { + return nativeHandle; + } - static { - try { - GenAI.init(); - } catch (Exception e) { - throw new RuntimeException("Failed to load onnxruntime-genai native libraries", e); - } + static { + try { + GenAI.init(); + } catch (Exception e) { + throw new RuntimeException("Failed to load onnxruntime-genai native libraries", e); } + } - private native long loadImages(String imagePath) throws GenAIException; + private native long loadImages(String imagePath) throws GenAIException; - private native void destroyImages(long imageshandle); + private native void destroyImages(long imageshandle); } diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Model.java b/src/java/src/main/java/ai/onnxruntime/genai/Model.java index e7e8f4783..6500dd805 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/Model.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/Model.java @@ -68,6 +68,7 @@ long nativeHandle() { } private native long createModel(String modelPath) throws GenAIException; + private native long createModelFromConfig(long configHandle) throws GenAIException; private native void destroyModel(long modelHandle); diff --git a/src/java/src/main/java/ai/onnxruntime/genai/MultiModalProcessor.java b/src/java/src/main/java/ai/onnxruntime/genai/MultiModalProcessor.java index 85670b71d..62c72a2db 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/MultiModalProcessor.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/MultiModalProcessor.java @@ -3,7 +3,10 @@ */ package ai.onnxruntime.genai; -/** The MultiModalProcessor class is responsible for converting text/images into a NamedTensors list that can be fed into a Generator class instance. */ +/** + * The MultiModalProcessor class is responsible for converting text/images into a NamedTensors list + * that can be fed into a Generator class instance. + */ public class MultiModalProcessor implements AutoCloseable { private long nativeHandle; @@ -77,9 +80,11 @@ public void close() { private native void destroyMultiModalProcessor(long tokenizerHandle); - private native long processorProcessImages(long processorHandle, String prompt, long imagesHandle) throws GenAIException; + private native long processorProcessImages(long processorHandle, String prompt, long imagesHandle) + throws GenAIException; private native String processorDecode(long processorHandle, int[] sequence) throws GenAIException; - private native long createTokenizerStreamFromProcessor(long processorHandle) throws GenAIException; + private native long createTokenizerStreamFromProcessor(long processorHandle) + throws GenAIException; } diff --git a/src/java/src/main/java/ai/onnxruntime/genai/NamedTensors.java b/src/java/src/main/java/ai/onnxruntime/genai/NamedTensors.java index 5811e554f..b7b27f71d 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/NamedTensors.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/NamedTensors.java @@ -3,32 +3,32 @@ */ package ai.onnxruntime.genai; -public class NamedTensors implements AutoCloseable{ - private long nativeHandle; +public class NamedTensors implements AutoCloseable { + private long nativeHandle; - public NamedTensors(long handle) { - nativeHandle = handle; - } + public NamedTensors(long handle) { + nativeHandle = handle; + } - @Override - public void close() { - if (nativeHandle != 0) { - destroyNamedTensors(nativeHandle); - nativeHandle = 0; - } + @Override + public void close() { + if (nativeHandle != 0) { + destroyNamedTensors(nativeHandle); + nativeHandle = 0; } + } - long nativeHandle() { - return nativeHandle; - } + long nativeHandle() { + return nativeHandle; + } - static { - try { - GenAI.init(); - } catch (Exception e) { - throw new RuntimeException("Failed to load onnxruntime-genai native libraries", e); - } + static { + try { + GenAI.init(); + } catch (Exception e) { + throw new RuntimeException("Failed to load onnxruntime-genai native libraries", e); } + } - private native void destroyNamedTensors(long handle); + private native void destroyNamedTensors(long handle); } diff --git a/src/java/src/main/java/ai/onnxruntime/genai/SimpleGenAI.java b/src/java/src/main/java/ai/onnxruntime/genai/SimpleGenAI.java index 124933b53..cfd230e7e 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/SimpleGenAI.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/SimpleGenAI.java @@ -25,7 +25,7 @@ * Create a class that implements the TokenUpdateListener interface and provide an instance of that * class as the `listener` argument. */ -public class SimpleGenAI { +public class SimpleGenAI implements AutoCloseable { private Model model; private Tokenizer tokenizer; @@ -99,4 +99,16 @@ public String generate(GeneratorParams generatorParams, String prompt, Consumer< return result; } + + @Override + public void close() { + if (tokenizer != null) { + tokenizer.close(); + tokenizer = null; + } + if (model != null) { + model.close(); + model = null; + } + } } diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Tensor.java b/src/java/src/main/java/ai/onnxruntime/genai/Tensor.java index d5946c23e..6c98ba488 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/Tensor.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/Tensor.java @@ -40,7 +40,7 @@ public enum ElementType { * @param data The data for the Tensor. Must be a direct ByteBuffer with native byte order. * @param shape The shape of the Tensor. * @param elementType The type of elements in the Tensor. - * @throws GenAIException + * @throws GenAIException If the call to the GenAI native API fails. */ public Tensor(ByteBuffer data, long[] shape, ElementType elementType) throws GenAIException { if (data == null || shape == null || elementType == ElementType.undefined) { @@ -62,11 +62,40 @@ public Tensor(ByteBuffer data, long[] shape, ElementType elementType) throws Gen this.elementType = elementType; this.shape = shape; - this.dataBuffer = data; // save a reference so the owning buffer will stay around. + this.dataBuffer = data; // save a reference so the owning buffer will stay around. nativeHandle = createTensor(data, shape, elementType.ordinal()); } + /** + * Construct a Tensor from native handle. + * + * @param handle The native tensor handle. + */ + Tensor(long handle) { + nativeHandle = handle; + elementType = ElementType.values()[getTensorType(handle)]; + shape = getTensorShape(handle); + } + + /** + * Get the element type. + * + * @return The element type. + */ + public ElementType getType() { + return this.elementType; + } + + /** + * Get the tensor shape. + * + * @return The tensor type. + */ + public long[] getShape() { + return this.shape; + } + @Override public void close() { if (nativeHandle != 0) { @@ -91,4 +120,8 @@ private native long createTensor(ByteBuffer data, long[] shape, int elementType) throws GenAIException; private native void destroyTensor(long tensorHandle); + + private native int getTensorType(long tensorHandle); + + private native long[] getTensorShape(long tensorHandle); } diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Adapters.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Adapters.cpp new file mode 100644 index 000000000..b905c2906 --- /dev/null +++ b/src/java/src/main/native/ai_onnxruntime_genai_Adapters.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +#include "ai_onnxruntime_genai_Adapters.h" + +#include "ort_genai_c.h" +#include "utils.h" + +using namespace Helpers; + +JNIEXPORT jlong JNICALL +Java_ai_onnxruntime_genai_Adapters_createAdapters(JNIEnv* env, jobject thiz, jlong model_handle) { + const OgaModel* model = reinterpret_cast(model_handle); + OgaAdapters* adapters = nullptr; + if (ThrowIfError(env, OgaCreateAdapters(model, &adapters))) { + return 0; + } + + return reinterpret_cast(adapters); +} + +JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_Adapters_destroyAdapters(JNIEnv* env, jobject thiz, jlong native_handle) { + OgaDestroyAdapters(reinterpret_cast(native_handle)); +} + +JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_Adapters_loadAdapter(JNIEnv* env, jobject thiz, jlong native_handle, + jstring adapter_file_path, jstring adapter_name) { + CString file_path{env, adapter_file_path}; + CString name{env, adapter_name}; + ThrowIfError(env, OgaLoadAdapter(reinterpret_cast(native_handle), file_path, name)); +} + +JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_Adapters_unloadAdapter(JNIEnv* env, jobject thiz, jlong native_handle, jstring adapter_name) { + CString name{env, adapter_name}; + ThrowIfError(env, OgaUnloadAdapter(reinterpret_cast(native_handle), name)); +} diff --git a/src/java/src/main/native/ai_onnxruntime_genai_GenAI.cpp b/src/java/src/main/native/ai_onnxruntime_genai_GenAI.cpp new file mode 100644 index 000000000..87abd4aae --- /dev/null +++ b/src/java/src/main/native/ai_onnxruntime_genai_GenAI.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +#include "ai_onnxruntime_genai_GenAI.h" + +#include "ort_genai_c.h" +#include "utils.h" + +using namespace Helpers; + +JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_GenAI_shutdown(JNIEnv* env, jclass cls) { + OgaShutdown(); +} \ No newline at end of file diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp index c7b52839b..251e076f7 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp @@ -78,9 +78,9 @@ Java_ai_onnxruntime_genai_Generator_getSequenceNative(JNIEnv* env, jobject thiz, // as there's no 'destroy' function in GenAI C API for the tokens we assume the OgaGenerator owns the memory. // copy the tokens so there's no potential for Java code to write to it (values should be treated as const) // or attempt to access the memory after the OgaGenerator is destroyed. - jintArray java_int_array = env->NewIntArray(num_tokens); + jintArray java_int_array = env->NewIntArray(static_cast(num_tokens)); // jint is `long` on Windows and `int` on linux. 32-bit but requires reinterpret_cast. - env->SetIntArrayRegion(java_int_array, 0, num_tokens, reinterpret_cast(tokens)); + env->SetIntArrayRegion(java_int_array, 0, static_cast(num_tokens), reinterpret_cast(tokens)); return java_int_array; } @@ -99,3 +99,23 @@ Java_ai_onnxruntime_genai_Generator_getSequenceLastToken(JNIEnv* env, jobject th return jint(tokens[num_tokens - 1]); } + +JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_Generator_setActiveAdapter(JNIEnv* env, jobject thiz, jlong native_handle, + jlong adapters_native_handle, jstring adapter_name) { + CString name{env, adapter_name}; + ThrowIfError(env, OgaSetActiveAdapter(reinterpret_cast(native_handle), + reinterpret_cast(adapters_native_handle), + name)); +} + +JNIEXPORT jlong JNICALL +Java_ai_onnxruntime_genai_Generator_getOutputNative(JNIEnv* env, jobject thiz, jlong native_handle, + jstring output_name) { + OgaTensor* tensor = nullptr; + CString name{env, output_name}; + if (ThrowIfError(env, OgaGenerator_GetOutput(reinterpret_cast(native_handle), name, &tensor))) { + return 0; + } + return reinterpret_cast(tensor); +} diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Sequences.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Sequences.cpp index 0eff851a2..7f5126e63 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_Sequences.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_Sequences.cpp @@ -34,8 +34,8 @@ Java_ai_onnxruntime_genai_Sequences_getSequenceNative(JNIEnv* env, jobject thiz, // copy the tokens so there's no potential for Java code to write to it (values should be treated as const), // or attempt to access the memory after the OgaSequences is destroyed. // note: jint is `long` on Windows and `int` on linux. both are 32-bit but require reinterpret_cast. - jintArray java_int_array = env->NewIntArray(num_tokens); - env->SetIntArrayRegion(java_int_array, 0, num_tokens, reinterpret_cast(tokens)); + jintArray java_int_array = env->NewIntArray(static_cast(num_tokens)); + env->SetIntArrayRegion(java_int_array, 0, static_cast(num_tokens), reinterpret_cast(tokens)); return java_int_array; } diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Tensor.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Tensor.cpp index 58d9aa9fe..12b07793d 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_Tensor.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_Tensor.cpp @@ -7,6 +7,8 @@ #include "ort_genai_c.h" #include "utils.h" +#include + using namespace Helpers; /* @@ -18,7 +20,7 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_genai_Tensor_createTensor(JNIEnv* env, jobject thiz, jobject tensor_data, jlongArray shape_dims_in, jint element_type_in) { void* data = env->GetDirectBufferAddress(tensor_data); - const int64_t* shape_dims = env->GetLongArrayElements(shape_dims_in, /*isCopy*/ 0); + const int64_t* shape_dims = reinterpret_cast(env->GetLongArrayElements(shape_dims_in, /*isCopy*/ 0)); size_t shape_dims_count = env->GetArrayLength(shape_dims_in); OgaElementType element_type = static_cast(element_type_in); OgaTensor* tensor = nullptr; @@ -39,3 +41,31 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_genai_Tensor_destroyTensor(JNIEnv* env, jobject thiz, jlong native_handle) { OgaDestroyTensor(reinterpret_cast(native_handle)); } + +JNIEXPORT jint JNICALL +Java_ai_onnxruntime_genai_Tensor_getTensorType(JNIEnv* env, jobject thiz, jlong native_handle) { + OgaElementType type; + if (ThrowIfError(env, OgaTensorGetType(reinterpret_cast(native_handle), &type))) { + return 0; + } + return static_cast(type); +} + +JNIEXPORT jlongArray JNICALL +Java_ai_onnxruntime_genai_Tensor_getTensorShape(JNIEnv* env, jobject thiz, jlong native_handle) { + OgaTensor* tensor = reinterpret_cast(native_handle); + size_t size; + if (ThrowIfError(env, OgaTensorGetShapeRank(tensor, &size))) { + return nullptr; + } + std::vector shape(size); + if (ThrowIfError(env, OgaTensorGetShape(tensor, shape.data(), shape.size()))) { + return nullptr; + } + + jlongArray result; + result = env->NewLongArray(static_cast(size)); + static_assert(sizeof(jlong) == sizeof(int64_t)); + env->SetLongArrayRegion(result, 0, static_cast(size), reinterpret_cast(shape.data())); + return result; +} diff --git a/src/java/src/test/java/ai/onnxruntime/genai/GenAITestExecutionListener.java b/src/java/src/test/java/ai/onnxruntime/genai/GenAITestExecutionListener.java new file mode 100644 index 000000000..c643218d5 --- /dev/null +++ b/src/java/src/test/java/ai/onnxruntime/genai/GenAITestExecutionListener.java @@ -0,0 +1,10 @@ +package ai.onnxruntime.genai; + +import org.junit.platform.launcher.TestExecutionListener; +import org.junit.platform.launcher.TestPlan; + +public class GenAITestExecutionListener implements TestExecutionListener { + public void testPlanExecutionFinished(TestPlan testPlan) { + GenAI.shutdown(); + } +} diff --git a/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java b/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java index d9fdabf20..d67faeb36 100644 --- a/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java +++ b/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java @@ -4,10 +4,9 @@ */ package ai.onnxruntime.genai; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; -import java.io.File; import java.util.function.Consumer; import java.util.logging.Logger; import org.junit.jupiter.api.Test; @@ -25,17 +24,7 @@ public class GenerationTest { // phi-2 can be used in full end-to-end testing but needs to be manually downloaded. // it's also used this way in the C# unit tests. private static final String phi2ModelPath() { - String repoRoot = TestUtils.getRepoRoot(); - File f = new File(repoRoot + "examples/python/example-models/phi2-int4-cpu"); - - if (!f.exists()) { - logger.warning("phi2 model not found at: " + f.getPath()); - logger.warning( - "Please install as per https://github.com/microsoft/onnxruntime-genai/tree/rel-0.2.0/examples/csharp/HelloPhi2"); - return null; - } - - return f.getPath(); + return TestUtils.getFilePathFromResource("/phi-2/int4/cpu"); } @SuppressWarnings("unused") // Used in EnabledIf @@ -43,64 +32,122 @@ private static boolean havePhi2() { return phi2ModelPath() != null; } + @SuppressWarnings("unused") // Used in EnabledIf + private static boolean haveAdapters() { + return TestUtils.testAdapterTestModelPath() != null; + } + @Test @EnabledIf("havePhi2") public void testUsageNoListener() throws GenAIException { - SimpleGenAI generator = new SimpleGenAI(phi2ModelPath()); - GeneratorParams params = generator.createGeneratorParams(); - - String result = generator.generate(params, "What's 6 times 7?", null); - logger.info("Result: " + result); - assertTrue(result.indexOf("Answer: 42") != -1); + try (SimpleGenAI generator = new SimpleGenAI(phi2ModelPath()); + GeneratorParams params = generator.createGeneratorParams(); ) { + params.setSearchOption("max_length", 20); + String result = + generator.generate(params, TestUtils.applyPhi2ChatTemplate("What's 6 times 7?"), null); + logger.info("Result: " + result); + } } @Test @EnabledIf("havePhi2") public void testUsageWithListener() throws GenAIException { - SimpleGenAI generator = new SimpleGenAI(phi2ModelPath()); - GeneratorParams params = generator.createGeneratorParams(); - Consumer listener = token -> logger.info("onTokenGenerate: " + token); - String result = generator.generate(params, "What's 6 times 7?", listener); + try (SimpleGenAI generator = new SimpleGenAI(phi2ModelPath()); + GeneratorParams params = generator.createGeneratorParams(); ) { + params.setSearchOption("max_length", 20); + Consumer listener = token -> logger.info("onTokenGenerate: " + token); + String result = + generator.generate( + params, TestUtils.applyPhi2ChatTemplate("What's 6 times 7?"), listener); - logger.info("Result: " + result); - assertTrue(result.indexOf("Answer: 42") != -1); + logger.info("Result: " + result); + } + } + + @Test + @EnabledIf("haveAdapters") + public void testUsageWithAdapters() throws GenAIException { + try (Model model = new Model(TestUtils.testAdapterTestModelPath()); + Tokenizer tokenizer = model.createTokenizer()) { + String[] prompts = { + TestUtils.applyPhi2ChatTemplate("def is_prime(n):"), + TestUtils.applyPhi2ChatTemplate("def compute_gcd(x, y):"), + TestUtils.applyPhi2ChatTemplate("def binary_search(arr, x):"), + }; + + try (Sequences sequences = tokenizer.encodeBatch(prompts); + GeneratorParams params = model.createGeneratorParams()) { + params.setSearchOption("max_length", 200); + params.setSearchOption("batch_size", prompts.length); + + long[] outputShape; + + try (Generator generator = new Generator(model, params); ) { + generator.appendTokenSequences(sequences); + while (!generator.isDone()) { + generator.generateNextToken(); + } + + try (Tensor logits = generator.getOutput("logits")) { + outputShape = logits.getShape(); + assertEquals(logits.getType(), Tensor.ElementType.float32); + } + } + + try (Adapters adapters = new Adapters(model); + Generator generator = new Generator(model, params); ) { + generator.appendTokenSequences(sequences); + adapters.loadAdapters(TestUtils.testAdapterTestAdaptersPath(), "adapters_a_and_b"); + generator.setActiveAdapter(adapters, "adapters_a_and_b"); + while (!generator.isDone()) { + generator.generateNextToken(); + } + try (Tensor logits = generator.getOutput("logits")) { + assertEquals(logits.getType(), Tensor.ElementType.float32); + assertArrayEquals(outputShape, logits.getShape()); + } + } + } + } } @Test public void testWithInputIds() throws GenAIException { // test using the HF model. input id values must be < 1000 so we use manually created input. // Input/expected output copied from the C# unit tests - Config config = new Config(TestUtils.testModelPath()); - Model model = new Model(config); - GeneratorParams params = new GeneratorParams(model); - int batchSize = 2; - int sequenceLength = 4; - int maxLength = 10; - int[] inputIDs = - new int[] { - 0, 0, 0, 52, - 0, 0, 195, 731 - }; - - params.setSearchOption("max_length", maxLength); - params.setSearchOption("batch_size", batchSize); - - int[] expectedOutput = - new int[] { - 0, 0, 0, 52, 204, 204, 204, 204, 204, 204, - 0, 0, 195, 731, 731, 114, 114, 114, 114, 114 - }; - - Generator generator = new Generator(model, params); - generator.appendTokens(inputIDs); - while (!generator.isDone()) { - generator.generateNextToken(); - } - - for (int i = 0; i < batchSize; i++) { - int[] outputIds = generator.getSequence(i); - for (int j = 0; j < maxLength; j++) { - assertEquals(outputIds[j], expectedOutput[i * maxLength + j]); + try (Config config = new Config(TestUtils.tinyGpt2ModelPath()); + Model model = new Model(config); + GeneratorParams params = new GeneratorParams(model); ) { + int batchSize = 2; + int sequenceLength = 4; + int maxLength = 10; + int[] inputIDs = + new int[] { + 0, 0, 0, 52, + 0, 0, 195, 731 + }; + + params.setSearchOption("max_length", maxLength); + params.setSearchOption("batch_size", batchSize); + + int[] expectedOutput = + new int[] { + 0, 0, 0, 52, 204, 204, 204, 204, 204, 204, + 0, 0, 195, 731, 731, 114, 114, 114, 114, 114 + }; + + try (Generator generator = new Generator(model, params); ) { + generator.appendTokens(inputIDs); + while (!generator.isDone()) { + generator.generateNextToken(); + } + + for (int i = 0; i < batchSize; i++) { + int[] outputIds = generator.getSequence(i); + for (int j = 0; j < maxLength; j++) { + assertEquals(outputIds[j], expectedOutput[i * maxLength + j]); + } + } } } } diff --git a/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java b/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java index bb0f886a5..cd8698de3 100644 --- a/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java +++ b/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java @@ -13,17 +13,19 @@ public class GeneratorParamsTest { @Test public void testValidSearchOption() throws GenAIException { // test setting an invalid search option throws a GenAIException - SimpleGenAI generator = new SimpleGenAI(TestUtils.testModelPath()); - GeneratorParams params = generator.createGeneratorParams(); - params.setSearchOption("early_stopping", true); // boolean - params.setSearchOption("max_length", 20); // number + try (SimpleGenAI generator = new SimpleGenAI(TestUtils.tinyGpt2ModelPath()); + GeneratorParams params = generator.createGeneratorParams(); ) { + params.setSearchOption("early_stopping", true); // boolean + params.setSearchOption("max_length", 20); // number + } } @Test public void testInvalidSearchOption() throws GenAIException { // test setting an invalid search option throws a GenAIException - SimpleGenAI generator = new SimpleGenAI(TestUtils.testModelPath()); - GeneratorParams params = generator.createGeneratorParams(); - assertThrows(GenAIException.class, () -> params.setSearchOption("invalid", true)); + try (SimpleGenAI generator = new SimpleGenAI(TestUtils.tinyGpt2ModelPath()); + GeneratorParams params = generator.createGeneratorParams(); ) { + assertThrows(GenAIException.class, () -> params.setSearchOption("invalid", true)); + } } } diff --git a/src/java/src/test/java/ai/onnxruntime/genai/MultiModalProcessorTest.java b/src/java/src/test/java/ai/onnxruntime/genai/MultiModalProcessorTest.java index 4cdb0f84b..7f90692d2 100644 --- a/src/java/src/test/java/ai/onnxruntime/genai/MultiModalProcessorTest.java +++ b/src/java/src/test/java/ai/onnxruntime/genai/MultiModalProcessorTest.java @@ -4,33 +4,28 @@ */ package ai.onnxruntime.genai; -import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import java.util.logging.Logger; import org.junit.jupiter.api.Test; // NOTE: Typical usage is covered in GenerationTest.java so we are just filling test gaps here. public class MultiModalProcessorTest { - @Test - public void testBatchEncodeDecode() throws GenAIException { - try (Model model = new Model(TestUtils.testModelPath()); - MultiModalProcessor multiModalProcessor = new MultiModalProcessor(model)) { - TokenizerStream stream = multiModalProcessor.createStream(); - GeneratorParams generatorParams = model.createGeneratorParams(); - String inputs = new String("This is a test"); - Images image = new Images("/src/java/src/test/java/ai/onnxruntime/genai/landscape.jpg"); - NamedTensors processed = multiModalProcessor.processImages(inputs, image); - generatorParams.setInputs(processed); + private static final Logger logger = Logger.getLogger(MultiModalProcessorTest.class.getName()); - Generator generator = new Generator(model, generatorParams); - - String fullAnswer = new String(); - while (!generator.isDone()) { - generator.generateNextToken(); - - int token = generator.getLastTokenInSequence(0); - - fullAnswer += stream.decode(token); - } - } + @Test + public void testBatchEncodeDecode() throws GenAIException { + try (Model model = new Model(TestUtils.testVisionModelPath()); + MultiModalProcessor multiModalProcessor = new MultiModalProcessor(model); + TokenizerStream stream = multiModalProcessor.createStream(); + GeneratorParams generatorParams = model.createGeneratorParams()) { + String inputs = + new String( + "<|user|>\n<|image_1|>\n Can you convert the table to markdown format?\n<|end|>\n<|assistant|>\n"); + try (Images image = new Images(TestUtils.getFilePathFromResource("/images/sheet.png")); + NamedTensors processed = multiModalProcessor.processImages(inputs, image); ) { + assertNotNull(processed); + } } + } } diff --git a/src/java/src/test/java/ai/onnxruntime/genai/TensorTest.java b/src/java/src/test/java/ai/onnxruntime/genai/TensorTest.java index 60edc8439..07833c679 100644 --- a/src/java/src/test/java/ai/onnxruntime/genai/TensorTest.java +++ b/src/java/src/test/java/ai/onnxruntime/genai/TensorTest.java @@ -7,6 +7,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.FloatBuffer; import org.junit.jupiter.api.Test; @@ -15,20 +16,21 @@ public class TensorTest { @Test public void testAddTensorInput() throws GenAIException { // test setting an invalid search option throws a GenAIException - SimpleGenAI generator = new SimpleGenAI(TestUtils.testModelPath()); - GeneratorParams params = generator.createGeneratorParams(); - long[] shape = {2, 2}; - Tensor.ElementType elementType = Tensor.ElementType.float32; - ByteBuffer data = ByteBuffer.allocateDirect(4 * Float.BYTES); - - FloatBuffer floatBuffer = data.asFloatBuffer(); - floatBuffer.put(new float[] {1.0f, 2.0f, 3.0f, 4.0f}); - Tensor tensor = new Tensor(data, shape, elementType); - - // no error on setting. - // assuming there's an error on execution if an invalid input has been provided so the user is - // aware of the issue - params.setInput("unknown_value", tensor); + try (SimpleGenAI generator = new SimpleGenAI(TestUtils.tinyGpt2ModelPath()); + GeneratorParams params = generator.createGeneratorParams(); ) { + long[] shape = {2, 2}; + Tensor.ElementType elementType = Tensor.ElementType.float32; + ByteBuffer data = ByteBuffer.allocateDirect(4 * Float.BYTES).order(ByteOrder.nativeOrder()); + + FloatBuffer floatBuffer = data.asFloatBuffer(); + floatBuffer.put(new float[] {1.0f, 2.0f, 3.0f, 4.0f}); + try (Tensor tensor = new Tensor(data, shape, elementType)) { + // no error on setting. + // assuming there's an error on execution if an invalid input has been provided so the user + // is aware of the issue + params.setInput("unknown_value", tensor); + } + } } @Test @@ -42,7 +44,7 @@ public void testInvalidParams() throws GenAIException { // missing data assertThrows(GenAIException.class, () -> new Tensor(null, shape, elementType)); - ByteBuffer data = ByteBuffer.allocateDirect(4 * Float.BYTES); + ByteBuffer data = ByteBuffer.allocateDirect(4 * Float.BYTES).order(ByteOrder.nativeOrder()); // missing shape assertThrows(GenAIException.class, () -> new Tensor(data, null, elementType)); diff --git a/src/java/src/test/java/ai/onnxruntime/genai/TestUtils.java b/src/java/src/test/java/ai/onnxruntime/genai/TestUtils.java index bc60b8dba..e84d9a507 100644 --- a/src/java/src/test/java/ai/onnxruntime/genai/TestUtils.java +++ b/src/java/src/test/java/ai/onnxruntime/genai/TestUtils.java @@ -11,16 +11,20 @@ public class TestUtils { private static final Logger logger = Logger.getLogger(TestUtils.class.getName()); - public static final String testModelPath() { - // get the resources directory from one of the classes - URL url = TestUtils.class.getResource("/hf-internal-testing/tiny-random-gpt2-fp32"); - if (url == null) { - logger.warning("Model not found at /hf-internal-testing/tiny-random-gpt2-fp32"); - return null; - } + public static final String testAdapterTestModelPath() { + return getFilePathFromResource("/adapters"); + } - File f = new File(url.getFile()); - return f.getPath(); + public static final String testAdapterTestAdaptersPath() { + return getFilePathFromResource("/adapters/adapters.onnx_adapter"); + } + + public static final String tinyGpt2ModelPath() { + return getFilePathFromResource("/hf-internal-testing/tiny-random-gpt2-fp32"); + } + + public static final String testVisionModelPath() { + return getFilePathFromResource("/vision-preprocessing"); } public static final String getRepoRoot() { @@ -43,4 +47,24 @@ public static final boolean setLocalNativeLibraryPath() { System.setProperty("onnxruntime-genai.native.path", fullPath.getPath()); return true; } + + public static final String getFilePathFromResource(String path) { + // get the resources directory from one of the classes + URL url = TestUtils.class.getResource(path); + if (url == null) { + logger.warning("Model not found at " + path); + return null; + } + + File f = new File(url.getFile()); + return f.getPath(); + } + + public static final String applyPhi2ChatTemplate(String question) { + return "User: " + question + "Assistant:"; + } + + public static final String applyPhi3ChatTemplate(String question) { + return "<|user|>" + question + "<|end|><|assistant|>"; + } } diff --git a/src/java/src/test/java/ai/onnxruntime/genai/TokenizerTest.java b/src/java/src/test/java/ai/onnxruntime/genai/TokenizerTest.java index 72aff5dcd..ca595fbf1 100644 --- a/src/java/src/test/java/ai/onnxruntime/genai/TokenizerTest.java +++ b/src/java/src/test/java/ai/onnxruntime/genai/TokenizerTest.java @@ -12,15 +12,16 @@ public class TokenizerTest { @Test public void testBatchEncodeDecode() throws GenAIException { - try (Model model = new Model(TestUtils.testModelPath()); + try (Model model = new Model(TestUtils.tinyGpt2ModelPath()); Tokenizer tokenizer = new Tokenizer(model)) { String[] inputs = new String[] {"This is a test", "This is another test"}; - Sequences encoded = tokenizer.encodeBatch(inputs); - String[] decoded = tokenizer.decodeBatch(encoded); + try (Sequences encoded = tokenizer.encodeBatch(inputs)) { + String[] decoded = tokenizer.decodeBatch(encoded); - assertEquals(inputs.length, decoded.length); - for (int i = 0; i < inputs.length; i++) { - assert inputs[i].equals(decoded[i]); + assertEquals(inputs.length, decoded.length); + for (int i = 0; i < inputs.length; i++) { + assert inputs[i].equals(decoded[i]); + } } } } diff --git a/src/java/src/test/resources/META-INF/services/org.junit.platform.launcher.TestExecutionListener b/src/java/src/test/resources/META-INF/services/org.junit.platform.launcher.TestExecutionListener new file mode 100644 index 000000000..34dece8a6 --- /dev/null +++ b/src/java/src/test/resources/META-INF/services/org.junit.platform.launcher.TestExecutionListener @@ -0,0 +1 @@ +ai.onnxruntime.genai.GenAITestExecutionListener \ No newline at end of file diff --git a/src/java/src/test/java/ai/onnxruntime/genai/landscape.jpg b/test/test_models/images/landscape.jpg similarity index 100% rename from src/java/src/test/java/ai/onnxruntime/genai/landscape.jpg rename to test/test_models/images/landscape.jpg