Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Google AI as an LLM #416

Open
wants to merge 7 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ intellijPlatform {
}
freeArgs = listOf(
"-mute",
"TemplateWordInPluginId,ForbiddenPluginIdPrefix"
"TemplateWordInPluginId,ForbiddenPluginIdPrefix",
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class LLMSetupPanelBuilder(e: AnActionEvent, private val project: Project) : Pan
private val defaultModulesArray = arrayOf("")
private var modelSelector = ComboBox(defaultModulesArray)
private var llmUserTokenField = JTextField(30)
private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName, llmSettingsState.huggingFaceName))
private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName, llmSettingsState.huggingFaceName, llmSettingsState.geminiName))
private val backLlmButton = JButton(PluginLabelsBundle.get("back"))
private val okLlmButton = JButton(PluginLabelsBundle.get("next"))
private val junitSelector = JUnitCombobox(e)
Expand Down Expand Up @@ -146,6 +146,10 @@ class LLMSetupPanelBuilder(e: AnActionEvent, private val project: Project) : Pan
llmSettingsState.huggingFaceToken = llmPlatforms[index].token
llmSettingsState.huggingFaceModel = llmPlatforms[index].model
}
if (llmPlatforms[index].name == llmSettingsState.geminiName) {
llmSettingsState.geminiToken = llmPlatforms[index].token
llmSettingsState.geminiModel = llmPlatforms[index].model
}
}
llmSettingsState.junitVersion = junitSelector.selectedItem!! as JUnitVersion

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager
import org.jetbrains.research.testspark.tools.llm.generation.LLMPlatform
import org.jetbrains.research.testspark.tools.llm.generation.TestBodyPrinterFactory
import org.jetbrains.research.testspark.tools.llm.generation.TestSuiteParserFactory
import org.jetbrains.research.testspark.tools.llm.generation.gemini.GeminiPlatform
import org.jetbrains.research.testspark.tools.llm.generation.grazie.GrazieInfo
import org.jetbrains.research.testspark.tools.llm.generation.grazie.GraziePlatform
import org.jetbrains.research.testspark.tools.llm.generation.hf.HuggingFacePlatform
Expand Down Expand Up @@ -74,6 +75,9 @@ object LLMHelper {
if (platformSelector.selectedItem!!.toString() == settingsState.huggingFaceName) {
models = getHuggingFaceModels()
}
if (platformSelector.selectedItem!!.toString() == settingsState.geminiName) {
models = getGeminiModels(llmUserTokenField.text)
}
modelSelector.model = DefaultComboBoxModel(models)
for (index in llmPlatforms.indices) {
if (llmPlatforms[index].name == settingsState.openAIName &&
Expand All @@ -94,6 +98,12 @@ object LLMHelper {
modelSelector.selectedItem = settingsState.huggingFaceModel
llmPlatforms[index].model = modelSelector.selectedItem!!.toString()
}
if (llmPlatforms[index].name == settingsState.geminiName &&
llmPlatforms[index].name == platformSelector.selectedItem!!.toString()
) {
modelSelector.selectedItem = settingsState.geminiModel
llmPlatforms[index].model = modelSelector.selectedItem!!.toString()
}
}
modelSelector.isEnabled = true
if (models.contentEquals(arrayOf(""))) modelSelector.isEnabled = false
Expand Down Expand Up @@ -131,6 +141,12 @@ object LLMHelper {
llmUserTokenField.text = settingsState.huggingFaceToken
llmPlatforms[index].token = settingsState.huggingFaceToken
}
if (llmPlatforms[index].name == settingsState.geminiName &&
llmPlatforms[index].name == platformSelector.selectedItem!!.toString()
) {
llmUserTokenField.text = settingsState.geminiToken
llmPlatforms[index].token = settingsState.geminiToken
}
}
}

