Skip to content

Commit

Permalink
Merge pull request #830 from andreadimaio/prompt_formatter
Browse files Browse the repository at this point in the history
Add PromptFormatter functionality
  • Loading branch information
geoand authored Sep 2, 2024
2 parents 1bb873c + 5f90e39 commit 7d1119d
Show file tree
Hide file tree
Showing 40 changed files with 2,406 additions and 616 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,6 @@ release.properties

# Quarkus CLI
.quarkus

#Dolphin
.directory
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.SEED_MEMORY;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.V;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
Expand Down Expand Up @@ -117,8 +116,8 @@ public class AiServicesProcessor {

public static final DotName MICROMETER_TIMED = DotName.createSimple("io.micrometer.core.annotation.Timed");
public static final DotName MICROMETER_COUNTED = DotName.createSimple("io.micrometer.core.annotation.Counted");
private static final String DEFAULT_DELIMITER = "\n";
private static final Predicate<AnnotationInstance> IS_METHOD_PARAMETER_ANNOTATION = ai -> ai.target()
public static final String DEFAULT_DELIMITER = "\n";
public static final Predicate<AnnotationInstance> IS_METHOD_PARAMETER_ANNOTATION = ai -> ai.target()
.kind() == AnnotationTarget.Kind.METHOD_PARAMETER;
private static final Function<AnnotationInstance, Integer> METHOD_PARAMETER_POSITION_FUNCTION = ai -> Integer
.valueOf(ai.target()
Expand Down Expand Up @@ -1033,7 +1032,7 @@ private Optional<AiServiceMethodCreateInfo.TemplateInfo> gatherSystemMessageInfo
instance = method.declaringClass().declaredAnnotation(LangChain4jDotNames.SYSTEM_MESSAGE);
}
if (instance != null) {
String systemMessageTemplate = getTemplateFromAnnotationInstance(instance);
String systemMessageTemplate = TemplateUtil.getTemplateFromAnnotationInstance(instance);
if (systemMessageTemplate.isEmpty()) {
throw illegalConfigurationForMethod("@SystemMessage's template parameter cannot be empty", method);
}
Expand Down Expand Up @@ -1061,7 +1060,7 @@ private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodIn

AnnotationInstance userMessageInstance = method.declaredAnnotation(LangChain4jDotNames.USER_MESSAGE);
if (userMessageInstance != null) {
String userMessageTemplate = getTemplateFromAnnotationInstance(userMessageInstance);
String userMessageTemplate = TemplateUtil.getTemplateFromAnnotationInstance(userMessageInstance);

if (userMessageTemplate.contains("{{it}}")) {
if (method.parametersCount() != 1) {
Expand Down Expand Up @@ -1110,41 +1109,6 @@ private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodIn
}
}

/**
* Meant to be called with instances of {@link dev.langchain4j.service.SystemMessage} or
* {@link dev.langchain4j.service.UserMessage}
*
* @return the String value of the template or an empty string if not specified
*/
private String getTemplateFromAnnotationInstance(AnnotationInstance instance) {
AnnotationValue fromResourceValue = instance.value("fromResource");
if (fromResourceValue != null) {
String fromResource = fromResourceValue.asString();
if (!fromResource.startsWith("/")) {
fromResource = "/" + fromResource;

}
try (InputStream is = Thread.currentThread().getContextClassLoader().getResourceAsStream(fromResource)) {
if (is != null) {
return new String(is.readAllBytes());
} else {
throw new FileNotFoundException("Resource not found: " + fromResource);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
} else {
AnnotationValue valueValue = instance.value();
if (valueValue != null) {
AnnotationValue delimiterValue = instance.value("delimiter");
String delimiter = delimiterValue != null ? delimiterValue.asString() : DEFAULT_DELIMITER;
return String.join(delimiter, valueValue.asStringArray());
}

}
return "";
}

private Optional<AiServiceMethodCreateInfo.MetricsTimedInfo> gatherMetricsTimedInfo(MethodInfo method,
boolean addMicrometerMetrics) {
if (!addMicrometerMetrics) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public class LangChain4jDotNames {
static final DotName AI_SERVICES = DotName.createSimple(AiServices.class);
static final DotName CREATED_AWARE = DotName.createSimple(CreatedAware.class);
public static final DotName SYSTEM_MESSAGE = DotName.createSimple(SystemMessage.class);
static final DotName USER_MESSAGE = DotName.createSimple(UserMessage.class);
public static final DotName USER_MESSAGE = DotName.createSimple(UserMessage.class);
static final DotName USER_NAME = DotName.createSimple(UserName.class);
static final DotName MODERATE = DotName.createSimple(Moderate.class);
static final DotName MEMORY_ID = DotName.createSimple(MemoryId.class);
Expand All @@ -52,7 +52,7 @@ public class LangChain4jDotNames {
static final DotName V = DotName.createSimple(dev.langchain4j.service.V.class);

static final DotName MODEL_NAME = DotName.createSimple(ModelName.class);
static final DotName REGISTER_AI_SERVICES = DotName.createSimple(RegisterAiService.class);
public static final DotName REGISTER_AI_SERVICES = DotName.createSimple(RegisterAiService.class);

static final DotName BEAN_CHAT_MODEL_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanChatLanguageModelSupplier.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
package io.quarkiverse.langchain4j.deployment;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationValue;

import io.quarkiverse.langchain4j.QuarkusPromptTemplateFactory;
import io.quarkus.qute.Engine;
import io.quarkus.qute.Expression;
import io.quarkus.qute.Template;

class TemplateUtil {
public class TemplateUtil {

static List<List<Expression.Part>> parts(String templateStr) {
Template template = Holder.ENGINE.parse(templateStr);
Expand All @@ -20,6 +27,46 @@ static List<List<Expression.Part>> parts(String templateStr) {
return expressions.stream().map(Expression::getParts).collect(Collectors.toList());
}

/**
* Meant to be called with instances of {@link dev.langchain4j.service.SystemMessage} or
* {@link dev.langchain4j.service.UserMessage}
*
* @return the String value of the template or an empty string if not specified
*/
public static String getTemplateFromAnnotationInstance(AnnotationInstance instance) {

if (instance == null) {
return "";
}

AnnotationValue fromResourceValue = instance.value("fromResource");
if (fromResourceValue != null) {
String fromResource = fromResourceValue.asString();
if (!fromResource.startsWith("/")) {
fromResource = "/" + fromResource;

}
try (InputStream is = Thread.currentThread().getContextClassLoader().getResourceAsStream(fromResource)) {
if (is != null) {
return new String(is.readAllBytes());
} else {
throw new FileNotFoundException("Resource not found: " + fromResource);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
} else {
AnnotationValue valueValue = instance.value();
if (valueValue != null) {
AnnotationValue delimiterValue = instance.value("delimiter");
String delimiter = delimiterValue != null ? delimiterValue.asString() : AiServicesProcessor.DEFAULT_DELIMITER;
return String.join(delimiter, valueValue.asStringArray());
}

}
return "";
}

private static class Holder {
private static final Engine ENGINE = Engine.builder().addDefaults()
.addParserHook(new QuarkusPromptTemplateFactory.MustacheTemplateVariableStyleParserHook()).build();
Expand Down
79 changes: 62 additions & 17 deletions docs/modules/ROOT/pages/watsonx.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,25 @@ NOTE: To determine the API key, go to https://cloud.ibm.com/iam/apikeys and gene

When creating prompts using watsonx.ai, it's important to follow the guidelines of the model you choose. Depending on the model, some special instructions may be required to ensure the desired output. For best results, always refer to the documentation provided for each model to maximize the effectiveness of your prompts.

For example, if you choose to use `ibm/granite-13b-chat-v2`, you can use the `<|system|>`, `<|user|>`, and `<|assistant|>` instructions:
To simplify the process of prompt creation, you can use the `prompt-formatter` property to automatically handle the addition of tags to your prompts. This property allows you to avoid manually adding tags by letting the system handle the formatting based on the model's requirements. This functionality is particularly useful for models such as `ibm/granite-13b-chat-v2`, `meta-llama/llama-3-405b-instruct`, and other supported models, ensuring consistent and accurate prompt structures without additional effort.

To enable this functionality, configure the `prompt-formatter` property in your `application.properties` file as follows:

[source,properties,subs=attributes+]
----
quarkus.langchain4j.watsonx.chat-model.prompt-formatter=true
----

When this property is set to `true`, the system will automatically format prompts with the appropriate tags. This helps to maintain prompt clarity and improves interaction with the LLM by ensuring that prompts follow the required structure. If set to `false`, you'll need to manage the tags manually.

For example, if you choose to use `ibm/granite-13b-chat-v2` without using the `prompt-formatter`, you will need to manually add the `<|system|>`, `<|user|>` and `<|assistant|>` instructions:

[source,properties,subs=attributes+]
----
quarkus.langchain4j.watsonx.api-key=hG-...
quarkus.langchain4j.watsonx.base-url=https://us-south.ml.cloud.ibm.com
quarkus.langchain4j.watsonx.chat-model.model-id=ibm/granite-13b-chat-v2
quarkus.langchain4j.watsonx.chat-model.prompt-formatter=false
----

[source,java]
Expand All @@ -96,30 +108,63 @@ public interface LLMService {
}
----

Enabling the `prompt-formatter` will result in:

[source,properties,subs=attributes+]
----
quarkus.langchain4j.watsonx.api-key=hG-...
quarkus.langchain4j.watsonx.base-url=https://us-south.ml.cloud.ibm.com
quarkus.langchain4j.watsonx.chat-model.model-id=ibm/granite-13b-chat-v2
quarkus.langchain4j.watsonx.chat-model.prompt-formatter=true
----

[source,java]
----
@Path("/llm")
public class LLMResource {
@RegisterAiService
public interface LLMService {
@Inject
LLMService llmService;
public record Result(Integer result) {}
@GET
@Path("/calculator")
@Produces(MediaType.APPLICATION_JSON)
public Result calculator() {
return llmService.calculator(2, 2);
}
@SystemMessage("""
You are a calculator and you must perform the mathematical operation
{response_schema}
""")
@UserMessage("""
{firstNumber} + {secondNumber}
""")
public Result calculator(int firstNumber, int secondNumber);
}
----

[source,shell]
----
❯ curl http://localhost:8080/llm/calculator
{"result":4}
----
The `prompt-formatter` supports the following models:

* `mistralai/mistral-large`
* `mistralai/mixtral-8x7b-instruct-v01`
* `sdaia/allam-1-13b-instruct`
* `meta-llama/llama-3-405b-instruct`
* `meta-llama/llama-3-1-70b-instruct`
* `meta-llama/llama-3-1-8b-instruct`
* `meta-llama/llama-3-70b-instruct`
* `meta-llama/llama-3-8b-instruct`
* `ibm/granite-13b-chat-v2`
* `ibm/granite-13b-instruct-v2`
* `ibm/granite-7b-lab`
* `ibm/granite-20b-code-instruct`
* `ibm/granite-34b-code-instruct`
* `ibm/granite-3b-code-instruct`
* `ibm/granite-8b-code-instruct`

==== Tool Execution with Prompt Formatter

In addition to simplifying prompt creation, the `prompt-formatter` property also enables the execution of tools for specific models. Tools allow for dynamic interactions within the model, enabling the AI to perform specific actions or fetch data as part of its response.

When the `prompt-formatter` is enabled and a supported model is selected, the prompt will be automatically formatted to use the tools. More information about tools is available in the xref:./agent-and-tools.adoc[Agent and Tools] page.

Currently, the following model supports tool execution:

* `mistralai/mistral-large`

IMPORTANT: The `@SystemMessage` and `@UserMessage` are joined by default without spaces or new lines, if you want to change this behavior use the property `quarkus.langchain4j.watsonx.chat-model.prompt-joiner=<value>`. By adjusting this property, you can define your preferred way of joining messages and ensure that the prompt structure meets your specific needs.
IMPORTANT: The `@SystemMessage` and `@UserMessage` annotations are joined by default with a new line. If you want to change this behavior, use the property `quarkus.langchain4j.watsonx.chat-model.prompt-joiner=<value>`. By adjusting this property, you can define your preferred way of joining messages and ensure that the prompt structure meets your specific needs. This customization option is available only when the `prompt-formatter` property is set to `false`. When the `prompt-formatter` is enabled (set to `true`), the prompt formatting, including the addition of tags and message joining, is automatically handled. In this case, the `prompt-joiner` property will be ignored, and you will not have the ability to customize how messages are joined.

NOTE: Sometimes it may be useful to use the `quarkus.langchain4j.watsonx.chat-model.stop-sequences` property to prevent the LLM model from returning more results than desired.

Expand Down
Loading

0 comments on commit 7d1119d

Please sign in to comment.