From a1aff238c8425a12f540c62fcc08c62a089aab73 Mon Sep 17 00:00:00 2001 From: raulraja Date: Tue, 16 Jul 2024 12:35:42 +0200 Subject: [PATCH] Allow `@Schema` on Tool requests and read `description` from annotation when available. --- .../functional/xef/conversation/Schema.kt | 11 ++++ .../functional/xef/llm/ChatWithFunctions.kt | 31 ++++++++--- .../xef/functions/FunctionSchemaTests.kt | 55 +++++++++++++++++++ 3 files changed, 90 insertions(+), 7 deletions(-) create mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/conversation/Schema.kt create mode 100644 core/src/commonTest/kotlin/com/xebia/functional/xef/functions/FunctionSchemaTests.kt diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/conversation/Schema.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/conversation/Schema.kt new file mode 100644 index 000000000..1ed8da2af --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/conversation/Schema.kt @@ -0,0 +1,11 @@ +package com.xebia.functional.xef.conversation + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialInfo + +/** Schema for a tool request */ +@OptIn(ExperimentalSerializationApi::class) +@SerialInfo +@Retention(AnnotationRetention.RUNTIME) +@Target(AnnotationTarget.CLASS) +annotation class Schema(val value: String) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt index 269b23da3..bbbce296a 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt @@ -7,9 +7,12 @@ import com.xebia.functional.openai.generated.api.Chat import com.xebia.functional.openai.generated.model.* import com.xebia.functional.xef.AIError import com.xebia.functional.xef.AIEvent +import com.xebia.functional.xef.Config import com.xebia.functional.xef.Tool import com.xebia.functional.xef.conversation.AiDsl import com.xebia.functional.xef.conversation.Conversation +import com.xebia.functional.xef.conversation.Description +import com.xebia.functional.xef.conversation.Schema import com.xebia.functional.xef.llm.models.functions.buildJsonSchema import com.xebia.functional.xef.prompt.Prompt import com.xebia.functional.xef.prompt.PromptBuilder.Companion.tool @@ -21,17 +24,31 @@ import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.json.* -@OptIn(ExperimentalSerializationApi::class) fun chatFunction(descriptor: SerialDescriptor): FunctionObject { - val fnName = descriptor.serialName.substringAfterLast(".") - return chatFunction(fnName, buildJsonSchema(descriptor)) + val functionName = functionName(descriptor) + return FunctionObject( + name = functionName, + description = functionDescription(descriptor, functionName), + parameters = functionSchema(descriptor) + ) } -fun chatFunctions(descriptors: List): List = - descriptors.map(::chatFunction) +@OptIn(ExperimentalSerializationApi::class) +fun functionSchema(descriptor: SerialDescriptor): JsonObject = + descriptor.annotations.filterIsInstance().firstOrNull()?.value?.let { + Config.DEFAULT.json.decodeFromString(JsonObject.serializer(), it) + } ?: buildJsonSchema(descriptor) -fun chatFunction(fnName: String, schema: JsonObject): FunctionObject = - FunctionObject(fnName, "Generated function for $fnName", schema) +@OptIn(ExperimentalSerializationApi::class) +fun functionDescription(descriptor: SerialDescriptor, fnName: String): String = + (descriptor.annotations.filterIsInstance().firstOrNull()?.value + ?: defaultFunctionDescription(fnName)) + +fun defaultFunctionDescription(fnName: String): String = "Generated function for $fnName" + +@OptIn(ExperimentalSerializationApi::class) +fun functionName(descriptor: SerialDescriptor): String = + descriptor.serialName.substringAfterLast(".") data class UsageTracker( var llmCalls: Int = 0, diff --git a/core/src/commonTest/kotlin/com/xebia/functional/xef/functions/FunctionSchemaTests.kt b/core/src/commonTest/kotlin/com/xebia/functional/xef/functions/FunctionSchemaTests.kt new file mode 100644 index 000000000..6f6376af0 --- /dev/null +++ b/core/src/commonTest/kotlin/com/xebia/functional/xef/functions/FunctionSchemaTests.kt @@ -0,0 +1,55 @@ +package com.xebia.functional.xef.functions + +import com.xebia.functional.xef.conversation.Description +import com.xebia.functional.xef.conversation.Schema +import com.xebia.functional.xef.llm.chatFunction +import com.xebia.functional.xef.llm.defaultFunctionDescription +import com.xebia.functional.xef.llm.functionName +import com.xebia.functional.xef.llm.models.functions.buildJsonSchema +import io.kotest.core.spec.style.StringSpec +import io.kotest.matchers.shouldBe +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject + +class FunctionSchemaTests : + StringSpec({ + "Request has default description" { + val descriptor = Request.serializer().descriptor + val function = chatFunction(descriptor) + val fnName = functionName(descriptor) + function.description shouldBe defaultFunctionDescription(fnName) + } + + "Description can be set on request" { + val descriptor = RequestWithDescription.serializer().descriptor + val function = chatFunction(descriptor) + function.description shouldBe "Request With Description" + } + + "Schema can be generated on request" { + val descriptor = Request.serializer().descriptor + val function = chatFunction(descriptor) + function.parameters shouldBe buildJsonSchema(descriptor) + } + + "Schema can be set on request" { + val descriptor = RequestWithSchema.serializer().descriptor + val function = chatFunction(descriptor) + function.parameters shouldBe JsonObject(emptyMap()) + } + }) { + + @Serializable data class Request(val input: String) + + @Serializable + @Description("Request With Description") + data class RequestWithDescription(val input: String) + + @Serializable + @Description("Request with schema") + @Schema(""" + { + } + """) + data class RequestWithSchema(val input: String) +}