Skip to content

Commit

Permalink
Merge pull request quarkiverse#1097 from jmartisk/lc4j-0.36.2
Browse files Browse the repository at this point in the history
Upgrade to LangChain4j 0.36.2
  • Loading branch information
geoand authored Nov 22, 2024
2 parents 0f77ff2 + 3168cce commit 2103f07
Show file tree
Hide file tree
Showing 16 changed files with 129 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import io.quarkus.deployment.builditem.FeatureBuildItem;
import io.quarkus.deployment.builditem.IndexDependencyBuildItem;
import io.quarkus.deployment.builditem.ShutdownContextBuildItem;
import io.quarkus.deployment.builditem.nativeimage.RuntimeInitializedClassBuildItem;
import io.quarkus.deployment.logging.LogCleanupFilterBuildItem;
import io.quarkus.runtime.configuration.ConfigurationException;

Expand Down Expand Up @@ -608,4 +609,10 @@ void logCleanupFilters(BuildProducer<LogCleanupFilterBuildItem> logCleanupFilter
logCleanupFilters
.produce(new LogCleanupFilterBuildItem("ai.djl.huggingface.tokenizers.jni.LibUtils", Level.INFO, "Extracting"));
}

@BuildStep
public void nativeSupport(BuildProducer<RuntimeInitializedClassBuildItem> producer) {
// RetryUtils initializes a java.lang.Random instance
producer.produce(new RuntimeInitializedClassBuildItem("dev.langchain4j.internal.RetryUtils"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.lang.reflect.Type;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand All @@ -15,6 +16,7 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.json.JsonReadFeature;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectWriter;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
Expand Down Expand Up @@ -60,6 +62,23 @@ public <T> T fromJson(String json, Class<T> type) {
}
}

@Override
public <T> T fromJson(String json, Type type) {
JavaType javaType = ObjectMapperHolder.MAPPER.constructType(type);
try {
String sanitizedJson = sanitize(json, javaType.getRawClass());
return ObjectMapperHolder.MAPPER.readValue(sanitizedJson, javaType);
} catch (JsonProcessingException e) {
if ((e instanceof JsonParseException) && (javaType.isEnumType())) {
// this is the case where LangChain4j simply passes the string value of the enum to Json.fromJson()
// and Jackson does not handle it
Class<? extends Enum> enumClass = javaType.getRawClass().asSubclass(Enum.class);
return (T) Enum.valueOf(enumClass, json);
}
throw new UncheckedIOException(e);
}
}

private <T> String sanitize(String original, Class<T> type) {
if (String.class.equals(type)) {
return original;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package io.quarkiverse.langchain4j.runtime.aiservice;

import static org.apache.commons.lang3.StringUtils.EMPTY;

import java.lang.reflect.Type;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -202,7 +200,7 @@ public String getUserMessageTemplate() {
Optional<String> userMessageTemplateOpt = this.getUserMessageInfo().template()
.flatMap(AiServiceMethodCreateInfo.TemplateInfo::text);

return userMessageTemplateOpt.orElse(EMPTY);
return userMessageTemplateOpt.orElse("");
}

public boolean isSwitchToWorkerThread() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.tool.ToolExecution;
import dev.langchain4j.service.tool.ToolExecutor;
import io.smallrye.mutiny.infrastructure.Infrastructure;
import io.vertx.core.Context;
Expand All @@ -38,6 +39,7 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon

private final Consumer<String> tokenHandler;
private final Consumer<Response<AiMessage>> completionHandler;
private final Consumer<ToolExecution> toolExecuteHandler;
private final Consumer<Throwable> errorHandler;

private final List<ChatMessage> temporaryMemory;
Expand All @@ -51,6 +53,7 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon
QuarkusAiServiceStreamingResponseHandler(AiServiceContext context,
Object memoryId,
Consumer<String> tokenHandler,
Consumer<ToolExecution> toolExecuteHandler,
Consumer<Response<AiMessage>> completionHandler,
Consumer<Throwable> errorHandler,
List<ChatMessage> temporaryMemory,
Expand All @@ -62,6 +65,7 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon

this.tokenHandler = ensureNotNull(tokenHandler, "tokenHandler");
this.completionHandler = completionHandler;
this.toolExecuteHandler = toolExecuteHandler;
this.errorHandler = errorHandler;

this.temporaryMemory = new ArrayList<>(temporaryMemory);
Expand Down Expand Up @@ -116,6 +120,12 @@ public void run() {
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(
toolExecutionRequest,
toolExecutionResult);
ToolExecution toolExecution = ToolExecution.builder()
.request(toolExecutionRequest).result(toolExecutionResult)
.build();
if (toolExecuteHandler != null) {
toolExecuteHandler.accept(toolExecution);
}
QuarkusAiServiceStreamingResponseHandler.this.addToMemory(toolExecutionResultMessage);
}

Expand All @@ -126,6 +136,7 @@ public void run() {
context,
memoryId,
tokenHandler,
toolExecuteHandler,
completionHandler,
errorHandler,
temporaryMemory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.tool.ToolExecution;
import dev.langchain4j.service.tool.ToolExecutor;
import io.vertx.core.Context;

Expand All @@ -44,12 +45,14 @@ public class QuarkusAiServiceTokenStream implements TokenStream {
private Consumer<List<Content>> contentsHandler;
private Consumer<Throwable> errorHandler;
private Consumer<Response<AiMessage>> completionHandler;
private Consumer<ToolExecution> toolExecuteHandler;

private int onNextInvoked;
private int onCompleteInvoked;
private int onRetrievedInvoked;
private int onErrorInvoked;
private int ignoreErrorsInvoked;
private int toolExecuteInvoked;

public QuarkusAiServiceTokenStream(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
Expand Down Expand Up @@ -82,6 +85,13 @@ public TokenStream onRetrieved(Consumer<List<Content>> contentsHandler) {
return this;
}

@Override
public TokenStream onToolExecuted(Consumer<ToolExecution> toolExecuteHandler) {
this.toolExecuteHandler = toolExecuteHandler;
this.toolExecuteInvoked++;
return this;
}

@Override
public TokenStream onComplete(Consumer<Response<AiMessage>> completionHandler) {
this.completionHandler = completionHandler;
Expand Down Expand Up @@ -110,6 +120,7 @@ public void start() {
context,
memoryId,
tokenHandler,
toolExecuteHandler,
completionHandler,
errorHandler,
initTemporaryMemory(context, messages),
Expand Down Expand Up @@ -150,6 +161,10 @@ private void validateConfiguration() {
throw new IllegalConfigurationException("onRetrieved must be invoked at most 1 time");
}

if (toolExecuteInvoked > 1) {
throw new IllegalConfigurationException("onToolExecuted must be invoked at most 1 time");
}

if (onErrorInvoked + ignoreErrorsInvoked != 1) {
throw new IllegalConfigurationException("One of onError or ignoreErrors must be invoked exactly 1 time");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

Expand All @@ -10,26 +11,36 @@ public class ToolSpecificationObjectSubstitution

@Override
public Serialized serialize(ToolSpecification obj) {
return new Serialized(obj.name(), obj.description(), obj.parameters());
return new Serialized(obj.name(), obj.description(), obj.toolParameters(), obj.parameters());
}

@Override
public ToolSpecification deserialize(Serialized obj) {
return ToolSpecification.builder()
ToolSpecification.Builder builder = ToolSpecification.builder()
.name(obj.name)
.description(obj.description)
.parameters(obj.parameters).build();
.description(obj.description);
if (obj.toolParameters != null) {
builder.parameters(obj.toolParameters);
}
if (obj.parameters != null) {
builder.parameters(obj.parameters);
}
return builder.build();
}

public static class Serialized {
private final String name;
private final String description;
private final ToolParameters parameters;
private final ToolParameters toolParameters;
private final JsonObjectSchema parameters;

@RecordableConstructor
public Serialized(String name, String description, ToolParameters parameters) {
public Serialized(String name, String description,
ToolParameters toolParameters,
JsonObjectSchema parameters) {
this.name = name;
this.description = description;
this.toolParameters = toolParameters;
this.parameters = parameters;
}

Expand All @@ -41,7 +52,11 @@ public String getDescription() {
return description;
}

public ToolParameters getParameters() {
public ToolParameters getToolParameters() {
return toolParameters;
}

public JsonObjectSchema getParameters() {
return parameters;
}

Expand Down
2 changes: 1 addition & 1 deletion docs/modules/ROOT/pages/includes/attributes.adoc
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
:project-version: 0.21.0
:langchain4j-version: 0.35.0
:langchain4j-version: 0.36.2
:examples-dir: ./../examples/
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ void blocking() {
"text" : "Hello, how are you today?"
} ]
} ],
"system" : [ ],
"max_tokens" : 1024,
"stream" : false,
"top_k" : 40
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ public void onComplete(Response<AiMessage> response) {
"text" : "Hello, how are you today?"
} ]
} ],
"system" : [ ],
"max_tokens" : 1024,
"stream" : true,
"top_k" : 40
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import com.github.tjake.jlama.safetensors.prompt.Tool;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import dev.langchain4j.model.chat.request.json.JsonSchemaElementHelper;
import dev.langchain4j.model.output.FinishReason;

/**
Expand Down Expand Up @@ -125,8 +127,16 @@ static Tool toTool(ToolSpecification toolSpecification) {
.name(toolSpecification.name())
.description(toolSpecification.description());

for (Map.Entry<String, Map<String, Object>> p : toolSpecification.parameters().properties().entrySet()) {
builder.addParameter(p.getKey(), p.getValue(), toolSpecification.parameters().required().contains(p.getKey()));
if (toolSpecification.toolParameters() != null) {
for (Map.Entry<String, Map<String, Object>> p : toolSpecification.toolParameters().properties().entrySet()) {
builder.addParameter(p.getKey(), p.getValue(),
toolSpecification.toolParameters().required().contains(p.getKey()));
}
} else if (toolSpecification.parameters() != null) {
for (Map.Entry<String, JsonSchemaElement> p : toolSpecification.parameters().properties().entrySet()) {
builder.addParameter(p.getKey(), JsonSchemaElementHelper.toMap(p.getValue()),
toolSpecification.parameters().required().contains(p.getKey()));
}
}

return Tool.from(builder.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
import dev.langchain4j.model.mistralai.internal.api.MistralAiEmbeddingRequest;
import dev.langchain4j.model.mistralai.internal.api.MistralAiEmbeddingResponse;
import dev.langchain4j.model.mistralai.internal.api.MistralAiModelResponse;
import dev.langchain4j.model.mistralai.internal.api.MistralAiModerationRequest;
import dev.langchain4j.model.mistralai.internal.api.MistralAiModerationResponse;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;
import io.quarkus.rest.client.reactive.NotBody;
import io.smallrye.mutiny.Multi;
Expand Down Expand Up @@ -82,6 +84,10 @@ Multi<MistralAiChatCompletionResponse> streamingChatCompletion(MistralAiChatComp
@GET
MistralAiModelResponse models(@NotBody String token);

@Path("moderations")
@POST
MistralAiModerationResponse moderation(MistralAiModerationRequest mistralAiModerationRequest, @NotBody String token);

/**
* The point of this is to properly set the {@code stream} value of the request
* so users don't have to remember to set it manually
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import dev.langchain4j.model.mistralai.internal.api.MistralAiEmbeddingRequest;
import dev.langchain4j.model.mistralai.internal.api.MistralAiEmbeddingResponse;
import dev.langchain4j.model.mistralai.internal.api.MistralAiModelResponse;
import dev.langchain4j.model.mistralai.internal.api.MistralAiModerationRequest;
import dev.langchain4j.model.mistralai.internal.api.MistralAiModerationResponse;
import dev.langchain4j.model.mistralai.internal.api.MistralAiUsage;
import dev.langchain4j.model.mistralai.internal.client.MistralAiClient;
import dev.langchain4j.model.mistralai.internal.client.MistralAiClientBuilderFactory;
Expand Down Expand Up @@ -116,6 +118,12 @@ public MistralAiEmbeddingResponse embedding(MistralAiEmbeddingRequest request) {
return restApi.embedding(request, apiKey);
}

// TODO: we don't provide support for MistralAiModerationModel yet
@Override
public MistralAiModerationResponse moderation(MistralAiModerationRequest mistralAiModerationRequest) {
return restApi.moderation(mistralAiModerationRequest, apiKey);
}

@Override
public MistralAiModelResponse listModels() {
return restApi.models(apiKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElementHelper;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;

// TODO: this could use a lot of refactoring
Expand Down Expand Up @@ -140,8 +142,14 @@ static List<Tool> toTools(Collection<ToolSpecification> toolSpecifications) {
}

private static Tool toTool(ToolSpecification toolSpecification) {
Tool.Function.Parameters functionParameters;
if (toolSpecification.toolParameters() != null) {
functionParameters = toFunctionParameters(toolSpecification.toolParameters());
} else {
functionParameters = toFunctionParameters(toolSpecification.parameters());
}
return new Tool(Tool.Type.FUNCTION, new Tool.Function(toolSpecification.name(), toolSpecification.description(),
toFunctionParameters(toolSpecification.parameters())));
functionParameters));
}

private static Tool.Function.Parameters toFunctionParameters(ToolParameters toolParameters) {
Expand All @@ -150,4 +158,13 @@ private static Tool.Function.Parameters toFunctionParameters(ToolParameters tool
}
return Tool.Function.Parameters.objectType(toolParameters.properties(), toolParameters.required());
}

private static Tool.Function.Parameters toFunctionParameters(JsonObjectSchema parameters) {
if (parameters == null) {
return Tool.Function.Parameters.empty();
}
return Tool.Function.Parameters.objectType(JsonSchemaElementHelper.toMap(parameters.properties()),
parameters.required());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public static Parameters objectType(Map<String, Map<String, Object>> properties,
public static Parameters empty() {
return new Parameters(OBJECT_TYPE, Collections.emptyMap(), Collections.emptyList());
}

}
}
}
Loading

0 comments on commit 2103f07

Please sign in to comment.