diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelBuilder.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelBuilder.kt index 873c634c9..f544f7040 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelBuilder.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/llm/LLMSetupPanelBuilder.kt @@ -34,7 +34,7 @@ class LLMSetupPanelBuilder(e: AnActionEvent, private val project: Project) : Pan private val defaultModulesArray = arrayOf("") private var modelSelector = ComboBox(defaultModulesArray) private var llmUserTokenField = JTextField(30) - private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName, llmSettingsState.huggingFaceName)) + private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName, llmSettingsState.huggingFaceName, llmSettingsState.geminiName)) private val backLlmButton = JButton(PluginLabelsBundle.get("back")) private val okLlmButton = JButton(PluginLabelsBundle.get("next")) private val junitSelector = JUnitCombobox(e) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt index 916da9537..e8a1ea58d 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt @@ -22,6 +22,7 @@ import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager import org.jetbrains.research.testspark.tools.llm.generation.LLMPlatform import org.jetbrains.research.testspark.tools.llm.generation.TestBodyPrinterFactory import org.jetbrains.research.testspark.tools.llm.generation.TestSuiteParserFactory +import org.jetbrains.research.testspark.tools.llm.generation.gemini.GeminiPlatform import org.jetbrains.research.testspark.tools.llm.generation.grazie.GrazieInfo import org.jetbrains.research.testspark.tools.llm.generation.grazie.GraziePlatform import org.jetbrains.research.testspark.tools.llm.generation.hf.HuggingFacePlatform @@ -74,6 +75,9 @@ object LLMHelper { if (platformSelector.selectedItem!!.toString() == settingsState.huggingFaceName) { models = getHuggingFaceModels() } + if (platformSelector.selectedItem!!.toString() == settingsState.geminiName) { + models = getGeminiModels(llmUserTokenField.text) + } modelSelector.model = DefaultComboBoxModel(models) for (index in llmPlatforms.indices) { if (llmPlatforms[index].name == settingsState.openAIName && @@ -219,7 +223,7 @@ object LLMHelper { * @return The list of LLMPlatforms. */ fun getLLLMPlatforms(): List { - return listOf(OpenAIPlatform(), GraziePlatform(), HuggingFacePlatform()) + return listOf(OpenAIPlatform(), GraziePlatform(), HuggingFacePlatform(), GeminiPlatform()) } /** @@ -346,6 +350,48 @@ object LLMHelper { return arrayOf("") } + /** + * Retrieves a list of available models from the Google AI API. + * + * Note that this will only return models that support content generation because we need this for the + * test-generation queries. + * + * @param providedToken The authentication token provided by Google AI. + * @return An array of model IDs. If an error occurs during the request, an array with an empty string is returned. + */ + fun getGeminiModels(providedToken: String): Array { + val url = "https://generativelanguage.googleapis.com/v1beta/models?key=$providedToken" + + val httpRequest = HttpRequests.request(url) + val models = mutableListOf() + + try { + httpRequest.connect { it -> + if ((it.connection as HttpURLConnection).responseCode == HttpURLConnection.HTTP_OK) { + val jsonObject = JsonParser.parseString(it.readString()).asJsonObject + val dataArray = jsonObject.getAsJsonArray("models") + for (dataObject in dataArray) { + val id = dataObject.asJsonObject.getAsJsonPrimitive("name").asString + val methods = dataObject.asJsonObject + .getAsJsonArray("supportedGenerationMethods") + .map { method -> method.asString } + if (methods.contains("generateContent")) { + models.add(id) + } + } + } + } + } catch (e: HttpRequests.HttpStatusException) { + return arrayOf("") + } + + if (models.isNotEmpty()) { + return models.toTypedArray() + } + + return arrayOf("") + } + /** * Retrieves the available Grazie models. * diff --git a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsState.kt b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsState.kt index 590ec3c1d..60fea6ce6 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsState.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsState.kt @@ -18,6 +18,9 @@ data class LLMSettingsState( var huggingFaceName: String = DefaultLLMSettingsState.huggingFaceName, var huggingFaceToken: String = DefaultLLMSettingsState.huggingFaceToken, var huggingFaceModel: String = DefaultLLMSettingsState.huggingFaceModel, + var geminiName: String = DefaultLLMSettingsState.geminiName, + var geminiToken: String = DefaultLLMSettingsState.geminiToken, + var geminiModel: String = DefaultLLMSettingsState.geminiModel, var currentLLMPlatformName: String = DefaultLLMSettingsState.currentLLMPlatformName, var maxLLMRequest: Int = DefaultLLMSettingsState.maxLLMRequest, var maxInputParamsDepth: Int = DefaultLLMSettingsState.maxInputParamsDepth, @@ -51,6 +54,9 @@ data class LLMSettingsState( val huggingFaceName: String = LLMDefaultsBundle.get("huggingFaceName") val huggingFaceToken: String = LLMDefaultsBundle.get("huggingFaceToken") val huggingFaceModel: String = LLMDefaultsBundle.get("huggingFaceModel") + val geminiName: String = LLMDefaultsBundle.get("geminiName") + val geminiToken: String = LLMDefaultsBundle.get("geminiToken") + val geminiModel: String = LLMDefaultsBundle.get("geminiModel") var currentLLMPlatformName: String = LLMDefaultsBundle.get("openAIName") val maxLLMRequest: Int = LLMDefaultsBundle.get("maxLLMRequest").toInt() val maxInputParamsDepth: Int = LLMDefaultsBundle.get("maxInputParamsDepth").toInt() diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt index 271cf4b49..344668285 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/LlmSettingsArguments.kt @@ -58,6 +58,7 @@ class LlmSettingsArguments(private val project: Project) { llmSettingsState.openAIName -> llmSettingsState.openAIToken llmSettingsState.grazieName -> llmSettingsState.grazieToken llmSettingsState.huggingFaceName -> llmSettingsState.huggingFaceToken + llmSettingsState.geminiName -> llmSettingsState.geminiToken else -> "" } @@ -70,6 +71,7 @@ class LlmSettingsArguments(private val project: Project) { llmSettingsState.openAIName -> llmSettingsState.openAIModel llmSettingsState.grazieName -> llmSettingsState.grazieModel llmSettingsState.huggingFaceName -> llmSettingsState.huggingFaceModel + llmSettingsState.geminiName -> llmSettingsState.geminiModel else -> "" } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt index f05d55986..91db397ff 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt @@ -5,6 +5,7 @@ import org.jetbrains.research.testspark.core.generation.llm.network.RequestManag import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.settings.llm.LLMSettingsState import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments +import org.jetbrains.research.testspark.tools.llm.generation.gemini.GeminiRequestManager import org.jetbrains.research.testspark.tools.llm.generation.grazie.GrazieRequestManager import org.jetbrains.research.testspark.tools.llm.generation.hf.HuggingFaceRequestManager import org.jetbrains.research.testspark.tools.llm.generation.openai.OpenAIRequestManager @@ -22,6 +23,7 @@ class StandardRequestManagerFactory(private val project: Project) : RequestManag llmSettingsState.openAIName -> OpenAIRequestManager(project) llmSettingsState.grazieName -> GrazieRequestManager(project) llmSettingsState.huggingFaceName -> HuggingFaceRequestManager(project) + llmSettingsState.geminiName -> GeminiRequestManager(project) else -> throw IllegalStateException("Unknown selected platform: $platform") } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiPlatform.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiPlatform.kt new file mode 100644 index 000000000..0fb6ed2a2 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiPlatform.kt @@ -0,0 +1,9 @@ +package org.jetbrains.research.testspark.tools.llm.generation.gemini + +import org.jetbrains.research.testspark.tools.llm.generation.LLMPlatform + +class GeminiPlatform( + override val name: String = "Gemini", + override var token: String = "", + override var model: String = "", +) : LLMPlatform \ No newline at end of file diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiRequestBody.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiRequestBody.kt new file mode 100644 index 000000000..365c1b780 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiRequestBody.kt @@ -0,0 +1,28 @@ +package org.jetbrains.research.testspark.tools.llm.generation.gemini + +data class GeminiRequest( + val contents: List, +) + +data class GeminiRequestBody( + val parts: List, +) + +data class GeminiChatMessage( + val text: String, +) + +data class GeminiReply( + val content: GeminiReplyContent, + val finishReason: String, + val avgLogprobs: Double, +) + +data class GeminiReplyContent( + val parts: List, + val role: String?, +) + +data class GeminiReplyPart( + val text: String, +) \ No newline at end of file diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiRequestManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiRequestManager.kt new file mode 100644 index 000000000..68e284cce --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiRequestManager.kt @@ -0,0 +1,122 @@ +package org.jetbrains.research.testspark.tools.llm.generation.gemini + +import com.google.gson.GsonBuilder +import com.google.gson.JsonParser +import com.intellij.openapi.project.Project +import com.intellij.util.io.HttpRequests +import com.intellij.util.io.HttpRequests.HttpStatusException +import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle +import org.jetbrains.research.testspark.core.monitor.ErrorMonitor +import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.TestsAssembler +import org.jetbrains.research.testspark.tools.ToolUtils +import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments +import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager +import org.jetbrains.research.testspark.tools.llm.generation.IJRequestManager +import java.net.HttpURLConnection + +class GeminiRequestManager(project: Project) : IJRequestManager(project) { + private val url = "https://generativelanguage.googleapis.com/v1beta/models/" + private val gson = GsonBuilder().create() + + private val llmErrorManager = LLMErrorManager() + + override fun send( + prompt: String, + indicator: CustomProgressIndicator, + testsAssembler: TestsAssembler, + errorMonitor: ErrorMonitor + ): SendResult { + val model = LlmSettingsArguments(project).getModel() + val apiURL = "$url$model:generateContent?key=$token" + val httpRequest = HttpRequests.post(apiURL, "application/json") + + val messages = chatHistory.map { + GeminiChatMessage(it.content) + } + + val geminiRequest = GeminiRequest(listOf(GeminiRequestBody(messages))) + + var sendResult = SendResult.OK + + try { + httpRequest.connect { request -> + request.write(gson.toJson(geminiRequest)) + + val connection = request.connection as HttpURLConnection + + when (val responseCode = connection.responseCode) { + HttpURLConnection.HTTP_OK -> { + assembleGeminiResponse(request, testsAssembler, indicator, errorMonitor) + } + + HttpURLConnection.HTTP_INTERNAL_ERROR -> { + llmErrorManager.errorProcess( + LLMMessagesBundle.get("serverProblems"), + project, + errorMonitor, + ) + sendResult = SendResult.OTHER + } + + HttpURLConnection.HTTP_BAD_REQUEST -> { + llmErrorManager.warningProcess( + LLMMessagesBundle.get("tooLongPrompt"), + project, + ) + sendResult = SendResult.PROMPT_TOO_LONG + } + + HttpURLConnection.HTTP_UNAUTHORIZED -> { + llmErrorManager.errorProcess( + LLMMessagesBundle.get("wrongToken"), + project, + errorMonitor, + ) + sendResult = SendResult.OTHER + } + + else -> { + llmErrorManager.errorProcess( + llmErrorManager.createRequestErrorMessage(responseCode), + project, + errorMonitor, + ) + sendResult = SendResult.OTHER + } + } + } + + } catch (e: HttpStatusException) { + log.error { "Error in sending request: ${e.message}" } + } + + return sendResult + } + + private fun assembleGeminiResponse( + httpRequest: HttpRequests.Request, + testsAssembler: TestsAssembler, + indicator: CustomProgressIndicator, + errorMonitor: ErrorMonitor, + ) { + while (true) { + if (ToolUtils.isProcessCanceled(errorMonitor, indicator)) return + + val text = httpRequest.reader.readText() + val result = + gson.fromJson( + JsonParser.parseString(text) + .asJsonObject["candidates"] + .asJsonArray[0].asJsonObject, + GeminiReply::class.java, + ) + + testsAssembler.consume(result.content.parts[0].text) + + if (result.finishReason == "STOP") break + } + + log.debug { testsAssembler.getContent() } + } +} \ No newline at end of file diff --git a/src/main/resources/properties/llm/LLMDefaults.properties b/src/main/resources/properties/llm/LLMDefaults.properties index f95c62ef1..757d37fc6 100644 --- a/src/main/resources/properties/llm/LLMDefaults.properties +++ b/src/main/resources/properties/llm/LLMDefaults.properties @@ -7,6 +7,9 @@ grazieModel= huggingFaceName=HuggingFace huggingFaceToken= huggingFaceModel= +geminiName=Google Gemini +geminiToken=AIzaSyDzC8Vjx4LZvX0vKi9Gye4e18JnqeysH9E +geminiModel=gemini-1.5-flash-latest huggingFaceInitialSystemPrompt=You are a helpful and honest code and programming assistant. Please, respond concisely and truthfully. maxLLMRequest=3 maxInputParamsDepth=2