Skip to content

Commit

Permalink
adjust AutoClose and scope of provider client
Browse files Browse the repository at this point in the history
  • Loading branch information
Intex32 committed Sep 18, 2023
1 parent 31a8e89 commit e6b32cb
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,12 @@ class GCP(projectId: String? = null, location: VertexAIRegion? = null, token: St
private fun fromEnv(name: String): String =
getenv(name) ?: throw AIError.Env.GCP(nonEmptyListOf("missing $name env var"))

val CODECHAT by lazy { GcpChat(ModelType.TODO("codechat-bison@001"), config) }
val TEXT_EMBEDDING_GECKO by lazy { GcpEmbeddings(ModelType.TODO("textembedding-gecko"), config) }
val defaultClient = GcpClient(config)

val CODECHAT by lazy { GcpChat(ModelType.TODO("codechat-bison@001"), defaultClient) }
val TEXT_EMBEDDING_GECKO by lazy {
GcpEmbeddings(ModelType.TODO("textembedding-gecko"), defaultClient)
}

@JvmField val DEFAULT_CHAT = CODECHAT
@JvmField val DEFAULT_EMBEDDING = TEXT_EMBEDDING_GECKO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

class GcpClient(
val modelId: String,
private val config: GcpConfig,
) : AutoClose by autoClose() {
private val http: HttpClient = jsonHttpClient()
Expand Down Expand Up @@ -66,6 +65,7 @@ class GcpClient(
)

suspend fun promptMessage(
modelId: String,
prompt: String,
temperature: Double? = null,
maxOutputTokens: Int? = null,
Expand Down Expand Up @@ -131,7 +131,7 @@ class GcpClient(
)
val response =
http.post(
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location.officialName}/publishers/google/models/$modelId:predict"
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location.officialName}/publishers/google/models/${request.model}:predict"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package com.xebia.functional.xef.gcp.models

import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.conversation.autoClose
import com.xebia.functional.xef.gcp.GcpClient
import com.xebia.functional.xef.gcp.GcpConfig
import com.xebia.functional.xef.llm.Chat
import com.xebia.functional.xef.llm.models.chat.*
import com.xebia.functional.xef.llm.models.usage.Usage
Expand All @@ -16,27 +14,26 @@ import kotlinx.uuid.generateUUID

class GcpChat(
override val modelType: ModelType,
private val config: GcpConfig,
private val client: GcpClient,
) : Chat {

private val client: GcpClient = autoClose { GcpClient(modelType.name, config) }

override suspend fun createChatCompletion(
request: ChatCompletionRequest
): ChatCompletionResponse {
val prompt: String = request.messages.buildPrompt()
val response: String =
client.promptMessage(
modelType.name,
prompt,
temperature = request.temperature,
maxOutputTokens = request.maxTokens,
topP = request.topP
)
return ChatCompletionResponse(
UUID.generateUUID().toString(),
client.modelId,
modelType.name,
getTimeMillis().toInt(),
client.modelId,
modelType.name,
Usage.ZERO, // TODO: token usage - no information about usage provided by GCP
listOf(Choice(Message(Role.ASSISTANT, response, Role.ASSISTANT.name), null, 0)),
)
Expand All @@ -51,6 +48,7 @@ class GcpChat(
val prompt: String = messages.buildPrompt()
val response =
client.promptMessage(
modelType.name,
prompt,
temperature = request.temperature,
maxOutputTokens = request.maxTokens,
Expand All @@ -60,7 +58,7 @@ class GcpChat(
ChatCompletionChunk(
UUID.generateUUID().toString(),
getTimeMillis().toInt(),
client.modelId,
modelType.name,
listOf(ChatChunk(delta = ChatDelta(Role.ASSISTANT, response))),
Usage
.ZERO, // TODO: token usage - no information about usage provided by GCP for codechat
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package com.xebia.functional.xef.gcp.models

import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.conversation.autoClose
import com.xebia.functional.xef.gcp.GcpClient
import com.xebia.functional.xef.gcp.GcpConfig
import com.xebia.functional.xef.llm.Completion
import com.xebia.functional.xef.llm.models.text.CompletionChoice
import com.xebia.functional.xef.llm.models.text.CompletionRequest
Expand All @@ -15,24 +13,23 @@ import kotlinx.uuid.generateUUID

class GcpCompletion(
override val modelType: ModelType,
config: GcpConfig,
private val client: GcpClient,
) : Completion {

private val client: GcpClient = autoClose { GcpClient(modelType.name, config) }

override suspend fun createCompletion(request: CompletionRequest): CompletionResult {
val response: String =
client.promptMessage(
modelType.name,
request.prompt,
temperature = request.temperature,
maxOutputTokens = request.maxTokens,
topP = request.topP
)
return CompletionResult(
UUID.generateUUID().toString(),
client.modelId,
modelType.name,
getTimeMillis(),
client.modelId,
modelType.name,
listOf(CompletionChoice(response, 0, null, null)),
Usage.ZERO, // TODO: token usage - no information about usage provided by GCP codechat model
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package com.xebia.functional.xef.gcp.models

import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.conversation.AutoClose
import com.xebia.functional.xef.conversation.autoClose
import com.xebia.functional.xef.gcp.GcpClient
import com.xebia.functional.xef.gcp.GcpConfig
import com.xebia.functional.xef.llm.Embeddings
import com.xebia.functional.xef.llm.models.embeddings.Embedding
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
Expand All @@ -13,11 +10,8 @@ import com.xebia.functional.xef.llm.models.usage.Usage

class GcpEmbeddings(
override val modelType: ModelType,
private val config: GcpConfig,
) : Embeddings, AutoClose by autoClose() {

private val client: GcpClient =
com.xebia.functional.xef.conversation.autoClose { GcpClient(modelType.name, config) }
private val client: GcpClient,
) : Embeddings {

override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult {
fun requestToEmbedding(it: GcpClient.EmbeddingPredictions): Embedding =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.xebia.functional.xef.conversation.llm.openai

import com.aallam.openai.client.Closeable
import com.xebia.functional.xef.conversation.AutoClose

/** integration to aallam's [Closeable] */
internal fun <A : Closeable> AutoClose.autoClose(closeable: A): A {
val wrapper =
object : AutoCloseable {
override fun close() = closeable.close()
}
autoClose(wrapper)
return closeable
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,59 +57,60 @@ class OpenAI(internal var token: String? = null, internal var host: String? = nu
}
}

fun createClient() =
val defaultClient =
OpenAIClient(
host = getHost()?.let { OpenAIHost(it) } ?: OpenAIHost.OpenAI,
token = getToken(),
logging = LoggingConfig(LogLevel.None),
headers = mapOf("Authorization" to " Bearer ${getToken()}"),
)
host = getHost()?.let { OpenAIHost(it) } ?: OpenAIHost.OpenAI,
token = getToken(),
logging = LoggingConfig(LogLevel.None),
headers = mapOf("Authorization" to " Bearer ${getToken()}"),
)
.let { autoClose(it) }

val GPT_4 by lazy { autoClose(OpenAIFunChat(ModelType.GPT_4, createClient())) }
val GPT_4 by lazy { autoClose(OpenAIFunChat(ModelType.GPT_4, defaultClient)) }

val GPT_4_0314 by lazy {
autoClose(OpenAIFunChat(ModelType.GPT_4, createClient())) // legacy
autoClose(OpenAIFunChat(ModelType.GPT_4, defaultClient)) // legacy
}

val GPT_4_32K by lazy { autoClose(OpenAIFunChat(ModelType.GPT_4_32K, createClient())) }
val GPT_4_32K by lazy { autoClose(OpenAIFunChat(ModelType.GPT_4_32K, defaultClient)) }

val GPT_3_5_TURBO by lazy { autoClose(OpenAIChat(ModelType.GPT_3_5_TURBO, createClient())) }
val GPT_3_5_TURBO by lazy { autoClose(OpenAIChat(ModelType.GPT_3_5_TURBO, defaultClient)) }

val GPT_3_5_TURBO_16K by lazy {
autoClose(OpenAIChat(ModelType.GPT_3_5_TURBO_16_K, createClient()))
autoClose(OpenAIChat(ModelType.GPT_3_5_TURBO_16_K, defaultClient))
}

val GPT_3_5_TURBO_FUNCTIONS by lazy {
autoClose(OpenAIFunChat(ModelType.GPT_3_5_TURBO_FUNCTIONS, createClient()))
autoClose(OpenAIFunChat(ModelType.GPT_3_5_TURBO_FUNCTIONS, defaultClient))
}

val GPT_3_5_TURBO_0301 by lazy {
autoClose(OpenAIChat(ModelType.GPT_3_5_TURBO, createClient())) // legacy
autoClose(OpenAIChat(ModelType.GPT_3_5_TURBO, defaultClient)) // legacy
}

val TEXT_DAVINCI_003 by lazy {
autoClose(OpenAICompletion(ModelType.TEXT_DAVINCI_003, createClient()))
autoClose(OpenAICompletion(ModelType.TEXT_DAVINCI_003, defaultClient))
}

val TEXT_DAVINCI_002 by lazy {
autoClose(OpenAICompletion(ModelType.TEXT_DAVINCI_002, createClient()))
autoClose(OpenAICompletion(ModelType.TEXT_DAVINCI_002, defaultClient))
}

val TEXT_CURIE_001 by lazy {
autoClose(OpenAICompletion(ModelType.TEXT_SIMILARITY_CURIE_001, createClient()))
autoClose(OpenAICompletion(ModelType.TEXT_SIMILARITY_CURIE_001, defaultClient))
}

val TEXT_BABBAGE_001 by lazy {
autoClose(OpenAICompletion(ModelType.TEXT_BABBAGE_001, createClient()))
autoClose(OpenAICompletion(ModelType.TEXT_BABBAGE_001, defaultClient))
}

val TEXT_ADA_001 by lazy { autoClose(OpenAICompletion(ModelType.TEXT_ADA_001, createClient())) }
val TEXT_ADA_001 by lazy { autoClose(OpenAICompletion(ModelType.TEXT_ADA_001, defaultClient)) }

val TEXT_EMBEDDING_ADA_002 by lazy {
autoClose(OpenAIEmbeddings(ModelType.TEXT_EMBEDDING_ADA_002, createClient()))
autoClose(OpenAIEmbeddings(ModelType.TEXT_EMBEDDING_ADA_002, defaultClient))
}

val DALLE_2 by lazy { autoClose(OpenAIImages(ModelType.GPT_3_5_TURBO, createClient())) }
val DALLE_2 by lazy { autoClose(OpenAIImages(ModelType.GPT_3_5_TURBO, defaultClient)) }

@JvmField val DEFAULT_CHAT = GPT_3_5_TURBO_16K

Expand Down

0 comments on commit e6b32cb

Please sign in to comment.