Skip to content

Commit

Permalink
Merge pull request #1100 from edeandrea/migrate-to-JsonSchemaElement
Browse files Browse the repository at this point in the history
Migrate to the JsonSchemaElement API
  • Loading branch information
geoand authored Nov 25, 2024
2 parents 9ed4d03 + 37ccdf2 commit 71e0996
Show file tree
Hide file tree
Showing 14 changed files with 352 additions and 152 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package io.quarkiverse.langchain4j.runtime.tool;

import dev.langchain4j.model.chat.request.json.JsonArraySchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

public class JsonArraySchemaObjectSubstitution
implements ObjectSubstitution<JsonArraySchema, JsonArraySchemaObjectSubstitution.Serialized> {
@Override
public Serialized serialize(JsonArraySchema obj) {
return new Serialized(obj.description(), obj.items());
}

@Override
public JsonArraySchema deserialize(Serialized obj) {
return JsonArraySchema.builder()
.description(obj.description)
.items(obj.items)
.build();
}

public record Serialized(String description, JsonSchemaElement items) {
@RecordableConstructor
public Serialized {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.quarkiverse.langchain4j.runtime.tool;

import dev.langchain4j.model.chat.request.json.JsonBooleanSchema;
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

public class JsonBooleanSchemaObjectSubstitution
implements ObjectSubstitution<JsonBooleanSchema, JsonBooleanSchemaObjectSubstitution.Serialized> {
@Override
public Serialized serialize(JsonBooleanSchema obj) {
return new Serialized(obj.description());
}

@Override
public JsonBooleanSchema deserialize(Serialized obj) {
return JsonBooleanSchema.builder()
.description(obj.description)
.build();
}

public record Serialized(String description) {
@RecordableConstructor
public Serialized {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package io.quarkiverse.langchain4j.runtime.tool;

import java.util.List;

import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

public class JsonEnumSchemaObjectSubstitution
implements ObjectSubstitution<JsonEnumSchema, JsonEnumSchemaObjectSubstitution.Serialized> {
@Override
public Serialized serialize(JsonEnumSchema obj) {
return new Serialized(obj.description(), obj.enumValues());
}

@Override
public JsonEnumSchema deserialize(Serialized obj) {
return JsonEnumSchema.builder()
.description(obj.description)
.enumValues(obj.enumValues)
.build();
}

public record Serialized(String description, List<String> enumValues) {
@RecordableConstructor
public Serialized {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.quarkiverse.langchain4j.runtime.tool;

import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

public final class JsonIntegerSchemaObjectSubstitution
implements ObjectSubstitution<JsonIntegerSchema, JsonIntegerSchemaObjectSubstitution.Serialized> {
@Override
public Serialized serialize(JsonIntegerSchema obj) {
return new Serialized(obj.description());
}

@Override
public JsonIntegerSchema deserialize(Serialized obj) {
return JsonIntegerSchema.builder()
.description(obj.description)
.build();
}

public record Serialized(String description) {
@RecordableConstructor
public Serialized {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.quarkiverse.langchain4j.runtime.tool;

import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

public class JsonNumberSchemaObjectSubstitution
implements ObjectSubstitution<JsonNumberSchema, JsonNumberSchemaObjectSubstitution.Serialized> {
@Override
public Serialized serialize(JsonNumberSchema obj) {
return new Serialized(obj.description());
}

@Override
public JsonNumberSchema deserialize(Serialized obj) {
return JsonNumberSchema.builder()
.description(obj.description)
.build();
}

public record Serialized(String description) {
@RecordableConstructor
public Serialized {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package io.quarkiverse.langchain4j.runtime.tool;

import java.util.List;
import java.util.Map;

import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

public class JsonObjectSchemaObjectSubstitution
implements ObjectSubstitution<JsonObjectSchema, JsonObjectSchemaObjectSubstitution.Serialized> {
@Override
public Serialized serialize(JsonObjectSchema obj) {
return new Serialized(obj.description(), obj.properties(), obj.required(), obj.additionalProperties(),
obj.definitions());
}

@Override
public JsonObjectSchema deserialize(Serialized obj) {
return JsonObjectSchema.builder()
.description(obj.description)
.properties(obj.properties)
.required(obj.required)
.additionalProperties(obj.additionalProperties)
.definitions(obj.definitions)
.build();
}

public record Serialized(String description, Map<String, JsonSchemaElement> properties, List<String> required,
Boolean additionalProperties, Map<String, JsonSchemaElement> definitions) {
@RecordableConstructor
public Serialized {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package io.quarkiverse.langchain4j.runtime.tool;

import dev.langchain4j.model.chat.request.json.JsonReferenceSchema;
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

public class JsonReferenceSchemaObjectSubstitution
implements ObjectSubstitution<JsonReferenceSchema, JsonReferenceSchemaObjectSubstitution.Serialized> {
public Serialized serialize(JsonReferenceSchema obj) {
return new Serialized(obj.reference());
}

public JsonReferenceSchema deserialize(Serialized obj) {
return JsonReferenceSchema.builder()
.reference(obj.reference)
.build();
}

public record Serialized(String reference) {
@RecordableConstructor
public Serialized {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.quarkiverse.langchain4j.runtime.tool;

import dev.langchain4j.model.chat.request.json.JsonStringSchema;
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

public final class JsonStringSchemaObjectSubstitution
implements ObjectSubstitution<JsonStringSchema, JsonStringSchemaObjectSubstitution.Serialized> {
@Override
public Serialized serialize(JsonStringSchema obj) {
return new Serialized(obj.description());
}

@Override
public JsonStringSchema deserialize(Serialized obj) {
return JsonStringSchema.builder()
.description(obj.description)
.build();
}

public record Serialized(String description) {
@RecordableConstructor
public Serialized {
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@

import org.jboss.resteasy.reactive.RestQuery;

import dev.langchain4j.agent.tool.P;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.output.structured.Description;
import io.quarkiverse.langchain4j.RegisterAiService;

@Path("assistant-with-tool")
Expand All @@ -27,8 +29,12 @@ public AssistantWithToolsResource(Assistant assistant) {
this.assistant = assistant;
}

@Description("Some test data")
public static class TestData {
@Description("The foo field")
String foo;

@Description("The bar field")
Integer bar;
Double baz;

Expand All @@ -54,8 +60,8 @@ public interface Assistant {
public static class Calculator {

@Tool("Calculates the length of a string")
int stringLength(String s) {
return s.length();
int stringLength(@P(value = "The string to compute the length of", required = false) String s) {
return (s == null) ? 0 : s.length();
}

@Tool("Calculates the sum of two numbers")
Expand All @@ -80,7 +86,7 @@ public TestData evaluateTestObject(List<TestData> data) {
}

@Tool("Calculates all factors of the provided integer.")
List<Integer> getFactors(int x) {
List<Integer> getFactors(@P("The integer to get factor") int x) {
return java.util.stream.IntStream.rangeClosed(1, x)
.filter(i -> x % i == 0)
.boxed()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,7 @@ static Tool toTool(ToolSpecification toolSpecification) {
.name(toolSpecification.name())
.description(toolSpecification.description());

if (toolSpecification.toolParameters() != null) {
for (Map.Entry<String, Map<String, Object>> p : toolSpecification.toolParameters().properties().entrySet()) {
builder.addParameter(p.getKey(), p.getValue(),
toolSpecification.toolParameters().required().contains(p.getKey()));
}
} else if (toolSpecification.parameters() != null) {
if (toolSpecification.parameters() != null) {
for (Map.Entry<String, JsonSchemaElement> p : toolSpecification.parameters().properties().entrySet()) {
builder.addParameter(p.getKey(), JsonSchemaElementHelper.toMap(p.getValue()),
toolSpecification.parameters().required().contains(p.getKey()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import com.fasterxml.jackson.core.type.TypeReference;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
Expand Down Expand Up @@ -142,23 +141,12 @@ static List<Tool> toTools(Collection<ToolSpecification> toolSpecifications) {
}

private static Tool toTool(ToolSpecification toolSpecification) {
Tool.Function.Parameters functionParameters;
if (toolSpecification.toolParameters() != null) {
functionParameters = toFunctionParameters(toolSpecification.toolParameters());
} else {
functionParameters = toFunctionParameters(toolSpecification.parameters());
}
Tool.Function.Parameters functionParameters = toFunctionParameters(toolSpecification.parameters());

return new Tool(Tool.Type.FUNCTION, new Tool.Function(toolSpecification.name(), toolSpecification.description(),
functionParameters));
}

private static Tool.Function.Parameters toFunctionParameters(ToolParameters toolParameters) {
if (toolParameters == null) {
return Tool.Function.Parameters.empty();
}
return Tool.Function.Parameters.objectType(toolParameters.properties(), toolParameters.required());
}

private static Tool.Function.Parameters toFunctionParameters(JsonObjectSchema parameters) {
if (parameters == null) {
return Tool.Function.Parameters.empty();
Expand Down
Loading

0 comments on commit 71e0996

Please sign in to comment.