Skip to content

Commit

Permalink
Generate tokens optimistically matching a pattern (#155)
Browse files Browse the repository at this point in the history
* feat: generate optimistically some pattern tokens, reducing the overall calls
* refactor: improve readability on patternPrompt
* example: update Alphabet.kt pessimistic tokens part
* style: spotless

---------

Co-authored-by: Juan Pedro Moreno <[email protected]>
  • Loading branch information
realdavidvega and juanpedromoreno authored Jun 5, 2023
1 parent da5c41e commit 5440ff0
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 39 deletions.
118 changes: 89 additions & 29 deletions core/src/jvmMain/kotlin/com/xebia/functional/xef/agents/LLMAgent.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.xebia.functional.xef.agents

import com.xebia.functional.tokenizer.Encoding
import com.xebia.functional.tokenizer.TokenVocabulary
import com.xebia.functional.xef.auto.AIScope
import com.xebia.functional.xef.llm.openai.ChatCompletionRequest
Expand All @@ -16,9 +17,11 @@ suspend fun AIScope.patternPrompt(
user: String = "testing",
n: Int = 1,
echo: Boolean = false,
temperature: Double = 0.0,
maxNewTokens: Int = 30,
stopAfterMatch: Boolean = true
temperature: Double = 0.5,
maxIterations: Int = 30,
maxTokensPerCompletion: Int = 1,
stopAfterMatch: Boolean = true,
logitBiasMaxSize: Int = 300
): String =
patternPrompt(
prompt,
Expand All @@ -28,11 +31,13 @@ suspend fun AIScope.patternPrompt(
n,
echo,
temperature,
maxNewTokens,
maxIterations,
maxTokensPerCompletion,
stopAfterMatch,
genTokens = 0,
iterations = 0,
partialCompletion = "",
tokenVocab = TokenVocabulary(model.modelType.encodingType)
tokensVocab = TokenVocabulary(model.modelType.encodingType),
logitBiasMaxSize = logitBiasMaxSize
)

private suspend fun AIScope.patternPrompt(
Expand All @@ -43,24 +48,58 @@ private suspend fun AIScope.patternPrompt(
n: Int,
echo: Boolean,
temperature: Double,
maxNewTokens: Int,
maxIterations: Int,
maxTokensPerCompletion: Int = 1,
stopAfterMatch: Boolean,
genTokens: Int,
iterations: Int,
partialCompletion: String,
tokenVocab: TokenVocabulary
tokensVocab: TokenVocabulary,
logitBiasMaxSize: Int = 300
): String {
if (genTokens >= maxNewTokens) return partialCompletion
if (iterations >= maxIterations) return partialCompletion

val logitBias: Map<String, Int> = tokenVocab.buildLogitBias(partialCompletion, pattern)
val patternLogitBias: Map<String, Int> =
tokensVocab.buildPatternLogitBias(partialCompletion, pattern, maxLength = logitBiasMaxSize)

val encoding: Encoding = model.modelType.encodingType.encoding

val logitBias: Map<String, Int> =
if (patternLogitBias.size < logitBiasMaxSize && maxTokensPerCompletion > 1) {
buildMap {
putAll(patternLogitBias)

patternLogitBias.entries.asSequence().forEach { (key: String) ->
val token: String = encoding.decode(listOf(key.toInt()))
val tokenLogitBias: Map<String, Int> =
tokensVocab.buildPatternLogitBias(
partialCompletion = partialCompletion + token,
pattern,
maxLength = logitBiasMaxSize - size
)
putAll(tokenLogitBias)
}
}
} else {
patternLogitBias
}

val outputCompletion: List<String> =
patternPrompt(model, user, prompt, echo, n, temperature, logitBias)
patternPrompt(model, user, prompt, echo, n, temperature, logitBias, maxTokensPerCompletion)

val nextPartialCompletion: String = partialCompletion + outputCompletion[0]
val nextPromptPlusCompletion: String = prompt + outputCompletion[0]
val output: String = outputCompletion[0]
val nextPartialCompletionOutput: String = partialCompletion + output

if (stopAfterMatch && pattern.matches(nextPartialCompletion)) {
return nextPartialCompletion
val cleanOutput: String =
nextPartialCompletionOutput
.removeValuesFromEndUntilRegexMet(tokensVocab.decodedTokens, pattern)
.replace(partialCompletion, "")
.ifEmpty { encoding.getLongestMatchingPattern(patternLogitBias) }

val nextCleanPartialCompletion: String = partialCompletion + cleanOutput
val nextPromptPlusCompletion: String = prompt + cleanOutput

if (stopAfterMatch && pattern.matches(nextCleanPartialCompletion)) {
return nextCleanPartialCompletion
}

println(nextPromptPlusCompletion)
Expand All @@ -73,11 +112,12 @@ private suspend fun AIScope.patternPrompt(
n,
echo,
temperature,
maxNewTokens,
maxIterations,
maxTokensPerCompletion,
stopAfterMatch,
genTokens = genTokens + 1,
nextPartialCompletion,
tokenVocab
iterations = iterations + 1,
nextCleanPartialCompletion,
tokensVocab
)
}

Expand All @@ -88,7 +128,8 @@ private suspend fun AIScope.patternPrompt(
echo: Boolean,
n: Int,
temperature: Double,
logitBias: Map<String, Int>
logitBias: Map<String, Int>,
maxTokensPerCompletion: Int = 1,
): List<String> =
when (model.kind) {
LLMModel.Kind.Completion -> {
Expand All @@ -100,7 +141,7 @@ private suspend fun AIScope.patternPrompt(
echo = echo,
n = n,
temperature = temperature,
maxTokens = 1,
maxTokens = maxTokensPerCompletion,
logitBias = logitBias
)
openAIClient.createCompletion(request).choices.map { it.text }
Expand All @@ -114,27 +155,46 @@ private suspend fun AIScope.patternPrompt(
temperature = temperature,
n = n,
user = user,
maxTokens = 1,
maxTokens = maxTokensPerCompletion,
logitBias = logitBias
)
openAIClient.createChatCompletion(request).choices.map { it.message.content }
}
}

private fun TokenVocabulary.buildLogitBias(
private fun TokenVocabulary.buildPatternLogitBias(
partialCompletion: String,
pattern: Regex
pattern: Regex,
maxLength: Int = 300,
bias: Int = 100
): Map<String, Int> = buildMap {
val openAILimit = 300
val exclusiveBias = 100
decodedTokens
.asSequence()
.filter { pattern.partialMatch(partialCompletion + it.value) }
.take(openAILimit)
.forEach { put("${it.key}", exclusiveBias) }
.take(maxLength)
.forEach { put("${it.key}", bias) }
}

private fun Regex.partialMatch(input: String): Boolean {
val matcher: Matcher = toPattern().matcher(input)
return matcher.matches().or(matcher.hitEnd())
}

private fun String.removeValuesFromEndUntilRegexMet(
tokens: Map<Int, String>,
pattern: Regex
): String =
if (isEmpty() || matchesRegex(tokens, pattern)) {
this
} else {
dropLast(n = 1).removeValuesFromEndUntilRegexMet(tokens, pattern)
}

private fun Encoding.getLongestMatchingPattern(patternLogitBias: Map<String, Int>): String =
patternLogitBias.entries
.map { (key: String) -> decode(listOf(key.toInt())) }
.maxByOrNull { it.length }
?: ""

private fun String.matchesRegex(tokens: Map<Int, String>, pattern: Regex): Boolean =
pattern.matches(this) || tokens.asSequence().any { pattern.partialMatch(this + it.value) }
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,33 @@ package com.xebia.functional.xef.auto.pattern
import com.xebia.functional.xef.agents.patternPrompt
import com.xebia.functional.xef.auto.ai
import com.xebia.functional.xef.auto.getOrElse
import com.xebia.functional.xef.auto.prompt
import kotlinx.serialization.json.Json

suspend fun main() {
val enableComparison = false

ai {
val goal = "Return the first three letters of the alphabet in a json array: "
val patternResponse: String = patternPrompt(
val goal = "Return the first three letters of the alphabet in the format of a JSON array: "
val pattern = Regex("""\["[a-z]", "[a-z]", "[a-z]"]""")

val pessimistic: String = patternPrompt(
prompt = goal,
pattern = Regex("""\["[a-z]", "[a-z]", "[a-z]"]"""),
maxNewTokens = 20
pattern = pattern,
maxIterations = 10,
maxTokensPerCompletion = 1
)
val list: List<String> = Json.decodeFromString(patternResponse)
println(list)
val pessimisticDecoded: List<String> = Json.decodeFromString(pessimistic)
println(pessimisticDecoded)

if (enableComparison) {
val response: List<String> = prompt(goal)
println(response)
val optimistic: String = patternPrompt(
prompt = goal,
pattern = pattern,
maxIterations = 10,
maxTokensPerCompletion = 3
)
val optimisticDecoded: List<String> = Json.decodeFromString(optimistic)
println(optimisticDecoded)
}

}.getOrElse { println(it) }
}

0 comments on commit 5440ff0

Please sign in to comment.