Expand Down Expand Up @@ -219,7 +235,7 @@ object LLMHelper {
* @return The list of LLMPlatforms.
*/
fun getLLLMPlatforms(): List<LLMPlatform> {
return listOf(OpenAIPlatform(), GraziePlatform(), HuggingFacePlatform())
return listOf(OpenAIPlatform(), GraziePlatform(), HuggingFacePlatform(), GeminiPlatform())
}

/**
Expand Down Expand Up @@ -346,6 +362,48 @@ object LLMHelper {
return arrayOf("")
}

/**
* Retrieves a list of available models from the Google AI API.
*
* Note that this will only return models that support content generation because we need this for the
* test-generation queries.
*
* @param providedToken The authentication token provided by Google AI.
* @return An array of model IDs. If an error occurs during the request, an array with an empty string is returned.
*/
fun getGeminiModels(providedToken: String): Array<String> {
val url = "https://generativelanguage.googleapis.com/v1beta/models?key=$providedToken"

val httpRequest = HttpRequests.request(url)
val models = mutableListOf<String>()

try {
httpRequest.connect { request ->
if ((request.connection as HttpURLConnection).responseCode == HttpURLConnection.HTTP_OK) {
val jsonObject = JsonParser.parseString(request.readString()).asJsonObject
val dataArray = jsonObject.getAsJsonArray("models")
for (dataObject in dataArray) {
val id = dataObject.asJsonObject.getAsJsonPrimitive("name").asString
val methods = dataObject.asJsonObject
.getAsJsonArray("supportedGenerationMethods")
.map { method -> method.asString }
if (methods.contains("generateContent")) {
models.add(id.removePrefix("models/"))
}
}
}
}
} catch (e: HttpRequests.HttpStatusException) {
return arrayOf("")
}

if (models.isNotEmpty()) {
return models.toTypedArray()
}

return arrayOf("")
}

/**
* Retrieves the available Grazie models.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class LLMSettingsComponent(private val project: Project) : SettingsComponent {

// Models
private var modelSelector = ComboBox(arrayOf(""))
private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName, llmSettingsState.huggingFaceName))
private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName, llmSettingsState.huggingFaceName, llmSettingsState.geminiName))

// Default LLM Requests
private var defaultLLMRequestsSeparator =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class LLMSettingsConfigurable(private val project: Project) : SettingsConfigurab
settingsComponent!!.llmPlatforms[index].token = llmSettingsState.huggingFaceToken
settingsComponent!!.llmPlatforms[index].model = llmSettingsState.huggingFaceModel
}
if (settingsComponent!!.llmPlatforms[index].name == llmSettingsState.geminiName) {
settingsComponent!!.llmPlatforms[index].token = llmSettingsState.geminiToken
settingsComponent!!.llmPlatforms[index].model = llmSettingsState.geminiModel
}
}
settingsComponent!!.currentLLMPlatformName = llmSettingsState.currentLLMPlatformName
settingsComponent!!.maxLLMRequest = llmSettingsState.maxLLMRequest
Expand Down Expand Up @@ -89,6 +93,10 @@ class LLMSettingsConfigurable(private val project: Project) : SettingsConfigurab
modified = modified or (settingsComponent!!.llmPlatforms[index].token != llmSettingsState.huggingFaceToken)
modified = modified or (settingsComponent!!.llmPlatforms[index].model != llmSettingsState.huggingFaceModel)
}
if (settingsComponent!!.llmPlatforms[index].name == llmSettingsState.geminiName) {
modified = modified or (settingsComponent!!.llmPlatforms[index].token != llmSettingsState.geminiToken)
modified = modified or (settingsComponent!!.llmPlatforms[index].model != llmSettingsState.geminiModel)
}
}
modified = modified or (settingsComponent!!.currentLLMPlatformName != llmSettingsState.currentLLMPlatformName)
modified = modified or (settingsComponent!!.maxLLMRequest != llmSettingsState.maxLLMRequest)
Expand Down Expand Up @@ -150,6 +158,10 @@ class LLMSettingsConfigurable(private val project: Project) : SettingsConfigurab
llmSettingsState.huggingFaceToken = settingsComponent!!.llmPlatforms[index].token
llmSettingsState.huggingFaceModel = settingsComponent!!.llmPlatforms[index].model
}
if (settingsComponent!!.llmPlatforms[index].name == llmSettingsState.geminiName) {
llmSettingsState.geminiToken = settingsComponent!!.llmPlatforms[index].token
llmSettingsState.geminiModel = settingsComponent!!.llmPlatforms[index].model
}
}
llmSettingsState.currentLLMPlatformName = settingsComponent!!.currentLLMPlatformName
llmSettingsState.maxLLMRequest = settingsComponent!!.maxLLMRequest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ data class LLMSettingsState(
var huggingFaceName: String = DefaultLLMSettingsState.huggingFaceName,
var huggingFaceToken: String = DefaultLLMSettingsState.huggingFaceToken,
var huggingFaceModel: String = DefaultLLMSettingsState.huggingFaceModel,
var geminiName: String = DefaultLLMSettingsState.geminiName,
var geminiToken: String = DefaultLLMSettingsState.geminiToken,
var geminiModel: String = DefaultLLMSettingsState.geminiModel,
var currentLLMPlatformName: String = DefaultLLMSettingsState.currentLLMPlatformName,
var maxLLMRequest: Int = DefaultLLMSettingsState.maxLLMRequest,
var maxInputParamsDepth: Int = DefaultLLMSettingsState.maxInputParamsDepth,
Expand Down Expand Up @@ -51,6 +54,9 @@ data class LLMSettingsState(
val huggingFaceName: String = LLMDefaultsBundle.get("huggingFaceName")
val huggingFaceToken: String = LLMDefaultsBundle.get("huggingFaceToken")
val huggingFaceModel: String = LLMDefaultsBundle.get("huggingFaceModel")
val geminiName: String = LLMDefaultsBundle.get("geminiName")
val geminiToken: String = LLMDefaultsBundle.get("geminiToken")
val geminiModel: String = LLMDefaultsBundle.get("geminiModel")
var currentLLMPlatformName: String = LLMDefaultsBundle.get("openAIName")
val maxLLMRequest: Int = LLMDefaultsBundle.get("maxLLMRequest").toInt()
val maxInputParamsDepth: Int = LLMDefaultsBundle.get("maxInputParamsDepth").toInt()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class LlmSettingsArguments(private val project: Project) {
llmSettingsState.openAIName -> llmSettingsState.openAIToken
llmSettingsState.grazieName -> llmSettingsState.grazieToken
llmSettingsState.huggingFaceName -> llmSettingsState.huggingFaceToken
llmSettingsState.geminiName -> llmSettingsState.geminiToken
else -> ""
}

Expand All @@ -70,6 +71,7 @@ class LlmSettingsArguments(private val project: Project) {
llmSettingsState.openAIName -> llmSettingsState.openAIModel
llmSettingsState.grazieName -> llmSettingsState.grazieModel
llmSettingsState.huggingFaceName -> llmSettingsState.huggingFaceModel
llmSettingsState.geminiName -> llmSettingsState.geminiModel
else -> ""
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import org.jetbrains.research.testspark.core.generation.llm.network.RequestManag
import org.jetbrains.research.testspark.services.LLMSettingsService
import org.jetbrains.research.testspark.settings.llm.LLMSettingsState
import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments
import org.jetbrains.research.testspark.tools.llm.generation.gemini.GeminiRequestManager
import org.jetbrains.research.testspark.tools.llm.generation.grazie.GrazieRequestManager
import org.jetbrains.research.testspark.tools.llm.generation.hf.HuggingFaceRequestManager
import org.jetbrains.research.testspark.tools.llm.generation.openai.OpenAIRequestManager
Expand All @@ -22,6 +23,7 @@ class StandardRequestManagerFactory(private val project: Project) : RequestManag
llmSettingsState.openAIName -> OpenAIRequestManager(project)
llmSettingsState.grazieName -> GrazieRequestManager(project)
llmSettingsState.huggingFaceName -> HuggingFaceRequestManager(project)
llmSettingsState.geminiName -> GeminiRequestManager(project)
else -> throw IllegalStateException("Unknown selected platform: $platform")
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package org.jetbrains.research.testspark.tools.llm.generation.gemini

import org.jetbrains.research.testspark.tools.llm.generation.LLMPlatform

class GeminiPlatform(
override val name: String = "Google AI",
override var token: String = "",
override var model: String = "",
) : LLMPlatform
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package org.jetbrains.research.testspark.tools.llm.generation.gemini

data class GeminiRequest(
val contents: List<GeminiRequestBody>,
)

data class GeminiRequestBody(
val parts: List<GeminiChatMessage>,
)

data class GeminiChatMessage(
val text: String,
)

data class GeminiReply(
val content: GeminiReplyContent,
val finishReason: String,
val avgLogprobs: Double,
)

data class GeminiReplyContent(
val parts: List<GeminiReplyPart>,
val role: String?,
)

data class GeminiReplyPart(
val text: String,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package org.jetbrains.research.testspark.tools.llm.generation.gemini

import com.google.gson.GsonBuilder
import com.google.gson.JsonParser
import com.intellij.openapi.project.Project
import com.intellij.util.io.HttpRequests
import com.intellij.util.io.HttpRequests.HttpStatusException
import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle
import org.jetbrains.research.testspark.core.monitor.ErrorMonitor
import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator
import org.jetbrains.research.testspark.core.test.TestsAssembler
import org.jetbrains.research.testspark.tools.ToolUtils
import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments
import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager
import org.jetbrains.research.testspark.tools.llm.generation.IJRequestManager
import java.net.HttpURLConnection

class GeminiRequestManager(project: Project) : IJRequestManager(project) {
private val url = "https://generativelanguage.googleapis.com/v1beta/models/"
private val gson = GsonBuilder().create()

private val llmErrorManager = LLMErrorManager()

override fun send(
prompt: String,
indicator: CustomProgressIndicator,
testsAssembler: TestsAssembler,
errorMonitor: ErrorMonitor,
): SendResult {
val model = LlmSettingsArguments(project).getModel()
val apiURL = "$url$model:generateContent?key=$token"
val httpRequest = HttpRequests.post(apiURL, "application/json")

val messages = chatHistory.map {
GeminiChatMessage(it.content)
}

val geminiRequest = GeminiRequest(listOf(GeminiRequestBody(messages)))

var sendResult = SendResult.OK

try {
httpRequest.connect { request ->
request.write(gson.toJson(geminiRequest))

val connection = request.connection as HttpURLConnection

when (val responseCode = connection.responseCode) {
HttpURLConnection.HTTP_OK -> {
assembleGeminiResponse(request, testsAssembler, indicator, errorMonitor)
}

HttpURLConnection.HTTP_INTERNAL_ERROR -> {
llmErrorManager.errorProcess(
LLMMessagesBundle.get("serverProblems"),
project,
errorMonitor,
)
sendResult = SendResult.OTHER
}

HttpURLConnection.HTTP_BAD_REQUEST -> {
llmErrorManager.warningProcess(
LLMMessagesBundle.get("tooLongPrompt"),
project,
)
sendResult = SendResult.PROMPT_TOO_LONG
}

HttpURLConnection.HTTP_UNAUTHORIZED -> {
llmErrorManager.errorProcess(
LLMMessagesBundle.get("wrongToken"),
project,
errorMonitor,
)
sendResult = SendResult.OTHER
}

else -> {
llmErrorManager.errorProcess(
llmErrorManager.createRequestErrorMessage(responseCode),
project,
errorMonitor,
)
sendResult = SendResult.OTHER
}
}
}
} catch (e: HttpStatusException) {
log.error { "Error in sending request: ${e.message}" }
}

return sendResult
}

private fun assembleGeminiResponse(
httpRequest: HttpRequests.Request,
testsAssembler: TestsAssembler,
indicator: CustomProgressIndicator,
errorMonitor: ErrorMonitor,
) {
while (true) {
if (ToolUtils.isProcessCanceled(errorMonitor, indicator)) return

val text = httpRequest.reader.readText()
val result =
gson.fromJson(
JsonParser.parseString(text)
.asJsonObject["candidates"]
.asJsonArray[0].asJsonObject,
GeminiReply::class.java,
)

testsAssembler.consume(result.content.parts[0].text)

if (result.finishReason == "STOP") break
}

log.debug { testsAssembler.getContent() }
}
}
Loading
Loading