From c0adb63adbf1846f4e59e555dffbbea31d4a262e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 12 Jun 2024 22:15:10 +0000 Subject: [PATCH 01/11] Bump quarkus-neo4j.version from 4.0.0 to 4.1.0 Bumps `quarkus-neo4j.version` from 4.0.0 to 4.1.0. Updates `io.quarkiverse.neo4j:quarkus-neo4j-deployment` from 4.0.0 to 4.1.0 - [Commits](https://github.com/quarkiverse/quarkus-neo4j/compare/4.0.0...4.1.0) Updates `io.quarkiverse.neo4j:quarkus-neo4j` from 4.0.0 to 4.1.0 - [Commits](https://github.com/quarkiverse/quarkus-neo4j/compare/4.0.0...4.1.0) --- updated-dependencies: - dependency-name: io.quarkiverse.neo4j:quarkus-neo4j-deployment dependency-type: direct:production update-type: version-update:semver-minor - dependency-name: io.quarkiverse.neo4j:quarkus-neo4j dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index fbe66681e..65480df02 100644 --- a/pom.xml +++ b/pom.xml @@ -65,7 +65,7 @@ 3.6.0 0.200.0 1.3.2 - 4.0.0 + 4.1.0 From 58e2662bd976880d4b44ba973cf444593e6d30e9 Mon Sep 17 00:00:00 2001 From: Sergey Beryozkin Date: Wed, 1 May 2024 18:51:26 +0100 Subject: [PATCH 02/11] Add secure-fraud-detection demo --- pom.xml | 1 + samples/secure-fraud-detection/README.md | 118 +++++++++++++++ samples/secure-fraud-detection/pom.xml | 134 ++++++++++++++++++ .../langchain4j/sample/Customer.java | 16 +++ .../langchain4j/sample/CustomerConfig.java | 11 ++ .../sample/CustomerRepository.java | 19 +++ .../langchain4j/sample/FraudDetectionAi.java | 33 +++++ .../FraudDetectionContentRetriever.java | 56 ++++++++ .../sample/FraudDetectionResource.java | 26 ++++ .../FraudDetectionRetrievalAugmentor.java | 27 ++++ .../langchain4j/sample/LoginResource.java | 33 +++++ .../langchain4j/sample/LogoutResource.java | 29 ++++ .../sample/MissingCustomerException.java | 4 + .../MissingCustomerExceptionMapper.java | 18 +++ .../sample/MissingCustomerResource.java | 32 +++++ .../quarkiverse/langchain4j/sample/Setup.java | 51 +++++++ .../langchain4j/sample/Transaction.java | 26 ++++ .../sample/TransactionRepository.java | 19 +++ .../META-INF/resources/images/google.png | Bin 0 -> 1768 bytes .../resources/META-INF/resources/index.html | 133 +++++++++++++++++ .../src/main/resources/application.properties | 14 ++ .../resources/templates/fraudDetection.html | 18 +++ .../resources/templates/missingCustomer.html | 18 +++ 23 files changed, 836 insertions(+) create mode 100644 samples/secure-fraud-detection/README.md create mode 100644 samples/secure-fraud-detection/pom.xml create mode 100644 samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/Customer.java create mode 100644 samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/CustomerConfig.java create mode 100644 samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/CustomerRepository.java create mode 100644 samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionAi.java create mode 100644 samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionContentRetriever.java create mode 100644 samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionResource.java create mode 100644 samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionRetrievalAugmentor.java create mode 100644 samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/LoginResource.java create mode 100644 samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/LogoutResource.java create mode 100644 samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/MissingCustomerException.java create mode 100644 samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/MissingCustomerExceptionMapper.java create mode 100644 samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/MissingCustomerResource.java create mode 100644 samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/Setup.java create mode 100644 samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/Transaction.java create mode 100644 samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/TransactionRepository.java create mode 100644 samples/secure-fraud-detection/src/main/resources/META-INF/resources/images/google.png create mode 100644 samples/secure-fraud-detection/src/main/resources/META-INF/resources/index.html create mode 100644 samples/secure-fraud-detection/src/main/resources/application.properties create mode 100644 samples/secure-fraud-detection/src/main/resources/templates/fraudDetection.html create mode 100644 samples/secure-fraud-detection/src/main/resources/templates/missingCustomer.html diff --git a/pom.xml b/pom.xml index fbe66681e..b0cb2f964 100644 --- a/pom.xml +++ b/pom.xml @@ -198,6 +198,7 @@ samples/cli-translator samples/review-triage samples/fraud-detection + samples/secure-fraud-detection samples/chatbot samples/chatbot-easy-rag samples/sql-chatbot diff --git a/samples/secure-fraud-detection/README.md b/samples/secure-fraud-detection/README.md new file mode 100644 index 000000000..b7bed77df --- /dev/null +++ b/samples/secure-fraud-detection/README.md @@ -0,0 +1,118 @@ +# Secure Fraud Detection Demo + +This demo showcases the implementation of a secure fraud detection system which is available only to users authenticated with Google. +It uses the `gpt-3.5-turbo` LLM, use `quarkus.langchain4j.openai.chat-model.model-name` property to select a different model. + +## The Demo + +### Setup + +The demo requires that your Google account's full name and email are configured. +You can use system or env properties, see `Running the Demo` section below. + +When the application starts, 5 transactions with random amounts between 1 and 1000 are generated for the registered user. +A random city is also assigned to each transaction. + +The setup is defined in the [Setup.java](./src/main/java/io/quarkiverse/langchain4j/samples/Setup.java) class. + +The registered user and transactions are stored in a PostgreSQL database. When running the demo in dev mode (recommended), the database is automatically created and populated. + +### Content Retrieval + +To enable fraud detection, we provide the LLM with access to the custom [FraudDetectionContentRetriever](./src/main/java/io/quarkiverse/langchain4j/samples/FraudDetectionContentRetriever.java) content retriever. + +`FraudDetectionContentRetriever` is registered by [FraudDetectionRetrievalAugmentor](./src/main/java/io/quarkiverse/langchain4j/samples/FraudDetectionRetrievalAugmentor.java). + +It can only be accessed securely and it retrieves transaction data for the currently authenticated user through two Panache repositories: + +- [CustomerRepository.java](./src/main/java/io/quarkiverse/langchain4j/samples/CustomerRepository.java) +- [TransactionRepository.java](./src/main/java/io/quarkiverse/langchain4j/samples/TransactionRepository.java) + +It extracts the authenticated user's name and email from an injected `JsonWebToken` ID token. + +### AI Service + +This demo leverages the AI service abstraction, with the interaction between the LLM and the application handled through the AIService interface. + +The `io.quarkiverse.langchain4j.sample.FraudDetectionAi` interface uses specific annotations to define the LLM: + +```java +@RegisterAiService(retrievalAugmentor = FraudDetectionRetrievalAugmentor.class) +``` + +For each message, the prompt is engineered to help the LLM understand the context and answer the request: + +```java + @SystemMessage(""" + You are a bank account fraud detection AI. You have to detect frauds in transactions. + """) + @UserMessage(""" + Your task is to detect whether a fraud was committed for the customer. + + Answer with a **single** JSON document containing: + - the customer name in the 'customer-name' key + - the transaction limit in the 'transaction-limit' key + - the computed sum of all transactions committed during the last 15 minutes in the 'total' key + - the 'fraud' key set to true if the computed sum of all transactions is greater than the transaction limit + - the 'transactions' key containing an array of JSON objects. Each object must have transaction 'amount', 'city' and formatted 'time' keys. + - the 'explanation' key containing an explanation of your answer. + - the 'email' key containing the customer email if the fraud was detected. + + Your response must be just the raw JSON document, without ```json, ``` or anything else. Do not use null JSON properties. + """) +@Timeout(value = 2, unit = ChronoUnit.MINUTES) +String detectAmountFraudForCustomer(); +``` + +_Note:_ You can also use fault tolerance annotations in combination with the prompt annotations. + +### Using the AI service + +Once defined, you can inject the AI service as a regular bean, and use it: + +```java +package io.quarkiverse.langchain4j.sample; + +import io.quarkus.security.Authenticated; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; + +@Path("/fraud") +@Authenticated +public class FraudDetectionResource { + + private final FraudDetectionAi service; + + public FraudDetectionResource(FraudDetectionAi service) { + this.service = service; + } + + @GET + @Path("/amount") + public String detectBaseOnAmount() { + return service.detectAmountFraudForCustomer(); + } +} +``` + +`FraudDetectionResource` can only be accessed by authenticated users. + +## Google Authentication + +This demo requires users to authenticate with Google. +All you need to do is to register an application with Google, follow steps listed in the [Quarkus Google](https://quarkus.io/guides/security-openid-connect-providers#google) section. +Name your Google application as `Quarkus LangChain4j AI`, and make sure an allowed callback URL is set to `http://localhost:8080/login`. +Google will generate a client id and secret, use them to set `quarkus.oidc.client-id` and `quarkus.oidc.credentials.secret` properties. + +## Running the Demo + +To run the demo, use the following command: + +```shell +mvn quarkus:dev -Dname="Firstname Familyname" -Demail=someuser@gmail.com +``` + +Note, you should use double quotes to register your Google account's full name. + +Then, access `http://localhost:8080`, login to Google, and follow a provided application link to check the fraud. + diff --git a/samples/secure-fraud-detection/pom.xml b/samples/secure-fraud-detection/pom.xml new file mode 100644 index 000000000..fc976c2fc --- /dev/null +++ b/samples/secure-fraud-detection/pom.xml @@ -0,0 +1,134 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-sample-secure-fraud-detection + Quarkus LangChain4j - Sample - Secure Fraud Detection + 1.0-SNAPSHOT + + + 3.13.0 + true + 17 + UTF-8 + UTF-8 + quarkus-bom + io.quarkus + 3.9.4 + true + 3.2.5 + 0.15.1 + + + + + + ${quarkus.platform.group-id} + ${quarkus.platform.artifact-id} + ${quarkus.platform.version} + pom + import + + + + + + + io.quarkus + quarkus-resteasy-reactive-jackson + + + io.quarkus + quarkus-oidc + + + io.quarkiverse.langchain4j + quarkus-langchain4j-openai + ${quarkus-langchain4j.version} + + + io.quarkus + quarkus-smallrye-fault-tolerance + + + io.quarkus + quarkus-jdbc-postgresql + + + io.quarkus + quarkus-hibernate-orm-panache + + + io.quarkus + quarkus-resteasy-reactive-qute + + + + + + io.quarkus + quarkus-maven-plugin + ${quarkus.platform.version} + + + + build + + + + + + maven-compiler-plugin + ${compiler-plugin.version} + + + maven-surefire-plugin + 3.2.5 + + + org.jboss.logmanager.LogManager + ${maven.home} + + + + + + + + + native + + + native + + + + + + maven-failsafe-plugin + 3.2.5 + + + + integration-test + verify + + + + ${project.build.directory}/${project.build.finalName}-runner + org.jboss.logmanager.LogManager + ${maven.home} + + + + + + + + + native + + + + diff --git a/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/Customer.java b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/Customer.java new file mode 100644 index 000000000..d67c3f02e --- /dev/null +++ b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/Customer.java @@ -0,0 +1,16 @@ +package io.quarkiverse.langchain4j.sample; + +import jakarta.persistence.Entity; +import jakarta.persistence.GeneratedValue; +import jakarta.persistence.Id; + +@Entity +public class Customer { + + @Id + @GeneratedValue + public Long id; + public String name; + public String email; + public int transactionLimit; +} diff --git a/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/CustomerConfig.java b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/CustomerConfig.java new file mode 100644 index 000000000..ecfb6a605 --- /dev/null +++ b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/CustomerConfig.java @@ -0,0 +1,11 @@ +package io.quarkiverse.langchain4j.sample; + +import io.smallrye.config.ConfigMapping; + +@ConfigMapping(prefix = "customer") +public interface CustomerConfig { + + String name(); + + String email(); +} diff --git a/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/CustomerRepository.java b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/CustomerRepository.java new file mode 100644 index 000000000..f1659ec5f --- /dev/null +++ b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/CustomerRepository.java @@ -0,0 +1,19 @@ +package io.quarkiverse.langchain4j.sample; + +import io.quarkus.hibernate.orm.panache.PanacheRepository; +import jakarta.enterprise.context.ApplicationScoped; + +@ApplicationScoped +public class CustomerRepository implements PanacheRepository { + + /* + * Transaction limit for the customer. + */ + public int getTransactionLimit(String customerName, String customerEmail) { + Customer customer = find("name = ?1 and email = ?2", customerName, customerEmail).firstResult(); + if (customer == null) { + throw new MissingCustomerException(); + } + return customer.transactionLimit; + } +} diff --git a/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionAi.java b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionAi.java new file mode 100644 index 000000000..7d93ddc19 --- /dev/null +++ b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionAi.java @@ -0,0 +1,33 @@ +package io.quarkiverse.langchain4j.sample; + +import java.time.temporal.ChronoUnit; + +import org.eclipse.microprofile.faulttolerance.Timeout; + +import dev.langchain4j.service.SystemMessage; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; + +@RegisterAiService(retrievalAugmentor = FraudDetectionRetrievalAugmentor.class) +public interface FraudDetectionAi { + + @SystemMessage(""" + You are a bank account fraud detection AI. You have to detect frauds in transactions. + """) + @UserMessage(""" + Your task is to detect whether a fraud was committed for the customer. + + Answer with a **single** JSON document containing: + - the customer name in the 'customer-name' key + - the transaction limit in the 'transaction-limit' key + - the computed sum of all transactions committed during the last 15 minutes in the 'total' key + - the 'fraud' key set to true if the computed sum of all transactions is greater than the transaction limit + - the 'transactions' key containing an array of JSON objects. Each object must have transaction 'amount', 'city' and formatted 'time' keys. + - the 'explanation' key containing an explanation of your answer. + - the 'email' key containing the customer email if the fraud was detected. + + Your response must be just the raw JSON document, without ```json, ``` or anything else. Do not use null JSON properties. + """) + @Timeout(value = 2, unit = ChronoUnit.MINUTES) + String detectAmountFraudForCustomer(); +} diff --git a/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionContentRetriever.java b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionContentRetriever.java new file mode 100644 index 000000000..3526d0bf4 --- /dev/null +++ b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionContentRetriever.java @@ -0,0 +1,56 @@ +package io.quarkiverse.langchain4j.sample; + +import java.time.ZoneOffset; +import java.util.List; + +import org.eclipse.microprofile.jwt.Claims; +import org.eclipse.microprofile.jwt.JsonWebToken; +import org.jboss.logging.Logger; + +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.rag.content.retriever.ContentRetriever; +import dev.langchain4j.rag.query.Query; +import io.quarkus.oidc.IdToken; +import io.quarkus.security.Authenticated; +import io.vertx.core.json.JsonArray; +import io.vertx.core.json.JsonObject; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +@ApplicationScoped +public class FraudDetectionContentRetriever implements ContentRetriever { + private static final Logger log = Logger.getLogger(FraudDetectionContentRetriever.class); + + @Inject + TransactionRepository transactionRepository; + + @Inject + CustomerRepository customerRepository; + + @Inject + @IdToken + JsonWebToken idToken; + + @Override + @Authenticated + public List retrieve(Query query) { + log.infof("Use customer name %s and email %s to retrieve content", idToken.getName(), + idToken.getClaim(Claims.email)); + + int transactionLimit = customerRepository.getTransactionLimit(idToken.getName(), + idToken.getClaim(Claims.email)); + + List transactions = transactionRepository.getTransactionsForCustomer(idToken.getName(), + idToken.getClaim(Claims.email)); + + JsonArray jsonTransactions = new JsonArray(); + for (Transaction t : transactions) { + jsonTransactions.add(JsonObject.of("customer-name", t.customerName, "customer-email", t.customerEmail, + "transaction-amount", t.amount, "transaction-city", t.city, + "transaction-time-in-seconds-from-the-epoch", t.time.toEpochSecond(ZoneOffset.UTC))); + } + + JsonObject json = JsonObject.of("transaction-limit", transactionLimit, "transactions", jsonTransactions); + return List.of(Content.from(json.toString())); + } +} diff --git a/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionResource.java b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionResource.java new file mode 100644 index 000000000..e08895659 --- /dev/null +++ b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionResource.java @@ -0,0 +1,26 @@ +package io.quarkiverse.langchain4j.sample; + +import io.quarkus.security.Authenticated; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; + +@Path("/fraud") +@Authenticated +public class FraudDetectionResource { + + private final FraudDetectionAi service; + + public FraudDetectionResource(FraudDetectionAi service) { + this.service = service; + } + + @GET + @Path("/amount") + public String detectBaseOnAmount() { + try { + return service.detectAmountFraudForCustomer(); + } catch (RuntimeException ex) { + throw (ex.getCause() instanceof MissingCustomerException) ? (MissingCustomerException) ex.getCause() : ex; + } + } +} diff --git a/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionRetrievalAugmentor.java b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionRetrievalAugmentor.java new file mode 100644 index 000000000..d424b0df5 --- /dev/null +++ b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionRetrievalAugmentor.java @@ -0,0 +1,27 @@ +package io.quarkiverse.langchain4j.sample; + +import java.util.function.Supplier; + +import org.eclipse.microprofile.context.ManagedExecutor; + +import dev.langchain4j.rag.DefaultRetrievalAugmentor; +import dev.langchain4j.rag.RetrievalAugmentor; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +@ApplicationScoped +public class FraudDetectionRetrievalAugmentor implements Supplier { + + @Inject + FraudDetectionContentRetriever contentRetriever; + + @Inject + ManagedExecutor executor; + + @Override + public RetrievalAugmentor get() { + return DefaultRetrievalAugmentor.builder() + .executor(executor) + .contentRetriever(contentRetriever).build(); + } +} diff --git a/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/LoginResource.java b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/LoginResource.java new file mode 100644 index 000000000..c09d3d8f8 --- /dev/null +++ b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/LoginResource.java @@ -0,0 +1,33 @@ +package io.quarkiverse.langchain4j.sample; + +import org.eclipse.microprofile.jwt.JsonWebToken; + +import io.quarkus.oidc.IdToken; +import io.quarkus.qute.Template; +import io.quarkus.qute.TemplateInstance; +import io.quarkus.security.Authenticated; +import jakarta.inject.Inject; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; + +/** + * Login resource which returns a fraud detection page to the authenticated user + */ +@Path("/login") +@Authenticated +public class LoginResource { + + @Inject + @IdToken + JsonWebToken idToken; + + @Inject + Template fraudDetection; + + @GET + @Produces("text/html") + public TemplateInstance fraudDetection() { + return fraudDetection.data("name", idToken.getName()); + } +} diff --git a/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/LogoutResource.java b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/LogoutResource.java new file mode 100644 index 000000000..7db2ad4c9 --- /dev/null +++ b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/LogoutResource.java @@ -0,0 +1,29 @@ +package io.quarkiverse.langchain4j.sample; + +import io.quarkus.oidc.OidcSession; +import io.quarkus.security.Authenticated; +import jakarta.inject.Inject; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; + +/** + * Logout resource + */ +@Path("/logout") +@Authenticated +public class LogoutResource { + + @Inject + OidcSession session; + + @GET + public Response logout(@Context UriInfo uriInfo) { + // remove the local session cookie + session.logout().await().indefinitely(); + // redirect to the login page + return Response.seeOther(uriInfo.getBaseUriBuilder().path("login").build()).build(); + } +} diff --git a/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/MissingCustomerException.java b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/MissingCustomerException.java new file mode 100644 index 000000000..039d5e164 --- /dev/null +++ b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/MissingCustomerException.java @@ -0,0 +1,4 @@ +package io.quarkiverse.langchain4j.sample; + +public class MissingCustomerException extends RuntimeException { +} diff --git a/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/MissingCustomerExceptionMapper.java b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/MissingCustomerExceptionMapper.java new file mode 100644 index 000000000..7ff0cf742 --- /dev/null +++ b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/MissingCustomerExceptionMapper.java @@ -0,0 +1,18 @@ +package io.quarkiverse.langchain4j.sample; + +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; +import jakarta.ws.rs.ext.ExceptionMapper; +import jakarta.ws.rs.ext.Provider; + +@Provider +public class MissingCustomerExceptionMapper implements ExceptionMapper { + @Context + UriInfo uriInfo; + + @Override + public Response toResponse(MissingCustomerException ex) { + return Response.seeOther(uriInfo.getBaseUriBuilder().path("missingCustomer").build()).build(); + } +} diff --git a/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/MissingCustomerResource.java b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/MissingCustomerResource.java new file mode 100644 index 000000000..16d7c7aae --- /dev/null +++ b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/MissingCustomerResource.java @@ -0,0 +1,32 @@ +package io.quarkiverse.langchain4j.sample; + +import org.eclipse.microprofile.jwt.Claims; +import org.eclipse.microprofile.jwt.JsonWebToken; + +import io.quarkus.oidc.IdToken; +import io.quarkus.qute.Template; +import io.quarkus.qute.TemplateInstance; +import io.quarkus.security.Authenticated; +import jakarta.inject.Inject; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; + +@Path("/missingCustomer") +@Authenticated +public class MissingCustomerResource { + + @Inject + @IdToken + JsonWebToken idToken; + + @Inject + Template missingCustomer; + + @GET + @Produces("text/html") + public TemplateInstance missingCustomer() { + return missingCustomer.data("given_name", idToken.getClaim("given_name")).data("name", idToken.getName()) + .data("email", idToken.getClaim(Claims.email)); + } +} diff --git a/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/Setup.java b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/Setup.java new file mode 100644 index 000000000..b1a0c83de --- /dev/null +++ b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/Setup.java @@ -0,0 +1,51 @@ +package io.quarkiverse.langchain4j.sample; + +import java.time.LocalDateTime; +import java.util.List; +import java.util.Random; + +import io.quarkus.runtime.StartupEvent; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.event.Observes; +import jakarta.inject.Inject; +import jakarta.transaction.Transactional; + +@ApplicationScoped +public class Setup { + + public static List CITIES = List.of("Paris", "Lyon", "Marseille", "Bordeaux", "Toulouse", "Nantes", "Brest", + "Clermont-Ferrand", "La Rochelle", "Lille", "Metz", "Strasbourg", "Nancy", "Valence", "Avignon", + "Montpellier", "Nime", "Arles", "Nice", "Cannes"); + + public static String getARandomCity() { + return CITIES.get(new Random().nextInt(CITIES.size())); + } + + @Inject + CustomerConfig config; + + @Transactional + public void init(@Observes StartupEvent ev, CustomerRepository customers, TransactionRepository transactions) { + customers.deleteAll(); + Random random = new Random(); + + var customer = new Customer(); + customer.name = config.name(); + customer.email = config.email(); + customer.transactionLimit = 1000; + customers.persist(customer); + + transactions.deleteAll(); // Delete all transactions + for (int i = 0; i < 5; i++) { + var transaction = new Transaction(); + transaction.customerName = customer.name; + transaction.customerEmail = customer.email; + transaction.amount = random.nextInt(1000) + 1; + transaction.time = LocalDateTime.now().minusMinutes(random.nextInt(20)); + transaction.city = getARandomCity(); + transactions.persist(transaction); + } + + System.out.println("Customer: " + customer.name + " - " + customer.email); + } +} diff --git a/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/Transaction.java b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/Transaction.java new file mode 100644 index 000000000..111abcd9f --- /dev/null +++ b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/Transaction.java @@ -0,0 +1,26 @@ +package io.quarkiverse.langchain4j.sample; + +import java.time.LocalDateTime; + +import jakarta.persistence.Entity; +import jakarta.persistence.GeneratedValue; +import jakarta.persistence.Id; + +@Entity +public class Transaction { + + @Id + @GeneratedValue + public Long id; + + public double amount; + + public String customerName; + + public String customerEmail; + + public String city; + + public LocalDateTime time; + +} diff --git a/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/TransactionRepository.java b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/TransactionRepository.java new file mode 100644 index 000000000..397602a5e --- /dev/null +++ b/samples/secure-fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/TransactionRepository.java @@ -0,0 +1,19 @@ +package io.quarkiverse.langchain4j.sample; + +import java.time.LocalDateTime; +import java.util.List; + +import io.quarkus.hibernate.orm.panache.PanacheRepository; +import jakarta.enterprise.context.ApplicationScoped; + +@ApplicationScoped +public class TransactionRepository implements PanacheRepository { + + /* + * List of transactions during the last 30 minutes. + */ + public List getTransactionsForCustomer(String customerName, String customerEmail) { + return find("customerName = ?1 and customerEmail = ?2 and time > ?3", customerName, customerEmail, + LocalDateTime.now().minusMinutes(30)).list(); + } +} diff --git a/samples/secure-fraud-detection/src/main/resources/META-INF/resources/images/google.png b/samples/secure-fraud-detection/src/main/resources/META-INF/resources/images/google.png new file mode 100644 index 0000000000000000000000000000000000000000..e2503df4e1444b9873941468508ccd199936d6c6 GIT binary patch literal 1768 zcmVP)Jou1o8{pKi)?@R z$HYuUmTlRh$s9G-NtJA%umqQ(h5F#3DAYk8%);GwE*34;-b+=oUs8I%bI*70k8{57 zeCM2U@Do{ic_|ojC2*W#5gR555P0^@-@R?+01Lob1P)Dr!dR-;8{r4Iu_zZ56!86a z`xbD**68SHJ~AQ_>>H8c5eNzjf?BN>^_)&8dU|@$W9>z2Yb!39FQKcm)52~Um7KCr zuh(1TzEc)Kk(k4tuU1dkx?uhS%+e)6qtRe2=6kSo5OaJ$gO^3 z^%EmU2vP@O!}@gSbh-pGh&LVOXaTu3_W9V?-dej3q9`($Os3;_7{Tnk z>NDPb4=R-k;|Z-+i}Ov5$%%=&+hDvD)U6ML=Ox(p}y%2s<8(B{-!2y#IeZ`-MBz-{{7LMTc8M2Ae&fcxx z4Hn2Xw5vrIll8RGTG{0KD?Z2VT{)8E#AkPP;-BJ=nWc|H4#?H>*@4jSv~Q)cdkv;6 zNllaW09)lg*ngS=}QtZ{$3Q#`r~! zpk-+s3ppt&mZ?OAYC@%FbpI@VJ(YG1ipc{QQf~4|?okmL9!Vi4p5kDXCRBR;2@cE1 zwR(0pl!@LK{n`pUq>#QK5KCY1eNw4|)a={0b4&-xeT+rYnPz z1n+X<75JZn2wEw6OMz2Id5nd=CishKlcg(-yPT{lrfPLMoRV(lLgFDfE#mbNhZbFy zHe^2ipx%j{!4Jr`f9n+8>PFaPX~$7%z_F^3G#lGH+9fA&QV^aaV>|n&K3Y5nazNIj?t>D_LK=%jQ(ax1 zbJmMVnm|=xG9|8Z5Tl90nwVEak}TmspfV7*oWo@;$zpiAzYO`wA3!%X7IHy`y6{)? z{$hP`C5`XZBm;v5tap;mN=}|q9l2OGBSri?4v_T>27`tAHg4HS*-2KMJ8$m5-xph6 zPD)BbfGPmv30huzzp8+9z_E>4&xp|k31>`a>SpQ^H2r=3$y7{_Ck`GugsWGtl$R8h z=p~g|*g08MT_ubc=v}OfcHSEouOZzvaOd^ zWaeg;QH$(W&n@9`<68SQvHiv@SLTa@($~$irnA53k`gc1)|wL7P&{d<)%z$Qs~L?( z96Wr8Wl@eDrNyNMzVv0afT{D~f~!Eoho)1UOiN9LtUBgCQ2Iot)o-<@G2nETGnW>4 z>j(Npa5Pa)P%H|M2#+GN)I<;EHK4l2ol}yPSdNMI&JMAO-oes2mYrBxL@%GHo>(S_ zm|q0%2Nl(9kEKxAN|hmoKlgry;5aMm4m7gmc8W5lQhlire&9b-ai9neJ9jz&0000< KMNUMnLSTYO@I>?g literal 0 HcmV?d00001 diff --git a/samples/secure-fraud-detection/src/main/resources/META-INF/resources/index.html b/samples/secure-fraud-detection/src/main/resources/META-INF/resources/index.html new file mode 100644 index 000000000..6fd3636ed --- /dev/null +++ b/samples/secure-fraud-detection/src/main/resources/META-INF/resources/index.html @@ -0,0 +1,133 @@ + + + + + Secure Fraud Detection + + + + + + +
+
+
+

Login

+ + + + +
Login with Google
+ +
+
+
+ +
+ +
+ + + + diff --git a/samples/secure-fraud-detection/src/main/resources/application.properties b/samples/secure-fraud-detection/src/main/resources/application.properties new file mode 100644 index 000000000..6c50ae0d2 --- /dev/null +++ b/samples/secure-fraud-detection/src/main/resources/application.properties @@ -0,0 +1,14 @@ +quarkus.langchain4j.openai.timeout=60s +quarkus.langchain4j.openai.chat-model.temperature=0 +quarkus.langchain4j.openai.api-key=${OPENAI_API_KEY} + +quarkus.oidc.provider=google +quarkus.oidc.client-id=${GOOGLE_CLIENT_ID} +quarkus.oidc.credentials.secret=${GOOGLE_CLIENT_SECRET} +quarkus.oidc.authentication.redirect-path=/login + +quarkus.langchain4j.openai.log-requests=true +quarkus.langchain4j.openai.log-responses=true + +customer.name=${name} +customer.email=${email} diff --git a/samples/secure-fraud-detection/src/main/resources/templates/fraudDetection.html b/samples/secure-fraud-detection/src/main/resources/templates/fraudDetection.html new file mode 100644 index 000000000..080a487af --- /dev/null +++ b/samples/secure-fraud-detection/src/main/resources/templates/fraudDetection.html @@ -0,0 +1,18 @@ + + + + +Secure Fraud Detection + + +

Hello {name}, please check fraud occurrences by amount:

+ + + + +
+ Detect fraud +
+ Logout + + diff --git a/samples/secure-fraud-detection/src/main/resources/templates/missingCustomer.html b/samples/secure-fraud-detection/src/main/resources/templates/missingCustomer.html new file mode 100644 index 000000000..12b1e4a8b --- /dev/null +++ b/samples/secure-fraud-detection/src/main/resources/templates/missingCustomer.html @@ -0,0 +1,18 @@ + + + + +Missing Customer + + + {given_name}, please make sure your Google account's full name and email are correctly registered at the startup using -Dname="{name}" and -Demail={email} system properties +

+ + + + +
+ Back to the main page +
+ + From f93eccd93566eaac603f3455b0e7e9d704a5ea72 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Fri, 14 Jun 2024 12:32:45 +0300 Subject: [PATCH 03/11] Update the email me a poem sample: * Use quarkus-mailpit - This is done to make the setup simpler * Introduce the LGTM observability Dev Service - This is done so users don't have to start anything manually - For this to work, users need to use -Dobservability * Use gpt-4o - This is done because GTP-3.x can sometimes result in multiple tool invocations - Improve the prompt to ensure that no markers are added * Update README accordingly --- samples/email-a-poem/README.md | 49 +++++++++++-------- samples/email-a-poem/pom.xml | 37 +++++++++++--- .../langchain4j/sample/MyAiService.java | 2 +- .../src/main/resources/application.properties | 13 +++-- 4 files changed, 68 insertions(+), 33 deletions(-) diff --git a/samples/email-a-poem/README.md b/samples/email-a-poem/README.md index da543407e..cb0d97742 100644 --- a/samples/email-a-poem/README.md +++ b/samples/email-a-poem/README.md @@ -11,38 +11,47 @@ A prerequisite to running this example is to provide your OpenAI API key. export QUARKUS_LANGCHAIN4J_OPENAI_API_KEY= ``` -To allow the application to send emails, start a mock SMTP server container: +Then, simply run the project in Dev mode: ``` -docker run -p 8025:8025 -p 1025:1025 docker.io/mailhog/mailhog +mvn quarkus:dev ``` -Alternatively, for podman users can use the following command: +> **_NOTE:_** +> When demoing observability is desired, execute `mvn quarkus:dev -Dobservability` -``` -podman run -p 8025:8025 -p 1025:1025 docker.io/mailhog/mailhog -``` +## Using the example -Then, simply run the project in Dev mode: +Open the application at http://localhost:8080 and click `Send me an email`. -``` -mvn quarkus:dev -``` +Quarkus will use a mock mailer which simply logs the email on the terminal. -## Using the example +## Viewing the sent email + +Go to the DevUI and click on the Mailpit UI + +## Viewing traces + +> **_NOTE:_** +> For this to be applicable, the application has to have been started using `mvn quarkus:dev -Dobservability` -Open the UI of the mock SMTP server at http://localhost:8025. This is where any -emails sent by the robot will appear. +The application has been configured to start the LGTM stack via [Dev Service](https://quarkus.io/guides/observability-devservices-lgtm). -To have the robot write a poem and send it to `sendMeALetter@quarkus.io` (the -actual address doesn't matter, for any address it will simply appear in the -SMTP server's UI), execute: +Find the host port on which Grafana is running by executing: ``` -curl http://localhost:8080/email-me-a-poem +GRAFANA_PORT=$(docker inspect $(docker container ls -q --filter 'label=quarkus-dev-service-lgtm=quarkus') --format '{{index (index (index .NetworkSettings.Ports "3000/tcp") 0) "HostPort"}}') +echo http://localhost:$GRAFANA_PORT ``` -If you don't have curl or a similar tool, simply opening the URL in your web -browser will work too. After this is done, open the SMTP server's UI and you -will see the email with a poem about Quarkus. +Open your browser at `http://localhost:${GRAFANA_PORT}` + +When prompted to login, use `admin:admin` as the username / password combination. + +From the menu on the top left, click on `Explore`. On the page, select `Tempo` as the datasource (next to `Outline`), then go to `Query type`, select `Search` and select `quarkus-langchain4j-sample-poem` from the dropdown options of `Service Name`. +Now hit `Run query` in the top right corner. + +## Viewing metrics + +Simply open the application at http://localhost:8080/q/metrics diff --git a/samples/email-a-poem/pom.xml b/samples/email-a-poem/pom.xml index 342bfeb41..d4b41e330 100644 --- a/samples/email-a-poem/pom.xml +++ b/samples/email-a-poem/pom.xml @@ -15,7 +15,7 @@ UTF-8 quarkus-bom io.quarkus - 3.8.5 + 3.12.0.CR1 true 3.2.5 0.15.1 @@ -36,7 +36,7 @@ io.quarkus - quarkus-resteasy-reactive-jackson + quarkus-rest-jackson io.quarkiverse.langchain4j @@ -47,13 +47,11 @@ io.quarkus quarkus-mailer + - io.quarkus - quarkus-opentelemetry - - - io.quarkus - quarkus-micrometer-registry-prometheus + io.quarkiverse.mailpit + quarkus-mailpit + 1.1.0 @@ -88,6 +86,29 @@ + + observability + + + observability + + + + + io.quarkus + quarkus-opentelemetry + + + io.quarkus + quarkus-micrometer-registry-prometheus + + + io.quarkus + quarkus-observability-devservices-lgtm + provided + + + native diff --git a/samples/email-a-poem/src/main/java/io/quarkiverse/langchain4j/sample/MyAiService.java b/samples/email-a-poem/src/main/java/io/quarkiverse/langchain4j/sample/MyAiService.java index 4b3d81209..de093caba 100644 --- a/samples/email-a-poem/src/main/java/io/quarkiverse/langchain4j/sample/MyAiService.java +++ b/samples/email-a-poem/src/main/java/io/quarkiverse/langchain4j/sample/MyAiService.java @@ -16,7 +16,7 @@ public interface MyAiService { */ @SystemMessage("You are a professional poet") @UserMessage(""" - Write a poem about {topic}. The poem should be {lines} lines long. + Write a single poem about {topic}. The poem should be {lines} lines long and your response should only include them poem itself, nothing else. Then send this poem by email. Your response should include the poem. """) String writeAPoem(String topic, int lines); diff --git a/samples/email-a-poem/src/main/resources/application.properties b/samples/email-a-poem/src/main/resources/application.properties index 7542cbb58..36ae48b59 100644 --- a/samples/email-a-poem/src/main/resources/application.properties +++ b/samples/email-a-poem/src/main/resources/application.properties @@ -2,7 +2,12 @@ quarkus.langchain4j.timeout=60s quarkus.langchain4j.log-requests=true quarkus.langchain4j.log-responses=true -quarkus.mailer.from=acme@acme.org -quarkus.mailer.port=1025 -quarkus.mailer.host=localhost -%dev.quarkus.mailer.mock=false +quarkus.langchain4j.openai.chat-model.model-name=gpt-4o + +# mailer config +quarkus.mailer.from=demoer@langchain4j.ai + +# observability config +quarkus.otel.exporter.otlp.traces.protocol=http/protobuf +%test.quarkus.otel.exporter.otlp.traces.endpoint=http://${quarkus.otel-collector.url} +%dev.quarkus.otel.exporter.otlp.traces.endpoint=http://${quarkus.otel-collector.url} From 5ad14001ca7e1158b26611aaea1e043fa39cd9c4 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Mon, 17 Jun 2024 11:25:33 +0300 Subject: [PATCH 04/11] Temporarily disable chroma tests --- .../chroma/deployment/ChromaEmbeddingStoreCDITest.java | 2 ++ .../langchain4j/chroma/deployment/ChromaEmbeddingStoreTest.java | 2 ++ 2 files changed, 4 insertions(+) diff --git a/embedding-stores/chroma/deployment/src/test/java/io/quarkiverse/langchain4j/chroma/deployment/ChromaEmbeddingStoreCDITest.java b/embedding-stores/chroma/deployment/src/test/java/io/quarkiverse/langchain4j/chroma/deployment/ChromaEmbeddingStoreCDITest.java index 4fdb07248..807cc1a45 100644 --- a/embedding-stores/chroma/deployment/src/test/java/io/quarkiverse/langchain4j/chroma/deployment/ChromaEmbeddingStoreCDITest.java +++ b/embedding-stores/chroma/deployment/src/test/java/io/quarkiverse/langchain4j/chroma/deployment/ChromaEmbeddingStoreCDITest.java @@ -5,6 +5,7 @@ import org.jboss.shrinkwrap.api.ShrinkWrap; import org.jboss.shrinkwrap.api.spec.JavaArchive; import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; @@ -17,6 +18,7 @@ /** * Tests injecting a ChromaEmbeddingStore using CDI, configured using properties. */ +@Disabled("temporarily disabled until we figure out what's going on") class ChromaEmbeddingStoreCDITest extends EmbeddingStoreIT { @RegisterExtension diff --git a/embedding-stores/chroma/deployment/src/test/java/io/quarkiverse/langchain4j/chroma/deployment/ChromaEmbeddingStoreTest.java b/embedding-stores/chroma/deployment/src/test/java/io/quarkiverse/langchain4j/chroma/deployment/ChromaEmbeddingStoreTest.java index 23954eb5a..74a91c227 100644 --- a/embedding-stores/chroma/deployment/src/test/java/io/quarkiverse/langchain4j/chroma/deployment/ChromaEmbeddingStoreTest.java +++ b/embedding-stores/chroma/deployment/src/test/java/io/quarkiverse/langchain4j/chroma/deployment/ChromaEmbeddingStoreTest.java @@ -6,6 +6,7 @@ import org.jboss.shrinkwrap.api.ShrinkWrap; import org.jboss.shrinkwrap.api.spec.JavaArchive; import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; @@ -15,6 +16,7 @@ import io.quarkus.logging.Log; import io.quarkus.test.QuarkusUnitTest; +@Disabled("temporarily disabled until we figure out what's going on") class ChromaEmbeddingStoreTest extends EmbeddingStoreIT { @RegisterExtension From e8ec408d128a1afd4cb2e9c21f8c781df6b5c8ca Mon Sep 17 00:00:00 2001 From: Jan Martiska Date: Mon, 17 Jun 2024 12:27:24 +0200 Subject: [PATCH 05/11] Implement logging for Chroma and attempt to fix the test --- .../includes/quarkus-langchain4j-chroma.adoc | 34 +++++ .../ChromaEmbeddingStoreCDITest.java | 4 +- .../deployment/ChromaEmbeddingStoreTest.java | 4 +- .../chroma/ChromaEmbeddingStore.java | 121 +++++++++++++++++- .../chroma/runtime/ChromaConfig.java | 15 +++ .../chroma/runtime/ChromaRecorder.java | 4 +- 6 files changed, 170 insertions(+), 12 deletions(-) diff --git a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-chroma.adoc b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-chroma.adoc index 2e4ad5f5f..7627c1bb1 100644 --- a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-chroma.adoc +++ b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-chroma.adoc @@ -157,6 +157,40 @@ endif::add-copy-button-to-env-var[] | +a| [[quarkus-langchain4j-chroma_quarkus-langchain4j-chroma-log-requests]]`link:#quarkus-langchain4j-chroma_quarkus-langchain4j-chroma-log-requests[quarkus.langchain4j.chroma.log-requests]` + + +[.description] +-- +Whether requests to Chroma should be logged + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_CHROMA_LOG_REQUESTS+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_CHROMA_LOG_REQUESTS+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + + +a| [[quarkus-langchain4j-chroma_quarkus-langchain4j-chroma-log-responses]]`link:#quarkus-langchain4j-chroma_quarkus-langchain4j-chroma-log-responses[quarkus.langchain4j.chroma.log-responses]` + + +[.description] +-- +Whether responses from Chroma should be logged + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_CHROMA_LOG_RESPONSES+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_CHROMA_LOG_RESPONSES+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + + a|icon:lock[title=Fixed at build time] [[quarkus-langchain4j-chroma_quarkus-langchain4j-chroma-devservices-container-env-container-env]]`link:#quarkus-langchain4j-chroma_quarkus-langchain4j-chroma-devservices-container-env-container-env[quarkus.langchain4j.chroma.devservices.container-env]` diff --git a/embedding-stores/chroma/deployment/src/test/java/io/quarkiverse/langchain4j/chroma/deployment/ChromaEmbeddingStoreCDITest.java b/embedding-stores/chroma/deployment/src/test/java/io/quarkiverse/langchain4j/chroma/deployment/ChromaEmbeddingStoreCDITest.java index 807cc1a45..bbb7517dc 100644 --- a/embedding-stores/chroma/deployment/src/test/java/io/quarkiverse/langchain4j/chroma/deployment/ChromaEmbeddingStoreCDITest.java +++ b/embedding-stores/chroma/deployment/src/test/java/io/quarkiverse/langchain4j/chroma/deployment/ChromaEmbeddingStoreCDITest.java @@ -5,7 +5,6 @@ import org.jboss.shrinkwrap.api.ShrinkWrap; import org.jboss.shrinkwrap.api.spec.JavaArchive; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; @@ -18,11 +17,12 @@ /** * Tests injecting a ChromaEmbeddingStore using CDI, configured using properties. */ -@Disabled("temporarily disabled until we figure out what's going on") class ChromaEmbeddingStoreCDITest extends EmbeddingStoreIT { @RegisterExtension static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + // .overrideConfigKey("quarkus.langchain4j.log-requests", "true") + // .overrideConfigKey("quarkus.langchain4j.log-responses", "true") .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)); @Inject diff --git a/embedding-stores/chroma/deployment/src/test/java/io/quarkiverse/langchain4j/chroma/deployment/ChromaEmbeddingStoreTest.java b/embedding-stores/chroma/deployment/src/test/java/io/quarkiverse/langchain4j/chroma/deployment/ChromaEmbeddingStoreTest.java index 74a91c227..fcc37363a 100644 --- a/embedding-stores/chroma/deployment/src/test/java/io/quarkiverse/langchain4j/chroma/deployment/ChromaEmbeddingStoreTest.java +++ b/embedding-stores/chroma/deployment/src/test/java/io/quarkiverse/langchain4j/chroma/deployment/ChromaEmbeddingStoreTest.java @@ -6,7 +6,6 @@ import org.jboss.shrinkwrap.api.ShrinkWrap; import org.jboss.shrinkwrap.api.spec.JavaArchive; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.extension.RegisterExtension; import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; @@ -16,7 +15,6 @@ import io.quarkus.logging.Log; import io.quarkus.test.QuarkusUnitTest; -@Disabled("temporarily disabled until we figure out what's going on") class ChromaEmbeddingStoreTest extends EmbeddingStoreIT { @RegisterExtension @@ -47,6 +45,8 @@ protected ChromaEmbeddingStore embeddingStore() { embeddingStore = ChromaEmbeddingStore.builder() .baseUrl(chromaUrl) .collectionName(randomUUID()) + // .logRequests(true) + // .logResponses(true) .build(); } return embeddingStore; diff --git a/embedding-stores/chroma/runtime/src/main/java/io/quarkiverse/langchain4j/chroma/ChromaEmbeddingStore.java b/embedding-stores/chroma/runtime/src/main/java/io/quarkiverse/langchain4j/chroma/ChromaEmbeddingStore.java index d8051fa2c..57e47d1e9 100644 --- a/embedding-stores/chroma/runtime/src/main/java/io/quarkiverse/langchain4j/chroma/ChromaEmbeddingStore.java +++ b/embedding-stores/chroma/runtime/src/main/java/io/quarkiverse/langchain4j/chroma/ChromaEmbeddingStore.java @@ -4,7 +4,9 @@ import static dev.langchain4j.internal.Utils.randomUUID; import static java.time.Duration.ofSeconds; import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toList; +import static java.util.stream.StreamSupport.stream; import java.net.URI; import java.net.URISyntaxException; @@ -17,6 +19,10 @@ import jakarta.ws.rs.WebApplicationException; +import org.jboss.logging.Logger; +import org.jboss.resteasy.reactive.client.api.ClientLogger; +import org.jboss.resteasy.reactive.client.api.LoggingScope; + import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; @@ -31,6 +37,9 @@ import io.quarkiverse.langchain4j.chroma.runtime.QueryResponse; import io.quarkus.arc.impl.LazyValue; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.HttpClientRequest; +import io.vertx.core.http.HttpClientResponse; /** * Represents a store for embeddings using the Chroma backend. @@ -49,11 +58,14 @@ public class ChromaEmbeddingStore implements EmbeddingStore { * @param baseUrl The base URL of the Chroma service. * @param collectionName The name of the collection in the Chroma service. If not specified, "default" will be used. * @param timeout The timeout duration for the Chroma client. If not specified, 5 seconds will be used. + * @param logRequests Whether to log requests. + * @param logResponses Whether to log responses. */ - public ChromaEmbeddingStore(String baseUrl, String collectionName, Duration timeout) { + public ChromaEmbeddingStore(String baseUrl, String collectionName, Duration timeout, + boolean logRequests, boolean logResponses) { String effectiveCollectionName = getOrDefault(collectionName, "default"); - this.chromaClient = new ChromaClient(baseUrl, getOrDefault(timeout, ofSeconds(5))); + this.chromaClient = new ChromaClient(baseUrl, getOrDefault(timeout, ofSeconds(5)), logRequests, logResponses); this.collectionId = new LazyValue<>(new Supplier() { @Override @@ -79,6 +91,8 @@ public static class Builder { private String baseUrl; private String collectionName; private Duration timeout; + private boolean logRequests; + private boolean logResponses; /** * @param baseUrl The base URL of the Chroma service. @@ -107,8 +121,18 @@ public Builder timeout(Duration timeout) { return this; } + public Builder logRequests(boolean logRequests) { + this.logRequests = logRequests; + return this; + } + + public Builder logResponses(boolean logResponses) { + this.logResponses = logResponses; + return this; + } + public ChromaEmbeddingStore build() { - return new ChromaEmbeddingStore(this.baseUrl, this.collectionName, this.timeout); + return new ChromaEmbeddingStore(this.baseUrl, this.collectionName, this.timeout, logRequests, logResponses); } } @@ -237,13 +261,19 @@ private static class ChromaClient { private final ChromaCollectionsRestApi chromaApi; - ChromaClient(String baseUrl, Duration timeout) { + ChromaClient(String baseUrl, Duration timeout, boolean logRequests, boolean logResponses) { try { - chromaApi = QuarkusRestClientBuilder.newBuilder() + var builder = QuarkusRestClientBuilder.newBuilder() .baseUri(new URI(baseUrl)) .connectTimeout(timeout.toSeconds(), TimeUnit.SECONDS) - .readTimeout(timeout.toSeconds(), TimeUnit.SECONDS) - .build(ChromaCollectionsRestApi.class); + .readTimeout(timeout.toSeconds(), TimeUnit.SECONDS); + + if (logRequests || logResponses) { + builder.loggingScope(LoggingScope.REQUEST_RESPONSE); + builder.clientLogger(new ChromaEmbeddingStore.ChromaClientLogger(logRequests, logResponses)); + } + + chromaApi = builder.build(ChromaCollectionsRestApi.class); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -280,7 +310,84 @@ public void deleteAllEmbeddings(String collectionId, int dimension) { if (!queryResponse.getIds().get(0).isEmpty()) { DeleteEmbeddingsRequest request = new DeleteEmbeddingsRequest(queryResponse.getIds().get(0)); List deletedIds = chromaApi.deleteEmbeddings(collectionId, request); + // TODO: why do we have to do this twice? for some reason + // embeddings sometimes remain in the db after the first delete, + // even though the response says they were deleted + chromaApi.deleteEmbeddings(collectionId, request); } } } + + static class ChromaClientLogger implements ClientLogger { + private static final Logger log = Logger.getLogger(ChromaClientLogger.class); + + private final boolean logRequests; + private final boolean logResponses; + + public ChromaClientLogger(boolean logRequests, boolean logResponses) { + this.logRequests = logRequests; + this.logResponses = logResponses; + } + + @Override + public void setBodySize(int bodySize) { + // ignore + } + + @Override + public void logRequest(HttpClientRequest request, Buffer body, boolean omitBody) { + if (!logRequests || !log.isInfoEnabled()) { + return; + } + try { + log.infof("Request:\n- method: %s\n- url: %s\n- headers: %s\n- body: %s", + request.getMethod(), + request.absoluteURI(), + inOneLine(request.headers()), + bodyToString(body)); + } catch (Exception e) { + log.warn("Failed to log request", e); + } + } + + @Override + public void logResponse(HttpClientResponse response, boolean redirect) { + if (!logResponses || !log.isInfoEnabled()) { + return; + } + response.bodyHandler(new io.vertx.core.Handler<>() { + @Override + public void handle(Buffer body) { + try { + log.infof( + "Response:\n- status code: %s\n- headers: %s\n- body: %s", + response.statusCode(), + inOneLine(response.headers()), + bodyToString(body)); + } catch (Exception e) { + log.warn("Failed to log response", e); + } + } + }); + } + + private String bodyToString(Buffer body) { + if (body == null) { + return ""; + } + return body.toString(); + } + + private String inOneLine(io.vertx.core.MultiMap headers) { + + return stream(headers.spliterator(), false) + .map(header -> { + String headerKey = header.getKey(); + String headerValue = header.getValue(); + return String.format("[%s: %s]", headerKey, headerValue); + }) + .collect(joining(", ")); + } + + } } diff --git a/embedding-stores/chroma/runtime/src/main/java/io/quarkiverse/langchain4j/chroma/runtime/ChromaConfig.java b/embedding-stores/chroma/runtime/src/main/java/io/quarkiverse/langchain4j/chroma/runtime/ChromaConfig.java index 8d9b1d774..bcd6b785c 100644 --- a/embedding-stores/chroma/runtime/src/main/java/io/quarkiverse/langchain4j/chroma/runtime/ChromaConfig.java +++ b/embedding-stores/chroma/runtime/src/main/java/io/quarkiverse/langchain4j/chroma/runtime/ChromaConfig.java @@ -5,6 +5,7 @@ import java.time.Duration; import java.util.Optional; +import io.quarkus.runtime.annotations.ConfigDocDefault; import io.quarkus.runtime.annotations.ConfigRoot; import io.smallrye.config.ConfigMapping; import io.smallrye.config.WithDefault; @@ -29,4 +30,18 @@ public interface ChromaConfig { */ Optional timeout(); + /** + * Whether requests to Chroma should be logged + */ + @ConfigDocDefault("false") + @WithDefault("${quarkus.langchain4j.log-requests}") + Optional logRequests(); + + /** + * Whether responses from Chroma should be logged + */ + @ConfigDocDefault("false") + @WithDefault("${quarkus.langchain4j.log-requests}") + Optional logResponses(); + } diff --git a/embedding-stores/chroma/runtime/src/main/java/io/quarkiverse/langchain4j/chroma/runtime/ChromaRecorder.java b/embedding-stores/chroma/runtime/src/main/java/io/quarkiverse/langchain4j/chroma/runtime/ChromaRecorder.java index 0abb5af40..18dab4368 100644 --- a/embedding-stores/chroma/runtime/src/main/java/io/quarkiverse/langchain4j/chroma/runtime/ChromaRecorder.java +++ b/embedding-stores/chroma/runtime/src/main/java/io/quarkiverse/langchain4j/chroma/runtime/ChromaRecorder.java @@ -14,7 +14,9 @@ public Supplier chromaStoreSupplier(ChromaConfig config) { public ChromaEmbeddingStore get() { return new ChromaEmbeddingStore(config.url(), config.collectionName(), - config.timeout().orElse(Duration.ofSeconds(5))); + config.timeout().orElse(Duration.ofSeconds(5)), + config.logRequests().orElse(false), + config.logResponses().orElse(false)); } }; } From 3dad8cdf986e830626bcd74da84811961dae0985 Mon Sep 17 00:00:00 2001 From: Laurent Perez Date: Mon, 17 Jun 2024 20:41:04 +0000 Subject: [PATCH 06/11] (doc) small ollama tyop --- .../langchain4j/ollama/devservices/OllamaContainer.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model-providers/ollama/devservices/src/main/java/io/quarkiverse/langchain4j/ollama/devservices/OllamaContainer.java b/model-providers/ollama/devservices/src/main/java/io/quarkiverse/langchain4j/ollama/devservices/OllamaContainer.java index 2dc55d63e..7906d592f 100644 --- a/model-providers/ollama/devservices/src/main/java/io/quarkiverse/langchain4j/ollama/devservices/OllamaContainer.java +++ b/model-providers/ollama/devservices/src/main/java/io/quarkiverse/langchain4j/ollama/devservices/OllamaContainer.java @@ -106,7 +106,7 @@ static String getModelId(OllamaConfig config) { String modelId = ConfigProvider.getConfig().getOptionalValue("quarkus.langchain4j.ollama.chat-model.model-id", String.class).orElse(""); - // if not found search through named mailers until we find one + // if not found search through named models until we find one if ("".equals(modelId)) { // check for all configs for (String key : ConfigProvider.getConfig().getPropertyNames()) { From 5bd63219ca56751cff1f0cd708bcd27c552c14cc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 17 Jun 2024 22:57:55 +0000 Subject: [PATCH 07/11] Bump org.apache.maven.plugins:maven-failsafe-plugin from 3.2.5 to 3.3.0 Bumps [org.apache.maven.plugins:maven-failsafe-plugin](https://github.com/apache/maven-surefire) from 3.2.5 to 3.3.0. - [Release notes](https://github.com/apache/maven-surefire/releases) - [Commits](https://github.com/apache/maven-surefire/compare/surefire-3.2.5...surefire-3.3.0) --- updated-dependencies: - dependency-name: org.apache.maven.plugins:maven-failsafe-plugin dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- samples/chatbot-easy-rag/pom.xml | 2 +- samples/chatbot/pom.xml | 2 +- samples/cli-translator/pom.xml | 2 +- samples/email-a-poem/pom.xml | 2 +- samples/fraud-detection/pom.xml | 2 +- samples/review-triage/pom.xml | 2 +- samples/secure-fraud-detection/pom.xml | 2 +- samples/sql-chatbot/pom.xml | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/samples/chatbot-easy-rag/pom.xml b/samples/chatbot-easy-rag/pom.xml index 28283a8bc..c03457180 100644 --- a/samples/chatbot-easy-rag/pom.xml +++ b/samples/chatbot-easy-rag/pom.xml @@ -115,7 +115,7 @@ maven-failsafe-plugin - 3.2.5 + 3.3.0 diff --git a/samples/chatbot/pom.xml b/samples/chatbot/pom.xml index 28ad53fe1..cf7e1f86d 100644 --- a/samples/chatbot/pom.xml +++ b/samples/chatbot/pom.xml @@ -115,7 +115,7 @@ maven-failsafe-plugin - 3.2.5 + 3.3.0 diff --git a/samples/cli-translator/pom.xml b/samples/cli-translator/pom.xml index ca4ec4795..e7a54001d 100644 --- a/samples/cli-translator/pom.xml +++ b/samples/cli-translator/pom.xml @@ -88,7 +88,7 @@ maven-failsafe-plugin - 3.2.5 + 3.3.0 diff --git a/samples/email-a-poem/pom.xml b/samples/email-a-poem/pom.xml index d4b41e330..d8776d18c 100644 --- a/samples/email-a-poem/pom.xml +++ b/samples/email-a-poem/pom.xml @@ -120,7 +120,7 @@ maven-failsafe-plugin - 3.2.5 + 3.3.0 diff --git a/samples/fraud-detection/pom.xml b/samples/fraud-detection/pom.xml index 7b82a7b20..605eddf73 100644 --- a/samples/fraud-detection/pom.xml +++ b/samples/fraud-detection/pom.xml @@ -99,7 +99,7 @@ maven-failsafe-plugin - 3.2.5 + 3.3.0 diff --git a/samples/review-triage/pom.xml b/samples/review-triage/pom.xml index 39a0a5656..999a95289 100644 --- a/samples/review-triage/pom.xml +++ b/samples/review-triage/pom.xml @@ -100,7 +100,7 @@ maven-failsafe-plugin - 3.2.5 + 3.3.0 diff --git a/samples/secure-fraud-detection/pom.xml b/samples/secure-fraud-detection/pom.xml index fc976c2fc..598e1d330 100644 --- a/samples/secure-fraud-detection/pom.xml +++ b/samples/secure-fraud-detection/pom.xml @@ -107,7 +107,7 @@ maven-failsafe-plugin - 3.2.5 + 3.3.0 diff --git a/samples/sql-chatbot/pom.xml b/samples/sql-chatbot/pom.xml index 5e192426e..915777f5e 100644 --- a/samples/sql-chatbot/pom.xml +++ b/samples/sql-chatbot/pom.xml @@ -123,7 +123,7 @@ maven-failsafe-plugin - 3.2.5 + 3.3.0 From 3ebba16e40f57234fcf6ae3f5c215c88aadbf393 Mon Sep 17 00:00:00 2001 From: Jan Martiska Date: Tue, 18 Jun 2024 09:53:23 +0200 Subject: [PATCH 08/11] Enable RAG together with streaming chat in the Dev UI --- .../src/main/resources/dev-ui/qwc-chat.js | 55 ++++++++++--------- .../runtime/devui/ChatJsonRPCService.java | 11 +++- 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/core/deployment/src/main/resources/dev-ui/qwc-chat.js b/core/deployment/src/main/resources/dev-ui/qwc-chat.js index 9f212b5b2..d87125304 100644 --- a/core/deployment/src/main/resources/dev-ui/qwc-chat.js +++ b/core/deployment/src/main/resources/dev-ui/qwc-chat.js @@ -124,10 +124,10 @@ export class QwcChat extends LitElement { this._streamingChatEnabled = this._streamingChatEnabled && !this._ragEnabled; this.render(); }}"/> -

{ - if (jsonRpcResponse.result.error) { - this._showError(jsonRpcResponse.result.error); - this._hideProgressBar(); - } else if (jsonRpcResponse.result.message) { - this._updateMessage(index, jsonRpcResponse.result.message); - this._hideProgressBar(); - } else { - msg += jsonRpcResponse.result.token; - this._updateMessage(index, msg); - } - }) - .onError((error) => { + try { + this._observer = this.jsonRpc.streamingChat({message: message, ragEnabled: this._ragEnabled}) + .onNext(jsonRpcResponse => { + if (jsonRpcResponse.result.error) { + this._showError(jsonRpcResponse.result.error); + this._hideProgressBar(); + } else if (jsonRpcResponse.result.augmentedMessage) { + // replace the last user message with the augmented message + this._updateMessage(index - 1, jsonRpcResponse.result.augmentedMessage); + } else if (jsonRpcResponse.result.message) { + this._updateMessage(index, jsonRpcResponse.result.message); + this._hideProgressBar(); + } else { + msg += jsonRpcResponse.result.token; + this._updateMessage(index, msg); + } + }) + .onError((error) => { + this._showError(error); + this._hideProgressBar(); + }); + } catch (error) { this._showError(error); this._hideProgressBar(); - }); - } catch(error) { - this._showError(error); - this._hideProgressBar(); - } + } } else { this.jsonRpc.chat({message: message, ragEnabled: this._ragEnabled}).then(jsonRpcResponse => { this._showResponse(jsonRpcResponse); @@ -229,7 +232,7 @@ export class QwcChat extends LitElement { } } - } + } _showResponse(jsonRpcResponse) { if (jsonRpcResponse.result === false) { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/devui/ChatJsonRPCService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/devui/ChatJsonRPCService.java index 378ee93f1..717905c13 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/devui/ChatJsonRPCService.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/devui/ChatJsonRPCService.java @@ -27,6 +27,7 @@ import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; +import dev.langchain4j.rag.AugmentationRequest; import dev.langchain4j.rag.RetrievalAugmentor; import dev.langchain4j.rag.query.Metadata; import io.quarkiverse.langchain4j.runtime.ToolsRecorder; @@ -149,9 +150,11 @@ public Multi streamingChat(String message, boolean ragEnabled) { if (retrievalAugmentor != null && ragEnabled) { UserMessage userMessage = UserMessage.from(message); Metadata metadata = Metadata.from(userMessage, currentMemoryId.get(), memory.messages()); - memory.add(retrievalAugmentor.augment(userMessage, metadata)); + AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata); + ChatMessage augmentedMessage = retrievalAugmentor.augment(augmentationRequest).chatMessage(); + memory.add(augmentedMessage); + em.emit(new JsonObject().put("augmentedMessage", augmentedMessage.text())); } else { - memory.add(new UserMessage(message)); } @@ -201,7 +204,9 @@ public ChatResultPojo chat(String message, boolean ragEnabled) { if (retrievalAugmentor != null && ragEnabled) { UserMessage userMessage = UserMessage.from(message); Metadata metadata = Metadata.from(userMessage, currentMemoryId.get(), memory.messages()); - memory.add(retrievalAugmentor.augment(userMessage, metadata)); + AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata); + ChatMessage augmentedMessage = retrievalAugmentor.augment(augmentationRequest).chatMessage(); + memory.add(augmentedMessage); } else { memory.add(new UserMessage(message)); } From 5c3ac136be28c471a03ae0794bc0cea451c2ddd0 Mon Sep 17 00:00:00 2001 From: Jan Martiska Date: Tue, 18 Jun 2024 10:55:17 +0200 Subject: [PATCH 09/11] Invoke tools together with Streaming chat in the Dev UI --- .../src/main/resources/dev-ui/qwc-chat.js | 24 ++++- .../runtime/devui/ChatJsonRPCService.java | 98 +++++++++++++++---- 2 files changed, 99 insertions(+), 23 deletions(-) diff --git a/core/deployment/src/main/resources/dev-ui/qwc-chat.js b/core/deployment/src/main/resources/dev-ui/qwc-chat.js index d87125304..5712d504d 100644 --- a/core/deployment/src/main/resources/dev-ui/qwc-chat.js +++ b/core/deployment/src/main/resources/dev-ui/qwc-chat.js @@ -191,12 +191,12 @@ export class QwcChat extends LitElement { let message = e.detail.value; if (message && message.trim().length > 0) { this._cementSystemMessage(); - this._addUserMessage(message); + var indexUserMessage = this._addUserMessage(message); this._showProgressBar(); if (this._streamingChatEnabled) { var msg = ""; - var index = this._addBotMessage(msg); + var index = null; try { this._observer = this.jsonRpc.streamingChat({message: message, ragEnabled: this._ragEnabled}) .onNext(jsonRpcResponse => { @@ -205,11 +205,25 @@ export class QwcChat extends LitElement { this._hideProgressBar(); } else if (jsonRpcResponse.result.augmentedMessage) { // replace the last user message with the augmented message - this._updateMessage(index - 1, jsonRpcResponse.result.augmentedMessage); + this._updateMessage(indexUserMessage, jsonRpcResponse.result.augmentedMessage); + } else if (jsonRpcResponse.result.toolExecutionRequest) { + var item = jsonRpcResponse.result.toolExecutionRequest; + this._addToolMessage(`Request to execute the following tool: + Request ID = ${item.id}, + tool name = ${item.name}, + arguments = ${item.arguments}`); + } else if (jsonRpcResponse.result.toolExecutionResult) { + var item = jsonRpcResponse.result.toolExecutionResult; + this._addToolMessage(`Tool execution result for request ID = ${item.id}, + tool name = ${item.toolName}, + status = ${item.text}`); } else if (jsonRpcResponse.result.message) { this._updateMessage(index, jsonRpcResponse.result.message); this._hideProgressBar(); - } else { + } else { // a new token from the stream + if(index === null) { + index = this._addBotMessage(msg); + } msg += jsonRpcResponse.result.token; this._updateMessage(index, msg); } @@ -321,7 +335,7 @@ status = ${item.toolExecutionResult.text}`); } _addUserMessage(message){ - this._addMessage(message, "Me", 1); + return this._addMessage(message, "Me", 1); } _addStyledMessage(message, user, colorIndex, className){ diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/devui/ChatJsonRPCService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/devui/ChatJsonRPCService.java index 717905c13..6afe796ac 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/devui/ChatJsonRPCService.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/devui/ChatJsonRPCService.java @@ -33,6 +33,8 @@ import io.quarkiverse.langchain4j.runtime.ToolsRecorder; import io.quarkiverse.langchain4j.runtime.devui.json.ChatMessagePojo; import io.quarkiverse.langchain4j.runtime.devui.json.ChatResultPojo; +import io.quarkiverse.langchain4j.runtime.devui.json.ToolExecutionRequestPojo; +import io.quarkiverse.langchain4j.runtime.devui.json.ToolExecutionResultPojo; import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutor; import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutorFactory; import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo; @@ -40,11 +42,13 @@ import io.quarkus.arc.Arc; import io.quarkus.logging.Log; import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.subscription.MultiEmitter; import io.vertx.core.json.JsonObject; @ActivateRequestContext public class ChatJsonRPCService { + public static final int MAX_SEQUENTIAL_TOOL_EXECUTIONS = 20; private final ChatLanguageModel model; private final Optional streamingModel; @@ -147,6 +151,7 @@ public Multi streamingChat(String message, boolean ragEnabled) { return Multi.createFrom().emitter(em -> { try { + // invoke RAG is applicable if (retrievalAugmentor != null && ragEnabled) { UserMessage userMessage = UserMessage.from(message); Metadata metadata = Metadata.from(userMessage, currentMemoryId.get(), memory.messages()); @@ -160,25 +165,31 @@ public Multi streamingChat(String message, boolean ragEnabled) { StreamingChatLanguageModel streamingModel = this.streamingModel.orElseThrow(IllegalStateException::new); - streamingModel.generate(memory.messages(), new StreamingResponseHandler() { - @Override - public void onComplete(Response response) { - memory.add(response.content()); - String message = response.content().text(); - em.emit(new JsonObject().put("message", message)); - em.complete(); - } + // invoke tools if applicable + Response modelResponse; + if (toolSpecifications.isEmpty()) { + streamingModel.generate(memory.messages(), new StreamingResponseHandler() { + @Override + public void onComplete(Response response) { + memory.add(response.content()); + String message = response.content().text(); + em.emit(new JsonObject().put("message", message)); + em.complete(); + } - @Override - public void onNext(String token) { - em.emit(new JsonObject().put("token", token)); - } + @Override + public void onNext(String token) { + em.emit(new JsonObject().put("token", token)); + } - @Override - public void onError(Throwable error) { - em.fail(error); - } - }); + @Override + public void onError(Throwable error) { + em.fail(error); + } + }); + } else { + executeWithToolsAndStreaming(memory, em, MAX_SEQUENTIAL_TOOL_EXECUTIONS); + } } catch (Throwable t) { // restore the memory from the backup memory.clear(); @@ -231,7 +242,7 @@ public ChatResultPojo chat(String message, boolean ragEnabled) { // FIXME: this was basically copied from `dev.langchain4j.service.DefaultAiServices`, // maybe it could be extracted into a reusable piece of code - public Response executeWithTools(ChatMemory memory) { + private Response executeWithTools(ChatMemory memory) { Response response = model.generate(memory.messages(), toolSpecifications); int MAX_SEQUENTIAL_TOOL_EXECUTIONS = 20; int executionsLeft = MAX_SEQUENTIAL_TOOL_EXECUTIONS; @@ -258,4 +269,55 @@ public Response executeWithTools(ChatMemory memory) { return Response.from(response.content(), new TokenUsage(), response.finishReason()); } + private void executeWithToolsAndStreaming(ChatMemory memory, + MultiEmitter em, + int toolExecutionsLeft) { + toolExecutionsLeft--; + if (toolExecutionsLeft == 0) { + throw new RuntimeException( + "Something is wrong, exceeded " + MAX_SEQUENTIAL_TOOL_EXECUTIONS + " sequential tool executions"); + } + int finalToolExecutionsLeft = toolExecutionsLeft; + streamingModel.get().generate(memory.messages(), toolSpecifications, new StreamingResponseHandler() { + @Override + public void onComplete(Response response) { + AiMessage aiMessage = response.content(); + memory.add(aiMessage); + if (!aiMessage.hasToolExecutionRequests()) { + em.emit(new JsonObject().put("message", aiMessage.text())); + em.complete(); + } else { + for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) { + ToolExecutor toolExecutor = toolExecutors.get(toolExecutionRequest.name()); + ToolExecutionRequestPojo toolExecutionRequestPojo = new ToolExecutionRequestPojo( + toolExecutionRequest.id(), + toolExecutionRequest.name(), + toolExecutionRequest.arguments()); + em.emit(new JsonObject().put("toolExecutionRequest", toolExecutionRequestPojo)); + String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, currentMemoryId.get()); + ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from( + toolExecutionRequest, + toolExecutionResult); + memory.add(toolExecutionResultMessage); + ToolExecutionResultPojo toolExecutionResultPojo = new ToolExecutionResultPojo( + toolExecutionResultMessage.id(), + toolExecutionResultMessage.toolName(), toolExecutionResultMessage.text()); + em.emit(new JsonObject().put("toolExecutionResult", toolExecutionResultPojo)); + } + executeWithToolsAndStreaming(memory, em, finalToolExecutionsLeft); + } + } + + @Override + public void onNext(String token) { + em.emit(new JsonObject().put("token", token)); + } + + @Override + public void onError(Throwable error) { + throw new RuntimeException(error); + } + }); + } + } From 02bd03fea5e0ffb73bedb236e8a91851c2c07d08 Mon Sep 17 00:00:00 2001 From: Jan Martiska Date: Tue, 18 Jun 2024 12:24:25 +0200 Subject: [PATCH 10/11] Avoid blocking the event loop with streaming+devui --- .../runtime/devui/ChatJsonRPCService.java | 53 ++++++++++--------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/devui/ChatJsonRPCService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/devui/ChatJsonRPCService.java index 6afe796ac..a62685a61 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/devui/ChatJsonRPCService.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/devui/ChatJsonRPCService.java @@ -42,6 +42,7 @@ import io.quarkus.arc.Arc; import io.quarkus.logging.Log; import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.infrastructure.Infrastructure; import io.smallrye.mutiny.subscription.MultiEmitter; import io.vertx.core.json.JsonObject; @@ -149,7 +150,7 @@ public Multi streamingChat(String message, boolean ragEnabled) { // removing single messages List chatMemoryBackup = memory.messages(); - return Multi.createFrom().emitter(em -> { + Multi stream = Multi.createFrom().emitter(em -> { try { // invoke RAG is applicable if (retrievalAugmentor != null && ragEnabled) { @@ -198,6 +199,8 @@ public void onError(Throwable error) { em.fail(t); } }); + // run on a worker thread because the retrieval augmentor might be blocking + return stream.runSubscriptionOn(Infrastructure.getDefaultWorkerPool()); } public ChatResultPojo chat(String message, boolean ragEnabled) { @@ -281,31 +284,31 @@ private void executeWithToolsAndStreaming(ChatMemory memory, streamingModel.get().generate(memory.messages(), toolSpecifications, new StreamingResponseHandler() { @Override public void onComplete(Response response) { - AiMessage aiMessage = response.content(); - memory.add(aiMessage); - if (!aiMessage.hasToolExecutionRequests()) { - em.emit(new JsonObject().put("message", aiMessage.text())); - em.complete(); - } else { - for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) { - ToolExecutor toolExecutor = toolExecutors.get(toolExecutionRequest.name()); - ToolExecutionRequestPojo toolExecutionRequestPojo = new ToolExecutionRequestPojo( - toolExecutionRequest.id(), - toolExecutionRequest.name(), - toolExecutionRequest.arguments()); - em.emit(new JsonObject().put("toolExecutionRequest", toolExecutionRequestPojo)); - String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, currentMemoryId.get()); - ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from( - toolExecutionRequest, - toolExecutionResult); - memory.add(toolExecutionResultMessage); - ToolExecutionResultPojo toolExecutionResultPojo = new ToolExecutionResultPojo( - toolExecutionResultMessage.id(), - toolExecutionResultMessage.toolName(), toolExecutionResultMessage.text()); - em.emit(new JsonObject().put("toolExecutionResult", toolExecutionResultPojo)); + // run on a worker thread because the tool might be blocking + Infrastructure.getDefaultExecutor().execute(() -> { + AiMessage aiMessage = response.content(); + memory.add(aiMessage); + if (!aiMessage.hasToolExecutionRequests()) { + em.emit(new JsonObject().put("message", aiMessage.text())); + em.complete(); + } else { + for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) { + ToolExecutor toolExecutor = toolExecutors.get(toolExecutionRequest.name()); + ToolExecutionRequestPojo toolExecutionRequestPojo = new ToolExecutionRequestPojo( + toolExecutionRequest.id(), toolExecutionRequest.name(), toolExecutionRequest.arguments()); + em.emit(new JsonObject().put("toolExecutionRequest", toolExecutionRequestPojo)); + String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, currentMemoryId.get()); + ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage + .from(toolExecutionRequest, toolExecutionResult); + memory.add(toolExecutionResultMessage); + ToolExecutionResultPojo toolExecutionResultPojo = new ToolExecutionResultPojo( + toolExecutionResultMessage.id(), toolExecutionResultMessage.toolName(), + toolExecutionResultMessage.text()); + em.emit(new JsonObject().put("toolExecutionResult", toolExecutionResultPojo)); + } + executeWithToolsAndStreaming(memory, em, finalToolExecutionsLeft); } - executeWithToolsAndStreaming(memory, em, finalToolExecutionsLeft); - } + }); } @Override From 567b97eebea910ff32f03f457a9031c4940eef07 Mon Sep 17 00:00:00 2001 From: Jan Martiska Date: Wed, 19 Jun 2024 13:25:45 +0200 Subject: [PATCH 11/11] Add a link to the workshop into our docs --- docs/modules/ROOT/pages/index.adoc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/modules/ROOT/pages/index.adoc b/docs/modules/ROOT/pages/index.adoc index ae06eb94a..c01eee948 100644 --- a/docs/modules/ROOT/pages/index.adoc +++ b/docs/modules/ROOT/pages/index.adoc @@ -20,6 +20,8 @@ image::llms-big-picture.png[width=600,align="center"] == Quick Overview +NOTE: If you're interested in a guided tutorial, there is also a https://github.com/quarkusio/quarkus-langchain4j-workshop[workshop] available that walks you through the basics of using Quarkus with LangChain4j. + To incorporate Quarkus LangChain4j into your Quarkus project, add the following Maven dependency: [source,xml,subs=attributes+]