Roberta ONNX as classifier #3566
-
I am trying to use a Roberta model as a classifier. The model is https://huggingface.co/guishe/nuner-v1_orgs (I use as tokenizer), and the same version in ONNX (I use for classification) is https://huggingface.co/protectai/guishe-nuner-v1_orgs-onnx Then the code I run is: HuggingFaceTokenizer huggingFaceTokenizer = HuggingFaceTokenizer.newInstance("guishe/nuner-v1_orgs");
String modelLoction = Thread.currentThread()
.getContextClassLoader()
.getResource("model.onnx").toExternalForm();
Criteria<String, String> criteria = Criteria.builder()
.setTypes(String.class, String.class)
.optModelUrls(modelLoction)
.optTranslator(new NoBatchifyTranslator<String, String>() {
@Override
public String processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
return ndList.getAsString();
}
@Override
public NDList processInput(TranslatorContext ctx, String s) throws Exception {
Encoding input = huggingFaceTokenizer.encode(s);
ctx.setAttachment("encoding", input);
return input.toNDList(ctx.getNDManager(), false);
}
})
.optEngine("OnnxRuntime") // use OnnxRuntime engine by default
.build();
try (ZooModel<String, String> model = criteria.loadModel()) {
System.out.println(model.newPredictor().predict("The CNN tv"));
} So when I run it I got Unexpected input data type. Actual: (tensor(int32)) , expected: (tensor(int64)) Encoding object returns longs, so I don't know where they are cast to 32 bits. Any help is appreciated |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 9 replies
-
There is a bug introduced by this PR: #3468 For the time being, you need to use DJL 0.29.0 to workaround your issue. Will create a PR to fix this bug. |
Beta Was this translation helpful? Give feedback.
@lordofthejars
I tried 0.29.0, it works for me: