From 2ad90eaa3ed9b4c282160ebad938a0bd150ce9b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vanja=20Radulovi=C4=87?= <72040772+VanjaRadulovic@users.noreply.github.com> Date: Tue, 28 Jan 2025 04:50:35 +0100 Subject: [PATCH] Expose OnnxRuntime's getMetadata() in DJL API (#3596) * Expose OnnxRuntime getMetadata() in DJL API * Revert "Expose OnnxRuntime getMetadata() in DJL API" This reverts commit 738a329bb27107fd247194ad3981292057cfaf68. * Expose OnnxRuntime getMetadata() in DJL API --------- Co-authored-by: Vanja Radulovic --- api/src/main/java/ai/djl/nn/Block.java | 11 +++++++++++ .../ai/djl/onnxruntime/engine/OrtSymbolBlock.java | 12 ++++++++++++ 2 files changed, 23 insertions(+) diff --git a/api/src/main/java/ai/djl/nn/Block.java b/api/src/main/java/ai/djl/nn/Block.java index 56c92a23e6b..14bd85735ab 100644 --- a/api/src/main/java/ai/djl/nn/Block.java +++ b/api/src/main/java/ai/djl/nn/Block.java @@ -26,6 +26,8 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; +import java.util.Collections; +import java.util.Map; import java.util.function.Predicate; /** @@ -353,4 +355,13 @@ static void validateLayout(LayoutType[] expectedLayout, LayoutType[] actualLayou } } } + + /** + * Returns a map of all the custom metadata of the block. + * + * @return the map of {@link PairList} + */ + default Map getCustomMetadata() { + return Collections.emptyMap(); + } } diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java index 4e8df210d40..9eb0f875140 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java @@ -26,6 +26,7 @@ import ai.djl.util.PairList; import ai.onnxruntime.OnnxJavaType; import ai.onnxruntime.OnnxMap; +import ai.onnxruntime.OnnxModelMetadata; import ai.onnxruntime.OnnxSequence; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OnnxValue; @@ -128,6 +129,17 @@ public PairList describeInput() { return result; } + /** {@inheritDoc} */ + @Override + public Map getCustomMetadata() { + try { + OnnxModelMetadata modelMetadata = session.getMetadata(); + return modelMetadata.getCustomMetadata(); + } catch (OrtException e) { + throw new EngineException(e); + } + } + private NDList evaluateOutput(OrtSession.Result results) { NDList output = new NDList(); for (Map.Entry r : results) {