From 7617c754fc4513cebf7f97a1fbbcf79c85e9e7cf Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Fri, 12 Jan 2024 17:39:12 +0200 Subject: [PATCH] WIP - Add experimental Watsonx repo --- integration-tests/pom.xml | 1 + integration-tests/watsonx/pom.xml | 118 +++++++++++ .../watsonx/ChatLanguageModelResource.java | 29 +++ .../src/main/resources/application.properties | 1 + pom.xml | 1 + watsonx/deployment/pom.xml | 62 ++++++ .../watsonx/deployment/WatsonxProcessor.java | 14 ++ watsonx/pom.xml | 21 ++ watsonx/runtime/pom.xml | 79 ++++++++ .../langchain4j/watsonx/Message.java | 5 + .../langchain4j/watsonx/Parameters.java | 26 +++ .../watsonx/TextGenerationRequest.java | 7 + .../watsonx/TextGenerationResponse.java | 8 + .../langchain4j/watsonx/WatsonxChatModel.java | 186 ++++++++++++++++++ .../langchain4j/watsonx/WatsonxRestApi.java | 162 +++++++++++++++ .../watsonx/runtime/WatsoxRecorder.java | 33 ++++ .../runtime/config/ChatModelConfig.java | 20 ++ .../config/Langchain4jWatsonxConfig.java | 51 +++++ .../src/main/resources/META-INF/beans.xml | 0 .../resources/META-INF/quarkus-extension.yaml | 12 ++ 20 files changed, 836 insertions(+) create mode 100644 integration-tests/watsonx/pom.xml create mode 100644 integration-tests/watsonx/src/main/java/org/acme/example/watsonx/ChatLanguageModelResource.java create mode 100644 integration-tests/watsonx/src/main/resources/application.properties create mode 100644 watsonx/deployment/pom.xml create mode 100644 watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java create mode 100644 watsonx/pom.xml create mode 100644 watsonx/runtime/pom.xml create mode 100644 watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/Message.java create mode 100644 watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/Parameters.java create mode 100644 watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/TextGenerationRequest.java create mode 100644 watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/TextGenerationResponse.java create mode 100644 watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java create mode 100644 watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxRestApi.java create mode 100644 watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsoxRecorder.java create mode 100644 watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelConfig.java create mode 100644 watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/Langchain4jWatsonxConfig.java create mode 100644 watsonx/runtime/src/main/resources/META-INF/beans.xml create mode 100644 watsonx/runtime/src/main/resources/META-INF/quarkus-extension.yaml diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml index add6d678a..6770dae95 100644 --- a/integration-tests/pom.xml +++ b/integration-tests/pom.xml @@ -13,6 +13,7 @@ openai hugging-face + watsonx ollama simple-ollama azure-openai diff --git a/integration-tests/watsonx/pom.xml b/integration-tests/watsonx/pom.xml new file mode 100644 index 000000000..f41f18039 --- /dev/null +++ b/integration-tests/watsonx/pom.xml @@ -0,0 +1,118 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-integration-tests-parent + 999-SNAPSHOT + + quarkus-langchain4j-integration-tests-watsonx + Quarkus Langchain4j - Integration Tests - Watsonx + + true + + + + io.quarkus + quarkus-resteasy-reactive-jackson + + + io.quarkiverse.langchain4j + quarkus-langchain4j-watsonx + ${project.version} + + + io.quarkus + quarkus-junit5 + test + + + io.rest-assured + rest-assured + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + io.quarkus + quarkus-devtools-testing + test + + + + + io.quarkiverse.langchain4j + quarkus-langchain4j-watsonx-deployment + ${project.version} + pom + test + + + * + * + + + + + + + + io.quarkus + quarkus-maven-plugin + + + + build + + + + + + maven-failsafe-plugin + + + + integration-test + verify + + + + ${project.build.directory}/${project.build.finalName}-runner + org.jboss.logmanager.LogManager + ${maven.home} + + + + + + + + + + native-image + + + native + + + + + + maven-surefire-plugin + + ${native.surefire.skip} + + + + + + false + native + + + + diff --git a/integration-tests/watsonx/src/main/java/org/acme/example/watsonx/ChatLanguageModelResource.java b/integration-tests/watsonx/src/main/java/org/acme/example/watsonx/ChatLanguageModelResource.java new file mode 100644 index 000000000..12ef15eae --- /dev/null +++ b/integration-tests/watsonx/src/main/java/org/acme/example/watsonx/ChatLanguageModelResource.java @@ -0,0 +1,29 @@ +package org.acme.example.watsonx; + +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel; + +@Path("chat") +public class ChatLanguageModelResource { + + private final ChatLanguageModel chatLanguageModel; + + public ChatLanguageModelResource() { + this.chatLanguageModel = WatsonxChatModel.builder() + .accessToken("pak-bvZOKvMD1b6ETeSO8lgBohQL8Um6KsGWxi0MVNyfaj4") + .modelId("meta-llama/llama-2-70b-chat") + .version("2024-01-10") + .logRequests(true) + .logResponses(true) + .build(); + } + + @GET + @Path("basic") + public String basic() { + return chatLanguageModel.generate("When was the nobel prize for economics first awarded?"); + } +} diff --git a/integration-tests/watsonx/src/main/resources/application.properties b/integration-tests/watsonx/src/main/resources/application.properties new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/integration-tests/watsonx/src/main/resources/application.properties @@ -0,0 +1 @@ + diff --git a/pom.xml b/pom.xml index 48da46e4c..f661b815e 100644 --- a/pom.xml +++ b/pom.xml @@ -16,6 +16,7 @@ core docs hugging-face + watsonx milvus ollama openai/azure-openai diff --git a/watsonx/deployment/pom.xml b/watsonx/deployment/pom.xml new file mode 100644 index 000000000..8ce07f960 --- /dev/null +++ b/watsonx/deployment/pom.xml @@ -0,0 +1,62 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-watsonx-parent + 999-SNAPSHOT + + quarkus-langchain4j-watsonx-deployment + Quarkus Langchain4j - WatsonX - Deployment + + + io.quarkiverse.langchain4j + quarkus-langchain4j-watsonx + ${project.version} + + + io.quarkus + quarkus-rest-client-reactive-jackson-deployment + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core-deployment + ${project.version} + + + io.quarkus + quarkus-junit5-internal + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + org.wiremock + wiremock-standalone + ${wiremock.version} + test + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${quarkus.version} + + + + + + + + + diff --git a/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java b/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java new file mode 100644 index 000000000..71190f6a2 --- /dev/null +++ b/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java @@ -0,0 +1,14 @@ +package io.quarkiverse.langchain4j.watsonx.deployment; + +import io.quarkus.deployment.annotations.BuildStep; +import io.quarkus.deployment.builditem.FeatureBuildItem; + +public class WatsonxProcessor { + + private static final String FEATURE = "langchain4j-watsonx"; + + @BuildStep + FeatureBuildItem feature() { + return new FeatureBuildItem(FEATURE); + } +} diff --git a/watsonx/pom.xml b/watsonx/pom.xml new file mode 100644 index 000000000..27085621b --- /dev/null +++ b/watsonx/pom.xml @@ -0,0 +1,21 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-parent + 999-SNAPSHOT + + quarkus-langchain4j-watsonx-parent + Quarkus Langchain4j - WatsonX - Parent + pom + + + deployment + runtime + + + + diff --git a/watsonx/runtime/pom.xml b/watsonx/runtime/pom.xml new file mode 100644 index 000000000..a9142162f --- /dev/null +++ b/watsonx/runtime/pom.xml @@ -0,0 +1,79 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-watsonx-parent + 999-SNAPSHOT + + quarkus-langchain4j-watsonx + Quarkus Langchain4j - WatsonX - Runtime + + + io.quarkus + quarkus-arc + + + io.quarkus + quarkus-rest-client-reactive-jackson + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core + ${project.version} + + + + + + io.quarkus + quarkus-extension-maven-plugin + ${quarkus.version} + + + compile + + extension-descriptor + + + ${project.groupId}:${project.artifactId}-deployment:${project.version} + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${quarkus.version} + + + + + + maven-jar-plugin + + + generate-codestart-jar + generate-resources + + jar + + + ${project.basedir}/src/main + + codestarts/** + + codestarts + true + + + + + + + + diff --git a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/Message.java b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/Message.java new file mode 100644 index 000000000..8542ead50 --- /dev/null +++ b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/Message.java @@ -0,0 +1,5 @@ +package io.quarkiverse.langchain4j.watsonx; + +public record Message(String role, String content) { + +} diff --git a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/Parameters.java b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/Parameters.java new file mode 100644 index 000000000..570d26a8b --- /dev/null +++ b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/Parameters.java @@ -0,0 +1,26 @@ +package io.quarkiverse.langchain4j.watsonx; + +public class Parameters { + + private final String decodingMethod; + private final Integer minNewTokens; + private final Integer maxNewTokens; + + public Parameters(String decodingMethod, Integer minNewTokens, Integer maxNewTokens) { + this.decodingMethod = decodingMethod; + this.minNewTokens = minNewTokens; + this.maxNewTokens = maxNewTokens; + } + + public String getDecodingMethod() { + return decodingMethod; + } + + public Integer getMinNewTokens() { + return minNewTokens; + } + + public Integer getMaxNewTokens() { + return maxNewTokens; + } +} diff --git a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/TextGenerationRequest.java b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/TextGenerationRequest.java new file mode 100644 index 000000000..6ae51423c --- /dev/null +++ b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/TextGenerationRequest.java @@ -0,0 +1,7 @@ +package io.quarkiverse.langchain4j.watsonx; + +import java.util.List; + +public record TextGenerationRequest(String modelId, List messages, Parameters parameters) { + +} diff --git a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/TextGenerationResponse.java b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/TextGenerationResponse.java new file mode 100644 index 000000000..04ab15b6b --- /dev/null +++ b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/TextGenerationResponse.java @@ -0,0 +1,8 @@ +package io.quarkiverse.langchain4j.watsonx; + +public record TextGenerationResponse(Results results) { + + public record Results(String generatedText) { + + } +} diff --git a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java new file mode 100644 index 000000000..a849a8371 --- /dev/null +++ b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java @@ -0,0 +1,186 @@ +package io.quarkiverse.langchain4j.watsonx; + +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import org.jboss.resteasy.reactive.client.api.LoggingScope; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; +import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; + +public class WatsonxChatModel implements ChatLanguageModel { + + private final String token; + private final String modelId; + private final String version; + private final String decodingMethod; + private final Integer minNewTokens; + private final Integer maxNewTokens; + private final Double temperature; + private final Double topP; + private final Integer topK; + private final WatsonxRestApi client; + + public WatsonxChatModel(Builder config) { + QuarkusRestClientBuilder builder = QuarkusRestClientBuilder.newBuilder() + .baseUri(config.url) + .connectTimeout(config.timeout.toSeconds(), TimeUnit.SECONDS) + .readTimeout(config.timeout.toSeconds(), TimeUnit.SECONDS); + + if (config.logRequests || config.logResponses) { + builder.loggingScope(LoggingScope.REQUEST_RESPONSE); + builder.clientLogger(new WatsonxRestApi.WatsonClientLogger(config.logRequests, + config.logResponses)); + } + + this.client = builder.build(WatsonxRestApi.class); + this.token = config.accessToken; + this.modelId = config.modelId; + this.version = config.version; + this.decodingMethod = config.decodingMethod; + this.minNewTokens = config.minNewTokens; + this.maxNewTokens = config.maxNewTokens; + this.temperature = config.temperature; + this.topP = config.topP; + this.topK = config.topK; + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public Response generate(List messages) { + + Parameters parameters = new Parameters(decodingMethod, minNewTokens, maxNewTokens); + + TextGenerationRequest request = new TextGenerationRequest(modelId, + messages.stream().map(cm -> new Message(getRole(cm), cm.text())).toList(), parameters); + + TextGenerationResponse textGenerationResponse = client.chat(request, token, version); + + return Response.from(AiMessage.from(textGenerationResponse.results().generatedText())); + } + + private String getRole(ChatMessage chatMessage) { + if (chatMessage instanceof SystemMessage) { + return "system"; + } else if (chatMessage instanceof UserMessage) { + return "user"; + } else if (chatMessage instanceof AiMessage) { + return "assistant"; + } + throw new IllegalArgumentException(chatMessage.getClass().getSimpleName() + " not supported"); + } + + @Override + public Response generate(List messages, List toolSpecifications) { + throw new IllegalArgumentException("Tools are currently not supported for WatsonX models"); + } + + @Override + public Response generate(List messages, ToolSpecification toolSpecification) { + throw new IllegalArgumentException("Tools are currently not supported for WatsonX models"); + } + + public static final class Builder { + + private String accessToken; + private String modelId; + private String version; + private Duration timeout = Duration.ofSeconds(15); + private String decodingMethod = "greedy"; + private Integer minNewTokens = 1; + private Integer maxNewTokens = 200; + private Double temperature; + + private URI url = URI.create("https://bam-api.res.ibm.com"); + private Integer topK; + private Double topP; + public boolean logResponses; + public boolean logRequests; + + public Builder modelId(String modelId) { + this.modelId = modelId; + return this; + } + + public Builder accessToken(String accessToken) { + this.accessToken = accessToken; + return this; + } + + public Builder version(String version) { + this.version = version; + return this; + } + + public Builder url(URL url) { + try { + this.url = url.toURI(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + return this; + } + + public Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + public Builder decodingMethod(String decodingMethod) { + this.decodingMethod = decodingMethod; + return this; + } + + public Builder minNewTokens(Integer minNewTokens) { + this.minNewTokens = minNewTokens; + return this; + } + + public Builder maxNewTokens(Integer maxNewTokens) { + this.maxNewTokens = maxNewTokens; + return this; + } + + public Builder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder topK(Integer topK) { + this.topK = topK; + return this; + } + + public Builder topP(Double topP) { + this.topP = topP; + return this; + } + + public WatsonxChatModel build() { + return new WatsonxChatModel(this); + } + + public Builder logRequests(boolean logRequests) { + this.logRequests = logRequests; + return this; + } + + public Builder logResponses(boolean logResponses) { + this.logResponses = logResponses; + return this; + } + } +} diff --git a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxRestApi.java b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxRestApi.java new file mode 100644 index 000000000..7f2b817ec --- /dev/null +++ b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxRestApi.java @@ -0,0 +1,162 @@ +package io.quarkiverse.langchain4j.watsonx; + +import static java.util.stream.Collectors.joining; +import static java.util.stream.StreamSupport.stream; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.QueryParam; +import jakarta.ws.rs.core.MediaType; + +import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam; +import org.jboss.logging.Logger; +import org.jboss.resteasy.reactive.client.api.ClientLogger; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; +import io.quarkus.rest.client.reactive.NotBody; +import io.quarkus.rest.client.reactive.jackson.ClientObjectMapper; +import io.vertx.core.Handler; +import io.vertx.core.MultiMap; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.HttpClientRequest; +import io.vertx.core.http.HttpClientResponse; + +/** + * This Microprofile REST client is used as the building block of all the API calls to Watsonx. + * The implementation is provided by the Reactive REST Client in Quarkus. + */ + +@Path("v2") +@ClientHeaderParam(name = "Authorization", value = "Bearer {token}") +@Consumes(MediaType.APPLICATION_JSON) +@Produces(MediaType.APPLICATION_JSON) +public interface WatsonxRestApi { + + @POST + @Path("text/chat") + TextGenerationResponse chat(TextGenerationRequest request, @NotBody String token, @QueryParam("version") String version); + + @ClientObjectMapper + static ObjectMapper objectMapper(ObjectMapper defaultObjectMapper) { + return QuarkusJsonCodecFactory.SnakeCaseObjectMapperHolder.MAPPER; + } + + /** + * Introduce a custom logger as the stock one logs at the DEBUG level by default... + */ + class WatsonClientLogger implements ClientLogger { + private static final Logger log = Logger.getLogger(WatsonClientLogger.class); + + private static final Pattern BEARER_PATTERN = Pattern.compile("(Bearer\\s*sk-)(\\w{2})(\\w+)(\\w{2})"); + + private final boolean logRequests; + private final boolean logResponses; + + public WatsonClientLogger(boolean logRequests, boolean logResponses) { + this.logRequests = logRequests; + this.logResponses = logResponses; + } + + @Override + public void setBodySize(int bodySize) { + // ignore + } + + @Override + public void logRequest(HttpClientRequest request, Buffer body, boolean omitBody) { + if (!logRequests || !log.isInfoEnabled()) { + return; + } + try { + log.infof("Request:\n- method: %s\n- url: %s\n- headers: %s\n- body: %s", + request.getMethod(), + request.absoluteURI(), + inOneLine(request.headers()), + bodyToString(body)); + } catch (Exception e) { + log.warn("Failed to log request", e); + } + } + + @Override + public void logResponse(HttpClientResponse response, boolean redirect) { + if (!logResponses || !log.isInfoEnabled()) { + return; + } + response.bodyHandler(new Handler<>() { + @Override + public void handle(Buffer body) { + try { + log.infof( + "Response:\n- status code: %s\n- headers: %s\n- body: %s", + response.statusCode(), + inOneLine(response.headers()), + bodyToString(body)); + } catch (Exception e) { + log.warn("Failed to log response", e); + } + } + }); + } + + private String bodyToString(Buffer body) { + if (body == null) { + return ""; + } + return body.toString(); + } + + private String inOneLine(MultiMap headers) { + + return stream(headers.spliterator(), false) + .map(header -> { + String headerKey = header.getKey(); + String headerValue = header.getValue(); + if (headerKey.equals("Authorization")) { + headerValue = maskAuthorizationHeaderValue(headerValue); + } else if (headerKey.equals("api-key")) { + headerValue = maskApiKeyHeaderValue(headerValue); + } + return String.format("[%s: %s]", headerKey, headerValue); + }) + .collect(joining(", ")); + } + + private static String maskAuthorizationHeaderValue(String authorizationHeaderValue) { + try { + + Matcher matcher = BEARER_PATTERN.matcher(authorizationHeaderValue); + + StringBuilder sb = new StringBuilder(); + while (matcher.find()) { + matcher.appendReplacement(sb, matcher.group(1) + matcher.group(2) + "..." + matcher.group(4)); + } + matcher.appendTail(sb); + + return sb.toString(); + } catch (Exception e) { + return "Failed to mask the API key."; + } + } + + private static String maskApiKeyHeaderValue(String apiKeyHeaderValue) { + try { + if (apiKeyHeaderValue.length() <= 4) { + return apiKeyHeaderValue; + } + return apiKeyHeaderValue.substring(0, 2) + + "..." + + apiKeyHeaderValue.substring(apiKeyHeaderValue.length() - 2); + } catch (Exception e) { + return "Failed to mask the API key."; + } + } + } +} diff --git a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsoxRecorder.java b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsoxRecorder.java new file mode 100644 index 000000000..aa42576e9 --- /dev/null +++ b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsoxRecorder.java @@ -0,0 +1,33 @@ +package io.quarkiverse.langchain4j.watsonx.runtime; + +import java.util.function.Supplier; + +import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel; +import io.quarkiverse.langchain4j.watsonx.runtime.config.ChatModelConfig; +import io.quarkiverse.langchain4j.watsonx.runtime.config.Langchain4jWatsonxConfig; +import io.quarkus.runtime.annotations.Recorder; + +@Recorder +public class WatsoxRecorder { + + public Supplier chatModel(Langchain4jWatsonxConfig runtimeConfig) { + ChatModelConfig chatModelConfig = runtimeConfig.chatModel(); + + var builder = WatsonxChatModel.builder() + .accessToken(runtimeConfig.apiKey()) + .timeout(runtimeConfig.timeout()) + .modelId(chatModelConfig.modelId()) + .version(chatModelConfig.version()); + + if (runtimeConfig.baseUrl().isPresent()) { + builder.url(runtimeConfig.baseUrl().get()); + } + + return new Supplier<>() { + @Override + public Object get() { + return builder.build(); + } + }; + } +} diff --git a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelConfig.java b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelConfig.java new file mode 100644 index 000000000..1697f9e99 --- /dev/null +++ b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelConfig.java @@ -0,0 +1,20 @@ +package io.quarkiverse.langchain4j.watsonx.runtime.config; + +import io.quarkus.runtime.annotations.ConfigGroup; +import io.smallrye.config.WithDefault; + +@ConfigGroup +public interface ChatModelConfig { + + /** + * Model to use + */ + @WithDefault("meta-llama/llama-2-70b-chat") + String modelId(); + + /** + * Version to use + */ + @WithDefault("2024-01-10") + String version(); +} diff --git a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/Langchain4jWatsonxConfig.java b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/Langchain4jWatsonxConfig.java new file mode 100644 index 000000000..548b96bad --- /dev/null +++ b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/Langchain4jWatsonxConfig.java @@ -0,0 +1,51 @@ +package io.quarkiverse.langchain4j.watsonx.runtime.config; + +import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME; + +import java.net.URL; +import java.time.Duration; +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithDefault; + +@ConfigRoot(phase = RUN_TIME) +@ConfigMapping(prefix = "quarkus.langchain4j.watsonx") +public interface Langchain4jWatsonxConfig { + + /** + * Base URL where the Ollama serving is running + */ + @ConfigDocDefault("https://bam-api.res.ibm.com") + Optional baseUrl(); + + /** + * Watsonx API key + */ + String apiKey(); + + /** + * Timeout for Watsonx calls + */ + @WithDefault("10s") + Duration timeout(); + + /** + * Whether the Watsonx client should log requests + */ + @WithDefault("false") + Boolean logRequests(); + + /** + * Whether the Watsonx client should log responses + */ + @WithDefault("false") + Boolean logResponses(); + + /** + * Chat model related settings + */ + ChatModelConfig chatModel(); +} diff --git a/watsonx/runtime/src/main/resources/META-INF/beans.xml b/watsonx/runtime/src/main/resources/META-INF/beans.xml new file mode 100644 index 000000000..e69de29bb diff --git a/watsonx/runtime/src/main/resources/META-INF/quarkus-extension.yaml b/watsonx/runtime/src/main/resources/META-INF/quarkus-extension.yaml new file mode 100644 index 000000000..b136c392a --- /dev/null +++ b/watsonx/runtime/src/main/resources/META-INF/quarkus-extension.yaml @@ -0,0 +1,12 @@ +name: Quarkus Langchain4j pgvector embedding store +artifact: ${project.groupId}:${project.artifactId}:${project.version} +description: Provides the pgvector Embedding store for Quarkus Langchain4j +metadata: + keywords: + - ai + - langchain4j + - openai + - pgvector + categories: + - "miscellaneous" + status: "preview"