Skip to content

Commit

Permalink
Support for creating tools with anonymous functions and passing descr…
Browse files Browse the repository at this point in the history
…iptions on tool creations. Annotation based functions seem impossible in KMP, would only work on jvm target
  • Loading branch information
raulraja committed Jun 11, 2024
1 parent 9c8dcb7 commit bf6854b
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 34 deletions.
15 changes: 15 additions & 0 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/AIEvent.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,19 @@ sealed class AIEvent<out A> {
val totalTokens: Int,
)
}

fun debugPrint(): Unit =
when (this) {
// emoji for start is: 🚀
Start -> println("🚀 Starting...")
is Result -> println("🎉 $value")
is ToolExecutionRequest ->
println("🔧 Executing tool: ${tool.function.name} with input: $input")
is ToolExecutionResponse ->
println("🔨 Tool response: ${tool.function.name} resulted in: $output")
is Stop -> {
println("🛑 Stopping...")
println("📊 Usage: $usage")
}
}
}
77 changes: 67 additions & 10 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/Tool.kt
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package com.xebia.functional.xef

import com.xebia.functional.openai.generated.model.FunctionObject
import com.xebia.functional.xef.conversation.Description
import com.xebia.functional.xef.llm.FunctionCall
import com.xebia.functional.xef.llm.StreamedFunction
import com.xebia.functional.xef.llm.chatFunction
import kotlin.jvm.JvmName
import kotlin.reflect.KClass
import kotlin.reflect.KFunction1
import kotlin.reflect.KSuspendFunction1
import kotlin.reflect.KType
import kotlin.reflect.typeOf
import kotlinx.coroutines.flow.Flow
Expand All @@ -17,11 +20,14 @@ import kotlinx.serialization.builtins.SetSerializer
import kotlinx.serialization.descriptors.*
import kotlinx.serialization.serializer

