Skip to content

Commit

Permalink
draft for parallel tool calls + abstract away serialization framework…
Browse files Browse the repository at this point in the history
… so AI features can be implemented outside of Kotlin
  • Loading branch information
raulraja committed Jun 11, 2024
1 parent ede4497 commit 6d70c55
Show file tree
Hide file tree
Showing 27 changed files with 573 additions and 440 deletions.
276 changes: 75 additions & 201 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt
Original file line number Diff line number Diff line change
@@ -1,237 +1,111 @@
package com.xebia.functional.xef

import com.xebia.functional.openai.generated.api.Chat
import com.xebia.functional.openai.generated.api.Images
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequest
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.conversation.Description
import com.xebia.functional.xef.llm.models.modelType
import com.xebia.functional.xef.llm.prompt
import com.xebia.functional.xef.llm.promptStreaming
import com.xebia.functional.xef.prompt.Prompt
import kotlin.coroutines.cancellation.CancellationException
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.typeOf
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.InternalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.descriptors.SerialKind
import kotlinx.serialization.serializer
import kotlinx.coroutines.flow.Flow

sealed interface AI {
class AI<out A>(private val config: AIConfig, val serializer: Tool<A>) {

@Serializable
@Description("The selected items indexes")
data class SelectedItems(
@Description("The selected items indexes") val selectedItems: List<Int>,
)
private fun runStreamingWithStringSerializer(prompt: Prompt): Flow<String> =
config.api.promptStreaming(prompt, config.conversation, config.tools)

data class Classification(
val name: String,
val description: String,
)

interface PromptClassifier {
fun template(input: String, output: String, context: String): String
}

interface PromptMultipleClassifier {
fun getItems(): List<Classification>

fun template(input: String): String {
val items = getItems()

return """
|Based on the <input>, identify whether the user is asking about one or more of the following items
|
|${
items.joinToString("\n") { item -> "<${item.name}>${item.description}</${item.name}>" }
}
|
|<items>
|${
items.mapIndexed { index, item -> "\t<item index=\"$index\">${item.name}</item>" }
.joinToString("\n")
@PublishedApi
internal suspend operator fun invoke(prompt: Prompt): A =
when (val serializer = serializer) {
is Tool.Class -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
is Tool.Contextual -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
is Tool.Enumeration<A> -> runWithEnumSingleTokenSerializer(serializer, prompt)
is Tool.FlowOfStreamedFunctions<*> -> {
config.api.promptStreaming(prompt, config.conversation, serializer, config.tools) as A
}
|</items>
|<input>
|$input
|</input>
"""
is Tool.FlowOfStrings -> runStreamingWithStringSerializer(prompt) as A
is Tool.Primitive -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
is Tool.Sealed ->
config.api.prompt(prompt, config.conversation, serializer, serializer.cases, config.tools)
}

@OptIn(ExperimentalSerializationApi::class)
fun KType.enumValuesName(
serializer: KSerializer<Any?> = serializer(this)
): List<Classification> {
return if (serializer.descriptor.kind != SerialKind.ENUM) {
emptyList()
} else {
(0 until serializer.descriptor.elementsCount).map { index ->
val name =
serializer.descriptor
.getElementName(index)
.removePrefix(serializer.descriptor.serialName)
val description =
(serializer.descriptor.getElementAnnotations(index).first { it is Description }
as Description)
.value
Classification(name, description)
private suspend fun runWithEnumSingleTokenSerializer(
serializer: Tool.Enumeration<A>,
prompt: Prompt
): A {
val encoding = prompt.model.modelType(forFunctions = false).encoding
val cases = serializer.cases
val logitBias =
cases
.flatMap {
val result = encoding.encode(it.function.name)
if (result.size > 1) {
error("Cannot encode enum case $it into one token")
}
result
}
}
.associate { "$it" to 100 }
val result =
config.api.createChatCompletion(
CreateChatCompletionRequest(
messages = prompt.messages,
model = prompt.model,
logitBias = logitBias,
maxTokens = 1,
temperature = 0.0
)
)
val choice = result.choices[0].message.content
val enumSerializer = serializer.enumSerializer
return if (choice != null) {
enumSerializer(choice)
} else {
error("Cannot decode enum case from $choice")
}
}

companion object {

fun <A : Any> chat(
target: KType,
model: CreateChatCompletionRequestModel,
api: Chat,
conversation: Conversation,
enumSerializer: ((case: String) -> A)?,
caseSerializers: List<KSerializer<A>>,
serializer: () -> KSerializer<A>,
): DefaultAI<A> =
DefaultAI(
target = target,
model = model,
api = api,
serializer = serializer,
conversation = conversation,
enumSerializer = enumSerializer,
caseSerializers = caseSerializers
)

fun images(
config: Config = Config(),
): Images = OpenAI(config).images

@PublishedApi
internal suspend inline fun <reified A : Any> invokeEnum(
prompt: Prompt,
target: KType = typeOf<A>(),
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): A =
chat(
target = target,
model = prompt.model,
api = api,
conversation = conversation,
enumSerializer = { @Suppress("UPPER_BOUND_VIOLATED") enumValueOf<A>(it) },
caseSerializers = emptyList()
) {
serializer<A>()
}
.invoke(prompt)

/**
* Classify a prompt using a given enum.
*
* @param input The input to the model.
* @param output The output to the model.
* @param context The context to the model.
* @param model The model to use.
* @param target The target type to return.
* @param api The chat API to use.
* @param conversation The conversation to use.
* @return The classified enum.
* @throws IllegalArgumentException If no enum values are found.
*/
@AiDsl
@Throws(IllegalArgumentException::class, CancellationException::class)
suspend inline fun <reified E> classify(
input: String,
output: String,
context: String,
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o,
target: KType = typeOf<E>(),
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): E where E : PromptClassifier, E : Enum<E> {
val value = enumValues<E>().firstOrNull() ?: error("No enum values found")
return invoke(
config: AIConfig = AIConfig(),
): E where E : Enum<E>, E : PromptClassifier {
val value = enumValues<E>().firstOrNull() ?: error("No values to classify")
return AI<E>(
prompt = value.template(input, output, context),
model = model,
target = target,
config = config,
api = api,
conversation = conversation
)
}

@AiDsl
@Throws(IllegalArgumentException::class, CancellationException::class)
suspend inline fun <reified E> multipleClassify(
input: String,
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o,
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): List<E> where E : PromptMultipleClassifier, E : Enum<E> {
config: AIConfig = AIConfig(),
): List<E> where E : Enum<E>, E : PromptMultipleClassifier {
val values = enumValues<E>()
val value = values.firstOrNull() ?: error("No enum values found")
val value = values.firstOrNull() ?: error("No values to classify")
val selected: SelectedItems =
invoke(
AI(
prompt = value.template(input),
model = model,
config = config,
api = api,
conversation = conversation
serializer = Tool.fromKotlin<SelectedItems>(),
config = config
)
return selected.selectedItems.mapNotNull { values.elementAtOrNull(it) }
}

@AiDsl
suspend inline operator fun <reified A : Any> invoke(
prompt: String,
target: KType = typeOf<A>(),
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_3_5_turbo_0125,
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): A = chat(Prompt(model, prompt), target, config, api, conversation)

@AiDsl
suspend inline operator fun <reified A : Any> invoke(
prompt: Prompt,
target: KType = typeOf<A>(),
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): A = chat(prompt, target, config, api, conversation)

@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class)
@AiDsl
suspend inline fun <reified A : Any> chat(
prompt: Prompt,
target: KType = typeOf<A>(),
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): A {
val kind =
(target.classifier as? KClass<*>)?.serializer()?.descriptor?.kind
?: error("Cannot find SerialKind for $target")
return when (kind) {
SerialKind.ENUM -> invokeEnum<A>(prompt, target, config, api, conversation)
else -> {
chat(
target = target,
model = prompt.model,
api = api,
conversation = conversation,
enumSerializer = null,
caseSerializers = emptyList()
) {
serializer<A>()
}
.invoke(prompt)
}
}
}
}
}

@AiDsl
suspend inline fun <reified A> AI(
prompt: String,
serializer: Tool<A> = Tool.fromKotlin<A>(),
config: AIConfig = AIConfig()
): A = AI(Prompt(config.model, prompt), serializer, config)

@AiDsl
suspend inline fun <reified A> AI(
prompt: Prompt,
serializer: Tool<A> = Tool.fromKotlin<A>(),
config: AIConfig = AIConfig(),
): A = AI(config, serializer).invoke(prompt)
15 changes: 15 additions & 0 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/AIConfig.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.xebia.functional.xef

import com.xebia.functional.openai.generated.api.Chat
import com.xebia.functional.openai.generated.api.OpenAI
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel
import com.xebia.functional.xef.conversation.Conversation

data class AIConfig(
val tools: List<Tool<*>> = emptyList(),
val model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o,
val config: Config = Config(),
val openAI: OpenAI = OpenAI(config),
val api: Chat = openAI.chat,
val conversation: Conversation = Conversation(),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.xebia.functional.xef

data class Classification(
val name: String,
val description: String,
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ data class Config(
prettyPrint = false
isLenient = true
explicitNulls = false
classDiscriminator = "_type_"
classDiscriminator = TYPE_DISCRIMINATOR
},
val streamingPrefix: String = "data:",
val streamingDelimiter: String = "data: [DONE]"
) {
companion object {
val DEFAULT = Config()
const val TYPE_DISCRIMINATOR = "_type_"
}
}

Expand Down
Loading

0 comments on commit 6d70c55

Please sign in to comment.