Skip to content

Commit

Permalink
decompose class hierarchy of LLM implementations (#397)
Browse files Browse the repository at this point in the history
* refactor hierarchy of LLM (first compiling version)

* first compiling version after clean build; Split request of Chat and FunChat

* small changes and fixes

* modify hierarchy for GCP

* adjust AutoClose and scope of provider client

* small changes

* small changes

* some models had wrong capability

* changes according to comments and feedback

* storing full prompt and response messages (#440)

* storing full prompt and response messages

* updated the rest of the functions

* fixed problem in messages order and added tests

* Animal Example Fixed

---------

Co-authored-by: Raúl Raja Martínez <[email protected]>
Co-authored-by: José Carlos Montañez <[email protected]>
Co-authored-by: Javi Pacheco <[email protected]>
  • Loading branch information
4 people authored Sep 22, 2023
1 parent a50b96e commit 774c796
Show file tree
Hide file tree
Showing 36 changed files with 848 additions and 742 deletions.
47 changes: 9 additions & 38 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.xebia.functional.xef.llm

import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
Expand All @@ -10,22 +9,18 @@ import com.xebia.functional.xef.prompt.templates.assistant
import kotlinx.coroutines.flow.*

interface Chat : LLM {
val modelType: ModelType

suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse

suspend fun createChatCompletions(request: ChatCompletionRequest): Flow<ChatCompletionChunk>

fun tokensFromMessages(messages: List<Message>): Int

@AiDsl
fun promptStreaming(prompt: Prompt, scope: Conversation): Flow<String> = flow {
val messagesForRequestPrompt =
PromptCalculator.adaptPromptToConversationAndModel(prompt, scope, this@Chat)

val request =
ChatCompletionRequest(
model = name,
user = prompt.configuration.user,
messages = messagesForRequestPrompt.messages,
n = prompt.configuration.numberOfPredictions,
Expand All @@ -39,8 +34,9 @@ interface Chat : LLM {
.onEach { emit(it) }
.fold("", String::plus)
.also { finalText ->
val message = assistant(finalText)
MemoryManagement.addMemoriesAfterStream(this@Chat, request, scope, listOf(message))
val aiResponseMessage = assistant(finalText)
val newMessages = prompt.messages + listOf(aiResponseMessage)
newMessages.addToMemory(this@Chat, scope)
}
}

Expand All @@ -50,46 +46,21 @@ interface Chat : LLM {

@AiDsl
suspend fun promptMessages(prompt: Prompt, scope: Conversation): List<String> {

val requestedMemories = prompt.messages.toMemory(this@Chat, scope)
val adaptedPrompt = PromptCalculator.adaptPromptToConversationAndModel(prompt, scope, this@Chat)

fun chatRequest(): ChatCompletionRequest =
val request =
ChatCompletionRequest(
model = name,
user = adaptedPrompt.configuration.user,
messages = adaptedPrompt.messages,
n = adaptedPrompt.configuration.numberOfPredictions,
temperature = adaptedPrompt.configuration.temperature,
maxTokens = adaptedPrompt.configuration.minResponseTokens,
functions = listOfNotNull(adaptedPrompt.function),
functionCall = adaptedPrompt.function?.let { mapOf("name" to (it.name)) }
)

return MemoryManagement.run {
when (this@Chat) {
is ChatWithFunctions ->
// we only support functions for now with GPT_3_5_TURBO_FUNCTIONS
if (modelType == ModelType.GPT_3_5_TURBO_FUNCTIONS) {
val request = chatRequest()
createChatCompletionWithFunctions(request)
.choices
.addChoiceWithFunctionsToMemory(this@Chat, request, scope)
.mapNotNull { it.message?.functionCall?.arguments }
} else {
val request = chatRequest()
createChatCompletion(request)
.choices
.addChoiceToMemory(this@Chat, request, scope)
.mapNotNull { it.message?.content }
}
else -> {
val request = chatRequest()
createChatCompletion(request)
.choices
.addChoiceToMemory(this@Chat, request, scope)
.mapNotNull { it.message?.content }
}
}
}
return createChatCompletion(request)
.choices
.addMessagesToMemory(this@Chat, scope, requestedMemories)
.mapNotNull { it.message?.content }
}
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package com.xebia.functional.xef.llm

import arrow.core.nel
import arrow.core.nonFatalOrThrow
import arrow.core.raise.catch
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.models.chat.ChatCompletionRequest
import com.xebia.functional.xef.llm.models.chat.ChatCompletionChunk
import com.xebia.functional.xef.llm.models.chat.ChatCompletionResponseWithFunctions
import com.xebia.functional.xef.llm.models.functions.CFunction
import com.xebia.functional.xef.llm.models.functions.FunChatCompletionRequest
import com.xebia.functional.xef.llm.models.functions.encodeJsonSchema
import com.xebia.functional.xef.prompt.Prompt
import io.github.oshai.kotlinlogging.KotlinLogging
Expand All @@ -17,12 +19,16 @@ import kotlinx.serialization.KSerializer
import kotlinx.serialization.descriptors.*
import kotlinx.serialization.json.*

interface ChatWithFunctions : Chat {
interface ChatWithFunctions : LLM {

suspend fun createChatCompletionWithFunctions(
request: ChatCompletionRequest
request: FunChatCompletionRequest
): ChatCompletionResponseWithFunctions

suspend fun createChatCompletionsWithFunctions(
request: FunChatCompletionRequest
): Flow<ChatCompletionChunk>

@OptIn(ExperimentalSerializationApi::class)
fun chatFunction(descriptor: SerialDescriptor): CFunction {
val fnName = descriptor.serialName.substringAfterLast(".")
Expand Down Expand Up @@ -60,11 +66,33 @@ interface ChatWithFunctions : Chat {
serializer: (json: String) -> A,
): A {
val promptWithFunctions = prompt.copy(function = function)
val adaptedPrompt =
PromptCalculator.adaptPromptToConversationAndModel(
promptWithFunctions,
scope,
this@ChatWithFunctions
)

val request =
FunChatCompletionRequest(
user = adaptedPrompt.configuration.user,
messages = adaptedPrompt.messages,
n = adaptedPrompt.configuration.numberOfPredictions,
temperature = adaptedPrompt.configuration.temperature,
maxTokens = adaptedPrompt.configuration.minResponseTokens,
functions = adaptedPrompt.function!!.nel(),
functionCall = mapOf("name" to (adaptedPrompt.function.name)),
)

return tryDeserialize(
serializer,
promptWithFunctions.configuration.maxDeserializationAttempts
) {
promptMessages(prompt = promptWithFunctions, scope = scope)
val requestedMemories = prompt.messages.toMemory(this@ChatWithFunctions, scope)
createChatCompletionWithFunctions(request)
.choices
.addMessagesToMemory(this@ChatWithFunctions, scope, requestedMemories)
.mapNotNull { it.message?.functionCall?.arguments }
}
}

Expand All @@ -84,16 +112,15 @@ interface ChatWithFunctions : Chat {
)

val request =
ChatCompletionRequest(
model = name,
FunChatCompletionRequest(
stream = true,
user = promptWithFunctions.configuration.user,
messages = messagesForRequestPrompt.messages,
functions = listOfNotNull(messagesForRequestPrompt.function),
n = promptWithFunctions.configuration.numberOfPredictions,
temperature = promptWithFunctions.configuration.temperature,
maxTokens = promptWithFunctions.configuration.minResponseTokens,
functionCall = mapOf("name" to (promptWithFunctions.function?.name ?: ""))
functions = promptWithFunctions.function!!.nel(),
functionCall = mapOf("name" to (promptWithFunctions.function.name)),
)

StreamedFunction.run {
Expand All @@ -102,6 +129,7 @@ interface ChatWithFunctions : Chat {
) {
streamFunctionCall(
chat = this@ChatWithFunctions,
promptMessages = prompt.messages,
request = request,
scope = scope,
serializer = serializer,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package com.xebia.functional.xef.llm

import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.llm.models.text.CompletionRequest
import com.xebia.functional.xef.llm.models.text.CompletionResult

interface Completion : LLM {
val modelType: ModelType

suspend fun createCompletion(request: CompletionRequest): CompletionResult
}
29 changes: 27 additions & 2 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLM.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,32 @@
package com.xebia.functional.xef.llm

import com.xebia.functional.tokenizer.Encoding
import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.llm.models.chat.Message

sealed interface LLM : AutoCloseable {
val name: String

override fun close() {}
val modelType: ModelType

@Deprecated("use modelType.name instead", replaceWith = ReplaceWith("modelType.name"))
val name
get() = modelType.name

fun tokensFromMessages(
messages: List<Message>
): Int { // TODO: naive implementation with magic numbers
fun Encoding.countTokensFromMessages(tokensPerMessage: Int, tokensPerName: Int): Int =
messages.sumOf { message ->
countTokens(message.role.name) +
countTokens(message.content) +
tokensPerMessage +
tokensPerName
} + 3
return modelType.encoding.countTokensFromMessages(
tokensPerMessage = modelType.tokensPerMessage,
tokensPerName = modelType.tokensPerName
) + modelType.tokenPadding
}

override fun close() = Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,100 +5,67 @@ import com.xebia.functional.xef.llm.models.chat.*
import com.xebia.functional.xef.store.Memory
import io.ktor.util.date.*

internal object MemoryManagement {
internal suspend fun List<Message>.addToMemory(chat: LLM, scope: Conversation) {
val memories = toMemory(chat, scope)
if (memories.isNotEmpty()) {
scope.store.addMemories(memories)
}
}

internal suspend fun addMemoriesAfterStream(
chat: Chat,
request: ChatCompletionRequest,
scope: Conversation,
messages: List<Message>,
) {
val lastRequestMessage = request.messages.lastOrNull()
val cid = scope.conversationId
if (cid != null && lastRequestMessage != null) {
val requestMemory =
Memory(
conversationId = cid,
content = lastRequestMessage,
timestamp = getTimeMillis(),
approxTokens = chat.tokensFromMessages(listOf(lastRequestMessage))
)
val responseMemories =
messages.map {
Memory(
conversationId = cid,
content = it,
timestamp = getTimeMillis(),
approxTokens = chat.tokensFromMessages(messages)
)
}
scope.store.addMemories(listOf(requestMemory) + responseMemories)
internal fun List<Message>.toMemory(chat: LLM, scope: Conversation): List<Memory> {
val cid = scope.conversationId
return if (cid != null) {
mapIndexed { delta, it ->
Memory(
conversationId = cid,
content = it,
// We are adding the delta to ensure that the timestamp is unique for every message.
// With this, we ensure that the order of the messages is preserved.
// We assume that the AI response time will be in the order of seconds.
timestamp = getTimeMillis() + delta,
approxTokens = chat.tokensFromMessages(listOf(it))
)
}
}
} else emptyList()
}

internal suspend fun List<ChoiceWithFunctions>.addChoiceWithFunctionsToMemory(
chat: Chat,
request: ChatCompletionRequest,
scope: Conversation
): List<ChoiceWithFunctions> = also {
val firstChoice = firstOrNull()
val requestUserMessage = request.messages.lastOrNull()
val cid = scope.conversationId
if (requestUserMessage != null && firstChoice != null && cid != null) {
val role = firstChoice.message?.role?.uppercase()?.let { Role.valueOf(it) } ?: Role.USER
internal suspend fun List<ChoiceWithFunctions>.addMessagesToMemory(
chat: LLM,
scope: Conversation,
previousMemories: List<Memory>
): List<ChoiceWithFunctions> = also {
val firstChoice = firstOrNull()
val cid = scope.conversationId
if (firstChoice != null && cid != null) {
val role = firstChoice.message?.role?.uppercase()?.let { Role.valueOf(it) } ?: Role.USER

val requestMemory =
Memory(
conversationId = cid,
content = requestUserMessage,
timestamp = getTimeMillis(),
approxTokens = chat.tokensFromMessages(listOf(requestUserMessage))
)
val firstChoiceMessage =
Message(
role = role,
content = firstChoice.message?.content
?: firstChoice.message?.functionCall?.arguments ?: "",
name = role.name
)
val firstChoiceMemory =
Memory(
conversationId = cid,
content = firstChoiceMessage,
timestamp = getTimeMillis(),
approxTokens = chat.tokensFromMessages(listOf(firstChoiceMessage))
)
scope.store.addMemories(listOf(requestMemory, firstChoiceMemory))
}
val firstChoiceMessage =
Message(
role = role,
content = firstChoice.message?.content
?: firstChoice.message?.functionCall?.arguments ?: "",
name = role.name
)

val newMessages = previousMemories + listOf(firstChoiceMessage).toMemory(chat, scope)
scope.store.addMemories(newMessages)
}
}

internal suspend fun List<Choice>.addChoiceToMemory(
chat: Chat,
request: ChatCompletionRequest,
scope: Conversation
): List<Choice> = also {
val firstChoice = firstOrNull()
val requestUserMessage = request.messages.lastOrNull()
val cid = scope.conversationId
if (requestUserMessage != null && firstChoice != null && cid != null) {
val role = firstChoice.message?.role?.name?.uppercase()?.let { Role.valueOf(it) } ?: Role.USER
val requestMemory =
Memory(
conversationId = cid,
content = requestUserMessage,
timestamp = getTimeMillis(),
approxTokens = chat.tokensFromMessages(listOf(requestUserMessage))
)
val firstChoiceMessage =
Message(role = role, content = firstChoice.message?.content ?: "", name = role.name)
val firstChoiceMemory =
Memory(
conversationId = cid,
content = firstChoiceMessage,
timestamp = getTimeMillis(),
approxTokens = chat.tokensFromMessages(listOf(firstChoiceMessage))
)
scope.store.addMemories(listOf(requestMemory, firstChoiceMemory))
}
internal suspend fun List<Choice>.addMessagesToMemory(
chat: Chat,
scope: Conversation,
previousMemories: List<Memory>
): List<Choice> = also {
val firstChoice = firstOrNull()
val cid = scope.conversationId
if (firstChoice != null && cid != null) {
val role = firstChoice.message?.role?.name?.uppercase()?.let { Role.valueOf(it) } ?: Role.USER

val firstChoiceMessage =
Message(role = role, content = firstChoice.message?.content ?: "", name = role.name)

val newMessages = previousMemories + listOf(firstChoiceMessage).toMemory(chat, scope)
scope.store.addMemories(newMessages)
}
}
Loading

0 comments on commit 774c796

Please sign in to comment.