diff --git a/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/EmbeddingModelBuildConfig.java b/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/EmbeddingModelBuildConfig.java new file mode 100644 index 000000000..9d2bd0440 --- /dev/null +++ b/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/EmbeddingModelBuildConfig.java @@ -0,0 +1,16 @@ +package io.quarkiverse.langchain4j.ollama.deployment; + +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigGroup; + +@ConfigGroup +public interface EmbeddingModelBuildConfig { + + /** + * Whether the model should be enabled + */ + @ConfigDocDefault("true") + Optional enabled(); +} diff --git a/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/Langchain4jOllamaOpenAiBuildConfig.java b/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/Langchain4jOllamaOpenAiBuildConfig.java index 462180a8d..5efe1304e 100644 --- a/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/Langchain4jOllamaOpenAiBuildConfig.java +++ b/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/Langchain4jOllamaOpenAiBuildConfig.java @@ -13,4 +13,9 @@ public interface Langchain4jOllamaOpenAiBuildConfig { * Chat model related settings */ ChatModelBuildConfig chatModel(); + + /** + * Embedding model related settings + */ + EmbeddingModelBuildConfig embeddingModel(); } diff --git a/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaProcessor.java b/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaProcessor.java index 7eb28a2c0..b9b7eecb6 100644 --- a/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaProcessor.java +++ b/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaProcessor.java @@ -1,13 +1,16 @@ package io.quarkiverse.langchain4j.ollama.deployment; import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.CHAT_MODEL; +import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.EMBEDDING_MODEL; import java.util.Optional; import jakarta.enterprise.context.ApplicationScoped; import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem; +import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem; import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; +import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem; import io.quarkiverse.langchain4j.ollama.runtime.OllamaRecorder; import io.quarkiverse.langchain4j.ollama.runtime.config.Langchain4jOllamaConfig; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; @@ -29,10 +32,14 @@ FeatureBuildItem feature() { @BuildStep public void providerCandidates(BuildProducer chatProducer, + BuildProducer embeddingProducer, Langchain4jOllamaOpenAiBuildConfig config) { if (config.chatModel().enabled().isEmpty() || config.chatModel().enabled().get()) { chatProducer.produce(new ChatModelProviderCandidateBuildItem(PROVIDER)); } + if (config.embeddingModel().enabled().isEmpty() || config.embeddingModel().enabled().get()) { + embeddingProducer.produce(new EmbeddingModelProviderCandidateBuildItem(PROVIDER)); + } } @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @@ -40,7 +47,7 @@ public void providerCandidates(BuildProducer selectedChatItem, - // Optional selectedEmbedding, + Optional selectedEmbedding, Langchain4jOllamaConfig config, BuildProducer beanProducer) { if (selectedChatItem.isPresent() && PROVIDER.equals(selectedChatItem.get().getProvider())) { @@ -53,14 +60,14 @@ void generateBeans(OllamaRecorder recorder, .done()); } - // if (selectedEmbedding.isPresent() && PROVIDER.equals(selectedEmbedding.get().getProvider())) { - // beanProducer.produce(SyntheticBeanBuildItem - // .configure(EMBEDDING_MODEL) - // .setRuntimeInit() - // .defaultBean() - // .scope(ApplicationScoped.class) - // .supplier(recorder.embeddingModel(config)) - // .done()); - // } + if (selectedEmbedding.isPresent() && PROVIDER.equals(selectedEmbedding.get().getProvider())) { + beanProducer.produce(SyntheticBeanBuildItem + .configure(EMBEDDING_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.embeddingModel(config)) + .done()); + } } } diff --git a/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/EmbeddingRequest.java b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/EmbeddingRequest.java new file mode 100644 index 000000000..06ba30731 --- /dev/null +++ b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/EmbeddingRequest.java @@ -0,0 +1,46 @@ +package io.quarkiverse.langchain4j.ollama; + +public class EmbeddingRequest { + + private final String model; + private final String prompt; + + private EmbeddingRequest(Builder builder) { + model = builder.model; + prompt = builder.prompt; + } + + public static Builder builder() { + return new Builder(); + } + + public String getModel() { + return model; + } + + public String getPrompt() { + return prompt; + } + + public static final class Builder { + private String model = "llama2"; + private String prompt; + + private Builder() { + } + + public Builder model(String val) { + model = val; + return this; + } + + public Builder prompt(String val) { + prompt = val; + return this; + } + + public EmbeddingRequest build() { + return new EmbeddingRequest(this); + } + } +} diff --git a/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/EmbeddingResponse.java b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/EmbeddingResponse.java new file mode 100644 index 000000000..375b12f24 --- /dev/null +++ b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/EmbeddingResponse.java @@ -0,0 +1,40 @@ +package io.quarkiverse.langchain4j.ollama; + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder; + +@JsonDeserialize(builder = EmbeddingResponse.Builder.class) +public class EmbeddingResponse { + + private float[] embedding; + + private EmbeddingResponse(Builder builder) { + embedding = builder.embedding; + } + + public float[] getEmbedding() { + return embedding; + } + + public void setEmbedding(float[] embedding) { + this.embedding = embedding; + } + + @JsonPOJOBuilder(withPrefix = "") + public static final class Builder { + private float[] embedding; + + private Builder() { + } + + public Builder embedding(float[] val) { + embedding = val; + return this; + } + + public EmbeddingResponse build() { + return new EmbeddingResponse(this); + } + } + +} diff --git a/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaClient.java b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaClient.java index 193edcc22..74168379d 100644 --- a/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaClient.java +++ b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaClient.java @@ -34,4 +34,8 @@ public OllamaClient(String baseUrl, Duration timeout, boolean logRequests, boole public CompletionResponse completion(CompletionRequest request) { return restApi.generate(request); } + + public EmbeddingResponse embedding(EmbeddingRequest request) { + return restApi.embeddings(request); + } } diff --git a/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaEmbeddingModel.java b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaEmbeddingModel.java new file mode 100644 index 000000000..ce1fd406d --- /dev/null +++ b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaEmbeddingModel.java @@ -0,0 +1,85 @@ +package io.quarkiverse.langchain4j.ollama; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.output.Response; + +public class OllamaEmbeddingModel implements EmbeddingModel { + + private final OllamaClient client; + private final String model; + + private OllamaEmbeddingModel(Builder builder) { + client = new OllamaClient(builder.baseUrl, builder.timeout, builder.logRequests, builder.logResponses); + model = builder.model; + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public Response> embedAll(List textSegments) { + List embeddings = new ArrayList<>(); + + textSegments.forEach(textSegment -> { + EmbeddingRequest request = EmbeddingRequest.builder() + .model(model) + .prompt(textSegment.text()) + .build(); + + EmbeddingResponse response = client.embedding(request); + + embeddings.add(Embedding.from(response.getEmbedding())); + }); + + return Response.from(embeddings); + } + + public static final class Builder { + private String baseUrl = "http://localhost:11434"; + private Duration timeout = Duration.ofSeconds(10); + private String model; + + private boolean logRequests = false; + private boolean logResponses = false; + + private Builder() { + } + + public Builder baseUrl(String val) { + baseUrl = val; + return this; + } + + public Builder timeout(Duration val) { + this.timeout = val; + return this; + } + + public Builder model(String val) { + model = val; + return this; + } + + public Builder logRequests(boolean logRequests) { + this.logRequests = logRequests; + return this; + } + + public Builder logResponses(boolean logResponses) { + this.logResponses = logResponses; + return this; + } + + public OllamaEmbeddingModel build() { + return new OllamaEmbeddingModel(this); + } + } + +} diff --git a/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaRestApi.java b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaRestApi.java index 5a2d9457f..61d203e17 100644 --- a/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaRestApi.java +++ b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaRestApi.java @@ -36,6 +36,10 @@ public interface OllamaRestApi { @POST CompletionResponse generate(CompletionRequest request); + @Path("/api/embeddings") + @POST + EmbeddingResponse embeddings(EmbeddingRequest request); + @ClientObjectMapper static ObjectMapper objectMapper(ObjectMapper defaultObjectMapper) { return QuarkusJsonCodecFactory.SnakeCaseObjectMapperHolder.MAPPER; diff --git a/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/OllamaRecorder.java b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/OllamaRecorder.java index b2c559814..91c179752 100644 --- a/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/OllamaRecorder.java +++ b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/OllamaRecorder.java @@ -3,6 +3,7 @@ import java.util.function.Supplier; import io.quarkiverse.langchain4j.ollama.OllamaChatLanguageModel; +import io.quarkiverse.langchain4j.ollama.OllamaEmbeddingModel; import io.quarkiverse.langchain4j.ollama.Options; import io.quarkiverse.langchain4j.ollama.runtime.config.ChatModelConfig; import io.quarkiverse.langchain4j.ollama.runtime.config.Langchain4jOllamaConfig; @@ -34,4 +35,27 @@ public Object get() { } }; } + + public Supplier embeddingModel(Langchain4jOllamaConfig runtimeConfig) { + ChatModelConfig chatModelConfig = runtimeConfig.chatModel(); + Options.Builder optionsBuilder = Options.builder() + .temperature(chatModelConfig.temperature()) + .topK(chatModelConfig.topK()) + .topP(chatModelConfig.topP()) + .numPredict(chatModelConfig.numPredict()); + if (chatModelConfig.stop().isPresent()) { + optionsBuilder.stop(chatModelConfig.stop().get()); + } + var builder = OllamaEmbeddingModel.builder() + .baseUrl(runtimeConfig.baseUrl()) + .timeout(runtimeConfig.timeout()) + .model(chatModelConfig.modelId()); + + return new Supplier<>() { + @Override + public Object get() { + return builder.build(); + } + }; + } }