diff --git a/.gitignore b/.gitignore index 396cd526c..943c68348 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,6 @@ release.properties # Quarkus CLI .quarkus + +#Dolphin +.directory \ No newline at end of file diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index e895b0636..53358a8d0 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -9,7 +9,6 @@ import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.SEED_MEMORY; import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.V; -import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.io.UncheckedIOException; @@ -117,8 +116,8 @@ public class AiServicesProcessor { public static final DotName MICROMETER_TIMED = DotName.createSimple("io.micrometer.core.annotation.Timed"); public static final DotName MICROMETER_COUNTED = DotName.createSimple("io.micrometer.core.annotation.Counted"); - private static final String DEFAULT_DELIMITER = "\n"; - private static final Predicate IS_METHOD_PARAMETER_ANNOTATION = ai -> ai.target() + public static final String DEFAULT_DELIMITER = "\n"; + public static final Predicate IS_METHOD_PARAMETER_ANNOTATION = ai -> ai.target() .kind() == AnnotationTarget.Kind.METHOD_PARAMETER; private static final Function METHOD_PARAMETER_POSITION_FUNCTION = ai -> Integer .valueOf(ai.target() @@ -1033,7 +1032,7 @@ private Optional gatherSystemMessageInfo instance = method.declaringClass().declaredAnnotation(LangChain4jDotNames.SYSTEM_MESSAGE); } if (instance != null) { - String systemMessageTemplate = getTemplateFromAnnotationInstance(instance); + String systemMessageTemplate = TemplateUtil.getTemplateFromAnnotationInstance(instance); if (systemMessageTemplate.isEmpty()) { throw illegalConfigurationForMethod("@SystemMessage's template parameter cannot be empty", method); } @@ -1061,7 +1060,7 @@ private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodIn AnnotationInstance userMessageInstance = method.declaredAnnotation(LangChain4jDotNames.USER_MESSAGE); if (userMessageInstance != null) { - String userMessageTemplate = getTemplateFromAnnotationInstance(userMessageInstance); + String userMessageTemplate = TemplateUtil.getTemplateFromAnnotationInstance(userMessageInstance); if (userMessageTemplate.contains("{{it}}")) { if (method.parametersCount() != 1) { @@ -1110,41 +1109,6 @@ private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodIn } } - /** - * Meant to be called with instances of {@link dev.langchain4j.service.SystemMessage} or - * {@link dev.langchain4j.service.UserMessage} - * - * @return the String value of the template or an empty string if not specified - */ - private String getTemplateFromAnnotationInstance(AnnotationInstance instance) { - AnnotationValue fromResourceValue = instance.value("fromResource"); - if (fromResourceValue != null) { - String fromResource = fromResourceValue.asString(); - if (!fromResource.startsWith("/")) { - fromResource = "/" + fromResource; - - } - try (InputStream is = Thread.currentThread().getContextClassLoader().getResourceAsStream(fromResource)) { - if (is != null) { - return new String(is.readAllBytes()); - } else { - throw new FileNotFoundException("Resource not found: " + fromResource); - } - } catch (IOException e) { - throw new UncheckedIOException(e); - } - } else { - AnnotationValue valueValue = instance.value(); - if (valueValue != null) { - AnnotationValue delimiterValue = instance.value("delimiter"); - String delimiter = delimiterValue != null ? delimiterValue.asString() : DEFAULT_DELIMITER; - return String.join(delimiter, valueValue.asStringArray()); - } - - } - return ""; - } - private Optional gatherMetricsTimedInfo(MethodInfo method, boolean addMicrometerMetrics) { if (!addMicrometerMetrics) { diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java index d3243ccc8..e6668c00c 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java @@ -42,7 +42,7 @@ public class LangChain4jDotNames { static final DotName AI_SERVICES = DotName.createSimple(AiServices.class); static final DotName CREATED_AWARE = DotName.createSimple(CreatedAware.class); public static final DotName SYSTEM_MESSAGE = DotName.createSimple(SystemMessage.class); - static final DotName USER_MESSAGE = DotName.createSimple(UserMessage.class); + public static final DotName USER_MESSAGE = DotName.createSimple(UserMessage.class); static final DotName USER_NAME = DotName.createSimple(UserName.class); static final DotName MODERATE = DotName.createSimple(Moderate.class); static final DotName MEMORY_ID = DotName.createSimple(MemoryId.class); @@ -52,7 +52,7 @@ public class LangChain4jDotNames { static final DotName V = DotName.createSimple(dev.langchain4j.service.V.class); static final DotName MODEL_NAME = DotName.createSimple(ModelName.class); - static final DotName REGISTER_AI_SERVICES = DotName.createSimple(RegisterAiService.class); + public static final DotName REGISTER_AI_SERVICES = DotName.createSimple(RegisterAiService.class); static final DotName BEAN_CHAT_MODEL_SUPPLIER = DotName.createSimple( RegisterAiService.BeanChatLanguageModelSupplier.class); diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/TemplateUtil.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/TemplateUtil.java index f908aa24c..9d0efc880 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/TemplateUtil.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/TemplateUtil.java @@ -1,15 +1,22 @@ package io.quarkiverse.langchain4j.deployment; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; +import org.jboss.jandex.AnnotationInstance; +import org.jboss.jandex.AnnotationValue; + import io.quarkiverse.langchain4j.QuarkusPromptTemplateFactory; import io.quarkus.qute.Engine; import io.quarkus.qute.Expression; import io.quarkus.qute.Template; -class TemplateUtil { +public class TemplateUtil { static List> parts(String templateStr) { Template template = Holder.ENGINE.parse(templateStr); @@ -20,6 +27,46 @@ static List> parts(String templateStr) { return expressions.stream().map(Expression::getParts).collect(Collectors.toList()); } + /** + * Meant to be called with instances of {@link dev.langchain4j.service.SystemMessage} or + * {@link dev.langchain4j.service.UserMessage} + * + * @return the String value of the template or an empty string if not specified + */ + public static String getTemplateFromAnnotationInstance(AnnotationInstance instance) { + + if (instance == null) { + return ""; + } + + AnnotationValue fromResourceValue = instance.value("fromResource"); + if (fromResourceValue != null) { + String fromResource = fromResourceValue.asString(); + if (!fromResource.startsWith("/")) { + fromResource = "/" + fromResource; + + } + try (InputStream is = Thread.currentThread().getContextClassLoader().getResourceAsStream(fromResource)) { + if (is != null) { + return new String(is.readAllBytes()); + } else { + throw new FileNotFoundException("Resource not found: " + fromResource); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } else { + AnnotationValue valueValue = instance.value(); + if (valueValue != null) { + AnnotationValue delimiterValue = instance.value("delimiter"); + String delimiter = delimiterValue != null ? delimiterValue.asString() : AiServicesProcessor.DEFAULT_DELIMITER; + return String.join(delimiter, valueValue.asStringArray()); + } + + } + return ""; + } + private static class Holder { private static final Engine ENGINE = Engine.builder().addDefaults() .addParserHook(new QuarkusPromptTemplateFactory.MustacheTemplateVariableStyleParserHook()).build(); diff --git a/docs/modules/ROOT/pages/watsonx.adoc b/docs/modules/ROOT/pages/watsonx.adoc index 7e097f5fa..7a9c1287d 100644 --- a/docs/modules/ROOT/pages/watsonx.adoc +++ b/docs/modules/ROOT/pages/watsonx.adoc @@ -66,13 +66,25 @@ NOTE: To determine the API key, go to https://cloud.ibm.com/iam/apikeys and gene When creating prompts using watsonx.ai, it's important to follow the guidelines of the model you choose. Depending on the model, some special instructions may be required to ensure the desired output. For best results, always refer to the documentation provided for each model to maximize the effectiveness of your prompts. -For example, if you choose to use `ibm/granite-13b-chat-v2`, you can use the `<|system|>`, `<|user|>`, and `<|assistant|>` instructions: +To simplify the process of prompt creation, you can use the `prompt-formatter` property to automatically handle the addition of tags to your prompts. This property allows you to avoid manually adding tags by letting the system handle the formatting based on the model's requirements. This functionality is particularly useful for models such as `ibm/granite-13b-chat-v2`, `meta-llama/llama-3-405b-instruct`, and other supported models, ensuring consistent and accurate prompt structures without additional effort. + +To enable this functionality, configure the `prompt-formatter` property in your `application.properties` file as follows: + +[source,properties,subs=attributes+] +---- +quarkus.langchain4j.watsonx.chat-model.prompt-formatter=true +---- + +When this property is set to `true`, the system will automatically format prompts with the appropriate tags. This helps to maintain prompt clarity and improves interaction with the LLM by ensuring that prompts follow the required structure. If set to `false`, you'll need to manage the tags manually. + +For example, if you choose to use `ibm/granite-13b-chat-v2` without using the `prompt-formatter`, you will need to manually add the `<|system|>`, `<|user|>` and `<|assistant|>` instructions: [source,properties,subs=attributes+] ---- quarkus.langchain4j.watsonx.api-key=hG-... quarkus.langchain4j.watsonx.base-url=https://us-south.ml.cloud.ibm.com quarkus.langchain4j.watsonx.chat-model.model-id=ibm/granite-13b-chat-v2 +quarkus.langchain4j.watsonx.chat-model.prompt-formatter=false ---- [source,java] @@ -96,30 +108,63 @@ public interface LLMService { } ---- +Enabling the `prompt-formatter` will result in: + +[source,properties,subs=attributes+] +---- +quarkus.langchain4j.watsonx.api-key=hG-... +quarkus.langchain4j.watsonx.base-url=https://us-south.ml.cloud.ibm.com +quarkus.langchain4j.watsonx.chat-model.model-id=ibm/granite-13b-chat-v2 +quarkus.langchain4j.watsonx.chat-model.prompt-formatter=true +---- + [source,java] ---- -@Path("/llm") -public class LLMResource { +@RegisterAiService +public interface LLMService { - @Inject - LLMService llmService; + public record Result(Integer result) {} - @GET - @Path("/calculator") - @Produces(MediaType.APPLICATION_JSON) - public Result calculator() { - return llmService.calculator(2, 2); - } + @SystemMessage(""" + You are a calculator and you must perform the mathematical operation + {response_schema} + """) + @UserMessage(""" + {firstNumber} + {secondNumber} + """) + public Result calculator(int firstNumber, int secondNumber); } ---- -[source,shell] ----- -❯ curl http://localhost:8080/llm/calculator -{"result":4} ----- +The `prompt-formatter` supports the following models: + +* `mistralai/mistral-large` +* `mistralai/mixtral-8x7b-instruct-v01` +* `sdaia/allam-1-13b-instruct` +* `meta-llama/llama-3-405b-instruct` +* `meta-llama/llama-3-1-70b-instruct` +* `meta-llama/llama-3-1-8b-instruct` +* `meta-llama/llama-3-70b-instruct` +* `meta-llama/llama-3-8b-instruct` +* `ibm/granite-13b-chat-v2` +* `ibm/granite-13b-instruct-v2` +* `ibm/granite-7b-lab` +* `ibm/granite-20b-code-instruct` +* `ibm/granite-34b-code-instruct` +* `ibm/granite-3b-code-instruct` +* `ibm/granite-8b-code-instruct` + +==== Tool Execution with Prompt Formatter + +In addition to simplifying prompt creation, the `prompt-formatter` property also enables the execution of tools for specific models. Tools allow for dynamic interactions within the model, enabling the AI to perform specific actions or fetch data as part of its response. + +When the `prompt-formatter` is enabled and a supported model is selected, the prompt will be automatically formatted to use the tools. More information about tools is available in the xref:./agent-and-tools.adoc[Agent and Tools] page. + +Currently, the following model supports tool execution: + +* `mistralai/mistral-large` -IMPORTANT: The `@SystemMessage` and `@UserMessage` are joined by default without spaces or new lines, if you want to change this behavior use the property `quarkus.langchain4j.watsonx.chat-model.prompt-joiner=`. By adjusting this property, you can define your preferred way of joining messages and ensure that the prompt structure meets your specific needs. +IMPORTANT: The `@SystemMessage` and `@UserMessage` annotations are joined by default with a new line. If you want to change this behavior, use the property `quarkus.langchain4j.watsonx.chat-model.prompt-joiner=`. By adjusting this property, you can define your preferred way of joining messages and ensure that the prompt structure meets your specific needs. This customization option is available only when the `prompt-formatter` property is set to `false`. When the `prompt-formatter` is enabled (set to `true`), the prompt formatting, including the addition of tags and message joining, is automatically handled. In this case, the `prompt-joiner` property will be ignored, and you will not have the ability to customize how messages are joined. NOTE: Sometimes it may be useful to use the `quarkus.langchain4j.watsonx.chat-model.stop-sequences` property to prevent the LLM model from returning more results than desired. diff --git a/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java b/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java index fcda4302a..33b9bfcc9 100644 --- a/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java +++ b/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java @@ -4,30 +4,39 @@ import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.EMBEDDING_MODEL; import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.STREAMING_CHAT_MODEL; import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.TOKEN_COUNT_ESTIMATOR; +import static io.quarkiverse.langchain4j.deployment.TemplateUtil.getTemplateFromAnnotationInstance; import java.util.List; import jakarta.enterprise.context.ApplicationScoped; import org.jboss.jandex.AnnotationInstance; +import org.jboss.logging.Logger; import io.quarkiverse.langchain4j.ModelName; +import io.quarkiverse.langchain4j.deployment.LangChain4jDotNames; import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem; import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem; import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem; import io.quarkiverse.langchain4j.runtime.NamedConfigUtil; +import io.quarkiverse.langchain4j.watsonx.deployment.items.WatsonxChatModelProviderBuildItem; +import io.quarkiverse.langchain4j.watsonx.prompt.PromptFormatter; +import io.quarkiverse.langchain4j.watsonx.prompt.PromptFormatterMapper; import io.quarkiverse.langchain4j.watsonx.runtime.WatsonxRecorder; import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; +import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxFixedRuntimeConfig; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.deployment.annotations.BuildProducer; import io.quarkus.deployment.annotations.BuildStep; import io.quarkus.deployment.annotations.ExecutionTime; import io.quarkus.deployment.annotations.Record; +import io.quarkus.deployment.builditem.CombinedIndexBuildItem; import io.quarkus.deployment.builditem.FeatureBuildItem; public class WatsonxProcessor { + private static final Logger log = Logger.getLogger(WatsonxProcessor.class); private static final String FEATURE = "langchain4j-watsonx"; private static final String PROVIDER = "watsonx"; @@ -51,44 +60,124 @@ public void providerCandidates(BuildProducer selectedChatItem, - List selectedEmbedding, - BuildProducer beanProducer) { + BuildProducer chatModelBuilder) { + + var index = indexBuildItem.getIndex(); + var annotationInstances = index.getAnnotations(LangChain4jDotNames.REGISTER_AI_SERVICES); for (var selected : selectedChatItem) { - if (PROVIDER.equals(selected.getProvider())) { - String configName = selected.getConfigName(); - var chatModel = recorder.chatModel(config, configName); - var chatBuilder = SyntheticBeanBuildItem - .configure(CHAT_MODEL) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(chatModel); - addQualifierIfNecessary(chatBuilder, configName); - beanProducer.produce(chatBuilder.done()); + if (!PROVIDER.equals(selected.getProvider())) { + continue; + } - var tokenizerBuilder = SyntheticBeanBuildItem - .configure(TOKEN_COUNT_ESTIMATOR) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(chatModel); - addQualifierIfNecessary(tokenizerBuilder, configName); - beanProducer.produce(tokenizerBuilder.done()); + String configName = selected.getConfigName(); - var streamingBuilder = SyntheticBeanBuildItem - .configure(STREAMING_CHAT_MODEL) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(recorder.streamingChatModel(config, configName)); - addQualifierIfNecessary(streamingBuilder, configName); - beanProducer.produce(streamingBuilder.done()); + String modelId = NamedConfigUtil.isDefault(configName) + ? fixedRuntimeConfig.defaultConfig().chatModel().modelId() + : fixedRuntimeConfig.namedConfig().get(configName).chatModel().modelId(); + + boolean promptFormatterIsEnabled = NamedConfigUtil.isDefault(configName) + ? fixedRuntimeConfig.defaultConfig().chatModel().promptFormatter() + : fixedRuntimeConfig.namedConfig().get(configName).chatModel().promptFormatter(); + + PromptFormatter promptFormatter = null; + + if (promptFormatterIsEnabled) { + promptFormatter = PromptFormatterMapper.get(modelId); + if (promptFormatter == null) { + log.warnf( + "The \"%s\" model does not have a PromptFormatter implementation, no tags are automatically generated.", + modelId); + } + } + + var registerAiService = annotationInstances.stream() + .filter(annotationInstance -> { + var modelName = annotationInstance.value("modelName"); + if (modelName == null) { + return configName.equals(NamedConfigUtil.DEFAULT_NAME); + } else { + return configName.equals(modelName.asString()); + } + }).findFirst(); + + if (!registerAiService.isEmpty()) { + + var classInfo = registerAiService.get().target().asClass(); + var tools = classInfo.annotation(LangChain4jDotNames.REGISTER_AI_SERVICES).value("tools"); + + if (tools != null && !PromptFormatterMapper.toolIsSupported(modelId)) { + throw new RuntimeException( + "The tool functionality is not supported for the model \"%s\"".formatted(modelId)); + } + + if (promptFormatter != null) { + var systemMessage = getTemplateFromAnnotationInstance( + classInfo.annotation(LangChain4jDotNames.SYSTEM_MESSAGE)); + var userMessage = getTemplateFromAnnotationInstance(classInfo.annotation(LangChain4jDotNames.USER_MESSAGE)); + var tokenAlreadyExist = promptFormatter.tokens().stream() + .filter(token -> systemMessage.contains(token) || userMessage.contains(token)) + .findFirst(); + + if (tokenAlreadyExist.isPresent()) { + log.warnf( + "The prompt in the AIService \"%s\" already contains one or more tags for the model \"%s\", the prompt-formatter option is disabled." + .formatted(classInfo.name().toString(), modelId)); + promptFormatter = null; + } + } } + + chatModelBuilder.produce(new WatsonxChatModelProviderBuildItem(configName, promptFormatter)); + } + } + + @BuildStep + @Record(ExecutionTime.RUNTIME_INIT) + void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeConfig, + LangChain4jWatsonxFixedRuntimeConfig fixedRuntimeConfig, + List selectedChatItem, + List selectedEmbedding, + BuildProducer beanProducer) { + + for (var selected : selectedChatItem) { + + String configName = selected.getConfigName(); + PromptFormatter promptFormatter = selected.getPromptFormatter(); + + var chatModel = recorder.chatModel(runtimeConfig, fixedRuntimeConfig, configName, promptFormatter); + var chatBuilder = SyntheticBeanBuildItem + .configure(CHAT_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(chatModel); + addQualifierIfNecessary(chatBuilder, configName); + beanProducer.produce(chatBuilder.done()); + + var tokenizerBuilder = SyntheticBeanBuildItem + .configure(TOKEN_COUNT_ESTIMATOR) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(chatModel); + addQualifierIfNecessary(tokenizerBuilder, configName); + beanProducer.produce(tokenizerBuilder.done()); + + var streamingBuilder = SyntheticBeanBuildItem + .configure(STREAMING_CHAT_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.streamingChatModel(runtimeConfig, fixedRuntimeConfig, configName, + promptFormatter)); + addQualifierIfNecessary(streamingBuilder, configName); + beanProducer.produce(streamingBuilder.done()); } for (var selected : selectedEmbedding) { @@ -100,7 +189,7 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig config, .defaultBean() .unremovable() .scope(ApplicationScoped.class) - .supplier(recorder.embeddingModel(config, configName)); + .supplier(recorder.embeddingModel(runtimeConfig, configName)); addQualifierIfNecessary(builder, configName); beanProducer.produce(builder.done()); } diff --git a/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/items/WatsonxChatModelProviderBuildItem.java b/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/items/WatsonxChatModelProviderBuildItem.java new file mode 100644 index 000000000..a69131a40 --- /dev/null +++ b/model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/items/WatsonxChatModelProviderBuildItem.java @@ -0,0 +1,23 @@ +package io.quarkiverse.langchain4j.watsonx.deployment.items; + +import io.quarkiverse.langchain4j.watsonx.prompt.PromptFormatter; +import io.quarkus.builder.item.MultiBuildItem; + +public final class WatsonxChatModelProviderBuildItem extends MultiBuildItem { + + private final String configName; + private final PromptFormatter promptFormatter; + + public WatsonxChatModelProviderBuildItem(String configName, PromptFormatter promptTemplate) { + this.configName = configName; + this.promptFormatter = promptTemplate; + } + + public String getConfigName() { + return configName; + } + + public PromptFormatter getPromptFormatter() { + return promptFormatter; + } +} diff --git a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AiChatServiceTest.java b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AiChatServiceTest.java index 2847e2ad5..d6cc94cd0 100644 --- a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AiChatServiceTest.java +++ b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AiChatServiceTest.java @@ -1,6 +1,5 @@ package com.ibm.langchain4j.watsonx.deployment; -import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; import static org.junit.jupiter.api.Assertions.assertEquals; import java.util.Date; @@ -10,39 +9,19 @@ import org.jboss.shrinkwrap.api.ShrinkWrap; import org.jboss.shrinkwrap.api.spec.JavaArchive; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.github.tomakehurst.wiremock.WireMockServer; - -import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.service.SystemMessage; import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.watsonx.bean.Parameters; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; -import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi; import io.quarkiverse.langchain4j.watsonx.runtime.config.ChatModelConfig; import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; import io.quarkus.test.QuarkusUnitTest; -public class AiChatServiceTest { - - static WireMockServer watsonxServer; - static WireMockServer iamServer; - static ObjectMapper mapper; - - @Inject - LangChain4jWatsonxConfig langchain4jWatsonConfig; - - @Inject - ChatLanguageModel chatModel; - - static WireMockUtil mockServers; +public class AiChatServiceTest extends WireMockAbstract { @RegisterExtension static QuarkusUnitTest unitTest = new QuarkusUnitTest() @@ -52,36 +31,11 @@ public class AiChatServiceTest { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); - @BeforeAll - static void beforeAll() { - mapper = WatsonxRestApi.objectMapper(new ObjectMapper()); - - watsonxServer = new WireMockServer(options().port(WireMockUtil.PORT_WATSONX_SERVER)); - watsonxServer.start(); - - iamServer = new WireMockServer(options().port(WireMockUtil.PORT_IAM_SERVER)); - iamServer.start(); - - mockServers = new WireMockUtil(watsonxServer, iamServer); - } - - @AfterAll - static void afterAll() { - watsonxServer.stop(); - iamServer.stop(); - } - - @BeforeEach - void beforeEach() { - watsonxServer.resetAll(); - iamServer.resetAll(); - } - @RegisterAiService @Singleton interface NewAIService { - @SystemMessage("This is a systemMessage\n") + @SystemMessage("This is a systemMessage") @UserMessage("This is a userMessage {text}") String chat(String text); } @@ -94,7 +48,7 @@ void chat() throws Exception { LangChain4jWatsonxConfig.WatsonConfig watsonConfig = langchain4jWatsonConfig.defaultConfig(); ChatModelConfig chatModelConfig = watsonConfig.chatModel(); - String modelId = chatModelConfig.modelId(); + String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); String projectId = watsonConfig.projectId(); String input = new StringBuilder() .append("This is a systemMessage") diff --git a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AiEmbeddingTest.java b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AiEmbeddingTest.java index 7ea245d64..80449b537 100644 --- a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AiEmbeddingTest.java +++ b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AiEmbeddingTest.java @@ -1,6 +1,5 @@ package com.ibm.langchain4j.watsonx.deployment; -import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -11,38 +10,18 @@ import org.jboss.shrinkwrap.api.ShrinkWrap; import org.jboss.shrinkwrap.api.spec.JavaArchive; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.github.tomakehurst.wiremock.WireMockServer; - import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.output.Response; import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest; -import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi; -import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; import io.quarkus.test.QuarkusUnitTest; -public class AiEmbeddingTest { - - static WireMockServer watsonxServer; - static WireMockServer iamServer; - static ObjectMapper mapper; - - @Inject - LangChain4jWatsonxConfig langchain4jWatsonConfig; - - @Inject - ChatLanguageModel model; - - static WireMockUtil mockServers; +public class AiEmbeddingTest extends WireMockAbstract { @RegisterExtension static QuarkusUnitTest unitTest = new QuarkusUnitTest() @@ -52,30 +31,8 @@ public class AiEmbeddingTest { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); - @BeforeAll - static void beforeAll() { - mapper = WatsonxRestApi.objectMapper(new ObjectMapper()); - - watsonxServer = new WireMockServer(options().port(WireMockUtil.PORT_WATSONX_SERVER)); - watsonxServer.start(); - - iamServer = new WireMockServer(options().port(WireMockUtil.PORT_IAM_SERVER)); - iamServer.start(); - - mockServers = new WireMockUtil(watsonxServer, iamServer); - } - - @AfterAll - static void afterAll() { - watsonxServer.stop(); - iamServer.stop(); - } - - @BeforeEach - void beforeEach() { - watsonxServer.resetAll(); - iamServer.resetAll(); - } + @Inject + ChatLanguageModel model; @Inject EmbeddingModel embeddingModel; diff --git a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AllPropertiesTest.java b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AllPropertiesTest.java index 3f428a791..bd83120f9 100644 --- a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AllPropertiesTest.java +++ b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AllPropertiesTest.java @@ -1,10 +1,10 @@ package com.ibm.langchain4j.watsonx.deployment; -import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; import static org.assertj.core.api.Assertions.assertThat; import static org.awaitility.Awaitility.await; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import java.time.Duration; import java.util.Date; @@ -16,15 +16,9 @@ import org.jboss.shrinkwrap.api.ShrinkWrap; import org.jboss.shrinkwrap.api.spec.JavaArchive; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.github.tomakehurst.wiremock.WireMockServer; - import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -32,37 +26,18 @@ import dev.langchain4j.model.chat.TokenCountEstimator; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.output.Response; +import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel; +import io.quarkiverse.langchain4j.watsonx.WatsonxStreamingChatModel; import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest; import io.quarkiverse.langchain4j.watsonx.bean.Parameters; import io.quarkiverse.langchain4j.watsonx.bean.Parameters.LengthPenalty; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; -import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi; -import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; +import io.quarkiverse.langchain4j.watsonx.prompt.impl.NoopPromptFormatter; +import io.quarkus.arc.ClientProxy; import io.quarkus.test.QuarkusUnitTest; -public class AllPropertiesTest { - - static WireMockServer watsonxServer; - static WireMockServer iamServer; - static ObjectMapper mapper; - - @Inject - LangChain4jWatsonxConfig langchain4jWatsonConfig; - - @Inject - ChatLanguageModel chatModel; - - @Inject - StreamingChatLanguageModel streamingChatModel; - - @Inject - EmbeddingModel embeddingModel; - - @Inject - TokenCountEstimator tokenCountEstimator; - - static WireMockUtil mockServers; +public class AllPropertiesTest extends WireMockAbstract { @RegisterExtension static QuarkusUnitTest unitTest = new QuarkusUnitTest() @@ -76,7 +51,9 @@ public class AllPropertiesTest { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.timeout", "60s") .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.grant-type", "grantME") - .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.model-id", "my_super_model") + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.model-id", "my_super_model") + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.prompt-formatter", "true") + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.prompt-joiner", "@") .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.decoding-method", "greedy") .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.length-penalty.decay-factor", "1.1") .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.length-penalty.start-index", "0") @@ -90,38 +67,28 @@ public class AllPropertiesTest { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.repetition-penalty", "2.0") .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.truncate-input-tokens", "0") .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.include-stop-sequence", "false") - .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.prompt-joiner", "@") .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.embedding-model.model-id", "my_super_embedding_model") .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); - @BeforeAll - static void beforeAll() { - mapper = WatsonxRestApi.objectMapper(new ObjectMapper()); - - watsonxServer = new WireMockServer(options().port(WireMockUtil.PORT_WATSONX_SERVER)); - watsonxServer.start(); - - iamServer = new WireMockServer(options().port(WireMockUtil.PORT_IAM_SERVER)); - iamServer.start(); - - mockServers = new WireMockUtil(watsonxServer, iamServer); - } - - @BeforeEach - void beforeEach() { - watsonxServer.resetAll(); - iamServer.resetAll(); + @Override + void handlerBeforeEach() { mockServers.mockIAMBuilder(200) .grantType(langchain4jWatsonConfig.defaultConfig().iam().grantType()) .response(WireMockUtil.BEARER_TOKEN, new Date()) .build(); } - @AfterAll - static void afterAll() { - watsonxServer.stop(); - iamServer.stop(); - } + @Inject + ChatLanguageModel chatModel; + + @Inject + StreamingChatLanguageModel streamingChatModel; + + @Inject + EmbeddingModel embeddingModel; + + @Inject + TokenCountEstimator tokenCountEstimator; static Parameters parameters = Parameters.builder() .minNewTokens(10) @@ -138,41 +105,52 @@ static void afterAll() { .includeStopSequence(false) .build(); + @Test + void prompt_formatter() { + var unwrapChatModel = (WatsonxChatModel) ClientProxy.unwrap(chatModel); + assertTrue(unwrapChatModel.getPromptFormatter() instanceof NoopPromptFormatter); + + var unwrapStreamingChatModel = (WatsonxStreamingChatModel) ClientProxy.unwrap(streamingChatModel); + assertTrue(unwrapStreamingChatModel.getPromptFormatter() instanceof NoopPromptFormatter); + } + @Test void check_config() throws Exception { - var config = langchain4jWatsonConfig.defaultConfig(); - assertEquals(WireMockUtil.URL_WATSONX_SERVER, config.baseUrl().toString()); - assertEquals(WireMockUtil.URL_IAM_SERVER, config.iam().baseUrl().toString()); - assertEquals(WireMockUtil.API_KEY, config.apiKey()); - assertEquals(WireMockUtil.PROJECT_ID, config.projectId()); - assertEquals(Duration.ofSeconds(60), config.timeout().get()); - assertEquals(Duration.ofSeconds(60), config.iam().timeout().get()); - assertEquals("grantME", config.iam().grantType()); - assertEquals(true, config.logRequests().orElse(false)); - assertEquals(true, config.logResponses().orElse(false)); - assertEquals("aaaa-mm-dd", config.version()); - assertEquals("my_super_model", config.chatModel().modelId()); - assertEquals("greedy", config.chatModel().decodingMethod()); - assertEquals(1.1, config.chatModel().lengthPenalty().get().decayFactor().get()); - assertEquals(0, config.chatModel().lengthPenalty().get().startIndex().get()); - assertEquals(200, config.chatModel().maxNewTokens()); - assertEquals(10, config.chatModel().minNewTokens()); - assertEquals(2, config.chatModel().randomSeed().get()); - assertEquals(List.of("\n", "\n\n"), config.chatModel().stopSequences().get()); - assertEquals(1.5, config.chatModel().temperature()); - assertEquals(90, config.chatModel().topK().get()); - assertEquals(0.5, config.chatModel().topP().get()); - assertEquals(2.0, config.chatModel().repetitionPenalty().get()); - assertEquals(0, config.chatModel().truncateInputTokens().get()); - assertEquals(false, config.chatModel().includeStopSequence().get()); - assertEquals("@", config.chatModel().promptJoiner().get()); - assertEquals("my_super_embedding_model", config.embeddingModel().modelId()); + var runtimeConfig = langchain4jWatsonConfig.defaultConfig(); + var fixedRuntimeConfig = langchain4jWatsonFixedRuntimeConfig.defaultConfig(); + assertEquals(WireMockUtil.URL_WATSONX_SERVER, runtimeConfig.baseUrl().toString()); + assertEquals(WireMockUtil.URL_IAM_SERVER, runtimeConfig.iam().baseUrl().toString()); + assertEquals(WireMockUtil.API_KEY, runtimeConfig.apiKey()); + assertEquals(WireMockUtil.PROJECT_ID, runtimeConfig.projectId()); + assertEquals(Duration.ofSeconds(60), runtimeConfig.timeout().get()); + assertEquals(Duration.ofSeconds(60), runtimeConfig.iam().timeout().get()); + assertEquals("grantME", runtimeConfig.iam().grantType()); + assertEquals(true, runtimeConfig.logRequests().orElse(false)); + assertEquals(true, runtimeConfig.logResponses().orElse(false)); + assertEquals("aaaa-mm-dd", runtimeConfig.version()); + assertEquals("my_super_model", fixedRuntimeConfig.chatModel().modelId()); + assertEquals("greedy", runtimeConfig.chatModel().decodingMethod()); + assertEquals(1.1, runtimeConfig.chatModel().lengthPenalty().decayFactor().get()); + assertEquals(0, runtimeConfig.chatModel().lengthPenalty().startIndex().get()); + assertEquals(200, runtimeConfig.chatModel().maxNewTokens()); + assertEquals(10, runtimeConfig.chatModel().minNewTokens()); + assertEquals(2, runtimeConfig.chatModel().randomSeed().get()); + assertEquals(List.of("\n", "\n\n"), runtimeConfig.chatModel().stopSequences().get()); + assertEquals(1.5, runtimeConfig.chatModel().temperature()); + assertEquals(90, runtimeConfig.chatModel().topK().get()); + assertEquals(0.5, runtimeConfig.chatModel().topP().get()); + assertEquals(2.0, runtimeConfig.chatModel().repetitionPenalty().get()); + assertEquals(0, runtimeConfig.chatModel().truncateInputTokens().get()); + assertEquals(false, runtimeConfig.chatModel().includeStopSequence().get()); + assertEquals("@", runtimeConfig.chatModel().promptJoiner()); + assertEquals(true, fixedRuntimeConfig.chatModel().promptFormatter()); + assertEquals("my_super_embedding_model", runtimeConfig.embeddingModel().modelId()); } @Test void check_chat_model_config() throws Exception { var config = langchain4jWatsonConfig.defaultConfig(); - String modelId = config.chatModel().modelId(); + String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); String projectId = config.projectId(); TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, "SystemMessage@UserMessage", parameters); @@ -207,7 +185,7 @@ void check_embedding_model() throws Exception { @Test void check_token_count_estimator() throws Exception { var config = langchain4jWatsonConfig.defaultConfig(); - String modelId = config.chatModel().modelId(); + String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); String projectId = config.projectId(); var body = new TokenizationRequest(modelId, "test", projectId); @@ -223,7 +201,7 @@ void check_token_count_estimator() throws Exception { @Test void check_chat_streaming_model_config() throws Exception { var config = langchain4jWatsonConfig.defaultConfig(); - String modelId = config.chatModel().modelId(); + String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); String projectId = config.projectId(); TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, "SystemMessage@UserMessage", parameters); diff --git a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/CacheTokenTest.java b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/CacheTokenTest.java index f8c291883..198214867 100644 --- a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/CacheTokenTest.java +++ b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/CacheTokenTest.java @@ -1,6 +1,5 @@ package com.ibm.langchain4j.watsonx.deployment; -import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; import static com.ibm.langchain4j.watsonx.deployment.WireMockUtil.streamingResponseHandler; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; @@ -15,14 +14,9 @@ import org.jboss.shrinkwrap.api.ShrinkWrap; import org.jboss.shrinkwrap.api.spec.JavaArchive; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.github.tomakehurst.wiremock.WireMockServer; import com.github.tomakehurst.wiremock.stubbing.Scenario; import dev.langchain4j.data.message.AiMessage; @@ -30,16 +24,11 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.chat.TokenCountEstimator; import dev.langchain4j.model.embedding.EmbeddingModel; -import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi; -import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; import io.quarkus.test.QuarkusUnitTest; -public class CacheTokenTest { +public class CacheTokenTest extends WireMockAbstract { static int cacheTimeout = 2000; - static WireMockServer watsonxServer; - static WireMockServer iamServer; - static ObjectMapper mapper; static String RESPONSE_401 = """ { "errors": [ @@ -54,8 +43,13 @@ public class CacheTokenTest { } """; - @Inject - LangChain4jWatsonxConfig config; + @RegisterExtension + static QuarkusUnitTest unitTest = new QuarkusUnitTest() + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.base-url", WireMockUtil.URL_WATSONX_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); @Inject ChatLanguageModel chatModel; @@ -69,41 +63,6 @@ public class CacheTokenTest { @Inject TokenCountEstimator tokenCountEstimator; - static WireMockUtil mockServers; - - @RegisterExtension - static QuarkusUnitTest unitTest = new QuarkusUnitTest() - .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.base-url", WireMockUtil.URL_WATSONX_SERVER) - .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) - .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) - .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) - .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); - - @BeforeAll - static void beforeAll() { - mapper = WatsonxRestApi.objectMapper(new ObjectMapper()); - - watsonxServer = new WireMockServer(options().port(WireMockUtil.PORT_WATSONX_SERVER)); - watsonxServer.start(); - - iamServer = new WireMockServer(options().port(WireMockUtil.PORT_IAM_SERVER)); - iamServer.start(); - - mockServers = new WireMockUtil(watsonxServer, iamServer); - } - - @AfterAll - static void afterAll() { - watsonxServer.stop(); - iamServer.stop(); - } - - @BeforeEach - void beforeEach() { - watsonxServer.resetAll(); - iamServer.resetAll(); - } - @Test void try_token_cache() throws InterruptedException { diff --git a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/DefaultPropertiesTest.java b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/DefaultPropertiesTest.java index b0725fc83..a98d851e5 100644 --- a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/DefaultPropertiesTest.java +++ b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/DefaultPropertiesTest.java @@ -1,6 +1,5 @@ package com.ibm.langchain4j.watsonx.deployment; -import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; import static org.assertj.core.api.Assertions.assertThat; import static org.awaitility.Awaitility.await; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -18,15 +17,9 @@ import org.jboss.shrinkwrap.api.ShrinkWrap; import org.jboss.shrinkwrap.api.spec.JavaArchive; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.github.tomakehurst.wiremock.WireMockServer; - import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -34,36 +27,16 @@ import dev.langchain4j.model.chat.TokenCountEstimator; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.output.Response; +import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel; import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest; import io.quarkiverse.langchain4j.watsonx.bean.Parameters; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; -import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi; -import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; +import io.quarkiverse.langchain4j.watsonx.prompt.impl.NoopPromptFormatter; +import io.quarkus.arc.ClientProxy; import io.quarkus.test.QuarkusUnitTest; -public class DefaultPropertiesTest { - - static WireMockServer watsonxServer; - static WireMockServer iamServer; - static ObjectMapper mapper; - - @Inject - LangChain4jWatsonxConfig langchain4jWatsonConfig; - - @Inject - ChatLanguageModel model; - - @Inject - StreamingChatLanguageModel streamingChatModel; - - @Inject - EmbeddingModel embeddingModel; - - @Inject - TokenCountEstimator tokenCountEstimator; - - static WireMockUtil mockServers; +public class DefaultPropertiesTest extends WireMockAbstract { @RegisterExtension static QuarkusUnitTest unitTest = new QuarkusUnitTest() @@ -73,29 +46,8 @@ public class DefaultPropertiesTest { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); - @BeforeAll - static void beforeAll() { - mapper = WatsonxRestApi.objectMapper(new ObjectMapper()); - - watsonxServer = new WireMockServer(options().port(WireMockUtil.PORT_WATSONX_SERVER)); - watsonxServer.start(); - - iamServer = new WireMockServer(options().port(WireMockUtil.PORT_IAM_SERVER)); - iamServer.start(); - - mockServers = new WireMockUtil(watsonxServer, iamServer); - } - - @AfterAll - static void afterAll() { - watsonxServer.stop(); - iamServer.stop(); - } - - @BeforeEach - void beforeEach() { - watsonxServer.resetAll(); - iamServer.resetAll(); + @Override + void handlerBeforeEach() { mockServers.mockIAMBuilder(200) .response("my_super_token", new Date()) .build(); @@ -108,46 +60,67 @@ void beforeEach() { .temperature(1.0) .build(); + @Inject + ChatLanguageModel chatModel; + + @Inject + StreamingChatLanguageModel streamingChatModel; + + @Inject + EmbeddingModel embeddingModel; + + @Inject + TokenCountEstimator tokenCountEstimator; + + @Test + void prompt_formatter() { + var unwrapChatModel = (WatsonxChatModel) ClientProxy.unwrap(chatModel); + assertTrue(unwrapChatModel.getPromptFormatter() instanceof NoopPromptFormatter); + } + @Test void check_config() throws Exception { - var config = langchain4jWatsonConfig.defaultConfig(); - assertEquals(Optional.empty(), config.timeout()); - assertEquals(Optional.empty(), config.iam().timeout()); - assertEquals(false, config.logRequests().orElse(false)); - assertEquals(false, config.logResponses().orElse(false)); - assertEquals(WireMockUtil.VERSION, config.version()); - assertEquals(WireMockUtil.DEFAULT_CHAT_MODEL, config.chatModel().modelId()); - assertEquals("greedy", config.chatModel().decodingMethod()); - assertEquals(null, config.chatModel().lengthPenalty().orElse(null)); - assertEquals(200, config.chatModel().maxNewTokens()); - assertEquals(0, config.chatModel().minNewTokens()); - assertEquals(null, config.chatModel().randomSeed().orElse(null)); - assertEquals(null, config.chatModel().stopSequences().orElse(null)); - assertEquals(1.0, config.chatModel().temperature()); - assertEquals("", config.chatModel().promptJoiner().orElse("")); - assertTrue(config.chatModel().topK().isEmpty()); - assertTrue(config.chatModel().topP().isEmpty()); - assertTrue(config.chatModel().repetitionPenalty().isEmpty()); - assertTrue(config.chatModel().truncateInputTokens().isEmpty()); - assertTrue(config.chatModel().includeStopSequence().isEmpty()); - assertEquals("urn:ibm:params:oauth:grant-type:apikey", config.iam().grantType()); - assertEquals(WireMockUtil.DEFAULT_EMBEDDING_MODEL, config.embeddingModel().modelId()); + var runtimeConfig = langchain4jWatsonConfig.defaultConfig(); + var fixedRuntimeConfig = langchain4jWatsonFixedRuntimeConfig.defaultConfig(); + assertEquals(Optional.empty(), runtimeConfig.timeout()); + assertEquals(Optional.empty(), runtimeConfig.iam().timeout()); + assertEquals(false, runtimeConfig.logRequests().orElse(false)); + assertEquals(false, runtimeConfig.logResponses().orElse(false)); + assertEquals(WireMockUtil.VERSION, runtimeConfig.version()); + assertEquals(WireMockUtil.DEFAULT_CHAT_MODEL, fixedRuntimeConfig.chatModel().modelId()); + assertEquals("greedy", runtimeConfig.chatModel().decodingMethod()); + assertEquals(null, runtimeConfig.chatModel().lengthPenalty().decayFactor().orElse(null)); + assertEquals(null, runtimeConfig.chatModel().lengthPenalty().startIndex().orElse(null)); + assertEquals(200, runtimeConfig.chatModel().maxNewTokens()); + assertEquals(0, runtimeConfig.chatModel().minNewTokens()); + assertEquals(null, runtimeConfig.chatModel().randomSeed().orElse(null)); + assertEquals(null, runtimeConfig.chatModel().stopSequences().orElse(null)); + assertEquals(1.0, runtimeConfig.chatModel().temperature()); + assertEquals("\n", runtimeConfig.chatModel().promptJoiner()); + assertEquals(false, fixedRuntimeConfig.chatModel().promptFormatter()); + assertTrue(runtimeConfig.chatModel().topK().isEmpty()); + assertTrue(runtimeConfig.chatModel().topP().isEmpty()); + assertTrue(runtimeConfig.chatModel().repetitionPenalty().isEmpty()); + assertTrue(runtimeConfig.chatModel().truncateInputTokens().isEmpty()); + assertTrue(runtimeConfig.chatModel().includeStopSequence().isEmpty()); + assertEquals("urn:ibm:params:oauth:grant-type:apikey", runtimeConfig.iam().grantType()); + assertEquals(WireMockUtil.DEFAULT_EMBEDDING_MODEL, runtimeConfig.embeddingModel().modelId()); } @Test void check_chat_model_config() throws Exception { var config = langchain4jWatsonConfig.defaultConfig(); - String modelId = config.chatModel().modelId(); + String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); String projectId = config.projectId(); - TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, "SystemMessageUserMessage", parameters); + TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, "SystemMessage\nUserMessage", parameters); mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) .body(mapper.writeValueAsString(body)) .response(WireMockUtil.RESPONSE_WATSONX_CHAT_API) .build(); - assertEquals("AI Response", model.generate(dev.langchain4j.data.message.SystemMessage.from("SystemMessage"), + assertEquals("AI Response", chatModel.generate(dev.langchain4j.data.message.SystemMessage.from("SystemMessage"), dev.langchain4j.data.message.UserMessage.from("UserMessage")).content().text()); } @@ -173,7 +146,7 @@ void check_embedding_model() throws Exception { @Test void check_token_count_estimator() throws Exception { var config = langchain4jWatsonConfig.defaultConfig(); - String modelId = config.chatModel().modelId(); + String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); String projectId = config.projectId(); var body = new TokenizationRequest(modelId, "test", projectId); @@ -189,10 +162,10 @@ void check_token_count_estimator() throws Exception { @Test void check_chat_streaming_model_config() throws Exception { var config = langchain4jWatsonConfig.defaultConfig(); - String modelId = config.chatModel().modelId(); + String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); String projectId = config.projectId(); - TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, "SystemMessageUserMessage", parameters); + TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, "SystemMessage\nUserMessage", parameters); mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_STREAMING_API, 200) .body(mapper.writeValueAsString(body)) diff --git a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/HttpErrorTest.java b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/HttpErrorTest.java index 3b68795e1..8fde40341 100644 --- a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/HttpErrorTest.java +++ b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/HttpErrorTest.java @@ -1,6 +1,5 @@ package com.ibm.langchain4j.watsonx.deployment; -import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -16,35 +15,15 @@ import org.jboss.resteasy.reactive.ClientWebApplicationException; import org.jboss.shrinkwrap.api.ShrinkWrap; import org.jboss.shrinkwrap.api.spec.JavaArchive; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.github.tomakehurst.wiremock.WireMockServer; - import dev.langchain4j.model.chat.ChatLanguageModel; import io.quarkiverse.langchain4j.watsonx.bean.WatsonxError; -import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi; import io.quarkiverse.langchain4j.watsonx.exception.WatsonxException; -import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; import io.quarkus.test.QuarkusUnitTest; -public class HttpErrorTest { - - static WireMockServer watsonxServer; - static WireMockServer iamServer; - static ObjectMapper mapper; - - @Inject - LangChain4jWatsonxConfig config; - - @Inject - ChatLanguageModel chatModel; - - static WireMockUtil mockServers; +public class HttpErrorTest extends WireMockAbstract { @RegisterExtension static QuarkusUnitTest unitTest = new QuarkusUnitTest() @@ -54,30 +33,8 @@ public class HttpErrorTest { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); - @BeforeAll - static void beforeAll() { - mapper = WatsonxRestApi.objectMapper(new ObjectMapper()); - - watsonxServer = new WireMockServer(options().port(WireMockUtil.PORT_WATSONX_SERVER)); - watsonxServer.start(); - - iamServer = new WireMockServer(options().port(WireMockUtil.PORT_IAM_SERVER)); - iamServer.start(); - - mockServers = new WireMockUtil(watsonxServer, iamServer); - } - - @AfterAll - static void afterAll() { - watsonxServer.stop(); - iamServer.stop(); - } - - @BeforeEach - void beforeEach() { - watsonxServer.resetAll(); - iamServer.resetAll(); - } + @Inject + ChatLanguageModel chatModel; @Test void error_404_model_not_supported() { diff --git a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/PromptFormatterExceptionTest.java b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/PromptFormatterExceptionTest.java new file mode 100644 index 000000000..b9833e2d2 --- /dev/null +++ b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/PromptFormatterExceptionTest.java @@ -0,0 +1,79 @@ +package com.ibm.langchain4j.watsonx.deployment; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; + +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.agent.tool.Tool; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkus.test.QuarkusUnitTest; + +public class PromptFormatterExceptionTest { + + @RegisterAiService(tools = Calculator.class) + interface AIService { + + } + + @Singleton + static class Calculator { + + @Tool("calculates the sum between two numbers") + double squareRoot(int firstNumber, int secondNumber) { + return firstNumber + secondNumber; + } + } + + @Nested + class ToolsModelNotSupported { + @RegisterExtension + static QuarkusUnitTest unitTest = new QuarkusUnitTest() + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.base-url", WireMockUtil.URL_WATSONX_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.embedding-model.model-id", + WireMockUtil.DEFAULT_CHAT_MODEL) + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.prompt-formatter", "true") + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(AIService.class, Calculator.class)) + .assertException(t -> { + assertThat(t).isInstanceOf(RuntimeException.class) + .hasMessage("The tool functionality is not supported for the model \"%s\"" + .formatted(WireMockUtil.DEFAULT_CHAT_MODEL)); + }); + + @Test + void test() { + fail("Should not be called"); + } + } + + @Nested + class ToolsPromptFormatterOff { + @RegisterExtension + static QuarkusUnitTest unitTest = new QuarkusUnitTest() + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.base-url", WireMockUtil.URL_WATSONX_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.prompt-formatter", "false") + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(AIService.class, Calculator.class)) + .assertException(t -> { + assertThat(t).isInstanceOf(RuntimeException.class) + .hasMessage("The tool functionality is not supported for the model \"%s\"" + .formatted(WireMockUtil.DEFAULT_CHAT_MODEL)); + }); + + @Test + void test() { + fail("Should not be called"); + } + } +} diff --git a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/PromptFormatterForceDefaultTest.java b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/PromptFormatterForceDefaultTest.java new file mode 100644 index 000000000..606920dce --- /dev/null +++ b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/PromptFormatterForceDefaultTest.java @@ -0,0 +1,94 @@ +package com.ibm.langchain4j.watsonx.deployment; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.service.SystemMessage; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.ModelName; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel; +import io.quarkiverse.langchain4j.watsonx.WatsonxStreamingChatModel; +import io.quarkiverse.langchain4j.watsonx.prompt.impl.NoopPromptFormatter; +import io.quarkus.arc.ClientProxy; +import io.quarkus.test.QuarkusUnitTest; + +public class PromptFormatterForceDefaultTest { + + @RegisterExtension + static QuarkusUnitTest unitTest = new QuarkusUnitTest() + + .overrideRuntimeConfigKey("quarkus.langchain4j.model1.chat-model.provider", "watsonx") + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.model1.base-url", WireMockUtil.URL_WATSONX_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.model1.iam.base-url", WireMockUtil.URL_IAM_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.model1.api-key", WireMockUtil.API_KEY) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.model1.project-id", WireMockUtil.PROJECT_ID) + .overrideConfigKey("quarkus.langchain4j.watsonx.model1.chat-model.prompt-formatter", "true") + .overrideRuntimeConfigKey("quarkus.langchain4j.model2.chat-model.provider", "watsonx") + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.model2.base-url", WireMockUtil.URL_WATSONX_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.model2.iam.base-url", WireMockUtil.URL_IAM_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.model2.api-key", WireMockUtil.API_KEY) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.model2.project-id", WireMockUtil.PROJECT_ID) + .overrideConfigKey("quarkus.langchain4j.watsonx.model2.chat-model.prompt-formatter", "true") + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClass(WireMockUtil.class)); + + @RegisterAiService(modelName = "model1") + @Singleton + interface AIServiceWithTokenInSystemMessage { + @SystemMessage("<|system|>This is a systemMessage") + @UserMessage("{text}") + String chat(String text); + } + + @RegisterAiService(modelName = "model2") + @Singleton + interface AIServiceWithTokenInUserMessage { + @SystemMessage("This is a systemMessage") + @UserMessage("<|system|>{text}") + String chat(String text); + } + + @Inject + @ModelName("model1") + ChatLanguageModel model1ChatModel; + + @Inject + @ModelName("model1") + StreamingChatLanguageModel model1StreamingChatModel; + + @Inject + @ModelName("model2") + ChatLanguageModel model2ChatModel; + + @Inject + @ModelName("model2") + StreamingChatLanguageModel model2StreamingChatModel; + + @Test + void prompt_formatter_model_1() { + var unwrapChatModel = (WatsonxChatModel) ClientProxy.unwrap(model1ChatModel); + assertTrue(unwrapChatModel.getPromptFormatter() instanceof NoopPromptFormatter); + + var unwrapStreamingChatModel = (WatsonxStreamingChatModel) ClientProxy.unwrap(model1StreamingChatModel); + assertTrue(unwrapStreamingChatModel.getPromptFormatter() instanceof NoopPromptFormatter); + } + + @Test + void prompt_formatter_model_2() { + var unwrapChatModel = (WatsonxChatModel) ClientProxy.unwrap(model2ChatModel); + assertTrue(unwrapChatModel.getPromptFormatter() instanceof NoopPromptFormatter); + + var unwrapStreamingChatModel = (WatsonxStreamingChatModel) ClientProxy.unwrap(model2StreamingChatModel); + assertTrue(unwrapStreamingChatModel.getPromptFormatter() instanceof NoopPromptFormatter); + } +} diff --git a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/PromptFormatterModelTest.java b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/PromptFormatterModelTest.java new file mode 100644 index 000000000..256dc0aae --- /dev/null +++ b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/PromptFormatterModelTest.java @@ -0,0 +1,424 @@ +package com.ibm.langchain4j.watsonx.deployment; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; +import io.quarkiverse.langchain4j.watsonx.prompt.impl.GraniteCodePromptFormatter; +import io.quarkiverse.langchain4j.watsonx.prompt.impl.GranitePromptFormatter; +import io.quarkiverse.langchain4j.watsonx.prompt.impl.LlamaPromptFormatter; +import io.quarkiverse.langchain4j.watsonx.prompt.impl.MistralLargePromptFormatter; +import io.quarkiverse.langchain4j.watsonx.prompt.impl.MistralPromptFormatter; +import io.quarkiverse.langchain4j.watsonx.prompt.impl.NoopPromptFormatter; +import io.quarkus.test.QuarkusUnitTest; + +public class PromptFormatterModelTest { + + @RegisterExtension + static QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)); + + @Test + void llama_prompt_formatter() { + + var promptFormatter = new LlamaPromptFormatter(); + var prompt1 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines")); + + assertEquals(""" + <|begin_of_text|><|start_header_id|>system<|end_header_id|> + + You are a poet<|eot_id|><|start_header_id|>user<|end_header_id|> + Write a poem about dog of ten lines<|eot_id|><|start_header_id|>assistant<|end_header_id|> + """, promptFormatter.format(prompt1)); + + var prompt2 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("I'm an assistant")); + + assertEquals(""" + <|begin_of_text|><|start_header_id|>system<|end_header_id|> + + You are a poet<|eot_id|><|start_header_id|>user<|end_header_id|> + Write a poem about dog of ten lines<|eot_id|><|start_header_id|>assistant<|end_header_id|> + I'm an assistant<|eot_id|>""", promptFormatter.format(prompt2)); + + var prompt3 = List. of(UserMessage.from("Write a poem about dog of ten lines")); + + assertEquals(""" + <|begin_of_text|><|start_header_id|>user<|end_header_id|> + Write a poem about dog of ten lines<|eot_id|><|start_header_id|>assistant<|end_header_id|> + """, promptFormatter.format(prompt3)); + + var prompt4 = List. of( + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("I'm an assistant")); + + assertEquals(""" + <|begin_of_text|><|start_header_id|>user<|end_header_id|> + Write a poem about dog of ten lines<|eot_id|><|start_header_id|>assistant<|end_header_id|> + I'm an assistant<|eot_id|>""", promptFormatter.format(prompt4)); + + var prompt5 = List. of( + SystemMessage.from("You are a poet"), + AiMessage.from("I'm an assistant")); + + assertEquals(""" + <|begin_of_text|><|start_header_id|>system<|end_header_id|> + + You are a poet<|eot_id|><|start_header_id|>assistant<|end_header_id|> + I'm an assistant<|eot_id|>""", promptFormatter.format(prompt5)); + + var prompt6 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("dog dog"), + UserMessage.from("Write a poem about cat of ten lines"), + AiMessage.from("cat cat")); + + assertEquals(""" + <|begin_of_text|><|start_header_id|>system<|end_header_id|> + + You are a poet<|eot_id|><|start_header_id|>user<|end_header_id|> + Write a poem about dog of ten lines<|eot_id|><|start_header_id|>assistant<|end_header_id|> + dog dog<|eot_id|><|start_header_id|>user<|end_header_id|> + Write a poem about cat of ten lines<|eot_id|><|start_header_id|>assistant<|end_header_id|> + cat cat<|eot_id|>""", promptFormatter.format(prompt6)); + } + + @Test + void mistral_large_prompt_formatter() { + + var promptFormatter = new MistralLargePromptFormatter(); + var prompt1 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines")); + + assertEquals("[INST] You are a poet [/INST][INST] Write a poem about dog of ten lines [/INST]", + promptFormatter.format(prompt1)); + + var prompt2 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("I'm an assistant")); + + assertEquals( + "[INST] You are a poet [/INST][INST] Write a poem about dog of ten lines [/INST]I'm an assistant", + promptFormatter.format(prompt2)); + + var prompt3 = List. of(UserMessage.from("Write a poem about dog of ten lines")); + assertEquals("[INST] Write a poem about dog of ten lines [/INST]", promptFormatter.format(prompt3)); + + var prompt4 = List. of( + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("I'm an assistant")); + + assertEquals("[INST] Write a poem about dog of ten lines [/INST]I'm an assistant", + promptFormatter.format(prompt4)); + + var prompt5 = List. of( + SystemMessage.from("You are a poet"), + AiMessage.from("I'm an assistant")); + + assertEquals("[INST] You are a poet [/INST]I'm an assistant", + promptFormatter.format(prompt5)); + + var prompt6 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("dog dog"), + UserMessage.from("Write a poem about cat of ten lines"), + AiMessage.from("cat cat")); + + assertEquals( + "[INST] You are a poet [/INST][INST] Write a poem about dog of ten lines [/INST]dog dog[INST] Write a poem about cat of ten lines [/INST]cat cat", + promptFormatter.format(prompt6)); + } + + @Test + void mistral_prompt_formatter() { + + var promptFormatter = new MistralPromptFormatter(); + var prompt1 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines")); + + assertEquals("[INST] You are a poet [/INST][INST] Write a poem about dog of ten lines [/INST]", + promptFormatter.format(prompt1)); + + var prompt2 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("I'm an assistant")); + + assertEquals( + "[INST] You are a poet [/INST][INST] Write a poem about dog of ten lines [/INST]I'm an assistant", + promptFormatter.format(prompt2)); + + var prompt3 = List. of(UserMessage.from("Write a poem about dog of ten lines")); + assertEquals("[INST] Write a poem about dog of ten lines [/INST]", promptFormatter.format(prompt3)); + + var prompt4 = List. of( + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("I'm an assistant")); + + assertEquals("[INST] Write a poem about dog of ten lines [/INST]I'm an assistant", + promptFormatter.format(prompt4)); + + var prompt5 = List. of( + SystemMessage.from("You are a poet"), + AiMessage.from("I'm an assistant")); + + assertEquals("[INST] You are a poet [/INST]I'm an assistant", + promptFormatter.format(prompt5)); + + var prompt6 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("dog dog"), + UserMessage.from("Write a poem about cat of ten lines"), + AiMessage.from("cat cat")); + + assertEquals( + "[INST] You are a poet [/INST][INST] Write a poem about dog of ten lines [/INST]dog dog[INST] Write a poem about cat of ten lines [/INST]cat cat", + promptFormatter.format(prompt6)); + } + + @Test + void granite_prompt_formatter() { + + var promptFormatter = new GranitePromptFormatter(); + var prompt1 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines")); + + assertEquals(""" + <|system|> + You are a poet + <|user|> + Write a poem about dog of ten lines + <|assistant|> + """, promptFormatter.format(prompt1)); + + var prompt2 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("I'm an assistant")); + + assertEquals(""" + <|system|> + You are a poet + <|user|> + Write a poem about dog of ten lines + <|assistant|> + I'm an assistant""", promptFormatter.format(prompt2)); + + var prompt3 = List. of(UserMessage.from("Write a poem about dog of ten lines")); + assertEquals(""" + <|user|> + Write a poem about dog of ten lines + <|assistant|> + """, promptFormatter.format(prompt3)); + + var prompt4 = List. of( + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("I'm an assistant")); + + assertEquals(""" + <|user|> + Write a poem about dog of ten lines + <|assistant|> + I'm an assistant""", promptFormatter.format(prompt4)); + + var prompt5 = List. of( + SystemMessage.from("You are a poet"), + AiMessage.from("I'm an assistant")); + + assertEquals(""" + <|system|> + You are a poet + <|assistant|> + I'm an assistant""", promptFormatter.format(prompt5)); + + var prompt6 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("dog dog"), + UserMessage.from("Write a poem about cat of ten lines"), + AiMessage.from("cat cat")); + + assertEquals(""" + <|system|> + You are a poet + <|user|> + Write a poem about dog of ten lines + <|assistant|> + dog dog + <|user|> + Write a poem about cat of ten lines + <|assistant|> + cat cat""", promptFormatter.format(prompt6)); + } + + @Test + void granite_code_prompt_formatter() { + + var promptFormatter = new GraniteCodePromptFormatter(); + var prompt1 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines")); + + assertEquals(""" + System: + You are a poet + + Question: + Write a poem about dog of ten lines + + Answer: + """, promptFormatter.format(prompt1)); + + var prompt2 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("I'm an assistant")); + + assertEquals(""" + System: + You are a poet + + Question: + Write a poem about dog of ten lines + + Answer: + I'm an assistant + """, promptFormatter.format(prompt2)); + + var prompt3 = List. of(UserMessage.from("Write a poem about dog of ten lines")); + assertEquals(""" + Question: + Write a poem about dog of ten lines + + Answer: + """, promptFormatter.format(prompt3)); + + var prompt4 = List. of( + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("I'm an assistant")); + + assertEquals(""" + Question: + Write a poem about dog of ten lines + + Answer: + I'm an assistant + """, promptFormatter.format(prompt4)); + + var prompt5 = List. of( + SystemMessage.from("You are a poet"), + AiMessage.from("I'm an assistant")); + + assertEquals(""" + System: + You are a poet + + Answer: + I'm an assistant + """, promptFormatter.format(prompt5)); + + var prompt6 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("dog dog"), + UserMessage.from("Write a poem about cat of ten lines"), + AiMessage.from("cat cat")); + + assertEquals(""" + System: + You are a poet + + Question: + Write a poem about dog of ten lines + + Answer: + dog dog + + Question: + Write a poem about cat of ten lines + + Answer: + cat cat + """, promptFormatter.format(prompt6)); + } + + @Test + void default_prompt_formatter() { + + var promptFormatter = new NoopPromptFormatter("\n"); + var promptFormatterWithDifferentJoin = new NoopPromptFormatter("@"); + + var prompt1 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines")); + + assertEquals(""" + You are a poet + Write a poem about dog of ten lines""", promptFormatter.format(prompt1)); + + assertEquals("You are a poet@Write a poem about dog of ten lines", promptFormatterWithDifferentJoin.format(prompt1)); + + var prompt2 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("I'm an assistant")); + + assertEquals(""" + You are a poet + Write a poem about dog of ten lines + I'm an assistant""", promptFormatter.format(prompt2)); + + var prompt3 = List. of(UserMessage.from("Write a poem about dog of ten lines")); + assertEquals("Write a poem about dog of ten lines", promptFormatter.format(prompt3)); + + var prompt4 = List. of( + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("I'm an assistant")); + + assertEquals(""" + Write a poem about dog of ten lines + I'm an assistant""", promptFormatter.format(prompt4)); + + var prompt5 = List. of( + SystemMessage.from("You are a poet"), + AiMessage.from("I'm an assistant")); + + assertEquals(""" + You are a poet + I'm an assistant""", promptFormatter.format(prompt5)); + + var prompt6 = List.of( + SystemMessage.from("You are a poet"), + UserMessage.from("Write a poem about dog of ten lines"), + AiMessage.from("dog dog"), + UserMessage.from("Write a poem about cat of ten lines"), + AiMessage.from("cat cat")); + + assertEquals(""" + You are a poet + Write a poem about dog of ten lines + dog dog + Write a poem about cat of ten lines + cat cat""", promptFormatter.format(prompt6)); + } +} diff --git a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/PromptFormatterTest.java b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/PromptFormatterTest.java new file mode 100644 index 000000000..637d185d8 --- /dev/null +++ b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/PromptFormatterTest.java @@ -0,0 +1,146 @@ +package com.ibm.langchain4j.watsonx.deployment; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Date; + +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.model.input.structured.StructuredPrompt; +import dev.langchain4j.service.SystemMessage; +import dev.langchain4j.service.UserMessage; +import dev.langchain4j.service.V; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.watsonx.bean.Parameters; +import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; +import io.quarkiverse.langchain4j.watsonx.runtime.config.ChatModelConfig; +import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; +import io.quarkus.test.QuarkusUnitTest; + +public class PromptFormatterTest extends WireMockAbstract { + + @RegisterExtension + static QuarkusUnitTest unitTest = new QuarkusUnitTest() + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.base-url", WireMockUtil.URL_WATSONX_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) + .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.model-id", "mistralai/mistral-large") + .overrideConfigKey("quarkus.langchain4j.watsonx.chat-model.prompt-formatter", "true") + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addAsResource("messages/system.txt") + .addAsResource("messages/user.txt") + .addClass(WireMockUtil.class)); + + @Override + void handlerBeforeEach() { + mockServers.mockIAMBuilder(200) + .response("my_super_token", new Date()) + .build(); + } + + @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) + @Singleton + interface AIService { + @SystemMessage("You are a poet") + @UserMessage("Generate a poem about {topic}") + String poem(String topic); + } + + @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) + @Singleton + @SystemMessage("You are a poet") + interface SystemMessageOnClassAIService { + @UserMessage("Generate a poem about {topic}") + String poem(String topic); + } + + @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) + @Singleton + interface AiServiceWithResources { + + @SystemMessage(fromResource = "messages/system.txt") + @UserMessage(fromResource = "messages/user.txt") + String poem(String topic); + } + + @StructuredPrompt("Generate a poem about {topic}") + static class PoemPrompt { + + private final String topic; + + public PoemPrompt(String topic) { + this.topic = topic; + } + + public String getTopic() { + return topic; + } + } + + @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) + @Singleton + interface StructuredPromptAIService { + + @SystemMessage("You are a poet") + String poem(PoemPrompt prompt); + } + + @RegisterAiService + @Singleton + interface AIRuntimeService { + @SystemMessage("You are a poet") + String poem(@UserMessage String prompt, @V("topic") String text); + } + + @Inject + AIService aiService; + + @Inject + SystemMessageOnClassAIService systemMessageOnClassAIService; + + @Inject + AiServiceWithResources aiServiceWithResources; + + @Inject + StructuredPromptAIService structuredPromptAIService; + + @Inject + AIRuntimeService aiRuntimeService; + + @Test + void tests() throws Exception { + + LangChain4jWatsonxConfig.WatsonConfig watsonConfig = langchain4jWatsonConfig.defaultConfig(); + ChatModelConfig chatModelConfig = watsonConfig.chatModel(); + String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); + String projectId = watsonConfig.projectId(); + + Parameters parameters = Parameters.builder() + .decodingMethod(chatModelConfig.decodingMethod()) + .temperature(chatModelConfig.temperature()) + .minNewTokens(chatModelConfig.minNewTokens()) + .maxNewTokens(chatModelConfig.maxNewTokens()) + .build(); + + TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, + "[INST] You are a poet [/INST][INST] Generate a poem about dog [/INST]", parameters); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200) + .body(mapper.writeValueAsString(body)) + .response(WireMockUtil.RESPONSE_WATSONX_CHAT_API) + .build(); + + assertEquals("AI Response", aiService.poem("dog")); + assertEquals("AI Response", systemMessageOnClassAIService.poem("dog")); + assertEquals("AI Response", aiServiceWithResources.poem("dog")); + assertEquals("AI Response", structuredPromptAIService.poem(new PoemPrompt("dog"))); + assertEquals("AI Response", aiRuntimeService.poem("Generate a poem about {topic}", "dog")); + } +} diff --git a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/PromptFormatterToolsTest.java b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/PromptFormatterToolsTest.java new file mode 100644 index 000000000..9816a122a --- /dev/null +++ b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/PromptFormatterToolsTest.java @@ -0,0 +1,89 @@ +package com.ibm.langchain4j.watsonx.deployment; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import java.util.Map; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolParameters; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.data.message.UserMessage; +import io.quarkiverse.langchain4j.watsonx.prompt.impl.MistralLargePromptFormatter; +import io.quarkus.test.QuarkusUnitTest; + +public class PromptFormatterToolsTest { + + @RegisterExtension + static QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)); + + List tools = List.of( + + ToolSpecification.builder() + .name("sum") + .description("Perform a subtraction between two numbers") + .parameters( + ToolParameters.builder() + .properties(Map.of("firstNumber", Map.of("type", "integer"), "secondNumber", + Map.of("type", "integer"))) + .required(List.of("firstNumber", "secondNumber")) + .type("object") + .build()) + .build()); + + ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() + .id("1") + .name("sum") + .arguments("{\"firstNumber\":2,\"secondNumber\":2}\"}") + .build(); + + ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from( + "1", + "sum", + "4"); + + List messages = List.of( + SystemMessage.from("You are a calculator"), + UserMessage.from("2 + 2"), + AiMessage.from(toolExecutionRequest), + toolExecutionResultMessage, + AiMessage.from("The result is 4")); + + @Test + void mistral_large_tools_test() { + MistralLargePromptFormatter promptFormatter = new MistralLargePromptFormatter(); + + String expected_1 = """ + [INST] You are a calculator [/INST][AVAILABLE_TOOLS] \ + [{"type":"function","function":{"name":"sum","description":"Perform a subtraction between two numbers","parameters":{"type":"object","properties":{"firstNumber":{"type":"integer"},"secondNumber":{"type":"integer"}},"required":["firstNumber","secondNumber"]}}}] \ + [/AVAILABLE_TOOLS][INST] 2 + 2 [/INST]\ + [TOOL_CALLS] [{"id":"1","name":"sum","arguments":{"firstNumber":2,"secondNumber":2}}]\ + [TOOL_RESULTS] {"content":4,"id":"1"} [/TOOL_RESULTS] The result is 4"""; + + String expected_2 = """ + [INST] You are a calculator [/INST][AVAILABLE_TOOLS] \ + [{"type":"function","function":{"name":"sum","description":"Perform a subtraction between two numbers","parameters":{"type":"object","properties":{"secondNumber":{"type":"integer"},"firstNumber":{"type":"integer"}},"required":["firstNumber","secondNumber"]}}}] \ + [/AVAILABLE_TOOLS][INST] 2 + 2 [/INST]\ + [TOOL_CALLS] [{"id":"1","name":"sum","arguments":{"firstNumber":2,"secondNumber":2}}]\ + [TOOL_RESULTS] {"content":4,"id":"1"} [/TOOL_RESULTS] The result is 4"""; + + boolean result = false; + + if (expected_1.equals(promptFormatter.format(messages, tools))) + result = true; + else if (expected_2.equals(promptFormatter.format(messages, tools))) + result = true; + + assertTrue(result); + } +} diff --git a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/TokenCountEstimatorTest.java b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/TokenCountEstimatorTest.java index 3a674feb1..2bec72472 100644 --- a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/TokenCountEstimatorTest.java +++ b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/TokenCountEstimatorTest.java @@ -1,6 +1,5 @@ package com.ibm.langchain4j.watsonx.deployment; -import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; import static org.junit.jupiter.api.Assertions.assertEquals; import java.util.Date; @@ -10,15 +9,9 @@ import org.jboss.shrinkwrap.api.ShrinkWrap; import org.jboss.shrinkwrap.api.spec.JavaArchive; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.github.tomakehurst.wiremock.WireMockServer; - import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.data.segment.TextSegment; @@ -26,23 +19,10 @@ import dev.langchain4j.model.chat.TokenCountEstimator; import dev.langchain4j.model.input.Prompt; import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; -import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi; import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; import io.quarkus.test.QuarkusUnitTest; -public class TokenCountEstimatorTest { - - static WireMockServer watsonxServer; - static WireMockServer iamServer; - static ObjectMapper mapper; - - @Inject - LangChain4jWatsonxConfig langchain4jWatsonConfig; - - @Inject - ChatLanguageModel model; - - static WireMockUtil mockServers; +public class TokenCountEstimatorTest extends WireMockAbstract { @RegisterExtension static QuarkusUnitTest unitTest = new QuarkusUnitTest() @@ -52,30 +32,15 @@ public class TokenCountEstimatorTest { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); - @BeforeAll - static void beforeAll() { - mapper = WatsonxRestApi.objectMapper(new ObjectMapper()); - - watsonxServer = new WireMockServer(options().port(WireMockUtil.PORT_WATSONX_SERVER)); - watsonxServer.start(); - - iamServer = new WireMockServer(options().port(WireMockUtil.PORT_IAM_SERVER)); - iamServer.start(); - - mockServers = new WireMockUtil(watsonxServer, iamServer); - } - - @AfterAll - static void afterAll() { - watsonxServer.stop(); - iamServer.stop(); + @Override + void handlerBeforeEach() { + mockServers.mockIAMBuilder(200) + .response(WireMockUtil.BEARER_TOKEN, new Date()) + .build(); } - @BeforeEach - void beforeEach() { - watsonxServer.resetAll(); - iamServer.resetAll(); - } + @Inject + ChatLanguageModel model; @Inject TokenCountEstimator tokenization; @@ -109,19 +74,24 @@ void token_count_estimator_prompt() throws Exception { @Test void token_count_estimator_list() throws Exception { - mockServer(); + + var modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); + var input = "Write a tagline for an alumni\nassociation: Together we"; + var projectId = langchain4jWatsonxConfig.defaultConfig().projectId(); + var body = new TokenizationRequest(modelId, input, projectId); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_TOKENIZER_API, 200) + .body(mapper.writeValueAsString(body)) + .response(WireMockUtil.RESPONSE_WATSONX_TOKENIZER_API.formatted(modelId)) + .build(); + assertEquals(11, tokenization.estimateTokenCount( - List.of(SystemMessage.from("Write a tagline for an alumni "), UserMessage.from("association: Together we")))); + List.of(SystemMessage.from("Write a tagline for an alumni"), UserMessage.from("association: Together we")))); } private String mockServer() throws Exception { - mockServers.mockIAMBuilder(200) - .response(WireMockUtil.BEARER_TOKEN, new Date()) - .build(); - - var config = langchain4jWatsonxConfig.defaultConfig(); - var modelId = config.chatModel().modelId(); + var modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId(); var input = "Write a tagline for an alumni association: Together we"; var projectId = langchain4jWatsonxConfig.defaultConfig().projectId(); var body = new TokenizationRequest(modelId, input, projectId); diff --git a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/WireMockAbstract.java b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/WireMockAbstract.java new file mode 100644 index 000000000..ffb0871d6 --- /dev/null +++ b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/WireMockAbstract.java @@ -0,0 +1,59 @@ +package com.ibm.langchain4j.watsonx.deployment; + +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.tomakehurst.wiremock.WireMockServer; + +import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi; +import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; +import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxFixedRuntimeConfig; + +public abstract class WireMockAbstract { + + static WireMockServer watsonxServer; + static WireMockServer iamServer; + static WireMockUtil mockServers; + static ObjectMapper mapper; + + @Inject + LangChain4jWatsonxConfig langchain4jWatsonConfig; + + @Inject + LangChain4jWatsonxFixedRuntimeConfig langchain4jWatsonFixedRuntimeConfig; + + @BeforeAll + static void beforeAll() { + mapper = WatsonxRestApi.objectMapper(new ObjectMapper()); + + watsonxServer = new WireMockServer(options().port(WireMockUtil.PORT_WATSONX_SERVER)); + watsonxServer.start(); + + iamServer = new WireMockServer(options().port(WireMockUtil.PORT_IAM_SERVER)); + iamServer.start(); + + mockServers = new WireMockUtil(watsonxServer, iamServer); + } + + @AfterAll + static void afterAll() { + watsonxServer.stop(); + iamServer.stop(); + } + + @BeforeEach + void beforeEach() { + watsonxServer.resetAll(); + iamServer.resetAll(); + handlerBeforeEach(); + } + + void handlerBeforeEach() { + }; +} diff --git a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/WireMockUtil.java b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/WireMockUtil.java index 684ccec73..e1183fa13 100644 --- a/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/WireMockUtil.java +++ b/model-providers/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/WireMockUtil.java @@ -38,7 +38,7 @@ public class WireMockUtil { public static final String PROJECT_ID = "123123321321"; public static final String GRANT_TYPE = "urn:ibm:params:oauth:grant-type:apikey"; public static final String VERSION = "2024-03-14"; - public static final String DEFAULT_CHAT_MODEL = "ibm/granite-20b-multilingual"; + public static final String DEFAULT_CHAT_MODEL = "ibm/granite-13b-chat-v2"; public static final String DEFAULT_EMBEDDING_MODEL = "ibm/slate-125m-english-rtrvr"; public static final String IAM_200_RESPONSE = """ { @@ -52,7 +52,7 @@ public class WireMockUtil { """; public static String RESPONSE_WATSONX_CHAT_API = """ { - "model_id": "meta-llama/llama-2-70b-chat", + "model_id": "ibm/granite-13b-chat-v2", "created_at": "2024-01-21T17:06:14.052Z", "results": [ { diff --git a/model-providers/watsonx/deployment/src/test/resources/messages/system.txt b/model-providers/watsonx/deployment/src/test/resources/messages/system.txt new file mode 100644 index 000000000..1e925a999 --- /dev/null +++ b/model-providers/watsonx/deployment/src/test/resources/messages/system.txt @@ -0,0 +1 @@ +You are a poet \ No newline at end of file diff --git a/model-providers/watsonx/deployment/src/test/resources/messages/user.txt b/model-providers/watsonx/deployment/src/test/resources/messages/user.txt new file mode 100644 index 000000000..71d05b714 --- /dev/null +++ b/model-providers/watsonx/deployment/src/test/resources/messages/user.txt @@ -0,0 +1 @@ +Generate a poem about {topic} \ No newline at end of file diff --git a/model-providers/watsonx/pom.xml b/model-providers/watsonx/pom.xml index f3b075821..95fc90cc1 100644 --- a/model-providers/watsonx/pom.xml +++ b/model-providers/watsonx/pom.xml @@ -10,11 +10,8 @@ quarkus-langchain4j-watsonx-parent Quarkus LangChain4j - Watsonx - Parent pom - deployment runtime - - diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java index 91636ba2f..7977174d3 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java @@ -4,6 +4,7 @@ import java.util.Objects; import java.util.concurrent.Callable; +import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -16,36 +17,18 @@ import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationResponse; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationResponse.Result; import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; +import io.quarkiverse.langchain4j.watsonx.prompt.PromptFormatterUtil; public class WatsonxChatModel extends WatsonxModel implements ChatLanguageModel, TokenCountEstimator { - public WatsonxChatModel(WatsonxModel.Builder config) { - super(config); + public WatsonxChatModel(WatsonxModel.Builder builder) { + super(builder); } @Override public Response generate(List messages) { - LengthPenalty lengthPenalty = null; - if (Objects.nonNull(decayFactor) || Objects.nonNull(startIndex)) { - lengthPenalty = new LengthPenalty(decayFactor, startIndex); - } - - Parameters parameters = Parameters.builder() - .decodingMethod(decodingMethod) - .lengthPenalty(lengthPenalty) - .minNewTokens(minNewTokens) - .maxNewTokens(maxNewTokens) - .randomSeed(randomSeed) - .stopSequences(stopSequences) - .temperature(temperature) - .topP(topP) - .topK(topK) - .repetitionPenalty(repetitionPenalty) - .truncateInputTokens(truncateInputTokens) - .includeStopSequence(includeStopSequence) - .build(); - + Parameters parameters = createParameters(); TextGenerationRequest request = new TextGenerationRequest(modelId, projectId, toInput(messages), parameters); Result result = retryOn(new Callable() { @@ -64,6 +47,41 @@ public TextGenerationResponse call() throws Exception { return Response.from(content, tokenUsage, finishReason); } + @Override + public Response generate(List messages, List toolSpecifications) { + Parameters parameters = createParameters(); + TextGenerationRequest request = new TextGenerationRequest(modelId, projectId, toInput(messages, toolSpecifications), + parameters); + + Result result = retryOn(new Callable() { + @Override + public TextGenerationResponse call() throws Exception { + return client.chat(request, version); + } + }).results().get(0); + + var finishReason = toFinishReason(result.stopReason()); + var tokenUsage = new TokenUsage( + result.inputTokenCount(), + result.generatedTokenCount()); + + AiMessage content; + + if (result.generatedText().startsWith(promptFormatter.toolExecution())) { + var tools = result.generatedText().replace(promptFormatter.toolExecution(), ""); + content = AiMessage.from(PromptFormatterUtil.toolExecutionRequest(tools)); + } else { + content = AiMessage.from(result.generatedText()); + } + + return Response.from(content, tokenUsage, finishReason); + } + + @Override + public Response generate(List messages, ToolSpecification toolSpecification) { + return generate(messages, List.of(toolSpecification)); + } + @Override public int estimateTokenCount(List messages) { @@ -77,4 +95,28 @@ public Integer call() throws Exception { } }); } + + private Parameters createParameters() { + LengthPenalty lengthPenalty = null; + if (Objects.nonNull(decayFactor) || Objects.nonNull(startIndex)) { + lengthPenalty = new LengthPenalty(decayFactor, startIndex); + } + + Parameters parameters = Parameters.builder() + .decodingMethod(decodingMethod) + .lengthPenalty(lengthPenalty) + .minNewTokens(minNewTokens) + .maxNewTokens(maxNewTokens) + .randomSeed(randomSeed) + .stopSequences(stopSequences) + .temperature(temperature) + .topP(topP) + .topK(topK) + .repetitionPenalty(repetitionPenalty) + .truncateInputTokens(truncateInputTokens) + .includeStopSequence(includeStopSequence) + .build(); + + return parameters; + } } diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxModel.java index 31ed6d539..331dfd0ac 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxModel.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxModel.java @@ -4,7 +4,6 @@ import java.time.Duration; import java.util.List; import java.util.Optional; -import java.util.StringJoiner; import java.util.concurrent.Callable; import java.util.concurrent.TimeUnit; @@ -12,12 +11,14 @@ import org.jboss.resteasy.reactive.client.api.LoggingScope; +import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.model.output.FinishReason; import io.quarkiverse.langchain4j.watsonx.bean.WatsonxError; import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi; import io.quarkiverse.langchain4j.watsonx.client.filter.BearerTokenHeaderFactory; import io.quarkiverse.langchain4j.watsonx.exception.WatsonxException; +import io.quarkiverse.langchain4j.watsonx.prompt.PromptFormatter; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; public abstract class WatsonxModel { @@ -38,42 +39,51 @@ public abstract class WatsonxModel { final Double repetitionPenalty; final Integer truncateInputTokens; final Boolean includeStopSequence; - final String promptJoiner; final WatsonxRestApi client; + final PromptFormatter promptFormatter; - public WatsonxModel(Builder config) { + public WatsonxModel(Builder builder) { - QuarkusRestClientBuilder builder = QuarkusRestClientBuilder.newBuilder() - .baseUrl(config.url) - .clientHeadersFactory(new BearerTokenHeaderFactory(config.tokenGenerator)) - .connectTimeout(config.timeout.toSeconds(), TimeUnit.SECONDS) - .readTimeout(config.timeout.toSeconds(), TimeUnit.SECONDS); + QuarkusRestClientBuilder restClientBuilder = QuarkusRestClientBuilder.newBuilder() + .baseUrl(builder.url) + .clientHeadersFactory(new BearerTokenHeaderFactory(builder.tokenGenerator)) + .connectTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS) + .readTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS); - if (config.logRequests || config.logResponses) { - builder.loggingScope(LoggingScope.REQUEST_RESPONSE); - builder.clientLogger(new WatsonxRestApi.WatsonClientLogger( - config.logRequests, - config.logResponses)); + if (builder.logRequests || builder.logResponses) { + restClientBuilder.loggingScope(LoggingScope.REQUEST_RESPONSE); + restClientBuilder.clientLogger(new WatsonxRestApi.WatsonClientLogger( + builder.logRequests, + builder.logResponses)); } - this.client = builder.build(WatsonxRestApi.class); - this.modelId = config.modelId; - this.version = config.version; - this.projectId = config.projectId; - this.decodingMethod = config.decodingMethod; - this.decayFactor = config.decayFactor; - this.startIndex = config.startIndex; - this.maxNewTokens = config.maxNewTokens; - this.minNewTokens = config.minNewTokens; - this.randomSeed = config.randomSeed; - this.stopSequences = config.stopSequences; - this.temperature = config.temperature; - this.topP = config.topP; - this.topK = config.topK; - this.repetitionPenalty = config.repetitionPenalty; - this.truncateInputTokens = config.truncateInputTokens; - this.includeStopSequence = config.includeStopSequence; - this.promptJoiner = config.promptJoiner; + this.client = restClientBuilder.build(WatsonxRestApi.class); + this.modelId = builder.modelId; + this.version = builder.version; + this.projectId = builder.projectId; + this.decodingMethod = builder.decodingMethod; + this.decayFactor = builder.decayFactor; + this.startIndex = builder.startIndex; + this.maxNewTokens = builder.maxNewTokens; + this.minNewTokens = builder.minNewTokens; + this.randomSeed = builder.randomSeed; + this.stopSequences = builder.stopSequences; + this.temperature = builder.temperature; + this.topP = builder.topP; + this.topK = builder.topK; + this.repetitionPenalty = builder.repetitionPenalty; + this.truncateInputTokens = builder.truncateInputTokens; + this.includeStopSequence = builder.includeStopSequence; + + if (builder.promptFormatter != null) { + this.promptFormatter = builder.promptFormatter; + } else { + this.promptFormatter = null; + } + } + + public PromptFormatter getPromptFormatter() { + return promptFormatter; } public static Builder builder() { @@ -81,14 +91,11 @@ public static Builder builder() { } protected String toInput(List messages) { - StringJoiner joiner = new StringJoiner(promptJoiner); - for (ChatMessage message : messages) { - switch (message.type()) { - case AI, USER, SYSTEM -> joiner.add(message.text()); - case TOOL_EXECUTION_RESULT -> throw new IllegalArgumentException("Tool message is not supported"); - } - } - return joiner.toString(); + return promptFormatter.format(messages, List.of()); + } + + protected String toInput(List messages, List tools) { + return promptFormatter.format(messages, tools); } protected FinishReason toFinishReason(String stopReason) { @@ -156,7 +163,7 @@ public static final class Builder { public boolean logResponses; public boolean logRequests; private TokenGenerator tokenGenerator; - private String promptJoiner; + private PromptFormatter promptFormatter; public Builder modelId(String modelId) { this.modelId = modelId; @@ -253,8 +260,8 @@ public Builder tokenGenerator(TokenGenerator tokenGenerator) { return this; } - public Builder promptJoiner(String promptJoiner) { - this.promptJoiner = promptJoiner; + public Builder promptFormatter(PromptFormatter promptFormatter) { + this.promptFormatter = promptFormatter; return this; } diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/PromptFormatter.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/PromptFormatter.java new file mode 100644 index 000000000..d797625a0 --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/PromptFormatter.java @@ -0,0 +1,287 @@ +package io.quarkiverse.langchain4j.watsonx.prompt; + +import static dev.langchain4j.data.message.ChatMessageType.AI; +import static dev.langchain4j.data.message.ChatMessageType.SYSTEM; +import static java.util.function.Predicate.not; + +import java.util.List; +import java.util.Map; +import java.util.StringJoiner; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Stream; + +import jakarta.json.Json; + +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.agent.tool.ToolParameters; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ChatMessageType; +import dev.langchain4j.data.message.ToolExecutionResultMessage; + +/** + * The PromptFormatter interface defines the structure for handling and converting different types of {@link ChatMessage} + * objects + * into a specific string format for prompt generation. + */ +public interface PromptFormatter { + + /** + * Defines the string used to join multiple {@link ChatMessage} objects when constructing the prompt. + * + * @return the string used to join messages in the prompt. + */ + public default String joiner() { + return "\n"; + } + + /** + * Defines a start tag that wraps the entire prompt. + * + * @return the start tag for the prompt. + */ + public default String start() { + return ""; + } + + /** + * Defines an end tag that wraps the entire prompt. + * + * @return the end tag for the prompt. + */ + public default String end() { + return ""; + } + + /** + * Returns the tag used to mark {@link SystemMessage} within the prompt. + * + * @return the tag representing a system message. + */ + String system(); + + /** + * Returns the tag used to mark {@link UserMessage} within the prompt. + * + * @return the tag representing a user message. + */ + String user(); + + /** + * Returns the tag used to mark {@link AiMessage} within the prompt. + * + * @return the tag representing an assistant message. + */ + String assistant(); + + /** + * Returns the tag used to mark {@link ToolExecutionResultMessage} within the prompt. + * + * @return the tag representing a tool execution message. + */ + default String toolResult() { + return ""; + } + + /** + * Returns the tag used by the LLM to request a {@link Tool} execution. + * + * @return the tag representing a tool request by the LLM. + */ + default String toolExecution() { + return ""; + } + + /** + * Converts a list of {@link ChatMessage} objects and {@link ToolSpecification} objects into a formatted prompt. + * + * @param messages the list of chat messages to be formatted. + * @param tools the list of tool specifications to be formatted. + * @return a string representing the formatted prompt. + */ + String format(List messages, List tools); + + /** + * Converts a list of {@link ChatMessage} into a formatted prompt. + * + * @param messages the list of chat messages to be formatted. + * @return a string representing the formatted prompt. + */ + default String format(List messages) { + return format(messages, null); + } + + /** + * Defines how to close a tag based on the message type. + * + * @param type the {@link ChatMessageType} for which the closing tag is being requested. + * @return the closing tag for the specified message type. + */ + String endOf(ChatMessageType type); + + /** + * Defines how to close a tag based on the message. + * + * @param message the {@link ChatMessage} for which the closing tag is being requested. + * @return the closing tag for the specified message type. + */ + default String endOf(ChatMessage message) { + return endOf(message.type()); + } + + /** + * Returns the tag associated with a specific {@link ChatMessageType}. + * + * @param type the {@link ChatMessageType} for which the tag is being requested. + * @return the tag for the specified message type. + */ + default String tagOf(ChatMessageType type) { + return switch (type) { + case AI -> assistant(); + case SYSTEM -> system(); + case TOOL_EXECUTION_RESULT -> toolResult(); + case USER -> user(); + }; + } + + /** + * Returns the tag associated with a specific {@link ChatMessage}. + * + * @param message the {@link ChatMessage} for which the tag is being requested. + * @return the tag for the specified message. + */ + default String tagOf(ChatMessage message) { + return tagOf(message.type()); + } + + /** + * Returns a list of all relevant tags used in the prompt. + * + * @return a list of all relevant tags used in the prompt. + */ + default List tokens() { + return Stream.of(start(), end(), system(), user(), assistant(), toolResult()) + .map(String::trim) + .filter(not(String::isBlank)) + .toList(); + } + + /** + * Formats the system message from a list of {@link ChatMessage} objects. + * + * @param messages the list of chat messages from which the system message is formatted. + * @return a string representing the formatted system message. + */ + default String systemMessageFormatter(List messages) { + return messages.stream() + .filter(new Predicate() { + @Override + public boolean test(ChatMessage message) { + return message.type().equals(SYSTEM); + } + }) + .findFirst() + .map(new Function() { + @Override + public String apply(ChatMessage message) { + return system() + message.text() + endOf(SYSTEM) + joiner(); + } + }) + .orElse(""); + } + + /** + * Formats a list of {@link ChatMessage} objects into a string by concatenating each message with its corresponding tag. + * + * @param messages the list of chat messages to be formatted. + * @return a string representing the formatted messages. + */ + default String messagesFormatter(List messages) { + + StringJoiner joiner = new StringJoiner(joiner(), "", ""); + var lastMessage = messages.get(messages.size() - 1); + + for (int i = 0; i < messages.size(); i++) { + + String text; + ChatMessage message = messages.get(i); + + if (message.type().equals(SYSTEM)) + continue; + + if (message instanceof ToolExecutionResultMessage toolExecutionResultMessage) { + + text = tagOf(message) + PromptFormatterUtil.convert(toolExecutionResultMessage) + endOf(message); + + } else if (message instanceof AiMessage aiMessage) { + + if (aiMessage.hasToolExecutionRequests()) { + text = toolExecution() + PromptFormatterUtil.convert(aiMessage.toolExecutionRequests()); + } else { + text = tagOf(message) + message.text() + endOf(message); + } + + } else { + text = tagOf(message) + message.text() + endOf(message); + } + + joiner.add(text); + } + + if (lastMessage.type() != AI && !tagOf(AI).isBlank()) { + joiner.add(tagOf(AI)); + } + + return joiner.toString(); + } + + /** + * Formats a list of {@link ToolSpecification} objects into a JSON string. + * + * @param tools the list of tool specifications to be formatted. + * @return a string representing the formatted tools in JSON format. + */ + default String toolsFormatter(List tools) { + + if (tools == null || tools.isEmpty()) + return ""; + + var result = Json.createArrayBuilder(); + for (ToolSpecification tool : tools) { + + var json = Json.createObjectBuilder().add("type", "function"); + var function = Json.createObjectBuilder() + .add("name", tool.name()) + .add("description", tool.description()); + + ToolParameters toolParameters = tool.parameters(); + var parameters = Json.createObjectBuilder(); + + if (toolParameters != null && !toolParameters.properties().isEmpty()) { + + var properties = Json.createObjectBuilder(); + parameters.add("type", toolParameters.type()); + + for (Map.Entry> entry : toolParameters.properties().entrySet()) { + var key = entry.getKey(); + var value = entry.getValue(); + properties.add(key, PromptFormatterUtil.convert(value)); + } + + parameters.add("properties", properties.build()); + } + + var required = Json.createArrayBuilder(); + toolParameters.required().forEach(required::add); + + parameters.add("required", required); + function.add("parameters", parameters); + json.add("function", function); + result.add(json); + } + + return result.build().toString(); + } +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/PromptFormatterMapper.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/PromptFormatterMapper.java new file mode 100644 index 000000000..d6c6d3427 --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/PromptFormatterMapper.java @@ -0,0 +1,63 @@ +package io.quarkiverse.langchain4j.watsonx.prompt; + +import java.util.HashMap; +import java.util.Map; + +import io.quarkiverse.langchain4j.watsonx.prompt.impl.GraniteCodePromptFormatter; +import io.quarkiverse.langchain4j.watsonx.prompt.impl.GranitePromptFormatter; +import io.quarkiverse.langchain4j.watsonx.prompt.impl.LlamaPromptFormatter; +import io.quarkiverse.langchain4j.watsonx.prompt.impl.MistralLargePromptFormatter; +import io.quarkiverse.langchain4j.watsonx.prompt.impl.MistralPromptFormatter; + +/** + * Utility class to map the model names to the corresponding {@link PromptFormatter}. + */ +public class PromptFormatterMapper { + + static final Map promptFormatters = new HashMap<>(); + + static { + + MistralLargePromptFormatter mistralLargePromptFormatter = new MistralLargePromptFormatter(); + promptFormatters.put("mistralai/mistral-large", mistralLargePromptFormatter); + + MistralPromptFormatter mistralPromptFormatter = new MistralPromptFormatter(); + promptFormatters.put("mistralai/mixtral-8x7b-instruct-v01", mistralPromptFormatter); + promptFormatters.put("sdaia/allam-1-13b-instruct", mistralPromptFormatter); + + LlamaPromptFormatter llamaPromptFormatter = new LlamaPromptFormatter(); + promptFormatters.put("meta-llama/llama-3-405b-instruct", llamaPromptFormatter); + promptFormatters.put("meta-llama/llama-3-1-70b-instruct", llamaPromptFormatter); + promptFormatters.put("meta-llama/llama-3-1-8b-instruct", llamaPromptFormatter); + promptFormatters.put("meta-llama/llama-3-70b-instruct", llamaPromptFormatter); + promptFormatters.put("meta-llama/llama-3-8b-instruct", llamaPromptFormatter); + + GranitePromptFormatter granitePromptFormatter = new GranitePromptFormatter(); + promptFormatters.put("ibm/granite-13b-chat-v2", granitePromptFormatter); + promptFormatters.put("ibm/granite-13b-instruct-v2", granitePromptFormatter); + promptFormatters.put("ibm/granite-7b-lab", granitePromptFormatter); + + GraniteCodePromptFormatter graniteCodePromptFormatter = new GraniteCodePromptFormatter(); + promptFormatters.put("ibm/granite-20b-code-instruct", graniteCodePromptFormatter); + promptFormatters.put("ibm/granite-34b-code-instruct", graniteCodePromptFormatter); + promptFormatters.put("ibm/granite-3b-code-instruct", graniteCodePromptFormatter); + promptFormatters.put("ibm/granite-8b-code-instruct", graniteCodePromptFormatter); + } + + /** + * Retrieves the {@link PromptFormatter} associated with the specified model name. + * + * @param model the name of the model whose {@link PromptFormatter} is requested. + * @return the {@link PromptFormatter} corresponding to the model name, or null if the model is not found. + */ + public static PromptFormatter get(String model) { + return promptFormatters.get(model); + } + + public static boolean toolIsSupported(String model) { + return switch (model) { + case "mistralai/mistral-large" -> true; + default -> false; + }; + } +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/PromptFormatterUtil.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/PromptFormatterUtil.java new file mode 100644 index 000000000..324ddf27a --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/PromptFormatterUtil.java @@ -0,0 +1,162 @@ +package io.quarkiverse.langchain4j.watsonx.prompt; + +import java.io.StringReader; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import jakarta.json.Json; +import jakarta.json.JsonArray; +import jakarta.json.JsonArrayBuilder; +import jakarta.json.JsonObject; +import jakarta.json.JsonReader; +import jakarta.json.JsonValue; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.data.message.ToolExecutionResultMessage; + +/** + * Utility class for handling various prompt-formatter related tasks. + */ +public class PromptFormatterUtil { + + /** + * Converts a Map into a JSON string representation. + * + * @param map the map to convert + * @return a JSON string representing the map + */ + @SuppressWarnings("unchecked") + public static JsonObject convert(Map map) { + + var json = Json.createObjectBuilder(); + + for (Map.Entry entry : map.entrySet()) { + String key = entry.getKey(); + Object value = entry.getValue(); + + if (value instanceof Map cValue) { + json.add(key, convert(cValue)); + } else if (value instanceof Collection cValue) { + json.add(key, convert(cValue)); + } else if (value instanceof String cValue) { + json.add(key, cValue); + } else if (value instanceof Integer cValue) { + json.add(key, cValue); + } else if (value instanceof Boolean cValue) { + json.add(key, cValue); + } else { + json.add(key, value.toString()); + } + } + return json.build(); + } + + /** + * Converts a ToolExecutionResultMessage into a JSON string representation. + * + * @param toolExecutionResultMessage the {@link ToolExecutionResultMessage} to convert + * @return a JSON string representing the tool execution result message + */ + public static JsonObject convert(ToolExecutionResultMessage toolExecutionResultMessage) { + + JsonValue content = null; + if (toolExecutionResultMessage.text() != null) { + StringReader stringReader = new StringReader(toolExecutionResultMessage.text()); + try (JsonReader jsonReader = Json.createReader(stringReader)) { + content = jsonReader.readValue(); + } + } + + return Json.createObjectBuilder() + .add("content", content) + .add("id", toolExecutionResultMessage.id()) + .build(); + } + + /** + * Converts a {@List} of {@link ToolExecutionRequest} objects into a JSON string representation. + * + * @param toolExecutionRequests the {@List} of {@link ToolExecutionRequest} objects to convert + * @return a JSON string representing the list of ToolExecutionRequest objects + */ + public static JsonArray convert(List toolExecutionRequests) { + var result = Json.createArrayBuilder(); + for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) { + result.add(PromptFormatterUtil.convert(toolExecutionRequest)); + } + return result.build(); + } + + /** + * Parses a JSON string representing a {@List} of {@link ToolExecutionRequest} and converts it into a JSON representation. + * + * @param json the JSON string to parse + * @return a {@List} of {@link ToolExecutionRequest} objects + */ + public static List toolExecutionRequest(String json) { + + List result = new ArrayList<>(); + StringReader stringReader = new StringReader(json); + + try (JsonReader jsonReader = Json.createReader(stringReader)) { + var toolExecutionRequests = jsonReader.readArray(); + for (JsonValue toolExecutionRequest : toolExecutionRequests) { + var tool = toolExecutionRequest.asJsonObject(); + result.add( + ToolExecutionRequest.builder() + .id(UUID.randomUUID().toString()) + .name(tool.getString("name")) + .arguments(tool.getJsonObject("arguments").toString()) + .build()); + } + } + return result; + } + + // + // Converts a ToolExecutionRequest object into a JsonObject. + // + private static JsonObject convert(ToolExecutionRequest toolExecutionRequest) { + + JsonValue arguments = null; + if (toolExecutionRequest.arguments() != null) { + StringReader stringReader = new StringReader(toolExecutionRequest.arguments()); + try (JsonReader jsonReader = Json.createReader(stringReader)) { + arguments = jsonReader.readValue(); + } + } + + return Json.createObjectBuilder() + .add("id", toolExecutionRequest.id()) + .add("name", toolExecutionRequest.name()) + .add("arguments", toolExecutionRequest.arguments() != null ? arguments : Json.createObjectBuilder().build()) + .build(); + } + + // + // Converts a Collection of objects into a JsonArray. + // + @SuppressWarnings("unchecked") + private static JsonArray convert(Collection list) { + + JsonArrayBuilder jsonArrayBuilder = Json.createArrayBuilder(); + + for (Object value : list) { + if (value instanceof Map cValue) { + jsonArrayBuilder.add(convert(cValue)); + } else if (value instanceof String cValue) { + jsonArrayBuilder.add(cValue); + } else if (value instanceof Integer cValue) { + jsonArrayBuilder.add(cValue); + } else if (value instanceof Boolean cValue) { + jsonArrayBuilder.add(cValue); + } else { + jsonArrayBuilder.add(value.toString()); + } + } + return jsonArrayBuilder.build(); + } +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/GraniteCodePromptFormatter.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/GraniteCodePromptFormatter.java new file mode 100644 index 000000000..68ad25258 --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/GraniteCodePromptFormatter.java @@ -0,0 +1,42 @@ +package io.quarkiverse.langchain4j.watsonx.prompt.impl; + +import java.util.List; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ChatMessageType; +import io.quarkiverse.langchain4j.watsonx.prompt.PromptFormatter; + +/** + * Granite code prompt formatter. + */ +public class GraniteCodePromptFormatter implements PromptFormatter { + + @Override + public String system() { + return "System:\n"; + } + + @Override + public String user() { + return "Question:\n"; + } + + @Override + public String assistant() { + return "Answer:\n"; + } + + @Override + public String endOf(ChatMessageType messageType) { + return "\n"; + } + + @Override + public String format(List messages, List tools) { + return """ + %s\ + %s\ + """.formatted(systemMessageFormatter(messages), messagesFormatter(messages)); + } +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/GranitePromptFormatter.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/GranitePromptFormatter.java new file mode 100644 index 000000000..5477322ba --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/GranitePromptFormatter.java @@ -0,0 +1,42 @@ +package io.quarkiverse.langchain4j.watsonx.prompt.impl; + +import java.util.List; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ChatMessageType; +import io.quarkiverse.langchain4j.watsonx.prompt.PromptFormatter; + +/** + * Granite prompt formatter. + */ +public class GranitePromptFormatter implements PromptFormatter { + + @Override + public String system() { + return "<|system|>\n"; + } + + @Override + public String user() { + return "<|user|>\n"; + } + + @Override + public String assistant() { + return "<|assistant|>\n"; + } + + @Override + public String endOf(ChatMessageType messageType) { + return ""; + } + + @Override + public String format(List messages, List tools) { + return """ + %s\ + %s\ + """.formatted(systemMessageFormatter(messages), messagesFormatter(messages)); + } +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/LlamaPromptFormatter.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/LlamaPromptFormatter.java new file mode 100644 index 000000000..897db1483 --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/LlamaPromptFormatter.java @@ -0,0 +1,53 @@ +package io.quarkiverse.langchain4j.watsonx.prompt.impl; + +import java.util.List; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ChatMessageType; +import io.quarkiverse.langchain4j.watsonx.prompt.PromptFormatter; + +/** + * Llama 3.x prompt formatter. + */ +public class LlamaPromptFormatter implements PromptFormatter { + + @Override + public String joiner() { + return ""; + } + + @Override + public String start() { + return "<|begin_of_text|>"; + } + + @Override + public String system() { + return "<|start_header_id|>system<|end_header_id|>\n\n"; + } + + @Override + public String user() { + return "<|start_header_id|>user<|end_header_id|>\n"; + } + + @Override + public String assistant() { + return "<|start_header_id|>assistant<|end_header_id|>\n"; + } + + @Override + public String format(List messages, List tools) { + return """ + %s\ + %s\ + %s\ + """.formatted(start(), systemMessageFormatter(messages), messagesFormatter(messages)); + } + + @Override + public String endOf(ChatMessageType type) { + return "<|eot_id|>"; + } +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/MistralLargePromptFormatter.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/MistralLargePromptFormatter.java new file mode 100644 index 000000000..9ff2140e3 --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/MistralLargePromptFormatter.java @@ -0,0 +1,78 @@ +package io.quarkiverse.langchain4j.watsonx.prompt.impl; + +import java.util.List; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ChatMessageType; +import io.quarkiverse.langchain4j.watsonx.prompt.PromptFormatter; + +/** + * MistralLarge prompt formatter. + */ +public class MistralLargePromptFormatter implements PromptFormatter { + + @Override + public String joiner() { + return ""; + } + + @Override + public String start() { + return ""; + } + + @Override + public String system() { + return "[INST] "; + } + + @Override + public String user() { + return "[INST] "; + } + + @Override + public String assistant() { + return ""; + } + + @Override + public String toolResult() { + return "[TOOL_RESULTS] "; + } + + @Override + public String toolExecution() { + return "[TOOL_CALLS] "; + } + + @Override + public String format(List messages, List tools) { + if (tools != null && tools.size() > 0) { + return """ + %s\ + %s\ + [AVAILABLE_TOOLS] %s [/AVAILABLE_TOOLS]\ + %s\ + """.formatted(start(), systemMessageFormatter(messages), toolsFormatter(tools), + messagesFormatter(messages)); + } else { + return """ + %s\ + %s\ + %s\ + """.formatted(start(), systemMessageFormatter(messages), messagesFormatter(messages)); + } + } + + @Override + public String endOf(ChatMessageType type) { + return switch (type) { + case AI -> ""; + case SYSTEM -> " [/INST]"; + case USER -> " [/INST]"; + case TOOL_EXECUTION_RESULT -> " [/TOOL_RESULTS] "; + }; + } +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/MistralPromptFormatter.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/MistralPromptFormatter.java new file mode 100644 index 000000000..2bff38aea --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/MistralPromptFormatter.java @@ -0,0 +1,58 @@ +package io.quarkiverse.langchain4j.watsonx.prompt.impl; + +import java.util.List; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ChatMessageType; +import io.quarkiverse.langchain4j.watsonx.prompt.PromptFormatter; + +/** + * Mistral prompt formatter. + */ +public class MistralPromptFormatter implements PromptFormatter { + + @Override + public String joiner() { + return ""; + } + + @Override + public String start() { + return ""; + } + + @Override + public String system() { + return "[INST] "; + } + + @Override + public String user() { + return "[INST] "; + } + + @Override + public String assistant() { + return ""; + } + + @Override + public String format(List messages, List tools) { + return """ + %s\ + %s\ + %s\ + """.formatted(start(), systemMessageFormatter(messages), messagesFormatter(messages)); + } + + @Override + public String endOf(ChatMessageType type) { + return switch (type) { + case AI -> ""; + case SYSTEM -> " [/INST]"; + case USER -> " [/INST]"; + case TOOL_EXECUTION_RESULT -> ""; + }; + } +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/NoopPromptFormatter.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/NoopPromptFormatter.java new file mode 100644 index 000000000..52d9f05ae --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/prompt/impl/NoopPromptFormatter.java @@ -0,0 +1,52 @@ +package io.quarkiverse.langchain4j.watsonx.prompt.impl; + +import java.util.List; +import java.util.stream.Collectors; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ChatMessageType; +import io.quarkiverse.langchain4j.watsonx.prompt.PromptFormatter; + +/* + * Prompt formatter used when the model used doesn't have a specific implementation or when the prompt-formatter property is set + * to false. + */ +public class NoopPromptFormatter implements PromptFormatter { + + private String joiner; + + public NoopPromptFormatter(String joiner) { + this.joiner = joiner; + } + + @Override + public String joiner() { + return joiner; + } + + @Override + public String system() { + return ""; + } + + @Override + public String user() { + return ""; + } + + @Override + public String assistant() { + return ""; + } + + @Override + public String format(List messages, List tools) { + return messages.stream().map(ChatMessage::text).collect(Collectors.joining(joiner())); + } + + @Override + public String endOf(ChatMessageType messageType) { + return ""; + } +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java index ef887723e..49b30abfa 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonxRecorder.java @@ -11,6 +11,8 @@ import java.util.function.Function; import java.util.function.Supplier; +import org.jboss.logging.Logger; + import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.DisabledChatLanguageModel; import dev.langchain4j.model.chat.DisabledStreamingChatLanguageModel; @@ -23,30 +25,43 @@ import io.quarkiverse.langchain4j.watsonx.WatsonxEmbeddingModel; import io.quarkiverse.langchain4j.watsonx.WatsonxModel; import io.quarkiverse.langchain4j.watsonx.WatsonxStreamingChatModel; +import io.quarkiverse.langchain4j.watsonx.prompt.PromptFormatter; +import io.quarkiverse.langchain4j.watsonx.prompt.impl.NoopPromptFormatter; import io.quarkiverse.langchain4j.watsonx.runtime.config.ChatModelConfig; -import io.quarkiverse.langchain4j.watsonx.runtime.config.ChatModelConfig.LengthPenaltyConfig; import io.quarkiverse.langchain4j.watsonx.runtime.config.EmbeddingModelConfig; import io.quarkiverse.langchain4j.watsonx.runtime.config.IAMConfig; import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; +import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxFixedRuntimeConfig; import io.quarkus.runtime.annotations.Recorder; import io.smallrye.config.ConfigValidationException; @Recorder public class WatsonxRecorder { + private static final Logger log = Logger.getLogger(WatsonxRecorder.class); + private static final String DUMMY_URL = "https://dummy.ai/api"; private static final String DUMMY_API_KEY = "dummy"; private static final String DUMMY_PROJECT_ID = "dummy"; private static final Map tokenGeneratorCache = new HashMap<>(); private static final ConfigValidationException.Problem[] EMPTY_PROBLEMS = new ConfigValidationException.Problem[0]; - public Supplier chatModel(LangChain4jWatsonxConfig runtimeConfig, String configName) { - LangChain4jWatsonxConfig.WatsonConfig watsonConfig = correspondingWatsonConfig(runtimeConfig, configName); + public Supplier chatModel(LangChain4jWatsonxConfig runtimeConfig, + LangChain4jWatsonxFixedRuntimeConfig fixedRuntimeConfig, + String configName, PromptFormatter promptFormatter) { - if (watsonConfig.enableIntegration()) { + LangChain4jWatsonxConfig.WatsonConfig watsonRuntimeConfig = correspondingWatsonRuntimeConfig(runtimeConfig, configName); + LangChain4jWatsonxFixedRuntimeConfig.WatsonConfig watsonFixedRuntimeConfig = correspondingWatsonFixedRuntimeConfig( + fixedRuntimeConfig, configName); + + if (promptFormatter != null && watsonFixedRuntimeConfig.chatModel().promptFormatter()) { + log.infof("The PromptFormatter for \"%s\" is loaded, the model tags are generated automatically.", + watsonFixedRuntimeConfig.chatModel().modelId()); + } - var builder = generateChatBuilder(watsonConfig, configName); + if (watsonRuntimeConfig.enableIntegration()) { + var builder = generateChatBuilder(watsonRuntimeConfig, watsonFixedRuntimeConfig, configName, promptFormatter); return new Supplier<>() { @Override public ChatLanguageModel get() { @@ -64,13 +79,16 @@ public ChatLanguageModel get() { } } - public Supplier streamingChatModel(LangChain4jWatsonxConfig runtimeConfig, String configName) { - LangChain4jWatsonxConfig.WatsonConfig watsonConfig = correspondingWatsonConfig(runtimeConfig, configName); + public Supplier streamingChatModel(LangChain4jWatsonxConfig runtimeConfig, + LangChain4jWatsonxFixedRuntimeConfig fixedRuntimeConfig, String configName, PromptFormatter promptFormatter) { - if (watsonConfig.enableIntegration()) { + LangChain4jWatsonxConfig.WatsonConfig watsonRuntimeConfig = correspondingWatsonRuntimeConfig(runtimeConfig, configName); + LangChain4jWatsonxFixedRuntimeConfig.WatsonConfig watsonFixedRuntimeConfig = correspondingWatsonFixedRuntimeConfig( + fixedRuntimeConfig, configName); - var builder = generateChatBuilder(watsonConfig, configName); + if (watsonRuntimeConfig.enableIntegration()) { + var builder = generateChatBuilder(watsonRuntimeConfig, watsonFixedRuntimeConfig, configName, promptFormatter); return new Supplier<>() { @Override public StreamingChatLanguageModel get() { @@ -89,7 +107,7 @@ public StreamingChatLanguageModel get() { } public Supplier embeddingModel(LangChain4jWatsonxConfig runtimeConfig, String configName) { - LangChain4jWatsonxConfig.WatsonConfig watsonConfig = correspondingWatsonConfig(runtimeConfig, configName); + LangChain4jWatsonxConfig.WatsonConfig watsonConfig = correspondingWatsonRuntimeConfig(runtimeConfig, configName); if (watsonConfig.enableIntegration()) { var configProblems = checkConfigurations(watsonConfig, configName); @@ -148,44 +166,42 @@ public TokenGenerator apply(String iamUrl) { }; } - private WatsonxModel.Builder generateChatBuilder(LangChain4jWatsonxConfig.WatsonConfig watsonConfig, String configName) { + private WatsonxModel.Builder generateChatBuilder( + LangChain4jWatsonxConfig.WatsonConfig watsonRuntimeConfig, + LangChain4jWatsonxFixedRuntimeConfig.WatsonConfig watsonFixedRuntimeConfig, + String configName, PromptFormatter promptFormatter) { - ChatModelConfig chatModelConfig = watsonConfig.chatModel(); - var configProblems = checkConfigurations(watsonConfig, configName); + ChatModelConfig chatModelConfig = watsonRuntimeConfig.chatModel(); + var configProblems = checkConfigurations(watsonRuntimeConfig, configName); if (!configProblems.isEmpty()) { throw new ConfigValidationException(configProblems.toArray(EMPTY_PROBLEMS)); } - String iamUrl = watsonConfig.iam().baseUrl().toExternalForm(); + String iamUrl = watsonRuntimeConfig.iam().baseUrl().toExternalForm(); TokenGenerator tokenGenerator = tokenGeneratorCache.computeIfAbsent(iamUrl, - createTokenGenerator(watsonConfig.iam(), watsonConfig.apiKey())); + createTokenGenerator(watsonRuntimeConfig.iam(), watsonRuntimeConfig.apiKey())); URL url; try { - url = new URL(watsonConfig.baseUrl()); + url = new URL(watsonRuntimeConfig.baseUrl()); } catch (Exception e) { throw new RuntimeException(e); } - Double decayFactor = null; - Integer startIndex = null; - - if (chatModelConfig.lengthPenalty().isPresent()) { - decayFactor = chatModelConfig.lengthPenalty().map(LengthPenaltyConfig::decayFactor).get().orElse(null); - startIndex = chatModelConfig.lengthPenalty().map(LengthPenaltyConfig::startIndex).get().orElse(null); - } + Double decayFactor = chatModelConfig.lengthPenalty().decayFactor().orElse(null); + Integer startIndex = chatModelConfig.lengthPenalty().startIndex().orElse(null); + String promptJoiner = chatModelConfig.promptJoiner(); return WatsonxChatModel.builder() - .promptJoiner(chatModelConfig.promptJoiner().orElse("")) .tokenGenerator(tokenGenerator) .url(url) - .timeout(watsonConfig.timeout().orElse(Duration.ofSeconds(10))) + .timeout(watsonRuntimeConfig.timeout().orElse(Duration.ofSeconds(10))) .logRequests(chatModelConfig.logRequests().orElse(false)) .logResponses(chatModelConfig.logResponses().orElse(false)) - .version(watsonConfig.version()) - .projectId(watsonConfig.projectId()) - .modelId(chatModelConfig.modelId()) + .version(watsonRuntimeConfig.version()) + .projectId(watsonRuntimeConfig.projectId()) + .modelId(watsonFixedRuntimeConfig.chatModel().modelId()) .decodingMethod(chatModelConfig.decodingMethod()) .decayFactor(decayFactor) .startIndex(startIndex) @@ -198,10 +214,11 @@ private WatsonxModel.Builder generateChatBuilder(LangChain4jWatsonxConfig.Watson .topP(firstOrDefault(null, chatModelConfig.topP())) .repetitionPenalty(firstOrDefault(null, chatModelConfig.repetitionPenalty())) .truncateInputTokens(chatModelConfig.truncateInputTokens().orElse(null)) - .includeStopSequence(chatModelConfig.includeStopSequence().orElse(null)); + .includeStopSequence(chatModelConfig.includeStopSequence().orElse(null)) + .promptFormatter(promptFormatter == null ? new NoopPromptFormatter(promptJoiner) : promptFormatter); } - private LangChain4jWatsonxConfig.WatsonConfig correspondingWatsonConfig(LangChain4jWatsonxConfig runtimeConfig, + private LangChain4jWatsonxConfig.WatsonConfig correspondingWatsonRuntimeConfig(LangChain4jWatsonxConfig runtimeConfig, String configName) { LangChain4jWatsonxConfig.WatsonConfig watsonConfig; if (NamedConfigUtil.isDefault(configName)) { @@ -212,6 +229,18 @@ private LangChain4jWatsonxConfig.WatsonConfig correspondingWatsonConfig(LangChai return watsonConfig; } + private LangChain4jWatsonxFixedRuntimeConfig.WatsonConfig correspondingWatsonFixedRuntimeConfig( + LangChain4jWatsonxFixedRuntimeConfig fixedRuntimeConfig, + String configName) { + LangChain4jWatsonxFixedRuntimeConfig.WatsonConfig watsonConfig; + if (NamedConfigUtil.isDefault(configName)) { + watsonConfig = fixedRuntimeConfig.defaultConfig(); + } else { + watsonConfig = fixedRuntimeConfig.namedConfig().get(configName); + } + return watsonConfig; + } + private List checkConfigurations(LangChain4jWatsonxConfig.WatsonConfig watsonConfig, String configName) { List configProblems = new ArrayList<>(); diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelConfig.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelConfig.java index 69e650d0b..8292d6242 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelConfig.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelConfig.java @@ -10,16 +10,6 @@ @ConfigGroup public interface ChatModelConfig { - /** - * Model id to use. - *

