Skip to content

Commit

Permalink
Merge pull request #911 from quarkiverse/#909
Browse files Browse the repository at this point in the history
Allow the use of `uarkus.langchain4j.ollama.chat-model.model-name`
  • Loading branch information
geoand authored Sep 19, 2024
2 parents 393f5c2 + 4e38167 commit 6e50c5d
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.FeatureBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ServiceProviderBuildItem;
import io.quarkus.runtime.configuration.ConfigUtils;
import io.smallrye.config.ConfigSourceInterceptor;

public class OllamaProcessor {

Expand All @@ -45,6 +47,12 @@ FeatureBuildItem feature() {
return new FeatureBuildItem(FEATURE);
}

@BuildStep
void nativeSupport(BuildProducer<ServiceProviderBuildItem> serviceProviderProducer) {
serviceProviderProducer
.produce(ServiceProviderBuildItem.allProvidersFromClassPath(ConfigSourceInterceptor.class.getName()));
}

@BuildStep
public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem> chatProducer,
BuildProducer<EmbeddingModelProviderCandidateBuildItem> embeddingProducer,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package io.quarkiverse.langchain4j.ollama.deployment;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.absent;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.matchingJsonPath;
import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static org.assertj.core.api.Assertions.assertThat;

import jakarta.inject.Inject;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import com.github.tomakehurst.wiremock.verification.LoggedRequest;

import dev.langchain4j.model.chat.ChatLanguageModel;
import io.quarkiverse.langchain4j.ollama.OllamaChatLanguageModel;
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
import io.quarkus.arc.ClientProxy;
import io.quarkus.test.QuarkusUnitTest;

public class OllamaChatLanguageModelModelIdTest extends WiremockAware {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class))
.overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig())
.overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false")
.overrideConfigKey("quarkus.langchain4j.ollama.chat-model.model-name", "foo")
.overrideRuntimeConfigKey("quarkus.langchain4j.ollama.log-requests", "true")
.overrideRuntimeConfigKey("quarkus.langchain4j.ollama.log-responses", "true");

@Inject
ChatLanguageModel chatLanguageModel;

@Test
void blocking() {
assertThat(ClientProxy.unwrap(chatLanguageModel)).isInstanceOf(OllamaChatLanguageModel.class);

wiremock().register(
post(urlEqualTo("/api/chat"))
.withRequestBody(matchingJsonPath("$.model", equalTo("foo")))
.withHeader("Authorization", absent())
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
.withBody("""
{
"model": "llama3.1",
"created_at": "2024-05-03T10:27:56.84235715Z",
"message": {
"role": "assistant",
"content": "Nice to meet you"
},
"done": true,
"total_duration": 1206200561,
"load_duration": 695039,
"prompt_eval_duration": 18430000,
"eval_count": 105,
"eval_duration": 1057198000
}
""")));

String response = chatLanguageModel.generate("hello");
assertThat(response).isEqualTo("Nice to meet you");

LoggedRequest loggedRequest = singleLoggedRequest();
assertThat(loggedRequest.getHeader("User-Agent")).isEqualTo("Resteasy Reactive Client");
String requestBody = new String(loggedRequest.getBody());
assertThat(requestBody).contains("hello");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package io.quarkiverse.langchain4j.ollama.deployment;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.absent;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.matchingJsonPath;
import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static org.assertj.core.api.Assertions.assertThat;

import jakarta.inject.Inject;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import com.github.tomakehurst.wiremock.verification.LoggedRequest;

import dev.langchain4j.model.chat.ChatLanguageModel;
import io.quarkiverse.langchain4j.ollama.OllamaChatLanguageModel;
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
import io.quarkus.arc.ClientProxy;
import io.quarkus.test.QuarkusUnitTest;

public class OllamaChatLanguageModelModelNameTest extends WiremockAware {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class))
.overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig())
.overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false")
.overrideConfigKey("quarkus.langchain4j.ollama.chat-model.model-name", "foo")
.overrideRuntimeConfigKey("quarkus.langchain4j.ollama.log-requests", "true")
.overrideRuntimeConfigKey("quarkus.langchain4j.ollama.log-responses", "true");

@Inject
ChatLanguageModel chatLanguageModel;

@Test
void blocking() {
assertThat(ClientProxy.unwrap(chatLanguageModel)).isInstanceOf(OllamaChatLanguageModel.class);

wiremock().register(
post(urlEqualTo("/api/chat"))
.withRequestBody(matchingJsonPath("$.model", equalTo("foo")))
.withHeader("Authorization", absent())
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
.withBody("""
{
"model": "llama3.1",
"created_at": "2024-05-03T10:27:56.84235715Z",
"message": {
"role": "assistant",
"content": "Nice to meet you"
},
"done": true,
"total_duration": 1206200561,
"load_duration": 695039,
"prompt_eval_duration": 18430000,
"eval_count": 105,
"eval_duration": 1057198000
}
""")));

String response = chatLanguageModel.generate("hello");
assertThat(response).isEqualTo("Nice to meet you");

LoggedRequest loggedRequest = singleLoggedRequest();
assertThat(loggedRequest.getHeader("User-Agent")).isEqualTo("Resteasy Reactive Client");
String requestBody = new String(loggedRequest.getBody());
assertThat(requestBody).contains("hello");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package io.quarkiverse.langchain4j.ollama.runtime.config;

import io.smallrye.config.FallbackConfigSourceInterceptor;

public class ModelIdConfigFallbackInterceptor extends FallbackConfigSourceInterceptor {

public ModelIdConfigFallbackInterceptor() {
super(ModelIdToModelNameFunction.INSTANCE);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package io.quarkiverse.langchain4j.ollama.runtime.config;

import io.smallrye.config.RelocateConfigSourceInterceptor;

public class ModelIdConfigRelocateInterceptor extends RelocateConfigSourceInterceptor {
public ModelIdConfigRelocateInterceptor() {
super(ModelIdToModelNameFunction.INSTANCE);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.quarkiverse.langchain4j.ollama.runtime.config;

import java.util.function.Function;

final class ModelIdToModelNameFunction implements Function<String, String> {

static final ModelIdToModelNameFunction INSTANCE = new ModelIdToModelNameFunction();

private ModelIdToModelNameFunction() {
}

@Override
public String apply(String name) {
if (!name.startsWith("quarkus.langchain4j.ollama")) {
return name;
}
if (!name.endsWith("model-id")) {
return name;
}
int index = name.lastIndexOf(".model-id");
if (index < 1) {
return name;
}
return name.substring(0, index) + ".model-name";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
io.quarkiverse.langchain4j.ollama.runtime.config.ModelIdConfigRelocateInterceptor
io.quarkiverse.langchain4j.ollama.runtime.config.ModelIdConfigFallbackInterceptor

0 comments on commit 6e50c5d

Please sign in to comment.