Skip to content

Commit

Permalink
Expose OnnxRuntime's getMetadata() in DJL API (#3596)
Browse files Browse the repository at this point in the history
* Expose OnnxRuntime getMetadata() in DJL API

* Revert "Expose OnnxRuntime getMetadata() in DJL API"

This reverts commit 738a329.

* Expose OnnxRuntime getMetadata() in DJL API

---------

Co-authored-by: Vanja Radulovic <[email protected]>
  • Loading branch information
VanjaRadulovic and Vanja Radulovic authored Jan 28, 2025
1 parent d381c4d commit 2ad90ea
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
11 changes: 11 additions & 0 deletions api/src/main/java/ai/djl/nn/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.function.Predicate;

/**
Expand Down Expand Up @@ -353,4 +355,13 @@ static void validateLayout(LayoutType[] expectedLayout, LayoutType[] actualLayou
}
}
}

/**
* Returns a map of all the custom metadata of the block.
*
* @return the map of {@link PairList}
*/
default Map<String, String> getCustomMetadata() {
return Collections.emptyMap();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import ai.djl.util.PairList;
import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxMap;
import ai.onnxruntime.OnnxModelMetadata;
import ai.onnxruntime.OnnxSequence;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
Expand Down Expand Up @@ -128,6 +129,17 @@ public PairList<String, Shape> describeInput() {
return result;
}

/** {@inheritDoc} */
@Override
public Map<String, String> getCustomMetadata() {
try {
OnnxModelMetadata modelMetadata = session.getMetadata();
return modelMetadata.getCustomMetadata();
} catch (OrtException e) {
throw new EngineException(e);
}
}

private NDList evaluateOutput(OrtSession.Result results) {
NDList output = new NDList();
for (Map.Entry<String, OnnxValue> r : results) {
Expand Down

0 comments on commit 2ad90ea

Please sign in to comment.