Skip to content

Commit

Permalink
Update to Langchain4J 0.25
Browse files Browse the repository at this point in the history
Closes: #185
  • Loading branch information
geoand committed Dec 23, 2023
1 parent 88f2dd4 commit d513115
Show file tree
Hide file tree
Showing 76 changed files with 2,012 additions and 137 deletions.
8 changes: 7 additions & 1 deletion chroma/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,14 @@
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
<version>${langchain4j.version}</version>
<version>${langchain4j-embeddings.version}</version>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>
<build>
Expand Down
8 changes: 7 additions & 1 deletion core/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,14 @@
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
<version>${langchain4j.version}</version>
<version>${langchain4j-embeddings.version}</version>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>
<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.CHAT_MODEL;
import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.EMBEDDING_MODEL;
import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.IMAGE_MODEL;
import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.MODERATION_MODEL;
import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.STREAMING_CHAT_MODEL;

Expand Down Expand Up @@ -48,18 +49,21 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
List<ChatModelProviderCandidateBuildItem> chatCandidateItems,
List<EmbeddingModelProviderCandidateBuildItem> embeddingCandidateItems,
List<ModerationModelProviderCandidateBuildItem> moderationCandidateItems,
List<ImageModelProviderCandidateBuildItem> imageCandidateItems,
List<RequestChatModelBeanBuildItem> requestChatModelBeanItems,
List<RequestModerationModelBeanBuildItem> requestModerationModelBeanBuildItems,
LangChain4jBuildConfig buildConfig,
BuildProducer<SelectedChatModelProviderBuildItem> selectedChatProducer,
BuildProducer<SelectedEmbeddingModelCandidateBuildItem> selectedEmbeddingProducer,
BuildProducer<SelectedModerationModelProviderBuildItem> selectedModerationProducer,
BuildProducer<SelectedImageModelProviderBuildItem> selectedImageProducer,
List<InProcessEmbeddingBuildItem> inProcessEmbeddingBuildItems) {

boolean chatModelBeanRequested = false;
boolean streamingChatModelBeanRequested = false;
boolean embeddingModelBeanRequested = false;
boolean moderationModelBeanRequested = false;
boolean imageModelBeanRequested = false;
for (InjectionPointInfo ip : beanDiscoveryFinished.getInjectionPoints()) {
DotName requiredName = ip.getRequiredType().name();
if (CHAT_MODEL.equals(requiredName)) {
Expand All @@ -70,6 +74,8 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
embeddingModelBeanRequested = true;
} else if (MODERATION_MODEL.equals(requiredName)) {
moderationModelBeanRequested = true;
} else if (IMAGE_MODEL.equals(requiredName)) {
imageModelBeanRequested = true;
}
}
if (!requestChatModelBeanItems.isEmpty()) {
Expand Down Expand Up @@ -107,6 +113,15 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
"ModerationModel",
"moderation-model")));
}
if (imageModelBeanRequested) {
selectedImageProducer.produce(
new SelectedImageModelProviderBuildItem(
selectProvider(
imageCandidateItems,
buildConfig.moderationModel().provider(),
"ImageModel",
"image-model")));
}

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.image.ImageModel;
import dev.langchain4j.model.input.structured.StructuredPrompt;
import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
import dev.langchain4j.model.moderation.ModerationModel;
Expand All @@ -28,6 +29,7 @@ public class Langchain4jDotNames {
public static final DotName STREAMING_CHAT_MODEL = DotName.createSimple(StreamingChatLanguageModel.class);
public static final DotName EMBEDDING_MODEL = DotName.createSimple(EmbeddingModel.class);
public static final DotName MODERATION_MODEL = DotName.createSimple(ModerationModel.class);
public static final DotName IMAGE_MODEL = DotName.createSimple(ImageModel.class);
static final DotName AI_SERVICES = DotName.createSimple(AiServices.class);
static final DotName CREATED_AWARE = DotName.createSimple(CreatedAware.class);
static final DotName SYSTEM_MESSAGE = DotName.createSimple(SystemMessage.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.quarkiverse.langchain4j.deployment.items;

import io.quarkus.builder.item.MultiBuildItem;

public final class ImageModelProviderCandidateBuildItem extends MultiBuildItem implements ProviderHolder {

private final String provider;

public ImageModelProviderCandidateBuildItem(String provider) {
this.provider = provider;
}

public String getProvider() {
return provider;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.quarkiverse.langchain4j.deployment.items;

import io.quarkus.builder.item.SimpleBuildItem;

public final class SelectedImageModelProviderBuildItem extends SimpleBuildItem {

private final String provider;

public SelectedImageModelProviderBuildItem(String provider) {
this.provider = provider;
}

public String getProvider() {
return provider;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageSerializer;
import dev.langchain4j.data.message.UserMessage;
import io.quarkus.test.QuarkusUnitTest;

Expand Down Expand Up @@ -93,16 +94,16 @@ void should_serialize_and_deserialize_list_with_all_types_of_messages() {
.name("calculator")
.arguments("{}")
.build()),
toolExecutionResultMessage("calculator", "4"));
toolExecutionResultMessage("12345", "calculator", "4"));

String json = messagesToJson(messages);
String json = ChatMessageSerializer.messagesToJson(messages);
assertThat(json).isEqualTo("[" +
"{\"text\":\"Hello from system\",\"type\":\"SYSTEM\"}," +
"{\"text\":\"Hello from user\",\"type\":\"USER\"}," +
"{\"name\":\"Klaus\",\"text\":\"Hello from Klaus\",\"type\":\"USER\"}," +
"{\"text\":\"Hello from AI\",\"type\":\"AI\"}," +
"{\"toolExecutionRequest\":{\"name\":\"calculator\",\"arguments\":\"{}\"},\"type\":\"AI\"}," +
"{\"toolName\":\"calculator\",\"text\":\"4\",\"type\":\"TOOL_EXECUTION_RESULT\"}" +
"{\"toolExecutionRequests\":[{\"name\":\"calculator\",\"arguments\":\"{}\"}],\"type\":\"AI\"}," +
"{\"text\":\"4\",\"id\":\"12345\",\"toolName\":\"calculator\",\"type\":\"TOOL_EXECUTION_RESULT\"}" +
"]");

List<ChatMessage> deserializedMessages = messagesFromJson(json);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import static org.assertj.core.data.Percentage.withPercentage;

import java.io.IOException;
import java.lang.reflect.Field;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
Expand Down Expand Up @@ -250,7 +252,7 @@ void should_serialize_to_and_deserialize_from_json() {
String json = originalEmbeddingStore.serializeToJson();
InMemoryEmbeddingStore<TextSegment> deserializedEmbeddingStore = InMemoryEmbeddingStore.fromJson(json);

assertThat(deserializedEmbeddingStore).isEqualTo(originalEmbeddingStore);
assertThat(entries(deserializedEmbeddingStore)).isEqualTo(entries(originalEmbeddingStore));
}

@Test
Expand All @@ -262,7 +264,7 @@ void should_serialize_to_and_deserialize_from_file() throws IOException {
originalEmbeddingStore.serializeToFile(filePath);
InMemoryEmbeddingStore<TextSegment> deserializedEmbeddingStore = InMemoryEmbeddingStore.fromFile(filePath);

assertThat(deserializedEmbeddingStore).isEqualTo(originalEmbeddingStore);
assertThat(entries(deserializedEmbeddingStore)).isEqualTo(entries(originalEmbeddingStore));
}

private InMemoryEmbeddingStore<TextSegment> createEmbeddingStore() {
Expand All @@ -279,4 +281,25 @@ private InMemoryEmbeddingStore<TextSegment> createEmbeddingStore() {

return embeddingStore;
}

private static final Field IN_MEMORY_EMBEDDING_STORE_FIELD;

static {
try {
Field f = InMemoryEmbeddingStore.class.getDeclaredField("entries");
f.setAccessible(true);
IN_MEMORY_EMBEDDING_STORE_FIELD = f;
} catch (NoSuchFieldException e) {
throw new RuntimeException(e);
}
}

private static CopyOnWriteArrayList<Object> entries(InMemoryEmbeddingStore<TextSegment> embeddingStore) {
try {
return (CopyOnWriteArrayList<Object>) IN_MEMORY_EMBEDDING_STORE_FIELD.get(embeddingStore);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}

}
16 changes: 8 additions & 8 deletions core/runtime/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -83,49 +83,49 @@
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
<version>${langchain4j.version}</version>
<version>${langchain4j-embeddings.version}</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
<version>${langchain4j.version}</version>
<version>${langchain4j-embeddings.version}</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-bge-small-en-q</artifactId>
<version>${langchain4j.version}</version>
<version>${langchain4j-embeddings.version}</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-bge-small-en</artifactId>
<version>${langchain4j.version}</version>
<version>${langchain4j-embeddings.version}</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-bge-small-zh-q</artifactId>
<version>${langchain4j.version}</version>
<version>${langchain4j-embeddings.version}</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
<version>${langchain4j.version}</version>
<version>${langchain4j-embeddings.version}</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-e5-small-v2-q</artifactId>
<version>${langchain4j.version}</version>
<version>${langchain4j-embeddings.version}</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-e5-small-v2</artifactId>
<version>${langchain4j.version}</version>
<version>${langchain4j-embeddings.version}</version>
<optional>true</optional>
</dependency>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.internal.Exceptions.runtime;
import static dev.langchain4j.service.AiServices.removeToolMessages;
import static dev.langchain4j.service.AiServices.verifyModerationIfNeeded;
import static dev.langchain4j.service.ServiceOutputParser.parse;
import static java.util.stream.Collectors.joining;

import java.lang.reflect.Array;
Expand Down Expand Up @@ -35,9 +37,9 @@
import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
import dev.langchain4j.model.moderation.Moderation;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.AiServiceTokenStream;
import dev.langchain4j.service.ServiceOutputParser;
import dev.langchain4j.service.TokenStream;
import io.quarkiverse.langchain4j.audit.Audit;
import io.quarkiverse.langchain4j.audit.AuditService;
Expand All @@ -50,6 +52,8 @@ public class AiServiceMethodImplementationSupport {

private static final Logger log = Logger.getLogger(AiServiceMethodImplementationSupport.class);

private static final int MAX_SEQUENTIAL_TOOL_EXECUTIONS = 10;

/**
* This method is called by the implementations of each ai service method.
*/
Expand Down Expand Up @@ -141,40 +145,50 @@ private static Object doImplement(AiServiceMethodCreateInfo createInfo, Object[]
Future<Moderation> moderationFuture = triggerModerationIfNeeded(context, createInfo, messages);

log.debug("Attempting to obtain AI response");
Response<AiMessage> response = context.toolSpecifications != null
? context.chatModel.generate(messages, context.toolSpecifications)
: context.chatModel.generate(messages);
Response<AiMessage> response = context.toolSpecifications == null
? context.chatModel.generate(messages)
: context.chatModel.generate(messages, context.toolSpecifications);
log.debug("AI response obtained");
if (audit != null) {
audit.addLLMToApplicationMessage(response);
}
TokenUsage tokenUsageAccumulator = response.tokenUsage();

verifyModerationIfNeeded(moderationFuture);

ToolExecutionRequest toolExecutionRequest;
while (true) { // TODO limit number of cycles
int executionsLeft = MAX_SEQUENTIAL_TOOL_EXECUTIONS;
while (true) {

if (executionsLeft-- == 0) {
throw runtime("Something is wrong, exceeded %s sequential tool executions",
MAX_SEQUENTIAL_TOOL_EXECUTIONS);
}

AiMessage aiMessage = response.content();

if (context.hasChatMemory()) {
context.chatMemory(memoryId).add(response.content());
}

toolExecutionRequest = response.content().toolExecutionRequest();
if (toolExecutionRequest == null) {
log.debug("No tool execution request found - computation is complete");
if (!aiMessage.hasToolExecutionRequests()) {
break;
}

ToolExecutor toolExecutor = context.toolExecutors.get(toolExecutionRequest.name());
log.debugv("Attempting to execute tool {0}", toolExecutionRequest);
String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, memoryId);
log.debugv("Result of {0} is '{1}'", toolExecutionRequest, toolExecutionResult);
ToolExecutionResultMessage toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest.name(),
toolExecutionResult);
if (audit != null) {
audit.addApplicationToLLMMessage(toolExecutionResultMessage);
}

ChatMemory chatMemory = context.chatMemory(memoryId);
chatMemory.add(toolExecutionResultMessage);

for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
log.debugv("Attempting to execute tool {0}", toolExecutionRequest);
ToolExecutor toolExecutor = context.toolExecutors.get(toolExecutionRequest.name());
String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, memoryId);
log.debugv("Result of {0} is '{1}'", toolExecutionRequest, toolExecutionResult);
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(
toolExecutionRequest,
toolExecutionResult);
if (audit != null) {
audit.addApplicationToLLMMessage(toolExecutionResultMessage);
}
chatMemory.add(toolExecutionResultMessage);
}

log.debug("Attempting to obtain AI response");
response = context.chatModel.generate(chatMemory.messages(), context.toolSpecifications);
Expand All @@ -183,9 +197,12 @@ private static Object doImplement(AiServiceMethodCreateInfo createInfo, Object[]
if (audit != null) {
audit.addLLMToApplicationMessage(response);
}

tokenUsageAccumulator = tokenUsageAccumulator.add(response.tokenUsage());
}

return ServiceOutputParser.parse(response, returnType);
response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason());
return parse(response, returnType);
}

private static Future<Moderation> triggerModerationIfNeeded(AiServiceContext context,
Expand Down
Loading

0 comments on commit d513115

Please sign in to comment.