From 4e381674f97645889bb3ca6c368eea2d8533dda1 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Thu, 19 Sep 2024 18:07:09 +0300 Subject: [PATCH] Allow the use of 'quarkus.langchain4j.ollama.chat-model.model-name' This is done in order to make the property configurable in a similar way as with OpenAI --- .../ollama/deployment/OllamaProcessor.java | 8 ++ .../OllamaChatLanguageModelModelIdTest.java | 75 +++++++++++++++++++ .../OllamaChatLanguageModelModelNameTest.java | 75 +++++++++++++++++++ .../ModelIdConfigFallbackInterceptor.java | 11 +++ .../ModelIdConfigRelocateInterceptor.java | 9 +++ .../config/ModelIdToModelNameFunction.java | 26 +++++++ ...io.smallrye.config.ConfigSourceInterceptor | 2 + 7 files changed, 206 insertions(+) create mode 100644 model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaChatLanguageModelModelIdTest.java create mode 100644 model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaChatLanguageModelModelNameTest.java create mode 100644 model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ModelIdConfigFallbackInterceptor.java create mode 100644 model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ModelIdConfigRelocateInterceptor.java create mode 100644 model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ModelIdToModelNameFunction.java create mode 100644 model-providers/ollama/runtime/src/main/resources/META-INF/services/io.smallrye.config.ConfigSourceInterceptor diff --git a/model-providers/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaProcessor.java b/model-providers/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaProcessor.java index 0c8c65ed4..45753d521 100644 --- a/model-providers/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaProcessor.java +++ b/model-providers/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaProcessor.java @@ -33,7 +33,9 @@ import io.quarkus.deployment.annotations.ExecutionTime; import io.quarkus.deployment.annotations.Record; import io.quarkus.deployment.builditem.FeatureBuildItem; +import io.quarkus.deployment.builditem.nativeimage.ServiceProviderBuildItem; import io.quarkus.runtime.configuration.ConfigUtils; +import io.smallrye.config.ConfigSourceInterceptor; public class OllamaProcessor { @@ -45,6 +47,12 @@ FeatureBuildItem feature() { return new FeatureBuildItem(FEATURE); } + @BuildStep + void nativeSupport(BuildProducer serviceProviderProducer) { + serviceProviderProducer + .produce(ServiceProviderBuildItem.allProvidersFromClassPath(ConfigSourceInterceptor.class.getName())); + } + @BuildStep public void providerCandidates(BuildProducer chatProducer, BuildProducer embeddingProducer, diff --git a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaChatLanguageModelModelIdTest.java b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaChatLanguageModelModelIdTest.java new file mode 100644 index 000000000..9b1759b07 --- /dev/null +++ b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaChatLanguageModelModelIdTest.java @@ -0,0 +1,75 @@ +package io.quarkiverse.langchain4j.ollama.deployment; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.absent; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.matchingJsonPath; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static org.assertj.core.api.Assertions.assertThat; + +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.github.tomakehurst.wiremock.verification.LoggedRequest; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import io.quarkiverse.langchain4j.ollama.OllamaChatLanguageModel; +import io.quarkiverse.langchain4j.testing.internal.WiremockAware; +import io.quarkus.arc.ClientProxy; +import io.quarkus.test.QuarkusUnitTest; + +public class OllamaChatLanguageModelModelIdTest extends WiremockAware { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)) + .overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig()) + .overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false") + .overrideConfigKey("quarkus.langchain4j.ollama.chat-model.model-name", "foo") + .overrideRuntimeConfigKey("quarkus.langchain4j.ollama.log-requests", "true") + .overrideRuntimeConfigKey("quarkus.langchain4j.ollama.log-responses", "true"); + + @Inject + ChatLanguageModel chatLanguageModel; + + @Test + void blocking() { + assertThat(ClientProxy.unwrap(chatLanguageModel)).isInstanceOf(OllamaChatLanguageModel.class); + + wiremock().register( + post(urlEqualTo("/api/chat")) + .withRequestBody(matchingJsonPath("$.model", equalTo("foo"))) + .withHeader("Authorization", absent()) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody(""" + { + "model": "llama3.1", + "created_at": "2024-05-03T10:27:56.84235715Z", + "message": { + "role": "assistant", + "content": "Nice to meet you" + }, + "done": true, + "total_duration": 1206200561, + "load_duration": 695039, + "prompt_eval_duration": 18430000, + "eval_count": 105, + "eval_duration": 1057198000 + } + """))); + + String response = chatLanguageModel.generate("hello"); + assertThat(response).isEqualTo("Nice to meet you"); + + LoggedRequest loggedRequest = singleLoggedRequest(); + assertThat(loggedRequest.getHeader("User-Agent")).isEqualTo("Resteasy Reactive Client"); + String requestBody = new String(loggedRequest.getBody()); + assertThat(requestBody).contains("hello"); + } +} diff --git a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaChatLanguageModelModelNameTest.java b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaChatLanguageModelModelNameTest.java new file mode 100644 index 000000000..7878e5fde --- /dev/null +++ b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaChatLanguageModelModelNameTest.java @@ -0,0 +1,75 @@ +package io.quarkiverse.langchain4j.ollama.deployment; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.absent; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.matchingJsonPath; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static org.assertj.core.api.Assertions.assertThat; + +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.github.tomakehurst.wiremock.verification.LoggedRequest; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import io.quarkiverse.langchain4j.ollama.OllamaChatLanguageModel; +import io.quarkiverse.langchain4j.testing.internal.WiremockAware; +import io.quarkus.arc.ClientProxy; +import io.quarkus.test.QuarkusUnitTest; + +public class OllamaChatLanguageModelModelNameTest extends WiremockAware { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)) + .overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig()) + .overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false") + .overrideConfigKey("quarkus.langchain4j.ollama.chat-model.model-name", "foo") + .overrideRuntimeConfigKey("quarkus.langchain4j.ollama.log-requests", "true") + .overrideRuntimeConfigKey("quarkus.langchain4j.ollama.log-responses", "true"); + + @Inject + ChatLanguageModel chatLanguageModel; + + @Test + void blocking() { + assertThat(ClientProxy.unwrap(chatLanguageModel)).isInstanceOf(OllamaChatLanguageModel.class); + + wiremock().register( + post(urlEqualTo("/api/chat")) + .withRequestBody(matchingJsonPath("$.model", equalTo("foo"))) + .withHeader("Authorization", absent()) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody(""" + { + "model": "llama3.1", + "created_at": "2024-05-03T10:27:56.84235715Z", + "message": { + "role": "assistant", + "content": "Nice to meet you" + }, + "done": true, + "total_duration": 1206200561, + "load_duration": 695039, + "prompt_eval_duration": 18430000, + "eval_count": 105, + "eval_duration": 1057198000 + } + """))); + + String response = chatLanguageModel.generate("hello"); + assertThat(response).isEqualTo("Nice to meet you"); + + LoggedRequest loggedRequest = singleLoggedRequest(); + assertThat(loggedRequest.getHeader("User-Agent")).isEqualTo("Resteasy Reactive Client"); + String requestBody = new String(loggedRequest.getBody()); + assertThat(requestBody).contains("hello"); + } +} diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ModelIdConfigFallbackInterceptor.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ModelIdConfigFallbackInterceptor.java new file mode 100644 index 000000000..f86d14fb6 --- /dev/null +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ModelIdConfigFallbackInterceptor.java @@ -0,0 +1,11 @@ +package io.quarkiverse.langchain4j.ollama.runtime.config; + +import io.smallrye.config.FallbackConfigSourceInterceptor; + +public class ModelIdConfigFallbackInterceptor extends FallbackConfigSourceInterceptor { + + public ModelIdConfigFallbackInterceptor() { + super(ModelIdToModelNameFunction.INSTANCE); + } + +} diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ModelIdConfigRelocateInterceptor.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ModelIdConfigRelocateInterceptor.java new file mode 100644 index 000000000..ce5045aed --- /dev/null +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ModelIdConfigRelocateInterceptor.java @@ -0,0 +1,9 @@ +package io.quarkiverse.langchain4j.ollama.runtime.config; + +import io.smallrye.config.RelocateConfigSourceInterceptor; + +public class ModelIdConfigRelocateInterceptor extends RelocateConfigSourceInterceptor { + public ModelIdConfigRelocateInterceptor() { + super(ModelIdToModelNameFunction.INSTANCE); + } +} diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ModelIdToModelNameFunction.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ModelIdToModelNameFunction.java new file mode 100644 index 000000000..516e9ea39 --- /dev/null +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ModelIdToModelNameFunction.java @@ -0,0 +1,26 @@ +package io.quarkiverse.langchain4j.ollama.runtime.config; + +import java.util.function.Function; + +final class ModelIdToModelNameFunction implements Function { + + static final ModelIdToModelNameFunction INSTANCE = new ModelIdToModelNameFunction(); + + private ModelIdToModelNameFunction() { + } + + @Override + public String apply(String name) { + if (!name.startsWith("quarkus.langchain4j.ollama")) { + return name; + } + if (!name.endsWith("model-id")) { + return name; + } + int index = name.lastIndexOf(".model-id"); + if (index < 1) { + return name; + } + return name.substring(0, index) + ".model-name"; + } +} diff --git a/model-providers/ollama/runtime/src/main/resources/META-INF/services/io.smallrye.config.ConfigSourceInterceptor b/model-providers/ollama/runtime/src/main/resources/META-INF/services/io.smallrye.config.ConfigSourceInterceptor new file mode 100644 index 000000000..427fac231 --- /dev/null +++ b/model-providers/ollama/runtime/src/main/resources/META-INF/services/io.smallrye.config.ConfigSourceInterceptor @@ -0,0 +1,2 @@ +io.quarkiverse.langchain4j.ollama.runtime.config.ModelIdConfigRelocateInterceptor +io.quarkiverse.langchain4j.ollama.runtime.config.ModelIdConfigFallbackInterceptor