Skip to content

Commit

Permalink
Gcp runtime (#371)
Browse files Browse the repository at this point in the history
* First approach for GCP runtime

* first rework of gcp instantiation [WIP]

* revert changes out of scope of 0.0.3 release and issue

* location and projectId now defaulted from env vars

* small changes according to pr comments

---------

Co-authored-by: Javi Pacheco <[email protected]>
  • Loading branch information
Intex32 and javipacheco authored Sep 1, 2023
1 parent 3105b29 commit 61a48cb
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -1,34 +1,16 @@
package com.xebia.functional.xef.conversation.gpc

import arrow.core.nonEmptyListOf
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.llm.openai.OpenAI
import com.xebia.functional.xef.env.getenv
import com.xebia.functional.xef.gcp.GcpChat
import com.xebia.functional.xef.gcp.GcpConfig
import com.xebia.functional.xef.gcp.GcpEmbeddings
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import com.xebia.functional.xef.gcp.GCP
import com.xebia.functional.xef.gcp.promptMessage
import com.xebia.functional.xef.prompt.Prompt

suspend fun main() {
OpenAI.conversation {
val token =
getenv("GCP_TOKEN") ?: throw AIError.Env.GCP(nonEmptyListOf("missing GCP_TOKEN env var"))

val gcp =
GcpChat("codechat-bison@001", GcpConfig(token, "xefdemo", "us-central1")).let(::autoClose)
val gcpEmbeddingModel =
GcpChat("codechat-bison@001", GcpConfig(token, "xefdemo", "us-central1")).let(::autoClose)

val embeddingResult =
GcpEmbeddings(gcpEmbeddingModel)
.embedQuery("strawberry donuts", RequestConfig(RequestConfig.Companion.User("user")))
println(embeddingResult)

GCP.conversation {
while (true) {
print("\n🤖 Enter your question: ")
val userInput = readlnOrNull() ?: break
val answer = gcp.promptMessage(Prompt(userInput))
if (userInput == "exit") break
val answer = promptMessage(Prompt(userInput))
println("\n🤖 $answer")
}
println("\n🤖 Done")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ package com.xebia.functional.xef.conversation.gpc
import com.xebia.functional.gpt4all.conversation
import com.xebia.functional.xef.env.getenv
import com.xebia.functional.xef.gcp.GcpConfig
import com.xebia.functional.xef.gcp.VertexAIRegion
import com.xebia.functional.xef.gcp.pipelines.GcpPipelinesClient

suspend fun main() {
conversation {
val token = getenv("GCP_TOKEN") ?: error("missing gcp token")
val pipelineClient = autoClose(GcpPipelinesClient(GcpConfig(token, "xefdemo", "us-central1")))
val pipelineClient =
autoClose(GcpPipelinesClient(GcpConfig(token, "xefdemo", VertexAIRegion.US_CENTRAL1)))
val answer = pipelineClient.list()
println("\n🤖 $answer")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package com.xebia.functional.xef.gcp

import arrow.core.nonEmptyListOf
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.conversation.PlatformConversation
import com.xebia.functional.xef.env.getenv
import com.xebia.functional.xef.store.LocalVectorStore
import com.xebia.functional.xef.store.VectorStore
import kotlin.jvm.JvmField
import kotlin.jvm.JvmOverloads
import kotlin.jvm.JvmStatic
import kotlin.jvm.JvmSynthetic

private const val GCP_TOKEN_ENV_VAR = "GCP_TOKEN"
private const val GCP_PROJECT_ID_VAR = "GCP_PROJECT_ID"
private const val GCP_LOCATION_VAR = "GCP_LOCATION"

class GCP(projectId: String? = null, location: VertexAIRegion? = null, token: String? = null) {
private val config =
GcpConfig(
token = token ?: tokenFromEnv(),
projectId = projectId ?: projectIdFromEnv(),
location = location ?: locationFromEnv(),
)

private fun tokenFromEnv(): String = fromEnv(GCP_TOKEN_ENV_VAR)

private fun projectIdFromEnv(): String = fromEnv(GCP_PROJECT_ID_VAR)

private fun locationFromEnv(): VertexAIRegion =
fromEnv(GCP_LOCATION_VAR).let { envVar ->
VertexAIRegion.entries.find { it.officialName == envVar }
}
?: throw AIError.Env.GCP(
nonEmptyListOf(
"invalid value for $GCP_LOCATION_VAR - valid values are ${VertexAIRegion.entries.map(VertexAIRegion::officialName)}"
)
)

private fun fromEnv(name: String): String =
getenv(name) ?: throw AIError.Env.GCP(nonEmptyListOf("missing $name env var"))

val CODECHAT by lazy { GcpModel("codechat-bison@001", config) }
val TEXT_EMBEDDING_GECKO by lazy { GcpModel("textembedding-gecko", config) }

@JvmField val DEFAULT_CHAT = CODECHAT
@JvmField val DEFAULT_EMBEDDING = TEXT_EMBEDDING_GECKO

fun supportedModels(): List<GcpModel> = listOf(CODECHAT, TEXT_EMBEDDING_GECKO)

companion object {

@JvmField val FromEnvironment: GCP = GCP()

@JvmSynthetic
suspend inline fun <A> conversation(
store: VectorStore,
noinline block: suspend Conversation.() -> A
): A = block(conversation(store))

@JvmSynthetic
suspend fun <A> conversation(block: suspend Conversation.() -> A): A =
block(conversation(LocalVectorStore(GcpEmbeddings(FromEnvironment.DEFAULT_EMBEDDING))))

@JvmStatic
@JvmOverloads
fun conversation(
store: VectorStore = LocalVectorStore(GcpEmbeddings(FromEnvironment.DEFAULT_EMBEDDING))
): PlatformConversation = Conversation(store)
}
}

suspend inline fun <A> GCP.conversation(noinline block: suspend Conversation.() -> A): A =
block(Conversation(LocalVectorStore(GcpEmbeddings(DEFAULT_EMBEDDING))))
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,12 @@ import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.AutoClose
import com.xebia.functional.xef.conversation.autoClose
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
import io.ktor.client.*
import io.ktor.client.HttpClient
import io.ktor.client.call.*
import io.ktor.client.call.body
import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.client.request.header
import io.ktor.client.request.post
import io.ktor.client.request.setBody
import io.ktor.client.statement.*
import io.ktor.client.statement.bodyAsText
import io.ktor.http.*
import io.ktor.http.ContentType
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
Expand Down Expand Up @@ -85,7 +79,7 @@ class GcpClient(
)
val response =
http.post(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/us-central1/publishers/google/models/$modelId:predict"
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/us-central1/publishers/google/models/$modelId:predict"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
Expand Down Expand Up @@ -137,7 +131,7 @@ class GcpClient(
)
val response =
http.post(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/publishers/google/models/$modelId:predict"
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location.officialName}/publishers/google/models/$modelId:predict"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ package com.xebia.functional.xef.gcp
data class GcpConfig(
val token: String,
val projectId: String,
/** https://cloud.google.com/vertex-ai/docs/general/locations */
val location: String, // Supported us-central1 or europe-west4
/** [GCP locations](https://cloud.google.com/vertex-ai/docs/general/locations) */
val location: VertexAIRegion, // Supported us-central1 or europe-west4
)

enum class VertexAIRegion(val officialName: String) {
US_CENTRAL1("us-central1"),
EU_WEST4("europe-west4"),
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import kotlinx.uuid.UUID
import kotlinx.uuid.generateUUID

@OptIn(ExperimentalStdlibApi::class)
class GcpChat(modelId: String, config: GcpConfig) : Chat, Completion, AutoCloseable, Embeddings {
class GcpModel(modelId: String, config: GcpConfig) : Chat, Completion, AutoCloseable, Embeddings {
private val client: GcpClient = GcpClient(modelId, config)

override val name: String = client.modelId
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.xebia.functional.xef.gcp

import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.Chat
import com.xebia.functional.xef.prompt.Prompt
import kotlinx.coroutines.flow.Flow

@AiDsl
suspend fun Conversation.promptMessage(prompt: Prompt, model: Chat = GCP().DEFAULT_CHAT): String =
model.promptMessage(prompt, this)

@AiDsl
fun Conversation.promptStreaming(prompt: Prompt, model: Chat = GCP().DEFAULT_CHAT): Flow<String> =
model.promptStreaming(prompt, this)
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class GcpPipelinesClient(
suspend fun list(): List<PipelineJob> {
val response =
http.get(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/pipelineJobs"
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location.officialName}/pipelineJobs"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
Expand All @@ -92,7 +92,7 @@ class GcpPipelinesClient(
suspend fun get(pipelineJobName: String): PipelineJob? {
val response =
http.get(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/pipelineJobs/$pipelineJobName"
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location.officialName}/pipelineJobs/$pipelineJobName"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
Expand All @@ -105,7 +105,7 @@ class GcpPipelinesClient(
suspend fun create(pipelineJobId: String?, pipelineJob: CreatePipelineJob): PipelineJob? {
val response =
http.post(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/pipelineJobs"
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location.officialName}/pipelineJobs"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
Expand All @@ -120,7 +120,7 @@ class GcpPipelinesClient(
suspend fun cancel(pipelineJobName: String): Unit {
val response =
http.post(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/pipelineJobs/$pipelineJobName:cancel"
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location.officialName}/pipelineJobs/$pipelineJobName:cancel"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
Expand All @@ -133,7 +133,7 @@ class GcpPipelinesClient(
suspend fun delete(pipelineJobName: String): Operation {
val response =
http.delete(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/pipelineJobs/$pipelineJobName"
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location.officialName}/pipelineJobs/$pipelineJobName"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@ import kotlin.jvm.JvmOverloads
import kotlin.jvm.JvmStatic
import kotlin.jvm.JvmSynthetic

private const val KEY_ENV_VAR = "OPENAI_TOKEN"
private const val HOST_ENV_VAR = "OPENAI_HOST"

class OpenAI(internal var token: String? = null, internal var host: String? = null) :
AutoCloseable, AutoClose by autoClose() {

private fun openAITokenFromEnv(): String {
return getenv("OPENAI_TOKEN")
?: throw AIError.Env.OpenAI(nonEmptyListOf("missing OPENAI_TOKEN env var"))
return getenv(KEY_ENV_VAR)
?: throw AIError.Env.OpenAI(nonEmptyListOf("missing $KEY_ENV_VAR env var"))
}

private fun openAIHostFromEnv(): String? {
return getenv("OPENAI_HOST")
return getenv(HOST_ENV_VAR)
}

fun getToken(): String {
Expand Down

0 comments on commit 61a48cb

Please sign in to comment.