Skip to content

Commit

Permalink
[huggingface] Adds CrossEncoderTranslator
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Oct 22, 2023
1 parent 90059cd commit 0fbe2c1
Show file tree
Hide file tree
Showing 6 changed files with 644 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.nlp.translator;

import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.BytesSupplier;
import ai.djl.ndarray.NDList;
import ai.djl.translate.Batchifier;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.JsonUtils;
import ai.djl.util.PairList;
import ai.djl.util.StringPair;

import com.google.gson.JsonElement;
import com.google.gson.JsonParseException;

/** A {@link Translator} that can handle generic cross encoder {@link Input} and {@link Output}. */
public class CrossEncoderServingTranslator implements NoBatchifyTranslator<Input, Output> {

private Translator<StringPair, float[]> translator;
private Translator<StringPair[], float[][]> batchTranslator;

/**
* Constructs a {@code CrossEncoderServingTranslator} instance.
*
* @param translator a {@code Translator} processes question answering input
*/
public CrossEncoderServingTranslator(Translator<StringPair, float[]> translator) {
this.translator = translator;
this.batchTranslator = translator.toBatchTranslator();
}

/** {@inheritDoc} */
@Override
public void prepare(TranslatorContext ctx) throws Exception {
translator.prepare(ctx);
batchTranslator.prepare(ctx);
}

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
PairList<String, BytesSupplier> content = input.getContent();
if (content.isEmpty()) {
throw new TranslateException("Input data is empty.");
}

String contentType = input.getProperty("Content-Type", null);
StringPair pair;
if ("application/json".equals(contentType)) {
String json = input.getData().getAsString();
try {
JsonElement element = JsonUtils.GSON.fromJson(json, JsonElement.class);
if (element.isJsonArray()) {
ctx.setAttachment("batch", Boolean.TRUE);
StringPair[] inputs = JsonUtils.GSON.fromJson(json, StringPair[].class);
return batchTranslator.processInput(ctx, inputs);
}

pair = JsonUtils.GSON.fromJson(json, StringPair.class);
if (pair.getKey() == null || pair.getValue() == null) {
throw new TranslateException("Missing key or value in json.");
}
} catch (JsonParseException e) {
throw new TranslateException("Input is not a valid json.", e);
}
} else {
String key = input.getAsString("key");
String value = input.getAsString("value");
if (key == null || value == null) {
throw new TranslateException("Missing key or value in input.");
}
pair = new StringPair(key, value);
}

NDList ret = translator.processInput(ctx, pair);
Batchifier batchifier = translator.getBatchifier();
if (batchifier != null) {
NDList[] batch = {ret};
return batchifier.batchify(batch);
}
return ret;
}

/** {@inheritDoc} */
@Override
public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
Output output = new Output();
output.addProperty("Content-Type", "application/json");
if (ctx.getAttachment("batch") != null) {
output.add(BytesSupplier.wrapAsJson(batchTranslator.processOutput(ctx, list)));
} else {
Batchifier batchifier = translator.getBatchifier();
if (batchifier != null) {
list = batchifier.unbatchify(list)[0];
}
output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list)));
}
return output;
}
}
27 changes: 27 additions & 0 deletions api/src/main/java/ai/djl/util/StringPair.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.util;

/** A class containing the string key-value pair. */
public class StringPair extends Pair<String, String> {

/**
* Constructs a {@code Pair} instance with key and value.
*
* @param key the key
* @param value the value
*/
public StringPair(String key, String value) {
super(key, value);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.huggingface.translator;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.PairList;
import ai.djl.util.StringPair;

import java.util.Arrays;

/** The translator for Huggingface cross encoder model. */
public class CrossEncoderBatchTranslator implements NoBatchifyTranslator<StringPair[], float[][]> {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier;

CrossEncoderBatchTranslator(
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
}

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, StringPair[] inputs)
throws TranslateException {
NDManager manager = ctx.getNDManager();
PairList<String, String> list = new PairList<>(Arrays.asList(inputs));
Encoding[] encodings = tokenizer.batchEncode(list);
NDList[] batch = new NDList[encodings.length];
for (int i = 0; i < encodings.length; ++i) {
batch[i] = encodings[i].toNDList(manager, includeTokenTypes);
}
return batchifier.batchify(batch);
}

/** {@inheritDoc} */
@Override
public float[][] processOutput(TranslatorContext ctx, NDList list) {
NDList[] batch = batchifier.unbatchify(list);
float[][] ret = new float[batch.length][];
for (int i = 0; i < batch.length; ++i) {
NDArray logits = list.get(0);
NDArray result = logits.getNDArrayInternal().sigmoid();
ret[i] = result.toFloatArray();
}
return ret;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.huggingface.translator;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.StringPair;

import java.io.IOException;
import java.util.Map;

/** The translator for Huggingface cross encoder model. */
public class CrossEncoderTranslator implements Translator<StringPair, float[]> {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier;

CrossEncoderTranslator(
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
}

/** {@inheritDoc} */
@Override
public Batchifier getBatchifier() {
return batchifier;
}

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, StringPair input) {
Encoding encoding = tokenizer.encode(input.getKey(), input.getValue());
ctx.setAttachment("encoding", encoding);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes);
}

/** {@inheritDoc} */
@Override
public float[] processOutput(TranslatorContext ctx, NDList list) {
NDArray logits = list.get(0);
NDArray result = logits.getNDArrayInternal().sigmoid();
return result.toFloatArray();
}

/** {@inheritDoc} */
@Override
public CrossEncoderBatchTranslator toBatchTranslator(Batchifier batchifier) {
tokenizer.enableBatch();
return new CrossEncoderBatchTranslator(tokenizer, includeTokenTypes, batchifier);
}

/**
* Creates a builder to build a {@code CrossEncoderTranslator}.
*
* @param tokenizer the tokenizer
* @return a new builder
*/
public static Builder builder(HuggingFaceTokenizer tokenizer) {
return new Builder(tokenizer);
}

/**
* Creates a builder to build a {@code CrossEncoderTranslator}.
*
* @param tokenizer the tokenizer
* @param arguments the models' arguments
* @return a new builder
*/
public static Builder builder(HuggingFaceTokenizer tokenizer, Map<String, ?> arguments) {
Builder builder = builder(tokenizer);
builder.configure(arguments);

return builder;
}

/** The builder for question answering translator. */
public static final class Builder {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier = Batchifier.STACK;

Builder(HuggingFaceTokenizer tokenizer) {
this.tokenizer = tokenizer;
}

/**
* Sets if include token types for the {@link Translator}.
*
* @param includeTokenTypes true to include token types
* @return this builder
*/
public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
this.includeTokenTypes = includeTokenTypes;
return this;
}

/**
* Sets the {@link Batchifier} for the {@link Translator}.
*
* @param batchifier true to include token types
* @return this builder
*/
public Builder optBatchifier(Batchifier batchifier) {
this.batchifier = batchifier;
return this;
}

/**
* Configures the builder with the model arguments.
*
* @param arguments the model arguments
*/
public void configure(Map<String, ?> arguments) {
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
optBatchifier(Batchifier.fromString(batchifierStr));
}

/**
* Builds the translator.
*
* @return the new translator
* @throws IOException if I/O error occurs
*/
public CrossEncoderTranslator build() throws IOException {
return new CrossEncoderTranslator(tokenizer, includeTokenTypes, batchifier);
}
}
}
Loading

0 comments on commit 0fbe2c1

Please sign in to comment.