Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

decompose class hierarchy of LLM implementations #397

Merged
merged 16 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 9 additions & 38 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.xebia.functional.xef.llm

import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
Expand All @@ -10,22 +9,18 @@ import com.xebia.functional.xef.prompt.templates.assistant
import kotlinx.coroutines.flow.*

interface Chat : LLM {
val modelType: ModelType

suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse

suspend fun createChatCompletions(request: ChatCompletionRequest): Flow<ChatCompletionChunk>

fun tokensFromMessages(messages: List<Message>): Int

@AiDsl
fun promptStreaming(prompt: Prompt, scope: Conversation): Flow<String> = flow {
val messagesForRequestPrompt =
PromptCalculator.adaptPromptToConversationAndModel(prompt, scope, this@Chat)

val request =
ChatCompletionRequest(
model = name,
user = prompt.configuration.user,
messages = messagesForRequestPrompt.messages,
n = prompt.configuration.numberOfPredictions,
Expand All @@ -39,8 +34,9 @@ interface Chat : LLM {
.onEach { emit(it) }
.fold("", String::plus)
.also { finalText ->
val message = assistant(finalText)
MemoryManagement.addMemoriesAfterStream(this@Chat, request, scope, listOf(message))
val aiResponseMessage = assistant(finalText)
val newMessages = prompt.messages + listOf(aiResponseMessage)
newMessages.addToMemory(this@Chat, scope)
}
}

Expand All @@ -50,46 +46,21 @@ interface Chat : LLM {

@AiDsl
suspend fun promptMessages(prompt: Prompt, scope: Conversation): List<String> {

val requestedMemories = prompt.messages.toMemory(this@Chat, scope)
val adaptedPrompt = PromptCalculator.adaptPromptToConversationAndModel(prompt, scope, this@Chat)

fun chatRequest(): ChatCompletionRequest =
val request =
ChatCompletionRequest(
model = name,
user = adaptedPrompt.configuration.user,
messages = adaptedPrompt.messages,
n = adaptedPrompt.configuration.numberOfPredictions,
temperature = adaptedPrompt.configuration.temperature,
maxTokens = adaptedPrompt.configuration.minResponseTokens,
functions = listOfNotNull(adaptedPrompt.function),
functionCall = adaptedPrompt.function?.let { mapOf("name" to (it.name)) }
)

return MemoryManagement.run {
when (this@Chat) {
is ChatWithFunctions ->
// we only support functions for now with GPT_3_5_TURBO_FUNCTIONS
if (modelType == ModelType.GPT_3_5_TURBO_FUNCTIONS) {
val request = chatRequest()
createChatCompletionWithFunctions(request)
.choices
.addChoiceWithFunctionsToMemory(this@Chat, request, scope)
.mapNotNull { it.message?.functionCall?.arguments }
} else {
val request = chatRequest()
createChatCompletion(request)
.choices
.addChoiceToMemory(this@Chat, request, scope)
.mapNotNull { it.message?.content }
}
else -> {
val request = chatRequest()
createChatCompletion(request)
.choices
.addChoiceToMemory(this@Chat, request, scope)
.mapNotNull { it.message?.content }
}
}
}
return createChatCompletion(request)
.choices
.addMessagesToMemory(this@Chat, scope, requestedMemories)
.mapNotNull { it.message?.content }
}
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package com.xebia.functional.xef.llm

import arrow.core.nel
import arrow.core.nonFatalOrThrow
import arrow.core.raise.catch
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.models.chat.ChatCompletionRequest
import com.xebia.functional.xef.llm.models.chat.ChatCompletionChunk
import com.xebia.functional.xef.llm.models.chat.ChatCompletionResponseWithFunctions
import com.xebia.functional.xef.llm.models.functions.CFunction
import com.xebia.functional.xef.llm.models.functions.FunChatCompletionRequest
import com.xebia.functional.xef.llm.models.functions.encodeJsonSchema
import com.xebia.functional.xef.prompt.Prompt
import io.github.oshai.kotlinlogging.KotlinLogging
Expand All @@ -17,12 +19,16 @@ import kotlinx.serialization.KSerializer
import kotlinx.serialization.descriptors.*
import kotlinx.serialization.json.*

interface ChatWithFunctions : Chat {
interface ChatWithFunctions : LLM {

suspend fun createChatCompletionWithFunctions(
request: ChatCompletionRequest
request: FunChatCompletionRequest
): ChatCompletionResponseWithFunctions

suspend fun createChatCompletionsWithFunctions(
request: FunChatCompletionRequest
): Flow<ChatCompletionChunk>

@OptIn(ExperimentalSerializationApi::class)
fun chatFunction(descriptor: SerialDescriptor): CFunction {
val fnName = descriptor.serialName.substringAfterLast(".")
Expand Down Expand Up @@ -60,11 +66,33 @@ interface ChatWithFunctions : Chat {
serializer: (json: String) -> A,
): A {
val promptWithFunctions = prompt.copy(function = function)
val adaptedPrompt =
PromptCalculator.adaptPromptToConversationAndModel(
promptWithFunctions,
scope,
this@ChatWithFunctions
)

val request =
FunChatCompletionRequest(
user = adaptedPrompt.configuration.user,
messages = adaptedPrompt.messages,
n = adaptedPrompt.configuration.numberOfPredictions,
temperature = adaptedPrompt.configuration.temperature,
maxTokens = adaptedPrompt.configuration.minResponseTokens,
functions = adaptedPrompt.function!!.nel(),
functionCall = mapOf("name" to (adaptedPrompt.function.name)),
)

return tryDeserialize(
serializer,
promptWithFunctions.configuration.maxDeserializationAttempts
) {
promptMessages(prompt = promptWithFunctions, scope = scope)
val requestedMemories = prompt.messages.toMemory(this@ChatWithFunctions, scope)
createChatCompletionWithFunctions(request)
.choices
.addMessagesToMemory(this@ChatWithFunctions, scope, requestedMemories)
.mapNotNull { it.message?.functionCall?.arguments }
}
}

Expand All @@ -84,16 +112,15 @@ interface ChatWithFunctions : Chat {
)

val request =
ChatCompletionRequest(
model = name,
FunChatCompletionRequest(
stream = true,
user = promptWithFunctions.configuration.user,
messages = messagesForRequestPrompt.messages,
functions = listOfNotNull(messagesForRequestPrompt.function),
n = promptWithFunctions.configuration.numberOfPredictions,
temperature = promptWithFunctions.configuration.temperature,
maxTokens = promptWithFunctions.configuration.minResponseTokens,
functionCall = mapOf("name" to (promptWithFunctions.function?.name ?: ""))
functions = promptWithFunctions.function!!.nel(),
functionCall = mapOf("name" to (promptWithFunctions.function.name)),
)

StreamedFunction.run {
Expand All @@ -102,6 +129,7 @@ interface ChatWithFunctions : Chat {
) {
streamFunctionCall(
chat = this@ChatWithFunctions,
promptMessages = prompt.messages,
request = request,
scope = scope,
serializer = serializer,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package com.xebia.functional.xef.llm

import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.llm.models.text.CompletionRequest
import com.xebia.functional.xef.llm.models.text.CompletionResult

interface Completion : LLM {
val modelType: ModelType

suspend fun createCompletion(request: CompletionRequest): CompletionResult
}
29 changes: 27 additions & 2 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLM.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,32 @@
package com.xebia.functional.xef.llm

import com.xebia.functional.tokenizer.Encoding
import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.llm.models.chat.Message

sealed interface LLM : AutoCloseable {
val name: String

override fun close() {}
val modelType: ModelType

@Deprecated("use modelType.name instead", replaceWith = ReplaceWith("modelType.name"))
val name
get() = modelType.name

fun tokensFromMessages(
messages: List<Message>
): Int { // TODO: naive implementation with magic numbers
fun Encoding.countTokensFromMessages(tokensPerMessage: Int, tokensPerName: Int): Int =
messages.sumOf { message ->
countTokens(message.role.name) +
countTokens(message.content) +
tokensPerMessage +
tokensPerName
} + 3
return modelType.encoding.countTokensFromMessages(
tokensPerMessage = modelType.tokensPerMessage,
tokensPerName = modelType.tokensPerName
) + modelType.tokenPadding
}

override fun close() = Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,100 +5,67 @@ import com.xebia.functional.xef.llm.models.chat.*
import com.xebia.functional.xef.store.Memory
import io.ktor.util.date.*

internal object MemoryManagement {
internal suspend fun List<Message>.addToMemory(chat: LLM, scope: Conversation) {
val memories = toMemory(chat, scope)
if (memories.isNotEmpty()) {
scope.store.addMemories(memories)
}
}

internal suspend fun addMemoriesAfterStream(
chat: Chat,
request: ChatCompletionRequest,
scope: Conversation,
messages: List<Message>,
) {
val lastRequestMessage = request.messages.lastOrNull()
val cid = scope.conversationId
if (cid != null && lastRequestMessage != null) {
val requestMemory =
Memory(
conversationId = cid,
content = lastRequestMessage,
timestamp = getTimeMillis(),
approxTokens = chat.tokensFromMessages(listOf(lastRequestMessage))
)
val responseMemories =
messages.map {
Memory(
conversationId = cid,
content = it,
timestamp = getTimeMillis(),
approxTokens = chat.tokensFromMessages(messages)
)
}
scope.store.addMemories(listOf(requestMemory) + responseMemories)
internal fun List<Message>.toMemory(chat: LLM, scope: Conversation): List<Memory> {
val cid = scope.conversationId
return if (cid != null) {
mapIndexed { delta, it ->
Memory(
conversationId = cid,
content = it,
// We are adding the delta to ensure that the timestamp is unique for every message.
// With this, we ensure that the order of the messages is preserved.
// We assume that the AI response time will be in the order of seconds.
timestamp = getTimeMillis() + delta,
approxTokens = chat.tokensFromMessages(listOf(it))
)
}
}
} else emptyList()
}

internal suspend fun List<ChoiceWithFunctions>.addChoiceWithFunctionsToMemory(
chat: Chat,
request: ChatCompletionRequest,
scope: Conversation
): List<ChoiceWithFunctions> = also {
val firstChoice = firstOrNull()
val requestUserMessage = request.messages.lastOrNull()
val cid = scope.conversationId
if (requestUserMessage != null && firstChoice != null && cid != null) {
val role = firstChoice.message?.role?.uppercase()?.let { Role.valueOf(it) } ?: Role.USER
internal suspend fun List<ChoiceWithFunctions>.addMessagesToMemory(
chat: LLM,
scope: Conversation,
previousMemories: List<Memory>
): List<ChoiceWithFunctions> = also {
val firstChoice = firstOrNull()
val cid = scope.conversationId
if (firstChoice != null && cid != null) {
val role = firstChoice.message?.role?.uppercase()?.let { Role.valueOf(it) } ?: Role.USER

val requestMemory =
Memory(
conversationId = cid,
content = requestUserMessage,
timestamp = getTimeMillis(),
approxTokens = chat.tokensFromMessages(listOf(requestUserMessage))
)
val firstChoiceMessage =
Message(
role = role,
content = firstChoice.message?.content
?: firstChoice.message?.functionCall?.arguments ?: "",
name = role.name
)
val firstChoiceMemory =
Memory(
conversationId = cid,
content = firstChoiceMessage,
timestamp = getTimeMillis(),
approxTokens = chat.tokensFromMessages(listOf(firstChoiceMessage))
)
scope.store.addMemories(listOf(requestMemory, firstChoiceMemory))
}
val firstChoiceMessage =
Message(
role = role,
content = firstChoice.message?.content
?: firstChoice.message?.functionCall?.arguments ?: "",
name = role.name
)

val newMessages = previousMemories + listOf(firstChoiceMessage).toMemory(chat, scope)
scope.store.addMemories(newMessages)
}
}

internal suspend fun List<Choice>.addChoiceToMemory(
chat: Chat,
request: ChatCompletionRequest,
scope: Conversation
): List<Choice> = also {
val firstChoice = firstOrNull()
val requestUserMessage = request.messages.lastOrNull()
val cid = scope.conversationId
if (requestUserMessage != null && firstChoice != null && cid != null) {
val role = firstChoice.message?.role?.name?.uppercase()?.let { Role.valueOf(it) } ?: Role.USER
val requestMemory =
Memory(
conversationId = cid,
content = requestUserMessage,
timestamp = getTimeMillis(),
approxTokens = chat.tokensFromMessages(listOf(requestUserMessage))
)
val firstChoiceMessage =
Message(role = role, content = firstChoice.message?.content ?: "", name = role.name)
val firstChoiceMemory =
Memory(
conversationId = cid,
content = firstChoiceMessage,
timestamp = getTimeMillis(),
approxTokens = chat.tokensFromMessages(listOf(firstChoiceMessage))
)
scope.store.addMemories(listOf(requestMemory, firstChoiceMemory))
}
internal suspend fun List<Choice>.addMessagesToMemory(
chat: Chat,
scope: Conversation,
previousMemories: List<Memory>
): List<Choice> = also {
val firstChoice = firstOrNull()
val cid = scope.conversationId
if (firstChoice != null && cid != null) {
val role = firstChoice.message?.role?.name?.uppercase()?.let { Role.valueOf(it) } ?: Role.USER

val firstChoiceMessage =
Message(role = role, content = firstChoice.message?.content ?: "", name = role.name)

val newMessages = previousMemories + listOf(firstChoiceMessage).toMemory(chat, scope)
scope.store.addMemories(newMessages)
}
}
Loading