sealed class Tool<out A>(open val function: FunctionObject, open val invoke: (FunctionCall) -> A) {
sealed class Tool<out A>(
open val function: FunctionObject,
open val invoke: suspend (FunctionCall) -> A
) {

data class Enumeration<out E>(
override val function: FunctionObject,
override val invoke: (FunctionCall) -> E,
override val invoke: suspend (FunctionCall) -> E,
val cases: List<Tool<E>>,
val enumSerializer: (String) -> E
) : Tool<E>(function = function, invoke = invoke)
Expand All @@ -31,35 +37,35 @@ sealed class Tool<out A>(open val function: FunctionObject, open val invoke: (Fu

class FlowOfStreamedFunctions<out A>(
override val function: FunctionObject,
override val invoke: (FunctionCall) -> A
override val invoke: suspend (FunctionCall) -> A
) : Tool<A>(function = function, invoke = invoke)

class FlowOfAIEvents<out A>(
override val function: FunctionObject,
override val invoke: (FunctionCall) -> A
override val invoke: suspend (FunctionCall) -> A
) : Tool<A>(function = function, invoke = invoke)

data class Sealed<A>(
override val function: FunctionObject,
override val invoke: (FunctionCall) -> A,
override val invoke: suspend (FunctionCall) -> A,
val cases: List<Case>,
) : Tool<A>(function = function, invoke = invoke) {
data class Case(val className: String, val tool: Tool<*>)
}

data class Contextual<A>(
override val function: FunctionObject,
override val invoke: (FunctionCall) -> A,
override val invoke: suspend (FunctionCall) -> A,
) : Tool<A>(function = function, invoke = invoke)

data class Callable<A>(
override val function: FunctionObject,
override val invoke: (FunctionCall) -> A,
override val invoke: suspend (FunctionCall) -> A,
) : Tool<A>(function = function, invoke = invoke)

data class Primitive<A>(
override val function: FunctionObject,
override val invoke: (FunctionCall) -> A
override val invoke: suspend (FunctionCall) -> A
) : Tool<A>(function = function, invoke = invoke)

companion object {
Expand Down Expand Up @@ -227,10 +233,61 @@ sealed class Tool<out A>(open val function: FunctionObject, open val invoke: (Fu
}
}

inline fun <reified A, B> toolOf(fn: KFunction1<A, B>): Tool<B> {
@JvmName("fromKotlinFunction1")
inline operator fun <reified A, reified B, reified F : (A) -> B> invoke(
name: String,
description: Description,
fn: F,
): Tool<B> {
val tool = fromKotlin<A>()
return Callable(
function = tool.function.copy(name = fn.name, description = fn.name),
function = tool.function.copy(name = name, description = description.value),
invoke = {
val input = tool.invoke(it)
fn(input)
}
)
}

@JvmName("fromKotlinSuspendFunction1")
inline fun <reified A, reified B> suspend(
name: String,
description: Description,
noinline fn: suspend (A) -> B,
): Tool<B> {
val tool = fromKotlin<A>()
return Callable(
function = tool.function.copy(name = name, description = description.value),
invoke = {
val input = tool.invoke(it)
fn(input)
}
)
}

@JvmName("fromKotlinKFunction1")
inline operator fun <reified A, reified B, reified F : KFunction1<A, B>> invoke(
fn: F,
description: Description = Description(fn.name)
): Tool<B> {
val tool = fromKotlin<A>()
return Callable(
function = tool.function.copy(name = fn.name, description = description.value),
invoke = {
val input = tool.invoke(it)
fn(input)
}
)
}

@JvmName("fromKotlinKSuspendFunction1")
inline operator fun <reified A, B> invoke(
fn: KSuspendFunction1<A, B>,
description: Description = Description(fn.name)
): Tool<B> {
val tool = fromKotlin<A>()
return Callable(
function = tool.function.copy(name = fn.name, description = description.value),
invoke = {
val input = tool.invoke(it)
fn(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,10 @@ import kotlinx.serialization.SerialInfo
@OptIn(ExperimentalSerializationApi::class)
@SerialInfo
@Retention(AnnotationRetention.RUNTIME)
@Target(AnnotationTarget.CLASS, AnnotationTarget.PROPERTY, AnnotationTarget.FIELD)
@Target(
AnnotationTarget.CLASS,
AnnotationTarget.PROPERTY,
AnnotationTarget.FIELD,
AnnotationTarget.FUNCTION
)
annotation class Description(val value: String)
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ private suspend fun <A> Chat.promptWithFunctions(
} else {
val callRequestedMessages = listOf(assistantRequestedCallMessage(calls))
val resultMessages = callResultMessages(calls, tools, collector)
resultMessages.forEach { usageTracker.toolInvocations++ }
repeat(resultMessages.size) { usageTracker.toolInvocations++ }
val promptWithToolOutputs =
prompt.copy(messages = prompt.messages + callRequestedMessages + resultMessages)
// recurse until the assistant decides to call the serializer
Expand Down Expand Up @@ -256,7 +256,7 @@ fun <A> Chat.promptStreaming(
prompt: Prompt,
scope: Conversation,
function: FunctionObject,
serializer: (json: String) -> A,
serializer: suspend (json: String) -> A,
): Flow<StreamedFunction<A>> = flow {
val promptWithFunctions = prompt.copy(functions = listOf(function))
val adaptedPrompt = PromptCalculator.adaptPromptToConversationAndModel(promptWithFunctions, scope)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ sealed class StreamedFunction<out A> {
prompt: Prompt,
request: CreateChatCompletionRequest,
scope: Conversation,
serializer: (json: String) -> A,
serializer: suspend (json: String) -> A,
function: FunctionObject
) {
val messages = mutableListOf<ChatCompletionRequestMessage>()
Expand Down Expand Up @@ -130,7 +130,7 @@ sealed class StreamedFunction<out A> {
private suspend fun <A> FlowCollector<StreamedFunction<A>>.streamResult(
functionCall: ChatCompletionMessageToolCallFunction,
messages: MutableList<ChatCompletionRequestMessage>,
serializer: (json: String) -> A
serializer: suspend (json: String) -> A
) {
val arguments = functionCall.arguments
messages.add(PromptBuilder.assistant("Function call: $functionCall"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,24 @@ import com.xebia.functional.xef.AI
import com.xebia.functional.xef.AIConfig
import com.xebia.functional.xef.AIEvent
import com.xebia.functional.xef.Tool
import com.xebia.functional.xef.conversation.Description
import kotlinx.coroutines.flow.Flow
import kotlinx.serialization.Serializable

fun ballLocationInfo(input: String): String = "The ball is in the 47 cup."
val ballCupLocation = 47

suspend fun ballLocationInfoFromLastCupTried(input: Int): String {
val tip = if (input < ballCupLocation) "higher" else "lower"
val recommendedCup =
if (input < ballCupLocation) (input + 1)..ballCupLocation else ballCupLocation until input
return "The ball is not under cup number $input. Try a cup with a $tip number. We recommend trying cup ${recommendedCup.random()}, ${recommendedCup.random()}, ${recommendedCup.random()} next"
}

fun lookUnderCupNumber(cupNumber: Int): String =
if (cupNumber == 47) "You found the ball and it's red and shiny."
if (cupNumber == ballCupLocation)
"You found the ball at $ballCupLocation's cup and it's red and shiny."
else
"Nothing found under cup number $cupNumber. Use the ballLocationInfo tool to find which cup the ball is under."
"Nothing found under cup number $cupNumber. Use the ballLocationInfoFromLastCupTried tool to get tips as to where it may be sending the last cup number you tried to find the ball."

@Serializable data class RevealedSecret(val secret: String)

Expand All @@ -21,21 +30,16 @@ suspend fun main() {
AI(
prompt = "Where is the ball? use the available tools to find out.",
config =
AIConfig(tools = listOf(Tool.toolOf(::ballLocationInfo), Tool.toolOf(::lookUnderCupNumber)))
AIConfig(
tools =
listOf(
Tool(
::ballLocationInfoFromLastCupTried,
Description("Get a tip on where the ball is based on the last cup number tried.")
),
Tool(::lookUnderCupNumber, Description("Look under a cup to find the ball."))
)
)
)
revealedSecret.collect {
when (it) {
// emoji for start is: 🚀
AIEvent.Start -> println("🚀 Starting...")
is AIEvent.Result -> println("🎉 ${it.value.secret}")
is AIEvent.ToolExecutionRequest ->
println("🔧 Executing tool: ${it.tool.function.name} with input: ${it.input}")
is AIEvent.ToolExecutionResponse ->
println("🔨 Tool response: ${it.tool.function.name} resulted in: ${it.output}")
is AIEvent.Stop -> {
println("🛑 Stopping...")
println("📊 Usage: ${it.usage}")
}
}
}
revealedSecret.collect { it.debugPrint() }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.xebia.functional.xef.dsl.chat

import com.xebia.functional.xef.AI
import com.xebia.functional.xef.AIConfig
import com.xebia.functional.xef.AIEvent
import com.xebia.functional.xef.Tool
import com.xebia.functional.xef.conversation.Description
import kotlinx.coroutines.flow.Flow

suspend fun ballLocationInfoFromLastCupTriedImpl(input: Int): String {
val tip = if (input < ballCupLocation) "higher" else "lower"
val recommendedCup =
if (input < ballCupLocation) (input + 1)..ballCupLocation else ballCupLocation until input
return "The ball is not under cup number $input. Try a cup with a $tip number. We recommend trying cup ${recommendedCup.random()}, ${recommendedCup.random()}, ${recommendedCup.random()} next"
}

fun lookUnderCupNumberImpl(cupNumber: Int): String =
if (cupNumber == ballCupLocation)
"You found the ball at $ballCupLocation's cup and it's red and shiny."
else
"Nothing found under cup number $cupNumber. Use the ballLocationInfoFromLastCupTried tool to get tips as to where it may be sending the last cup number you tried to find the ball."

suspend fun main() {
val revealedSecret: Flow<AIEvent<RevealedSecret>> =
AI(
prompt = "Where is the ball? use the available tools to find out.",
config =
AIConfig(
tools =
listOf(
Tool.suspend(
"ballLocationInfoFromLastCupTried",
Description("Get a tip on where the ball is based on the last cup number tried.")
) { lastTried: Int ->
ballLocationInfoFromLastCupTriedImpl(lastTried)
},
Tool("lookUnderCupNumber", Description("Look under a cup to find the ball.")) {
cupNumber: Int ->
lookUnderCupNumberImpl(cupNumber)
}
)
)
)
revealedSecret.collect { it.debugPrint() }
}

0 comments on commit bf6854b

Please sign in to comment.