From bf6854bcb726e29968c0ddfa78df24a99d8fd933 Mon Sep 17 00:00:00 2001 From: raulraja Date: Tue, 11 Jun 2024 19:27:21 +0200 Subject: [PATCH] Support for creating tools with anonymous functions and passing descriptions on tool creations. Annotation based functions seem impossible in KMP, would only work on jvm target --- .../com/xebia/functional/xef/AIEvent.kt | 15 ++++ .../kotlin/com/xebia/functional/xef/Tool.kt | 77 ++++++++++++++++--- .../xef/conversation/Description.kt | 7 +- .../functional/xef/llm/ChatWithFunctions.kt | 4 +- .../functional/xef/llm/StreamedFunction.kt | 4 +- .../xef/dsl/chat/ParallelToolCalls.kt | 42 +++++----- .../dsl/chat/ParallelToolCallsAnonymous.kt | 45 +++++++++++ 7 files changed, 160 insertions(+), 34 deletions(-) create mode 100644 examples/src/main/kotlin/com/xebia/functional/xef/dsl/chat/ParallelToolCallsAnonymous.kt diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/AIEvent.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/AIEvent.kt index b8a424f87..17dd83333 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/AIEvent.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/AIEvent.kt @@ -18,4 +18,19 @@ sealed class AIEvent { val totalTokens: Int, ) } + + fun debugPrint(): Unit = + when (this) { + // emoji for start is: 🚀 + Start -> println("🚀 Starting...") + is Result -> println("🎉 $value") + is ToolExecutionRequest -> + println("🔧 Executing tool: ${tool.function.name} with input: $input") + is ToolExecutionResponse -> + println("🔨 Tool response: ${tool.function.name} resulted in: $output") + is Stop -> { + println("🛑 Stopping...") + println("📊 Usage: $usage") + } + } } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/Tool.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/Tool.kt index 31c70f480..9ca5e847d 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/Tool.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/Tool.kt @@ -1,11 +1,14 @@ package com.xebia.functional.xef import com.xebia.functional.openai.generated.model.FunctionObject +import com.xebia.functional.xef.conversation.Description import com.xebia.functional.xef.llm.FunctionCall import com.xebia.functional.xef.llm.StreamedFunction import com.xebia.functional.xef.llm.chatFunction +import kotlin.jvm.JvmName import kotlin.reflect.KClass import kotlin.reflect.KFunction1 +import kotlin.reflect.KSuspendFunction1 import kotlin.reflect.KType import kotlin.reflect.typeOf import kotlinx.coroutines.flow.Flow @@ -17,11 +20,14 @@ import kotlinx.serialization.builtins.SetSerializer import kotlinx.serialization.descriptors.* import kotlinx.serialization.serializer -sealed class Tool(open val function: FunctionObject, open val invoke: (FunctionCall) -> A) { +sealed class Tool( + open val function: FunctionObject, + open val invoke: suspend (FunctionCall) -> A +) { data class Enumeration( override val function: FunctionObject, - override val invoke: (FunctionCall) -> E, + override val invoke: suspend (FunctionCall) -> E, val cases: List>, val enumSerializer: (String) -> E ) : Tool(function = function, invoke = invoke) @@ -31,17 +37,17 @@ sealed class Tool(open val function: FunctionObject, open val invoke: (Fu class FlowOfStreamedFunctions( override val function: FunctionObject, - override val invoke: (FunctionCall) -> A + override val invoke: suspend (FunctionCall) -> A ) : Tool(function = function, invoke = invoke) class FlowOfAIEvents( override val function: FunctionObject, - override val invoke: (FunctionCall) -> A + override val invoke: suspend (FunctionCall) -> A ) : Tool(function = function, invoke = invoke) data class Sealed( override val function: FunctionObject, - override val invoke: (FunctionCall) -> A, + override val invoke: suspend (FunctionCall) -> A, val cases: List, ) : Tool(function = function, invoke = invoke) { data class Case(val className: String, val tool: Tool<*>) @@ -49,17 +55,17 @@ sealed class Tool(open val function: FunctionObject, open val invoke: (Fu data class Contextual( override val function: FunctionObject, - override val invoke: (FunctionCall) -> A, + override val invoke: suspend (FunctionCall) -> A, ) : Tool(function = function, invoke = invoke) data class Callable( override val function: FunctionObject, - override val invoke: (FunctionCall) -> A, + override val invoke: suspend (FunctionCall) -> A, ) : Tool(function = function, invoke = invoke) data class Primitive( override val function: FunctionObject, - override val invoke: (FunctionCall) -> A + override val invoke: suspend (FunctionCall) -> A ) : Tool(function = function, invoke = invoke) companion object { @@ -227,10 +233,61 @@ sealed class Tool(open val function: FunctionObject, open val invoke: (Fu } } - inline fun toolOf(fn: KFunction1): Tool { + @JvmName("fromKotlinFunction1") + inline operator fun B> invoke( + name: String, + description: Description, + fn: F, + ): Tool { val tool = fromKotlin() return Callable( - function = tool.function.copy(name = fn.name, description = fn.name), + function = tool.function.copy(name = name, description = description.value), + invoke = { + val input = tool.invoke(it) + fn(input) + } + ) + } + + @JvmName("fromKotlinSuspendFunction1") + inline fun suspend( + name: String, + description: Description, + noinline fn: suspend (A) -> B, + ): Tool { + val tool = fromKotlin() + return Callable( + function = tool.function.copy(name = name, description = description.value), + invoke = { + val input = tool.invoke(it) + fn(input) + } + ) + } + + @JvmName("fromKotlinKFunction1") + inline operator fun > invoke( + fn: F, + description: Description = Description(fn.name) + ): Tool { + val tool = fromKotlin() + return Callable( + function = tool.function.copy(name = fn.name, description = description.value), + invoke = { + val input = tool.invoke(it) + fn(input) + } + ) + } + + @JvmName("fromKotlinKSuspendFunction1") + inline operator fun invoke( + fn: KSuspendFunction1, + description: Description = Description(fn.name) + ): Tool { + val tool = fromKotlin() + return Callable( + function = tool.function.copy(name = fn.name, description = description.value), invoke = { val input = tool.invoke(it) fn(input) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/conversation/Description.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/conversation/Description.kt index 3cfe5d3fc..ba541f137 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/conversation/Description.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/conversation/Description.kt @@ -9,5 +9,10 @@ import kotlinx.serialization.SerialInfo @OptIn(ExperimentalSerializationApi::class) @SerialInfo @Retention(AnnotationRetention.RUNTIME) -@Target(AnnotationTarget.CLASS, AnnotationTarget.PROPERTY, AnnotationTarget.FIELD) +@Target( + AnnotationTarget.CLASS, + AnnotationTarget.PROPERTY, + AnnotationTarget.FIELD, + AnnotationTarget.FUNCTION +) annotation class Description(val value: String) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt index 72b7a07a8..0958449e0 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt @@ -100,7 +100,7 @@ private suspend fun Chat.promptWithFunctions( } else { val callRequestedMessages = listOf(assistantRequestedCallMessage(calls)) val resultMessages = callResultMessages(calls, tools, collector) - resultMessages.forEach { usageTracker.toolInvocations++ } + repeat(resultMessages.size) { usageTracker.toolInvocations++ } val promptWithToolOutputs = prompt.copy(messages = prompt.messages + callRequestedMessages + resultMessages) // recurse until the assistant decides to call the serializer @@ -256,7 +256,7 @@ fun Chat.promptStreaming( prompt: Prompt, scope: Conversation, function: FunctionObject, - serializer: (json: String) -> A, + serializer: suspend (json: String) -> A, ): Flow> = flow { val promptWithFunctions = prompt.copy(functions = listOf(function)) val adaptedPrompt = PromptCalculator.adaptPromptToConversationAndModel(promptWithFunctions, scope) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/StreamedFunction.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/StreamedFunction.kt index cd14ebef4..6164b8a1d 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/StreamedFunction.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/StreamedFunction.kt @@ -43,7 +43,7 @@ sealed class StreamedFunction { prompt: Prompt, request: CreateChatCompletionRequest, scope: Conversation, - serializer: (json: String) -> A, + serializer: suspend (json: String) -> A, function: FunctionObject ) { val messages = mutableListOf() @@ -130,7 +130,7 @@ sealed class StreamedFunction { private suspend fun FlowCollector>.streamResult( functionCall: ChatCompletionMessageToolCallFunction, messages: MutableList, - serializer: (json: String) -> A + serializer: suspend (json: String) -> A ) { val arguments = functionCall.arguments messages.add(PromptBuilder.assistant("Function call: $functionCall")) diff --git a/examples/src/main/kotlin/com/xebia/functional/xef/dsl/chat/ParallelToolCalls.kt b/examples/src/main/kotlin/com/xebia/functional/xef/dsl/chat/ParallelToolCalls.kt index cfec55b3a..62a4b1ccd 100644 --- a/examples/src/main/kotlin/com/xebia/functional/xef/dsl/chat/ParallelToolCalls.kt +++ b/examples/src/main/kotlin/com/xebia/functional/xef/dsl/chat/ParallelToolCalls.kt @@ -4,15 +4,24 @@ import com.xebia.functional.xef.AI import com.xebia.functional.xef.AIConfig import com.xebia.functional.xef.AIEvent import com.xebia.functional.xef.Tool +import com.xebia.functional.xef.conversation.Description import kotlinx.coroutines.flow.Flow import kotlinx.serialization.Serializable -fun ballLocationInfo(input: String): String = "The ball is in the 47 cup." +val ballCupLocation = 47 + +suspend fun ballLocationInfoFromLastCupTried(input: Int): String { + val tip = if (input < ballCupLocation) "higher" else "lower" + val recommendedCup = + if (input < ballCupLocation) (input + 1)..ballCupLocation else ballCupLocation until input + return "The ball is not under cup number $input. Try a cup with a $tip number. We recommend trying cup ${recommendedCup.random()}, ${recommendedCup.random()}, ${recommendedCup.random()} next" +} fun lookUnderCupNumber(cupNumber: Int): String = - if (cupNumber == 47) "You found the ball and it's red and shiny." + if (cupNumber == ballCupLocation) + "You found the ball at $ballCupLocation's cup and it's red and shiny." else - "Nothing found under cup number $cupNumber. Use the ballLocationInfo tool to find which cup the ball is under." + "Nothing found under cup number $cupNumber. Use the ballLocationInfoFromLastCupTried tool to get tips as to where it may be sending the last cup number you tried to find the ball." @Serializable data class RevealedSecret(val secret: String) @@ -21,21 +30,16 @@ suspend fun main() { AI( prompt = "Where is the ball? use the available tools to find out.", config = - AIConfig(tools = listOf(Tool.toolOf(::ballLocationInfo), Tool.toolOf(::lookUnderCupNumber))) + AIConfig( + tools = + listOf( + Tool( + ::ballLocationInfoFromLastCupTried, + Description("Get a tip on where the ball is based on the last cup number tried.") + ), + Tool(::lookUnderCupNumber, Description("Look under a cup to find the ball.")) + ) + ) ) - revealedSecret.collect { - when (it) { - // emoji for start is: 🚀 - AIEvent.Start -> println("🚀 Starting...") - is AIEvent.Result -> println("🎉 ${it.value.secret}") - is AIEvent.ToolExecutionRequest -> - println("🔧 Executing tool: ${it.tool.function.name} with input: ${it.input}") - is AIEvent.ToolExecutionResponse -> - println("🔨 Tool response: ${it.tool.function.name} resulted in: ${it.output}") - is AIEvent.Stop -> { - println("🛑 Stopping...") - println("📊 Usage: ${it.usage}") - } - } - } + revealedSecret.collect { it.debugPrint() } } diff --git a/examples/src/main/kotlin/com/xebia/functional/xef/dsl/chat/ParallelToolCallsAnonymous.kt b/examples/src/main/kotlin/com/xebia/functional/xef/dsl/chat/ParallelToolCallsAnonymous.kt new file mode 100644 index 000000000..49ae1da64 --- /dev/null +++ b/examples/src/main/kotlin/com/xebia/functional/xef/dsl/chat/ParallelToolCallsAnonymous.kt @@ -0,0 +1,45 @@ +package com.xebia.functional.xef.dsl.chat + +import com.xebia.functional.xef.AI +import com.xebia.functional.xef.AIConfig +import com.xebia.functional.xef.AIEvent +import com.xebia.functional.xef.Tool +import com.xebia.functional.xef.conversation.Description +import kotlinx.coroutines.flow.Flow + +suspend fun ballLocationInfoFromLastCupTriedImpl(input: Int): String { + val tip = if (input < ballCupLocation) "higher" else "lower" + val recommendedCup = + if (input < ballCupLocation) (input + 1)..ballCupLocation else ballCupLocation until input + return "The ball is not under cup number $input. Try a cup with a $tip number. We recommend trying cup ${recommendedCup.random()}, ${recommendedCup.random()}, ${recommendedCup.random()} next" +} + +fun lookUnderCupNumberImpl(cupNumber: Int): String = + if (cupNumber == ballCupLocation) + "You found the ball at $ballCupLocation's cup and it's red and shiny." + else + "Nothing found under cup number $cupNumber. Use the ballLocationInfoFromLastCupTried tool to get tips as to where it may be sending the last cup number you tried to find the ball." + +suspend fun main() { + val revealedSecret: Flow> = + AI( + prompt = "Where is the ball? use the available tools to find out.", + config = + AIConfig( + tools = + listOf( + Tool.suspend( + "ballLocationInfoFromLastCupTried", + Description("Get a tip on where the ball is based on the last cup number tried.") + ) { lastTried: Int -> + ballLocationInfoFromLastCupTriedImpl(lastTried) + }, + Tool("lookUnderCupNumber", Description("Look under a cup to find the ball.")) { + cupNumber: Int -> + lookUnderCupNumberImpl(cupNumber) + } + ) + ) + ) + revealedSecret.collect { it.debugPrint() } +}