Skip to content

Commit

Permalink
Google Java Format
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions committed Aug 11, 2023
1 parent c075a61 commit 6c9ce6d
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public void upsert(String input) {
} else if (endpoint instanceof MiniLMEndpoint miniLMEndpoint) {
WordEmbeddings embeddings = miniLMEndpoint.embeddings(input, arkRequest);
this.postgresEndpoint.upsert(embeddings, this.filename, this.dimensions);
} else if(endpoint instanceof BgeSmallEndpoint bgeSmallEndpoint) {
} else if (endpoint instanceof BgeSmallEndpoint bgeSmallEndpoint) {
WordEmbeddings embeddings = bgeSmallEndpoint.embeddings(input, arkRequest);
this.postgresEndpoint.upsert(embeddings, this.filename, this.dimensions);
} else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,15 @@ public EdgeChain<BgeSmallResponse> createEmbeddings(String input) {
Observable.create(
emitter -> {
try {
Predictor<String, float[]> predictor =
loadSmallBgeEn().newPredictor();
float[] predict = predictor.predict(input);
List<Float> floatList = new LinkedList<>();
for (float v : predict) {
floatList.add(v);
}

emitter.onNext(new BgeSmallResponse(floatList));
emitter.onComplete();
Predictor<String, float[]> predictor = loadSmallBgeEn().newPredictor();
float[] predict = predictor.predict(input);
List<Float> floatList = new LinkedList<>();
for (float v : predict) {
floatList.add(v);
}

emitter.onNext(new BgeSmallResponse(floatList));
emitter.onComplete();
} catch (final Exception e) {
emitter.onError(e);
}
Expand All @@ -75,22 +74,22 @@ private ZooModel<String, float[]> loadSmallBgeEn() throws IOException {
if (r == null) {
Path path = Paths.get("./model");
HuggingFaceTokenizer tokenizer =
HuggingFaceTokenizer.builder()
.optTokenizerPath(path)
.optManager(NDManager.newBaseManager("PyTorch"))
.build();
HuggingFaceTokenizer.builder()
.optTokenizerPath(path)
.optManager(NDManager.newBaseManager("PyTorch"))
.build();

MyTextEmbeddingTranslator translator =
new MyTextEmbeddingTranslator(tokenizer, Batchifier.STACK, "mean", true, true);
new MyTextEmbeddingTranslator(tokenizer, Batchifier.STACK, "mean", true, true);

Criteria<String, float[]> criteria =
Criteria.builder()
.setTypes(String.class, float[].class)
.optModelPath(path)
.optEngine("OnnxRuntime")
.optTranslator(translator)
.optProgress(new ProgressBar())
.build();
Criteria.builder()
.setTypes(String.class, float[].class)
.optModelPath(path)
.optEngine("OnnxRuntime")
.optTranslator(translator)
.optProgress(new ProgressBar())
.build();
try {
r = criteria.loadModel();
bgeSmallEn = r;
Expand All @@ -104,9 +103,7 @@ private ZooModel<String, float[]> loadSmallBgeEn() throws IOException {
return r;
}



//Custom TextEmbeddingTranslator for BGE-Small Onnx Model
// Custom TextEmbeddingTranslator for BGE-Small Onnx Model
static final class MyTextEmbeddingTranslator implements Translator<String, float[]> {

private static final int[] AXIS = {0};
Expand All @@ -118,11 +115,11 @@ static final class MyTextEmbeddingTranslator implements Translator<String, float
private boolean includeTokenTypes;

MyTextEmbeddingTranslator(
HuggingFaceTokenizer tokenizer,
Batchifier batchifier,
String pooling,
boolean normalize,
boolean includeTokenTypes) {
HuggingFaceTokenizer tokenizer,
Batchifier batchifier,
String pooling,
boolean normalize,
boolean includeTokenTypes) {
this.tokenizer = tokenizer;
this.batchifier = batchifier;
this.pooling = pooling;
Expand Down Expand Up @@ -158,7 +155,7 @@ public float[] processOutput(TranslatorContext ctx, NDList list) {
}

static NDArray processEmbedding(
NDManager manager, NDList list, Encoding encoding, String pooling) {
NDManager manager, NDList list, Encoding encoding, String pooling) {
NDArray embedding = list.get("last_hidden_state");
if (embedding == null) {
// For Onnx model, NDArray name is not present
Expand Down Expand Up @@ -219,4 +216,3 @@ private static NDArray weightedMeanPool(NDArray embeddings, NDArray attentionMas
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ public BgeSmallEndpoint(String modelUrl, String tokenizerUrl) {
File modelFile = new File(MODEL_PATH);
File tokenizerFile = new File(TOKENIZER_PATH);

//check if the file already exists
if(!modelFile.exists()) downloadFile(modelUrl, MODEL_PATH);
if(!tokenizerFile.exists()) downloadFile(tokenizerUrl, TOKENIZER_PATH);
// check if the file already exists
if (!modelFile.exists()) downloadFile(modelUrl, MODEL_PATH);
if (!tokenizerFile.exists()) downloadFile(tokenizerUrl, TOKENIZER_PATH);
logger.info("Model downloaded successfully!");
}

Expand All @@ -63,7 +63,6 @@ public String getInput() {
return input;
}


public String getCallIdentifier() {
return callIdentifier;
}
Expand All @@ -87,6 +86,7 @@ public WordEmbeddings embeddings(String input, ArkRequest arkRequest) {
.map(m -> new WordEmbeddings(input, m.getEmbedding()))
.blockingGet();
}

private void downloadFile(String urlStr, String path) {

File modelFolderFile = new File(MODEL_FOLDER);
Expand All @@ -95,7 +95,6 @@ private void downloadFile(String urlStr, String path) {
modelFolderFile.mkdir();
}


ReadableByteChannel rbc = null;
FileOutputStream fos = null;
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
import retrofit2.http.POST;

public interface BgeSmallService {
@POST(value = "bgeSmall")
Single<BgeSmallResponse> embeddings(@Body BgeSmallEndpoint bgeSmallEndpoint);
@POST(value = "bgeSmall")
Single<BgeSmallResponse> embeddings(@Body BgeSmallEndpoint bgeSmallEndpoint);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,39 @@
@RequestMapping(WebConfiguration.CONTEXT_PATH + "/bgeSmall")
public class BgeSmallController {

@Autowired
private BgeSmallClient bgeSmallClient;
@Autowired private BgeSmallClient bgeSmallClient;

@Autowired private EmbeddingLogService embeddingLogService;
@Autowired private EmbeddingLogService embeddingLogService;

@Autowired private Environment env;
@Autowired private Environment env;

@PostMapping
public Single<BgeSmallResponse> embeddings(@RequestBody BgeSmallEndpoint bgeSmallEndpoint) {
@PostMapping
public Single<BgeSmallResponse> embeddings(@RequestBody BgeSmallEndpoint bgeSmallEndpoint) {

this.bgeSmallClient.setEndpoint(bgeSmallEndpoint);
this.bgeSmallClient.setEndpoint(bgeSmallEndpoint);

EdgeChain<BgeSmallResponse> edgeChain =
this.bgeSmallClient.createEmbeddings(bgeSmallEndpoint.getInput());
EdgeChain<BgeSmallResponse> edgeChain =
this.bgeSmallClient.createEmbeddings(bgeSmallEndpoint.getInput());

if (Objects.nonNull(env.getProperty("postgres.db.host"))) {
if (Objects.nonNull(env.getProperty("postgres.db.host"))) {

EmbeddingLog embeddingLog = new EmbeddingLog();
embeddingLog.setCreatedAt(LocalDateTime.now());
embeddingLog.setCallIdentifier(bgeSmallEndpoint.getCallIdentifier());
embeddingLog.setModel("bge-small-en");
EmbeddingLog embeddingLog = new EmbeddingLog();
embeddingLog.setCreatedAt(LocalDateTime.now());
embeddingLog.setCallIdentifier(bgeSmallEndpoint.getCallIdentifier());
embeddingLog.setModel("bge-small-en");

return edgeChain
.doOnNext(
c -> {
embeddingLog.setCompletedAt(LocalDateTime.now());
Duration duration =
Duration.between(embeddingLog.getCreatedAt(), embeddingLog.getCompletedAt());
embeddingLog.setLatency(duration.toMillis());
embeddingLogService.saveOrUpdate(embeddingLog);
})
.toSingle();
}

return edgeChain.toSingle();
return edgeChain
.doOnNext(
c -> {
embeddingLog.setCompletedAt(LocalDateTime.now());
Duration duration =
Duration.between(embeddingLog.getCreatedAt(), embeddingLog.getCompletedAt());
embeddingLog.setLatency(duration.toMillis());
embeddingLogService.saveOrUpdate(embeddingLog);
})
.toSingle();
}

return edgeChain.toSingle();
}
}

0 comments on commit 6c9ce6d

Please sign in to comment.