Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

decompose class hierarchy of LLM implementations #397

Merged
merged 16 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 11 additions & 35 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.xebia.functional.xef.llm

import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
Expand All @@ -10,22 +9,18 @@ import com.xebia.functional.xef.prompt.templates.assistant
import kotlinx.coroutines.flow.*

interface Chat : LLM {
val modelType: ModelType

suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse

suspend fun createChatCompletions(request: ChatCompletionRequest): Flow<ChatCompletionChunk>

fun tokensFromMessages(messages: List<Message>): Int

@AiDsl
fun promptStreaming(prompt: Prompt, scope: Conversation): Flow<String> = flow {
val messagesForRequestPrompt =
PromptCalculator.adaptPromptToConversationAndModel(prompt, scope, this@Chat)

val request =
ChatCompletionRequest(
model = name,
user = prompt.configuration.user,
messages = messagesForRequestPrompt.messages,
n = prompt.configuration.numberOfPredictions,
Expand All @@ -40,7 +35,12 @@ interface Chat : LLM {
.fold("", String::plus)
.also { finalText ->
val message = assistant(finalText)
MemoryManagement.addMemoriesAfterStream(this@Chat, request, scope, listOf(message))
MemoryManagement.addMemoriesAfterStream(
this@Chat,
request.messages.lastOrNull(),
scope,
listOf(message)
)
}
}

Expand All @@ -50,46 +50,22 @@ interface Chat : LLM {

@AiDsl
suspend fun promptMessages(prompt: Prompt, scope: Conversation): List<String> {

val adaptedPrompt = PromptCalculator.adaptPromptToConversationAndModel(prompt, scope, this@Chat)

fun chatRequest(): ChatCompletionRequest =
val request =
ChatCompletionRequest(
model = name,
user = adaptedPrompt.configuration.user,
messages = adaptedPrompt.messages,
n = adaptedPrompt.configuration.numberOfPredictions,
temperature = adaptedPrompt.configuration.temperature,
maxTokens = adaptedPrompt.configuration.minResponseTokens,
functions = listOfNotNull(adaptedPrompt.function),
functionCall = adaptedPrompt.function?.let { mapOf("name" to (it.name)) }
)

return MemoryManagement.run {
when (this@Chat) {
is ChatWithFunctions ->
// we only support functions for now with GPT_3_5_TURBO_FUNCTIONS
if (modelType == ModelType.GPT_3_5_TURBO_FUNCTIONS) {
val request = chatRequest()
createChatCompletionWithFunctions(request)
.choices
.addChoiceWithFunctionsToMemory(this@Chat, request, scope)
.mapNotNull { it.message?.functionCall?.arguments }
} else {
val request = chatRequest()
createChatCompletion(request)
.choices
.addChoiceToMemory(this@Chat, request, scope)
.mapNotNull { it.message?.content }
}
else -> {
val request = chatRequest()
createChatCompletion(request)
.choices
.addChoiceToMemory(this@Chat, request, scope)
.mapNotNull { it.message?.content }
}
}
createChatCompletion(request)
.choices
.addChoiceToMemory(this@Chat, request, scope)
.mapNotNull { it.message?.content }
}
}
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package com.xebia.functional.xef.llm

import arrow.core.nel
import arrow.core.nonFatalOrThrow
import arrow.core.raise.catch
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.models.chat.ChatCompletionRequest
import com.xebia.functional.xef.llm.models.chat.ChatCompletionChunk
import com.xebia.functional.xef.llm.models.chat.ChatCompletionResponseWithFunctions
import com.xebia.functional.xef.llm.models.functions.CFunction
import com.xebia.functional.xef.llm.models.functions.FunChatCompletionRequest
import com.xebia.functional.xef.llm.models.functions.encodeJsonSchema
import com.xebia.functional.xef.prompt.Prompt
import io.github.oshai.kotlinlogging.KotlinLogging
Expand All @@ -17,12 +19,16 @@ import kotlinx.serialization.KSerializer
import kotlinx.serialization.descriptors.*
import kotlinx.serialization.json.*

interface ChatWithFunctions : Chat {
interface ChatWithFunctions : LLM {

suspend fun createChatCompletionWithFunctions(
request: ChatCompletionRequest
request: FunChatCompletionRequest
): ChatCompletionResponseWithFunctions

suspend fun createChatCompletionsWithFunctions(
request: FunChatCompletionRequest
): Flow<ChatCompletionChunk>

@OptIn(ExperimentalSerializationApi::class)
fun chatFunction(descriptor: SerialDescriptor): CFunction {
val fnName = descriptor.serialName.substringAfterLast(".")
Expand Down Expand Up @@ -60,11 +66,38 @@ interface ChatWithFunctions : Chat {
serializer: (json: String) -> A,
): A {
val promptWithFunctions = prompt.copy(function = function)
val adaptedPrompt =
PromptCalculator.adaptPromptToConversationAndModel(
promptWithFunctions,
scope,
this@ChatWithFunctions
)

val request =
FunChatCompletionRequest(
user = adaptedPrompt.configuration.user,
messages = adaptedPrompt.messages,
n = adaptedPrompt.configuration.numberOfPredictions,
temperature = adaptedPrompt.configuration.temperature,
maxTokens = adaptedPrompt.configuration.minResponseTokens,
functions = adaptedPrompt.function!!.nel(),
functionCall = mapOf("name" to (adaptedPrompt.function.name)),
)

return tryDeserialize(
serializer,
promptWithFunctions.configuration.maxDeserializationAttempts
) {
promptMessages(prompt = promptWithFunctions, scope = scope)
MemoryManagement.run {
createChatCompletionWithFunctions(request)
.choices
.addChoiceWithFunctionsToMemory(
this@ChatWithFunctions,
request.messages.lastOrNull(),
scope
)
.mapNotNull { it.message?.functionCall?.arguments }
}
}
}

Expand All @@ -84,16 +117,15 @@ interface ChatWithFunctions : Chat {
)

val request =
ChatCompletionRequest(
model = name,
FunChatCompletionRequest(
stream = true,
user = promptWithFunctions.configuration.user,
messages = messagesForRequestPrompt.messages,
functions = listOfNotNull(messagesForRequestPrompt.function),
n = promptWithFunctions.configuration.numberOfPredictions,
temperature = promptWithFunctions.configuration.temperature,
maxTokens = promptWithFunctions.configuration.minResponseTokens,
functionCall = mapOf("name" to (promptWithFunctions.function?.name ?: ""))
functions = promptWithFunctions.function!!.nel(),
functionCall = mapOf("name" to (promptWithFunctions.function.name)),
)

StreamedFunction.run {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package com.xebia.functional.xef.llm

import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.llm.models.text.CompletionRequest
import com.xebia.functional.xef.llm.models.text.CompletionResult

interface Completion : LLM {
val modelType: ModelType

suspend fun createCompletion(request: CompletionRequest): CompletionResult
}
30 changes: 28 additions & 2 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLM.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,33 @@
package com.xebia.functional.xef.llm

import com.xebia.functional.tokenizer.Encoding
import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.llm.models.chat.Message

sealed interface LLM : AutoCloseable {
val name: String

override fun close() {}
val modelType: ModelType

@Deprecated("use modelType.name instead", replaceWith = ReplaceWith("modelType.name"))
val name
get() = modelType.name

fun tokensFromMessages(
messages: List<Message>
): Int { // TODO: naive implementation with magic numbers
fun Encoding.countTokensFromMessages(tokensPerMessage: Int, tokensPerName: Int): Int =
messages.sumOf { message ->
message.role.name.length
Intex32 marked this conversation as resolved.
Show resolved Hide resolved
countTokens(message.role.name) +
countTokens(message.content) +
tokensPerMessage +
tokensPerName
} + 3
return modelType.encoding.countTokensFromMessages(
tokensPerMessage = modelType.tokensPerMessage,
tokensPerName = modelType.tokensPerName
) + modelType.tokenPadding
}

override fun close() = Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@ import io.ktor.util.date.*
internal object MemoryManagement {

internal suspend fun addMemoriesAfterStream(
chat: Chat,
request: ChatCompletionRequest,
chat: LLM,
lastRequestMessage: Message?,
Intex32 marked this conversation as resolved.
Show resolved Hide resolved
scope: Conversation,
messages: List<Message>,
) {
val lastRequestMessage = request.messages.lastOrNull()
val cid = scope.conversationId
if (cid != null && lastRequestMessage != null) {
val requestMemory =
Expand All @@ -37,12 +36,11 @@ internal object MemoryManagement {
}

internal suspend fun List<ChoiceWithFunctions>.addChoiceWithFunctionsToMemory(
chat: Chat,
request: ChatCompletionRequest,
chat: LLM,
requestUserMessage: Message?,
Intex32 marked this conversation as resolved.
Show resolved Hide resolved
scope: Conversation
): List<ChoiceWithFunctions> = also {
val firstChoice = firstOrNull()
val requestUserMessage = request.messages.lastOrNull()
val cid = scope.conversationId
if (requestUserMessage != null && firstChoice != null && cid != null) {
val role = firstChoice.message?.role?.uppercase()?.let { Role.valueOf(it) } ?: Role.USER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ internal object PromptCalculator {
suspend fun adaptPromptToConversationAndModel(
prompt: Prompt,
scope: Conversation,
chat: Chat
chat: LLM
): Prompt {

// calculate tokens for history and context
Expand Down Expand Up @@ -55,7 +55,7 @@ internal object PromptCalculator {
memories.map { it.content }

private fun calculateMessagesFromHistory(
chat: Chat,
chat: LLM,
memories: List<Memory>,
maxHistoryTokens: Int
) =
Expand Down Expand Up @@ -94,7 +94,7 @@ internal object PromptCalculator {
return maxHistoryTokens
}

private fun calculateRemainingTokensForContext(chat: Chat, prompt: Prompt): Int {
private fun calculateRemainingTokensForContext(chat: LLM, prompt: Prompt): Int {
val maxContextLength: Int = chat.modelType.maxContextLength
val remainingTokens: Int = maxContextLength - prompt.configuration.minResponseTokens

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package com.xebia.functional.xef.llm

import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.StreamedFunction.Companion.PropertyType.*
import com.xebia.functional.xef.llm.models.chat.ChatCompletionRequest
import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.functions.CFunction
import com.xebia.functional.xef.llm.models.functions.FunChatCompletionRequest
import com.xebia.functional.xef.llm.models.functions.FunctionCall
import com.xebia.functional.xef.prompt.templates.assistant
import kotlin.jvm.JvmSynthetic
Expand Down Expand Up @@ -39,7 +39,7 @@ sealed class StreamedFunction<out A> {
@JvmSynthetic
internal suspend fun <A> FlowCollector<StreamedFunction<A>>.streamFunctionCall(
chat: ChatWithFunctions,
request: ChatCompletionRequest,
request: FunChatCompletionRequest,
scope: Conversation,
serializer: (json: String) -> A,
function: CFunction
Expand All @@ -59,8 +59,15 @@ sealed class StreamedFunction<out A> {
// as the LLM is sending us chunks with malformed JSON
val example = createExampleFromSchema(schema)
chat
.createChatCompletions(request)
.onCompletion { MemoryManagement.addMemoriesAfterStream(chat, request, scope, messages) }
.createChatCompletionsWithFunctions(request)
.onCompletion {
MemoryManagement.addMemoriesAfterStream(
chat,
request.messages.lastOrNull(),
scope,
messages
)
}
.collect { responseChunk ->
// Each chunk is emitted from the LLM and it will include a delta.parameters with
// the function is streaming, the JSON received will be partial and usually malformed
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
package com.xebia.functional.xef.llm.models.chat

import com.xebia.functional.xef.llm.models.functions.CFunction

data class ChatCompletionRequest(
val model: String,
val messages: List<Message>,
val functions: List<CFunction> = emptyList(),
val temperature: Double = 0.0,
val topP: Double = 1.0,
val n: Int = 1,
Expand All @@ -17,5 +13,4 @@ data class ChatCompletionRequest(
val logitBias: Map<String, Int> = emptyMap(),
val user: String?,
val streamToStandardOut: Boolean = false,
val functionCall: Map<String, String>? = null
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.xebia.functional.xef.llm.models.functions

import arrow.core.Nel
import com.xebia.functional.xef.llm.models.chat.Message

data class FunChatCompletionRequest(
val messages: List<Message>,
val temperature: Double = 0.0,
val topP: Double = 1.0,
val n: Int = 1,
val stream: Boolean = false,
val stop: List<String>? = null,
val maxTokens: Int? = null,
val presencePenalty: Double = 0.0,
val frequencyPenalty: Double = 0.0,
val logitBias: Map<String, Int> = emptyMap(),
val user: String?,
val streamToStandardOut: Boolean = false,
val functions: Nel<CFunction>,
javipacheco marked this conversation as resolved.
Show resolved Hide resolved
val functionCall: Map<String, String>? = null,
)
Loading