diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptBuilder.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptBuilder.kt index 00d712244..3976b8732 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptBuilder.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptBuilder.kt @@ -23,25 +23,37 @@ interface PromptBuilder { @JvmSynthetic operator fun Message.unaryPlus() { - items.add(this) + addMessage(this) } @JvmSynthetic operator fun List.unaryPlus() { - items.addAll(this) + addMessages(this) } - fun addPrompt(prompt: Prompt): PromptBuilder = apply { items.addAll(prompt.messages) } + fun addPrompt(prompt: Prompt): PromptBuilder = apply { addMessages(prompt.messages) } - fun addMessage(message: Message): PromptBuilder = apply { items.add(message) } + fun addSystemMessage(message: String): PromptBuilder = apply { addMessage(system(message)) } - fun addSystemMessage(message: String): PromptBuilder = apply { items.add(system(message)) } + fun addAssistantMessage(message: String): PromptBuilder = apply { addMessage(assistant(message)) } - fun addAssistantMessage(message: String): PromptBuilder = apply { items.add(assistant(message)) } + fun addUserMessage(message: String): PromptBuilder = apply { addMessage(user(message)) } - fun addUserMessage(message: String): PromptBuilder = apply { items.add(user(message)) } + fun addMessage(message: Message): PromptBuilder = apply { + val lastMessageWithSameRole: Message? = items.lastMessageWithSameRole(message) + if (lastMessageWithSameRole != null) { + val messageUpdated = lastMessageWithSameRole.addContent(message) + items.remove(lastMessageWithSameRole) + items.add(messageUpdated) + } else { + items.add(message) + } + } - fun addMessages(messages: List): PromptBuilder = apply { items.addAll(messages) } + fun addMessages(messages: List): PromptBuilder = apply { + val last = items.removeLastOrNull() + items.addAll(((last?.let { listOf(it) } ?: emptyList()) + messages).flatten()) + } companion object { @@ -53,3 +65,22 @@ fun String.message(role: Role): Message = Message(role, this, role.name) inline fun A.message(role: Role): Message = Message(role, Json.encodeToString(serializer(), this), role.name) + +private fun List.flatten(): List = + fold(mutableListOf()) { acc, message -> + val lastMessageWithSameRole: Message? = acc.lastMessageWithSameRole(message) + if (lastMessageWithSameRole != null) { + val messageUpdated = lastMessageWithSameRole.addContent(message) + acc.remove(lastMessageWithSameRole) + acc.add(messageUpdated) + } else { + acc.add(message) + } + acc + } + +private fun Message.addContent(message: Message): Message = + copy(content = "${content}\n${message.content}") + +private fun List.lastMessageWithSameRole(message: Message): Message? = + lastOrNull()?.let { if (it.role == message.role) it else null } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/templates/templates.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/templates/templates.kt index e7b8648ed..b8a0ebf43 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/templates/templates.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/templates/templates.kt @@ -2,8 +2,6 @@ package com.xebia.functional.xef.prompt.templates import com.xebia.functional.xef.llm.models.chat.Message import com.xebia.functional.xef.llm.models.chat.Role -import com.xebia.functional.xef.prompt.PlatformPromptBuilder -import com.xebia.functional.xef.prompt.Prompt import com.xebia.functional.xef.prompt.message fun system(context: String): Message = context.message(Role.SYSTEM) @@ -18,14 +16,14 @@ inline fun assistant(data: A): Message = data.message(Role.ASSISTANT inline fun user(data: A): Message = data.message(Role.USER) -class StepsMessageBuilder : PlatformPromptBuilder() { +fun steps(role: Role, content: () -> List): Message = + content().mapIndexed { ix, elt -> "${ix + 1} - $elt" }.joinToString("\n").message(role) - override fun preprocess(elements: List): List = - elements.mapIndexed { ix, elt -> "${ix + 1} - ${elt.content}".message(elt.role) } -} +fun systemSteps(content: () -> List): Message = steps(Role.SYSTEM, content) + +fun assistantSteps(content: () -> List): Message = steps(Role.ASSISTANT, content) -fun steps(inside: StepsMessageBuilder.() -> Unit): Prompt = - StepsMessageBuilder().apply { inside() }.build() +fun userSteps(content: () -> List): Message = steps(Role.USER, content) fun writeSequenceOf( content: String, diff --git a/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptBuilderSpec.kt b/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptBuilderSpec.kt index 93cd30f72..560c9b378 100644 --- a/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptBuilderSpec.kt +++ b/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptBuilderSpec.kt @@ -2,10 +2,7 @@ package com.xebia.functional.xef.prompt import com.xebia.functional.xef.data.Question import com.xebia.functional.xef.llm.models.chat.Role -import com.xebia.functional.xef.prompt.templates.assistant -import com.xebia.functional.xef.prompt.templates.steps -import com.xebia.functional.xef.prompt.templates.system -import com.xebia.functional.xef.prompt.templates.user +import com.xebia.functional.xef.prompt.templates.* import io.kotest.core.spec.style.StringSpec import io.kotest.matchers.shouldBe @@ -45,8 +42,12 @@ class PromptBuilderSpec : listOf( "Test System".message(Role.SYSTEM), "Test Query".message(Role.USER), - "instruction 1".message(Role.ASSISTANT), - "instruction 2".message(Role.ASSISTANT) + """ + |instruction 1 + |instruction 2 + """ + .trimMargin() + .message(Role.ASSISTANT), ) messages shouldBe messagesExpected @@ -59,7 +60,7 @@ class PromptBuilderSpec : Prompt { +system("Test System") +user("Test Query") - +steps { instructions.forEach { +assistant(it) } } + +assistantSteps { instructions } } .messages @@ -67,8 +68,12 @@ class PromptBuilderSpec : listOf( "Test System".message(Role.SYSTEM), "Test Query".message(Role.USER), - "1 - instruction 1".message(Role.ASSISTANT), - "2 - instruction 2".message(Role.ASSISTANT) + """ + |1 - instruction 1 + |2 - instruction 2 + """ + .trimMargin() + .message(Role.ASSISTANT), ) messages shouldBe messagesExpected @@ -92,4 +97,41 @@ class PromptBuilderSpec : messages shouldBe messagesExpected } + + "Prompt should flatten the messages with the same role" { + val messages = + Prompt { + +system("Test System") + +user("User message 1") + +user("User message 2") + +assistant("Assistant message 1") + +user("User message 3") + +assistant("Assistant message 2") + +assistant("Assistant message 3") + +user("User message 4") + } + .messages + + val messagesExpected = + listOf( + "Test System".message(Role.SYSTEM), + """ + |User message 1 + |User message 2 + """ + .trimMargin() + .message(Role.USER), + "Assistant message 1".message(Role.ASSISTANT), + "User message 3".message(Role.USER), + """ + |Assistant message 2 + |Assistant message 3 + """ + .trimMargin() + .message(Role.ASSISTANT), + "User message 4".message(Role.USER), + ) + + messages shouldBe messagesExpected + } }) diff --git a/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/text/summarize/Summarize.kt b/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/text/summarize/Summarize.kt index 757c713c5..8cb38985b 100644 --- a/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/text/summarize/Summarize.kt +++ b/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/text/summarize/Summarize.kt @@ -5,8 +5,7 @@ import com.xebia.functional.tokenizer.truncateText import com.xebia.functional.xef.conversation.Conversation import com.xebia.functional.xef.llm.Chat import com.xebia.functional.xef.prompt.Prompt -import com.xebia.functional.xef.prompt.templates.assistant -import com.xebia.functional.xef.prompt.templates.steps +import com.xebia.functional.xef.prompt.templates.assistantSteps import com.xebia.functional.xef.prompt.templates.system import com.xebia.functional.xef.prompt.templates.user import com.xebia.functional.xef.reasoning.tools.Tool @@ -60,12 +59,11 @@ class Summarize( """ .trimMargin() ) - +steps { - (listOf( - "Summarize the `text` in max $summaryLength words", - "Reply with an empty response: ` ` if the text can't be summarized" - ) + instructions) - .forEach { +assistant(it) } + +assistantSteps { + listOf( + "Summarize the `text` in max $summaryLength words", + "Reply with an empty response: ` ` if the text can't be summarized" + ) + instructions } } diff --git a/reasoning/src/jvmMain/kotlin/com/xebia/functional/xef/reasoning/filesystem/ProduceTextFile.kt b/reasoning/src/jvmMain/kotlin/com/xebia/functional/xef/reasoning/filesystem/ProduceTextFile.kt index 1a8bf4f3a..433768a9f 100644 --- a/reasoning/src/jvmMain/kotlin/com/xebia/functional/xef/reasoning/filesystem/ProduceTextFile.kt +++ b/reasoning/src/jvmMain/kotlin/com/xebia/functional/xef/reasoning/filesystem/ProduceTextFile.kt @@ -4,10 +4,7 @@ import com.xebia.functional.xef.conversation.Conversation import com.xebia.functional.xef.io.DEFAULT import com.xebia.functional.xef.llm.ChatWithFunctions import com.xebia.functional.xef.prompt.Prompt -import com.xebia.functional.xef.prompt.templates.assistant -import com.xebia.functional.xef.prompt.templates.steps -import com.xebia.functional.xef.prompt.templates.system -import com.xebia.functional.xef.prompt.templates.user +import com.xebia.functional.xef.prompt.templates.* import com.xebia.functional.xef.reasoning.tools.Tool import kotlinx.uuid.UUID import kotlinx.uuid.generateUUID @@ -29,7 +26,7 @@ class ProduceTextFile( Prompt { +system("Convert output for a Text File") +user(input) - +steps { instructions.forEach { +assistant(it) } } + +assistantSteps { instructions } }, scope = scope, serializer = TxtFile.serializer()