- * To view the complete model list, click - * here. - */ - @WithDefault("ibm/granite-20b-multilingual") - String modelId(); - /** * Represents the strategy used for picking the tokens during generation of the output text. During text generation when * parameter @@ -42,7 +32,7 @@ public interface ChatModelConfig { * tokens * have been generated. */ - Optional lengthPenalty(); + LengthPenaltyConfig lengthPenalty(); /** * The maximum number of new tokens to be generated. The maximum supported value for this field depends on the model being @@ -161,7 +151,8 @@ public interface ChatModelConfig { * your * preferred way of concatenating messages to ensure that the prompt is structured in the correct way. */ - Optional promptJoiner(); + @WithDefault("\n") + String promptJoiner(); @ConfigGroup public interface LengthPenaltyConfig { diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelFixedRuntimeConfig.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelFixedRuntimeConfig.java new file mode 100644 index 000000000..274e1c0e2 --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/ChatModelFixedRuntimeConfig.java @@ -0,0 +1,30 @@ +package io.quarkiverse.langchain4j.watsonx.runtime.config; + +import io.quarkus.runtime.annotations.ConfigGroup; +import io.smallrye.config.WithDefault; + +@ConfigGroup +public interface ChatModelFixedRuntimeConfig { + + /** + * Model id to use. + * + *

+ * To view the complete model list, click + * here. + */ + @WithDefault("ibm/granite-13b-chat-v2") + String modelId(); + + /** + * Configuration property that enables or disables the functionality of the prompt formatter. + * + *

    + *
  • true: When enabled, prompts are automatically enriched with the specific tags defined by the model.
  • + *
  • false: Prompts will not be enriched with the model's tags.
  • + *
+ */ + @WithDefault("false") + boolean promptFormatter(); +} diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/LangChain4jWatsonxFixedRuntimeConfig.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/LangChain4jWatsonxFixedRuntimeConfig.java new file mode 100644 index 000000000..6fa6c0e68 --- /dev/null +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/LangChain4jWatsonxFixedRuntimeConfig.java @@ -0,0 +1,40 @@ +package io.quarkiverse.langchain4j.watsonx.runtime.config; + +import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_AND_RUN_TIME_FIXED; + +import java.util.Map; + +import io.quarkus.runtime.annotations.ConfigDocMapKey; +import io.quarkus.runtime.annotations.ConfigDocSection; +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithDefaults; +import io.smallrye.config.WithParentName; + +@ConfigRoot(phase = BUILD_AND_RUN_TIME_FIXED) +@ConfigMapping(prefix = "quarkus.langchain4j.watsonx") +public interface LangChain4jWatsonxFixedRuntimeConfig { + + /** + * Default model config. + */ + @WithParentName + WatsonConfig defaultConfig(); + + /** + * Named model config. + */ + @ConfigDocSection + @ConfigDocMapKey("model-name") + @WithParentName + @WithDefaults + Map namedConfig(); + + interface WatsonConfig { + + /** + * Chat model related settings + */ + ChatModelFixedRuntimeConfig chatModel(); + } +} diff --git a/model-providers/watsonx/runtime/src/test/java/io/quarkiverse/langchain4j/watsonx/runtime/DisabledModelsWatsonRecorderTest.java b/model-providers/watsonx/runtime/src/test/java/io/quarkiverse/langchain4j/watsonx/runtime/DisabledModelsWatsonRecorderTest.java index d9737afa8..12dfa5bab 100644 --- a/model-providers/watsonx/runtime/src/test/java/io/quarkiverse/langchain4j/watsonx/runtime/DisabledModelsWatsonRecorderTest.java +++ b/model-providers/watsonx/runtime/src/test/java/io/quarkiverse/langchain4j/watsonx/runtime/DisabledModelsWatsonRecorderTest.java @@ -13,9 +13,12 @@ import io.quarkiverse.langchain4j.runtime.NamedConfigUtil; import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig.WatsonConfig; +import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxFixedRuntimeConfig; class DisabledModelsWatsonRecorderTest { - LangChain4jWatsonxConfig config = mock(LangChain4jWatsonxConfig.class); + LangChain4jWatsonxConfig runtimeConfig = mock(LangChain4jWatsonxConfig.class); + LangChain4jWatsonxFixedRuntimeConfig fixedRuntimeConfig = mock(LangChain4jWatsonxFixedRuntimeConfig.class); + WatsonConfig defaultConfig = mock(WatsonConfig.class); WatsonxRecorder recorder = new WatsonxRecorder(); @@ -24,21 +27,23 @@ void setupMocks() { when(defaultConfig.enableIntegration()) .thenReturn(false); - when(config.defaultConfig()) + when(runtimeConfig.defaultConfig()) .thenReturn(defaultConfig); } @Test void disabledChatModel() { - assertThat(recorder.chatModel(config, NamedConfigUtil.DEFAULT_NAME).get()) + assertThat(recorder + .chatModel(runtimeConfig, fixedRuntimeConfig, NamedConfigUtil.DEFAULT_NAME, null) + .get()) .isNotNull() .isExactlyInstanceOf(DisabledChatLanguageModel.class); - assertThat(recorder.streamingChatModel(config, NamedConfigUtil.DEFAULT_NAME).get()) + assertThat(recorder.streamingChatModel(runtimeConfig, fixedRuntimeConfig, NamedConfigUtil.DEFAULT_NAME, null).get()) .isNotNull() .isExactlyInstanceOf(DisabledStreamingChatLanguageModel.class); - assertThat(recorder.embeddingModel(config, NamedConfigUtil.DEFAULT_NAME).get()) + assertThat(recorder.embeddingModel(runtimeConfig, NamedConfigUtil.DEFAULT_NAME).get()) .isNotNull() .isExactlyInstanceOf(DisabledEmbeddingModel.class); }