From 0de4d57eba663cc6b52911aaa6302c3cecfeafb1 Mon Sep 17 00:00:00 2001 From: Javi Pacheco Date: Wed, 13 Sep 2023 16:04:00 +0200 Subject: [PATCH 1/2] Flatten messages in PromptBuilder --- .../functional/xef/prompt/PromptBuilder.kt | 35 +++++++++++++---- .../xef/prompt/PromptBuilderSpec.kt | 38 +++++++++++++++++++ 2 files changed, 66 insertions(+), 7 deletions(-) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptBuilder.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptBuilder.kt index 00d712244..460e45c61 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptBuilder.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptBuilder.kt @@ -23,23 +23,23 @@ interface PromptBuilder { @JvmSynthetic operator fun Message.unaryPlus() { - items.add(this) + addMessage(this) } @JvmSynthetic operator fun List.unaryPlus() { - items.addAll(this) + addMessages(this) } - fun addPrompt(prompt: Prompt): PromptBuilder = apply { items.addAll(prompt.messages) } + fun addPrompt(prompt: Prompt): PromptBuilder = apply { addMessages(prompt.messages) } - fun addMessage(message: Message): PromptBuilder = apply { items.add(message) } + fun addSystemMessage(message: String): PromptBuilder = apply { addMessage(system(message)) } - fun addSystemMessage(message: String): PromptBuilder = apply { items.add(system(message)) } + fun addAssistantMessage(message: String): PromptBuilder = apply { addMessage(assistant(message)) } - fun addAssistantMessage(message: String): PromptBuilder = apply { items.add(assistant(message)) } + fun addUserMessage(message: String): PromptBuilder = apply { addMessage(user(message)) } - fun addUserMessage(message: String): PromptBuilder = apply { items.add(user(message)) } + fun addMessage(message: Message): PromptBuilder = apply { items.add(message) } fun addMessages(messages: List): PromptBuilder = apply { items.addAll(messages) } @@ -53,3 +53,24 @@ fun String.message(role: Role): Message = Message(role, this, role.name) inline fun A.message(role: Role): Message = Message(role, Json.encodeToString(serializer(), this), role.name) + +fun Prompt.flatten(): Prompt = + Prompt( + messages.fold(mutableListOf()) { acc, message -> + val lastMessageWithSameRole: Message? = acc.lastMessageWithSameRole(message) + if (lastMessageWithSameRole != null) { + val messageUpdated = + lastMessageWithSameRole.copy( + content = "${lastMessageWithSameRole.content}\n${message.content}" + ) + acc.remove(lastMessageWithSameRole) + acc.add(messageUpdated) + } else { + acc.add(message) + } + acc + } + ) + +private fun List.lastMessageWithSameRole(message: Message): Message? = + lastOrNull()?.let { if (it.role == message.role) it else null } diff --git a/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptBuilderSpec.kt b/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptBuilderSpec.kt index 93cd30f72..4fe75e2b5 100644 --- a/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptBuilderSpec.kt +++ b/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptBuilderSpec.kt @@ -92,4 +92,42 @@ class PromptBuilderSpec : messages shouldBe messagesExpected } + + "flatten method should flatten the messages with the same role" { + val messages = + Prompt { + +system("Test System") + +user("User message 1") + +user("User message 2") + +assistant("Assistant message 1") + +user("User message 3") + +assistant("Assistant message 2") + +assistant("Assistant message 3") + +user("User message 4") + } + .flatten() + .messages + + val messagesExpected = + listOf( + "Test System".message(Role.SYSTEM), + """ + |User message 1 + |User message 2 + """ + .trimMargin() + .message(Role.USER), + "Assistant message 1".message(Role.ASSISTANT), + "User message 3".message(Role.USER), + """ + |Assistant message 2 + |Assistant message 3 + """ + .trimMargin() + .message(Role.ASSISTANT), + "User message 4".message(Role.USER), + ) + + messages shouldBe messagesExpected + } }) From aab9123848df3c4ddbf39d3ccb2adb3f3e58357a Mon Sep 17 00:00:00 2001 From: Javi Pacheco Date: Mon, 18 Sep 2023 10:21:09 +0200 Subject: [PATCH 2/2] Flatten by default in PromptBuilder --- .../functional/xef/prompt/PromptBuilder.kt | 46 +++++++++++-------- .../xef/prompt/templates/templates.kt | 14 +++--- .../xef/prompt/PromptBuilderSpec.kt | 26 ++++++----- .../xef/reasoning/text/summarize/Summarize.kt | 14 +++--- .../reasoning/filesystem/ProduceTextFile.kt | 7 +-- 5 files changed, 57 insertions(+), 50 deletions(-) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptBuilder.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptBuilder.kt index 460e45c61..3976b8732 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptBuilder.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/PromptBuilder.kt @@ -39,9 +39,21 @@ interface PromptBuilder { fun addUserMessage(message: String): PromptBuilder = apply { addMessage(user(message)) } - fun addMessage(message: Message): PromptBuilder = apply { items.add(message) } + fun addMessage(message: Message): PromptBuilder = apply { + val lastMessageWithSameRole: Message? = items.lastMessageWithSameRole(message) + if (lastMessageWithSameRole != null) { + val messageUpdated = lastMessageWithSameRole.addContent(message) + items.remove(lastMessageWithSameRole) + items.add(messageUpdated) + } else { + items.add(message) + } + } - fun addMessages(messages: List): PromptBuilder = apply { items.addAll(messages) } + fun addMessages(messages: List): PromptBuilder = apply { + val last = items.removeLastOrNull() + items.addAll(((last?.let { listOf(it) } ?: emptyList()) + messages).flatten()) + } companion object { @@ -54,23 +66,21 @@ fun String.message(role: Role): Message = Message(role, this, role.name) inline fun A.message(role: Role): Message = Message(role, Json.encodeToString(serializer(), this), role.name) -fun Prompt.flatten(): Prompt = - Prompt( - messages.fold(mutableListOf()) { acc, message -> - val lastMessageWithSameRole: Message? = acc.lastMessageWithSameRole(message) - if (lastMessageWithSameRole != null) { - val messageUpdated = - lastMessageWithSameRole.copy( - content = "${lastMessageWithSameRole.content}\n${message.content}" - ) - acc.remove(lastMessageWithSameRole) - acc.add(messageUpdated) - } else { - acc.add(message) - } - acc +private fun List.flatten(): List = + fold(mutableListOf()) { acc, message -> + val lastMessageWithSameRole: Message? = acc.lastMessageWithSameRole(message) + if (lastMessageWithSameRole != null) { + val messageUpdated = lastMessageWithSameRole.addContent(message) + acc.remove(lastMessageWithSameRole) + acc.add(messageUpdated) + } else { + acc.add(message) } - ) + acc + } + +private fun Message.addContent(message: Message): Message = + copy(content = "${content}\n${message.content}") private fun List.lastMessageWithSameRole(message: Message): Message? = lastOrNull()?.let { if (it.role == message.role) it else null } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/templates/templates.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/templates/templates.kt index e7b8648ed..b8a0ebf43 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/templates/templates.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/templates/templates.kt @@ -2,8 +2,6 @@ package com.xebia.functional.xef.prompt.templates import com.xebia.functional.xef.llm.models.chat.Message import com.xebia.functional.xef.llm.models.chat.Role -import com.xebia.functional.xef.prompt.PlatformPromptBuilder -import com.xebia.functional.xef.prompt.Prompt import com.xebia.functional.xef.prompt.message fun system(context: String): Message = context.message(Role.SYSTEM) @@ -18,14 +16,14 @@ inline fun assistant(data: A): Message = data.message(Role.ASSISTANT inline fun user(data: A): Message = data.message(Role.USER) -class StepsMessageBuilder : PlatformPromptBuilder() { +fun steps(role: Role, content: () -> List): Message = + content().mapIndexed { ix, elt -> "${ix + 1} - $elt" }.joinToString("\n").message(role) - override fun preprocess(elements: List): List = - elements.mapIndexed { ix, elt -> "${ix + 1} - ${elt.content}".message(elt.role) } -} +fun systemSteps(content: () -> List): Message = steps(Role.SYSTEM, content) + +fun assistantSteps(content: () -> List): Message = steps(Role.ASSISTANT, content) -fun steps(inside: StepsMessageBuilder.() -> Unit): Prompt = - StepsMessageBuilder().apply { inside() }.build() +fun userSteps(content: () -> List): Message = steps(Role.USER, content) fun writeSequenceOf( content: String, diff --git a/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptBuilderSpec.kt b/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptBuilderSpec.kt index 4fe75e2b5..560c9b378 100644 --- a/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptBuilderSpec.kt +++ b/core/src/commonTest/kotlin/com/xebia/functional/xef/prompt/PromptBuilderSpec.kt @@ -2,10 +2,7 @@ package com.xebia.functional.xef.prompt import com.xebia.functional.xef.data.Question import com.xebia.functional.xef.llm.models.chat.Role -import com.xebia.functional.xef.prompt.templates.assistant -import com.xebia.functional.xef.prompt.templates.steps -import com.xebia.functional.xef.prompt.templates.system -import com.xebia.functional.xef.prompt.templates.user +import com.xebia.functional.xef.prompt.templates.* import io.kotest.core.spec.style.StringSpec import io.kotest.matchers.shouldBe @@ -45,8 +42,12 @@ class PromptBuilderSpec : listOf( "Test System".message(Role.SYSTEM), "Test Query".message(Role.USER), - "instruction 1".message(Role.ASSISTANT), - "instruction 2".message(Role.ASSISTANT) + """ + |instruction 1 + |instruction 2 + """ + .trimMargin() + .message(Role.ASSISTANT), ) messages shouldBe messagesExpected @@ -59,7 +60,7 @@ class PromptBuilderSpec : Prompt { +system("Test System") +user("Test Query") - +steps { instructions.forEach { +assistant(it) } } + +assistantSteps { instructions } } .messages @@ -67,8 +68,12 @@ class PromptBuilderSpec : listOf( "Test System".message(Role.SYSTEM), "Test Query".message(Role.USER), - "1 - instruction 1".message(Role.ASSISTANT), - "2 - instruction 2".message(Role.ASSISTANT) + """ + |1 - instruction 1 + |2 - instruction 2 + """ + .trimMargin() + .message(Role.ASSISTANT), ) messages shouldBe messagesExpected @@ -93,7 +98,7 @@ class PromptBuilderSpec : messages shouldBe messagesExpected } - "flatten method should flatten the messages with the same role" { + "Prompt should flatten the messages with the same role" { val messages = Prompt { +system("Test System") @@ -105,7 +110,6 @@ class PromptBuilderSpec : +assistant("Assistant message 3") +user("User message 4") } - .flatten() .messages val messagesExpected = diff --git a/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/text/summarize/Summarize.kt b/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/text/summarize/Summarize.kt index 757c713c5..8cb38985b 100644 --- a/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/text/summarize/Summarize.kt +++ b/reasoning/src/commonMain/kotlin/com/xebia/functional/xef/reasoning/text/summarize/Summarize.kt @@ -5,8 +5,7 @@ import com.xebia.functional.tokenizer.truncateText import com.xebia.functional.xef.conversation.Conversation import com.xebia.functional.xef.llm.Chat import com.xebia.functional.xef.prompt.Prompt -import com.xebia.functional.xef.prompt.templates.assistant -import com.xebia.functional.xef.prompt.templates.steps +import com.xebia.functional.xef.prompt.templates.assistantSteps import com.xebia.functional.xef.prompt.templates.system import com.xebia.functional.xef.prompt.templates.user import com.xebia.functional.xef.reasoning.tools.Tool @@ -60,12 +59,11 @@ class Summarize( """ .trimMargin() ) - +steps { - (listOf( - "Summarize the `text` in max $summaryLength words", - "Reply with an empty response: ` ` if the text can't be summarized" - ) + instructions) - .forEach { +assistant(it) } + +assistantSteps { + listOf( + "Summarize the `text` in max $summaryLength words", + "Reply with an empty response: ` ` if the text can't be summarized" + ) + instructions } } diff --git a/reasoning/src/jvmMain/kotlin/com/xebia/functional/xef/reasoning/filesystem/ProduceTextFile.kt b/reasoning/src/jvmMain/kotlin/com/xebia/functional/xef/reasoning/filesystem/ProduceTextFile.kt index 1a8bf4f3a..433768a9f 100644 --- a/reasoning/src/jvmMain/kotlin/com/xebia/functional/xef/reasoning/filesystem/ProduceTextFile.kt +++ b/reasoning/src/jvmMain/kotlin/com/xebia/functional/xef/reasoning/filesystem/ProduceTextFile.kt @@ -4,10 +4,7 @@ import com.xebia.functional.xef.conversation.Conversation import com.xebia.functional.xef.io.DEFAULT import com.xebia.functional.xef.llm.ChatWithFunctions import com.xebia.functional.xef.prompt.Prompt -import com.xebia.functional.xef.prompt.templates.assistant -import com.xebia.functional.xef.prompt.templates.steps -import com.xebia.functional.xef.prompt.templates.system -import com.xebia.functional.xef.prompt.templates.user +import com.xebia.functional.xef.prompt.templates.* import com.xebia.functional.xef.reasoning.tools.Tool import kotlinx.uuid.UUID import kotlinx.uuid.generateUUID @@ -29,7 +26,7 @@ class ProduceTextFile( Prompt { +system("Convert output for a Text File") +user(input) - +steps { instructions.forEach { +assistant(it) } } + +assistantSteps { instructions } }, scope = scope, serializer = TxtFile.serializer()