diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java
new file mode 100644
index 000000000000..e62167a34b2d
--- /dev/null
+++ b/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java
@@ -0,0 +1,115 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.modality.nlp.translator;
+
+import ai.djl.modality.Input;
+import ai.djl.modality.Output;
+import ai.djl.ndarray.BytesSupplier;
+import ai.djl.ndarray.NDList;
+import ai.djl.translate.Batchifier;
+import ai.djl.translate.NoBatchifyTranslator;
+import ai.djl.translate.TranslateException;
+import ai.djl.translate.Translator;
+import ai.djl.translate.TranslatorContext;
+import ai.djl.util.JsonUtils;
+import ai.djl.util.PairList;
+import ai.djl.util.StringPair;
+
+import com.google.gson.JsonElement;
+import com.google.gson.JsonParseException;
+
+/** A {@link Translator} that can handle generic cross encoder {@link Input} and {@link Output}. */
+public class CrossEncoderServingTranslator implements NoBatchifyTranslator {
+
+ private Translator translator;
+ private Translator batchTranslator;
+
+ /**
+ * Constructs a {@code CrossEncoderServingTranslator} instance.
+ *
+ * @param translator a {@code Translator} processes question answering input
+ */
+ public CrossEncoderServingTranslator(Translator translator) {
+ this.translator = translator;
+ this.batchTranslator = translator.toBatchTranslator();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void prepare(TranslatorContext ctx) throws Exception {
+ translator.prepare(ctx);
+ batchTranslator.prepare(ctx);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
+ PairList content = input.getContent();
+ if (content.isEmpty()) {
+ throw new TranslateException("Input data is empty.");
+ }
+
+ String contentType = input.getProperty("Content-Type", null);
+ StringPair pair;
+ if ("application/json".equals(contentType)) {
+ String json = input.getData().getAsString();
+ try {
+ JsonElement element = JsonUtils.GSON.fromJson(json, JsonElement.class);
+ if (element.isJsonArray()) {
+ ctx.setAttachment("batch", Boolean.TRUE);
+ StringPair[] inputs = JsonUtils.GSON.fromJson(json, StringPair[].class);
+ return batchTranslator.processInput(ctx, inputs);
+ }
+
+ pair = JsonUtils.GSON.fromJson(json, StringPair.class);
+ if (pair.getKey() == null || pair.getValue() == null) {
+ throw new TranslateException("Missing key or value in json.");
+ }
+ } catch (JsonParseException e) {
+ throw new TranslateException("Input is not a valid json.", e);
+ }
+ } else {
+ String key = input.getAsString("key");
+ String value = input.getAsString("value");
+ if (key == null || value == null) {
+ throw new TranslateException("Missing key or value in input.");
+ }
+ pair = new StringPair(key, value);
+ }
+
+ NDList ret = translator.processInput(ctx, pair);
+ Batchifier batchifier = translator.getBatchifier();
+ if (batchifier != null) {
+ NDList[] batch = {ret};
+ return batchifier.batchify(batch);
+ }
+ return ret;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
+ Output output = new Output();
+ output.addProperty("Content-Type", "application/json");
+ if (ctx.getAttachment("batch") != null) {
+ output.add(BytesSupplier.wrapAsJson(batchTranslator.processOutput(ctx, list)));
+ } else {
+ Batchifier batchifier = translator.getBatchifier();
+ if (batchifier != null) {
+ list = batchifier.unbatchify(list)[0];
+ }
+ output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list)));
+ }
+ return output;
+ }
+}
diff --git a/api/src/main/java/ai/djl/util/StringPair.java b/api/src/main/java/ai/djl/util/StringPair.java
new file mode 100644
index 000000000000..a42e739614b5
--- /dev/null
+++ b/api/src/main/java/ai/djl/util/StringPair.java
@@ -0,0 +1,27 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.util;
+
+/** A class containing the string key-value pair. */
+public class StringPair extends Pair {
+
+ /**
+ * Constructs a {@code Pair} instance with key and value.
+ *
+ * @param key the key
+ * @param value the value
+ */
+ public StringPair(String key, String value) {
+ super(key, value);
+ }
+}
diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderBatchTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderBatchTranslator.java
new file mode 100644
index 000000000000..6f43c7cb480e
--- /dev/null
+++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderBatchTranslator.java
@@ -0,0 +1,69 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.huggingface.translator;
+
+import ai.djl.huggingface.tokenizers.Encoding;
+import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.NDManager;
+import ai.djl.translate.Batchifier;
+import ai.djl.translate.NoBatchifyTranslator;
+import ai.djl.translate.TranslateException;
+import ai.djl.translate.TranslatorContext;
+import ai.djl.util.PairList;
+import ai.djl.util.StringPair;
+
+import java.util.Arrays;
+
+/** The translator for Huggingface cross encoder model. */
+public class CrossEncoderBatchTranslator implements NoBatchifyTranslator {
+
+ private HuggingFaceTokenizer tokenizer;
+ private boolean includeTokenTypes;
+ private Batchifier batchifier;
+
+ CrossEncoderBatchTranslator(
+ HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
+ this.tokenizer = tokenizer;
+ this.includeTokenTypes = includeTokenTypes;
+ this.batchifier = batchifier;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDList processInput(TranslatorContext ctx, StringPair[] inputs)
+ throws TranslateException {
+ NDManager manager = ctx.getNDManager();
+ PairList list = new PairList<>(Arrays.asList(inputs));
+ Encoding[] encodings = tokenizer.batchEncode(list);
+ NDList[] batch = new NDList[encodings.length];
+ for (int i = 0; i < encodings.length; ++i) {
+ batch[i] = encodings[i].toNDList(manager, includeTokenTypes);
+ }
+ return batchifier.batchify(batch);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public float[][] processOutput(TranslatorContext ctx, NDList list) {
+ NDList[] batch = batchifier.unbatchify(list);
+ float[][] ret = new float[batch.length][];
+ for (int i = 0; i < batch.length; ++i) {
+ NDArray logits = list.get(0);
+ NDArray result = logits.getNDArrayInternal().sigmoid();
+ ret[i] = result.toFloatArray();
+ }
+ return ret;
+ }
+}
diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java
new file mode 100644
index 000000000000..b88347bc60ed
--- /dev/null
+++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java
@@ -0,0 +1,149 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.huggingface.translator;
+
+import ai.djl.huggingface.tokenizers.Encoding;
+import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.translate.ArgumentsUtil;
+import ai.djl.translate.Batchifier;
+import ai.djl.translate.Translator;
+import ai.djl.translate.TranslatorContext;
+import ai.djl.util.StringPair;
+
+import java.io.IOException;
+import java.util.Map;
+
+/** The translator for Huggingface cross encoder model. */
+public class CrossEncoderTranslator implements Translator {
+
+ private HuggingFaceTokenizer tokenizer;
+ private boolean includeTokenTypes;
+ private Batchifier batchifier;
+
+ CrossEncoderTranslator(
+ HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
+ this.tokenizer = tokenizer;
+ this.includeTokenTypes = includeTokenTypes;
+ this.batchifier = batchifier;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Batchifier getBatchifier() {
+ return batchifier;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDList processInput(TranslatorContext ctx, StringPair input) {
+ Encoding encoding = tokenizer.encode(input.getKey(), input.getValue());
+ ctx.setAttachment("encoding", encoding);
+ return encoding.toNDList(ctx.getNDManager(), includeTokenTypes);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public float[] processOutput(TranslatorContext ctx, NDList list) {
+ NDArray logits = list.get(0);
+ NDArray result = logits.getNDArrayInternal().sigmoid();
+ return result.toFloatArray();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public CrossEncoderBatchTranslator toBatchTranslator(Batchifier batchifier) {
+ tokenizer.enableBatch();
+ return new CrossEncoderBatchTranslator(tokenizer, includeTokenTypes, batchifier);
+ }
+
+ /**
+ * Creates a builder to build a {@code CrossEncoderTranslator}.
+ *
+ * @param tokenizer the tokenizer
+ * @return a new builder
+ */
+ public static Builder builder(HuggingFaceTokenizer tokenizer) {
+ return new Builder(tokenizer);
+ }
+
+ /**
+ * Creates a builder to build a {@code CrossEncoderTranslator}.
+ *
+ * @param tokenizer the tokenizer
+ * @param arguments the models' arguments
+ * @return a new builder
+ */
+ public static Builder builder(HuggingFaceTokenizer tokenizer, Map arguments) {
+ Builder builder = builder(tokenizer);
+ builder.configure(arguments);
+
+ return builder;
+ }
+
+ /** The builder for question answering translator. */
+ public static final class Builder {
+
+ private HuggingFaceTokenizer tokenizer;
+ private boolean includeTokenTypes;
+ private Batchifier batchifier = Batchifier.STACK;
+
+ Builder(HuggingFaceTokenizer tokenizer) {
+ this.tokenizer = tokenizer;
+ }
+
+ /**
+ * Sets if include token types for the {@link Translator}.
+ *
+ * @param includeTokenTypes true to include token types
+ * @return this builder
+ */
+ public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
+ this.includeTokenTypes = includeTokenTypes;
+ return this;
+ }
+
+ /**
+ * Sets the {@link Batchifier} for the {@link Translator}.
+ *
+ * @param batchifier true to include token types
+ * @return this builder
+ */
+ public Builder optBatchifier(Batchifier batchifier) {
+ this.batchifier = batchifier;
+ return this;
+ }
+
+ /**
+ * Configures the builder with the model arguments.
+ *
+ * @param arguments the model arguments
+ */
+ public void configure(Map arguments) {
+ optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
+ String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
+ optBatchifier(Batchifier.fromString(batchifierStr));
+ }
+
+ /**
+ * Builds the translator.
+ *
+ * @return the new translator
+ * @throws IOException if I/O error occurs
+ */
+ public CrossEncoderTranslator build() throws IOException {
+ return new CrossEncoderTranslator(tokenizer, includeTokenTypes, batchifier);
+ }
+ }
+}
diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslatorFactory.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslatorFactory.java
new file mode 100644
index 000000000000..f4f9af02c4ba
--- /dev/null
+++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslatorFactory.java
@@ -0,0 +1,80 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.huggingface.translator;
+
+import ai.djl.Model;
+import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
+import ai.djl.modality.Input;
+import ai.djl.modality.Output;
+import ai.djl.modality.nlp.translator.CrossEncoderServingTranslator;
+import ai.djl.translate.TranslateException;
+import ai.djl.translate.Translator;
+import ai.djl.translate.TranslatorFactory;
+import ai.djl.util.Pair;
+import ai.djl.util.StringPair;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.lang.reflect.Type;
+import java.nio.file.Path;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+/** A {@link TranslatorFactory} that creates a {@link CrossEncoderTranslatorFactory} instance. */
+public class CrossEncoderTranslatorFactory implements TranslatorFactory, Serializable {
+
+ private static final long serialVersionUID = 1L;
+
+ private static final Set> SUPPORTED_TYPES = new HashSet<>();
+
+ static {
+ SUPPORTED_TYPES.add(new Pair<>(StringPair.class, float[].class));
+ SUPPORTED_TYPES.add(new Pair<>(StringPair[].class, float[][].class));
+ SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class));
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Set> getSupportedTypes() {
+ return SUPPORTED_TYPES;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ @SuppressWarnings("unchecked")
+ public Translator newInstance(
+ Class input, Class output, Model model, Map arguments)
+ throws TranslateException {
+ Path modelPath = model.getModelPath();
+ try {
+ HuggingFaceTokenizer tokenizer =
+ HuggingFaceTokenizer.builder(arguments)
+ .optTokenizerPath(modelPath)
+ .optManager(model.getNDManager())
+ .build();
+ CrossEncoderTranslator translator =
+ CrossEncoderTranslator.builder(tokenizer, arguments).build();
+ if (input == StringPair.class && output == float[].class) {
+ return (Translator) translator;
+ } else if (input == StringPair[].class && output == float[][].class) {
+ return (Translator) translator.toBatchTranslator();
+ } else if (input == Input.class && output == Output.class) {
+ return (Translator) new CrossEncoderServingTranslator(translator);
+ }
+ throw new IllegalArgumentException("Unsupported input/output types.");
+ } catch (IOException e) {
+ throw new TranslateException("Failed to load tokenizer.", e);
+ }
+ }
+}
diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java
new file mode 100644
index 000000000000..f3ee102e3251
--- /dev/null
+++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java
@@ -0,0 +1,204 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.huggingface.tokenizers;
+
+import ai.djl.Model;
+import ai.djl.ModelException;
+import ai.djl.huggingface.translator.CrossEncoderTranslatorFactory;
+import ai.djl.inference.Predictor;
+import ai.djl.modality.Input;
+import ai.djl.modality.Output;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.NDManager;
+import ai.djl.nn.Block;
+import ai.djl.nn.LambdaBlock;
+import ai.djl.repository.zoo.Criteria;
+import ai.djl.repository.zoo.ZooModel;
+import ai.djl.translate.TranslateException;
+import ai.djl.util.JsonUtils;
+import ai.djl.util.StringPair;
+
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.HashMap;
+import java.util.Map;
+
+public class CrossEncoderTranslatorTest {
+
+ @Test
+ public void testCrossEncoderTranslator()
+ throws ModelException, IOException, TranslateException {
+ String text1 = "Sentence 1";
+ String text2 = "Sentence 2";
+ Block block =
+ new LambdaBlock(
+ a -> {
+ NDManager manager = a.getManager();
+ NDArray array = manager.create(new float[] {-0.7329f});
+ return new NDList(array);
+ },
+ "model");
+ Path modelDir = Paths.get("build/model");
+ Files.createDirectories(modelDir);
+
+ Criteria criteria =
+ Criteria.builder()
+ .setTypes(StringPair.class, float[].class)
+ .optModelPath(modelDir)
+ .optBlock(block)
+ .optEngine("PyTorch")
+ .optArgument("tokenizer", "bert-base-cased")
+ .optOption("hasParameter", "false")
+ .optTranslatorFactory(new CrossEncoderTranslatorFactory())
+ .build();
+
+ try (ZooModel model = criteria.loadModel();
+ Predictor predictor = model.newPredictor()) {
+ StringPair input = new StringPair(text1, text2);
+ float[] res = predictor.predict(input);
+ Assert.assertEquals(res[0], 0.32456556f, 0.0001);
+ }
+
+ Criteria criteria2 =
+ Criteria.builder()
+ .setTypes(Input.class, Output.class)
+ .optModelPath(modelDir)
+ .optBlock(block)
+ .optEngine("PyTorch")
+ .optArgument("tokenizer", "bert-base-cased")
+ .optOption("hasParameter", "false")
+ .optTranslatorFactory(new CrossEncoderTranslatorFactory())
+ .build();
+
+ try (ZooModel model = criteria2.loadModel();
+ Predictor predictor = model.newPredictor()) {
+ Input input = new Input();
+ input.add("key", text1);
+ input.add("value", text2);
+ Output res = predictor.predict(input);
+ float[] buf = (float[]) res.getData().getAsObject();
+ Assert.assertEquals(buf[0], 0.32455865, 0.0001);
+
+ Assert.assertThrows(TranslateException.class, () -> predictor.predict(new Input()));
+
+ Assert.assertThrows(
+ TranslateException.class,
+ () -> {
+ Input req = new Input();
+ req.add("something", "false");
+ predictor.predict(req);
+ });
+
+ Assert.assertThrows(
+ TranslateException.class,
+ () -> {
+ Input req = new Input();
+ req.addProperty("Content-Type", "application/json");
+ req.add("Invalid json");
+ predictor.predict(req);
+ });
+
+ Assert.assertThrows(
+ TranslateException.class,
+ () -> {
+ Input req = new Input();
+ req.addProperty("Content-Type", "application/json");
+ req.add(JsonUtils.GSON.toJson(new StringPair(text1, null)));
+ predictor.predict(req);
+ });
+ }
+
+ try (Model model = Model.newInstance("test")) {
+ model.setBlock(block);
+ Map options = new HashMap<>();
+ options.put("hasParameter", "false");
+ model.load(modelDir, "test", options);
+
+ CrossEncoderTranslatorFactory factory = new CrossEncoderTranslatorFactory();
+ Map arguments = new HashMap<>();
+
+ Assert.assertThrows(
+ TranslateException.class,
+ () -> factory.newInstance(String.class, Integer.class, model, arguments));
+
+ arguments.put("tokenizer", "bert-base-cased");
+
+ Assert.assertThrows(
+ IllegalArgumentException.class,
+ () -> factory.newInstance(String.class, Integer.class, model, arguments));
+ }
+ }
+
+ @Test
+ public void testCrossEncoderBatchTranslator()
+ throws ModelException, IOException, TranslateException {
+ StringPair pair1 = new StringPair("Sentence 1", "Sentence 2");
+ StringPair pair2 = new StringPair("Sentence 3", "Sentence 4");
+
+ Block block =
+ new LambdaBlock(
+ a -> {
+ NDManager manager = a.getManager();
+ NDArray array = manager.create(new float[][] {{-0.7329f}, {-0.7329f}});
+ return new NDList(array);
+ },
+ "model");
+ Path modelDir = Paths.get("build/model");
+ Files.createDirectories(modelDir);
+
+ Criteria criteria =
+ Criteria.builder()
+ .setTypes(StringPair[].class, float[][].class)
+ .optModelPath(modelDir)
+ .optBlock(block)
+ .optEngine("PyTorch")
+ .optArgument("tokenizer", "bert-base-cased")
+ .optOption("hasParameter", "false")
+ .optTranslatorFactory(new CrossEncoderTranslatorFactory())
+ .build();
+
+ try (ZooModel model = criteria.loadModel();
+ Predictor predictor = model.newPredictor()) {
+ StringPair[] inputs = {pair1, pair2};
+ float[][] res = predictor.predict(inputs);
+ Assert.assertEquals(res[1][0], 0.32455865, 0.0001);
+ }
+
+ Criteria criteria2 =
+ Criteria.builder()
+ .setTypes(Input.class, Output.class)
+ .optModelPath(modelDir)
+ .optBlock(block)
+ .optEngine("PyTorch")
+ .optArgument("tokenizer", "bert-base-cased")
+ .optOption("hasParameter", "false")
+ .optTranslatorFactory(new CrossEncoderTranslatorFactory())
+ .build();
+
+ try (ZooModel model = criteria2.loadModel();
+ Predictor predictor = model.newPredictor()) {
+ Input input = new Input();
+ input.add(JsonUtils.GSON.toJson(new StringPair[] {pair1, pair2}));
+ input.addProperty("Content-Type", "application/json");
+ Output out = predictor.predict(input);
+ float[][] buf = (float[][]) out.getData().getAsObject();
+ Assert.assertEquals(buf[0][0], 0.32455865, 0.0001);
+ }
+ }
+}