Skip to content

Commit

Permalink
Flatten messages in PromptBuilder (#424)
Browse files Browse the repository at this point in the history
* Flatten messages in PromptBuilder

* Flatten by default in PromptBuilder
  • Loading branch information
javipacheco authored Sep 18, 2023
1 parent 8155a54 commit e6d7d50
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,37 @@ interface PromptBuilder {

@JvmSynthetic
operator fun Message.unaryPlus() {
items.add(this)
addMessage(this)
}

@JvmSynthetic
operator fun List<Message>.unaryPlus() {
items.addAll(this)
addMessages(this)
}

fun addPrompt(prompt: Prompt): PromptBuilder = apply { items.addAll(prompt.messages) }
fun addPrompt(prompt: Prompt): PromptBuilder = apply { addMessages(prompt.messages) }

fun addMessage(message: Message): PromptBuilder = apply { items.add(message) }
fun addSystemMessage(message: String): PromptBuilder = apply { addMessage(system(message)) }

fun addSystemMessage(message: String): PromptBuilder = apply { items.add(system(message)) }
fun addAssistantMessage(message: String): PromptBuilder = apply { addMessage(assistant(message)) }

fun addAssistantMessage(message: String): PromptBuilder = apply { items.add(assistant(message)) }
fun addUserMessage(message: String): PromptBuilder = apply { addMessage(user(message)) }

fun addUserMessage(message: String): PromptBuilder = apply { items.add(user(message)) }
fun addMessage(message: Message): PromptBuilder = apply {
val lastMessageWithSameRole: Message? = items.lastMessageWithSameRole(message)
if (lastMessageWithSameRole != null) {
val messageUpdated = lastMessageWithSameRole.addContent(message)
items.remove(lastMessageWithSameRole)
items.add(messageUpdated)
} else {
items.add(message)
}
}

fun addMessages(messages: List<Message>): PromptBuilder = apply { items.addAll(messages) }
fun addMessages(messages: List<Message>): PromptBuilder = apply {
val last = items.removeLastOrNull()
items.addAll(((last?.let { listOf(it) } ?: emptyList()) + messages).flatten())
}

companion object {

Expand All @@ -53,3 +65,22 @@ fun String.message(role: Role): Message = Message(role, this, role.name)

inline fun <reified A> A.message(role: Role): Message =
Message(role, Json.encodeToString(serializer(), this), role.name)

private fun List<Message>.flatten(): List<Message> =
fold(mutableListOf()) { acc, message ->
val lastMessageWithSameRole: Message? = acc.lastMessageWithSameRole(message)
if (lastMessageWithSameRole != null) {
val messageUpdated = lastMessageWithSameRole.addContent(message)
acc.remove(lastMessageWithSameRole)
acc.add(messageUpdated)
} else {
acc.add(message)
}
acc
}

private fun Message.addContent(message: Message): Message =
copy(content = "${content}\n${message.content}")

private fun List<Message>.lastMessageWithSameRole(message: Message): Message? =
lastOrNull()?.let { if (it.role == message.role) it else null }
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package com.xebia.functional.xef.prompt.templates

import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.chat.Role
import com.xebia.functional.xef.prompt.PlatformPromptBuilder
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.message

fun system(context: String): Message = context.message(Role.SYSTEM)
Expand All @@ -18,14 +16,14 @@ inline fun <reified A> assistant(data: A): Message = data.message(Role.ASSISTANT

inline fun <reified A> user(data: A): Message = data.message(Role.USER)

class StepsMessageBuilder : PlatformPromptBuilder() {
fun steps(role: Role, content: () -> List<String>): Message =
content().mapIndexed { ix, elt -> "${ix + 1} - $elt" }.joinToString("\n").message(role)

override fun preprocess(elements: List<Message>): List<Message> =
elements.mapIndexed { ix, elt -> "${ix + 1} - ${elt.content}".message(elt.role) }
}
fun systemSteps(content: () -> List<String>): Message = steps(Role.SYSTEM, content)

fun assistantSteps(content: () -> List<String>): Message = steps(Role.ASSISTANT, content)

fun steps(inside: StepsMessageBuilder.() -> Unit): Prompt =
StepsMessageBuilder().apply { inside() }.build()
fun userSteps(content: () -> List<String>): Message = steps(Role.USER, content)

fun writeSequenceOf(
content: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ package com.xebia.functional.xef.prompt

import com.xebia.functional.xef.data.Question
import com.xebia.functional.xef.llm.models.chat.Role
import com.xebia.functional.xef.prompt.templates.assistant
import com.xebia.functional.xef.prompt.templates.steps
import com.xebia.functional.xef.prompt.templates.system
import com.xebia.functional.xef.prompt.templates.user
import com.xebia.functional.xef.prompt.templates.*
import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldBe

Expand Down Expand Up @@ -45,8 +42,12 @@ class PromptBuilderSpec :
listOf(
"Test System".message(Role.SYSTEM),
"Test Query".message(Role.USER),
"instruction 1".message(Role.ASSISTANT),
"instruction 2".message(Role.ASSISTANT)
"""
|instruction 1
|instruction 2
"""
.trimMargin()
.message(Role.ASSISTANT),
)

messages shouldBe messagesExpected
Expand All @@ -59,16 +60,20 @@ class PromptBuilderSpec :
Prompt {
+system("Test System")
+user("Test Query")
+steps { instructions.forEach { +assistant(it) } }
+assistantSteps { instructions }
}
.messages

val messagesExpected =
listOf(
"Test System".message(Role.SYSTEM),
"Test Query".message(Role.USER),
"1 - instruction 1".message(Role.ASSISTANT),
"2 - instruction 2".message(Role.ASSISTANT)
"""
|1 - instruction 1
|2 - instruction 2
"""
.trimMargin()
.message(Role.ASSISTANT),
)

messages shouldBe messagesExpected
Expand All @@ -92,4 +97,41 @@ class PromptBuilderSpec :

messages shouldBe messagesExpected
}

"Prompt should flatten the messages with the same role" {
val messages =
Prompt {
+system("Test System")
+user("User message 1")
+user("User message 2")
+assistant("Assistant message 1")
+user("User message 3")
+assistant("Assistant message 2")
+assistant("Assistant message 3")
+user("User message 4")
}
.messages

val messagesExpected =
listOf(
"Test System".message(Role.SYSTEM),
"""
|User message 1
|User message 2
"""
.trimMargin()
.message(Role.USER),
"Assistant message 1".message(Role.ASSISTANT),
"User message 3".message(Role.USER),
"""
|Assistant message 2
|Assistant message 3
"""
.trimMargin()
.message(Role.ASSISTANT),
"User message 4".message(Role.USER),
)

messages shouldBe messagesExpected
}
})
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ import com.xebia.functional.tokenizer.truncateText
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.Chat
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.templates.assistant
import com.xebia.functional.xef.prompt.templates.steps
import com.xebia.functional.xef.prompt.templates.assistantSteps
import com.xebia.functional.xef.prompt.templates.system
import com.xebia.functional.xef.prompt.templates.user
import com.xebia.functional.xef.reasoning.tools.Tool
Expand Down Expand Up @@ -60,12 +59,11 @@ class Summarize(
"""
.trimMargin()
)
+steps {
(listOf(
"Summarize the `text` in max $summaryLength words",
"Reply with an empty response: ` ` if the text can't be summarized"
) + instructions)
.forEach { +assistant(it) }
+assistantSteps {
listOf(
"Summarize the `text` in max $summaryLength words",
"Reply with an empty response: ` ` if the text can't be summarized"
) + instructions
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.io.DEFAULT
import com.xebia.functional.xef.llm.ChatWithFunctions
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.templates.assistant
import com.xebia.functional.xef.prompt.templates.steps
import com.xebia.functional.xef.prompt.templates.system
import com.xebia.functional.xef.prompt.templates.user
import com.xebia.functional.xef.prompt.templates.*
import com.xebia.functional.xef.reasoning.tools.Tool
import kotlinx.uuid.UUID
import kotlinx.uuid.generateUUID
Expand All @@ -29,7 +26,7 @@ class ProduceTextFile(
Prompt {
+system("Convert output for a Text File")
+user(input)
+steps { instructions.forEach { +assistant(it) } }
+assistantSteps { instructions }
},
scope = scope,
serializer = TxtFile.serializer()
Expand Down

0 comments on commit e6d7d50

Please sign in to comment.