Skip to content

Commit

Permalink
Merge pull request #201 from sebastienblanc/ollama-embedding
Browse files Browse the repository at this point in the history
Add embeddings endpoint support for Ollama
  • Loading branch information
geoand authored Jan 3, 2024
2 parents fd2dd81 + 2743ce2 commit cd127b7
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.quarkiverse.langchain4j.ollama.deployment;

import java.util.Optional;

import io.quarkus.runtime.annotations.ConfigDocDefault;
import io.quarkus.runtime.annotations.ConfigGroup;

@ConfigGroup
public interface EmbeddingModelBuildConfig {

/**
* Whether the model should be enabled
*/
@ConfigDocDefault("true")
Optional<Boolean> enabled();
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,9 @@ public interface Langchain4jOllamaOpenAiBuildConfig {
* Chat model related settings
*/
ChatModelBuildConfig chatModel();

/**
* Embedding model related settings
*/
EmbeddingModelBuildConfig embeddingModel();
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package io.quarkiverse.langchain4j.ollama.deployment;

import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.CHAT_MODEL;
import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.EMBEDDING_MODEL;

import java.util.Optional;

import jakarta.enterprise.context.ApplicationScoped;

import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem;
import io.quarkiverse.langchain4j.ollama.runtime.OllamaRecorder;
import io.quarkiverse.langchain4j.ollama.runtime.config.Langchain4jOllamaConfig;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
Expand All @@ -29,18 +32,22 @@ FeatureBuildItem feature() {

@BuildStep
public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem> chatProducer,
BuildProducer<EmbeddingModelProviderCandidateBuildItem> embeddingProducer,
Langchain4jOllamaOpenAiBuildConfig config) {
if (config.chatModel().enabled().isEmpty() || config.chatModel().enabled().get()) {
chatProducer.produce(new ChatModelProviderCandidateBuildItem(PROVIDER));
}
if (config.embeddingModel().enabled().isEmpty() || config.embeddingModel().enabled().get()) {
embeddingProducer.produce(new EmbeddingModelProviderCandidateBuildItem(PROVIDER));
}
}

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@BuildStep
@Record(ExecutionTime.RUNTIME_INIT)
void generateBeans(OllamaRecorder recorder,
Optional<SelectedChatModelProviderBuildItem> selectedChatItem,
// Optional<SelectedEmbeddingModelCandidateBuildItem> selectedEmbedding,
Optional<SelectedEmbeddingModelCandidateBuildItem> selectedEmbedding,
Langchain4jOllamaConfig config,
BuildProducer<SyntheticBeanBuildItem> beanProducer) {
if (selectedChatItem.isPresent() && PROVIDER.equals(selectedChatItem.get().getProvider())) {
Expand All @@ -53,14 +60,14 @@ void generateBeans(OllamaRecorder recorder,
.done());
}

// if (selectedEmbedding.isPresent() && PROVIDER.equals(selectedEmbedding.get().getProvider())) {
// beanProducer.produce(SyntheticBeanBuildItem
// .configure(EMBEDDING_MODEL)
// .setRuntimeInit()
// .defaultBean()
// .scope(ApplicationScoped.class)
// .supplier(recorder.embeddingModel(config))
// .done());
// }
if (selectedEmbedding.isPresent() && PROVIDER.equals(selectedEmbedding.get().getProvider())) {
beanProducer.produce(SyntheticBeanBuildItem
.configure(EMBEDDING_MODEL)
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(recorder.embeddingModel(config))
.done());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package io.quarkiverse.langchain4j.ollama;

public class EmbeddingRequest {

private final String model;
private final String prompt;

private EmbeddingRequest(Builder builder) {
model = builder.model;
prompt = builder.prompt;
}

public static Builder builder() {
return new Builder();
}

public String getModel() {
return model;
}

public String getPrompt() {
return prompt;
}

public static final class Builder {
private String model = "llama2";
private String prompt;

private Builder() {
}

public Builder model(String val) {
model = val;
return this;
}

public Builder prompt(String val) {
prompt = val;
return this;
}

public EmbeddingRequest build() {
return new EmbeddingRequest(this);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package io.quarkiverse.langchain4j.ollama;

import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;

@JsonDeserialize(builder = EmbeddingResponse.Builder.class)
public class EmbeddingResponse {

private float[] embedding;

private EmbeddingResponse(Builder builder) {
embedding = builder.embedding;
}

public float[] getEmbedding() {
return embedding;
}

public void setEmbedding(float[] embedding) {
this.embedding = embedding;
}

@JsonPOJOBuilder(withPrefix = "")
public static final class Builder {
private float[] embedding;

private Builder() {
}

public Builder embedding(float[] val) {
embedding = val;
return this;
}

public EmbeddingResponse build() {
return new EmbeddingResponse(this);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,8 @@ public OllamaClient(String baseUrl, Duration timeout, boolean logRequests, boole
public CompletionResponse completion(CompletionRequest request) {
return restApi.generate(request);
}

public EmbeddingResponse embedding(EmbeddingRequest request) {
return restApi.embeddings(request);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package io.quarkiverse.langchain4j.ollama;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;

public class OllamaEmbeddingModel implements EmbeddingModel {

private final OllamaClient client;
private final String model;

private OllamaEmbeddingModel(Builder builder) {
client = new OllamaClient(builder.baseUrl, builder.timeout, builder.logRequests, builder.logResponses);
model = builder.model;
}

public static Builder builder() {
return new Builder();
}

@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
List<Embedding> embeddings = new ArrayList<>();

textSegments.forEach(textSegment -> {
EmbeddingRequest request = EmbeddingRequest.builder()
.model(model)
.prompt(textSegment.text())
.build();

EmbeddingResponse response = client.embedding(request);

embeddings.add(Embedding.from(response.getEmbedding()));
});

return Response.from(embeddings);
}

public static final class Builder {
private String baseUrl = "http://localhost:11434";
private Duration timeout = Duration.ofSeconds(10);
private String model;

private boolean logRequests = false;
private boolean logResponses = false;

private Builder() {
}

public Builder baseUrl(String val) {
baseUrl = val;
return this;
}

public Builder timeout(Duration val) {
this.timeout = val;
return this;
}

public Builder model(String val) {
model = val;
return this;
}

public Builder logRequests(boolean logRequests) {
this.logRequests = logRequests;
return this;
}

public Builder logResponses(boolean logResponses) {
this.logResponses = logResponses;
return this;
}

public OllamaEmbeddingModel build() {
return new OllamaEmbeddingModel(this);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ public interface OllamaRestApi {
@POST
CompletionResponse generate(CompletionRequest request);

@Path("/api/embeddings")
@POST
EmbeddingResponse embeddings(EmbeddingRequest request);

@ClientObjectMapper
static ObjectMapper objectMapper(ObjectMapper defaultObjectMapper) {
return QuarkusJsonCodecFactory.SnakeCaseObjectMapperHolder.MAPPER;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.util.function.Supplier;

import io.quarkiverse.langchain4j.ollama.OllamaChatLanguageModel;
import io.quarkiverse.langchain4j.ollama.OllamaEmbeddingModel;
import io.quarkiverse.langchain4j.ollama.Options;
import io.quarkiverse.langchain4j.ollama.runtime.config.ChatModelConfig;
import io.quarkiverse.langchain4j.ollama.runtime.config.Langchain4jOllamaConfig;
Expand Down Expand Up @@ -34,4 +35,27 @@ public Object get() {
}
};
}

public Supplier<?> embeddingModel(Langchain4jOllamaConfig runtimeConfig) {
ChatModelConfig chatModelConfig = runtimeConfig.chatModel();
Options.Builder optionsBuilder = Options.builder()
.temperature(chatModelConfig.temperature())
.topK(chatModelConfig.topK())
.topP(chatModelConfig.topP())
.numPredict(chatModelConfig.numPredict());
if (chatModelConfig.stop().isPresent()) {
optionsBuilder.stop(chatModelConfig.stop().get());
}
var builder = OllamaEmbeddingModel.builder()
.baseUrl(runtimeConfig.baseUrl())
.timeout(runtimeConfig.timeout())
.model(chatModelConfig.modelId());

return new Supplier<>() {
@Override
public Object get() {
return builder.build();
}
};
}
}

0 comments on commit cd127b7

Please sign in to comment.