Skip to content

Commit

Permalink
Add support for Google AI as an LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanlukasczyk committed Nov 24, 2024
1 parent 3224f28 commit 6817a86
Show file tree
Hide file tree
Showing 9 changed files with 221 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class LLMSetupPanelBuilder(e: AnActionEvent, private val project: Project) : Pan
private val defaultModulesArray = arrayOf("")
private var modelSelector = ComboBox(defaultModulesArray)
private var llmUserTokenField = JTextField(30)
private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName, llmSettingsState.huggingFaceName))
private var platformSelector = ComboBox(arrayOf(llmSettingsState.openAIName, llmSettingsState.huggingFaceName, llmSettingsState.geminiName))
private val backLlmButton = JButton(PluginLabelsBundle.get("back"))
private val okLlmButton = JButton(PluginLabelsBundle.get("next"))
private val junitSelector = JUnitCombobox(e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager
import org.jetbrains.research.testspark.tools.llm.generation.LLMPlatform
import org.jetbrains.research.testspark.tools.llm.generation.TestBodyPrinterFactory
import org.jetbrains.research.testspark.tools.llm.generation.TestSuiteParserFactory
import org.jetbrains.research.testspark.tools.llm.generation.gemini.GeminiPlatform
import org.jetbrains.research.testspark.tools.llm.generation.grazie.GrazieInfo
import org.jetbrains.research.testspark.tools.llm.generation.grazie.GraziePlatform
import org.jetbrains.research.testspark.tools.llm.generation.hf.HuggingFacePlatform
Expand Down Expand Up @@ -74,6 +75,9 @@ object LLMHelper {
if (platformSelector.selectedItem!!.toString() == settingsState.huggingFaceName) {
models = getHuggingFaceModels()
}
if (platformSelector.selectedItem!!.toString() == settingsState.geminiName) {
models = getGeminiModels(llmUserTokenField.text)
}
modelSelector.model = DefaultComboBoxModel(models)
for (index in llmPlatforms.indices) {
if (llmPlatforms[index].name == settingsState.openAIName &&
Expand Down Expand Up @@ -219,7 +223,7 @@ object LLMHelper {
* @return The list of LLMPlatforms.
*/
fun getLLLMPlatforms(): List<LLMPlatform> {
return listOf(OpenAIPlatform(), GraziePlatform(), HuggingFacePlatform())
return listOf(OpenAIPlatform(), GraziePlatform(), HuggingFacePlatform(), GeminiPlatform())
}

/**
Expand Down Expand Up @@ -346,6 +350,48 @@ object LLMHelper {
return arrayOf("")
}

/**
* Retrieves a list of available models from the Google AI API.
*
* Note that this will only return models that support content generation because we need this for the
* test-generation queries.
*
* @param providedToken The authentication token provided by Google AI.
* @return An array of model IDs. If an error occurs during the request, an array with an empty string is returned.
*/
fun getGeminiModels(providedToken: String): Array<String> {
val url = "https://generativelanguage.googleapis.com/v1beta/models?key=$providedToken"

val httpRequest = HttpRequests.request(url)
val models = mutableListOf<String>()

try {
httpRequest.connect { it ->
if ((it.connection as HttpURLConnection).responseCode == HttpURLConnection.HTTP_OK) {
val jsonObject = JsonParser.parseString(it.readString()).asJsonObject
val dataArray = jsonObject.getAsJsonArray("models")
for (dataObject in dataArray) {
val id = dataObject.asJsonObject.getAsJsonPrimitive("name").asString
val methods = dataObject.asJsonObject
.getAsJsonArray("supportedGenerationMethods")
.map { method -> method.asString }
if (methods.contains("generateContent")) {
models.add(id)
}
}
}
}
} catch (e: HttpRequests.HttpStatusException) {
return arrayOf("")
}

if (models.isNotEmpty()) {
return models.toTypedArray()
}

return arrayOf("")
}

/**
* Retrieves the available Grazie models.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ data class LLMSettingsState(
var huggingFaceName: String = DefaultLLMSettingsState.huggingFaceName,
var huggingFaceToken: String = DefaultLLMSettingsState.huggingFaceToken,
var huggingFaceModel: String = DefaultLLMSettingsState.huggingFaceModel,
var geminiName: String = DefaultLLMSettingsState.geminiName,
var geminiToken: String = DefaultLLMSettingsState.geminiToken,
var geminiModel: String = DefaultLLMSettingsState.geminiModel,
var currentLLMPlatformName: String = DefaultLLMSettingsState.currentLLMPlatformName,
var maxLLMRequest: Int = DefaultLLMSettingsState.maxLLMRequest,
var maxInputParamsDepth: Int = DefaultLLMSettingsState.maxInputParamsDepth,
Expand Down Expand Up @@ -51,6 +54,9 @@ data class LLMSettingsState(
val huggingFaceName: String = LLMDefaultsBundle.get("huggingFaceName")
val huggingFaceToken: String = LLMDefaultsBundle.get("huggingFaceToken")
val huggingFaceModel: String = LLMDefaultsBundle.get("huggingFaceModel")
val geminiName: String = LLMDefaultsBundle.get("geminiName")
val geminiToken: String = LLMDefaultsBundle.get("geminiToken")
val geminiModel: String = LLMDefaultsBundle.get("geminiModel")
var currentLLMPlatformName: String = LLMDefaultsBundle.get("openAIName")
val maxLLMRequest: Int = LLMDefaultsBundle.get("maxLLMRequest").toInt()
val maxInputParamsDepth: Int = LLMDefaultsBundle.get("maxInputParamsDepth").toInt()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class LlmSettingsArguments(private val project: Project) {
llmSettingsState.openAIName -> llmSettingsState.openAIToken
llmSettingsState.grazieName -> llmSettingsState.grazieToken
llmSettingsState.huggingFaceName -> llmSettingsState.huggingFaceToken
llmSettingsState.geminiName -> llmSettingsState.geminiToken
else -> ""
}

Expand All @@ -70,6 +71,7 @@ class LlmSettingsArguments(private val project: Project) {
llmSettingsState.openAIName -> llmSettingsState.openAIModel
llmSettingsState.grazieName -> llmSettingsState.grazieModel
llmSettingsState.huggingFaceName -> llmSettingsState.huggingFaceModel
llmSettingsState.geminiName -> llmSettingsState.geminiModel
else -> ""
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import org.jetbrains.research.testspark.core.generation.llm.network.RequestManag
import org.jetbrains.research.testspark.services.LLMSettingsService
import org.jetbrains.research.testspark.settings.llm.LLMSettingsState
import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments
import org.jetbrains.research.testspark.tools.llm.generation.gemini.GeminiRequestManager
import org.jetbrains.research.testspark.tools.llm.generation.grazie.GrazieRequestManager
import org.jetbrains.research.testspark.tools.llm.generation.hf.HuggingFaceRequestManager
import org.jetbrains.research.testspark.tools.llm.generation.openai.OpenAIRequestManager
Expand All @@ -22,6 +23,7 @@ class StandardRequestManagerFactory(private val project: Project) : RequestManag
llmSettingsState.openAIName -> OpenAIRequestManager(project)
llmSettingsState.grazieName -> GrazieRequestManager(project)
llmSettingsState.huggingFaceName -> HuggingFaceRequestManager(project)
llmSettingsState.geminiName -> GeminiRequestManager(project)
else -> throw IllegalStateException("Unknown selected platform: $platform")
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package org.jetbrains.research.testspark.tools.llm.generation.gemini

import org.jetbrains.research.testspark.tools.llm.generation.LLMPlatform

class GeminiPlatform(
override val name: String = "Gemini",
override var token: String = "",
override var model: String = "",
) : LLMPlatform
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package org.jetbrains.research.testspark.tools.llm.generation.gemini

data class GeminiRequest(
val contents: List<GeminiRequestBody>,
)

data class GeminiRequestBody(
val parts: List<GeminiChatMessage>,
)

data class GeminiChatMessage(
val text: String,
)

data class GeminiReply(
val content: GeminiReplyContent,
val finishReason: String,
val avgLogprobs: Double,
)

data class GeminiReplyContent(
val parts: List<GeminiReplyPart>,
val role: String?,
)

data class GeminiReplyPart(
val text: String,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package org.jetbrains.research.testspark.tools.llm.generation.gemini

import com.google.gson.GsonBuilder
import com.google.gson.JsonParser
import com.intellij.openapi.project.Project
import com.intellij.util.io.HttpRequests
import com.intellij.util.io.HttpRequests.HttpStatusException
import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle
import org.jetbrains.research.testspark.core.monitor.ErrorMonitor
import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator
import org.jetbrains.research.testspark.core.test.TestsAssembler
import org.jetbrains.research.testspark.tools.ToolUtils
import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments
import org.jetbrains.research.testspark.tools.llm.error.LLMErrorManager
import org.jetbrains.research.testspark.tools.llm.generation.IJRequestManager
import java.net.HttpURLConnection

class GeminiRequestManager(project: Project) : IJRequestManager(project) {
private val url = "https://generativelanguage.googleapis.com/v1beta/models/"
private val gson = GsonBuilder().create()

private val llmErrorManager = LLMErrorManager()

override fun send(
prompt: String,
indicator: CustomProgressIndicator,
testsAssembler: TestsAssembler,
errorMonitor: ErrorMonitor
): SendResult {
val model = LlmSettingsArguments(project).getModel()
val apiURL = "$url$model:generateContent?key=$token"
val httpRequest = HttpRequests.post(apiURL, "application/json")

val messages = chatHistory.map {
GeminiChatMessage(it.content)
}

val geminiRequest = GeminiRequest(listOf(GeminiRequestBody(messages)))

var sendResult = SendResult.OK

try {
httpRequest.connect { request ->
request.write(gson.toJson(geminiRequest))

val connection = request.connection as HttpURLConnection

when (val responseCode = connection.responseCode) {
HttpURLConnection.HTTP_OK -> {
assembleGeminiResponse(request, testsAssembler, indicator, errorMonitor)
}

HttpURLConnection.HTTP_INTERNAL_ERROR -> {
llmErrorManager.errorProcess(
LLMMessagesBundle.get("serverProblems"),
project,
errorMonitor,
)
sendResult = SendResult.OTHER
}

HttpURLConnection.HTTP_BAD_REQUEST -> {
llmErrorManager.warningProcess(
LLMMessagesBundle.get("tooLongPrompt"),
project,
)
sendResult = SendResult.PROMPT_TOO_LONG
}

HttpURLConnection.HTTP_UNAUTHORIZED -> {
llmErrorManager.errorProcess(
LLMMessagesBundle.get("wrongToken"),
project,
errorMonitor,
)
sendResult = SendResult.OTHER
}

else -> {
llmErrorManager.errorProcess(
llmErrorManager.createRequestErrorMessage(responseCode),
project,
errorMonitor,
)
sendResult = SendResult.OTHER
}
}
}

} catch (e: HttpStatusException) {
log.error { "Error in sending request: ${e.message}" }
}

return sendResult
}

private fun assembleGeminiResponse(
httpRequest: HttpRequests.Request,
testsAssembler: TestsAssembler,
indicator: CustomProgressIndicator,
errorMonitor: ErrorMonitor,
) {
while (true) {
if (ToolUtils.isProcessCanceled(errorMonitor, indicator)) return

val text = httpRequest.reader.readText()
val result =
gson.fromJson(
JsonParser.parseString(text)
.asJsonObject["candidates"]
.asJsonArray[0].asJsonObject,
GeminiReply::class.java,
)

testsAssembler.consume(result.content.parts[0].text)

if (result.finishReason == "STOP") break
}

log.debug { testsAssembler.getContent() }
}
}
5 changes: 4 additions & 1 deletion src/main/resources/properties/llm/LLMDefaults.properties
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ grazieModel=
huggingFaceName=HuggingFace
huggingFaceToken=
huggingFaceModel=
geminiName=Google AI
geminiToken=
geminiModel=
huggingFaceInitialSystemPrompt=You are a helpful and honest code and programming assistant. Please, respond concisely and truthfully.
maxLLMRequest=3
maxInputParamsDepth=2
Expand All @@ -23,4 +26,4 @@ lineCurrentDefaultPromptIndex=0
defaultLLMRequests=["Add more comments to the test","Reformat the test","Improve variable names","Improve assertions","Increase the call sequences for a more complex scenario"]
provideTestSamples=true
junitVersionPriority=true
llmSetup=true
llmSetup=true

0 comments on commit 6817a86

Please sign in to comment.