Skip to content

Commit

Permalink
BGE small model integration for text embeddings (#191)
Browse files Browse the repository at this point in the history
* Google Java Format

* Google Java Format

* Google Java Format

* added support for bge-small-en model for embeddings

* bge-small model downloading from url

---------

Co-authored-by: github-actions <>
  • Loading branch information
ArthSrivastava authored Aug 11, 2023
1 parent 663115c commit c075a61
Show file tree
Hide file tree
Showing 8 changed files with 450 additions and 3 deletions.
4 changes: 2 additions & 2 deletions Examples/json/JsonFormat.java
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public String extract(ArkRequest arkRequest) {
if (gptResponse == null || gptResponse.isEmpty()) {
System.out.println("ChatGptResponse is null. There was an error processing the request.");
return ("ChatGptResponse is empty. There was an error processing the request. Please try"
+ " again.");
+ " again.");
}

try {
Expand Down Expand Up @@ -255,7 +255,7 @@ public Object function(ArkRequest arkRequest) {
System.out.println(
"ChatGptResponse is null or empty. There was an error processing the request.");
return "ChatGptResponse is empty. There was an error processing the request. Please try"
+ " again.";
+ " again.";
} else {
try {
JsonNode jsonNode =
Expand Down
6 changes: 6 additions & 0 deletions FlySpring/edgechain-app/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ai.djl.onnxruntime</groupId>
<artifactId>onnxruntime-engine</artifactId>
<version>0.23.0</version>
<scope>runtime</scope>
</dependency>

</dependencies>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.edgechain.lib.embeddings.WordEmbeddings;
import com.edgechain.lib.endpoint.Endpoint;
import com.edgechain.lib.endpoint.impl.BgeSmallEndpoint;
import com.edgechain.lib.endpoint.impl.MiniLMEndpoint;
import com.edgechain.lib.endpoint.impl.OpenAiEndpoint;
import com.edgechain.lib.endpoint.impl.PostgresEndpoint;
Expand Down Expand Up @@ -49,8 +50,11 @@ public void upsert(String input) {
} else if (endpoint instanceof MiniLMEndpoint miniLMEndpoint) {
WordEmbeddings embeddings = miniLMEndpoint.embeddings(input, arkRequest);
this.postgresEndpoint.upsert(embeddings, this.filename, this.dimensions);
} else if(endpoint instanceof BgeSmallEndpoint bgeSmallEndpoint) {
WordEmbeddings embeddings = bgeSmallEndpoint.embeddings(input, arkRequest);
this.postgresEndpoint.upsert(embeddings, this.filename, this.dimensions);
} else
throw new RuntimeException(
"Invalid Endpoint; Only OpenAIEndpoint & MiniLMEndpoint are supported");
"Invalid Endpoint; Only OpenAIEndpoint, MiniLMEndpoint & BgeSmallEndpoint are supported");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
package com.edgechain.lib.embeddings.bgeSmall;

import ai.djl.MalformedModelException;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import com.edgechain.lib.embeddings.bgeSmall.response.BgeSmallResponse;
import com.edgechain.lib.endpoint.impl.BgeSmallEndpoint;
import com.edgechain.lib.rxjava.transformer.observable.EdgeChain;
import io.reactivex.rxjava3.core.Observable;
import org.springframework.stereotype.Service;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.LinkedList;
import java.util.List;

@Service
public class BgeSmallClient {

private BgeSmallEndpoint endpoint;

private static volatile ZooModel<String, float[]> bgeSmallEn;

public BgeSmallEndpoint getEndpoint() {
return endpoint;
}

public void setEndpoint(BgeSmallEndpoint endpoint) {
this.endpoint = endpoint;
}

public EdgeChain<BgeSmallResponse> createEmbeddings(String input) {

return new EdgeChain<>(
Observable.create(
emitter -> {
try {
Predictor<String, float[]> predictor =
loadSmallBgeEn().newPredictor();
float[] predict = predictor.predict(input);
List<Float> floatList = new LinkedList<>();
for (float v : predict) {
floatList.add(v);
}

emitter.onNext(new BgeSmallResponse(floatList));
emitter.onComplete();
} catch (final Exception e) {
emitter.onError(e);
}
}),
endpoint);
}

private ZooModel<String, float[]> loadSmallBgeEn() throws IOException {

ZooModel<String, float[]> r = bgeSmallEn;

if (r == null) {
synchronized (this) {
r = bgeSmallEn;
if (r == null) {
Path path = Paths.get("./model");
HuggingFaceTokenizer tokenizer =
HuggingFaceTokenizer.builder()
.optTokenizerPath(path)
.optManager(NDManager.newBaseManager("PyTorch"))
.build();

MyTextEmbeddingTranslator translator =
new MyTextEmbeddingTranslator(tokenizer, Batchifier.STACK, "mean", true, true);

Criteria<String, float[]> criteria =
Criteria.builder()
.setTypes(String.class, float[].class)
.optModelPath(path)
.optEngine("OnnxRuntime")
.optTranslator(translator)
.optProgress(new ProgressBar())
.build();
try {
r = criteria.loadModel();
bgeSmallEn = r;
} catch (IOException | ModelNotFoundException | MalformedModelException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
}
}
return r;
}



//Custom TextEmbeddingTranslator for BGE-Small Onnx Model
static final class MyTextEmbeddingTranslator implements Translator<String, float[]> {

private static final int[] AXIS = {0};

private HuggingFaceTokenizer tokenizer;
private Batchifier batchifier;
private boolean normalize;
private String pooling;
private boolean includeTokenTypes;

MyTextEmbeddingTranslator(
HuggingFaceTokenizer tokenizer,
Batchifier batchifier,
String pooling,
boolean normalize,
boolean includeTokenTypes) {
this.tokenizer = tokenizer;
this.batchifier = batchifier;
this.pooling = pooling;
this.normalize = normalize;
this.includeTokenTypes = includeTokenTypes;
}

/** {@inheritDoc} */
@Override
public Batchifier getBatchifier() {
return batchifier;
}

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, String input) {
Encoding encoding = tokenizer.encode(input);
ctx.setAttachment("encoding", encoding);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes);
}

/** {@inheritDoc} */
@Override
public float[] processOutput(TranslatorContext ctx, NDList list) {
Encoding encoding = (Encoding) ctx.getAttachment("encoding");
NDManager manager = ctx.getNDManager();
NDArray embeddings = processEmbedding(manager, list, encoding, pooling);
if (normalize) {
embeddings = embeddings.normalize(2, 0);
}

return embeddings.toFloatArray();
}

static NDArray processEmbedding(
NDManager manager, NDList list, Encoding encoding, String pooling) {
NDArray embedding = list.get("last_hidden_state");
if (embedding == null) {
// For Onnx model, NDArray name is not present
embedding = list.head();
}
long[] attentionMask = encoding.getAttentionMask();
try (NDManager ptManager = NDManager.newBaseManager("PyTorch")) {
NDArray inputAttentionMask = ptManager.create(attentionMask).toType(DataType.FLOAT32, true);
switch (pooling) {
case "mean":
return meanPool(embedding, inputAttentionMask, false);
case "mean_sqrt_len":
return meanPool(embedding, inputAttentionMask, true);
case "max":
return maxPool(embedding, inputAttentionMask);
case "weightedmean":
return weightedMeanPool(embedding, inputAttentionMask);
case "cls":
return embedding.get(0);
default:
throw new AssertionError("Unexpected pooling mode: " + pooling);
}
}
}

private static NDArray meanPool(NDArray embeddings, NDArray attentionMask, boolean sqrt) {
long[] shape = embeddings.getShape().getShape();
attentionMask = attentionMask.expandDims(-1).broadcast(shape);
NDArray inputAttentionMaskSum = attentionMask.sum(AXIS);
NDArray clamp = inputAttentionMaskSum.clip(1e-9, 1e12);
NDArray prod = embeddings.mul(attentionMask);
NDArray sum = prod.sum(AXIS);
if (sqrt) {
return sum.div(clamp.sqrt());
}
return sum.div(clamp);
}

private static NDArray maxPool(NDArray embeddings, NDArray inputAttentionMask) {
long[] shape = embeddings.getShape().getShape();
inputAttentionMask = inputAttentionMask.expandDims(-1).broadcast(shape);
inputAttentionMask = inputAttentionMask.eq(0);
embeddings = embeddings.duplicate();
embeddings.set(inputAttentionMask, -1e9); // Set padding tokens to large negative value

return embeddings.max(AXIS, true);
}

private static NDArray weightedMeanPool(NDArray embeddings, NDArray attentionMask) {
long[] shape = embeddings.getShape().getShape();
NDArray weight = embeddings.getManager().arange(1, shape[0] + 1);
weight = weight.expandDims(-1).broadcast(shape);

attentionMask = attentionMask.expandDims(-1).broadcast(shape).mul(weight);
NDArray maskSum = attentionMask.sum(AXIS);
NDArray embeddingSum = embeddings.mul(attentionMask).sum(AXIS);
return embeddingSum.div(maskSum);
}
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.edgechain.lib.embeddings.bgeSmall.response;

import java.util.List;

public class BgeSmallResponse {

private List<Float> embedding;

public BgeSmallResponse() {}

public BgeSmallResponse(List<Float> embedding) {
this.embedding = embedding;
}

public List<Float> getEmbedding() {
return embedding;
}

public void setEmbedding(List<Float> embedding) {
this.embedding = embedding;
}
}
Loading

0 comments on commit c075a61

Please sign in to comment.